From dbb0d1d42ea6b9624a91f361f271ed7e351a4984 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 11 Aug 2021 11:40:31 +0200 Subject: [PATCH] wip. --- rllib/agents/cql/cql_tf_policy.py | 3 +-- rllib/agents/cql/cql_torch_policy.py | 3 +-- rllib/policy/torch_policy.py | 5 +++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/rllib/agents/cql/cql_tf_policy.py b/rllib/agents/cql/cql_tf_policy.py index 844fc2d11..5b060e6c9 100644 --- a/rllib/agents/cql/cql_tf_policy.py +++ b/rllib/agents/cql/cql_tf_policy.py @@ -134,9 +134,8 @@ def cql_loss(policy: Policy, model: ModelV2, # Q-values for the batched actions. action_dist_tp1 = policy.dist_class( model.get_policy_output(model_out_tp1), policy.model) - policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp() + policy_tp1, _ = action_dist_tp1.sample_logp() - log_pis_tp1 = tf.expand_dims(log_pis_tp1, -1) q_t = model.get_q_values(model_out_t, actions) q_t_selected = tf.squeeze(q_t, axis=-1) if twin_q: diff --git a/rllib/agents/cql/cql_torch_policy.py b/rllib/agents/cql/cql_torch_policy.py index 31e169844..6bc43f24a 100644 --- a/rllib/agents/cql/cql_torch_policy.py +++ b/rllib/agents/cql/cql_torch_policy.py @@ -143,9 +143,8 @@ def cql_loss(policy: Policy, model: ModelV2, # Q-values for the batched actions. action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) - policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp() + policy_tp1, _ = action_dist_tp1.sample_logp() - log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1) q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t_selected = torch.squeeze(q_t, dim=-1) if twin_q: diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index b3e6dbc65..8147ee7ab 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -587,6 +587,11 @@ class TorchPolicy(Policy): "sgd_minibatch_size", self.config["train_batch_size"]) // \ len(self.devices) + # Set Model to train mode. + if self.model_gpu_towers: + for t in self.model_gpu_towers: + t.train() + # Shortcut for 1 CPU only: Batch should already be stored in # `self._loaded_batches`. if len(self.devices) == 1 and self.devices[0].type == "cpu":