mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib]: Doubly Robust Off-Policy Evaluation. (#25056)
This commit is contained in:
parent
429d0f0eee
commit
a9d8da0100
29 changed files with 1378 additions and 315 deletions
|
@ -521,7 +521,7 @@ You can configure any Trainer to launch a policy server with the following confi
|
|||
# Use the existing trainer process to run the server.
|
||||
"num_workers": 0,
|
||||
# Disable OPE, since the rollouts are coming from online clients.
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
}
|
||||
|
||||
Clients can then connect in either *local* or *remote* inference mode. In local inference mode, copies of the policy are downloaded from the server and cached on the client for a configurable period of time. This allows actions to be computed by the client without requiring a network round trip each time. In remote inference mode, each computed action requires a network call to the server.
|
||||
|
|
|
@ -48,7 +48,7 @@ Then, we can tell DQN to train using these previously generated experiences with
|
|||
--env=CartPole-v0 \
|
||||
--config='{
|
||||
"input": "/tmp/cartpole-out",
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
"explore": false}'
|
||||
|
||||
.. _is:
|
||||
|
@ -62,7 +62,14 @@ Then, we can tell DQN to train using these previously generated experiences with
|
|||
--env=CartPole-v0 \
|
||||
--config='{
|
||||
"input": "/tmp/cartpole-out",
|
||||
"off_policy_estimation_methods": ["is", "wis"],
|
||||
"off_policy_estimation_methods": {
|
||||
"is": {
|
||||
"type": "ImportanceSampling",
|
||||
},
|
||||
"wis": {
|
||||
"type": "WeightedImportanceSampling",
|
||||
}
|
||||
},
|
||||
"exploration_config": {
|
||||
"type": "SoftQ",
|
||||
"temperature": 1.0,
|
||||
|
@ -275,10 +282,10 @@ You can configure experience input for an agent using the following options:
|
|||
# - Any subclass of OffPolicyEstimator, e.g.
|
||||
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
|
||||
# subclass.
|
||||
"off_policy_estimation_methods": [
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
],
|
||||
"off_policy_estimation_methods": {
|
||||
ImportanceSampling: None,
|
||||
WeightedImportanceSampling: None,
|
||||
},
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behavior* policy, which is typically undesirable for
|
||||
|
|
|
@ -574,10 +574,14 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# - Any subclass of OffPolicyEstimator, e.g.
|
||||
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
|
||||
# subclass.
|
||||
"off_policy_estimation_methods": [
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
],
|
||||
"off_policy_estimation_methods": {
|
||||
"is": {
|
||||
"type": ImportanceSampling,
|
||||
},
|
||||
"wis": {
|
||||
"type": WeightedImportanceSampling,
|
||||
}
|
||||
},
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behavior* policy, which is typically undesirable for
|
||||
|
|
|
@ -26,6 +26,7 @@ def deep_update(
|
|||
new_keys_allowed: bool = False,
|
||||
allow_new_subkey_list: Optional[List[str]] = None,
|
||||
override_all_if_type_changes: Optional[List[str]] = None,
|
||||
override_all_key_list: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
|
||||
|
@ -37,22 +38,29 @@ def deep_update(
|
|||
original: Dictionary with default values.
|
||||
new_dict: Dictionary with values to be updated
|
||||
new_keys_allowed: Whether new keys are allowed.
|
||||
allow_new_subkey_list (Optional[List[str]]): List of keys that
|
||||
allow_new_subkey_list: List of keys that
|
||||
correspond to dict values where new subkeys can be introduced.
|
||||
This is only at the top level.
|
||||
override_all_if_type_changes(Optional[List[str]]): List of top level
|
||||
override_all_if_type_changes: List of top level
|
||||
keys with value=dict, for which we always simply override the
|
||||
entire value (dict), iff the "type" key in that value dict changes.
|
||||
override_all_key_list: List of top level keys
|
||||
for which we override the entire value if the key is in the new_dict.
|
||||
"""
|
||||
allow_new_subkey_list = allow_new_subkey_list or []
|
||||
override_all_if_type_changes = override_all_if_type_changes or []
|
||||
override_all_key_list = override_all_key_list or []
|
||||
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
|
||||
# Both orginal value and new one are dicts.
|
||||
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
||||
if (
|
||||
isinstance(original.get(k), dict)
|
||||
and isinstance(value, dict)
|
||||
and k not in override_all_key_list
|
||||
):
|
||||
# Check old type vs old one. If different, override entire value.
|
||||
if (
|
||||
k in override_all_if_type_changes
|
||||
|
@ -63,10 +71,20 @@ def deep_update(
|
|||
original[k] = value
|
||||
# Allowed key -> ok to add new subkeys.
|
||||
elif k in allow_new_subkey_list:
|
||||
deep_update(original[k], value, True)
|
||||
deep_update(
|
||||
original[k],
|
||||
value,
|
||||
True,
|
||||
override_all_key_list=override_all_key_list,
|
||||
)
|
||||
# Non-allowed key.
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed)
|
||||
deep_update(
|
||||
original[k],
|
||||
value,
|
||||
new_keys_allowed,
|
||||
override_all_key_list=override_all_key_list,
|
||||
)
|
||||
# Original value not a dict OR new value not a dict:
|
||||
# Override entire value.
|
||||
else:
|
||||
|
|
|
@ -11,7 +11,7 @@ marwil-halfcheetahbulletenv-v0:
|
|||
input: ["~/halfcheetah_expert_sac.zip"]
|
||||
actions_in_input_normalized: true
|
||||
# Switch off input evaluation (data does not contain action probs).
|
||||
off_policy_estimation_methods: []
|
||||
off_policy_estimation_methods: {}
|
||||
|
||||
num_gpus: 1
|
||||
|
||||
|
|
17
rllib/BUILD
17
rllib/BUILD
|
@ -24,6 +24,7 @@
|
|||
# - `evaluation` directory tests.
|
||||
# - `execution` directory tests.
|
||||
# - `models` directory tests.
|
||||
# - `offline` directory tests.
|
||||
# - `policy` directory tests.
|
||||
# - `utils` directory tests.
|
||||
|
||||
|
@ -1140,7 +1141,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": {\"wis\": {\"type\": \"wis\"}, \"is\": {\"type\": \"is\"}}, \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1531,6 +1532,20 @@ py_test(
|
|||
srcs = ["models/tests/test_preprocessors.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Offline
|
||||
# rllib/offline/
|
||||
#
|
||||
# Tag: offline
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_ope",
|
||||
tags = ["team:ml", "offline", "torch_only"],
|
||||
size = "medium",
|
||||
srcs = ["offline/estimators/tests/test_ope.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Policies
|
||||
# rllib/policy/
|
||||
|
|
|
@ -34,7 +34,6 @@ from ray.rllib.agents.trainer_config import TrainerConfig
|
|||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.utils import _gym_env_creator
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.evaluation.metrics import (
|
||||
collect_episodes,
|
||||
collect_metrics,
|
||||
|
@ -189,6 +188,9 @@ class Trainer(Trainable):
|
|||
"replay_buffer_config",
|
||||
]
|
||||
|
||||
# List of keys that are always fully overridden if present in any dict or sub-dict
|
||||
_override_all_key_list = ["off_policy_estimation_methods"]
|
||||
|
||||
@PublicAPI
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1748,6 +1750,7 @@ class Trainer(Trainable):
|
|||
_allow_unknown_configs,
|
||||
cls._allow_unknown_subkeys,
|
||||
cls._override_all_subkeys_if_type_changes,
|
||||
cls._override_all_key_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -1924,9 +1927,22 @@ class Trainer(Trainable):
|
|||
error=False,
|
||||
)
|
||||
config["off_policy_estimation_methods"] = input_evaluation
|
||||
config["off_policy_estimation_methods"] = force_list(
|
||||
config["off_policy_estimation_methods"]
|
||||
)
|
||||
if isinstance(config["off_policy_estimation_methods"], list) or isinstance(
|
||||
config["off_policy_estimation_methods"], tuple
|
||||
):
|
||||
ope_dict = {
|
||||
str(ope): {"type": ope} for ope in self.off_policy_estimation_methods
|
||||
}
|
||||
deprecation_warning(
|
||||
old="config.off_policy_estimation_methods={}".format(
|
||||
self.off_policy_estimation_methods
|
||||
),
|
||||
new="config.off_policy_estimation_methods={}".format(
|
||||
ope_dict,
|
||||
),
|
||||
error=False,
|
||||
)
|
||||
config["off_policy_estimation_methods"] = ope_dict
|
||||
|
||||
# Check model config.
|
||||
# If no preprocessing, propagate into model's config as well
|
||||
|
|
|
@ -169,7 +169,7 @@ class TrainerConfig:
|
|||
self.input_ = "sampler"
|
||||
self.input_config = {}
|
||||
self.actions_in_input_normalized = False
|
||||
self.off_policy_estimation_methods = []
|
||||
self.off_policy_estimation_methods = {}
|
||||
self.postprocess_inputs = False
|
||||
self.shuffle_buffer_size = 0
|
||||
self.output = None
|
||||
|
@ -932,15 +932,22 @@ class TrainerConfig:
|
|||
when the offline file has been generated by another RLlib algorithm
|
||||
(e.g. PPO or SAC), while "normalize_actions" was set to True.
|
||||
input_evaluation: DEPRECATED: Use `off_policy_estimation_methods` instead!
|
||||
off_policy_estimation_methods: Specify how to evaluate the current policy.
|
||||
off_policy_estimation_methods: Specify how to evaluate the current policy,
|
||||
along with any optional config parameters.
|
||||
This only has an effect when reading offline experiences
|
||||
("input" is not "sampler").
|
||||
Available options:
|
||||
- "simulation": Run the environment in the background, but use
|
||||
this data for evaluation only and not for learning.
|
||||
- Any subclass of OffPolicyEstimator, e.g.
|
||||
ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
|
||||
subclass.
|
||||
Available keys:
|
||||
- {ope_method_name: {"type": ope_type, ...}} where `ope_method_name`
|
||||
is a user-defined string to save the OPE results under, and
|
||||
`ope_type` can be:
|
||||
- "simulation": Run the environment in the background, but use
|
||||
this data for evaluation only and not for learning.
|
||||
- Any subclass of OffPolicyEstimator, e.g.
|
||||
ray.rllib.offline.estimators.is::ImportanceSampling
|
||||
or your own custom subclass.
|
||||
You can also add additional config arguments to be passed to the
|
||||
OffPolicyEstimator in the dict, e.g.
|
||||
{"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}}
|
||||
postprocess_inputs: Whether to run postprocess_trajectory() on the
|
||||
trajectory fragments from offline inputs. Note that postprocessing will
|
||||
be done using the *current* policy, not the *behavior* policy, which
|
||||
|
@ -978,9 +985,25 @@ class TrainerConfig:
|
|||
),
|
||||
error=True,
|
||||
)
|
||||
self.off_policy_estimation_methods = input_evaluation
|
||||
if isinstance(off_policy_estimation_methods, list) or isinstance(
|
||||
off_policy_estimation_methods, tuple
|
||||
):
|
||||
ope_dict = {
|
||||
str(ope): {"type": ope} for ope in off_policy_estimation_methods
|
||||
}
|
||||
deprecation_warning(
|
||||
old="offline_data(off_policy_estimation_methods={}".format(
|
||||
off_policy_estimation_methods
|
||||
),
|
||||
new="offline_data(off_policy_estimation_methods={}".format(
|
||||
ope_dict,
|
||||
),
|
||||
error=False,
|
||||
)
|
||||
off_policy_estimation_methods = ope_dict
|
||||
if off_policy_estimation_methods is not None:
|
||||
self.off_policy_estimation_methods = off_policy_estimation_methods
|
||||
|
||||
if postprocess_inputs is not None:
|
||||
self.postprocess_inputs = postprocess_inputs
|
||||
if shuffle_buffer_size is not None:
|
||||
|
|
|
@ -49,7 +49,7 @@ class BCConfig(MARWILConfig):
|
|||
# not important for behavioral cloning.
|
||||
self.postprocess_inputs = False
|
||||
# No reward estimation.
|
||||
self.off_policy_estimation_methods = []
|
||||
self.off_policy_estimation_methods = {}
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ class CQLConfig(SACConfig):
|
|||
|
||||
# Changes to Trainer's/SACConfig's default:
|
||||
# .offline_data()
|
||||
self.off_policy_estimation_methods = []
|
||||
self.off_policy_estimation_methods = {}
|
||||
|
||||
# .reporting()
|
||||
self.min_sample_timesteps_per_reporting = 0
|
||||
|
|
|
@ -5,6 +5,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.algorithms import cql
|
||||
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
||||
|
@ -51,7 +52,7 @@ class TestCQL(unittest.TestCase):
|
|||
# RLlib algorithm (e.g. PPO or SAC).
|
||||
actions_in_input_normalized=False,
|
||||
# Switch on off-policy evaluation.
|
||||
off_policy_estimation_methods=["is"],
|
||||
off_policy_estimation_methods={"is": {"type": ImportanceSampling}},
|
||||
)
|
||||
.training(
|
||||
clip_actions=False,
|
||||
|
|
|
@ -103,9 +103,10 @@ class MARWILConfig(TrainerConfig):
|
|||
# the same line.
|
||||
self.input_ = "sampler"
|
||||
# Use importance sampling estimators for reward.
|
||||
self.off_policy_estimation_methods = [
|
||||
ImportanceSampling, WeightedImportanceSampling
|
||||
]
|
||||
self.off_policy_estimation_methods = {
|
||||
"is": {"type": ImportanceSampling},
|
||||
"wis": {"type": WeightedImportanceSampling},
|
||||
}
|
||||
self.postprocess_inputs = True
|
||||
self.lr = 1e-4
|
||||
self.train_batch_size = 2000
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
|||
import ray
|
||||
from ray import ObjectRef
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
|
|
|
@ -33,8 +33,14 @@ from ray.rllib.evaluation.metrics import RolloutMetrics
|
|||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
|
||||
from ray.rllib.offline.estimators import ImportanceSampling, WeightedImportanceSampling
|
||||
from ray.rllib.offline.estimators import (
|
||||
OffPolicyEstimate,
|
||||
OffPolicyEstimator,
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
DirectMethod,
|
||||
DoublyRobust,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.policy_map import PolicyMap
|
||||
|
@ -242,7 +248,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
input_creator: Callable[
|
||||
[IOContext], InputReader
|
||||
] = lambda ioctx: ioctx.default_sampler_input(),
|
||||
off_policy_estimation_methods: List[str] = frozenset([]),
|
||||
off_policy_estimation_methods: Optional[Dict[str, Dict]] = None,
|
||||
output_creator: Callable[
|
||||
[IOContext], OutputWriter
|
||||
] = lambda ioctx: NoopOutput(),
|
||||
|
@ -339,14 +345,25 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
DefaultCallbacks for training/policy/rollout-worker callbacks.
|
||||
input_creator: Function that returns an InputReader object for
|
||||
loading previous generated experiences.
|
||||
off_policy_estimation_methods: How to evaluate the policy performance.
|
||||
Setting this only makes sense when the input is reading offline data.
|
||||
Available options:
|
||||
- "simulation" (str): Run the environment in the background, but use
|
||||
off_policy_estimation_methods: A dict that specifies how to
|
||||
evaluate the current policy.
|
||||
This only has an effect when reading offline experiences
|
||||
("input" is not "sampler").
|
||||
Available key-value pairs:
|
||||
- {"simulation": None}: Run the environment in the background, but use
|
||||
this data for evaluation only and not for learning.
|
||||
- Any subclass (type) of the OffPolicyEstimator API class, e.g.
|
||||
`ray.rllib.offline.estimators.importance_sampling::ImportanceSampling`
|
||||
- {ope_name: {"type": ope_type, args}}. where `ope_name` is an arbitrary
|
||||
string under which the metrics for this OPE estimator are saved,
|
||||
and `ope_type` can be any subclass of OffPolicyEstimator, e.g.
|
||||
ray.rllib.offline.estimators::ImportanceSampling
|
||||
or your own custom subclass.
|
||||
You can also add additional config arguments to be passed to the
|
||||
OffPolicyEstimator e.g.
|
||||
off_policy_estimation_methods = {
|
||||
"dr_qreg": {"type": DoublyRobust, "q_model_type": "qreg"},
|
||||
"dm_64": {"type": DirectMethod, "batch_size": 64},
|
||||
}
|
||||
See ray/rllib/offline/estimators for more information.
|
||||
output_creator: Function that returns an OutputWriter object for
|
||||
saving generated experiences.
|
||||
remote_worker_envs: If using num_envs_per_worker > 1,
|
||||
|
@ -702,41 +719,50 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
log_dir, policy_config, worker_index, self
|
||||
)
|
||||
self.reward_estimators: List[OffPolicyEstimator] = []
|
||||
for method in off_policy_estimation_methods:
|
||||
if method == "is":
|
||||
method = ImportanceSampling
|
||||
ope_types = {
|
||||
"is": ImportanceSampling,
|
||||
"wis": WeightedImportanceSampling,
|
||||
"dm": DirectMethod,
|
||||
"dr": DoublyRobust,
|
||||
}
|
||||
off_policy_estimation_methods = off_policy_estimation_methods or {}
|
||||
for name, method_config in off_policy_estimation_methods.items():
|
||||
method_type = method_config.pop("type")
|
||||
if method_type in ope_types:
|
||||
deprecation_warning(
|
||||
old="config.off_policy_estimation_methods=[is]",
|
||||
new="from ray.rllib.offline.estimators import "
|
||||
f"{method.__name__}; config.off_policy_estimation_methods="
|
||||
f"[{method.__name__}]",
|
||||
old=method_type,
|
||||
new=str(ope_types[method_type]),
|
||||
error=False,
|
||||
)
|
||||
elif method == "wis":
|
||||
method = WeightedImportanceSampling
|
||||
deprecation_warning(
|
||||
old="config.off_policy_estimation_methods=[wis]",
|
||||
new="from ray.rllib.offline.estimators import "
|
||||
f"{method.__name__}; config.off_policy_estimation_methods="
|
||||
f"[{method.__name__}]",
|
||||
error=False,
|
||||
)
|
||||
|
||||
if method == "simulation":
|
||||
method_type = ope_types[method_type]
|
||||
if name == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics."
|
||||
)
|
||||
sample_async = True
|
||||
elif isinstance(method, type) and issubclass(method, OffPolicyEstimator):
|
||||
# TODO: Allow for this to be a full classpath string as well, then construct
|
||||
# this with our `from_config` util.
|
||||
elif isinstance(method_type, type) and issubclass(
|
||||
method_type, OffPolicyEstimator
|
||||
):
|
||||
gamma = self.io_context.worker.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(self.io_context.worker.policy_map.keys())
|
||||
if len(keys) > 1:
|
||||
raise NotImplementedError(
|
||||
"Off-policy estimation is not implemented for multi-agent. "
|
||||
"You can set `input_evaluation: []` to resolve this."
|
||||
)
|
||||
policy = self.io_context.worker.get_policy(keys[0])
|
||||
self.reward_estimators.append(
|
||||
method.create_from_io_context(self.io_context)
|
||||
method_type(name=name, policy=policy, gamma=gamma, **method_config)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown evaluation method: {method}! Must be "
|
||||
"either `simulation` or a sub-class of ray.rllib.offline."
|
||||
"off_policy_estimator::OffPolicyEstimator"
|
||||
f"Unknown off_policy_estimation type: {method_type}! Must be "
|
||||
"either `simulation|is|wis|dm|dr` or a sub-class of ray.rllib."
|
||||
"offline.estimators.off_policy_estimator::OffPolicyEstimator"
|
||||
)
|
||||
|
||||
render = False
|
||||
|
@ -866,9 +892,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
# Do off-policy estimation, if needed.
|
||||
if self.reward_estimators:
|
||||
for sub_batch in batch.split_by_episode():
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(sub_batch)
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(batch)
|
||||
|
||||
if log_once("sample_end"):
|
||||
logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch)))
|
||||
|
@ -1138,7 +1163,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
out = self.sampler.get_metrics()
|
||||
else:
|
||||
out = []
|
||||
# Get metrics from our reward-estimators (if any).
|
||||
# Get metrics from our reward estimators (if any).
|
||||
for m in self.reward_estimators:
|
||||
out.extend(m.get_metrics())
|
||||
|
||||
|
|
|
@ -627,7 +627,7 @@ class WorkerSet:
|
|||
)
|
||||
|
||||
if config["input"] == "sampler":
|
||||
off_policy_estimation_methods = []
|
||||
off_policy_estimation_methods = {}
|
||||
else:
|
||||
off_policy_estimation_methods = config["off_policy_estimation_methods"]
|
||||
|
||||
|
|
|
@ -165,7 +165,7 @@ if __name__ == "__main__":
|
|||
# Use n worker processes to listen on different ports.
|
||||
"num_workers": args.num_workers,
|
||||
# Disable OPE, since the rollouts are coming from online clients.
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
# Create a "chatty" client/server or not.
|
||||
"callbacks": MyCallbacks if args.callbacks_verbose else None,
|
||||
# DL framework to use.
|
||||
|
|
|
@ -132,7 +132,7 @@ if __name__ == "__main__":
|
|||
# Use n worker processes to listen on different ports.
|
||||
"num_workers": args.num_workers,
|
||||
# Disable OPE, since the rollouts are coming from online clients.
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
# Other settings.
|
||||
"train_batch_size": 256,
|
||||
"rollout_fragment_length": 20,
|
||||
|
|
|
@ -2,8 +2,18 @@ from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
|
|||
from ray.rllib.offline.estimators.weighted_importance_sampling import (
|
||||
WeightedImportanceSampling,
|
||||
)
|
||||
from ray.rllib.offline.estimators.direct_method import DirectMethod
|
||||
from ray.rllib.offline.estimators.doubly_robust import DoublyRobust
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import (
|
||||
OffPolicyEstimate,
|
||||
OffPolicyEstimator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OffPolicyEstimator",
|
||||
"OffPolicyEstimate",
|
||||
"ImportanceSampling",
|
||||
"WeightedImportanceSampling",
|
||||
"DirectMethod",
|
||||
"DoublyRobust",
|
||||
]
|
||||
|
|
170
rllib/offline/estimators/direct_method.py
Normal file
170
rllib/offline/estimators/direct_method.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
from typing import Tuple, List, Generator
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import (
|
||||
OffPolicyEstimator,
|
||||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
from ray.rllib.offline.estimators.qreg_torch_model import QRegTorchModel
|
||||
from gym.spaces import Discrete
|
||||
import numpy as np
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
# TODO (rohan): replace with AIR/parallel workers
|
||||
# (And find a better name than `should_train`)
|
||||
@DeveloperAPI
|
||||
def k_fold_cv(
|
||||
batch: SampleBatchType, k: int, should_train: bool = True
|
||||
) -> Generator[Tuple[List[SampleBatch]], None, None]:
|
||||
"""Utility function that returns a k-fold cross validation generator
|
||||
over episodes from the given batch. If the number of episodes in the
|
||||
batch is less than `k` or `should_train` is set to False, yields an empty
|
||||
list for train_episodes and all the episodes in test_episodes.
|
||||
|
||||
Args:
|
||||
batch: A SampleBatch of episodes to split
|
||||
k: Number of cross-validation splits
|
||||
should_train: True by default. If False, yield [], [episodes].
|
||||
|
||||
Returns:
|
||||
A tuple with two lists of SampleBatches (train_episodes, test_episodes)
|
||||
"""
|
||||
episodes = batch.split_by_episode()
|
||||
n_episodes = len(episodes)
|
||||
if n_episodes < k or not should_train:
|
||||
yield [], episodes
|
||||
return
|
||||
n_fold = n_episodes // k
|
||||
for i in range(k):
|
||||
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
|
||||
if i != k - 1:
|
||||
test_episodes = episodes[i * n_fold : (i + 1) * n_fold]
|
||||
else:
|
||||
# Append remaining episodes onto the last test_episodes
|
||||
test_episodes = episodes[i * n_fold :]
|
||||
yield train_episodes, test_episodes
|
||||
return
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class DirectMethod(OffPolicyEstimator):
|
||||
"""The Direct Method estimator.
|
||||
|
||||
DM estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
policy: Policy,
|
||||
gamma: float,
|
||||
q_model_type: str = "fqe",
|
||||
k: int = 5,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a Direct Method OPE Estimator.
|
||||
|
||||
Args:
|
||||
name: string to save OPE results under
|
||||
policy: Policy to evaluate.
|
||||
gamma: Discount factor of the environment.
|
||||
q_model_type: Either "fqe" for Fitted Q-Evaluation
|
||||
or "qreg" for Q-Regression, or a custom model that implements:
|
||||
- `estimate_q(states,actions)`
|
||||
- `estimate_v(states, action_probs)`
|
||||
k: k-fold cross validation for training model and evaluating OPE
|
||||
kwargs: Optional arguments for the specified Q model
|
||||
"""
|
||||
|
||||
super().__init__(name, policy, gamma)
|
||||
# TODO (rohan): Add support for continuous action spaces
|
||||
assert isinstance(
|
||||
policy.action_space, Discrete
|
||||
), "DM Estimator only supports discrete action spaces!"
|
||||
assert (
|
||||
policy.config["batch_mode"] == "complete_episodes"
|
||||
), "DM Estimator only supports `batch_mode`=`complete_episodes`"
|
||||
|
||||
# TODO (rohan): Add support for TF!
|
||||
if policy.framework == "torch":
|
||||
if q_model_type == "qreg":
|
||||
model_cls = QRegTorchModel
|
||||
elif q_model_type == "fqe":
|
||||
model_cls = FQETorchModel
|
||||
else:
|
||||
assert hasattr(
|
||||
q_model_type, "estimate_q"
|
||||
), "q_model_type must implement `estimate_q`!"
|
||||
assert hasattr(
|
||||
q_model_type, "estimate_v"
|
||||
), "q_model_type must implement `estimate_v`!"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__}"
|
||||
"estimator only supports `policy.framework`=`torch`"
|
||||
)
|
||||
|
||||
self.model = model_cls(
|
||||
policy=policy,
|
||||
gamma=gamma,
|
||||
**kwargs,
|
||||
)
|
||||
self.k = k
|
||||
self.losses = []
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(
|
||||
self, batch: SampleBatchType, should_train: bool = True
|
||||
) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
estimates = []
|
||||
# Split data into train and test using k-fold cross validation
|
||||
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
|
||||
|
||||
# Train Q-function
|
||||
if train_episodes:
|
||||
# Reinitialize model
|
||||
self.model.reset()
|
||||
train_batch = SampleBatch.concat_samples(train_episodes)
|
||||
losses = self.train(train_batch)
|
||||
self.losses.append(losses)
|
||||
|
||||
# Calculate direct method OPE estimates
|
||||
for episode in test_episodes:
|
||||
rewards = episode["rewards"]
|
||||
v_old = 0.0
|
||||
v_new = 0.0
|
||||
for t in range(episode.count):
|
||||
v_old += rewards[t] * self.gamma ** t
|
||||
|
||||
init_step = episode[0:1]
|
||||
init_obs = np.array([init_step[SampleBatch.OBS]])
|
||||
all_actions = np.arange(self.policy.action_space.n, dtype=float)
|
||||
init_step[SampleBatch.ACTIONS] = all_actions
|
||||
action_probs = np.exp(self.action_log_likelihood(init_step))
|
||||
v_value = self.model.estimate_v(init_obs, action_probs)
|
||||
v_new = convert_to_numpy(v_value).item()
|
||||
|
||||
estimates.append(
|
||||
OffPolicyEstimate(
|
||||
self.name,
|
||||
{
|
||||
"v_old": v_old,
|
||||
"v_new": v_new,
|
||||
"v_gain": v_new / max(1e-8, v_old),
|
||||
},
|
||||
)
|
||||
)
|
||||
return estimates
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def train(self, batch: SampleBatchType):
|
||||
return self.model.train_q(batch)
|
71
rllib/offline/estimators/doubly_robust.py
Normal file
71
rllib/offline/estimators/doubly_robust.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.offline.estimators.direct_method import DirectMethod, k_fold_cv
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
import numpy as np
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class DoublyRobust(DirectMethod):
|
||||
"""The Doubly Robust (DR) estimator.
|
||||
|
||||
DR estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
@override(DirectMethod)
|
||||
def estimate(
|
||||
self, batch: SampleBatchType, should_train: bool = True
|
||||
) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
estimates = []
|
||||
# Split data into train and test using k-fold cross validation
|
||||
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
|
||||
|
||||
# Train Q-function
|
||||
if train_episodes:
|
||||
# Reinitialize model
|
||||
self.model.reset()
|
||||
train_batch = SampleBatch.concat_samples(train_episodes)
|
||||
losses = self.train(train_batch)
|
||||
self.losses.append(losses)
|
||||
|
||||
# Calculate doubly robust OPE estimates
|
||||
for episode in test_episodes:
|
||||
rewards, old_prob = episode["rewards"], episode["action_prob"]
|
||||
new_prob = np.exp(self.action_log_likelihood(episode))
|
||||
|
||||
v_old = 0.0
|
||||
v_new = 0.0
|
||||
q_values = self.model.estimate_q(
|
||||
episode[SampleBatch.OBS], episode[SampleBatch.ACTIONS]
|
||||
)
|
||||
q_values = convert_to_numpy(q_values)
|
||||
|
||||
all_actions = np.zeros([episode.count, self.policy.action_space.n])
|
||||
all_actions[:] = np.arange(self.policy.action_space.n)
|
||||
# Two transposes required for torch.distributions to work
|
||||
tmp_episode = episode.copy()
|
||||
tmp_episode[SampleBatch.ACTIONS] = all_actions.T
|
||||
action_probs = np.exp(self.action_log_likelihood(tmp_episode)).T
|
||||
v_values = self.model.estimate_v(episode[SampleBatch.OBS], action_probs)
|
||||
v_values = convert_to_numpy(v_values)
|
||||
|
||||
for t in reversed(range(episode.count)):
|
||||
v_old = rewards[t] + self.gamma * v_old
|
||||
v_new = v_values[t] + (new_prob[t] / old_prob[t]) * (
|
||||
rewards[t] + self.gamma * v_new - q_values[t]
|
||||
)
|
||||
v_new = v_new.item()
|
||||
|
||||
estimates.append(
|
||||
OffPolicyEstimate(
|
||||
self.name,
|
||||
{
|
||||
"v_old": v_old,
|
||||
"v_new": v_new,
|
||||
"v_gain": v_new / max(1e-8, v_old),
|
||||
},
|
||||
)
|
||||
)
|
||||
return estimates
|
222
rllib/offline/estimators/fqe_torch_model.py
Normal file
222
rllib/offline/estimators/fqe_torch_model.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
from ray.rllib.models.utils import get_initializer
|
||||
from ray.rllib.policy import Policy
|
||||
from typing import List, Union
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class FQETorchModel:
|
||||
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
|
||||
https://arxiv.org/pdf/1911.06854.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Policy,
|
||||
gamma: float,
|
||||
model: ModelConfigDict = None,
|
||||
n_iters: int = 160,
|
||||
lr: float = 1e-3,
|
||||
delta: float = 1e-4,
|
||||
clip_grad_norm: float = 100.0,
|
||||
batch_size: int = 32,
|
||||
tau: float = 0.05,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
policy: Policy to evaluate.
|
||||
gamma: Discount factor of the environment.
|
||||
# The ModelConfigDict for self.q_model
|
||||
model = {
|
||||
"fcnet_hiddens": [8, 8],
|
||||
"fcnet_activation": "relu",
|
||||
"vf_share_layers": True,
|
||||
},
|
||||
# Maximum number of training iterations to run on the batch
|
||||
n_iters = 160,
|
||||
# Learning rate for Q-function optimizer
|
||||
lr = 1e-3,
|
||||
# Early stopping if the mean loss < delta
|
||||
delta = 1e-4,
|
||||
# Clip gradients to this maximum value
|
||||
clip_grad_norm = 100.0,
|
||||
# Minibatch size for training Q-function
|
||||
batch_size = 32,
|
||||
# Polyak averaging factor for target Q-function
|
||||
tau = 0.05
|
||||
"""
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.observation_space = policy.observation_space
|
||||
self.action_space = policy.action_space
|
||||
|
||||
if model is None:
|
||||
model = {
|
||||
"fcnet_hiddens": [8, 8],
|
||||
"fcnet_activation": "relu",
|
||||
"vf_share_layers": True,
|
||||
}
|
||||
|
||||
self.device = self.policy.device
|
||||
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.action_space.n,
|
||||
model,
|
||||
framework="torch",
|
||||
name="TorchQModel",
|
||||
).to(self.device)
|
||||
self.target_q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.action_space.n,
|
||||
model,
|
||||
framework="torch",
|
||||
name="TargetTorchQModel",
|
||||
).to(self.device)
|
||||
self.n_iters = n_iters
|
||||
self.lr = lr
|
||||
self.delta = delta
|
||||
self.clip_grad_norm = clip_grad_norm
|
||||
self.batch_size = batch_size
|
||||
self.tau = tau
|
||||
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
|
||||
initializer = get_initializer("xavier_uniform", framework="torch")
|
||||
# Hard update target
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
def f(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
initializer(m.weight)
|
||||
|
||||
self.initializer = f
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets/Reinintializes the model weights."""
|
||||
self.q_model.apply(self.initializer)
|
||||
|
||||
def train_q(self, batch: SampleBatch) -> TensorType:
|
||||
"""Trains self.q_model using FQE loss on given batch.
|
||||
|
||||
Args:
|
||||
batch: A SampleBatch of episodes to train on
|
||||
|
||||
Returns:
|
||||
A list of losses for each training iteration
|
||||
"""
|
||||
losses = []
|
||||
for _ in range(self.n_iters):
|
||||
minibatch_losses = []
|
||||
batch.shuffle()
|
||||
for idx in range(0, batch.count, self.batch_size):
|
||||
minibatch = batch[idx : idx + self.batch_size]
|
||||
obs = torch.tensor(minibatch[SampleBatch.OBS], device=self.device)
|
||||
actions = torch.tensor(
|
||||
minibatch[SampleBatch.ACTIONS], device=self.device
|
||||
)
|
||||
rewards = torch.tensor(
|
||||
minibatch[SampleBatch.REWARDS], device=self.device
|
||||
)
|
||||
next_obs = torch.tensor(
|
||||
minibatch[SampleBatch.NEXT_OBS], device=self.device
|
||||
)
|
||||
dones = torch.tensor(minibatch[SampleBatch.DONES], device=self.device)
|
||||
|
||||
# Neccessary if policy uses recurrent/attention model
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
num_state_inputs += 1
|
||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
|
||||
# Compute action_probs for next_obs as in FQE
|
||||
all_actions = torch.zeros([minibatch.count, self.policy.action_space.n])
|
||||
all_actions[:] = torch.arange(self.policy.action_space.n)
|
||||
next_action_prob = self.policy.compute_log_likelihoods(
|
||||
actions=all_actions.T,
|
||||
obs_batch=next_obs,
|
||||
state_batches=[minibatch[k] for k in state_keys],
|
||||
prev_action_batch=minibatch[SampleBatch.ACTIONS],
|
||||
prev_reward_batch=minibatch[SampleBatch.REWARDS],
|
||||
actions_normalized=False,
|
||||
)
|
||||
next_action_prob = (
|
||||
torch.exp(next_action_prob.T).to(self.device).detach()
|
||||
)
|
||||
|
||||
q_values, _ = self.q_model({"obs": obs}, [], None)
|
||||
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
||||
with torch.no_grad():
|
||||
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
|
||||
next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
|
||||
targets = rewards + ~dones * self.gamma * next_v
|
||||
loss = (targets - q_acts) ** 2
|
||||
loss = torch.mean(loss)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad.clip_grad_norm_(
|
||||
self.q_model.variables(), self.clip_grad_norm
|
||||
)
|
||||
self.optimizer.step()
|
||||
minibatch_losses.append(loss.item())
|
||||
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
|
||||
losses.append(iter_loss)
|
||||
if iter_loss < self.delta:
|
||||
break
|
||||
self.update_target()
|
||||
return losses
|
||||
|
||||
def estimate_q(
|
||||
self,
|
||||
obs: Union[TensorType, List[TensorType]],
|
||||
actions: Union[TensorType, List[TensorType]] = None,
|
||||
) -> TensorType:
|
||||
"""Given `obs`, a list or array or tensor of observations,
|
||||
compute the Q-values for `obs` for all actions in the action space.
|
||||
If `actions` is not None, return the Q-values for the actions provided,
|
||||
else return Q-values for all actions for each observation in `obs`.
|
||||
"""
|
||||
obs = torch.tensor(obs, device=self.device)
|
||||
q_values, _ = self.q_model({"obs": obs}, [], None)
|
||||
if actions is not None:
|
||||
actions = torch.tensor(actions, device=self.device, dtype=int)
|
||||
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
||||
return q_values.detach()
|
||||
|
||||
def estimate_v(
|
||||
self,
|
||||
obs: Union[TensorType, List[TensorType]],
|
||||
action_probs: Union[TensorType, List[TensorType]],
|
||||
) -> TensorType:
|
||||
"""Given `obs`, compute q-values for all actions in the action space
|
||||
for each observations s in `obs`, then multiply this by `action_probs`,
|
||||
the probability distribution over actions for each state s to give the
|
||||
state value V(s) = sum_A pi(a|s)Q(s,a).
|
||||
"""
|
||||
q_values = self.estimate_q(obs)
|
||||
action_probs = torch.tensor(action_probs, device=self.device)
|
||||
v_values = torch.sum(q_values * action_probs, axis=-1)
|
||||
return v_values.detach()
|
||||
|
||||
def update_target(self, tau=None):
|
||||
# Update_target will be called periodically to copy Q network to
|
||||
# target Q network, using (soft) tau-synching.
|
||||
tau = tau or self.tau
|
||||
model_state_dict = self.q_model.state_dict()
|
||||
# Support partial (soft) synching.
|
||||
# If tau == 1.0: Full sync from Q-model to target Q-model.
|
||||
target_state_dict = self.target_q_model.state_dict()
|
||||
model_state_dict = {
|
||||
k: tau * model_state_dict[k] + (1 - tau) * v
|
||||
for k, v in target_state_dict.items()
|
||||
}
|
||||
|
||||
self.target_q_model.load_state_dict(model_state_dict)
|
|
@ -1,42 +1,52 @@
|
|||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import (
|
||||
OffPolicyEstimator,
|
||||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from typing import List
|
||||
import numpy as np
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ImportanceSampling(OffPolicyEstimator):
|
||||
"""The step-wise IS estimator.
|
||||
|
||||
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf,
|
||||
https://arxiv.org/pdf/1911.06854.pdf"""
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
|
||||
self.check_can_estimate_for(batch)
|
||||
estimates = []
|
||||
for sub_batch in batch.split_by_episode():
|
||||
rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"]
|
||||
new_prob = np.exp(self.action_log_likelihood(sub_batch))
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_log_likelihood(batch)
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(sub_batch.count):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(batch.count):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
# calculate stepwise IS estimate
|
||||
v_old = 0.0
|
||||
v_new = 0.0
|
||||
for t in range(sub_batch.count):
|
||||
v_old += rewards[t] * self.gamma ** t
|
||||
v_new += p[t] * rewards[t] * self.gamma ** t
|
||||
|
||||
# calculate stepwise IS estimate
|
||||
V_prev, V_step_IS = 0.0, 0.0
|
||||
for t in range(batch.count):
|
||||
V_prev += rewards[t] * self.gamma ** t
|
||||
V_step_IS += p[t] * rewards[t] * self.gamma ** t
|
||||
|
||||
estimation = OffPolicyEstimate(
|
||||
"importance_sampling",
|
||||
{
|
||||
"V_prev": V_prev,
|
||||
"V_step_IS": V_step_IS,
|
||||
"V_gain_est": V_step_IS / max(1e-8, V_prev),
|
||||
},
|
||||
)
|
||||
return estimation
|
||||
estimates.append(
|
||||
OffPolicyEstimate(
|
||||
self.name,
|
||||
{
|
||||
"v_old": v_old,
|
||||
"v_new": v_new,
|
||||
"v_gain": v_new / max(1e-8, v_old),
|
||||
},
|
||||
)
|
||||
)
|
||||
return estimates
|
||||
|
|
188
rllib/offline/estimators/off_policy_estimator.py
Normal file
188
rllib/offline/estimators/off_policy_estimator.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
from collections import namedtuple
|
||||
import logging
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.annotations import Deprecated
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OffPolicyEstimate = DeveloperAPI(
|
||||
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class OffPolicyEstimator:
|
||||
"""Interface for an off policy reward estimator."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, name: str, policy: Policy, gamma: float):
|
||||
"""Initializes an OffPolicyEstimator instance.
|
||||
|
||||
Args:
|
||||
name: string to save OPE results under
|
||||
policy: Policy to evaluate.
|
||||
gamma: Discount factor of the environment.
|
||||
"""
|
||||
self.name = name
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.new_estimates = []
|
||||
|
||||
@DeveloperAPI
|
||||
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
|
||||
"""Returns a list of off policy estimates for the given batch of episodes.
|
||||
|
||||
Args:
|
||||
batch: The batch to calculate the off policy estimates (OPE) on.
|
||||
|
||||
Returns:
|
||||
The off-policy estimates (OPE) calculated on the given batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def train(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Trains an Off-Policy Estimator on a batch of experiences.
|
||||
A model-based estimator should override this and train
|
||||
a transition, value, or reward model.
|
||||
|
||||
Args:
|
||||
batch: The batch to train the model on
|
||||
|
||||
Returns:
|
||||
any optional training/loss metrics from the model
|
||||
"""
|
||||
pass
|
||||
|
||||
@DeveloperAPI
|
||||
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Returns log likelihood for actions in given batch for policy.
|
||||
|
||||
Computes likelihoods by passing the observations through the current
|
||||
policy's `compute_log_likelihoods()` method
|
||||
|
||||
Args:
|
||||
batch: The SampleBatch or MultiAgentBatch to calculate action
|
||||
log likelihoods from. This batch/batches must contain OBS
|
||||
and ACTIONS keys.
|
||||
|
||||
Returns:
|
||||
The probabilities of the actions in the batch, given the
|
||||
observations and the policy.
|
||||
"""
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
num_state_inputs += 1
|
||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
|
||||
actions=batch[SampleBatch.ACTIONS],
|
||||
obs_batch=batch[SampleBatch.OBS],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||
actions_normalized=True,
|
||||
)
|
||||
log_likelihoods = convert_to_numpy(log_likelihoods)
|
||||
return log_likelihoods
|
||||
|
||||
@DeveloperAPI
|
||||
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
|
||||
"""Checks if we support off policy estimation (OPE) on given batch.
|
||||
|
||||
Args:
|
||||
batch: The batch to check.
|
||||
|
||||
Raises:
|
||||
ValueError: In case `action_prob` key is not in batch OR batch
|
||||
is a MultiAgentBatch.
|
||||
"""
|
||||
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
raise ValueError(
|
||||
"Off-Policy Estimation is not implemented for multi-agent batches. "
|
||||
"You can set `off_policy_estimation_methods: {}` to resolve this."
|
||||
)
|
||||
|
||||
if "action_prob" not in batch:
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the inputs "
|
||||
"include action probabilities (i.e., the policy is stochastic "
|
||||
"and emits the 'action_prob' key). For DQN this means using "
|
||||
"`exploration_config: {type: 'SoftQ'}`. You can also set "
|
||||
"`off_policy_estimation_methods: {}` to disable estimation."
|
||||
)
|
||||
|
||||
@DeveloperAPI
|
||||
def process(self, batch: SampleBatchType) -> None:
|
||||
"""Computes off policy estimates (OPE) on batch and stores results.
|
||||
Thus-far collected results can be retrieved then by calling
|
||||
`self.get_metrics` (which flushes the internal results storage).
|
||||
Args:
|
||||
batch: The batch to process (call `self.estimate()` on) and
|
||||
store results (OPEs) for.
|
||||
"""
|
||||
self.new_estimates.extend(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]:
|
||||
"""Returns list of new episode metric estimates since the last call.
|
||||
|
||||
Args:
|
||||
get_losses: If True, also return self.losses for the OPE estimator
|
||||
Returns:
|
||||
out: List of OffPolicyEstimate objects.
|
||||
losses: List of training losses for the estimator.
|
||||
"""
|
||||
out = self.new_estimates
|
||||
self.new_estimates = []
|
||||
if hasattr(self, "losses"):
|
||||
losses = self.losses
|
||||
self.losses = []
|
||||
if get_losses:
|
||||
return out, losses
|
||||
return out
|
||||
|
||||
# TODO (rohan): Remove deprecated methods; set to error=True because changing
|
||||
# from one episode per SampleBatch to full SampleBatch is a breaking change anyway
|
||||
|
||||
@Deprecated(help="OffPolicyEstimator.__init__(policy, gamma, config)", error=False)
|
||||
@classmethod
|
||||
@DeveloperAPI
|
||||
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||
"""Creates an off-policy estimator from an IOContext object.
|
||||
Extracts Policy and gamma (discount factor) information from the
|
||||
IOContext.
|
||||
Args:
|
||||
ioctx: The IOContext object to create the OffPolicyEstimator
|
||||
from.
|
||||
Returns:
|
||||
The OffPolicyEstimator object created from the IOContext object.
|
||||
"""
|
||||
gamma = ioctx.worker.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(ioctx.worker.policy_map.keys())
|
||||
if len(keys) > 1:
|
||||
raise NotImplementedError(
|
||||
"Off-policy estimation is not implemented for multi-agent. "
|
||||
"You can set `input_evaluation: []` to resolve this."
|
||||
)
|
||||
policy = ioctx.worker.get_policy(keys[0])
|
||||
config = ioctx.input_config.get("estimator_config", {})
|
||||
return cls(policy, gamma, config)
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=True)
|
||||
@DeveloperAPI
|
||||
def create(self, *args, **kwargs):
|
||||
return self.create_from_io_context(*args, **kwargs)
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.compute_log_likelihoods", error=False)
|
||||
@DeveloperAPI
|
||||
def action_prob(self, *args, **kwargs):
|
||||
return self.compute_log_likelihoods(*args, **kwargs)
|
224
rllib/offline/estimators/qreg_torch_model.py
Normal file
224
rllib/offline/estimators/qreg_torch_model.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
from ray.rllib.models.utils import get_initializer
|
||||
from ray.rllib.policy import Policy
|
||||
from typing import List, Union
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType, ModelConfigDict
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class QRegTorchModel:
|
||||
"""Pytorch implementation of the Q-Reg model from
|
||||
https://arxiv.org/pdf/1911.06854.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Policy,
|
||||
gamma: float,
|
||||
model: ModelConfigDict = None,
|
||||
n_iters: int = 160,
|
||||
lr: float = 1e-3,
|
||||
delta: float = 1e-4,
|
||||
clip_grad_norm: float = 100.0,
|
||||
batch_size: int = 32,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
policy: Policy to evaluate.
|
||||
gamma: Discount factor of the environment.
|
||||
# The ModelConfigDict for self.q_model
|
||||
model = {
|
||||
"fcnet_hiddens": [8, 8],
|
||||
"fcnet_activation": "relu",
|
||||
"vf_share_layers": True,
|
||||
},
|
||||
# Maximum number of training iterations to run on the batch
|
||||
n_iters = 160,
|
||||
# Learning rate for Q-function optimizer
|
||||
lr = 1e-3,
|
||||
# Early stopping if the mean loss < delta
|
||||
delta = 1e-4,
|
||||
# Clip gradients to this maximum value
|
||||
clip_grad_norm = 100.0,
|
||||
# Minibatch size for training Q-function
|
||||
batch_size = 32,
|
||||
"""
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.observation_space = policy.observation_space
|
||||
self.action_space = policy.action_space
|
||||
|
||||
if model is None:
|
||||
model = {
|
||||
"fcnet_hiddens": [8, 8],
|
||||
"fcnet_activation": "relu",
|
||||
"vf_share_layers": True,
|
||||
}
|
||||
|
||||
self.device = self.policy.device
|
||||
self.q_model: TorchModelV2 = ModelCatalog.get_model_v2(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.action_space.n,
|
||||
model,
|
||||
framework="torch",
|
||||
name="TorchQModel",
|
||||
).to(self.device)
|
||||
self.n_iters = n_iters
|
||||
self.lr = lr
|
||||
self.delta = delta
|
||||
self.clip_grad_norm = clip_grad_norm
|
||||
self.batch_size = batch_size
|
||||
self.optimizer = torch.optim.Adam(self.q_model.variables(), self.lr)
|
||||
initializer = get_initializer("xavier_uniform", framework="torch")
|
||||
|
||||
def f(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
initializer(m.weight)
|
||||
|
||||
self.initializer = f
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets/Reinintializes the model weights."""
|
||||
self.q_model.apply(self.initializer)
|
||||
|
||||
def train_q(self, batch: SampleBatch) -> TensorType:
|
||||
"""Trains self.q_model using Q-Reg loss on given batch.
|
||||
|
||||
Args:
|
||||
batch: A SampleBatch of episodes to train on
|
||||
|
||||
Returns:
|
||||
A list of losses for each training iteration
|
||||
"""
|
||||
losses = []
|
||||
obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
|
||||
actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device)
|
||||
ps = torch.zeros([batch.count], device=self.device)
|
||||
returns = torch.zeros([batch.count], device=self.device)
|
||||
discounts = torch.zeros([batch.count], device=self.device)
|
||||
|
||||
# Neccessary if policy uses recurrent/attention model
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
num_state_inputs += 1
|
||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
|
||||
# get rewards, old_prob, new_prob
|
||||
rewards = batch[SampleBatch.REWARDS]
|
||||
old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP])
|
||||
new_log_prob = (
|
||||
self.policy.compute_log_likelihoods(
|
||||
actions=batch[SampleBatch.ACTIONS],
|
||||
obs_batch=batch[SampleBatch.OBS],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||
actions_normalized=False,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
prob_ratio = torch.exp(new_log_prob - old_log_prob)
|
||||
|
||||
eps_begin = 0
|
||||
for episode in batch.split_by_episode():
|
||||
eps_end = eps_begin + episode.count
|
||||
|
||||
# calculate importance ratios and returns
|
||||
|
||||
for t in range(episode.count):
|
||||
discounts[eps_begin + t] = self.gamma ** t
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = ps[eps_begin + t - 1]
|
||||
ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t]
|
||||
|
||||
# O(n^3)
|
||||
# ret = 0
|
||||
# for t_prime in range(t, episode.count):
|
||||
# gamma = self.gamma ** (t_prime - t)
|
||||
# rho_t_1_t_prime = 1.0
|
||||
# for k in range(t + 1, min(t_prime + 1, episode.count)):
|
||||
# rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k]
|
||||
# r = rewards[eps_begin + t_prime]
|
||||
# ret += gamma * rho_t_1_t_prime * r
|
||||
|
||||
# O(n^2)
|
||||
ret = 0
|
||||
rho = 1
|
||||
for t_ in reversed(range(t, episode.count)):
|
||||
ret = rewards[eps_begin + t_] + self.gamma * rho * ret
|
||||
rho = prob_ratio[eps_begin + t_]
|
||||
|
||||
returns[eps_begin + t] = ret
|
||||
|
||||
# Update before next episode
|
||||
eps_begin = eps_end
|
||||
|
||||
indices = np.arange(batch.count)
|
||||
for _ in range(self.n_iters):
|
||||
minibatch_losses = []
|
||||
np.random.shuffle(indices)
|
||||
for idx in range(0, batch.count, self.batch_size):
|
||||
idxs = indices[idx : idx + self.batch_size]
|
||||
q_values, _ = self.q_model({"obs": obs[idxs]}, [], None)
|
||||
q_acts = torch.gather(
|
||||
q_values, -1, actions[idxs].unsqueeze(-1)
|
||||
).squeeze()
|
||||
loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts) ** 2
|
||||
loss = torch.mean(loss)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad.clip_grad_norm_(
|
||||
self.q_model.variables(), self.clip_grad_norm
|
||||
)
|
||||
self.optimizer.step()
|
||||
minibatch_losses.append(loss.item())
|
||||
iter_loss = sum(minibatch_losses) / len(minibatch_losses)
|
||||
losses.append(iter_loss)
|
||||
if iter_loss < self.delta:
|
||||
break
|
||||
return losses
|
||||
|
||||
def estimate_q(
|
||||
self,
|
||||
obs: Union[TensorType, List[TensorType]],
|
||||
actions: Union[TensorType, List[TensorType]] = None,
|
||||
) -> TensorType:
|
||||
"""Given `obs`, a list or array or tensor of observations,
|
||||
compute the Q-values for `obs` for all actions in the action space.
|
||||
If `actions` is not None, return the Q-values for the actions provided,
|
||||
else return Q-values for all actions for each observation in `obs`.
|
||||
"""
|
||||
obs = torch.tensor(obs, device=self.device)
|
||||
q_values, _ = self.q_model({"obs": obs}, [], None)
|
||||
if actions is not None:
|
||||
actions = torch.tensor(actions, device=self.device, dtype=int)
|
||||
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
|
||||
return q_values.detach()
|
||||
|
||||
def estimate_v(
|
||||
self,
|
||||
obs: Union[TensorType, List[TensorType]],
|
||||
action_probs: Union[TensorType, List[TensorType]],
|
||||
) -> TensorType:
|
||||
"""Given `obs`, compute q-values for all actions in the action space
|
||||
for each observations s in `obs`, then multiply this by `action_probs`,
|
||||
the probability distribution over actions for each state s to give the
|
||||
state value V(s) = sum_A pi(a|s)Q(s,a).
|
||||
"""
|
||||
q_values = self.estimate_q(obs)
|
||||
action_probs = torch.tensor(action_probs, device=self.device)
|
||||
v_values = torch.sum(q_values * action_probs, axis=-1)
|
||||
return v_values.detach()
|
200
rllib/offline/estimators/tests/test_ope.py
Normal file
200
rllib/offline/estimators/tests/test_ope.py
Normal file
|
@ -0,0 +1,200 @@
|
|||
import unittest
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.algorithms.dqn import DQNConfig
|
||||
from ray.rllib.offline.estimators import (
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
DirectMethod,
|
||||
DoublyRobust,
|
||||
)
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from pathlib import Path
|
||||
import os
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
|
||||
class TestOPE(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(ignore_reinit_error=True)
|
||||
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||
print("rllib dir={}".format(rllib_dir))
|
||||
data_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json")
|
||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||
|
||||
env_name = "CartPole-v0"
|
||||
cls.gamma = 0.99
|
||||
train_steps = 20000
|
||||
n_batches = 20 # Approx. equal to n_episodes
|
||||
n_eval_episodes = 100
|
||||
|
||||
config = (
|
||||
DQNConfig()
|
||||
.environment(env=env_name)
|
||||
.training(gamma=cls.gamma)
|
||||
.rollouts(num_rollout_workers=3)
|
||||
.exploration(
|
||||
explore=True,
|
||||
exploration_config={
|
||||
"type": "SoftQ",
|
||||
"temperature": 1.0,
|
||||
},
|
||||
)
|
||||
.framework("torch")
|
||||
.rollouts(batch_mode="complete_episodes")
|
||||
)
|
||||
cls.trainer = config.build()
|
||||
|
||||
# Train DQN for evaluation policy
|
||||
tune.run(
|
||||
"DQN",
|
||||
config=config.to_dict(),
|
||||
stop={"timesteps_total": train_steps},
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
# Read n_batches of data
|
||||
reader = JsonReader(data_file)
|
||||
cls.batch = reader.next()
|
||||
for _ in range(n_batches - 1):
|
||||
cls.batch = cls.batch.concat(reader.next())
|
||||
cls.n_episodes = len(cls.batch.split_by_episode())
|
||||
print("Episodes:", cls.n_episodes, "Steps:", cls.batch.count)
|
||||
|
||||
cls.mean_ret = {}
|
||||
cls.std_ret = {}
|
||||
|
||||
# Simulate Monte-Carlo rollouts
|
||||
mc_ret = []
|
||||
env = gym.make(env_name)
|
||||
for _ in range(n_eval_episodes):
|
||||
obs = env.reset()
|
||||
done = False
|
||||
rewards = []
|
||||
while not done:
|
||||
act = cls.trainer.compute_single_action(obs)
|
||||
obs, reward, done, _ = env.step(act)
|
||||
rewards.append(reward)
|
||||
ret = 0
|
||||
for r in reversed(rewards):
|
||||
ret = r + cls.gamma * ret
|
||||
mc_ret.append(ret)
|
||||
|
||||
cls.mean_ret["simulation"] = np.mean(mc_ret)
|
||||
cls.std_ret["simulation"] = np.std(mc_ret)
|
||||
|
||||
# Optional configs for the model-based estimators
|
||||
cls.model_config = {"k": 2, "n_iters": 10}
|
||||
ray.shutdown()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
print("Mean:", cls.mean_ret)
|
||||
print("Stddev:", cls.std_ret)
|
||||
ray.shutdown()
|
||||
|
||||
def test_is(self):
|
||||
name = "is"
|
||||
estimator = ImportanceSampling(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
gamma=self.gamma,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_wis(self):
|
||||
name = "wis"
|
||||
estimator = WeightedImportanceSampling(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
gamma=self.gamma,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_dm_qreg(self):
|
||||
name = "dm_qreg"
|
||||
estimator = DirectMethod(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="qreg",
|
||||
**self.model_config,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_dm_fqe(self):
|
||||
name = "dm_fqe"
|
||||
estimator = DirectMethod(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="fqe",
|
||||
**self.model_config,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_dr_qreg(self):
|
||||
name = "dr_qreg"
|
||||
estimator = DoublyRobust(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="qreg",
|
||||
**self.model_config,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_dr_fqe(self):
|
||||
name = "dr_fqe"
|
||||
estimator = DoublyRobust(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="fqe",
|
||||
**self.model_config,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
estimates = estimator.get_metrics()
|
||||
assert len(estimates) == self.n_episodes
|
||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_ope_in_trainer(self):
|
||||
# TODO (rohan): Add performance tests for off_policy_estimation_methods,
|
||||
# with fixed seeds and hyperparameters
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,56 +1,66 @@
|
|||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import (
|
||||
OffPolicyEstimator,
|
||||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
import numpy as np
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class WeightedImportanceSampling(OffPolicyEstimator):
|
||||
"""The weighted step-wise IS estimator.
|
||||
|
||||
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf,
|
||||
https://arxiv.org/pdf/1911.06854.pdf"""
|
||||
|
||||
def __init__(self, policy: Policy, gamma: float):
|
||||
super().__init__(policy, gamma)
|
||||
@override(OffPolicyEstimator)
|
||||
def __init__(self, name: str, policy: Policy, gamma: float):
|
||||
super().__init__(name, policy, gamma)
|
||||
self.filter_values = []
|
||||
self.filter_counts = []
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
estimates = []
|
||||
for sub_batch in batch.split_by_episode():
|
||||
rewards, old_prob = sub_batch["rewards"], sub_batch["action_prob"]
|
||||
new_prob = np.exp(self.action_log_likelihood(sub_batch))
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_log_likelihood(batch)
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(sub_batch.count):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
for t, v in enumerate(p):
|
||||
if t >= len(self.filter_values):
|
||||
self.filter_values.append(v)
|
||||
self.filter_counts.append(1.0)
|
||||
else:
|
||||
self.filter_values[t] += v
|
||||
self.filter_counts[t] += 1.0
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
for t in range(batch.count):
|
||||
if t == 0:
|
||||
pt_prev = 1.0
|
||||
else:
|
||||
pt_prev = p[t - 1]
|
||||
p.append(pt_prev * new_prob[t] / old_prob[t])
|
||||
for t, v in enumerate(p):
|
||||
if t >= len(self.filter_values):
|
||||
self.filter_values.append(v)
|
||||
self.filter_counts.append(1.0)
|
||||
else:
|
||||
self.filter_values[t] += v
|
||||
self.filter_counts[t] += 1.0
|
||||
# calculate stepwise weighted IS estimate
|
||||
v_old = 0.0
|
||||
v_new = 0.0
|
||||
for t in range(sub_batch.count):
|
||||
v_old += rewards[t] * self.gamma ** t
|
||||
w_t = self.filter_values[t] / self.filter_counts[t]
|
||||
v_new += p[t] / w_t * rewards[t] * self.gamma ** t
|
||||
|
||||
# calculate stepwise weighted IS estimate
|
||||
V_prev, V_step_WIS = 0.0, 0.0
|
||||
for t in range(batch.count):
|
||||
V_prev += rewards[t] * self.gamma ** t
|
||||
w_t = self.filter_values[t] / self.filter_counts[t]
|
||||
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t
|
||||
|
||||
estimation = OffPolicyEstimate(
|
||||
"weighted_importance_sampling",
|
||||
{
|
||||
"V_prev": V_prev,
|
||||
"V_step_WIS": V_step_WIS,
|
||||
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
|
||||
},
|
||||
)
|
||||
return estimation
|
||||
estimates.append(
|
||||
OffPolicyEstimate(
|
||||
self.name,
|
||||
{
|
||||
"v_old": v_old,
|
||||
"v_new": v_new,
|
||||
"v_gain": v_new / max(1e-8, v_old),
|
||||
},
|
||||
)
|
||||
)
|
||||
return estimates
|
||||
|
|
|
@ -1,167 +1,11 @@
|
|||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.annotations import Deprecated
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OffPolicyEstimate = DeveloperAPI(
|
||||
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import ( # noqa: F401
|
||||
OffPolicyEstimator,
|
||||
OffPolicyEstimate,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class OffPolicyEstimator:
|
||||
"""Interface for an off policy reward estimator."""
|
||||
|
||||
def __init__(self, policy: Policy, gamma: float):
|
||||
"""Initializes an OffPolicyEstimator instance.
|
||||
|
||||
Args:
|
||||
policy: Policy to evaluate.
|
||||
gamma: Discount factor of the environment.
|
||||
"""
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.new_estimates = []
|
||||
|
||||
@classmethod
|
||||
@DeveloperAPI
|
||||
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||
"""Creates an off-policy estimator from an IOContext object.
|
||||
|
||||
Extracts Policy and gamma (discount factor) information from the
|
||||
IOContext.
|
||||
|
||||
Args:
|
||||
ioctx: The IOContext object to create the OffPolicyEstimator
|
||||
from.
|
||||
|
||||
Returns:
|
||||
The OffPolicyEstimator object created from the IOContext object.
|
||||
"""
|
||||
gamma = ioctx.worker.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(ioctx.worker.policy_map.keys())
|
||||
if len(keys) > 1:
|
||||
raise NotImplementedError(
|
||||
"Off-policy estimation is not implemented for multi-agent. "
|
||||
"You can set `off_policy_estimation_methods: []` to resolve this."
|
||||
)
|
||||
policy = ioctx.worker.get_policy(keys[0])
|
||||
return cls(policy, gamma)
|
||||
|
||||
@DeveloperAPI
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
"""Returns an off policy estimate for the given batch of experiences.
|
||||
|
||||
The batch will at most only contain data from one episode,
|
||||
but it may also only be a fragment of an episode.
|
||||
|
||||
Args:
|
||||
batch: The batch to calculate the off policy estimate (OPE) on.
|
||||
|
||||
Returns:
|
||||
The off-policy estimates (OPE) calculated on the given batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Returns log likelihoods for actions in given batch for policy.
|
||||
|
||||
Computes likelihoods by passing the observations through the current
|
||||
policy's `compute_log_likelihoods()` method.
|
||||
|
||||
Args:
|
||||
batch: The SampleBatch or MultiAgentBatch to calculate action
|
||||
log likelihoods from. This batch/batches must contain OBS
|
||||
and ACTIONS keys.
|
||||
|
||||
Returns:
|
||||
The log likelihoods of the actions in the batch, given the
|
||||
observations and the policy.
|
||||
"""
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
num_state_inputs += 1
|
||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
|
||||
actions=batch[SampleBatch.ACTIONS],
|
||||
obs_batch=batch[SampleBatch.OBS],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||
actions_normalized=True,
|
||||
)
|
||||
log_likelihoods = convert_to_numpy(log_likelihoods)
|
||||
return np.exp(log_likelihoods)
|
||||
|
||||
@DeveloperAPI
|
||||
def process(self, batch: SampleBatchType) -> None:
|
||||
"""Computes off policy estimates (OPE) on batch and stores results.
|
||||
|
||||
Thus-far collected results can be retrieved then by calling
|
||||
`self.get_metrics` (which flushes the internal results storage).
|
||||
|
||||
Args:
|
||||
batch: The batch to process (call `self.estimate()` on) and
|
||||
store results (OPEs) for.
|
||||
"""
|
||||
self.new_estimates.append(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
|
||||
"""Checks if we support off policy estimation (OPE) on given batch.
|
||||
|
||||
Args:
|
||||
batch: The batch to check.
|
||||
|
||||
Raises:
|
||||
ValueError: In case `action_prob` key is not in batch OR batch
|
||||
is a MultiAgentBatch.
|
||||
"""
|
||||
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
raise ValueError(
|
||||
"off-policy estimation is not implemented for multi-agent batches. "
|
||||
"You can set `off_policy_estimation_methods: []` to resolve this."
|
||||
)
|
||||
|
||||
if "action_prob" not in batch:
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the inputs "
|
||||
"include action probabilities (i.e., the policy is stochastic "
|
||||
"and emits the 'action_prob' key). For DQN this means using "
|
||||
"`exploration_config: {type: 'SoftQ'}`. You can also set "
|
||||
"`off_policy_estimation_methods: []` to disable estimation."
|
||||
)
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self) -> List[OffPolicyEstimate]:
|
||||
"""Returns list of new episode metric estimates since the last call.
|
||||
|
||||
Returns:
|
||||
List of OffPolicyEstimate objects.
|
||||
"""
|
||||
out = self.new_estimates
|
||||
self.new_estimates = []
|
||||
return out
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False)
|
||||
def create(self, *args, **kwargs):
|
||||
return self.create_from_io_context(*args, **kwargs)
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False)
|
||||
def action_prob(self, *args, **kwargs):
|
||||
return self.action_log_likelihood(*args, **kwargs)
|
||||
deprecation_warning(
|
||||
old="ray.rllib.offline.off_policy_estimator",
|
||||
new="ray.rllib.offline.estimators.off_policy_estimator",
|
||||
error=False,
|
||||
)
|
||||
|
|
|
@ -98,7 +98,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir + fw,
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
"framework": fw,
|
||||
},
|
||||
)
|
||||
|
@ -141,7 +141,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir + fw,
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
"postprocess_inputs": True, # adds back 'advantages'
|
||||
"framework": fw,
|
||||
},
|
||||
|
@ -158,7 +158,9 @@ class AgentIOTest(unittest.TestCase):
|
|||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir + fw,
|
||||
"off_policy_estimation_methods": ["simulation"],
|
||||
"off_policy_estimation_methods": {
|
||||
"simulation": {"type": "simulation"}
|
||||
},
|
||||
"framework": fw,
|
||||
},
|
||||
)
|
||||
|
@ -176,7 +178,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + fw + "/*.json"),
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
"rollout_fragment_length": 99,
|
||||
"framework": fw,
|
||||
},
|
||||
|
@ -196,7 +198,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
"sampler": 0.9,
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
"framework": fw,
|
||||
},
|
||||
)
|
||||
|
@ -234,7 +236,9 @@ class AgentIOTest(unittest.TestCase):
|
|||
config={
|
||||
"num_workers": 0,
|
||||
"input": self.test_dir,
|
||||
"off_policy_estimation_methods": ["simulation"],
|
||||
"off_policy_estimation_methods": {
|
||||
"simulation": {"type": "simulation"}
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policies": {"policy_1", "policy_2"},
|
||||
|
@ -276,7 +280,7 @@ class AgentIOTest(unittest.TestCase):
|
|||
config={
|
||||
"input": input_procedure,
|
||||
"input_config": {"input_files": self.test_dir + fw},
|
||||
"off_policy_estimation_methods": [],
|
||||
"off_policy_estimation_methods": {},
|
||||
"framework": fw,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -69,7 +69,7 @@ class NestedActionSpacesTest(unittest.TestCase):
|
|||
config["output"] = tmp_dir
|
||||
# Switch off OPE as we don't write action-probs.
|
||||
# TODO: We should probably always write those if `output` is given.
|
||||
config["off_policy_estimation_methods"] = []
|
||||
config["off_policy_estimation_methods"] = {}
|
||||
|
||||
# Pretend actions in offline files are already normalized.
|
||||
config["actions_in_input_normalized"] = True
|
||||
|
|
Loading…
Add table
Reference in a new issue