mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Added DefaultCallbacks which replaces old callbacks dict interface (#6972)
This commit is contained in:
parent
35ae7f0e68
commit
dbcad35022
11 changed files with 392 additions and 161 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -106,6 +106,7 @@ scripts/nodes.txt
|
|||
# Generated documentation files
|
||||
/doc/_build
|
||||
/doc/source/_static/thumbs
|
||||
/doc/source/tune/generated_guides/
|
||||
|
||||
# User-specific stuff:
|
||||
.idea/**/workspace.xml
|
||||
|
|
|
@ -496,51 +496,12 @@ Ray actors provide high levels of performance, so in more complex cases they can
|
|||
Callbacks and Custom Metrics
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__. Custom state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. The following example (full code `here <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__) logs a custom metric from the environment:
|
||||
You can provide callbacks to be called at points during policy evaluation. These callbacks have access to state for the current `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__. Certain callbacks such as ``on_postprocess_trajectory``, ``on_sample_end``, and ``on_train_result`` are also places where custom postprocessing can be applied to intermediate data or results.
|
||||
|
||||
.. code-block:: python
|
||||
User-defined state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__.
|
||||
|
||||
def on_episode_start(info):
|
||||
print(info.keys()) # -> "env", 'episode"
|
||||
episode = info["episode"]
|
||||
print("episode {} started".format(episode.episode_id))
|
||||
episode.user_data["pole_angles"] = []
|
||||
|
||||
def on_episode_step(info):
|
||||
episode = info["episode"]
|
||||
pole_angle = abs(episode.last_observation_for()[2])
|
||||
episode.user_data["pole_angles"].append(pole_angle)
|
||||
|
||||
def on_episode_end(info):
|
||||
episode = info["episode"]
|
||||
pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, pole_angle))
|
||||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
|
||||
def on_train_result(info):
|
||||
print("trainer.train() result: {} -> {} episodes".format(
|
||||
info["trainer"].__name__, info["result"]["episodes_this_iter"]))
|
||||
|
||||
def on_postprocess_traj(info):
|
||||
episode = info["episode"]
|
||||
batch = info["post_batch"] # note: you can mutate this
|
||||
print("postprocessed {} steps".format(batch.count))
|
||||
|
||||
ray.init()
|
||||
analysis = tune.run(
|
||||
"PG",
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"callbacks": {
|
||||
"on_episode_start": on_episode_start,
|
||||
"on_episode_step": on_episode_step,
|
||||
"on_episode_end": on_episode_end,
|
||||
"on_train_result": on_train_result,
|
||||
"on_postprocess_traj": on_postprocess_traj,
|
||||
},
|
||||
},
|
||||
)
|
||||
.. autoclass:: ray.rllib.agents.callbacks.DefaultCallbacks
|
||||
:members:
|
||||
|
||||
Visualizing Custom Metrics
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -1321,6 +1321,14 @@ py_test(
|
|||
args = ["--num-iters=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_metrics_and_callbacks_legacy",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "small",
|
||||
srcs = ["examples/custom_metrics_and_callbacks_legacy.py"],
|
||||
args = ["--num-iters=2"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/custom_tf_policy",
|
||||
tags = ["examples", "examples_C"],
|
||||
|
|
166
rllib/agents/callbacks.py
Normal file
166
rllib/agents/callbacks.py
Normal file
|
@ -0,0 +1,166 @@
|
|||
from typing import Dict
|
||||
|
||||
from ray.rllib.env import BaseEnv
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class DefaultCallbacks:
|
||||
"""Abstract base class for RLlib callbacks (similar to Keras callbacks).
|
||||
|
||||
These callbacks can be used for custom metrics and custom postprocessing.
|
||||
|
||||
By default, all of these callbacks are no-ops. To configure custom training
|
||||
callbacks, subclass DefaultCallbacks and then set
|
||||
{"callbacks": YourCallbacksClass} in the trainer config.
|
||||
"""
|
||||
|
||||
def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
|
||||
if legacy_callbacks_dict:
|
||||
deprecation_warning(
|
||||
"callbacks dict interface",
|
||||
"a class extending rllib.agents.callbacks.DefaultCallbacks")
|
||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||
|
||||
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
policies: Dict[str, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
"""Callback run on the rollout worker before each episode starts.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): BaseEnv running the episode. The underlying
|
||||
env object can be gotten by calling base_env.get_unwrapped().
|
||||
policies (dict): Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default" policy.
|
||||
episode (MultiAgentEpisode): Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_start"):
|
||||
self.legacy_callbacks["on_episode_start"]({
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
})
|
||||
|
||||
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
"""Runs on each episode step.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): BaseEnv running the episode. The underlying
|
||||
env object can be gotten by calling base_env.get_unwrapped().
|
||||
episode (MultiAgentEpisode): Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_step"):
|
||||
self.legacy_callbacks["on_episode_step"]({
|
||||
"env": base_env,
|
||||
"episode": episode
|
||||
})
|
||||
|
||||
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
**kwargs):
|
||||
"""Runs when an episode is done.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
base_env (BaseEnv): BaseEnv running the episode. The underlying
|
||||
env object can be gotten by calling base_env.get_unwrapped().
|
||||
policies (dict): Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default" policy.
|
||||
episode (MultiAgentEpisode): Episode object which contains episode
|
||||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_episode_end"):
|
||||
self.legacy_callbacks["on_episode_end"]({
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
})
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self, worker: RolloutWorker, episode: MultiAgentEpisode,
|
||||
agent_id: str, policy_id: str, policies: Dict[str, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[str, SampleBatch], **kwargs):
|
||||
"""Called immediately after a policy's postprocess_fn is called.
|
||||
|
||||
You can use this callback to do additional postprocessing for a policy,
|
||||
including looking at the trajectory data of other agents in multi-agent
|
||||
settings.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
episode (MultiAgentEpisode): Episode object.
|
||||
agent_id (str): Id of the current agent.
|
||||
policy_id (str): Id of the current policy for the agent.
|
||||
policies (dict): Mapping of policy id to policy objects. In single
|
||||
agent mode there will only be a single "default" policy.
|
||||
postprocessed_batch (SampleBatch): The postprocessed sample batch
|
||||
for this agent. You can mutate this object to apply your own
|
||||
trajectory postprocessing.
|
||||
original_batches (dict): Mapping of agents to their unpostprocessed
|
||||
trajectory data. You should not mutate this object.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_postprocess_traj"):
|
||||
self.legacy_callbacks["on_postprocess_traj"]({
|
||||
"episode": episode,
|
||||
"agent_id": agent_id,
|
||||
"pre_batch": original_batches[agent_id],
|
||||
"post_batch": postprocessed_batch,
|
||||
"all_pre_batches": original_batches,
|
||||
})
|
||||
|
||||
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
|
||||
**kwargs):
|
||||
"""Called at the end RolloutWorker.sample().
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
samples (SampleBatch): Batch to be returned. You can mutate this
|
||||
object to modify the samples generated.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_sample_end"):
|
||||
self.legacy_callbacks["on_sample_end"]({
|
||||
"worker": worker,
|
||||
"samples": samples,
|
||||
})
|
||||
|
||||
def on_train_result(self, trainer, result: dict, **kwargs):
|
||||
"""Called at the end of Trainable.train().
|
||||
|
||||
Args:
|
||||
trainer (Trainer): Current trainer instance.
|
||||
result (dict): Dict of results returned from trainer.train() call.
|
||||
You can mutate this object to add additional metrics.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
if self.legacy_callbacks.get("on_train_result"):
|
||||
self.legacy_callbacks["on_train_result"]({
|
||||
"trainer": trainer,
|
||||
"result": result,
|
||||
})
|
|
@ -9,6 +9,7 @@ import tempfile
|
|||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
@ -126,24 +127,10 @@ COMMON_CONFIG = {
|
|||
# `rllib train` command, you can also use the `-v` and `-vv` flags as
|
||||
# shorthand for INFO and DEBUG.
|
||||
"log_level": "WARN",
|
||||
# Callbacks that will be run during various phases of training. These all
|
||||
# take a single "info" dict as an argument. For episode callbacks, custom
|
||||
# metrics can be attached to the episode by updating the episode object's
|
||||
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
|
||||
# may also mutate the passed in batch data in your callback.
|
||||
"callbacks": {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "worker": ...}
|
||||
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
|
||||
"on_postprocess_traj": None, # arg: {
|
||||
# "agent_id": ..., "episode": ...,
|
||||
# "pre_batch": (before processing),
|
||||
# "post_batch": (after processing),
|
||||
# "all_pre_batches": (other agent ids),
|
||||
# }
|
||||
},
|
||||
# Callbacks that will be run during various phases of training. See the
|
||||
# `DefaultCallbacks` class and `examples/custom_metrics_and_callbacks.py`
|
||||
# for more usage information.
|
||||
"callbacks": DefaultCallbacks,
|
||||
# Whether to attempt to continue training if a worker crashes. The number
|
||||
# of currently healthy workers is reported as the "num_healthy_workers"
|
||||
# metric.
|
||||
|
@ -542,11 +529,7 @@ class Trainer(Trainable):
|
|||
|
||||
@override(Trainable)
|
||||
def _log_result(self, result):
|
||||
if self.config["callbacks"].get("on_train_result"):
|
||||
self.config["callbacks"]["on_train_result"]({
|
||||
"trainer": self,
|
||||
"result": result,
|
||||
})
|
||||
self.callbacks.on_train_result(trainer=self, result=result)
|
||||
# log after the callback is invoked, so that the user has a chance
|
||||
# to mutate the result
|
||||
Trainable._log_result(self, result)
|
||||
|
@ -584,6 +567,12 @@ class Trainer(Trainable):
|
|||
self.env_creator = lambda env_config: normalize(inner(env_config))
|
||||
|
||||
Trainer._validate_config(self.config)
|
||||
if not callable(self.config["callbacks"]):
|
||||
raise ValueError(
|
||||
"`callbacks` must be a callable method that "
|
||||
"returns a subclass of DefaultCallbacks, got {}".format(
|
||||
self.config["callbacks"]))
|
||||
self.callbacks = self.config["callbacks"]()
|
||||
log_level = self.config.get("log_level")
|
||||
if log_level in ["WARN", "ERROR"]:
|
||||
logger.info("Current log_level is {}. For more information, "
|
||||
|
@ -918,6 +907,15 @@ class Trainer(Trainable):
|
|||
"sample_batch_size", new="rollout_fragment_length")
|
||||
config2["rollout_fragment_length"] = config2["sample_batch_size"]
|
||||
del config2["sample_batch_size"]
|
||||
if "callbacks" in config2 and type(config2["callbacks"]) is dict:
|
||||
legacy_callbacks_dict = config2["callbacks"]
|
||||
|
||||
def make_callbacks():
|
||||
# Deprecation warning will be logged by DefaultCallbacks.
|
||||
return DefaultCallbacks(
|
||||
legacy_callbacks_dict=legacy_callbacks_dict)
|
||||
|
||||
config2["callbacks"] = make_callbacks
|
||||
return deep_update(config1, config2, cls._allow_unknown_configs,
|
||||
cls._allow_unknown_subkeys,
|
||||
cls._override_all_subkeys_if_type_changes)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
|
@ -21,11 +22,17 @@ torch, nn = try_import_torch()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def on_episode_start(info):
|
||||
class AlphaZeroDefaultCallbacks(DefaultCallbacks):
|
||||
"""AlphaZero callbacks.
|
||||
|
||||
If you use custom callbacks, you must extend this class and call super()
|
||||
for on_episode_start.
|
||||
"""
|
||||
|
||||
def on_episode_start(self, worker, base_env, policies, episode, **kwargs):
|
||||
# save env state when an episode starts
|
||||
env = info["env"].get_unwrapped()[0]
|
||||
env = base_env.get_unwrapped()[0]
|
||||
state = env.get_state()
|
||||
episode = info["episode"]
|
||||
episode.user_data["initial_state"] = state
|
||||
|
||||
|
||||
|
@ -94,9 +101,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
},
|
||||
|
||||
# === Callbacks ===
|
||||
"callbacks": {
|
||||
"on_episode_start": on_episode_start,
|
||||
},
|
||||
"callbacks": AlphaZeroDefaultCallbacks,
|
||||
|
||||
"use_pytorch": True,
|
||||
})
|
||||
|
|
|
@ -214,7 +214,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
directory if specified.
|
||||
log_dir (str): Directory where logs can be placed.
|
||||
log_level (str): Set the root log level on creation.
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
callbacks (DefaultCallbacks): Custom training callbacks.
|
||||
input_creator (func): Function that returns an InputReader object
|
||||
for loading previous generated experiences.
|
||||
input_evaluation (list): How to evaluate the policy performance.
|
||||
|
@ -279,7 +279,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
|
||||
env_context = EnvContext(env_config or {}, worker_index)
|
||||
self.policy_config = policy_config
|
||||
self.callbacks = callbacks or {}
|
||||
if callbacks:
|
||||
self.callbacks = callbacks()
|
||||
else:
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
self.callbacks = DefaultCallbacks()
|
||||
self.worker_index = worker_index
|
||||
self.num_workers = num_workers
|
||||
model_config = model_config or {}
|
||||
|
@ -444,6 +448,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self,
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
|
@ -462,6 +467,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self,
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
|
@ -518,8 +524,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
|||
batches.append(batch)
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.callbacks.get("on_sample_end"):
|
||||
self.callbacks["on_sample_end"]({"worker": self, "samples": batch})
|
||||
self.callbacks.on_sample_end(worker=self, samples=batch)
|
||||
|
||||
# Always do writes prior to compression for consistency and to allow
|
||||
# for better compression inside the writer.
|
||||
|
|
|
@ -72,13 +72,13 @@ class MultiAgentSampleBatchBuilder:
|
|||
corresponding policy batch for the agent's policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_map, clip_rewards, postp_callback):
|
||||
def __init__(self, policy_map, clip_rewards, callbacks):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy instances.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
postp_callback: function to call on each postprocessed batch.
|
||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
|
@ -89,7 +89,7 @@ class MultiAgentSampleBatchBuilder:
|
|||
}
|
||||
self.agent_builders = {}
|
||||
self.agent_to_policy = {}
|
||||
self.postp_callback = postp_callback
|
||||
self.callbacks = callbacks
|
||||
self.count = 0 # increment this manually
|
||||
|
||||
def total(self):
|
||||
|
@ -161,15 +161,16 @@ class MultiAgentSampleBatchBuilder:
|
|||
format(summarize(post_batches)))
|
||||
|
||||
# Append into policy batches and reset
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
if self.postp_callback:
|
||||
self.postp_callback({
|
||||
"episode": episode,
|
||||
"agent_id": agent_id,
|
||||
"pre_batch": pre_batches[agent_id],
|
||||
"post_batch": post_batch,
|
||||
"all_pre_batches": pre_batches,
|
||||
})
|
||||
self.callbacks.on_postprocess_trajectory(
|
||||
worker=get_global_worker(),
|
||||
episode=episode,
|
||||
agent_id=agent_id,
|
||||
policy_id=self.agent_to_policy[agent_id],
|
||||
policies=self.policy_map,
|
||||
postprocessed_batch=post_batch,
|
||||
original_batches=pre_batches)
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class SamplerInput(InputReader):
|
|||
|
||||
class SyncSampler(SamplerInput):
|
||||
def __init__(self,
|
||||
worker,
|
||||
env,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
|
@ -84,7 +85,7 @@ class SyncSampler(SamplerInput):
|
|||
self.extra_batches = queue.Queue()
|
||||
self.perf_stats = PerfStats()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.base_env, self.extra_batches.put, self.policies,
|
||||
worker, self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
|
||||
|
@ -121,6 +122,7 @@ class SyncSampler(SamplerInput):
|
|||
|
||||
class AsyncSampler(threading.Thread, SamplerInput):
|
||||
def __init__(self,
|
||||
worker,
|
||||
env,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
|
@ -139,6 +141,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
self.worker = worker
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
|
@ -178,7 +181,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
extra_batches_putter = (
|
||||
lambda x: self.extra_batches.put(x, timeout=600.0))
|
||||
rollout_provider = _env_runner(
|
||||
self.base_env, extra_batches_putter, self.policies,
|
||||
self.worker, self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
||||
|
@ -224,13 +227,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
return extra
|
||||
|
||||
|
||||
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
rollout_fragment_length, horizon, preprocessors, obs_filters,
|
||||
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
||||
perf_stats, soft_horizon, no_done_at_end):
|
||||
def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||
policy_mapping_fn, rollout_fragment_length, horizon,
|
||||
preprocessors, obs_filters, clip_rewards, clip_actions, pack,
|
||||
callbacks, tf_sess, perf_stats, soft_horizon, no_done_at_end):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): reference to the current rollout worker.
|
||||
base_env (BaseEnv): env implementing BaseEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
policies (dict): Map of policy ids to Policy instances.
|
||||
|
@ -250,7 +254,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
guarantees batches will be exactly `rollout_fragment_length` in
|
||||
size.
|
||||
clip_actions (bool): Whether to clip actions to the space range.
|
||||
callbacks (dict): User callbacks to run on episode events.
|
||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
perf_stats (PerfStats): Record perf stats into this object.
|
||||
|
@ -300,8 +304,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
if batch_builder_pool:
|
||||
return batch_builder_pool.pop()
|
||||
else:
|
||||
return MultiAgentSampleBatchBuilder(
|
||||
policies, clip_rewards, callbacks.get("on_postprocess_traj"))
|
||||
return MultiAgentSampleBatchBuilder(policies, clip_rewards,
|
||||
callbacks)
|
||||
|
||||
def new_episode():
|
||||
episode = MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
|
@ -313,13 +317,11 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
# Call custom on_episode_start callback.
|
||||
if callbacks.get("on_episode_start"):
|
||||
callbacks["on_episode_start"]({
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode,
|
||||
})
|
||||
callbacks.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
return episode
|
||||
|
||||
active_episodes = defaultdict(new_episode)
|
||||
|
@ -340,7 +342,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
# Process observations and prepare for policy evaluation
|
||||
t1 = time.time()
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
worker, base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, rollout_fragment_length, pack,
|
||||
callbacks, soft_horizon, no_done_at_end)
|
||||
|
@ -368,7 +370,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
perf_stats.env_wait_time += time.time() - t4
|
||||
|
||||
|
||||
def _process_observations(base_env, policies, batch_builder_pool,
|
||||
def _process_observations(worker, base_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, preprocessors,
|
||||
obs_filters, rollout_fragment_length, pack,
|
||||
|
@ -482,8 +484,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
**episode.last_pi_info_for(agent_id))
|
||||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
if callbacks.get("on_episode_step"):
|
||||
callbacks["on_episode_step"]({"env": base_env, "episode": episode})
|
||||
callbacks.on_episode_step(
|
||||
worker=worker, base_env=base_env, episode=episode)
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
|
@ -508,12 +510,11 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
# Call custom on_episode_end callback.
|
||||
if callbacks.get("on_episode_end"):
|
||||
callbacks["on_episode_end"]({
|
||||
"env": base_env,
|
||||
"policy": policies,
|
||||
"episode": episode
|
||||
})
|
||||
callbacks.on_episode_end(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs = agent_obs
|
||||
|
|
|
@ -4,52 +4,59 @@ Here we use callbacks to track the average CartPole pole angle magnitude as a
|
|||
custom metric.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.env import BaseEnv
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
|
||||
|
||||
def on_episode_start(info):
|
||||
episode = info["episode"]
|
||||
class MyCallbacks(DefaultCallbacks):
|
||||
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
policies: Dict[str, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
print("episode {} started".format(episode.episode_id))
|
||||
episode.user_data["pole_angles"] = []
|
||||
episode.hist_data["pole_angles"] = []
|
||||
|
||||
|
||||
def on_episode_step(info):
|
||||
episode = info["episode"]
|
||||
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
pole_angle = abs(episode.last_observation_for()[2])
|
||||
raw_angle = abs(episode.last_raw_obs_for()[2])
|
||||
assert pole_angle == raw_angle
|
||||
episode.user_data["pole_angles"].append(pole_angle)
|
||||
|
||||
|
||||
def on_episode_end(info):
|
||||
episode = info["episode"]
|
||||
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
**kwargs):
|
||||
pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, pole_angle))
|
||||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
|
||||
|
||||
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
|
||||
**kwargs):
|
||||
print("returned sample batch of size {}".format(samples.count))
|
||||
|
||||
def on_sample_end(info):
|
||||
print("returned sample batch of size {}".format(info["samples"].count))
|
||||
|
||||
|
||||
def on_train_result(info):
|
||||
def on_train_result(self, trainer, result: dict, **kwargs):
|
||||
print("trainer.train() result: {} -> {} episodes".format(
|
||||
info["trainer"], info["result"]["episodes_this_iter"]))
|
||||
trainer, result["episodes_this_iter"]))
|
||||
# you can mutate the result dict to add new fields to return
|
||||
info["result"]["callback_ok"] = True
|
||||
result["callback_ok"] = True
|
||||
|
||||
|
||||
def on_postprocess_traj(info):
|
||||
episode = info["episode"]
|
||||
batch = info["post_batch"]
|
||||
print("postprocessed {} steps".format(batch.count))
|
||||
def on_postprocess_trajectory(
|
||||
self, worker: RolloutWorker, episode: MultiAgentEpisode,
|
||||
agent_id: str, policy_id: str, policies: Dict[str, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[str, SampleBatch], **kwargs):
|
||||
print("postprocessed {} steps".format(postprocessed_batch.count))
|
||||
if "num_batches" not in episode.custom_metrics:
|
||||
episode.custom_metrics["num_batches"] = 0
|
||||
episode.custom_metrics["num_batches"] += 1
|
||||
|
@ -68,14 +75,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"callbacks": {
|
||||
"on_episode_start": on_episode_start,
|
||||
"on_episode_step": on_episode_step,
|
||||
"on_episode_end": on_episode_end,
|
||||
"on_sample_end": on_sample_end,
|
||||
"on_train_result": on_train_result,
|
||||
"on_postprocess_traj": on_postprocess_traj,
|
||||
},
|
||||
"callbacks": MyCallbacks,
|
||||
},
|
||||
return_trials=True)
|
||||
|
||||
|
|
85
rllib/examples/custom_metrics_and_callbacks_legacy.py
Normal file
85
rllib/examples/custom_metrics_and_callbacks_legacy.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
"""Deprecated API; see custom_metrics_and_callbacks.py instead."""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
|
||||
def on_episode_start(info):
|
||||
episode = info["episode"]
|
||||
print("episode {} started".format(episode.episode_id))
|
||||
episode.user_data["pole_angles"] = []
|
||||
episode.hist_data["pole_angles"] = []
|
||||
|
||||
|
||||
def on_episode_step(info):
|
||||
episode = info["episode"]
|
||||
pole_angle = abs(episode.last_observation_for()[2])
|
||||
raw_angle = abs(episode.last_raw_obs_for()[2])
|
||||
assert pole_angle == raw_angle
|
||||
episode.user_data["pole_angles"].append(pole_angle)
|
||||
|
||||
|
||||
def on_episode_end(info):
|
||||
episode = info["episode"]
|
||||
pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, pole_angle))
|
||||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
|
||||
|
||||
|
||||
def on_sample_end(info):
|
||||
print("returned sample batch of size {}".format(info["samples"].count))
|
||||
|
||||
|
||||
def on_train_result(info):
|
||||
print("trainer.train() result: {} -> {} episodes".format(
|
||||
info["trainer"], info["result"]["episodes_this_iter"]))
|
||||
# you can mutate the result dict to add new fields to return
|
||||
info["result"]["callback_ok"] = True
|
||||
|
||||
|
||||
def on_postprocess_traj(info):
|
||||
episode = info["episode"]
|
||||
batch = info["post_batch"]
|
||||
print("postprocessed {} steps".format(batch.count))
|
||||
if "num_batches" not in episode.custom_metrics:
|
||||
episode.custom_metrics["num_batches"] = 0
|
||||
episode.custom_metrics["num_batches"] += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=2000)
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init()
|
||||
trials = tune.run(
|
||||
"PG",
|
||||
stop={
|
||||
"training_iteration": args.num_iters,
|
||||
},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"callbacks": {
|
||||
"on_episode_start": on_episode_start,
|
||||
"on_episode_step": on_episode_step,
|
||||
"on_episode_end": on_episode_end,
|
||||
"on_sample_end": on_sample_end,
|
||||
"on_train_result": on_train_result,
|
||||
"on_postprocess_traj": on_postprocess_traj,
|
||||
},
|
||||
},
|
||||
return_trials=True)
|
||||
|
||||
# verify custom metrics for integration tests
|
||||
custom_metrics = trials[0].last_result["custom_metrics"]
|
||||
print(custom_metrics)
|
||||
assert "pole_angle_mean" in custom_metrics
|
||||
assert "pole_angle_min" in custom_metrics
|
||||
assert "pole_angle_max" in custom_metrics
|
||||
assert "num_batches_mean" in custom_metrics
|
||||
assert "callback_ok" in trials[0].last_result
|
Loading…
Add table
Reference in a new issue