mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Bandit tf2 fix (+ add tf2 to test cases). (#24908)
This commit is contained in:
parent
fb60d68bbb
commit
628ee4b5f0
2 changed files with 7 additions and 3 deletions
|
@ -44,7 +44,7 @@ class OnlineLinearRegression(tf.Module if tf else object):
|
|||
x = tf.squeeze(x, axis=0)
|
||||
y = y[0]
|
||||
self.time += 1
|
||||
self.delta_f += y * x
|
||||
self.delta_f += tf.cast(y, tf.float32) * x
|
||||
self.delta_b += tf.tensordot(x, x, axes=0)
|
||||
# Can follow an update schedule if not doing sherman morison updates
|
||||
if self.time % self.update_schedule == 0:
|
||||
|
|
|
@ -24,7 +24,9 @@ class TestBandits(unittest.TestCase):
|
|||
)
|
||||
num_iterations = 5
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
for _ in framework_iterator(
|
||||
config, frameworks=("tf2", "torch"), with_eager_tracing=True
|
||||
):
|
||||
for train_batch_size in [1, 10]:
|
||||
config.training(train_batch_size=train_batch_size)
|
||||
trainer = config.build()
|
||||
|
@ -47,7 +49,9 @@ class TestBandits(unittest.TestCase):
|
|||
|
||||
num_iterations = 5
|
||||
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
for _ in framework_iterator(
|
||||
config, frameworks=("tf2", "torch"), with_eager_tracing=True
|
||||
):
|
||||
for train_batch_size in [1, 10]:
|
||||
config.training(train_batch_size=train_batch_size)
|
||||
trainer = config.build()
|
||||
|
|
Loading…
Add table
Reference in a new issue