[RLlib] Bandit tf2 fix (+ add tf2 to test cases). (#24908)

This commit is contained in:
Sven Mika 2022-05-18 18:58:42 +02:00 committed by GitHub
parent fb60d68bbb
commit 628ee4b5f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 3 deletions

View file

@ -44,7 +44,7 @@ class OnlineLinearRegression(tf.Module if tf else object):
x = tf.squeeze(x, axis=0)
y = y[0]
self.time += 1
self.delta_f += y * x
self.delta_f += tf.cast(y, tf.float32) * x
self.delta_b += tf.tensordot(x, x, axes=0)
# Can follow an update schedule if not doing sherman morison updates
if self.time % self.update_schedule == 0:

View file

@ -24,7 +24,9 @@ class TestBandits(unittest.TestCase):
)
num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"):
for _ in framework_iterator(
config, frameworks=("tf2", "torch"), with_eager_tracing=True
):
for train_batch_size in [1, 10]:
config.training(train_batch_size=train_batch_size)
trainer = config.build()
@ -47,7 +49,9 @@ class TestBandits(unittest.TestCase):
num_iterations = 5
for _ in framework_iterator(config, frameworks="torch"):
for _ in framework_iterator(
config, frameworks=("tf2", "torch"), with_eager_tracing=True
):
for train_batch_size in [1, 10]:
config.training(train_batch_size=train_batch_size)
trainer = config.build()