Bug fix in the contextual bandit's linear_regression.py model. (#8815)

This commit is contained in:
Sven Mika 2020-06-06 22:47:42 +02:00 committed by GitHub
parent be26a7b1b0
commit ad695a818b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -34,7 +34,7 @@ class OnlineLinearRegression(nn.Module):
def partial_fit(self, x, y):
# TODO: Handle batch of data rather than individual points
self._check_inputs(x, y)
x = x.squeeze()
x = x.squeeze(0)
y = y.item()
self.time += 1
self.delta_f += y * x