mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib]: Add Off-Policy Estimation docs (#26809)
Co-authored-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
parent
2ca11d61b3
commit
deccf33912
6 changed files with 128 additions and 51 deletions
Binary file not shown.
Before Width: | Height: | Size: 65 KiB |
BIN
doc/source/rllib/images/rllib-offline.png
Normal file
BIN
doc/source/rllib/images/rllib-offline.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.3 MiB |
|
@ -13,7 +13,7 @@ RLlib's offline dataset APIs enable working with experiences read from offline s
|
|||
RLlib represents trajectory sequences (i.e., ``(s, a, r, s', ...)`` tuples) with `SampleBatch <https://github.com/ray-project/ray/blob/master/rllib/policy/sample_batch.py>`__ objects. Using a batch format enables efficient encoding and compression of experiences. During online training, RLlib uses `policy evaluation <rllib-concepts.html#policy-evaluation>`__ actors to generate batches of experiences in parallel using the current policy. RLlib also uses this same batch format for reading and writing experiences to offline storage.
|
||||
|
||||
Example: Training on previously saved experiences
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
-------------------------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -23,7 +23,7 @@ In this example, we will save batches of experiences generated during online tra
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
$ rllib train
|
||||
$ rllib train \
|
||||
--run=PG \
|
||||
--env=CartPole-v0 \
|
||||
--config='{"output": "/tmp/cartpole-out", "output_max_file_size": 5000000}' \
|
||||
|
@ -50,35 +50,97 @@ Then, we can tell DQN to train using these previously generated experiences with
|
|||
"input": "/tmp/cartpole-out",
|
||||
"explore": false}'
|
||||
|
||||
.. _is:
|
||||
Off-Policy Estimation (OPE)
|
||||
---------------------------
|
||||
|
||||
**Off-policy estimation:** Since the input experiences are not from running simulations, RLlib cannot report the true policy performance during training. However, you can use ``tensorboard --logdir=~/ray_results`` to monitor training progress via other metrics such as estimated Q-value. Alternatively, `off-policy estimation <https://arxiv.org/pdf/1511.03722.pdf>`__ can be used, which requires both the source and target action probabilities to be available (i.e., the ``action_prob`` batch key). For DQN, this means enabling soft Q learning so that actions are sampled from a probability distribution:
|
||||
In practice, when training on offline data, it is usually not straightforward to evaluate the trained policies using a simulator as in online RL. For example, in recommeder systems, rolling out a policy trained on offline data in a real-world environment can jeopardize your business if the policy is suboptimal. For these situations we can use `off-policy estimation <https://arxiv.org/abs/1911.06854>`__ methods which avoid the risk of evaluating a possibly sub-optimal policy in a real-world environment.
|
||||
|
||||
With RLlib's evaluation framework you can:
|
||||
|
||||
- Evaluate policies on a simulated environement, if available, using ``evaluation_config["input"] = "sampler"``. You can then monitor your policy's performance on tensorboard as it is getting trained (by using ``tensorboard --logdir=~/ray_results``).
|
||||
|
||||
- Use RLlib's off-policy estimation methods, which estimate the policy's performance on a separate offline dataset. To be able to use this feature, the evaluation dataset should contain ``action_prob`` key that represents the action probability distribution of the collected data so that we can do counterfactual evaluation.
|
||||
|
||||
RLlib supports the following off-policy estimators:
|
||||
|
||||
- `Importance Sampling (IS) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/importance_sampling.py>`__
|
||||
- `Weighted Importance Sampling (WIS) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/weighted_importance_sampling.py>`__
|
||||
- `Direct Method (DM) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/direct_method.py>`__
|
||||
- `Doubly Robust (DR) <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/doubly_robust.py>`__
|
||||
|
||||
IS and WIS compute the ratio between the action probabilities under the behavior policy (from the dataset) and the target policy (the policy under evaluation), and use this ratio to estimate the policy's return. More details on this can be found in their respective papers.
|
||||
|
||||
DM and DR train a Q-model to compute the estimated return. By default, RLlib uses `Fitted-Q Evaluation (FQE) <https://arxiv.org/abs/1911.06854>`__ to train the Q-model. See `fqe_torch_model.py <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/fqe_torch_model.py>`__ for more details.
|
||||
|
||||
.. note:: For a contextual bandit dataset, the ``dones`` key should always be set to ``True``. In this case, FQE reduces to fitting a reward model to the data.
|
||||
|
||||
RLlib's OPE estimators output six metrics:
|
||||
|
||||
- ``v_behavior``: The discounted sum over rewards in the offline episode, averaged over episodes in the batch.
|
||||
- ``v_behavior_std``: The standard deviation corresponding to v_behavior.
|
||||
- ``v_target``: The OPE's estimated discounted return for the target policy, averaged over episodes in the batch.
|
||||
- ``v_target_std``: The standard deviation corresponding to v_target.
|
||||
- ``v_gain``: ``v_target / max(v_behavior, 1e-8)``, averaged over episodes in the batch. ``v_gain > 1.0`` indicates that the policy is better than the policy that generated the behavior data.
|
||||
- ``v_gain_std``: The standard deviation corresponding to v_gain.
|
||||
|
||||
As an example, we generate an evaluation dataset for off-policy estimation:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ rllib train \
|
||||
--run=DQN \
|
||||
--run=PG \
|
||||
--env=CartPole-v0 \
|
||||
--config='{
|
||||
"input": "/tmp/cartpole-out",
|
||||
"off_policy_estimation_methods": {
|
||||
"is": {
|
||||
"type": "ray.rllib.offline.estimators.ImportanceSampling",
|
||||
--config='{"output": "/tmp/cartpole-eval", "output_max_file_size": 5000000}' \
|
||||
--stop='{"timesteps_total": 10000}'
|
||||
|
||||
.. hint:: You should use separate datasets for algorithm training and OPE, as shown here.
|
||||
|
||||
We can now train a DQN algorithm offline and evaluate it using OPE:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.rllib.algorithms.dqn import DQNConfig
|
||||
from ray.rllib.offline.estimators import (
|
||||
ImportanceSampling,
|
||||
WeightedImportanceSampling,
|
||||
DirectMethod,
|
||||
DoublyRobust,
|
||||
)
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
|
||||
config = (
|
||||
DQNConfig()
|
||||
.environment(env="CartPole-v0")
|
||||
.framework("torch")
|
||||
.offline_data(input_="/tmp/cartpole-out")
|
||||
.evaluation(
|
||||
evaluation_interval=1,
|
||||
evaluation_duration=10,
|
||||
evaluation_num_workers=1,
|
||||
evaluation_duration_unit="episodes",
|
||||
evaluation_config={"input": "/tmp/cartpole-eval"},
|
||||
off_policy_estimation_methods={
|
||||
"is": {"type": ImportanceSampling},
|
||||
"wis": {"type": WeightedImportanceSampling},
|
||||
"dm_fqe": {
|
||||
"type": DirectMethod,
|
||||
"q_model_config": {"type": FQETorchModel, "tau": 0.05},
|
||||
},
|
||||
"wis": {
|
||||
"type": "ray.rllib.offline.estimators.WeightedImportanceSampling",
|
||||
}
|
||||
"dr_fqe": {
|
||||
"type": DoublyRobust,
|
||||
"q_model_config": {"type": FQETorchModel, "tau": 0.05},
|
||||
},
|
||||
"exploration_config": {
|
||||
"type": "SoftQ",
|
||||
"temperature": 1.0,
|
||||
}'
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
This example plot shows the Q-value metric in addition to importance sampling (IS) and weighted importance sampling (WIS) gain estimates (>1.0 means there is an estimated improvement over the original policy):
|
||||
algo = config.build()
|
||||
for _ in range(100):
|
||||
algo.train()
|
||||
|
||||
.. image:: images/offline-q.png
|
||||
.. image:: images/rllib-offline.png
|
||||
|
||||
**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy object and gamma value for the environment:
|
||||
**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.train(batch)`` to perform any neccessary training and ``estimator.estimate(batch)`` to perform counterfactual estimation. The estimators take in an RLlib Policy object and gamma value for the environment, along with additional estimator-specific arguments (e.g. ``q_model_config`` for DM and DR). You can take a look at the example config parameters of the q_model_config `here <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/fqe_torch_model.py>`__. You can also write your own off-policy estimator by subclassing from the `OffPolicyEstimator <https://github.com/ray-project/ray/blob/master/rllib/offline/estimators/off_policy_estimator.py>`__ base class.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -86,17 +148,32 @@ This example plot shows the Q-value metric in addition to importance sampling (I
|
|||
... # train policy offline
|
||||
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.offline.estimators import DoublyRobust
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
|
||||
estimator = WeightedImportanceSamplingEstimator(algo.get_policy(), gamma=0.99)
|
||||
reader = JsonReader("/path/to/data")
|
||||
for _ in range(1000):
|
||||
estimator = DoublyRobust(
|
||||
policy=algo.get_policy(),
|
||||
gamma=0.99,
|
||||
q_model_config={"type": FQETorchModel, "n_iters": 160},
|
||||
)
|
||||
|
||||
# Train estimator's Q-model; only required for DM and DR estimators
|
||||
reader = JsonReader("/tmp/cartpole-out")
|
||||
for _ in range(100):
|
||||
batch = reader.next()
|
||||
for episode in batch.split_by_episode():
|
||||
print(estimator.estimate(episode))
|
||||
print(estimator.train(batch))
|
||||
# {'loss': ...}
|
||||
|
||||
reader = JsonReader("/tmp/cartpole-eval")
|
||||
# Compute off-policy estimates
|
||||
for _ in range(100):
|
||||
batch = reader.next()
|
||||
print(estimator.estimate(batch))
|
||||
# {'v_behavior': ..., 'v_target': ..., 'v_gain': ...,
|
||||
# 'v_behavior_std': ..., 'v_target_std': ..., 'v_gain_std': ...}
|
||||
|
||||
Example: Converting external experiences to batch format
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------------------------------------------
|
||||
|
||||
When the env does not support simulation (e.g., it is a web application), it is necessary to generate the ``*.json`` experience batch files outside of RLlib. This can be done by using the `JsonWriter <https://github.com/ray-project/ray/blob/master/rllib/offline/json_writer.py>`__ class to write out batches.
|
||||
This `runnable example <https://github.com/ray-project/ray/blob/master/rllib/examples/saving_experiences.py>`__ shows how to generate and save experience batches for CartPole-v0 to disk:
|
||||
|
@ -107,7 +184,7 @@ This `runnable example <https://github.com/ray-project/ray/blob/master/rllib/exa
|
|||
:end-before: __sphinx_doc_end__
|
||||
|
||||
On-policy algorithms and experience postprocessing
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
----------------------------------------------------
|
||||
|
||||
RLlib assumes that input batches are of
|
||||
`postprocessed experiences <https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py#L434>`__.
|
||||
|
@ -121,7 +198,7 @@ However, for on-policy algorithms like PPO, you'll need to pass in the extra val
|
|||
Note that for on-policy algorithms, you'll also have to throw away experiences generated by prior versions of the policy. This greatly reduces sample efficiency, which is typically undesirable for offline training, but can make sense for certain applications.
|
||||
|
||||
Mixing simulation and offline data
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
-----------------------------------
|
||||
|
||||
RLlib supports multiplexing inputs from multiple input sources, including simulation. For example, in the following example we read 40% of our experiences from ``/tmp/cartpole-out``, 30% from ``hdfs:/archive/cartpole``, and the last 30% is produced via policy evaluation. Input sources are multiplexed using `np.random.choice <https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.random.choice.html>`__:
|
||||
|
||||
|
@ -139,12 +216,12 @@ RLlib supports multiplexing inputs from multiple input sources, including simula
|
|||
"explore": false}'
|
||||
|
||||
Scaling I/O throughput
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
-----------------------
|
||||
|
||||
Similar to scaling online training, you can scale offline I/O throughput by increasing the number of RLlib workers via the ``num_workers`` config. Each worker accesses offline storage independently in parallel, for linear scaling of I/O throughput. Within each read worker, files are chosen in random order for reads, but file contents are read sequentially.
|
||||
|
||||
Ray Dataset Integration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------------
|
||||
|
||||
RLlib has experimental support for reading/writing training samples from/to large offline datasets using
|
||||
`Ray Dataset <https://docs.ray.io/en/latest/data/dataset.html>`__.
|
||||
|
@ -189,15 +266,15 @@ To write sample data to JSON or Parquet files using Dataset, specify output and
|
|||
}
|
||||
|
||||
Writing Environment Data
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------------
|
||||
|
||||
To include environment data in the training sample datasets you can use the optional
|
||||
``store_infos`` parameter that is part of the ``output_config`` dictionary. This parameter
|
||||
ensures that the ``infos`` dictionary, as returned by the RL environment, is included in the output files.
|
||||
|
||||
Note 1: It is the responsibility of the user to ensure that the content of ``infos`` can be serialized
|
||||
to file.
|
||||
Note 2: This setting is only relevant for the TensorFlow based agents, for PyTorch agents the ``infos`` data is always stored.
|
||||
.. note:: It is the responsibility of the user to ensure that the content of ``infos`` can be serialized to file.
|
||||
|
||||
.. note:: This setting is only relevant for the TensorFlow based agents, for PyTorch agents the ``infos`` data is always stored.
|
||||
|
||||
To write the ``infos`` data to JSON or Parquet files using Dataset, specify output and output_config keys like the following:
|
||||
|
||||
|
@ -279,12 +356,8 @@ You can configure experience input for an agent using the following options:
|
|||
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
|
||||
# subclass.
|
||||
"off_policy_estimation_methods": {
|
||||
"is": {
|
||||
"type": ImportanceSampling,
|
||||
},
|
||||
"wis": {
|
||||
"type": WeightedImportanceSampling,
|
||||
}
|
||||
"is": {"type": ImportanceSampling},
|
||||
"wis": {"type": WeightedImportanceSampling}
|
||||
},
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
|
@ -303,7 +376,7 @@ The interface for a custom input reader is as follows:
|
|||
:noindex:
|
||||
|
||||
Example Custom Input API
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
------------------------
|
||||
|
||||
You can create a custom input reader like the following:
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -35,7 +36,7 @@ class DirectMethod(OffPolicyEstimator):
|
|||
self,
|
||||
policy: Policy,
|
||||
gamma: float,
|
||||
q_model_config: Dict = None,
|
||||
q_model_config: Optional[Dict] = None,
|
||||
):
|
||||
"""Initializes a Direct Method OPE Estimator.
|
||||
|
||||
|
@ -55,7 +56,8 @@ class DirectMethod(OffPolicyEstimator):
|
|||
), "DirectMethod estimator only works with torch!"
|
||||
super().__init__(policy, gamma)
|
||||
|
||||
model_cls = q_model_config.pop("type")
|
||||
q_model_config = q_model_config or {}
|
||||
model_cls = q_model_config.pop("type", FQETorchModel)
|
||||
self.model = model_cls(
|
||||
policy=policy,
|
||||
gamma=gamma,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -9,6 +9,7 @@ from ray.rllib.utils.numpy import convert_to_numpy
|
|||
from ray.rllib.utils.policy import compute_log_likelihoods_from_input_dict
|
||||
|
||||
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
|
||||
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
@ -46,7 +47,7 @@ class DoublyRobust(OffPolicyEstimator):
|
|||
self,
|
||||
policy: Policy,
|
||||
gamma: float,
|
||||
q_model_config: Dict = None,
|
||||
q_model_config: Optional[Dict] = None,
|
||||
):
|
||||
"""Initializes a Doubly Robust OPE Estimator.
|
||||
|
||||
|
@ -63,7 +64,8 @@ class DoublyRobust(OffPolicyEstimator):
|
|||
"""
|
||||
|
||||
super().__init__(policy, gamma)
|
||||
model_cls = q_model_config.pop("type")
|
||||
q_model_config = q_model_config or {}
|
||||
model_cls = q_model_config.pop("type", FQETorchModel)
|
||||
|
||||
self.model = model_cls(
|
||||
policy=policy,
|
||||
|
|
|
@ -18,7 +18,7 @@ torch, nn = try_import_torch()
|
|||
@DeveloperAPI
|
||||
class FQETorchModel:
|
||||
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
|
||||
https://arxiv.org/pdf/1911.06854.pdf
|
||||
https://arxiv.org/abs/1911.06854
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -44,9 +44,9 @@ class FQETorchModel:
|
|||
"vf_share_layers": True,
|
||||
},
|
||||
n_iters: Number of gradient steps to run on batch, defaults to 1
|
||||
lr: Learning rate for Q-model optimizer
|
||||
lr: Learning rate for Adam optimizer
|
||||
delta: Early stopping threshold if the mean loss < delta
|
||||
clip_grad_norm: Clip gradients to this maximum value
|
||||
clip_grad_norm: Clip loss gradients to this maximum value
|
||||
minibatch_size: Minibatch size for training Q-function;
|
||||
if None, train on the whole batch
|
||||
tau: Polyak averaging factor for target Q-function
|
||||
|
|
Loading…
Add table
Reference in a new issue