Commit 098a21ab authored by hazrmard's avatar hazrmard
Browse files

Added truncated back-propagation through time bpy.py

parent d43a467c
from .pytorchbridge import TorchEstimator
from .bpt import RecurrentTorchEstimator
"""
TorchEstimator which implements truncated backpropagation through time.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import trange
from .pytorchbridge import TorchEstimator
class RecurrentTorchEstimator(TorchEstimator):
def __init__(self, module: nn.Module=None,
optimizer: optim.Optimizer=None,
loss: nn.modules.loss._Loss=None, epochs: int=2, verbose=False,
batch_size: int=8, cuda=True,
bpt_every: int=None, bpt_for: int=None):
"""
Keyword Arguments:
module {torch.nn.Module} -- A `nn.Module` describing the neural network,
optimizer {torch.optim.Optimizer} -- An `Optimizer` instance which
iteratively modifies weights,
loss {torch.nn._Loss} -- a `_Loss` instance which calculates the loss metric,
epochs {int} -- The number of times to iterate over the training data,
verbose {bool} -- Whether to log training progress or not,
batch_size {int} -- Chunk size of data for each training step,
cuda {bool} -- Whether to use GPU acceleration if available.
"""
super().__init__(
module=module,
optimizer=optimizer,
loss=loss,
epochs=epochs,
verbose=verbose,
batch_size=batch_size,
cuda=cuda
)
self.bpt_every = bpt_every
self.bpt_for = bpt_for
def _init(self, X, y):
# pylint: disable=no-member
# Create a linear model if no module provided
if self.module is None:
self._device = torch.device('cuda') if \
torch.cuda.is_available() and self.cuda \
else torch.device('cpu')
_, _, infeatures = self._get_shape(X)
_, _, outfeatures = self._get_shape(y)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(infeatures, outfeatures)
self.squeeze = y.ndim == 1
def forward(self, x, h0=None):
out, hidden = self.rnn(x, h0)
if self.squeeze:
return torch.squeeze(out)
return out, hidden
self.module = MyModule()
self.module.to(self._device)
else:
self._device = next(self.module.parameters()).device
self._dtype = next(self.module.parameters()).dtype
self._batch_first = any(map(lambda x: getattr(x, 'batch_first', False),
self.module.modules()))
self._time_dim = 1 if self._batch_first else 0
self._batch_dim = 0 if self._batch_first else 1
if self.optimizer is None:
self.optimizer = optim.SGD(self.module.parameters(), lr=0.1)
if self.loss is None:
self.loss = nn.MSELoss()
def fit(self, X, y, **kwargs):
# pylint: disable=no-member
self._init(X, y)
if self.verbose:
print()
ranger = trange(self.epochs, leave=False)
for e in ranger:
total_loss = 0.
for instance, target in zip(self._to_batches(X), self._to_batches(y)):
instance, target = instance.to(self._device), target.to(self._device)
hidden_shape = list(target.shape)
del hidden_shape[self._time_dim]
post_hidden = instance.new_zeros(hidden_shape) # batch, feature
for t in range(0, self._get_shape(instance)[0], self.bpt_every):
pre_in = self._slice_time(instance, t, t + self.bpt_every - self.bpt_for)
post_in = self._slice_time(instance, t + self.bpt_every - self.bpt_for, t + self.bpt_every)
sub_target = self._slice_time(target, t, t + self.bpt_every)
if post_in.shape[self._time_dim] < self.bpt_for: break
pre_out, pre_hidden = self.module(pre_in, post_hidden, **kwargs)
pre_hidden = pre_hidden.detach()
post_out, post_hidden = self.module(post_in, pre_hidden, **kwargs)
post_hidden = post_hidden.detach()
sub_output = torch.cat((pre_out, post_out), dim=self._time_dim)
self.module.zero_grad()
loss = self.loss(sub_output, sub_target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if self.verbose:
ranger.write(f'Epoch {e+1:3d}\tLoss: {total_loss:10.2f}')
return self
def _slice_time(self, t: torch.Tensor, start: int=None, stop: int=None) \
-> torch.Tensor:
if t.ndim > 2:
if self._time_dim == 1:
return t[:, start:stop, :]
else:
return t[start:stop, :]
return t
\ No newline at end of file
......@@ -9,18 +9,21 @@ from sklearn.base import BaseEstimator
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import trange, tqdm
from tqdm.auto import trange
class TorchEstimator(BaseEstimator):
"""
Wraps a `torch.nn.Module` instance with a scikit-learn `Estimator` API.
Note:
* All parameters in the provided module must have the same data type and device.
"""
def __init__(self, module: nn.Module=None,
optimizer: optim.Optimizer=None,
loss: nn.modules.loss._Loss=None, epochs: int=10, verbose=False,
loss: nn.modules.loss._Loss=None, epochs: int=2, verbose=False,
batch_size: int=8, cuda=True):
"""
Keyword Arguments:
......@@ -69,7 +72,7 @@ class TorchEstimator(BaseEstimator):
def __init__(self):
super().__init__()
self.linear = nn.Linear(infeatures, outfeatures)
self.squeeze = len(torch.as_tensor(y).size()) == 1
self.squeeze = y.ndim == 1
def forward(self, x):
x = self.linear(x)
......@@ -83,6 +86,8 @@ class TorchEstimator(BaseEstimator):
self._device = next(self.module.parameters()).device
self._dtype = next(self.module.parameters()).dtype
self._batch_first = any(map(lambda x: getattr(x, 'batch_first', True),
self.module.modules()))
if self.optimizer is None:
self.optimizer = optim.SGD(self.module.parameters(), lr=0.1)
......@@ -104,14 +109,13 @@ class TorchEstimator(BaseEstimator):
Returns:
self
"""
# TODO: Add instance weights to super-/sub-sample training data.
# pylint: disable=no-member
self._init(X, y)
if self.verbose:
print()
ranger = trange(self.epochs)
self._batch_first = self._is_batch_first()
ranger = trange(self.epochs, leave=False)
for e in ranger:
total_loss = 0.
......@@ -147,12 +151,14 @@ class TorchEstimator(BaseEstimator):
X = torch.as_tensor(X, dtype=self._dtype, device=self._device)
with torch.no_grad():
result = self.module(X, **kwargs)
if isinstance(result, tuple):
result = result[0] # recurrent layers return (output, hidden)
if is_numpy:
return result.numpy()
return result.cpu().numpy()
return result
def score(self, X, y) -> float:
def score(self, X, y, **kwargs) -> float:
"""
Measure how well the estimator learned through the coefficient of
determination.
......@@ -162,6 +168,7 @@ class TorchEstimator(BaseEstimator):
for recurrent modules or (N, Features) for other modules.
y {torch.Tensor} -- `Tensor` of shape ([SeqLen,] N, OutputFeatures) for recurrent
modules of (N, OutputFeatures).
**kwargs -- Keyword arguments passed to `self.module(X, **kwargs)`
Returns:
float -- Coefficient of determination.
......@@ -169,7 +176,7 @@ class TorchEstimator(BaseEstimator):
# pylint: disable=no-member
X = torch.as_tensor(X, dtype=self._dtype, device=self._device)
y = torch.as_tensor(y, dtype=self._dtype, device=self._device)
y_pred = self.predict(X)
y_pred = self.predict(X, **kwargs)
residual_squares_sum = ((y - y_pred) ** 2).sum()
total_squares_sum = ((y - y.mean()) ** 2).sum()
return (1 - residual_squares_sum / total_squares_sum).item()
......@@ -184,7 +191,7 @@ class TorchEstimator(BaseEstimator):
if isinstance(X, np.ndarray):
X = X.astype(float)
X = torch.as_tensor(X, dtype=self._dtype)
if not self._batch_first:
if not self._batch_first and X.ndim > 2:
# Recurrent layers take inputs of the shape (SeqLen, N, Features...)
# So if there is any recurrent layer in the module, assume that this
# is the expected input shape
......@@ -201,22 +208,6 @@ class TorchEstimator(BaseEstimator):
yield X[i*self.batch_size:(i+1)*self.batch_size]
def _is_recurrent(self) -> bool:
"""
Checks whether the network has any recurrent units.
"""
return any(map(lambda x: isinstance(x, nn.RNNBase), self.module.modules()))
def _is_batch_first(self) -> bool:
"""
Checks whether the features arrays are in the shape (Batch, ..., Features) or
(..., Batch, Features).
"""
# Default setting is batch_first=False for RNNBase subclasses
return any(map(lambda x: getattr(x, 'batch_first', True), self.module.modules()))
def _get_shape(self, t: torch.Tensor) -> Tuple[int, int, int]:
"""
Get size of each dimension of tensor depending on `batch_first`. The
......@@ -241,7 +232,8 @@ class TorchEstimator(BaseEstimator):
if ndims == 2:
return 0, sz[0], sz[1]
elif ndims == 3:
if self._is_batch_first():
if self._batch_first:
return sz[1], sz[0], sz[2]
else:
return sz[0], sz[1], sz[2]
......@@ -6,9 +6,10 @@ from sklearn.utils.estimator_checks import check_estimator
from sklearn.model_selection import GridSearchCV
from .pytorchbridge import TorchEstimator
from .bpt import RecurrentTorchEstimator
# pylint: disable=no-member
class TestAPI(TestCase):
......@@ -31,6 +32,19 @@ class TestAPI(TestCase):
est.fit(X, y).predict(X)
def test_recurrent_estimator(self):
X = torch.from_numpy(np.random.rand(10, 5, 2))
y = torch.from_numpy(np.random.rand(10, 5, 1))
est = RecurrentTorchEstimator(bpt_every=3, bpt_for=2)
pred = est.fit(X, y).predict(X)
self.assertSequenceEqual(pred.shape, (10, 5, 1))
def test_truncated_bpt(self):
X = torch.ones(10, 1, 1)
y = np.cumsum(X, axis=0)
est = RecurrentTorchEstimator(bpt_every=3, bpt_for=1)
def grid_search(self):
pass
......
......@@ -3,7 +3,7 @@ from distutils.core import setup
setup(name='pytorchbridge',
version='0.1.1',
version='0.1.2',
packages=['pytorchbridge'],
install_requires=['tqdm', 'scikit-learn>=0.20'],
author='Ibrahim Ahmed',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment