Commit d43a467c authored by hazrmard's avatar hazrmard
Browse files

tensor type/device conversion for score(), optional keyword-arguments for fit()/predict()

parent e70275cd
Pipeline #314 failed with stages
in 0 seconds
*.ini
pytorchbridge.egg-info
__pycache__/
\ No newline at end of file
......@@ -90,7 +90,7 @@ class TorchEstimator(BaseEstimator):
self.loss = nn.MSELoss()
def fit(self, X: torch.Tensor, y: torch.Tensor) -> 'TorchEstimator':
def fit(self, X: torch.Tensor, y: torch.Tensor, **kwargs) -> 'TorchEstimator':
"""
Fit target to features.
......@@ -99,6 +99,7 @@ class TorchEstimator(BaseEstimator):
for recurrent modules or (N, Features) for other modules.
y {torch.Tensor} -- `Tensor` of shape ([SeqLen,] N, OutputFeatures) for recurrent
modules of (N, OutputFeatures).
**kwargs -- Keyword arguments passed to `self.module(X, **kwargs)`
Returns:
self
......@@ -117,7 +118,7 @@ class TorchEstimator(BaseEstimator):
for instance, target in zip(self._to_batches(X), self._to_batches(y)):
instance, target = instance.to(self._device), target.to(self._device)
self.module.zero_grad()
output = self.module(instance)
output = self.module(instance, **kwargs)
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()
......@@ -128,13 +129,14 @@ class TorchEstimator(BaseEstimator):
return self
def predict(self, X: torch.Tensor) -> torch.Tensor:
def predict(self, X: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Predict output from inputs.
Arguments:
X {torch.Tensor} -- `Tensor` of shape (SeqLen, N, Features) or (N, SeqLen, Features)
for recurrent modules or (N, Features) for other modules.
**kwargs -- Keyword arguments passed to `self.module(X, **kwargs)`
Returns:
torch.Tensor -- of shape ([SeqLen,] N, OutputFeatures) for recurrent
......@@ -144,7 +146,7 @@ class TorchEstimator(BaseEstimator):
is_numpy = isinstance(X, np.ndarray)
X = torch.as_tensor(X, dtype=self._dtype, device=self._device)
with torch.no_grad():
result = self.module(X)
result = self.module(X, **kwargs)
if is_numpy:
return result.numpy()
return result
......@@ -164,6 +166,9 @@ class TorchEstimator(BaseEstimator):
Returns:
float -- Coefficient of determination.
"""
# pylint: disable=no-member
X = torch.as_tensor(X, dtype=self._dtype, device=self._device)
y = torch.as_tensor(y, dtype=self._dtype, device=self._device)
y_pred = self.predict(X)
residual_squares_sum = ((y - y_pred) ** 2).sum()
total_squares_sum = ((y - y.mean()) ** 2).sum()
......
......@@ -3,7 +3,7 @@ from distutils.core import setup
setup(name='pytorchbridge',
version='0.1.0',
version='0.1.1',
packages=['pytorchbridge'],
install_requires=['tqdm', 'scikit-learn>=0.20'],
author='Ibrahim Ahmed',
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment