mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
wip.
This commit is contained in:
parent
b138f6ce8c
commit
dbb0d1d42e
3 changed files with 7 additions and 4 deletions
|
@ -134,9 +134,8 @@ def cql_loss(policy: Policy, model: ModelV2,
|
|||
# Q-values for the batched actions.
|
||||
action_dist_tp1 = policy.dist_class(
|
||||
model.get_policy_output(model_out_tp1), policy.model)
|
||||
policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()
|
||||
policy_tp1, _ = action_dist_tp1.sample_logp()
|
||||
|
||||
log_pis_tp1 = tf.expand_dims(log_pis_tp1, -1)
|
||||
q_t = model.get_q_values(model_out_t, actions)
|
||||
q_t_selected = tf.squeeze(q_t, axis=-1)
|
||||
if twin_q:
|
||||
|
|
|
@ -143,9 +143,8 @@ def cql_loss(policy: Policy, model: ModelV2,
|
|||
# Q-values for the batched actions.
|
||||
action_dist_tp1 = action_dist_class(
|
||||
model.get_policy_output(model_out_tp1), policy.model)
|
||||
policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()
|
||||
policy_tp1, _ = action_dist_tp1.sample_logp()
|
||||
|
||||
log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1)
|
||||
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||
q_t_selected = torch.squeeze(q_t, dim=-1)
|
||||
if twin_q:
|
||||
|
|
|
@ -587,6 +587,11 @@ class TorchPolicy(Policy):
|
|||
"sgd_minibatch_size", self.config["train_batch_size"]) // \
|
||||
len(self.devices)
|
||||
|
||||
# Set Model to train mode.
|
||||
if self.model_gpu_towers:
|
||||
for t in self.model_gpu_towers:
|
||||
t.train()
|
||||
|
||||
# Shortcut for 1 CPU only: Batch should already be stored in
|
||||
# `self._loaded_batches`.
|
||||
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
||||
|
|
Loading…
Add table
Reference in a new issue