mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
wip.
This commit is contained in:
parent
b138f6ce8c
commit
dbb0d1d42e
3 changed files with 7 additions and 4 deletions
|
@ -134,9 +134,8 @@ def cql_loss(policy: Policy, model: ModelV2,
|
||||||
# Q-values for the batched actions.
|
# Q-values for the batched actions.
|
||||||
action_dist_tp1 = policy.dist_class(
|
action_dist_tp1 = policy.dist_class(
|
||||||
model.get_policy_output(model_out_tp1), policy.model)
|
model.get_policy_output(model_out_tp1), policy.model)
|
||||||
policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()
|
policy_tp1, _ = action_dist_tp1.sample_logp()
|
||||||
|
|
||||||
log_pis_tp1 = tf.expand_dims(log_pis_tp1, -1)
|
|
||||||
q_t = model.get_q_values(model_out_t, actions)
|
q_t = model.get_q_values(model_out_t, actions)
|
||||||
q_t_selected = tf.squeeze(q_t, axis=-1)
|
q_t_selected = tf.squeeze(q_t, axis=-1)
|
||||||
if twin_q:
|
if twin_q:
|
||||||
|
|
|
@ -143,9 +143,8 @@ def cql_loss(policy: Policy, model: ModelV2,
|
||||||
# Q-values for the batched actions.
|
# Q-values for the batched actions.
|
||||||
action_dist_tp1 = action_dist_class(
|
action_dist_tp1 = action_dist_class(
|
||||||
model.get_policy_output(model_out_tp1), policy.model)
|
model.get_policy_output(model_out_tp1), policy.model)
|
||||||
policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp()
|
policy_tp1, _ = action_dist_tp1.sample_logp()
|
||||||
|
|
||||||
log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1)
|
|
||||||
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
|
||||||
q_t_selected = torch.squeeze(q_t, dim=-1)
|
q_t_selected = torch.squeeze(q_t, dim=-1)
|
||||||
if twin_q:
|
if twin_q:
|
||||||
|
|
|
@ -587,6 +587,11 @@ class TorchPolicy(Policy):
|
||||||
"sgd_minibatch_size", self.config["train_batch_size"]) // \
|
"sgd_minibatch_size", self.config["train_batch_size"]) // \
|
||||||
len(self.devices)
|
len(self.devices)
|
||||||
|
|
||||||
|
# Set Model to train mode.
|
||||||
|
if self.model_gpu_towers:
|
||||||
|
for t in self.model_gpu_towers:
|
||||||
|
t.train()
|
||||||
|
|
||||||
# Shortcut for 1 CPU only: Batch should already be stored in
|
# Shortcut for 1 CPU only: Batch should already be stored in
|
||||||
# `self._loaded_batches`.
|
# `self._loaded_batches`.
|
||||||
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
if len(self.devices) == 1 and self.devices[0].type == "cpu":
|
||||||
|
|
Loading…
Add table
Reference in a new issue