Commit edc9fae2 authored by hazrmard's avatar hazrmard
Browse files

refactoring, minor updates

parent 9851db1b
......@@ -85,44 +85,44 @@ class RecurrentTorchEstimator(TorchEstimator):
self.loss = nn.MSELoss()
def fit(self, X, y, **kwargs):
# pylint: disable=no-member
self._init(X, y)
if self.verbose:
print()
ranger = trange(self.epochs, leave=False)
for e in ranger:
total_loss = 0.
for instance, target in zip(self._to_batches(X), self._to_batches(y)):
instance, target = instance.to(self._device), target.to(self._device)
hidden_shape = list(target.shape)
del hidden_shape[self._time_dim]
post_hidden = instance.new_zeros(hidden_shape) # batch, feature
for t in range(0, self._get_shape(instance)[0], self.bpt_every):
pre_in = self._slice_time(instance, t, t + self.bpt_every - self.bpt_for)
post_in = self._slice_time(instance, t + self.bpt_every - self.bpt_for, t + self.bpt_every)
sub_target = self._slice_time(target, t, t + self.bpt_every)
if post_in.shape[self._time_dim] < self.bpt_for: break
pre_out, pre_hidden = self.module(pre_in, post_hidden, **kwargs)
pre_hidden = pre_hidden.detach()
post_out, post_hidden = self.module(post_in, pre_hidden, **kwargs)
post_hidden = post_hidden.detach()
sub_output = torch.cat((pre_out, post_out), dim=self._time_dim)
self.module.zero_grad()
loss = self.loss(sub_output, sub_target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if self.verbose:
ranger.write(f'Epoch {e+1:3d}\tLoss: {total_loss:10.2f}')
return self
def partial_fit(self, X: torch.Tensor, y: torch.Tensor, **kwargs) -> float:
"""
Fit a single batch of tensors to the module. No type/device checking is
done.
Parameters
----------
X : torch.Tensor
A tensor containing features.
y : torch.Tensor
A tensor containing targets.
Returns
-------
float
The loss of the targets and module outputs.
"""
hidden_shape = list(y.shape)
del hidden_shape[self._time_dim]
post_hidden = X.new_zeros(hidden_shape) # batch, feature
for t in range(0, self._get_shape(X)[0], self.bpt_every):
pre_in = self._slice_time(X, t, t + self.bpt_every - self.bpt_for)
post_in = self._slice_time(X, t + self.bpt_every - self.bpt_for, t + self.bpt_every)
sub_target = self._slice_time(y, t, t + self.bpt_every)
if post_in.shape[self._time_dim] < self.bpt_for: break
pre_out, pre_hidden = self.module(pre_in, post_hidden, **kwargs)
pre_hidden = pre_hidden.detach()
post_out, post_hidden = self.module(post_in, pre_hidden, **kwargs)
post_hidden = post_hidden.detach()
# pylint: disable=no-member
sub_output = torch.cat((pre_out, post_out), dim=self._time_dim)
self.module.zero_grad()
loss = self.loss(sub_output, sub_target)
loss.backward()
self.optimizer.step()
return loss.item()
def _slice_time(self, t: torch.Tensor, start: int=None, stop: int=None) \
......
......@@ -9,16 +9,24 @@ from sklearn.base import BaseEstimator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import trange
# Defining custom types
TensorLike = Union[torch.Tensor, np.ndarray]
class TorchEstimator(BaseEstimator):
"""
Wraps a `torch.nn.Module` instance with a scikit-learn `Estimator` API.
Note:
* All parameters in the provided module must have the same data type and device.
*Note:*
All parameters in the provided module must have the same data type and
device. The `_init()` method uses the dtype/device of the first element
in `module.parameters()` to set default casting options.
"""
def __init__(self, module: nn.Module=None,
......@@ -78,7 +86,7 @@ class TorchEstimator(BaseEstimator):
def _init(self, X, y):
def _init(self, X: TensorLike, y: TensorLike):
"""
Initializes internal parameters before fitting, including device, data
types for network parameters.
......@@ -123,11 +131,20 @@ class TorchEstimator(BaseEstimator):
self.loss = nn.MSELoss()
def parameters(self):
def parameters(self) -> Iterator[torch.Tensor]:
"""
Convenience method for `self.module.parameters()`.
Returns
-------
Iterator
Iterator over a module's parameters.
"""
return self.module.parameters()
def fit(self, X: torch.Tensor, y: torch.Tensor, **kwargs) -> 'TorchEstimator':
def fit(self, X: Union[TensorLike, DataLoader], y: TensorLike=None, \
**kwargs) -> 'TorchEstimator':
"""
Fit target to features.
......@@ -135,7 +152,8 @@ class TorchEstimator(BaseEstimator):
X {torch.Tensor} -- `Tensor` of shape (SeqLen, N, Features) or (N, SeqLen, Features)
for recurrent modules or (N, Features) for other modules.
y {torch.Tensor} -- `Tensor` of shape ([SeqLen,] N, OutputFeatures) for recurrent
modules of (N, OutputFeatures).
modules of (N, OutputFeatures). Optional if X is a `DataLoader` which
already contains features and targets.
**kwargs -- Keyword arguments passed to `self.module(X, **kwargs)`
Returns:
......@@ -151,7 +169,7 @@ class TorchEstimator(BaseEstimator):
vidx = np.random.choice(len(X),
size=int(self.validation_fraction * len(X)))
Xval, yval = X[vidx], y[vidx]
idx = np.asarray(set(range(len(X))) - set(vidx), dtype=int)
idx = np.asarray(list(set(range(len(X))) - set(vidx)), dtype=int)
X, y = X[idx], y[idx]
else:
......@@ -170,18 +188,14 @@ class TorchEstimator(BaseEstimator):
for e in ranger:
total_loss = 0.
for instance, target in zip(self._to_batches(X), self._to_batches(y)):
if isinstance(X, DataLoader):
iterable = X
else:
iterable = zip(self._to_batches(X), self._to_batches(y))
for instance, target in iterable:
instance, target = instance.to(self._device), target.to(self._device)
self.module.zero_grad()
output = self.module(instance, **kwargs)
# For recurrent networks, the outputs may also return
# a tuple of hidden states/cell states
if isinstance(output, tuple):
output, _ = output
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
total_loss += self.partial_fit(instance, target, **kwargs)
if self.verbose:
ranger.write(f'Epoch {e+1:3d}\tLoss: {total_loss:10.2f}')
......@@ -204,18 +218,53 @@ class TorchEstimator(BaseEstimator):
return self
def predict(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def partial_fit(self, X: torch.Tensor, y: torch.Tensor, **kwargs) -> float:
"""
Fit a single batch of tensors to the module. No type/device checking is
done.
Parameters
----------
X : torch.Tensor
A tensor containing features.
y : torch.Tensor
A tensor containing targets.
Returns
-------
float
The loss of the targets and module outputs.
"""
self.module.zero_grad()
output = self.module(X, **kwargs)
# For recurrent networks, the outputs may also return
# a tuple of hidden states/cell states
if isinstance(output, tuple):
output, _ = output
loss = self.loss(output, y)
loss.backward()
self.optimizer.step()
return loss.item()
def predict(self, X: TensorLike, *args, **kwargs) -> torch.Tensor:
"""
Predict output from inputs.
Arguments:
X {torch.Tensor} -- `Tensor` of shape (SeqLen, N, Features) or (N, SeqLen, Features)
Parameters
----------
X : torch.Tensor
Tensor` of shape (SeqLen, N, Features) or (N, SeqLen, Features)
for recurrent modules or (N, Features) for other modules.
*args -- positional arguments passed to `self.module(X, *args, **kwargs)`
**kwargs -- Keyword arguments passed to `self.module(X, *args, **kwargs)`
Returns:
torch.Tensor -- of shape ([SeqLen,] N, OutputFeatures) for recurrent
*args
Positional arguments passed to `self.module(X, *args, **kwargs)`
**kwargs
Keyword arguments passed to `self.module(X, *args, **kwargs)`
Returns
-------
torch.Tensor
Of shape ([SeqLen,] N, OutputFeatures) for recurrent
modules of (N, OutputFeatures).
"""
# pylint: disable=no-member
......@@ -227,6 +276,7 @@ class TorchEstimator(BaseEstimator):
result = result[0] # recurrent layers return (output, hidden)
if is_numpy:
# TODO: Iterate over arbitrarily nested tuples of returned tensors
if isinstance(result, tuple): # If hidden units are returned
h = result[0].cpu().numpy()
if isinstance(result[1], tuple):# LSTM case
......@@ -241,7 +291,7 @@ class TorchEstimator(BaseEstimator):
return result
def score(self, X, y, **kwargs) -> float:
def score(self, X: TensorLike, y: TensorLike, **kwargs) -> float:
"""
Measure how well the estimator learned through the coefficient of
determination.
......@@ -265,7 +315,7 @@ class TorchEstimator(BaseEstimator):
return (1 - residual_squares_sum / total_squares_sum).item()
def _to_batches(self, X: torch.Tensor) -> Iterator[torch.Tensor]:
def _to_batches(self, X: TensorLike) -> Iterator[torch.Tensor]:
"""
Convert ([SeqLen,] N, Features) to a generator of ([SeqLen,] n, Features)
mini-batches. So for recurrent layers, training can be done in batches.
......@@ -291,7 +341,7 @@ class TorchEstimator(BaseEstimator):
yield X[i*self.batch_size:(i+1)*self.batch_size]
def _get_shape(self, t: torch.Tensor) -> Tuple[int, int, int]:
def _get_shape(self, t: TensorLike) -> Tuple[int, int, int]:
"""
Get size of each dimension of tensor depending on `batch_first`. The
size is returned in order of time, batch, features.
......
......@@ -17,7 +17,32 @@ class TestAPI(TestCase):
# check_estimator(TorchEstimator())
def test_numpy_arrays(self):
def create_arrays(self, recurrent=False, batch_first=True, tensor=True):
t, n, fin, fout = 1, 10, 2, 1 # time, batch, features
if recurrent:
if batch_first:
X = np.random.rand(n, t, fin)
y = np.random.rand(n, fout)
else:
X = np.random.rand(t, n, fin)
y = np.random.rand(n, fout)
else:
X = np.random.rand(n, fin)
y = np.random.rand(n, fout)
if tensor:
X = torch.as_tensor(X)
y = torch.as_tensor(y)
return X, y
def test_numpy_dense(self):
X = np.random.rand(10, 2)
y = np.random.rand(10)
est = TorchEstimator()
est.fit(X, y).predict(X)
def test_numpy_recurrent(self):
X = np.random.rand(10, 2)
y = np.random.rand(10)
est = TorchEstimator()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment