From a9d8da0100e680e5fa88c88fe99f139f243c86f8 Mon Sep 17 00:00:00 2001 From: Rohan Potdar <105385119+rapotdar@users.noreply.github.com> Date: Tue, 7 Jun 2022 03:52:19 -0700 Subject: [PATCH] [RLlib]: Doubly Robust Off-Policy Evaluation. (#25056) --- doc/source/rllib/rllib-env.rst | 2 +- doc/source/rllib/rllib-offline.rst | 19 +- doc/source/rllib/rllib-training.rst | 12 +- python/ray/util/ml_utils/dict.py | 28 ++- .../marwil-halfcheetahbulletenv-v0.yaml | 2 +- rllib/BUILD | 17 +- rllib/agents/trainer.py | 24 +- rllib/agents/trainer_config.py | 41 +++- rllib/algorithms/bc/bc.py | 2 +- rllib/algorithms/cql/cql.py | 2 +- rllib/algorithms/cql/tests/test_cql.py | 3 +- rllib/algorithms/marwil/marwil.py | 7 +- rllib/evaluation/metrics.py | 2 +- rllib/evaluation/rollout_worker.py | 97 +++++--- rllib/evaluation/worker_set.py | 2 +- rllib/examples/serving/cartpole_server.py | 2 +- rllib/examples/serving/unity3d_server.py | 2 +- rllib/offline/estimators/__init__.py | 10 + rllib/offline/estimators/direct_method.py | 170 +++++++++++++ rllib/offline/estimators/doubly_robust.py | 71 ++++++ rllib/offline/estimators/fqe_torch_model.py | 222 +++++++++++++++++ .../offline/estimators/importance_sampling.py | 66 +++--- .../estimators/off_policy_estimator.py | 188 +++++++++++++++ rllib/offline/estimators/qreg_torch_model.py | 224 ++++++++++++++++++ rllib/offline/estimators/tests/test_ope.py | 200 ++++++++++++++++ .../weighted_importance_sampling.py | 84 ++++--- rllib/offline/off_policy_estimator.py | 174 +------------- rllib/tests/test_io.py | 18 +- rllib/tests/test_nested_action_spaces.py | 2 +- 29 files changed, 1378 insertions(+), 315 deletions(-) create mode 100644 rllib/offline/estimators/direct_method.py create mode 100644 rllib/offline/estimators/doubly_robust.py create mode 100644 rllib/offline/estimators/fqe_torch_model.py create mode 100644 rllib/offline/estimators/off_policy_estimator.py create mode 100644 rllib/offline/estimators/qreg_torch_model.py create mode 100644 rllib/offline/estimators/tests/test_ope.py diff --git a/doc/source/rllib/rllib-env.rst b/doc/source/rllib/rllib-env.rst index b87a2fafd..44ab79110 100644 --- a/doc/source/rllib/rllib-env.rst +++ b/doc/source/rllib/rllib-env.rst @@ -521,7 +521,7 @@ You can configure any Trainer to launch a policy server with the following confi # Use the existing trainer process to run the server. "num_workers": 0, # Disable OPE, since the rollouts are coming from online clients. - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, } Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server. diff --git a/doc/source/rllib/rllib-offline.rst b/doc/source/rllib/rllib-offline.rst index 5435cfbf4..57d7a6a4c 100644 --- a/doc/source/rllib/rllib-offline.rst +++ b/doc/source/rllib/rllib-offline.rst @@ -48,7 +48,7 @@ Then, we can tell DQN to train using these previously generated experiences with --env=CartPole-v0 \ --config='{ "input": "/tmp/cartpole-out", - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, "explore": false}' .. _is: @@ -62,7 +62,14 @@ Then, we can tell DQN to train using these previously generated experiences with --env=CartPole-v0 \ --config='{ "input": "/tmp/cartpole-out", - "off_policy_estimation_methods": ["is", "wis"], + "off_policy_estimation_methods": { + "is": { + "type": "ImportanceSampling", + }, + "wis": { + "type": "WeightedImportanceSampling", + } + }, "exploration_config": { "type": "SoftQ", "temperature": 1.0, @@ -275,10 +282,10 @@ You can configure experience input for an agent using the following options: # - Any subclass of OffPolicyEstimator, e.g. # ray.rllib.offline.estimators.is::ImportanceSampling or your own custom # subclass. - "off_policy_estimation_methods": [ - ImportanceSampling, - WeightedImportanceSampling, - ], + "off_policy_estimation_methods": { + ImportanceSampling: None, + WeightedImportanceSampling: None, + }, # Whether to run postprocess_trajectory() on the trajectory fragments from # offline inputs. Note that postprocessing will be done using the *current* # policy, not the *behavior* policy, which is typically undesirable for diff --git a/doc/source/rllib/rllib-training.rst b/doc/source/rllib/rllib-training.rst index 77b6afdd5..ce63f7288 100644 --- a/doc/source/rllib/rllib-training.rst +++ b/doc/source/rllib/rllib-training.rst @@ -574,10 +574,14 @@ The following is a list of the common algorithm hyper-parameters: # - Any subclass of OffPolicyEstimator, e.g. # ray.rllib.offline.estimators.is::ImportanceSampling or your own custom # subclass. - "off_policy_estimation_methods": [ - ImportanceSampling, - WeightedImportanceSampling, - ], + "off_policy_estimation_methods": { + "is": { + "type": ImportanceSampling, + }, + "wis": { + "type": WeightedImportanceSampling, + } + }, # Whether to run postprocess_trajectory() on the trajectory fragments from # offline inputs. Note that postprocessing will be done using the *current* # policy, not the *behavior* policy, which is typically undesirable for diff --git a/python/ray/util/ml_utils/dict.py b/python/ray/util/ml_utils/dict.py index 706350161..1dc09e60b 100644 --- a/python/ray/util/ml_utils/dict.py +++ b/python/ray/util/ml_utils/dict.py @@ -26,6 +26,7 @@ def deep_update( new_keys_allowed: bool = False, allow_new_subkey_list: Optional[List[str]] = None, override_all_if_type_changes: Optional[List[str]] = None, + override_all_key_list: Optional[List[str]] = None, ) -> dict: """Updates original dict with values from new_dict recursively. @@ -37,22 +38,29 @@ def deep_update( original: Dictionary with default values. new_dict: Dictionary with values to be updated new_keys_allowed: Whether new keys are allowed. - allow_new_subkey_list (Optional[List[str]]): List of keys that + allow_new_subkey_list: List of keys that correspond to dict values where new subkeys can be introduced. This is only at the top level. - override_all_if_type_changes(Optional[List[str]]): List of top level + override_all_if_type_changes: List of top level keys with value=dict, for which we always simply override the entire value (dict), iff the "type" key in that value dict changes. + override_all_key_list: List of top level keys + for which we override the entire value if the key is in the new_dict. """ allow_new_subkey_list = allow_new_subkey_list or [] override_all_if_type_changes = override_all_if_type_changes or [] + override_all_key_list = override_all_key_list or [] for k, value in new_dict.items(): if k not in original and not new_keys_allowed: raise Exception("Unknown config parameter `{}` ".format(k)) # Both orginal value and new one are dicts. - if isinstance(original.get(k), dict) and isinstance(value, dict): + if ( + isinstance(original.get(k), dict) + and isinstance(value, dict) + and k not in override_all_key_list + ): # Check old type vs old one. If different, override entire value. if ( k in override_all_if_type_changes @@ -63,10 +71,20 @@ def deep_update( original[k] = value # Allowed key -> ok to add new subkeys. elif k in allow_new_subkey_list: - deep_update(original[k], value, True) + deep_update( + original[k], + value, + True, + override_all_key_list=override_all_key_list, + ) # Non-allowed key. else: - deep_update(original[k], value, new_keys_allowed) + deep_update( + original[k], + value, + new_keys_allowed, + override_all_key_list=override_all_key_list, + ) # Original value not a dict OR new value not a dict: # Override entire value. else: diff --git a/release/rllib_tests/learning_tests/yaml_files/marwil-halfcheetahbulletenv-v0.yaml b/release/rllib_tests/learning_tests/yaml_files/marwil-halfcheetahbulletenv-v0.yaml index 27d59e39c..049af63ff 100644 --- a/release/rllib_tests/learning_tests/yaml_files/marwil-halfcheetahbulletenv-v0.yaml +++ b/release/rllib_tests/learning_tests/yaml_files/marwil-halfcheetahbulletenv-v0.yaml @@ -11,7 +11,7 @@ marwil-halfcheetahbulletenv-v0: input: ["~/halfcheetah_expert_sac.zip"] actions_in_input_normalized: true # Switch off input evaluation (data does not contain action probs). - off_policy_estimation_methods: [] + off_policy_estimation_methods: {} num_gpus: 1 diff --git a/rllib/BUILD b/rllib/BUILD index 2cda0d33f..1da27e935 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -24,6 +24,7 @@ # - `evaluation` directory tests. # - `execution` directory tests. # - `models` directory tests. +# - `offline` directory tests. # - `policy` directory tests. # - `utils` directory tests. @@ -1140,7 +1141,7 @@ py_test( "--env", "CartPole-v0", "--run", "DQN", "--stop", "'{\"training_iteration\": 1}'", - "--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'" + "--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": {\"wis\": {\"type\": \"wis\"}, \"is\": {\"type\": \"is\"}}, \"exploration_config\": {\"type\": \"SoftQ\"}}'" ] ) @@ -1531,6 +1532,20 @@ py_test( srcs = ["models/tests/test_preprocessors.py"] ) +# -------------------------------------------------------------------- +# Offline +# rllib/offline/ +# +# Tag: offline +# -------------------------------------------------------------------- + +py_test( + name = "test_ope", + tags = ["team:ml", "offline", "torch_only"], + size = "medium", + srcs = ["offline/estimators/tests/test_ope.py"] +) + # -------------------------------------------------------------------- # Policies # rllib/policy/ diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 773bab9d5..e39af13e6 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -34,7 +34,6 @@ from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.env.env_context import EnvContext from ray.rllib.env.utils import _gym_env_creator from ray.rllib.evaluation.episode import Episode -from ray.rllib.utils import force_list from ray.rllib.evaluation.metrics import ( collect_episodes, collect_metrics, @@ -189,6 +188,9 @@ class Trainer(Trainable): "replay_buffer_config", ] + # List of keys that are always fully overridden if present in any dict or sub-dict + _override_all_key_list = ["off_policy_estimation_methods"] + @PublicAPI def __init__( self, @@ -1748,6 +1750,7 @@ class Trainer(Trainable): _allow_unknown_configs, cls._allow_unknown_subkeys, cls._override_all_subkeys_if_type_changes, + cls._override_all_key_list, ) @staticmethod @@ -1924,9 +1927,22 @@ class Trainer(Trainable): error=False, ) config["off_policy_estimation_methods"] = input_evaluation - config["off_policy_estimation_methods"] = force_list( - config["off_policy_estimation_methods"] - ) + if isinstance(config["off_policy_estimation_methods"], list) or isinstance( + config["off_policy_estimation_methods"], tuple + ): + ope_dict = { + str(ope): {"type": ope} for ope in self.off_policy_estimation_methods + } + deprecation_warning( + old="config.off_policy_estimation_methods={}".format( + self.off_policy_estimation_methods + ), + new="config.off_policy_estimation_methods={}".format( + ope_dict, + ), + error=False, + ) + config["off_policy_estimation_methods"] = ope_dict # Check model config. # If no preprocessing, propagate into model's config as well diff --git a/rllib/agents/trainer_config.py b/rllib/agents/trainer_config.py index 671869f79..e7bfc0b61 100644 --- a/rllib/agents/trainer_config.py +++ b/rllib/agents/trainer_config.py @@ -169,7 +169,7 @@ class TrainerConfig: self.input_ = "sampler" self.input_config = {} self.actions_in_input_normalized = False - self.off_policy_estimation_methods = [] + self.off_policy_estimation_methods = {} self.postprocess_inputs = False self.shuffle_buffer_size = 0 self.output = None @@ -932,15 +932,22 @@ class TrainerConfig: when the offline file has been generated by another RLlib algorithm (e.g. PPO or SAC), while "normalize_actions" was set to True. input_evaluation: DEPRECATED: Use `off_policy_estimation_methods` instead! - off_policy_estimation_methods: Specify how to evaluate the current policy. + off_policy_estimation_methods: Specify how to evaluate the current policy, + along with any optional config parameters. This only has an effect when reading offline experiences ("input" is not "sampler"). - Available options: - - "simulation": Run the environment in the background, but use - this data for evaluation only and not for learning. - - Any subclass of OffPolicyEstimator, e.g. - ray.rllib.offline.estimators.is::ImportanceSampling or your own custom - subclass. + Available keys: + - {ope_method_name: {"type": ope_type, ...}} where `ope_method_name` + is a user-defined string to save the OPE results under, and + `ope_type` can be: + - "simulation": Run the environment in the background, but use + this data for evaluation only and not for learning. + - Any subclass of OffPolicyEstimator, e.g. + ray.rllib.offline.estimators.is::ImportanceSampling + or your own custom subclass. + You can also add additional config arguments to be passed to the + OffPolicyEstimator in the dict, e.g. + {"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}} postprocess_inputs: Whether to run postprocess_trajectory() on the trajectory fragments from offline inputs. Note that postprocessing will be done using the *current* policy, not the *behavior* policy, which @@ -978,9 +985,25 @@ class TrainerConfig: ), error=True, ) - self.off_policy_estimation_methods = input_evaluation + if isinstance(off_policy_estimation_methods, list) or isinstance( + off_policy_estimation_methods, tuple + ): + ope_dict = { + str(ope): {"type": ope} for ope in off_policy_estimation_methods + } + deprecation_warning( + old="offline_data(off_policy_estimation_methods={}".format( + off_policy_estimation_methods + ), + new="offline_data(off_policy_estimation_methods={}".format( + ope_dict, + ), + error=False, + ) + off_policy_estimation_methods = ope_dict if off_policy_estimation_methods is not None: self.off_policy_estimation_methods = off_policy_estimation_methods + if postprocess_inputs is not None: self.postprocess_inputs = postprocess_inputs if shuffle_buffer_size is not None: diff --git a/rllib/algorithms/bc/bc.py b/rllib/algorithms/bc/bc.py index 6b038a1a4..2350d35fd 100644 --- a/rllib/algorithms/bc/bc.py +++ b/rllib/algorithms/bc/bc.py @@ -49,7 +49,7 @@ class BCConfig(MARWILConfig): # not important for behavioral cloning. self.postprocess_inputs = False # No reward estimation. - self.off_policy_estimation_methods = [] + self.off_policy_estimation_methods = {} # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index c2fd1d5a1..5ac0dfffb 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -67,7 +67,7 @@ class CQLConfig(SACConfig): # Changes to Trainer's/SACConfig's default: # .offline_data() - self.off_policy_estimation_methods = [] + self.off_policy_estimation_methods = {} # .reporting() self.min_sample_timesteps_per_reporting = 0 diff --git a/rllib/algorithms/cql/tests/test_cql.py b/rllib/algorithms/cql/tests/test_cql.py index 0d03c3907..44c22df30 100644 --- a/rllib/algorithms/cql/tests/test_cql.py +++ b/rllib/algorithms/cql/tests/test_cql.py @@ -5,6 +5,7 @@ import unittest import ray from ray.rllib.algorithms import cql +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import ( check_compute_single_action, @@ -51,7 +52,7 @@ class TestCQL(unittest.TestCase): # RLlib algorithm (e.g. PPO or SAC). actions_in_input_normalized=False, # Switch on off-policy evaluation. - off_policy_estimation_methods=["is"], + off_policy_estimation_methods={"is": {"type": ImportanceSampling}}, ) .training( clip_actions=False, diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index dcc87f3f7..6243e175c 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -103,9 +103,10 @@ class MARWILConfig(TrainerConfig): # the same line. self.input_ = "sampler" # Use importance sampling estimators for reward. - self.off_policy_estimation_methods = [ - ImportanceSampling, WeightedImportanceSampling - ] + self.off_policy_estimation_methods = { + "is": {"type": ImportanceSampling}, + "wis": {"type": WeightedImportanceSampling}, + } self.postprocess_inputs = True self.lr = 1e-4 self.train_batch_size = 2000 diff --git a/rllib/evaluation/metrics.py b/rllib/evaluation/metrics.py index b90f922f4..565f23b48 100644 --- a/rllib/evaluation/metrics.py +++ b/rllib/evaluation/metrics.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import ray from ray import ObjectRef from ray.actor import ActorHandle -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate +from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index a0202a693..5d2dd0012 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -33,8 +33,14 @@ from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.models import ModelCatalog from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate -from ray.rllib.offline.estimators import ImportanceSampling, WeightedImportanceSampling +from ray.rllib.offline.estimators import ( + OffPolicyEstimate, + OffPolicyEstimator, + ImportanceSampling, + WeightedImportanceSampling, + DirectMethod, + DoublyRobust, +) from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap @@ -242,7 +248,7 @@ class RolloutWorker(ParallelIteratorWorker): input_creator: Callable[ [IOContext], InputReader ] = lambda ioctx: ioctx.default_sampler_input(), - off_policy_estimation_methods: List[str] = frozenset([]), + off_policy_estimation_methods: Optional[Dict[str, Dict]] = None, output_creator: Callable[ [IOContext], OutputWriter ] = lambda ioctx: NoopOutput(), @@ -339,14 +345,25 @@ class RolloutWorker(ParallelIteratorWorker): DefaultCallbacks for training/policy/rollout-worker callbacks. input_creator: Function that returns an InputReader object for loading previous generated experiences. - off_policy_estimation_methods: How to evaluate the policy performance. - Setting this only makes sense when the input is reading offline data. - Available options: - - "simulation" (str): Run the environment in the background, but use + off_policy_estimation_methods: A dict that specifies how to + evaluate the current policy. + This only has an effect when reading offline experiences + ("input" is not "sampler"). + Available key-value pairs: + - {"simulation": None}: Run the environment in the background, but use this data for evaluation only and not for learning. - - Any subclass (type) of the OffPolicyEstimator API class, e.g. - `ray.rllib.offline.estimators.importance_sampling::ImportanceSampling` + - {ope_name: {"type": ope_type, args}}. where `ope_name` is an arbitrary + string under which the metrics for this OPE estimator are saved, + and `ope_type` can be any subclass of OffPolicyEstimator, e.g. + ray.rllib.offline.estimators::ImportanceSampling or your own custom subclass. + You can also add additional config arguments to be passed to the + OffPolicyEstimator e.g. + off_policy_estimation_methods = { + "dr_qreg": {"type": DoublyRobust, "q_model_type": "qreg"}, + "dm_64": {"type": DirectMethod, "batch_size": 64}, + } + See ray/rllib/offline/estimators for more information. output_creator: Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs: If using num_envs_per_worker > 1, @@ -702,41 +719,50 @@ class RolloutWorker(ParallelIteratorWorker): log_dir, policy_config, worker_index, self ) self.reward_estimators: List[OffPolicyEstimator] = [] - for method in off_policy_estimation_methods: - if method == "is": - method = ImportanceSampling + ope_types = { + "is": ImportanceSampling, + "wis": WeightedImportanceSampling, + "dm": DirectMethod, + "dr": DoublyRobust, + } + off_policy_estimation_methods = off_policy_estimation_methods or {} + for name, method_config in off_policy_estimation_methods.items(): + method_type = method_config.pop("type") + if method_type in ope_types: deprecation_warning( - old="config.off_policy_estimation_methods=[is]", - new="from ray.rllib.offline.estimators import " - f"{method.__name__}; config.off_policy_estimation_methods=" - f"[{method.__name__}]", + old=method_type, + new=str(ope_types[method_type]), error=False, ) - elif method == "wis": - method = WeightedImportanceSampling - deprecation_warning( - old="config.off_policy_estimation_methods=[wis]", - new="from ray.rllib.offline.estimators import " - f"{method.__name__}; config.off_policy_estimation_methods=" - f"[{method.__name__}]", - error=False, - ) - - if method == "simulation": + method_type = ope_types[method_type] + if name == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics." ) sample_async = True - elif isinstance(method, type) and issubclass(method, OffPolicyEstimator): + # TODO: Allow for this to be a full classpath string as well, then construct + # this with our `from_config` util. + elif isinstance(method_type, type) and issubclass( + method_type, OffPolicyEstimator + ): + gamma = self.io_context.worker.policy_config["gamma"] + # Grab a reference to the current model + keys = list(self.io_context.worker.policy_map.keys()) + if len(keys) > 1: + raise NotImplementedError( + "Off-policy estimation is not implemented for multi-agent. " + "You can set `input_evaluation: []` to resolve this." + ) + policy = self.io_context.worker.get_policy(keys[0]) self.reward_estimators.append( - method.create_from_io_context(self.io_context) + method_type(name=name, policy=policy, gamma=gamma, **method_config) ) else: raise ValueError( - f"Unknown evaluation method: {method}! Must be " - "either `simulation` or a sub-class of ray.rllib.offline." - "off_policy_estimator::OffPolicyEstimator" + f"Unknown off_policy_estimation type: {method_type}! Must be " + "either `simulation|is|wis|dm|dr` or a sub-class of ray.rllib." + "offline.estimators.off_policy_estimator::OffPolicyEstimator" ) render = False @@ -866,9 +892,8 @@ class RolloutWorker(ParallelIteratorWorker): # Do off-policy estimation, if needed. if self.reward_estimators: - for sub_batch in batch.split_by_episode(): - for estimator in self.reward_estimators: - estimator.process(sub_batch) + for estimator in self.reward_estimators: + estimator.process(batch) if log_once("sample_end"): logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch))) @@ -1138,7 +1163,7 @@ class RolloutWorker(ParallelIteratorWorker): out = self.sampler.get_metrics() else: out = [] - # Get metrics from our reward-estimators (if any). + # Get metrics from our reward estimators (if any). for m in self.reward_estimators: out.extend(m.get_metrics()) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 7213b037d..232599e74 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -627,7 +627,7 @@ class WorkerSet: ) if config["input"] == "sampler": - off_policy_estimation_methods = [] + off_policy_estimation_methods = {} else: off_policy_estimation_methods = config["off_policy_estimation_methods"] diff --git a/rllib/examples/serving/cartpole_server.py b/rllib/examples/serving/cartpole_server.py index e024258a0..c05e5e13f 100755 --- a/rllib/examples/serving/cartpole_server.py +++ b/rllib/examples/serving/cartpole_server.py @@ -165,7 +165,7 @@ if __name__ == "__main__": # Use n worker processes to listen on different ports. "num_workers": args.num_workers, # Disable OPE, since the rollouts are coming from online clients. - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, # Create a "chatty" client/server or not. "callbacks": MyCallbacks if args.callbacks_verbose else None, # DL framework to use. diff --git a/rllib/examples/serving/unity3d_server.py b/rllib/examples/serving/unity3d_server.py index 5e1132aa9..693e7427a 100755 --- a/rllib/examples/serving/unity3d_server.py +++ b/rllib/examples/serving/unity3d_server.py @@ -132,7 +132,7 @@ if __name__ == "__main__": # Use n worker processes to listen on different ports. "num_workers": args.num_workers, # Disable OPE, since the rollouts are coming from online clients. - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, # Other settings. "train_batch_size": 256, "rollout_fragment_length": 20, diff --git a/rllib/offline/estimators/__init__.py b/rllib/offline/estimators/__init__.py index 0a6d41d90..ef5eec9b8 100644 --- a/rllib/offline/estimators/__init__.py +++ b/rllib/offline/estimators/__init__.py @@ -2,8 +2,18 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling from ray.rllib.offline.estimators.weighted_importance_sampling import ( WeightedImportanceSampling, ) +from ray.rllib.offline.estimators.direct_method import DirectMethod +from ray.rllib.offline.estimators.doubly_robust import DoublyRobust +from ray.rllib.offline.estimators.off_policy_estimator import ( + OffPolicyEstimate, + OffPolicyEstimator, +) __all__ = [ + "OffPolicyEstimator", + "OffPolicyEstimate", "ImportanceSampling", "WeightedImportanceSampling", + "DirectMethod", + "DoublyRobust", ] diff --git a/rllib/offline/estimators/direct_method.py b/rllib/offline/estimators/direct_method.py new file mode 100644 index 000000000..91b4f486b --- /dev/null +++ b/rllib/offline/estimators/direct_method.py @@ -0,0 +1,170 @@ +from typing import Tuple, List, Generator +from ray.rllib.offline.estimators.off_policy_estimator import ( + OffPolicyEstimator, + OffPolicyEstimate, +) +from ray.rllib.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI, override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel +from ray.rllib.offline.estimators.qreg_torch_model import QRegTorchModel +from gym.spaces import Discrete +import numpy as np + +torch, nn = try_import_torch() + + +# TODO (rohan): replace with AIR/parallel workers +# (And find a better name than `should_train`) +@DeveloperAPI +def k_fold_cv( + batch: SampleBatchType, k: int, should_train: bool = True +) -> Generator[Tuple[List[SampleBatch]], None, None]: + """Utility function that returns a k-fold cross validation generator + over episodes from the given batch. If the number of episodes in the + batch is less than `k` or `should_train` is set to False, yields an empty + list for train_episodes and all the episodes in test_episodes. + + Args: + batch: A SampleBatch of episodes to split + k: Number of cross-validation splits + should_train: True by default. If False, yield [], [episodes]. + + Returns: + A tuple with two lists of SampleBatches (train_episodes, test_episodes) + """ + episodes = batch.split_by_episode() + n_episodes = len(episodes) + if n_episodes < k or not should_train: + yield [], episodes + return + n_fold = n_episodes // k + for i in range(k): + train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :] + if i != k - 1: + test_episodes = episodes[i * n_fold : (i + 1) * n_fold] + else: + # Append remaining episodes onto the last test_episodes + test_episodes = episodes[i * n_fold :] + yield train_episodes, test_episodes + return + + +@DeveloperAPI +class DirectMethod(OffPolicyEstimator): + """The Direct Method estimator. + + DM estimator described in https://arxiv.org/pdf/1511.03722.pdf""" + + @override(OffPolicyEstimator) + def __init__( + self, + name: str, + policy: Policy, + gamma: float, + q_model_type: str = "fqe", + k: int = 5, + **kwargs, + ): + """ + Initializes a Direct Method OPE Estimator. + + Args: + name: string to save OPE results under + policy: Policy to evaluate. + gamma: Discount factor of the environment. + q_model_type: Either "fqe" for Fitted Q-Evaluation + or "qreg" for Q-Regression, or a custom model that implements: + - `estimate_q(states,actions)` + - `estimate_v(states, action_probs)` + k: k-fold cross validation for training model and evaluating OPE + kwargs: Optional arguments for the specified Q model + """ + + super().__init__(name, policy, gamma) + # TODO (rohan): Add support for continuous action spaces + assert isinstance( + policy.action_space, Discrete + ), "DM Estimator only supports discrete action spaces!" + assert ( + policy.config["batch_mode"] == "complete_episodes" + ), "DM Estimator only supports `batch_mode`=`complete_episodes`" + + # TODO (rohan): Add support for TF! + if policy.framework == "torch": + if q_model_type == "qreg": + model_cls = QRegTorchModel + elif q_model_type == "fqe": + model_cls = FQETorchModel + else: + assert hasattr( + q_model_type, "estimate_q" + ), "q_model_type must implement `estimate_q`!" + assert hasattr( + q_model_type, "estimate_v" + ), "q_model_type must implement `estimate_v`!" + else: + raise ValueError( + f"{self.__class__.__name__}" + "estimator only supports `policy.framework`=`torch`" + ) + + self.model = model_cls( + policy=policy, + gamma=gamma, + **kwargs, + ) + self.k = k + self.losses = [] + + @override(OffPolicyEstimator) + def estimate( + self, batch: SampleBatchType, should_train: bool = True + ) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + estimates = [] + # Split data into train and test using k-fold cross validation + for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train): + + # Train Q-function + if train_episodes: + # Reinitialize model + self.model.reset() + train_batch = SampleBatch.concat_samples(train_episodes) + losses = self.train(train_batch) + self.losses.append(losses) + + # Calculate direct method OPE estimates + for episode in test_episodes: + rewards = episode["rewards"] + v_old = 0.0 + v_new = 0.0 + for t in range(episode.count): + v_old += rewards[t] * self.gamma ** t + + init_step = episode[0:1] + init_obs = np.array([init_step[SampleBatch.OBS]]) + all_actions = np.arange(self.policy.action_space.n, dtype=float) + init_step[SampleBatch.ACTIONS] = all_actions + action_probs = np.exp(self.action_log_likelihood(init_step)) + v_value = self.model.estimate_v(init_obs, action_probs) + v_new = convert_to_numpy(v_value).item() + + estimates.append( + OffPolicyEstimate( + self.name, + { + "v_old": v_old, + "v_new": v_new, + "v_gain": v_new / max(1e-8, v_old), + }, + ) + ) + return estimates + + @override(OffPolicyEstimator) + def train(self, batch: SampleBatchType): + return self.model.train_q(batch) diff --git a/rllib/offline/estimators/doubly_robust.py b/rllib/offline/estimators/doubly_robust.py new file mode 100644 index 000000000..6e66c9678 --- /dev/null +++ b/rllib/offline/estimators/doubly_robust.py @@ -0,0 +1,71 @@ +from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate +from ray.rllib.offline.estimators.direct_method import DirectMethod, k_fold_cv +from ray.rllib.utils.annotations import DeveloperAPI, override +from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.numpy import convert_to_numpy +import numpy as np + + +@DeveloperAPI +class DoublyRobust(DirectMethod): + """The Doubly Robust (DR) estimator. + + DR estimator described in https://arxiv.org/pdf/1511.03722.pdf""" + + @override(DirectMethod) + def estimate( + self, batch: SampleBatchType, should_train: bool = True + ) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + estimates = [] + # Split data into train and test using k-fold cross validation + for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train): + + # Train Q-function + if train_episodes: + # Reinitialize model + self.model.reset() + train_batch = SampleBatch.concat_samples(train_episodes) + losses = self.train(train_batch) + self.losses.append(losses) + + # Calculate doubly robust OPE estimates + for episode in test_episodes: + rewards, old_prob = episode["rewards"], episode["action_prob"] + new_prob = np.exp(self.action_log_likelihood(episode)) + + v_old = 0.0 + v_new = 0.0 + q_values = self.model.estimate_q( + episode[SampleBatch.OBS], episode[SampleBatch.ACTIONS] + ) + q_values = convert_to_numpy(q_values) + + all_actions = np.zeros([episode.count, self.policy.action_space.n]) + all_actions[:] = np.arange(self.policy.action_space.n) + # Two transposes required for torch.distributions to work + tmp_episode = episode.copy() + tmp_episode[SampleBatch.ACTIONS] = all_actions.T + action_probs = np.exp(self.action_log_likelihood(tmp_episode)).T + v_values = self.model.estimate_v(episode[SampleBatch.OBS], action_probs) + v_values = convert_to_numpy(v_values) + + for t in reversed(range(episode.count)): + v_old = rewards[t] + self.gamma * v_old + v_new = v_values[t] + (new_prob[t] / old_prob[t]) * ( + rewards[t] + self.gamma * v_new - q_values[t] + ) + v_new = v_new.item() + + estimates.append( + OffPolicyEstimate( + self.name, + { + "v_old": v_old, + "v_new": v_new, + "v_gain": v_new / max(1e-8, v_old), + }, + ) + ) + return estimates diff --git a/rllib/offline/estimators/fqe_torch_model.py b/rllib/offline/estimators/fqe_torch_model.py new file mode 100644 index 000000000..d7a2ed71f --- /dev/null +++ b/rllib/offline/estimators/fqe_torch_model.py @@ -0,0 +1,222 @@ +from ray.rllib.models.utils import get_initializer +from ray.rllib.policy import Policy +from typing import List, Union + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModelConfigDict, TensorType + +torch, nn = try_import_torch() + + +@DeveloperAPI +class FQETorchModel: + """Pytorch implementation of the Fitted Q-Evaluation (FQE) model from + https://arxiv.org/pdf/1911.06854.pdf + """ + + def __init__( + self, + policy: Policy, + gamma: float, + model: ModelConfigDict = None, + n_iters: int = 160, + lr: float = 1e-3, + delta: float = 1e-4, + clip_grad_norm: float = 100.0, + batch_size: int = 32, + tau: float = 0.05, + ) -> None: + """ + Args: + policy: Policy to evaluate. + gamma: Discount factor of the environment. + # The ModelConfigDict for self.q_model + model = { + "fcnet_hiddens": [8, 8], + "fcnet_activation": "relu", + "vf_share_layers": True, + }, + # Maximum number of training iterations to run on the batch + n_iters = 160, + # Learning rate for Q-function optimizer + lr = 1e-3, + # Early stopping if the mean loss < delta + delta = 1e-4, + # Clip gradients to this maximum value + clip_grad_norm = 100.0, + # Minibatch size for training Q-function + batch_size = 32, + # Polyak averaging factor for target Q-function + tau = 0.05 + """ + self.policy = policy + self.gamma = gamma + self.observation_space = policy.observation_space + self.action_space = policy.action_space + + if model is None: + model = { + "fcnet_hiddens": [8, 8], + "fcnet_activation": "relu", + "vf_share_layers": True, + } + + self.device = self.policy.device + self.q_model: TorchModelV2 = ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + self.action_space.n, + model, + framework="torch", + name="TorchQModel", + ).to(self.device) + self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + self.action_space.n, + model, + framework="torch", + name="TargetTorchQModel", + ).to(self.device) + self.n_iters = n_iters + self.lr = lr + self.delta = delta + self.clip_grad_norm = clip_grad_norm + self.batch_size = batch_size + self.tau = tau + self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr) + initializer = get_initializer("xavier_uniform", framework="torch") + # Hard update target + self.update_target(tau=1.0) + + def f(m): + if isinstance(m, nn.Linear): + initializer(m.weight) + + self.initializer = f + + def reset(self) -> None: + """Resets/Reinintializes the model weights.""" + self.q_model.apply(self.initializer) + + def train_q(self, batch: SampleBatch) -> TensorType: + """Trains self.q_model using FQE loss on given batch. + + Args: + batch: A SampleBatch of episodes to train on + + Returns: + A list of losses for each training iteration + """ + losses = [] + for _ in range(self.n_iters): + minibatch_losses = [] + batch.shuffle() + for idx in range(0, batch.count, self.batch_size): + minibatch = batch[idx : idx + self.batch_size] + obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device) + actions = torch.tensor( + minibatch[SampleBatch.ACTIONS], device=self.device + ) + rewards = torch.tensor( + minibatch[SampleBatch.REWARDS], device=self.device + ) + next_obs = torch.tensor( + minibatch[SampleBatch.NEXT_OBS], device=self.device + ) + dones = torch.tensor(minibatch[SampleBatch.DONES], device=self.device) + + # Neccessary if policy uses recurrent/attention model + num_state_inputs = 0 + for k in batch.keys(): + if k.startswith("state_in_"): + num_state_inputs += 1 + state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] + + # Compute action_probs for next_obs as in FQE + all_actions = torch.zeros([minibatch.count, self.policy.action_space.n]) + all_actions[:] = torch.arange(self.policy.action_space.n) + next_action_prob = self.policy.compute_log_likelihoods( + actions=all_actions.T, + obs_batch=next_obs, + state_batches=[minibatch[k] for k in state_keys], + prev_action_batch=minibatch[SampleBatch.ACTIONS], + prev_reward_batch=minibatch[SampleBatch.REWARDS], + actions_normalized=False, + ) + next_action_prob = ( + torch.exp(next_action_prob.T).to(self.device).detach() + ) + + q_values, _ = self.q_model({"obs": obs}, [], None) + q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze() + with torch.no_grad(): + next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None) + next_v = torch.sum(next_q_values * next_action_prob, axis=-1) + targets = rewards + ~dones * self.gamma * next_v + loss = (targets - q_acts) ** 2 + loss = torch.mean(loss) + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad.clip_grad_norm_( + self.q_model.variables(), self.clip_grad_norm + ) + self.optimizer.step() + minibatch_losses.append(loss.item()) + iter_loss = sum(minibatch_losses) / len(minibatch_losses) + losses.append(iter_loss) + if iter_loss < self.delta: + break + self.update_target() + return losses + + def estimate_q( + self, + obs: Union[TensorType, List[TensorType]], + actions: Union[TensorType, List[TensorType]] = None, + ) -> TensorType: + """Given `obs`, a list or array or tensor of observations, + compute the Q-values for `obs` for all actions in the action space. + If `actions` is not None, return the Q-values for the actions provided, + else return Q-values for all actions for each observation in `obs`. + """ + obs = torch.tensor(obs, device=self.device) + q_values, _ = self.q_model({"obs": obs}, [], None) + if actions is not None: + actions = torch.tensor(actions, device=self.device, dtype=int) + q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze() + return q_values.detach() + + def estimate_v( + self, + obs: Union[TensorType, List[TensorType]], + action_probs: Union[TensorType, List[TensorType]], + ) -> TensorType: + """Given `obs`, compute q-values for all actions in the action space + for each observations s in `obs`, then multiply this by `action_probs`, + the probability distribution over actions for each state s to give the + state value V(s) = sum_A pi(a|s)Q(s,a). + """ + q_values = self.estimate_q(obs) + action_probs = torch.tensor(action_probs, device=self.device) + v_values = torch.sum(q_values * action_probs, axis=-1) + return v_values.detach() + + def update_target(self, tau=None): + # Update_target will be called periodically to copy Q network to + # target Q network, using (soft) tau-synching. + tau = tau or self.tau + model_state_dict = self.q_model.state_dict() + # Support partial (soft) synching. + # If tau == 1.0: Full sync from Q-model to target Q-model. + target_state_dict = self.target_q_model.state_dict() + model_state_dict = { + k: tau * model_state_dict[k] + (1 - tau) * v + for k, v in target_state_dict.items() + } + + self.target_q_model.load_state_dict(model_state_dict) diff --git a/rllib/offline/estimators/importance_sampling.py b/rllib/offline/estimators/importance_sampling.py index 7138125fd..4e1d692df 100644 --- a/rllib/offline/estimators/importance_sampling.py +++ b/rllib/offline/estimators/importance_sampling.py @@ -1,42 +1,52 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate +from ray.rllib.offline.estimators.off_policy_estimator import ( + OffPolicyEstimator, + OffPolicyEstimate, +) from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.typing import SampleBatchType +from typing import List +import numpy as np @DeveloperAPI class ImportanceSampling(OffPolicyEstimator): """The step-wise IS estimator. - Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf""" + Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf, + https://arxiv.org/pdf/1911.06854.pdf""" @override(OffPolicyEstimator) - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]: self.check_can_estimate_for(batch) + estimates = [] + for sub_batch in batch.split_by_episode(): + rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"] + new_prob = np.exp(self.action_log_likelihood(sub_batch)) - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_log_likelihood(batch) + # calculate importance ratios + p = [] + for t in range(sub_batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) + # calculate stepwise IS estimate + v_old = 0.0 + v_new = 0.0 + for t in range(sub_batch.count): + v_old += rewards[t] * self.gamma ** t + v_new += p[t] * rewards[t] * self.gamma ** t - # calculate stepwise IS estimate - V_prev, V_step_IS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma ** t - V_step_IS += p[t] * rewards[t] * self.gamma ** t - - estimation = OffPolicyEstimate( - "importance_sampling", - { - "V_prev": V_prev, - "V_step_IS": V_step_IS, - "V_gain_est": V_step_IS / max(1e-8, V_prev), - }, - ) - return estimation + estimates.append( + OffPolicyEstimate( + self.name, + { + "v_old": v_old, + "v_new": v_new, + "v_gain": v_new / max(1e-8, v_old), + }, + ) + ) + return estimates diff --git a/rllib/offline/estimators/off_policy_estimator.py b/rllib/offline/estimators/off_policy_estimator.py new file mode 100644 index 000000000..199f024e3 --- /dev/null +++ b/rllib/offline/estimators/off_policy_estimator.py @@ -0,0 +1,188 @@ +from collections import namedtuple +import logging +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.offline.io_context import IOContext +from ray.rllib.utils.annotations import Deprecated +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import TensorType, SampleBatchType +from typing import List + +logger = logging.getLogger(__name__) + +OffPolicyEstimate = DeveloperAPI( + namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"]) +) + + +@DeveloperAPI +class OffPolicyEstimator: + """Interface for an off policy reward estimator.""" + + @DeveloperAPI + def __init__(self, name: str, policy: Policy, gamma: float): + """Initializes an OffPolicyEstimator instance. + + Args: + name: string to save OPE results under + policy: Policy to evaluate. + gamma: Discount factor of the environment. + """ + self.name = name + self.policy = policy + self.gamma = gamma + self.new_estimates = [] + + @DeveloperAPI + def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]: + """Returns a list of off policy estimates for the given batch of episodes. + + Args: + batch: The batch to calculate the off policy estimates (OPE) on. + + Returns: + The off-policy estimates (OPE) calculated on the given batch. + """ + raise NotImplementedError + + @DeveloperAPI + def train(self, batch: SampleBatchType) -> TensorType: + """Trains an Off-Policy Estimator on a batch of experiences. + A model-based estimator should override this and train + a transition, value, or reward model. + + Args: + batch: The batch to train the model on + + Returns: + any optional training/loss metrics from the model + """ + pass + + @DeveloperAPI + def action_log_likelihood(self, batch: SampleBatchType) -> TensorType: + """Returns log likelihood for actions in given batch for policy. + + Computes likelihoods by passing the observations through the current + policy's `compute_log_likelihoods()` method + + Args: + batch: The SampleBatch or MultiAgentBatch to calculate action + log likelihoods from. This batch/batches must contain OBS + and ACTIONS keys. + + Returns: + The probabilities of the actions in the batch, given the + observations and the policy. + """ + num_state_inputs = 0 + for k in batch.keys(): + if k.startswith("state_in_"): + num_state_inputs += 1 + state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] + log_likelihoods: TensorType = self.policy.compute_log_likelihoods( + actions=batch[SampleBatch.ACTIONS], + obs_batch=batch[SampleBatch.OBS], + state_batches=[batch[k] for k in state_keys], + prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), + actions_normalized=True, + ) + log_likelihoods = convert_to_numpy(log_likelihoods) + return log_likelihoods + + @DeveloperAPI + def check_can_estimate_for(self, batch: SampleBatchType) -> None: + """Checks if we support off policy estimation (OPE) on given batch. + + Args: + batch: The batch to check. + + Raises: + ValueError: In case `action_prob` key is not in batch OR batch + is a MultiAgentBatch. + """ + + if isinstance(batch, MultiAgentBatch): + raise ValueError( + "Off-Policy Estimation is not implemented for multi-agent batches. " + "You can set `off_policy_estimation_methods: {}` to resolve this." + ) + + if "action_prob" not in batch: + raise ValueError( + "Off-policy estimation is not possible unless the inputs " + "include action probabilities (i.e., the policy is stochastic " + "and emits the 'action_prob' key). For DQN this means using " + "`exploration_config: {type: 'SoftQ'}`. You can also set " + "`off_policy_estimation_methods: {}` to disable estimation." + ) + + @DeveloperAPI + def process(self, batch: SampleBatchType) -> None: + """Computes off policy estimates (OPE) on batch and stores results. + Thus-far collected results can be retrieved then by calling + `self.get_metrics` (which flushes the internal results storage). + Args: + batch: The batch to process (call `self.estimate()` on) and + store results (OPEs) for. + """ + self.new_estimates.extend(self.estimate(batch)) + + @DeveloperAPI + def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]: + """Returns list of new episode metric estimates since the last call. + + Args: + get_losses: If True, also return self.losses for the OPE estimator + Returns: + out: List of OffPolicyEstimate objects. + losses: List of training losses for the estimator. + """ + out = self.new_estimates + self.new_estimates = [] + if hasattr(self, "losses"): + losses = self.losses + self.losses = [] + if get_losses: + return out, losses + return out + + # TODO (rohan): Remove deprecated methods; set to error=True because changing + # from one episode per SampleBatch to full SampleBatch is a breaking change anyway + + @Deprecated(help="OffPolicyEstimator.__init__(policy, gamma, config)", error=False) + @classmethod + @DeveloperAPI + def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator": + """Creates an off-policy estimator from an IOContext object. + Extracts Policy and gamma (discount factor) information from the + IOContext. + Args: + ioctx: The IOContext object to create the OffPolicyEstimator + from. + Returns: + The OffPolicyEstimator object created from the IOContext object. + """ + gamma = ioctx.worker.policy_config["gamma"] + # Grab a reference to the current model + keys = list(ioctx.worker.policy_map.keys()) + if len(keys) > 1: + raise NotImplementedError( + "Off-policy estimation is not implemented for multi-agent. " + "You can set `input_evaluation: []` to resolve this." + ) + policy = ioctx.worker.get_policy(keys[0]) + config = ioctx.input_config.get("estimator_config", {}) + return cls(policy, gamma, config) + + @Deprecated(new="OffPolicyEstimator.create_from_io_context", error=True) + @DeveloperAPI + def create(self, *args, **kwargs): + return self.create_from_io_context(*args, **kwargs) + + @Deprecated(new="OffPolicyEstimator.compute_log_likelihoods", error=False) + @DeveloperAPI + def action_prob(self, *args, **kwargs): + return self.compute_log_likelihoods(*args, **kwargs) diff --git a/rllib/offline/estimators/qreg_torch_model.py b/rllib/offline/estimators/qreg_torch_model.py new file mode 100644 index 000000000..4243d52e6 --- /dev/null +++ b/rllib/offline/estimators/qreg_torch_model.py @@ -0,0 +1,224 @@ +from ray.rllib.models.utils import get_initializer +from ray.rllib.policy import Policy +from typing import List, Union +import numpy as np + +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType, ModelConfigDict + +torch, nn = try_import_torch() + + +@DeveloperAPI +class QRegTorchModel: + """Pytorch implementation of the Q-Reg model from + https://arxiv.org/pdf/1911.06854.pdf + """ + + def __init__( + self, + policy: Policy, + gamma: float, + model: ModelConfigDict = None, + n_iters: int = 160, + lr: float = 1e-3, + delta: float = 1e-4, + clip_grad_norm: float = 100.0, + batch_size: int = 32, + ) -> None: + """ + Args: + policy: Policy to evaluate. + gamma: Discount factor of the environment. + # The ModelConfigDict for self.q_model + model = { + "fcnet_hiddens": [8, 8], + "fcnet_activation": "relu", + "vf_share_layers": True, + }, + # Maximum number of training iterations to run on the batch + n_iters = 160, + # Learning rate for Q-function optimizer + lr = 1e-3, + # Early stopping if the mean loss < delta + delta = 1e-4, + # Clip gradients to this maximum value + clip_grad_norm = 100.0, + # Minibatch size for training Q-function + batch_size = 32, + """ + self.policy = policy + self.gamma = gamma + self.observation_space = policy.observation_space + self.action_space = policy.action_space + + if model is None: + model = { + "fcnet_hiddens": [8, 8], + "fcnet_activation": "relu", + "vf_share_layers": True, + } + + self.device = self.policy.device + self.q_model: TorchModelV2 = ModelCatalog.get_model_v2( + self.observation_space, + self.action_space, + self.action_space.n, + model, + framework="torch", + name="TorchQModel", + ).to(self.device) + self.n_iters = n_iters + self.lr = lr + self.delta = delta + self.clip_grad_norm = clip_grad_norm + self.batch_size = batch_size + self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr) + initializer = get_initializer("xavier_uniform", framework="torch") + + def f(m): + if isinstance(m, nn.Linear): + initializer(m.weight) + + self.initializer = f + + def reset(self) -> None: + """Resets/Reinintializes the model weights.""" + self.q_model.apply(self.initializer) + + def train_q(self, batch: SampleBatch) -> TensorType: + """Trains self.q_model using Q-Reg loss on given batch. + + Args: + batch: A SampleBatch of episodes to train on + + Returns: + A list of losses for each training iteration + """ + losses = [] + obs = torch.tensor(batch[SampleBatch.OBS], device=self.device) + actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device) + ps = torch.zeros([batch.count], device=self.device) + returns = torch.zeros([batch.count], device=self.device) + discounts = torch.zeros([batch.count], device=self.device) + + # Neccessary if policy uses recurrent/attention model + num_state_inputs = 0 + for k in batch.keys(): + if k.startswith("state_in_"): + num_state_inputs += 1 + state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] + + # get rewards, old_prob, new_prob + rewards = batch[SampleBatch.REWARDS] + old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP]) + new_log_prob = ( + self.policy.compute_log_likelihoods( + actions=batch[SampleBatch.ACTIONS], + obs_batch=batch[SampleBatch.OBS], + state_batches=[batch[k] for k in state_keys], + prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), + actions_normalized=False, + ) + .detach() + .cpu() + ) + prob_ratio = torch.exp(new_log_prob - old_log_prob) + + eps_begin = 0 + for episode in batch.split_by_episode(): + eps_end = eps_begin + episode.count + + # calculate importance ratios and returns + + for t in range(episode.count): + discounts[eps_begin + t] = self.gamma ** t + if t == 0: + pt_prev = 1.0 + else: + pt_prev = ps[eps_begin + t - 1] + ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t] + + # O(n^3) + # ret = 0 + # for t_prime in range(t, episode.count): + # gamma = self.gamma ** (t_prime - t) + # rho_t_1_t_prime = 1.0 + # for k in range(t + 1, min(t_prime + 1, episode.count)): + # rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k] + # r = rewards[eps_begin + t_prime] + # ret += gamma * rho_t_1_t_prime * r + + # O(n^2) + ret = 0 + rho = 1 + for t_ in reversed(range(t, episode.count)): + ret = rewards[eps_begin + t_] + self.gamma * rho * ret + rho = prob_ratio[eps_begin + t_] + + returns[eps_begin + t] = ret + + # Update before next episode + eps_begin = eps_end + + indices = np.arange(batch.count) + for _ in range(self.n_iters): + minibatch_losses = [] + np.random.shuffle(indices) + for idx in range(0, batch.count, self.batch_size): + idxs = indices[idx : idx + self.batch_size] + q_values, _ = self.q_model({"obs": obs[idxs]}, [], None) + q_acts = torch.gather( + q_values, -1, actions[idxs].unsqueeze(-1) + ).squeeze() + loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts) ** 2 + loss = torch.mean(loss) + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad.clip_grad_norm_( + self.q_model.variables(), self.clip_grad_norm + ) + self.optimizer.step() + minibatch_losses.append(loss.item()) + iter_loss = sum(minibatch_losses) / len(minibatch_losses) + losses.append(iter_loss) + if iter_loss < self.delta: + break + return losses + + def estimate_q( + self, + obs: Union[TensorType, List[TensorType]], + actions: Union[TensorType, List[TensorType]] = None, + ) -> TensorType: + """Given `obs`, a list or array or tensor of observations, + compute the Q-values for `obs` for all actions in the action space. + If `actions` is not None, return the Q-values for the actions provided, + else return Q-values for all actions for each observation in `obs`. + """ + obs = torch.tensor(obs, device=self.device) + q_values, _ = self.q_model({"obs": obs}, [], None) + if actions is not None: + actions = torch.tensor(actions, device=self.device, dtype=int) + q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze() + return q_values.detach() + + def estimate_v( + self, + obs: Union[TensorType, List[TensorType]], + action_probs: Union[TensorType, List[TensorType]], + ) -> TensorType: + """Given `obs`, compute q-values for all actions in the action space + for each observations s in `obs`, then multiply this by `action_probs`, + the probability distribution over actions for each state s to give the + state value V(s) = sum_A pi(a|s)Q(s,a). + """ + q_values = self.estimate_q(obs) + action_probs = torch.tensor(action_probs, device=self.device) + v_values = torch.sum(q_values * action_probs, axis=-1) + return v_values.detach() diff --git a/rllib/offline/estimators/tests/test_ope.py b/rllib/offline/estimators/tests/test_ope.py new file mode 100644 index 000000000..c7e567443 --- /dev/null +++ b/rllib/offline/estimators/tests/test_ope.py @@ -0,0 +1,200 @@ +import unittest +import ray +from ray import tune +from ray.rllib.algorithms.dqn import DQNConfig +from ray.rllib.offline.estimators import ( + ImportanceSampling, + WeightedImportanceSampling, + DirectMethod, + DoublyRobust, +) +from ray.rllib.offline.json_reader import JsonReader +from pathlib import Path +import os +import numpy as np +import gym + + +class TestOPE(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=4) + + def tearDown(self): + ray.shutdown() + + @classmethod + def setUpClass(cls): + ray.init(ignore_reinit_error=True) + rllib_dir = Path(__file__).parent.parent.parent.parent + print("rllib dir={}".format(rllib_dir)) + data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json") + print("data_file={} exists={}".format(data_file, os.path.isfile(data_file))) + + env_name = "CartPole-v0" + cls.gamma = 0.99 + train_steps = 20000 + n_batches = 20 # Approx. equal to n_episodes + n_eval_episodes = 100 + + config = ( + DQNConfig() + .environment(env=env_name) + .training(gamma=cls.gamma) + .rollouts(num_rollout_workers=3) + .exploration( + explore=True, + exploration_config={ + "type": "SoftQ", + "temperature": 1.0, + }, + ) + .framework("torch") + .rollouts(batch_mode="complete_episodes") + ) + cls.trainer = config.build() + + # Train DQN for evaluation policy + tune.run( + "DQN", + config=config.to_dict(), + stop={"timesteps_total": train_steps}, + verbose=0, + ) + + # Read n_batches of data + reader = JsonReader(data_file) + cls.batch = reader.next() + for _ in range(n_batches - 1): + cls.batch = cls.batch.concat(reader.next()) + cls.n_episodes = len(cls.batch.split_by_episode()) + print("Episodes:", cls.n_episodes, "Steps:", cls.batch.count) + + cls.mean_ret = {} + cls.std_ret = {} + + # Simulate Monte-Carlo rollouts + mc_ret = [] + env = gym.make(env_name) + for _ in range(n_eval_episodes): + obs = env.reset() + done = False + rewards = [] + while not done: + act = cls.trainer.compute_single_action(obs) + obs, reward, done, _ = env.step(act) + rewards.append(reward) + ret = 0 + for r in reversed(rewards): + ret = r + cls.gamma * ret + mc_ret.append(ret) + + cls.mean_ret["simulation"] = np.mean(mc_ret) + cls.std_ret["simulation"] = np.std(mc_ret) + + # Optional configs for the model-based estimators + cls.model_config = {"k": 2, "n_iters": 10} + ray.shutdown() + + @classmethod + def tearDownClass(cls): + print("Mean:", cls.mean_ret) + print("Stddev:", cls.std_ret) + ray.shutdown() + + def test_is(self): + name = "is" + estimator = ImportanceSampling( + name=name, + policy=self.trainer.get_policy(), + gamma=self.gamma, + ) + estimator.process(self.batch) + estimates = estimator.get_metrics() + assert len(estimates) == self.n_episodes + self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) + self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) + + def test_wis(self): + name = "wis" + estimator = WeightedImportanceSampling( + name=name, + policy=self.trainer.get_policy(), + gamma=self.gamma, + ) + estimator.process(self.batch) + estimates = estimator.get_metrics() + assert len(estimates) == self.n_episodes + self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) + self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) + + def test_dm_qreg(self): + name = "dm_qreg" + estimator = DirectMethod( + name=name, + policy=self.trainer.get_policy(), + gamma=self.gamma, + q_model_type="qreg", + **self.model_config, + ) + estimator.process(self.batch) + estimates = estimator.get_metrics() + assert len(estimates) == self.n_episodes + self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) + self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) + + def test_dm_fqe(self): + name = "dm_fqe" + estimator = DirectMethod( + name=name, + policy=self.trainer.get_policy(), + gamma=self.gamma, + q_model_type="fqe", + **self.model_config, + ) + estimator.process(self.batch) + estimates = estimator.get_metrics() + assert len(estimates) == self.n_episodes + self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) + self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) + + def test_dr_qreg(self): + name = "dr_qreg" + estimator = DoublyRobust( + name=name, + policy=self.trainer.get_policy(), + gamma=self.gamma, + q_model_type="qreg", + **self.model_config, + ) + estimator.process(self.batch) + estimates = estimator.get_metrics() + assert len(estimates) == self.n_episodes + self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) + self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) + + def test_dr_fqe(self): + name = "dr_fqe" + estimator = DoublyRobust( + name=name, + policy=self.trainer.get_policy(), + gamma=self.gamma, + q_model_type="fqe", + **self.model_config, + ) + estimator.process(self.batch) + estimates = estimator.get_metrics() + assert len(estimates) == self.n_episodes + self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) + self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) + + def test_ope_in_trainer(self): + # TODO (rohan): Add performance tests for off_policy_estimation_methods, + # with fixed seeds and hyperparameters + pass + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/offline/estimators/weighted_importance_sampling.py b/rllib/offline/estimators/weighted_importance_sampling.py index bf772cea6..cdd4335e6 100644 --- a/rllib/offline/estimators/weighted_importance_sampling.py +++ b/rllib/offline/estimators/weighted_importance_sampling.py @@ -1,56 +1,66 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate +from ray.rllib.offline.estimators.off_policy_estimator import ( + OffPolicyEstimator, + OffPolicyEstimate, +) from ray.rllib.policy import Policy from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.typing import SampleBatchType +import numpy as np @DeveloperAPI class WeightedImportanceSampling(OffPolicyEstimator): """The weighted step-wise IS estimator. - Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" + Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf, + https://arxiv.org/pdf/1911.06854.pdf""" - def __init__(self, policy: Policy, gamma: float): - super().__init__(policy, gamma) + @override(OffPolicyEstimator) + def __init__(self, name: str, policy: Policy, gamma: float): + super().__init__(name, policy, gamma) self.filter_values = [] self.filter_counts = [] @override(OffPolicyEstimator) def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: self.check_can_estimate_for(batch) + estimates = [] + for sub_batch in batch.split_by_episode(): + rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"] + new_prob = np.exp(self.action_log_likelihood(sub_batch)) - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_log_likelihood(batch) + # calculate importance ratios + p = [] + for t in range(sub_batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) + for t, v in enumerate(p): + if t >= len(self.filter_values): + self.filter_values.append(v) + self.filter_counts.append(1.0) + else: + self.filter_values[t] += v + self.filter_counts[t] += 1.0 - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) - for t, v in enumerate(p): - if t >= len(self.filter_values): - self.filter_values.append(v) - self.filter_counts.append(1.0) - else: - self.filter_values[t] += v - self.filter_counts[t] += 1.0 + # calculate stepwise weighted IS estimate + v_old = 0.0 + v_new = 0.0 + for t in range(sub_batch.count): + v_old += rewards[t] * self.gamma ** t + w_t = self.filter_values[t] / self.filter_counts[t] + v_new += p[t] / w_t * rewards[t] * self.gamma ** t - # calculate stepwise weighted IS estimate - V_prev, V_step_WIS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma ** t - w_t = self.filter_values[t] / self.filter_counts[t] - V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t - - estimation = OffPolicyEstimate( - "weighted_importance_sampling", - { - "V_prev": V_prev, - "V_step_WIS": V_step_WIS, - "V_gain_est": V_step_WIS / max(1e-8, V_prev), - }, - ) - return estimation + estimates.append( + OffPolicyEstimate( + self.name, + { + "v_old": v_old, + "v_new": v_new, + "v_gain": v_new / max(1e-8, v_old), + }, + ) + ) + return estimates diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 14c037955..0b9169d28 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -1,167 +1,11 @@ -from collections import namedtuple -import logging - -import numpy as np - -from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch -from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.offline.io_context import IOContext -from ray.rllib.utils.annotations import Deprecated -from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.typing import TensorType, SampleBatchType -from typing import List - -logger = logging.getLogger(__name__) - -OffPolicyEstimate = DeveloperAPI( - namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"]) +from ray.rllib.offline.estimators.off_policy_estimator import ( # noqa: F401 + OffPolicyEstimator, + OffPolicyEstimate, ) +from ray.rllib.utils.deprecation import deprecation_warning - -@DeveloperAPI -class OffPolicyEstimator: - """Interface for an off policy reward estimator.""" - - def __init__(self, policy: Policy, gamma: float): - """Initializes an OffPolicyEstimator instance. - - Args: - policy: Policy to evaluate. - gamma: Discount factor of the environment. - """ - self.policy = policy - self.gamma = gamma - self.new_estimates = [] - - @classmethod - @DeveloperAPI - def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator": - """Creates an off-policy estimator from an IOContext object. - - Extracts Policy and gamma (discount factor) information from the - IOContext. - - Args: - ioctx: The IOContext object to create the OffPolicyEstimator - from. - - Returns: - The OffPolicyEstimator object created from the IOContext object. - """ - gamma = ioctx.worker.policy_config["gamma"] - # Grab a reference to the current model - keys = list(ioctx.worker.policy_map.keys()) - if len(keys) > 1: - raise NotImplementedError( - "Off-policy estimation is not implemented for multi-agent. " - "You can set `off_policy_estimation_methods: []` to resolve this." - ) - policy = ioctx.worker.get_policy(keys[0]) - return cls(policy, gamma) - - @DeveloperAPI - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: - """Returns an off policy estimate for the given batch of experiences. - - The batch will at most only contain data from one episode, - but it may also only be a fragment of an episode. - - Args: - batch: The batch to calculate the off policy estimate (OPE) on. - - Returns: - The off-policy estimates (OPE) calculated on the given batch. - """ - raise NotImplementedError - - @DeveloperAPI - def action_log_likelihood(self, batch: SampleBatchType) -> TensorType: - """Returns log likelihoods for actions in given batch for policy. - - Computes likelihoods by passing the observations through the current - policy's `compute_log_likelihoods()` method. - - Args: - batch: The SampleBatch or MultiAgentBatch to calculate action - log likelihoods from. This batch/batches must contain OBS - and ACTIONS keys. - - Returns: - The log likelihoods of the actions in the batch, given the - observations and the policy. - """ - num_state_inputs = 0 - for k in batch.keys(): - if k.startswith("state_in_"): - num_state_inputs += 1 - state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] - log_likelihoods: TensorType = self.policy.compute_log_likelihoods( - actions=batch[SampleBatch.ACTIONS], - obs_batch=batch[SampleBatch.OBS], - state_batches=[batch[k] for k in state_keys], - prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), - prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), - actions_normalized=True, - ) - log_likelihoods = convert_to_numpy(log_likelihoods) - return np.exp(log_likelihoods) - - @DeveloperAPI - def process(self, batch: SampleBatchType) -> None: - """Computes off policy estimates (OPE) on batch and stores results. - - Thus-far collected results can be retrieved then by calling - `self.get_metrics` (which flushes the internal results storage). - - Args: - batch: The batch to process (call `self.estimate()` on) and - store results (OPEs) for. - """ - self.new_estimates.append(self.estimate(batch)) - - @DeveloperAPI - def check_can_estimate_for(self, batch: SampleBatchType) -> None: - """Checks if we support off policy estimation (OPE) on given batch. - - Args: - batch: The batch to check. - - Raises: - ValueError: In case `action_prob` key is not in batch OR batch - is a MultiAgentBatch. - """ - - if isinstance(batch, MultiAgentBatch): - raise ValueError( - "off-policy estimation is not implemented for multi-agent batches. " - "You can set `off_policy_estimation_methods: []` to resolve this." - ) - - if "action_prob" not in batch: - raise ValueError( - "Off-policy estimation is not possible unless the inputs " - "include action probabilities (i.e., the policy is stochastic " - "and emits the 'action_prob' key). For DQN this means using " - "`exploration_config: {type: 'SoftQ'}`. You can also set " - "`off_policy_estimation_methods: []` to disable estimation." - ) - - @DeveloperAPI - def get_metrics(self) -> List[OffPolicyEstimate]: - """Returns list of new episode metric estimates since the last call. - - Returns: - List of OffPolicyEstimate objects. - """ - out = self.new_estimates - self.new_estimates = [] - return out - - @Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False) - def create(self, *args, **kwargs): - return self.create_from_io_context(*args, **kwargs) - - @Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False) - def action_prob(self, *args, **kwargs): - return self.action_log_likelihood(*args, **kwargs) +deprecation_warning( + old="ray.rllib.offline.off_policy_estimator", + new="ray.rllib.offline.estimators.off_policy_estimator", + error=False, +) diff --git a/rllib/tests/test_io.py b/rllib/tests/test_io.py index bc1a709b1..1c9b2ed4a 100644 --- a/rllib/tests/test_io.py +++ b/rllib/tests/test_io.py @@ -98,7 +98,7 @@ class AgentIOTest(unittest.TestCase): env="CartPole-v0", config={ "input": self.test_dir + fw, - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, "framework": fw, }, ) @@ -141,7 +141,7 @@ class AgentIOTest(unittest.TestCase): env="CartPole-v0", config={ "input": self.test_dir + fw, - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, "postprocess_inputs": True, # adds back 'advantages' "framework": fw, }, @@ -158,7 +158,9 @@ class AgentIOTest(unittest.TestCase): env="CartPole-v0", config={ "input": self.test_dir + fw, - "off_policy_estimation_methods": ["simulation"], + "off_policy_estimation_methods": { + "simulation": {"type": "simulation"} + }, "framework": fw, }, ) @@ -176,7 +178,7 @@ class AgentIOTest(unittest.TestCase): env="CartPole-v0", config={ "input": glob.glob(self.test_dir + fw + "/*.json"), - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, "rollout_fragment_length": 99, "framework": fw, }, @@ -196,7 +198,7 @@ class AgentIOTest(unittest.TestCase): "sampler": 0.9, }, "train_batch_size": 2000, - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, "framework": fw, }, ) @@ -234,7 +236,9 @@ class AgentIOTest(unittest.TestCase): config={ "num_workers": 0, "input": self.test_dir, - "off_policy_estimation_methods": ["simulation"], + "off_policy_estimation_methods": { + "simulation": {"type": "simulation"} + }, "train_batch_size": 2000, "multiagent": { "policies": {"policy_1", "policy_2"}, @@ -276,7 +280,7 @@ class AgentIOTest(unittest.TestCase): config={ "input": input_procedure, "input_config": {"input_files": self.test_dir + fw}, - "off_policy_estimation_methods": [], + "off_policy_estimation_methods": {}, "framework": fw, }, ) diff --git a/rllib/tests/test_nested_action_spaces.py b/rllib/tests/test_nested_action_spaces.py index 643950798..01b045a66 100644 --- a/rllib/tests/test_nested_action_spaces.py +++ b/rllib/tests/test_nested_action_spaces.py @@ -69,7 +69,7 @@ class NestedActionSpacesTest(unittest.TestCase): config["output"] = tmp_dir # Switch off OPE as we don't write action-probs. # TODO: We should probably always write those if `output` is given. - config["off_policy_estimation_methods"] = [] + config["off_policy_estimation_methods"] = {} # Pretend actions in offline files are already normalized. config["actions_in_input_normalized"] = True