Commit 9851db1b authored by hazrmard's avatar hazrmard
Browse files

added validation scores for early stopping

parent 098a21ab
......@@ -18,7 +18,7 @@ class RecurrentTorchEstimator(TorchEstimator):
def __init__(self, module: nn.Module=None,
optimizer: optim.Optimizer=None,
loss: nn.modules.loss._Loss=None, epochs: int=2, verbose=False,
batch_size: int=8, cuda=True,
batch_size: int=8, cuda=True, return_hidden: bool=False,
bpt_every: int=None, bpt_for: int=None):
"""
Keyword Arguments:
......@@ -38,7 +38,8 @@ class RecurrentTorchEstimator(TorchEstimator):
epochs=epochs,
verbose=verbose,
batch_size=batch_size,
cuda=cuda
cuda=cuda,
return_hidden=return_hidden
)
self.bpt_every = bpt_every
......
......@@ -23,26 +23,54 @@ class TorchEstimator(BaseEstimator):
def __init__(self, module: nn.Module=None,
optimizer: optim.Optimizer=None,
loss: nn.modules.loss._Loss=None, epochs: int=2, verbose=False,
batch_size: int=8, cuda=True):
loss: nn.modules.loss._Loss=None, epochs: int=2, tol: float=1e-4,
max_tol_iter: int=5, verbose=False,
early_stopping: bool=False, validation_fraction: float=0.1,
batch_size: int=8, cuda=True, return_hidden: bool=False):
"""
Keyword Arguments:
module {torch.nn.Module} -- A `nn.Module` describing the neural network,
optimizer {torch.optim.Optimizer} -- An `Optimizer` instance which
iteratively modifies weights,
loss {torch.nn._Loss} -- a `_Loss` instance which calculates the loss metric,
epochs {int} -- The number of times to iterate over the training data,
verbose {bool} -- Whether to log training progress or not,
batch_size {int} -- Chunk size of data for each training step,
cuda {bool} -- Whether to use GPU acceleration if available.
Parameters
----------
module: torch.nn.Module
A `nn.Module` describing the neural network,
optimizer: torch.optim.Optimizer
An `Optimizer` instance which iteratively modifies weights,
loss: torch.nn._Loss
A `_Loss` instance which calculates the loss metric,
epochs: int
The number of times to iterate over the training data,
tol: float
Tolerance for loss between epochs.
max_tol_iter: int
If loss does not change by `tol` for `max_tol_iter`, training is
stopped.
early_stopping: bool, optional
Whether to stop training early if validation score does not improve
by `tol` for `max_tol_iter` iterations. By default False.
validation_fraction: float, optional
The fraction of the training data to be set aside for validation if
`early_stopping=True`. By default 0.1.
verbose: bool
Whether to log training progress or not,
batch_size: int
Chunk size of data for each training step,
cuda: bool
Whether to use GPU acceleration if available.
return_hidden: bool
Whether to return hidden and cell state tuple for recurrent
networks. (default: False)
"""
self.module = module
self.optimizer = optimizer
self.loss = loss
self.epochs = epochs
self.tol = tol
self.max_tol_iter = max_tol_iter
self.verbose = verbose
self.early_stopping = early_stopping
self.validation_fraction = validation_fraction
self.batch_size = batch_size
self.cuda = cuda
self.return_hidden = return_hidden
# pylint: disable=no-member
self._device = torch.device('cpu')
self._batch_first = None
......@@ -84,17 +112,21 @@ class TorchEstimator(BaseEstimator):
self.module.to(self._device)
else:
self._device = next(self.module.parameters()).device
self._dtype = next(self.module.parameters()).dtype
self._batch_first = any(map(lambda x: getattr(x, 'batch_first', True),
self.module.modules()))
self._batch_first = all(map(lambda x: getattr(x, 'batch_first', True),
self.module.modules()))
if self.optimizer is None:
self.optimizer = optim.SGD(self.module.parameters(), lr=0.1)
self.optimizer = optim.SGD(self.module.parameters(), lr=0.01)
if self.loss is None:
self.loss = nn.MSELoss()
def parameters(self):
return self.module.parameters()
def fit(self, X: torch.Tensor, y: torch.Tensor, **kwargs) -> 'TorchEstimator':
"""
Fit target to features.
......@@ -113,9 +145,28 @@ class TorchEstimator(BaseEstimator):
# pylint: disable=no-member
self._init(X, y)
# Setup train/validation split for early stopping
if self.early_stopping and self.validation_fraction > 0.:
if self._batch_first:
vidx = np.random.choice(len(X),
size=int(self.validation_fraction * len(X)))
Xval, yval = X[vidx], y[vidx]
idx = np.asarray(set(range(len(X))) - set(vidx), dtype=int)
X, y = X[idx], y[idx]
else:
vidx = np.random.choice(len(X[0]),
size=int(self.validation_fraction * len(X[0])))
Xval, yval = X[:, vidx], y[:, vidx]
idx = np.asarray(set(range(len(X[0]))) - set(vidx), dtype=int)
X, y = X[idx], y[idx]
Xval = torch.as_tensor(Xval, dtype=self._dtype, device=self._device)
yval = torch.as_tensor(yval, dtype=self._dtype, device=self._device)
if self.verbose:
print()
ranger = trange(self.epochs, leave=False)
loss_hist = []
for e in ranger:
total_loss = 0.
......@@ -123,6 +174,10 @@ class TorchEstimator(BaseEstimator):
instance, target = instance.to(self._device), target.to(self._device)
self.module.zero_grad()
output = self.module(instance, **kwargs)
# For recurrent networks, the outputs may also return
# a tuple of hidden states/cell states
if isinstance(output, tuple):
output, _ = output
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
......@@ -130,17 +185,34 @@ class TorchEstimator(BaseEstimator):
if self.verbose:
ranger.write(f'Epoch {e+1:3d}\tLoss: {total_loss:10.2f}')
if self.early_stopping and self.validation_fraction > 0.:
output = self.module(Xval, **kwargs)
total_loss = self.loss(output, yval).item()
loss_hist.append(total_loss)
if len(loss_hist) > 1:
# Get last max_tol_iter + 1 loss values
arr_loss = np.asarray(loss_hist[-self.max_tol_iter - 1:])
# Get last max_tol_iter changes in loss
delta_loss = np.abs(arr_loss[1:] - arr_loss[:-1])
# Check if last max_tol_iter changes are < tolerance
thresh_loss = delta_loss < self.tol
if sum(thresh_loss) == self.max_tol_iter:
ranger.close()
break
return self
def predict(self, X: torch.Tensor, **kwargs) -> torch.Tensor:
def predict(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Predict output from inputs.
Arguments:
X {torch.Tensor} -- `Tensor` of shape (SeqLen, N, Features) or (N, SeqLen, Features)
for recurrent modules or (N, Features) for other modules.
**kwargs -- Keyword arguments passed to `self.module(X, **kwargs)`
*args -- positional arguments passed to `self.module(X, *args, **kwargs)`
**kwargs -- Keyword arguments passed to `self.module(X, *args, **kwargs)`
Returns:
torch.Tensor -- of shape ([SeqLen,] N, OutputFeatures) for recurrent
......@@ -150,11 +222,22 @@ class TorchEstimator(BaseEstimator):
is_numpy = isinstance(X, np.ndarray)
X = torch.as_tensor(X, dtype=self._dtype, device=self._device)
with torch.no_grad():
result = self.module(X, **kwargs)
if isinstance(result, tuple):
result = self.module(X, *args, **kwargs)
if isinstance(result, tuple) and not self.return_hidden:
result = result[0] # recurrent layers return (output, hidden)
if is_numpy:
return result.cpu().numpy()
if isinstance(result, tuple): # If hidden units are returned
h = result[0].cpu().numpy()
if isinstance(result[1], tuple):# LSTM case
hn = result[1][0].cpu() # Secondary results are not
cn = result[1][1].cpu() # type-converted
return h, (hn, cn)
else: # RNN/GRU case
hn = result[1].cpu()
return h, hn
else: # If only result is returned
return result.cpu().numpy()
return result
......
......@@ -3,7 +3,7 @@ from distutils.core import setup
setup(name='pytorchbridge',
version='0.1.2',
version='0.1.3',
packages=['pytorchbridge'],
install_requires=['tqdm', 'scikit-learn>=0.20'],
author='Ibrahim Ahmed',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment