This commit is contained in:
sven1977 2021-08-11 11:40:31 +02:00
parent b138f6ce8c
commit dbb0d1d42e
3 changed files with 7 additions and 4 deletions

View file

@ -134,9 +134,8 @@ def cql_loss(policy: Policy, model: ModelV2,
# Q-values for the batched actions.
action_dist_tp1 = policy.dist_class(
model.get_policy_output(model_out_tp1), policy.model)
policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()
policy_tp1, _ = action_dist_tp1.sample_logp()
log_pis_tp1 = tf.expand_dims(log_pis_tp1, -1)
q_t = model.get_q_values(model_out_t, actions)
q_t_selected = tf.squeeze(q_t, axis=-1)
if twin_q:

View file

@ -143,9 +143,8 @@ def cql_loss(policy: Policy, model: ModelV2,
# Q-values for the batched actions.
action_dist_tp1 = action_dist_class(
model.get_policy_output(model_out_tp1), policy.model)
policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()
policy_tp1, _ = action_dist_tp1.sample_logp()
log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1)
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
q_t_selected = torch.squeeze(q_t, dim=-1)
if twin_q:

View file

@ -587,6 +587,11 @@ class TorchPolicy(Policy):
"sgd_minibatch_size", self.config["train_batch_size"]) // \
len(self.devices)
# Set Model to train mode.
if self.model_gpu_towers:
for t in self.model_gpu_towers:
t.train()
# Shortcut for 1 CPU only: Batch should already be stored in
# `self._loaded_batches`.
if len(self.devices) == 1 and self.devices[0].type == "cpu":