[RLlib]: Doubly Robust Off-Policy Evaluation. (#25056)

This commit is contained in:
Rohan Potdar 2022-06-07 03:52:19 -07:00 committed by GitHub
parent 429d0f0eee
commit a9d8da0100
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 1378 additions and 315 deletions

View file

@ -521,7 +521,7 @@ You can configure any Trainer to launch a policy server with the following confi
# Use the existing trainer process to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
}
Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server.

View file

@ -48,7 +48,7 @@ Then, we can tell DQN to train using these previously generated experiences with
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
"explore": false}'
.. _is:
@ -62,7 +62,14 @@ Then, we can tell DQN to train using these previously generated experiences with
--env=CartPole-v0 \
--config='{
"input": "/tmp/cartpole-out",
"off_policy_estimation_methods": ["is", "wis"],
"off_policy_estimation_methods": {
"is": {
"type": "ImportanceSampling",
},
"wis": {
"type": "WeightedImportanceSampling",
}
},
"exploration_config": {
"type": "SoftQ",
"temperature": 1.0,
@ -275,10 +282,10 @@ You can configure experience input for an agent using the following options:
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"off_policy_estimation_methods": [
ImportanceSampling,
WeightedImportanceSampling,
],
"off_policy_estimation_methods": {
ImportanceSampling: None,
WeightedImportanceSampling: None,
},
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behavior* policy, which is typically undesirable for

View file

@ -574,10 +574,14 @@ The following is a list of the common algorithm hyper-parameters:
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"off_policy_estimation_methods": [
ImportanceSampling,
WeightedImportanceSampling,
],
"off_policy_estimation_methods": {
"is": {
"type": ImportanceSampling,
},
"wis": {
"type": WeightedImportanceSampling,
}
},
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behavior* policy, which is typically undesirable for

View file

@ -26,6 +26,7 @@ def deep_update(
new_keys_allowed: bool = False,
allow_new_subkey_list: Optional[List[str]] = None,
override_all_if_type_changes: Optional[List[str]] = None,
override_all_key_list: Optional[List[str]] = None,
) -> dict:
"""Updates original dict with values from new_dict recursively.
@ -37,22 +38,29 @@ def deep_update(
original: Dictionary with default values.
new_dict: Dictionary with values to be updated
new_keys_allowed: Whether new keys are allowed.
allow_new_subkey_list (Optional[List[str]]): List of keys that
allow_new_subkey_list: List of keys that
correspond to dict values where new subkeys can be introduced.
This is only at the top level.
override_all_if_type_changes(Optional[List[str]]): List of top level
override_all_if_type_changes: List of top level
keys with value=dict, for which we always simply override the
entire value (dict), iff the "type" key in that value dict changes.
override_all_key_list: List of top level keys
for which we override the entire value if the key is in the new_dict.
"""
allow_new_subkey_list = allow_new_subkey_list or []
override_all_if_type_changes = override_all_if_type_changes or []
override_all_key_list = override_all_key_list or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
# Both orginal value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
if (
isinstance(original.get(k), dict)
and isinstance(value, dict)
and k not in override_all_key_list
):
# Check old type vs old one. If different, override entire value.
if (
k in override_all_if_type_changes
@ -63,10 +71,20 @@ def deep_update(
original[k] = value
# Allowed key -> ok to add new subkeys.
elif k in allow_new_subkey_list:
deep_update(original[k], value, True)
deep_update(
original[k],
value,
True,
override_all_key_list=override_all_key_list,
)
# Non-allowed key.
else:
deep_update(original[k], value, new_keys_allowed)
deep_update(
original[k],
value,
new_keys_allowed,
override_all_key_list=override_all_key_list,
)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:

View file

@ -11,7 +11,7 @@ marwil-halfcheetahbulletenv-v0:
input: ["~/halfcheetah_expert_sac.zip"]
actions_in_input_normalized: true
# Switch off input evaluation (data does not contain action probs).
off_policy_estimation_methods: []
off_policy_estimation_methods: {}
num_gpus: 1

View file

@ -24,6 +24,7 @@
# - `evaluation` directory tests.
# - `execution` directory tests.
# - `models` directory tests.
# - `offline` directory tests.
# - `policy` directory tests.
# - `utils` directory tests.
@ -1140,7 +1141,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "DQN",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": {\"wis\": {\"type\": \"wis\"}, \"is\": {\"type\": \"is\"}}, \"exploration_config\": {\"type\": \"SoftQ\"}}'"
]
)
@ -1531,6 +1532,20 @@ py_test(
srcs = ["models/tests/test_preprocessors.py"]
)
# --------------------------------------------------------------------
# Offline
# rllib/offline/
#
# Tag: offline
# --------------------------------------------------------------------
py_test(
name = "test_ope",
tags = ["team:ml", "offline", "torch_only"],
size = "medium",
srcs = ["offline/estimators/tests/test_ope.py"]
)
# --------------------------------------------------------------------
# Policies
# rllib/policy/

View file

@ -34,7 +34,6 @@ from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.evaluation.episode import Episode
from ray.rllib.utils import force_list
from ray.rllib.evaluation.metrics import (
collect_episodes,
collect_metrics,
@ -189,6 +188,9 @@ class Trainer(Trainable):
"replay_buffer_config",
]
# List of keys that are always fully overridden if present in any dict or sub-dict
_override_all_key_list = ["off_policy_estimation_methods"]
@PublicAPI
def __init__(
self,
@ -1748,6 +1750,7 @@ class Trainer(Trainable):
_allow_unknown_configs,
cls._allow_unknown_subkeys,
cls._override_all_subkeys_if_type_changes,
cls._override_all_key_list,
)
@staticmethod
@ -1924,9 +1927,22 @@ class Trainer(Trainable):
error=False,
)
config["off_policy_estimation_methods"] = input_evaluation
config["off_policy_estimation_methods"] = force_list(
config["off_policy_estimation_methods"]
)
if isinstance(config["off_policy_estimation_methods"], list) or isinstance(
config["off_policy_estimation_methods"], tuple
):
ope_dict = {
str(ope): {"type": ope} for ope in self.off_policy_estimation_methods
}
deprecation_warning(
old="config.off_policy_estimation_methods={}".format(
self.off_policy_estimation_methods
),
new="config.off_policy_estimation_methods={}".format(
ope_dict,
),
error=False,
)
config["off_policy_estimation_methods"] = ope_dict
# Check model config.
# If no preprocessing, propagate into model's config as well

View file

@ -169,7 +169,7 @@ class TrainerConfig:
self.input_ = "sampler"
self.input_config = {}
self.actions_in_input_normalized = False
self.off_policy_estimation_methods = []
self.off_policy_estimation_methods = {}
self.postprocess_inputs = False
self.shuffle_buffer_size = 0
self.output = None
@ -932,15 +932,22 @@ class TrainerConfig:
when the offline file has been generated by another RLlib algorithm
(e.g. PPO or SAC), while "normalize_actions" was set to True.
input_evaluation: DEPRECATED: Use `off_policy_estimation_methods` instead!
off_policy_estimation_methods: Specify how to evaluate the current policy.
off_policy_estimation_methods: Specify how to evaluate the current policy,
along with any optional config parameters.
This only has an effect when reading offline experiences
("input" is not "sampler").
Available options:
- "simulation": Run the environment in the background, but use
this data for evaluation only and not for learning.
- Any subclass of OffPolicyEstimator, e.g.
ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
subclass.
Available keys:
- {ope_method_name: {"type": ope_type, ...}} where `ope_method_name`
is a user-defined string to save the OPE results under, and
`ope_type` can be:
- "simulation": Run the environment in the background, but use
this data for evaluation only and not for learning.
- Any subclass of OffPolicyEstimator, e.g.
ray.rllib.offline.estimators.is::ImportanceSampling
or your own custom subclass.
You can also add additional config arguments to be passed to the
OffPolicyEstimator in the dict, e.g.
{"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}}
postprocess_inputs: Whether to run postprocess_trajectory() on the
trajectory fragments from offline inputs. Note that postprocessing will
be done using the *current* policy, not the *behavior* policy, which
@ -978,9 +985,25 @@ class TrainerConfig:
),
error=True,
)
self.off_policy_estimation_methods = input_evaluation
if isinstance(off_policy_estimation_methods, list) or isinstance(
off_policy_estimation_methods, tuple
):
ope_dict = {
str(ope): {"type": ope} for ope in off_policy_estimation_methods
}
deprecation_warning(
old="offline_data(off_policy_estimation_methods={}".format(
off_policy_estimation_methods
),
new="offline_data(off_policy_estimation_methods={}".format(
ope_dict,
),
error=False,
)
off_policy_estimation_methods = ope_dict
if off_policy_estimation_methods is not None:
self.off_policy_estimation_methods = off_policy_estimation_methods
if postprocess_inputs is not None:
self.postprocess_inputs = postprocess_inputs
if shuffle_buffer_size is not None:

View file

@ -49,7 +49,7 @@ class BCConfig(MARWILConfig):
# not important for behavioral cloning.
self.postprocess_inputs = False
# No reward estimation.
self.off_policy_estimation_methods = []
self.off_policy_estimation_methods = {}
# __sphinx_doc_end__
# fmt: on

View file

@ -67,7 +67,7 @@ class CQLConfig(SACConfig):
# Changes to Trainer's/SACConfig's default:
# .offline_data()
self.off_policy_estimation_methods = []
self.off_policy_estimation_methods = {}
# .reporting()
self.min_sample_timesteps_per_reporting = 0

View file

@ -5,6 +5,7 @@ import unittest
import ray
from ray.rllib.algorithms import cql
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import (
check_compute_single_action,
@ -51,7 +52,7 @@ class TestCQL(unittest.TestCase):
# RLlib algorithm (e.g. PPO or SAC).
actions_in_input_normalized=False,
# Switch on off-policy evaluation.
off_policy_estimation_methods=["is"],
off_policy_estimation_methods={"is": {"type": ImportanceSampling}},
)
.training(
clip_actions=False,

View file

@ -103,9 +103,10 @@ class MARWILConfig(TrainerConfig):
# the same line.
self.input_ = "sampler"
# Use importance sampling estimators for reward.
self.off_policy_estimation_methods = [
ImportanceSampling, WeightedImportanceSampling
]
self.off_policy_estimation_methods = {
"is": {"type": ImportanceSampling},
"wis": {"type": WeightedImportanceSampling},
}
self.postprocess_inputs = True
self.lr = 1e-4
self.train_batch_size = 2000

View file

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import ray
from ray import ObjectRef
from ray.actor import ActorHandle
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY

View file

@ -33,8 +33,14 @@ from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.offline.estimators import ImportanceSampling, WeightedImportanceSampling
from ray.rllib.offline.estimators import (
OffPolicyEstimate,
OffPolicyEstimator,
ImportanceSampling,
WeightedImportanceSampling,
DirectMethod,
DoublyRobust,
)
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
@ -242,7 +248,7 @@ class RolloutWorker(ParallelIteratorWorker):
input_creator: Callable[
[IOContext], InputReader
] = lambda ioctx: ioctx.default_sampler_input(),
off_policy_estimation_methods: List[str] = frozenset([]),
off_policy_estimation_methods: Optional[Dict[str, Dict]] = None,
output_creator: Callable[
[IOContext], OutputWriter
] = lambda ioctx: NoopOutput(),
@ -339,14 +345,25 @@ class RolloutWorker(ParallelIteratorWorker):
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator: Function that returns an InputReader object for
loading previous generated experiences.
off_policy_estimation_methods: How to evaluate the policy performance.
Setting this only makes sense when the input is reading offline data.
Available options:
- "simulation" (str): Run the environment in the background, but use
off_policy_estimation_methods: A dict that specifies how to
evaluate the current policy.
This only has an effect when reading offline experiences
("input" is not "sampler").
Available key-value pairs:
- {"simulation": None}: Run the environment in the background, but use
this data for evaluation only and not for learning.
- Any subclass (type) of the OffPolicyEstimator API class, e.g.
`ray.rllib.offline.estimators.importance_sampling::ImportanceSampling`
- {ope_name: {"type": ope_type, args}}. where `ope_name` is an arbitrary
string under which the metrics for this OPE estimator are saved,
and `ope_type` can be any subclass of OffPolicyEstimator, e.g.
ray.rllib.offline.estimators::ImportanceSampling
or your own custom subclass.
You can also add additional config arguments to be passed to the
OffPolicyEstimator e.g.
off_policy_estimation_methods = {
"dr_qreg": {"type": DoublyRobust, "q_model_type": "qreg"},
"dm_64": {"type": DirectMethod, "batch_size": 64},
}
See ray/rllib/offline/estimators for more information.
output_creator: Function that returns an OutputWriter object for
saving generated experiences.
remote_worker_envs: If using num_envs_per_worker > 1,
@ -702,41 +719,50 @@ class RolloutWorker(ParallelIteratorWorker):
log_dir, policy_config, worker_index, self
)
self.reward_estimators: List[OffPolicyEstimator] = []
for method in off_policy_estimation_methods:
if method == "is":
method = ImportanceSampling
ope_types = {
"is": ImportanceSampling,
"wis": WeightedImportanceSampling,
"dm": DirectMethod,
"dr": DoublyRobust,
}
off_policy_estimation_methods = off_policy_estimation_methods or {}
for name, method_config in off_policy_estimation_methods.items():
method_type = method_config.pop("type")
if method_type in ope_types:
deprecation_warning(
old="config.off_policy_estimation_methods=[is]",
new="from ray.rllib.offline.estimators import "
f"{method.__name__}; config.off_policy_estimation_methods="
f"[{method.__name__}]",
old=method_type,
new=str(ope_types[method_type]),
error=False,
)
elif method == "wis":
method = WeightedImportanceSampling
deprecation_warning(
old="config.off_policy_estimation_methods=[wis]",
new="from ray.rllib.offline.estimators import "
f"{method.__name__}; config.off_policy_estimation_methods="
f"[{method.__name__}]",
error=False,
)
if method == "simulation":
method_type = ope_types[method_type]
if name == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics."
)
sample_async = True
elif isinstance(method, type) and issubclass(method, OffPolicyEstimator):
# TODO: Allow for this to be a full classpath string as well, then construct
# this with our `from_config` util.
elif isinstance(method_type, type) and issubclass(
method_type, OffPolicyEstimator
):
gamma = self.io_context.worker.policy_config["gamma"]
# Grab a reference to the current model
keys = list(self.io_context.worker.policy_map.keys())
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `input_evaluation: []` to resolve this."
)
policy = self.io_context.worker.get_policy(keys[0])
self.reward_estimators.append(
method.create_from_io_context(self.io_context)
method_type(name=name, policy=policy, gamma=gamma, **method_config)
)
else:
raise ValueError(
f"Unknown evaluation method: {method}! Must be "
"either `simulation` or a sub-class of ray.rllib.offline."
"off_policy_estimator::OffPolicyEstimator"
f"Unknown off_policy_estimation type: {method_type}! Must be "
"either `simulation|is|wis|dm|dr` or a sub-class of ray.rllib."
"offline.estimators.off_policy_estimator::OffPolicyEstimator"
)
render = False
@ -866,9 +892,8 @@ class RolloutWorker(ParallelIteratorWorker):
# Do off-policy estimation, if needed.
if self.reward_estimators:
for sub_batch in batch.split_by_episode():
for estimator in self.reward_estimators:
estimator.process(sub_batch)
for estimator in self.reward_estimators:
estimator.process(batch)
if log_once("sample_end"):
logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch)))
@ -1138,7 +1163,7 @@ class RolloutWorker(ParallelIteratorWorker):
out = self.sampler.get_metrics()
else:
out = []
# Get metrics from our reward-estimators (if any).
# Get metrics from our reward estimators (if any).
for m in self.reward_estimators:
out.extend(m.get_metrics())

View file

@ -627,7 +627,7 @@ class WorkerSet:
)
if config["input"] == "sampler":
off_policy_estimation_methods = []
off_policy_estimation_methods = {}
else:
off_policy_estimation_methods = config["off_policy_estimation_methods"]

View file

@ -165,7 +165,7 @@ if __name__ == "__main__":
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
# Create a "chatty" client/server or not.
"callbacks": MyCallbacks if args.callbacks_verbose else None,
# DL framework to use.

View file

@ -132,7 +132,7 @@ if __name__ == "__main__":
# Use n worker processes to listen on different ports.
"num_workers": args.num_workers,
# Disable OPE, since the rollouts are coming from online clients.
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
# Other settings.
"train_batch_size": 256,
"rollout_fragment_length": 20,

View file

@ -2,8 +2,18 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.offline.estimators.direct_method import DirectMethod
from ray.rllib.offline.estimators.doubly_robust import DoublyRobust
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimate,
OffPolicyEstimator,
)
__all__ = [
"OffPolicyEstimator",
"OffPolicyEstimate",
"ImportanceSampling",
"WeightedImportanceSampling",
"DirectMethod",
"DoublyRobust",
]

View file

@ -0,0 +1,170 @@
from typing import Tuple, List, Generator
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.offline.estimators.qreg_torch_model import QRegTorchModel
from gym.spaces import Discrete
import numpy as np
torch, nn = try_import_torch()
# TODO (rohan): replace with AIR/parallel workers
# (And find a better name than `should_train`)
@DeveloperAPI
def k_fold_cv(
batch: SampleBatchType, k: int, should_train: bool = True
) -> Generator[Tuple[List[SampleBatch]], None, None]:
"""Utility function that returns a k-fold cross validation generator
over episodes from the given batch. If the number of episodes in the
batch is less than `k` or `should_train` is set to False, yields an empty
list for train_episodes and all the episodes in test_episodes.
Args:
batch: A SampleBatch of episodes to split
k: Number of cross-validation splits
should_train: True by default. If False, yield [], [episodes].
Returns:
A tuple with two lists of SampleBatches (train_episodes, test_episodes)
"""
episodes = batch.split_by_episode()
n_episodes = len(episodes)
if n_episodes < k or not should_train:
yield [], episodes
return
n_fold = n_episodes // k
for i in range(k):
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
if i != k - 1:
test_episodes = episodes[i * n_fold : (i + 1) * n_fold]
else:
# Append remaining episodes onto the last test_episodes
test_episodes = episodes[i * n_fold :]
yield train_episodes, test_episodes
return
@DeveloperAPI
class DirectMethod(OffPolicyEstimator):
"""The Direct Method estimator.
DM estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
@override(OffPolicyEstimator)
def __init__(
self,
name: str,
policy: Policy,
gamma: float,
q_model_type: str = "fqe",
k: int = 5,
**kwargs,
):
"""
Initializes a Direct Method OPE Estimator.
Args:
name: string to save OPE results under
policy: Policy to evaluate.
gamma: Discount factor of the environment.
q_model_type: Either "fqe" for Fitted Q-Evaluation
or "qreg" for Q-Regression, or a custom model that implements:
- `estimate_q(states,actions)`
- `estimate_v(states, action_probs)`
k: k-fold cross validation for training model and evaluating OPE
kwargs: Optional arguments for the specified Q model
"""
super().__init__(name, policy, gamma)
# TODO (rohan): Add support for continuous action spaces
assert isinstance(
policy.action_space, Discrete
), "DM Estimator only supports discrete action spaces!"
assert (
policy.config["batch_mode"] == "complete_episodes"
), "DM Estimator only supports `batch_mode`=`complete_episodes`"
# TODO (rohan): Add support for TF!
if policy.framework == "torch":
if q_model_type == "qreg":
model_cls = QRegTorchModel
elif q_model_type == "fqe":
model_cls = FQETorchModel
else:
assert hasattr(
q_model_type, "estimate_q"
), "q_model_type must implement `estimate_q`!"
assert hasattr(
q_model_type, "estimate_v"
), "q_model_type must implement `estimate_v`!"
else:
raise ValueError(
f"{self.__class__.__name__}"
"estimator only supports `policy.framework`=`torch`"
)
self.model = model_cls(
policy=policy,
gamma=gamma,
**kwargs,
)
self.k = k
self.losses = []
@override(OffPolicyEstimator)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Train Q-function
if train_episodes:
# Reinitialize model
self.model.reset()
train_batch = SampleBatch.concat_samples(train_episodes)
losses = self.train(train_batch)
self.losses.append(losses)
# Calculate direct method OPE estimates
for episode in test_episodes:
rewards = episode["rewards"]
v_old = 0.0
v_new = 0.0
for t in range(episode.count):
v_old += rewards[t] * self.gamma ** t
init_step = episode[0:1]
init_obs = np.array([init_step[SampleBatch.OBS]])
all_actions = np.arange(self.policy.action_space.n, dtype=float)
init_step[SampleBatch.ACTIONS] = all_actions
action_probs = np.exp(self.action_log_likelihood(init_step))
v_value = self.model.estimate_v(init_obs, action_probs)
v_new = convert_to_numpy(v_value).item()
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates
@override(OffPolicyEstimator)
def train(self, batch: SampleBatchType):
return self.model.train_q(batch)

View file

@ -0,0 +1,71 @@
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
from ray.rllib.offline.estimators.direct_method import DirectMethod, k_fold_cv
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
import numpy as np
@DeveloperAPI
class DoublyRobust(DirectMethod):
"""The Doubly Robust (DR) estimator.
DR estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
@override(DirectMethod)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Train Q-function
if train_episodes:
# Reinitialize model
self.model.reset()
train_batch = SampleBatch.concat_samples(train_episodes)
losses = self.train(train_batch)
self.losses.append(losses)
# Calculate doubly robust OPE estimates
for episode in test_episodes:
rewards, old_prob = episode["rewards"], episode["action_prob"]
new_prob = np.exp(self.action_log_likelihood(episode))
v_old = 0.0
v_new = 0.0
q_values = self.model.estimate_q(
episode[SampleBatch.OBS], episode[SampleBatch.ACTIONS]
)
q_values = convert_to_numpy(q_values)
all_actions = np.zeros([episode.count, self.policy.action_space.n])
all_actions[:] = np.arange(self.policy.action_space.n)
# Two transposes required for torch.distributions to work
tmp_episode = episode.copy()
tmp_episode[SampleBatch.ACTIONS] = all_actions.T
action_probs = np.exp(self.action_log_likelihood(tmp_episode)).T
v_values = self.model.estimate_v(episode[SampleBatch.OBS], action_probs)
v_values = convert_to_numpy(v_values)
for t in reversed(range(episode.count)):
v_old = rewards[t] + self.gamma * v_old
v_new = v_values[t] + (new_prob[t] / old_prob[t]) * (
rewards[t] + self.gamma * v_new - q_values[t]
)
v_new = v_new.item()
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates

View file

@ -0,0 +1,222 @@
from ray.rllib.models.utils import get_initializer
from ray.rllib.policy import Policy
from typing import List, Union
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
torch, nn = try_import_torch()
@DeveloperAPI
class FQETorchModel:
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
https://arxiv.org/pdf/1911.06854.pdf
"""
def __init__(
self,
policy: Policy,
gamma: float,
model: ModelConfigDict = None,
n_iters: int = 160,
lr: float = 1e-3,
delta: float = 1e-4,
clip_grad_norm: float = 100.0,
batch_size: int = 32,
tau: float = 0.05,
) -> None:
"""
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
# The ModelConfigDict for self.q_model
model = {
"fcnet_hiddens": [8, 8],
"fcnet_activation": "relu",
"vf_share_layers": True,
},
# Maximum number of training iterations to run on the batch
n_iters = 160,
# Learning rate for Q-function optimizer
lr = 1e-3,
# Early stopping if the mean loss < delta
delta = 1e-4,
# Clip gradients to this maximum value
clip_grad_norm = 100.0,
# Minibatch size for training Q-function
batch_size = 32,
# Polyak averaging factor for target Q-function
tau = 0.05
"""
self.policy = policy
self.gamma = gamma
self.observation_space = policy.observation_space
self.action_space = policy.action_space
if model is None:
model = {
"fcnet_hiddens": [8, 8],
"fcnet_activation": "relu",
"vf_share_layers": True,
}
self.device = self.policy.device
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
self.action_space.n,
model,
framework="torch",
name="TorchQModel",
).to(self.device)
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
self.action_space.n,
model,
framework="torch",
name="TargetTorchQModel",
).to(self.device)
self.n_iters = n_iters
self.lr = lr
self.delta = delta
self.clip_grad_norm = clip_grad_norm
self.batch_size = batch_size
self.tau = tau
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
initializer = get_initializer("xavier_uniform", framework="torch")
# Hard update target
self.update_target(tau=1.0)
def f(m):
if isinstance(m, nn.Linear):
initializer(m.weight)
self.initializer = f
def reset(self) -> None:
"""Resets/Reinintializes the model weights."""
self.q_model.apply(self.initializer)
def train_q(self, batch: SampleBatch) -> TensorType:
"""Trains self.q_model using FQE loss on given batch.
Args:
batch: A SampleBatch of episodes to train on
Returns:
A list of losses for each training iteration
"""
losses = []
for _ in range(self.n_iters):
minibatch_losses = []
batch.shuffle()
for idx in range(0, batch.count, self.batch_size):
minibatch = batch[idx : idx + self.batch_size]
obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
actions = torch.tensor(
minibatch[SampleBatch.ACTIONS], device=self.device
)
rewards = torch.tensor(
minibatch[SampleBatch.REWARDS], device=self.device
)
next_obs = torch.tensor(
minibatch[SampleBatch.NEXT_OBS], device=self.device
)
dones = torch.tensor(minibatch[SampleBatch.DONES], device=self.device)
# Neccessary if policy uses recurrent/attention model
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
num_state_inputs += 1
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
# Compute action_probs for next_obs as in FQE
all_actions = torch.zeros([minibatch.count, self.policy.action_space.n])
all_actions[:] = torch.arange(self.policy.action_space.n)
next_action_prob = self.policy.compute_log_likelihoods(
actions=all_actions.T,
obs_batch=next_obs,
state_batches=[minibatch[k] for k in state_keys],
prev_action_batch=minibatch[SampleBatch.ACTIONS],
prev_reward_batch=minibatch[SampleBatch.REWARDS],
actions_normalized=False,
)
next_action_prob = (
torch.exp(next_action_prob.T).to(self.device).detach()
)
q_values, _ = self.q_model({"obs": obs}, [], None)
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
with torch.no_grad():
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
targets = rewards + ~dones * self.gamma * next_v
loss = (targets - q_acts) ** 2
loss = torch.mean(loss)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad.clip_grad_norm_(
self.q_model.variables(), self.clip_grad_norm
)
self.optimizer.step()
minibatch_losses.append(loss.item())
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
losses.append(iter_loss)
if iter_loss < self.delta:
break
self.update_target()
return losses
def estimate_q(
self,
obs: Union[TensorType, List[TensorType]],
actions: Union[TensorType, List[TensorType]] = None,
) -> TensorType:
"""Given `obs`, a list or array or tensor of observations,
compute the Q-values for `obs` for all actions in the action space.
If `actions` is not None, return the Q-values for the actions provided,
else return Q-values for all actions for each observation in `obs`.
"""
obs = torch.tensor(obs, device=self.device)
q_values, _ = self.q_model({"obs": obs}, [], None)
if actions is not None:
actions = torch.tensor(actions, device=self.device, dtype=int)
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
return q_values.detach()
def estimate_v(
self,
obs: Union[TensorType, List[TensorType]],
action_probs: Union[TensorType, List[TensorType]],
) -> TensorType:
"""Given `obs`, compute q-values for all actions in the action space
for each observations s in `obs`, then multiply this by `action_probs`,
the probability distribution over actions for each state s to give the
state value V(s) = sum_A pi(a|s)Q(s,a).
"""
q_values = self.estimate_q(obs)
action_probs = torch.tensor(action_probs, device=self.device)
v_values = torch.sum(q_values * action_probs, axis=-1)
return v_values.detach()
def update_target(self, tau=None):
# Update_target will be called periodically to copy Q network to
# target Q network, using (soft) tau-synching.
tau = tau or self.tau
model_state_dict = self.q_model.state_dict()
# Support partial (soft) synching.
# If tau == 1.0: Full sync from Q-model to target Q-model.
target_state_dict = self.target_q_model.state_dict()
model_state_dict = {
k: tau * model_state_dict[k] + (1 - tau) * v
for k, v in target_state_dict.items()
}
self.target_q_model.load_state_dict(model_state_dict)

View file

@ -1,42 +1,52 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.typing import SampleBatchType
from typing import List
import numpy as np
@DeveloperAPI
class ImportanceSampling(OffPolicyEstimator):
"""The step-wise IS estimator.
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf,
https://arxiv.org/pdf/1911.06854.pdf"""
@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
self.check_can_estimate_for(batch)
estimates = []
for sub_batch in batch.split_by_episode():
rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"]
new_prob = np.exp(self.action_log_likelihood(sub_batch))
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)
# calculate importance ratios
p = []
for t in range(sub_batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
# calculate stepwise IS estimate
v_old = 0.0
v_new = 0.0
for t in range(sub_batch.count):
v_old += rewards[t] * self.gamma ** t
v_new += p[t] * rewards[t] * self.gamma ** t
# calculate stepwise IS estimate
V_prev, V_step_IS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma ** t
V_step_IS += p[t] * rewards[t] * self.gamma ** t
estimation = OffPolicyEstimate(
"importance_sampling",
{
"V_prev": V_prev,
"V_step_IS": V_step_IS,
"V_gain_est": V_step_IS / max(1e-8, V_prev),
},
)
return estimation
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates

View file

@ -0,0 +1,188 @@
from collections import namedtuple
import logging
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType, SampleBatchType
from typing import List
logger = logging.getLogger(__name__)
OffPolicyEstimate = DeveloperAPI(
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
)
@DeveloperAPI
class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""
@DeveloperAPI
def __init__(self, name: str, policy: Policy, gamma: float):
"""Initializes an OffPolicyEstimator instance.
Args:
name: string to save OPE results under
policy: Policy to evaluate.
gamma: Discount factor of the environment.
"""
self.name = name
self.policy = policy
self.gamma = gamma
self.new_estimates = []
@DeveloperAPI
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
"""Returns a list of off policy estimates for the given batch of episodes.
Args:
batch: The batch to calculate the off policy estimates (OPE) on.
Returns:
The off-policy estimates (OPE) calculated on the given batch.
"""
raise NotImplementedError
@DeveloperAPI
def train(self, batch: SampleBatchType) -> TensorType:
"""Trains an Off-Policy Estimator on a batch of experiences.
A model-based estimator should override this and train
a transition, value, or reward model.
Args:
batch: The batch to train the model on
Returns:
any optional training/loss metrics from the model
"""
pass
@DeveloperAPI
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
"""Returns log likelihood for actions in given batch for policy.
Computes likelihoods by passing the observations through the current
policy's `compute_log_likelihoods()` method
Args:
batch: The SampleBatch or MultiAgentBatch to calculate action
log likelihoods from. This batch/batches must contain OBS
and ACTIONS keys.
Returns:
The probabilities of the actions in the batch, given the
observations and the policy.
"""
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
num_state_inputs += 1
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
actions=batch[SampleBatch.ACTIONS],
obs_batch=batch[SampleBatch.OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=True,
)
log_likelihoods = convert_to_numpy(log_likelihoods)
return log_likelihoods
@DeveloperAPI
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"""Checks if we support off policy estimation (OPE) on given batch.
Args:
batch: The batch to check.
Raises:
ValueError: In case `action_prob` key is not in batch OR batch
is a MultiAgentBatch.
"""
if isinstance(batch, MultiAgentBatch):
raise ValueError(
"Off-Policy Estimation is not implemented for multi-agent batches. "
"You can set `off_policy_estimation_methods: {}` to resolve this."
)
if "action_prob" not in batch:
raise ValueError(
"Off-policy estimation is not possible unless the inputs "
"include action probabilities (i.e., the policy is stochastic "
"and emits the 'action_prob' key). For DQN this means using "
"`exploration_config: {type: 'SoftQ'}`. You can also set "
"`off_policy_estimation_methods: {}` to disable estimation."
)
@DeveloperAPI
def process(self, batch: SampleBatchType) -> None:
"""Computes off policy estimates (OPE) on batch and stores results.
Thus-far collected results can be retrieved then by calling
`self.get_metrics` (which flushes the internal results storage).
Args:
batch: The batch to process (call `self.estimate()` on) and
store results (OPEs) for.
"""
self.new_estimates.extend(self.estimate(batch))
@DeveloperAPI
def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]:
"""Returns list of new episode metric estimates since the last call.
Args:
get_losses: If True, also return self.losses for the OPE estimator
Returns:
out: List of OffPolicyEstimate objects.
losses: List of training losses for the estimator.
"""
out = self.new_estimates
self.new_estimates = []
if hasattr(self, "losses"):
losses = self.losses
self.losses = []
if get_losses:
return out, losses
return out
# TODO (rohan): Remove deprecated methods; set to error=True because changing
# from one episode per SampleBatch to full SampleBatch is a breaking change anyway
@Deprecated(help="OffPolicyEstimator.__init__(policy, gamma, config)", error=False)
@classmethod
@DeveloperAPI
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Creates an off-policy estimator from an IOContext object.
Extracts Policy and gamma (discount factor) information from the
IOContext.
Args:
ioctx: The IOContext object to create the OffPolicyEstimator
from.
Returns:
The OffPolicyEstimator object created from the IOContext object.
"""
gamma = ioctx.worker.policy_config["gamma"]
# Grab a reference to the current model
keys = list(ioctx.worker.policy_map.keys())
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `input_evaluation: []` to resolve this."
)
policy = ioctx.worker.get_policy(keys[0])
config = ioctx.input_config.get("estimator_config", {})
return cls(policy, gamma, config)
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=True)
@DeveloperAPI
def create(self, *args, **kwargs):
return self.create_from_io_context(*args, **kwargs)
@Deprecated(new="OffPolicyEstimator.compute_log_likelihoods", error=False)
@DeveloperAPI
def action_prob(self, *args, **kwargs):
return self.compute_log_likelihoods(*args, **kwargs)

View file

@ -0,0 +1,224 @@
from ray.rllib.models.utils import get_initializer
from ray.rllib.policy import Policy
from typing import List, Union
import numpy as np
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType, ModelConfigDict
torch, nn = try_import_torch()
@DeveloperAPI
class QRegTorchModel:
"""Pytorch implementation of the Q-Reg model from
https://arxiv.org/pdf/1911.06854.pdf
"""
def __init__(
self,
policy: Policy,
gamma: float,
model: ModelConfigDict = None,
n_iters: int = 160,
lr: float = 1e-3,
delta: float = 1e-4,
clip_grad_norm: float = 100.0,
batch_size: int = 32,
) -> None:
"""
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
# The ModelConfigDict for self.q_model
model = {
"fcnet_hiddens": [8, 8],
"fcnet_activation": "relu",
"vf_share_layers": True,
},
# Maximum number of training iterations to run on the batch
n_iters = 160,
# Learning rate for Q-function optimizer
lr = 1e-3,
# Early stopping if the mean loss < delta
delta = 1e-4,
# Clip gradients to this maximum value
clip_grad_norm = 100.0,
# Minibatch size for training Q-function
batch_size = 32,
"""
self.policy = policy
self.gamma = gamma
self.observation_space = policy.observation_space
self.action_space = policy.action_space
if model is None:
model = {
"fcnet_hiddens": [8, 8],
"fcnet_activation": "relu",
"vf_share_layers": True,
}
self.device = self.policy.device
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
self.observation_space,
self.action_space,
self.action_space.n,
model,
framework="torch",
name="TorchQModel",
).to(self.device)
self.n_iters = n_iters
self.lr = lr
self.delta = delta
self.clip_grad_norm = clip_grad_norm
self.batch_size = batch_size
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
initializer = get_initializer("xavier_uniform", framework="torch")
def f(m):
if isinstance(m, nn.Linear):
initializer(m.weight)
self.initializer = f
def reset(self) -> None:
"""Resets/Reinintializes the model weights."""
self.q_model.apply(self.initializer)
def train_q(self, batch: SampleBatch) -> TensorType:
"""Trains self.q_model using Q-Reg loss on given batch.
Args:
batch: A SampleBatch of episodes to train on
Returns:
A list of losses for each training iteration
"""
losses = []
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device)
ps = torch.zeros([batch.count], device=self.device)
returns = torch.zeros([batch.count], device=self.device)
discounts = torch.zeros([batch.count], device=self.device)
# Neccessary if policy uses recurrent/attention model
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
num_state_inputs += 1
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
# get rewards, old_prob, new_prob
rewards = batch[SampleBatch.REWARDS]
old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP])
new_log_prob = (
self.policy.compute_log_likelihoods(
actions=batch[SampleBatch.ACTIONS],
obs_batch=batch[SampleBatch.OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=False,
)
.detach()
.cpu()
)
prob_ratio = torch.exp(new_log_prob - old_log_prob)
eps_begin = 0
for episode in batch.split_by_episode():
eps_end = eps_begin + episode.count
# calculate importance ratios and returns
for t in range(episode.count):
discounts[eps_begin + t] = self.gamma ** t
if t == 0:
pt_prev = 1.0
else:
pt_prev = ps[eps_begin + t - 1]
ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t]
# O(n^3)
# ret = 0
# for t_prime in range(t, episode.count):
# gamma = self.gamma ** (t_prime - t)
# rho_t_1_t_prime = 1.0
# for k in range(t + 1, min(t_prime + 1, episode.count)):
# rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k]
# r = rewards[eps_begin + t_prime]
# ret += gamma * rho_t_1_t_prime * r
# O(n^2)
ret = 0
rho = 1
for t_ in reversed(range(t, episode.count)):
ret = rewards[eps_begin + t_] + self.gamma * rho * ret
rho = prob_ratio[eps_begin + t_]
returns[eps_begin + t] = ret
# Update before next episode
eps_begin = eps_end
indices = np.arange(batch.count)
for _ in range(self.n_iters):
minibatch_losses = []
np.random.shuffle(indices)
for idx in range(0, batch.count, self.batch_size):
idxs = indices[idx : idx + self.batch_size]
q_values, _ = self.q_model({"obs": obs[idxs]}, [], None)
q_acts = torch.gather(
q_values, -1, actions[idxs].unsqueeze(-1)
).squeeze()
loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts) ** 2
loss = torch.mean(loss)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad.clip_grad_norm_(
self.q_model.variables(), self.clip_grad_norm
)
self.optimizer.step()
minibatch_losses.append(loss.item())
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
losses.append(iter_loss)
if iter_loss < self.delta:
break
return losses
def estimate_q(
self,
obs: Union[TensorType, List[TensorType]],
actions: Union[TensorType, List[TensorType]] = None,
) -> TensorType:
"""Given `obs`, a list or array or tensor of observations,
compute the Q-values for `obs` for all actions in the action space.
If `actions` is not None, return the Q-values for the actions provided,
else return Q-values for all actions for each observation in `obs`.
"""
obs = torch.tensor(obs, device=self.device)
q_values, _ = self.q_model({"obs": obs}, [], None)
if actions is not None:
actions = torch.tensor(actions, device=self.device, dtype=int)
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
return q_values.detach()
def estimate_v(
self,
obs: Union[TensorType, List[TensorType]],
action_probs: Union[TensorType, List[TensorType]],
) -> TensorType:
"""Given `obs`, compute q-values for all actions in the action space
for each observations s in `obs`, then multiply this by `action_probs`,
the probability distribution over actions for each state s to give the
state value V(s) = sum_A pi(a|s)Q(s,a).
"""
q_values = self.estimate_q(obs)
action_probs = torch.tensor(action_probs, device=self.device)
v_values = torch.sum(q_values * action_probs, axis=-1)
return v_values.detach()

View file

@ -0,0 +1,200 @@
import unittest
import ray
from ray import tune
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.offline.estimators import (
ImportanceSampling,
WeightedImportanceSampling,
DirectMethod,
DoublyRobust,
)
from ray.rllib.offline.json_reader import JsonReader
from pathlib import Path
import os
import numpy as np
import gym
class TestOPE(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True)
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
env_name = "CartPole-v0"
cls.gamma = 0.99
train_steps = 20000
n_batches = 20 # Approx. equal to n_episodes
n_eval_episodes = 100
config = (
DQNConfig()
.environment(env=env_name)
.training(gamma=cls.gamma)
.rollouts(num_rollout_workers=3)
.exploration(
explore=True,
exploration_config={
"type": "SoftQ",
"temperature": 1.0,
},
)
.framework("torch")
.rollouts(batch_mode="complete_episodes")
)
cls.trainer = config.build()
# Train DQN for evaluation policy
tune.run(
"DQN",
config=config.to_dict(),
stop={"timesteps_total": train_steps},
verbose=0,
)
# Read n_batches of data
reader = JsonReader(data_file)
cls.batch = reader.next()
for _ in range(n_batches - 1):
cls.batch = cls.batch.concat(reader.next())
cls.n_episodes = len(cls.batch.split_by_episode())
print("Episodes:", cls.n_episodes, "Steps:", cls.batch.count)
cls.mean_ret = {}
cls.std_ret = {}
# Simulate Monte-Carlo rollouts
mc_ret = []
env = gym.make(env_name)
for _ in range(n_eval_episodes):
obs = env.reset()
done = False
rewards = []
while not done:
act = cls.trainer.compute_single_action(obs)
obs, reward, done, _ = env.step(act)
rewards.append(reward)
ret = 0
for r in reversed(rewards):
ret = r + cls.gamma * ret
mc_ret.append(ret)
cls.mean_ret["simulation"] = np.mean(mc_ret)
cls.std_ret["simulation"] = np.std(mc_ret)
# Optional configs for the model-based estimators
cls.model_config = {"k": 2, "n_iters": 10}
ray.shutdown()
@classmethod
def tearDownClass(cls):
print("Mean:", cls.mean_ret)
print("Stddev:", cls.std_ret)
ray.shutdown()
def test_is(self):
name = "is"
estimator = ImportanceSampling(
name=name,
policy=self.trainer.get_policy(),
gamma=self.gamma,
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_wis(self):
name = "wis"
estimator = WeightedImportanceSampling(
name=name,
policy=self.trainer.get_policy(),
gamma=self.gamma,
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_dm_qreg(self):
name = "dm_qreg"
estimator = DirectMethod(
name=name,
policy=self.trainer.get_policy(),
gamma=self.gamma,
q_model_type="qreg",
**self.model_config,
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_dm_fqe(self):
name = "dm_fqe"
estimator = DirectMethod(
name=name,
policy=self.trainer.get_policy(),
gamma=self.gamma,
q_model_type="fqe",
**self.model_config,
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_dr_qreg(self):
name = "dr_qreg"
estimator = DoublyRobust(
name=name,
policy=self.trainer.get_policy(),
gamma=self.gamma,
q_model_type="qreg",
**self.model_config,
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_dr_fqe(self):
name = "dr_fqe"
estimator = DoublyRobust(
name=name,
policy=self.trainer.get_policy(),
gamma=self.gamma,
q_model_type="fqe",
**self.model_config,
)
estimator.process(self.batch)
estimates = estimator.get_metrics()
assert len(estimates) == self.n_episodes
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_ope_in_trainer(self):
# TODO (rohan): Add performance tests for off_policy_estimation_methods,
# with fixed seeds and hyperparameters
pass
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,56 +1,66 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.typing import SampleBatchType
import numpy as np
@DeveloperAPI
class WeightedImportanceSampling(OffPolicyEstimator):
"""The weighted step-wise IS estimator.
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf,
https://arxiv.org/pdf/1911.06854.pdf"""
def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
@override(OffPolicyEstimator)
def __init__(self, name: str, policy: Policy, gamma: float):
super().__init__(name, policy, gamma)
self.filter_values = []
self.filter_counts = []
@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
for sub_batch in batch.split_by_episode():
rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"]
new_prob = np.exp(self.action_log_likelihood(sub_batch))
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)
# calculate importance ratios
p = []
for t in range(sub_batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0
# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0
# calculate stepwise weighted IS estimate
v_old = 0.0
v_new = 0.0
for t in range(sub_batch.count):
v_old += rewards[t] * self.gamma ** t
w_t = self.filter_values[t] / self.filter_counts[t]
v_new += p[t] / w_t * rewards[t] * self.gamma ** t
# calculate stepwise weighted IS estimate
V_prev, V_step_WIS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma ** t
w_t = self.filter_values[t] / self.filter_counts[t]
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t
estimation = OffPolicyEstimate(
"weighted_importance_sampling",
{
"V_prev": V_prev,
"V_step_WIS": V_step_WIS,
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
},
)
return estimation
estimates.append(
OffPolicyEstimate(
self.name,
{
"v_old": v_old,
"v_new": v_new,
"v_gain": v_new / max(1e-8, v_old),
},
)
)
return estimates

View file

@ -1,167 +1,11 @@
from collections import namedtuple
import logging
import numpy as np
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType, SampleBatchType
from typing import List
logger = logging.getLogger(__name__)
OffPolicyEstimate = DeveloperAPI(
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
from ray.rllib.offline.estimators.off_policy_estimator import ( # noqa: F401
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.utils.deprecation import deprecation_warning
@DeveloperAPI
class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""
def __init__(self, policy: Policy, gamma: float):
"""Initializes an OffPolicyEstimator instance.
Args:
policy: Policy to evaluate.
gamma: Discount factor of the environment.
"""
self.policy = policy
self.gamma = gamma
self.new_estimates = []
@classmethod
@DeveloperAPI
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Creates an off-policy estimator from an IOContext object.
Extracts Policy and gamma (discount factor) information from the
IOContext.
Args:
ioctx: The IOContext object to create the OffPolicyEstimator
from.
Returns:
The OffPolicyEstimator object created from the IOContext object.
"""
gamma = ioctx.worker.policy_config["gamma"]
# Grab a reference to the current model
keys = list(ioctx.worker.policy_map.keys())
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `off_policy_estimation_methods: []` to resolve this."
)
policy = ioctx.worker.get_policy(keys[0])
return cls(policy, gamma)
@DeveloperAPI
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
"""Returns an off policy estimate for the given batch of experiences.
The batch will at most only contain data from one episode,
but it may also only be a fragment of an episode.
Args:
batch: The batch to calculate the off policy estimate (OPE) on.
Returns:
The off-policy estimates (OPE) calculated on the given batch.
"""
raise NotImplementedError
@DeveloperAPI
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
"""Returns log likelihoods for actions in given batch for policy.
Computes likelihoods by passing the observations through the current
policy's `compute_log_likelihoods()` method.
Args:
batch: The SampleBatch or MultiAgentBatch to calculate action
log likelihoods from. This batch/batches must contain OBS
and ACTIONS keys.
Returns:
The log likelihoods of the actions in the batch, given the
observations and the policy.
"""
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
num_state_inputs += 1
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
actions=batch[SampleBatch.ACTIONS],
obs_batch=batch[SampleBatch.OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=True,
)
log_likelihoods = convert_to_numpy(log_likelihoods)
return np.exp(log_likelihoods)
@DeveloperAPI
def process(self, batch: SampleBatchType) -> None:
"""Computes off policy estimates (OPE) on batch and stores results.
Thus-far collected results can be retrieved then by calling
`self.get_metrics` (which flushes the internal results storage).
Args:
batch: The batch to process (call `self.estimate()` on) and
store results (OPEs) for.
"""
self.new_estimates.append(self.estimate(batch))
@DeveloperAPI
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"""Checks if we support off policy estimation (OPE) on given batch.
Args:
batch: The batch to check.
Raises:
ValueError: In case `action_prob` key is not in batch OR batch
is a MultiAgentBatch.
"""
if isinstance(batch, MultiAgentBatch):
raise ValueError(
"off-policy estimation is not implemented for multi-agent batches. "
"You can set `off_policy_estimation_methods: []` to resolve this."
)
if "action_prob" not in batch:
raise ValueError(
"Off-policy estimation is not possible unless the inputs "
"include action probabilities (i.e., the policy is stochastic "
"and emits the 'action_prob' key). For DQN this means using "
"`exploration_config: {type: 'SoftQ'}`. You can also set "
"`off_policy_estimation_methods: []` to disable estimation."
)
@DeveloperAPI
def get_metrics(self) -> List[OffPolicyEstimate]:
"""Returns list of new episode metric estimates since the last call.
Returns:
List of OffPolicyEstimate objects.
"""
out = self.new_estimates
self.new_estimates = []
return out
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False)
def create(self, *args, **kwargs):
return self.create_from_io_context(*args, **kwargs)
@Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False)
def action_prob(self, *args, **kwargs):
return self.action_log_likelihood(*args, **kwargs)
deprecation_warning(
old="ray.rllib.offline.off_policy_estimator",
new="ray.rllib.offline.estimators.off_policy_estimator",
error=False,
)

View file

@ -98,7 +98,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": self.test_dir + fw,
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
"framework": fw,
},
)
@ -141,7 +141,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": self.test_dir + fw,
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
"postprocess_inputs": True, # adds back 'advantages'
"framework": fw,
},
@ -158,7 +158,9 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": self.test_dir + fw,
"off_policy_estimation_methods": ["simulation"],
"off_policy_estimation_methods": {
"simulation": {"type": "simulation"}
},
"framework": fw,
},
)
@ -176,7 +178,7 @@ class AgentIOTest(unittest.TestCase):
env="CartPole-v0",
config={
"input": glob.glob(self.test_dir + fw + "/*.json"),
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
"rollout_fragment_length": 99,
"framework": fw,
},
@ -196,7 +198,7 @@ class AgentIOTest(unittest.TestCase):
"sampler": 0.9,
},
"train_batch_size": 2000,
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
"framework": fw,
},
)
@ -234,7 +236,9 @@ class AgentIOTest(unittest.TestCase):
config={
"num_workers": 0,
"input": self.test_dir,
"off_policy_estimation_methods": ["simulation"],
"off_policy_estimation_methods": {
"simulation": {"type": "simulation"}
},
"train_batch_size": 2000,
"multiagent": {
"policies": {"policy_1", "policy_2"},
@ -276,7 +280,7 @@ class AgentIOTest(unittest.TestCase):
config={
"input": input_procedure,
"input_config": {"input_files": self.test_dir + fw},
"off_policy_estimation_methods": [],
"off_policy_estimation_methods": {},
"framework": fw,
},
)

View file

@ -69,7 +69,7 @@ class NestedActionSpacesTest(unittest.TestCase):
config["output"] = tmp_dir
# Switch off OPE as we don't write action-probs.
# TODO: We should probably always write those if `output` is given.
config["off_policy_estimation_methods"] = []
config["off_policy_estimation_methods"] = {}
# Pretend actions in offline files are already normalized.
config["actions_in_input_normalized"] = True