mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib]: Fix FQE Policy call (#26671)
This commit is contained in:
parent
adf24bfa97
commit
4fded80813
2 changed files with 72 additions and 14 deletions
|
@ -110,8 +110,9 @@ class FQETorchModel:
|
|||
A list of losses for each training iteration
|
||||
"""
|
||||
losses = []
|
||||
if self.minibatch_size is None:
|
||||
minibatch_size = batch.count
|
||||
minibatch_size = self.minibatch_size or batch.count
|
||||
# Copy batch for shuffling
|
||||
batch = batch.copy(shallow=True)
|
||||
for _ in range(self.n_iters):
|
||||
minibatch_losses = []
|
||||
batch.shuffle()
|
||||
|
@ -209,18 +210,29 @@ class FQETorchModel:
|
|||
input_dict = {SampleBatch.OBS: obs}
|
||||
seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
|
||||
state_batches = []
|
||||
if self.policy.action_distribution_fn and is_overridden(
|
||||
self.policy.action_distribution_fn
|
||||
):
|
||||
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
||||
self.policy,
|
||||
self.policy.model,
|
||||
input_dict=input_dict,
|
||||
state_batches=state_batches,
|
||||
seq_lens=seq_lens,
|
||||
explore=False,
|
||||
is_training=False,
|
||||
)
|
||||
if is_overridden(self.policy.action_distribution_fn):
|
||||
try:
|
||||
# TorchPolicyV2 function signature
|
||||
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
||||
self.policy.model,
|
||||
obs_batch=input_dict,
|
||||
state_batches=state_batches,
|
||||
seq_lens=seq_lens,
|
||||
explore=False,
|
||||
is_training=False,
|
||||
)
|
||||
except TypeError:
|
||||
# TorchPolicyV1 function signature for compatibility with DQN
|
||||
# TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
|
||||
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
|
||||
self.policy,
|
||||
self.policy.model,
|
||||
input_dict=input_dict,
|
||||
state_batches=state_batches,
|
||||
seq_lens=seq_lens,
|
||||
explore=False,
|
||||
is_training=False,
|
||||
)
|
||||
else:
|
||||
dist_class = self.policy.dist_class
|
||||
dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
|
||||
|
|
|
@ -10,10 +10,14 @@ from ray.rllib.offline.estimators import (
|
|||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.policy.sample_batch import concat_samples
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from pathlib import Path
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
import gym
|
||||
import torch
|
||||
|
||||
|
||||
class TestOPE(unittest.TestCase):
|
||||
|
@ -162,6 +166,48 @@ class TestOPE(unittest.TestCase):
|
|||
print(*list(std_est.items()), sep="\n")
|
||||
print("\n\n\n")
|
||||
|
||||
def test_fqe_model(self):
|
||||
# Test FQETorchModel for:
|
||||
# (1) Check that it does not modify the underlying batch during training
|
||||
# (2) Check that the stoppign criteria from FQE are working correctly
|
||||
# (3) Check that using fqe._compute_action_probs equals brute force
|
||||
# iterating over all actions with policy.compute_log_likelihoods
|
||||
fqe = FQETorchModel(
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
**self.q_model_config,
|
||||
)
|
||||
tmp_batch = copy.deepcopy(self.batch)
|
||||
losses = fqe.train(self.batch)
|
||||
|
||||
# Make sure FQETorchModel.train() does not modify self.batch
|
||||
check(tmp_batch, self.batch)
|
||||
|
||||
# Make sure FQE stopping criteria are respected
|
||||
assert (
|
||||
len(losses) == fqe.n_iters or losses[-1] < fqe.delta
|
||||
), f"FQE.train() terminated early in {len(losses)} steps with final loss"
|
||||
f"{losses[-1]} for n_iters: {fqe.n_iters} and delta: {fqe.delta}"
|
||||
|
||||
# Test fqe._compute_action_probs against "brute force" method
|
||||
# of computing log_prob for each possible action individually
|
||||
# using policy.compute_log_likelihoods
|
||||
obs = torch.tensor(self.batch["obs"], device=fqe.device)
|
||||
action_probs = fqe._compute_action_probs(obs)
|
||||
action_probs = convert_to_numpy(action_probs)
|
||||
|
||||
tmp_probs = []
|
||||
for act in range(fqe.policy.action_space.n):
|
||||
tmp_actions = np.zeros_like(self.batch["actions"]) + act
|
||||
log_probs = fqe.policy.compute_log_likelihoods(
|
||||
actions=tmp_actions,
|
||||
obs_batch=self.batch["obs"],
|
||||
)
|
||||
tmp_probs.append(torch.exp(log_probs))
|
||||
tmp_probs = torch.stack(tmp_probs).transpose(0, 1)
|
||||
tmp_probs = convert_to_numpy(tmp_probs)
|
||||
check(action_probs, tmp_probs, decimals=3)
|
||||
|
||||
def test_multiple_inputs(self):
|
||||
# TODO (Rohan138): Test with multiple input files
|
||||
pass
|
||||
|
|
Loading…
Add table
Reference in a new issue