From 4fded808130e6d560bf5671ed2b06fc525066df3 Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Tue, 19 Jul 2022 00:58:31 -0700 Subject: [PATCH] [RLlib]: Fix FQE Policy call (#26671) --- rllib/offline/estimators/fqe_torch_model.py | 40 +++++++++++------- rllib/offline/estimators/tests/test_ope.py | 46 +++++++++++++++++++++ 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/rllib/offline/estimators/fqe_torch_model.py b/rllib/offline/estimators/fqe_torch_model.py index f976104ea..a68e74a4e 100644 --- a/rllib/offline/estimators/fqe_torch_model.py +++ b/rllib/offline/estimators/fqe_torch_model.py @@ -110,8 +110,9 @@ class FQETorchModel: A list of losses for each training iteration """ losses = [] - if self.minibatch_size is None: - minibatch_size = batch.count + minibatch_size = self.minibatch_size or batch.count + # Copy batch for shuffling + batch = batch.copy(shallow=True) for _ in range(self.n_iters): minibatch_losses = [] batch.shuffle() @@ -209,18 +210,29 @@ class FQETorchModel: input_dict = {SampleBatch.OBS: obs} seq_lens = torch.ones(len(obs), device=self.device, dtype=int) state_batches = [] - if self.policy.action_distribution_fn and is_overridden( - self.policy.action_distribution_fn - ): - dist_inputs, dist_class, _ = self.policy.action_distribution_fn( - self.policy, - self.policy.model, - input_dict=input_dict, - state_batches=state_batches, - seq_lens=seq_lens, - explore=False, - is_training=False, - ) + if is_overridden(self.policy.action_distribution_fn): + try: + # TorchPolicyV2 function signature + dist_inputs, dist_class, _ = self.policy.action_distribution_fn( + self.policy.model, + obs_batch=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=False, + is_training=False, + ) + except TypeError: + # TorchPolicyV1 function signature for compatibility with DQN + # TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2 + dist_inputs, dist_class, _ = self.policy.action_distribution_fn( + self.policy, + self.policy.model, + input_dict=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=False, + is_training=False, + ) else: dist_class = self.policy.dist_class dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens) diff --git a/rllib/offline/estimators/tests/test_ope.py b/rllib/offline/estimators/tests/test_ope.py index 749d3ca9f..c282375e3 100644 --- a/rllib/offline/estimators/tests/test_ope.py +++ b/rllib/offline/estimators/tests/test_ope.py @@ -10,10 +10,14 @@ from ray.rllib.offline.estimators import ( from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel from ray.rllib.offline.json_reader import JsonReader from ray.rllib.policy.sample_batch import concat_samples +from ray.rllib.utils.test_utils import check +from ray.rllib.utils.numpy import convert_to_numpy from pathlib import Path import os +import copy import numpy as np import gym +import torch class TestOPE(unittest.TestCase): @@ -162,6 +166,48 @@ class TestOPE(unittest.TestCase): print(*list(std_est.items()), sep="\n") print("\n\n\n") + def test_fqe_model(self): + # Test FQETorchModel for: + # (1) Check that it does not modify the underlying batch during training + # (2) Check that the stoppign criteria from FQE are working correctly + # (3) Check that using fqe._compute_action_probs equals brute force + # iterating over all actions with policy.compute_log_likelihoods + fqe = FQETorchModel( + policy=self.algo.get_policy(), + gamma=self.gamma, + **self.q_model_config, + ) + tmp_batch = copy.deepcopy(self.batch) + losses = fqe.train(self.batch) + + # Make sure FQETorchModel.train() does not modify self.batch + check(tmp_batch, self.batch) + + # Make sure FQE stopping criteria are respected + assert ( + len(losses) == fqe.n_iters or losses[-1] < fqe.delta + ), f"FQE.train() terminated early in {len(losses)} steps with final loss" + f"{losses[-1]} for n_iters: {fqe.n_iters} and delta: {fqe.delta}" + + # Test fqe._compute_action_probs against "brute force" method + # of computing log_prob for each possible action individually + # using policy.compute_log_likelihoods + obs = torch.tensor(self.batch["obs"], device=fqe.device) + action_probs = fqe._compute_action_probs(obs) + action_probs = convert_to_numpy(action_probs) + + tmp_probs = [] + for act in range(fqe.policy.action_space.n): + tmp_actions = np.zeros_like(self.batch["actions"]) + act + log_probs = fqe.policy.compute_log_likelihoods( + actions=tmp_actions, + obs_batch=self.batch["obs"], + ) + tmp_probs.append(torch.exp(log_probs)) + tmp_probs = torch.stack(tmp_probs).transpose(0, 1) + tmp_probs = convert_to_numpy(tmp_probs) + check(action_probs, tmp_probs, decimals=3) + def test_multiple_inputs(self): # TODO (Rohan138): Test with multiple input files pass