diff --git a/rllib/contrib/bandits/models/linear_regression.py b/rllib/contrib/bandits/models/linear_regression.py index 568fcd56d..76ff2203a 100644 --- a/rllib/contrib/bandits/models/linear_regression.py +++ b/rllib/contrib/bandits/models/linear_regression.py @@ -34,7 +34,7 @@ class OnlineLinearRegression(nn.Module): def partial_fit(self, x, y): # TODO: Handle batch of data rather than individual points self._check_inputs(x, y) - x = x.squeeze() + x = x.squeeze(0) y = y.item() self.time += 1 self.delta_f += y * x