[RLlib] Added DefaultCallbacks which replaces old callbacks dict interface (#6972)

This commit is contained in:
roireshef 2020-04-17 02:06:42 +03:00 committed by GitHub
parent 35ae7f0e68
commit dbcad35022
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 392 additions and 161 deletions

1
.gitignore vendored
View file

@ -106,6 +106,7 @@ scripts/nodes.txt
# Generated documentation files
/doc/_build
/doc/source/_static/thumbs
/doc/source/tune/generated_guides/
# User-specific stuff:
.idea/**/workspace.xml

View file

@ -496,51 +496,12 @@ Ray actors provide high levels of performance, so in more complex cases they can
Callbacks and Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You can provide callback functions to be called at points during policy evaluation. These functions have access to an info dict containing state for the current `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__. Custom state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``info["episode"].user_data`` dict, and custom scalar metrics reported by saving values to the ``info["episode"].custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. The following example (full code `here <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__) logs a custom metric from the environment:
You can provide callbacks to be called at points during policy evaluation. These callbacks have access to state for the current `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__. Certain callbacks such as ``on_postprocess_trajectory``, ``on_sample_end``, and ``on_train_result`` are also places where custom postprocessing can be applied to intermediate data or results.
.. code-block:: python
User-defined state can be stored for the `episode <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ in the ``episode.user_data`` dict, and custom scalar metrics reported by saving values to the ``episode.custom_metrics`` dict. These custom metrics will be aggregated and reported as part of training results. For a full example, see `custom_metrics_and_callbacks.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_metrics_and_callbacks.py>`__.
def on_episode_start(info):
print(info.keys()) # -> "env", 'episode"
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(info):
episode = info["episode"]
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
def on_train_result(info):
print("trainer.train() result: {} -> {} episodes".format(
info["trainer"].__name__, info["result"]["episodes_this_iter"]))
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["post_batch"] # note: you can mutate this
print("postprocessed {} steps".format(batch.count))
ray.init()
analysis = tune.run(
"PG",
config={
"env": "CartPole-v0",
"callbacks": {
"on_episode_start": on_episode_start,
"on_episode_step": on_episode_step,
"on_episode_end": on_episode_end,
"on_train_result": on_train_result,
"on_postprocess_traj": on_postprocess_traj,
},
},
)
.. autoclass:: ray.rllib.agents.callbacks.DefaultCallbacks
:members:
Visualizing Custom Metrics
~~~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -1321,6 +1321,14 @@ py_test(
args = ["--num-iters=2"]
)
py_test(
name = "examples/custom_metrics_and_callbacks_legacy",
tags = ["examples", "examples_C"],
size = "small",
srcs = ["examples/custom_metrics_and_callbacks_legacy.py"],
args = ["--num-iters=2"]
)
py_test(
name = "examples/custom_tf_policy",
tags = ["examples", "examples_C"],

166
rllib/agents/callbacks.py Normal file
View file

@ -0,0 +1,166 @@
from typing import Dict
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
@PublicAPI
class DefaultCallbacks:
"""Abstract base class for RLlib callbacks (similar to Keras callbacks).
These callbacks can be used for custom metrics and custom postprocessing.
By default, all of these callbacks are no-ops. To configure custom training
callbacks, subclass DefaultCallbacks and then set
{"callbacks": YourCallbacksClass} in the trainer config.
"""
def __init__(self, legacy_callbacks_dict: Dict[str, callable] = None):
if legacy_callbacks_dict:
deprecation_warning(
"callbacks dict interface",
"a class extending rllib.agents.callbacks.DefaultCallbacks")
self.legacy_callbacks = legacy_callbacks_dict or {}
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, **kwargs):
"""Callback run on the rollout worker before each episode starts.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_unwrapped().
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode (MultiAgentEpisode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_start"):
self.legacy_callbacks["on_episode_start"]({
"env": base_env,
"policy": policies,
"episode": episode,
})
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, **kwargs):
"""Runs on each episode step.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_unwrapped().
episode (MultiAgentEpisode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_step"):
self.legacy_callbacks["on_episode_step"]({
"env": base_env,
"episode": episode
})
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
**kwargs):
"""Runs when an episode is done.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_unwrapped().
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode (MultiAgentEpisode): Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_episode_end"):
self.legacy_callbacks["on_episode_end"]({
"env": base_env,
"policy": policies,
"episode": episode,
})
def on_postprocess_trajectory(
self, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
"""Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy,
including looking at the trajectory data of other agents in multi-agent
settings.
Args:
worker (RolloutWorker): Reference to the current rollout worker.
episode (MultiAgentEpisode): Episode object.
agent_id (str): Id of the current agent.
policy_id (str): Id of the current policy for the agent.
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
postprocessed_batch (SampleBatch): The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
original_batches (dict): Mapping of agents to their unpostprocessed
trajectory data. You should not mutate this object.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_postprocess_traj"):
self.legacy_callbacks["on_postprocess_traj"]({
"episode": episode,
"agent_id": agent_id,
"pre_batch": original_batches[agent_id],
"post_batch": postprocessed_batch,
"all_pre_batches": original_batches,
})
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
**kwargs):
"""Called at the end RolloutWorker.sample().
Args:
worker (RolloutWorker): Reference to the current rollout worker.
samples (SampleBatch): Batch to be returned. You can mutate this
object to modify the samples generated.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_sample_end"):
self.legacy_callbacks["on_sample_end"]({
"worker": worker,
"samples": samples,
})
def on_train_result(self, trainer, result: dict, **kwargs):
"""Called at the end of Trainable.train().
Args:
trainer (Trainer): Current trainer instance.
result (dict): Dict of results returned from trainer.train() call.
You can mutate this object to add additional metrics.
kwargs: Forward compatibility placeholder.
"""
if self.legacy_callbacks.get("on_train_result"):
self.legacy_callbacks["on_train_result"]({
"trainer": trainer,
"result": result,
})

View file

@ -9,6 +9,7 @@ import tempfile
import ray
from ray.exceptions import RayError
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
@ -126,24 +127,10 @@ COMMON_CONFIG = {
# `rllib train` command, you can also use the `-v` and `-vv` flags as
# shorthand for INFO and DEBUG.
"log_level": "WARN",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
# may also mutate the passed in batch data in your callback.
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "worker": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {
# "agent_id": ..., "episode": ...,
# "pre_batch": (before processing),
# "post_batch": (after processing),
# "all_pre_batches": (other agent ids),
# }
},
# Callbacks that will be run during various phases of training. See the
# `DefaultCallbacks` class and `examples/custom_metrics_and_callbacks.py`
# for more usage information.
"callbacks": DefaultCallbacks,
# Whether to attempt to continue training if a worker crashes. The number
# of currently healthy workers is reported as the "num_healthy_workers"
# metric.
@ -542,11 +529,7 @@ class Trainer(Trainable):
@override(Trainable)
def _log_result(self, result):
if self.config["callbacks"].get("on_train_result"):
self.config["callbacks"]["on_train_result"]({
"trainer": self,
"result": result,
})
self.callbacks.on_train_result(trainer=self, result=result)
# log after the callback is invoked, so that the user has a chance
# to mutate the result
Trainable._log_result(self, result)
@ -584,6 +567,12 @@ class Trainer(Trainable):
self.env_creator = lambda env_config: normalize(inner(env_config))
Trainer._validate_config(self.config)
if not callable(self.config["callbacks"]):
raise ValueError(
"`callbacks` must be a callable method that "
"returns a subclass of DefaultCallbacks, got {}".format(
self.config["callbacks"]))
self.callbacks = self.config["callbacks"]()
log_level = self.config.get("log_level")
if log_level in ["WARN", "ERROR"]:
logger.info("Current log_level is {}. For more information, "
@ -918,6 +907,15 @@ class Trainer(Trainable):
"sample_batch_size", new="rollout_fragment_length")
config2["rollout_fragment_length"] = config2["sample_batch_size"]
del config2["sample_batch_size"]
if "callbacks" in config2 and type(config2["callbacks"]) is dict:
legacy_callbacks_dict = config2["callbacks"]
def make_callbacks():
# Deprecation warning will be logged by DefaultCallbacks.
return DefaultCallbacks(
legacy_callbacks_dict=legacy_callbacks_dict)
config2["callbacks"] = make_callbacks
return deep_update(config1, config2, cls._allow_unknown_configs,
cls._allow_unknown_subkeys,
cls._override_all_subkeys_if_type_changes)

View file

@ -1,6 +1,7 @@
import logging
from ray.rllib.agents import with_common_config
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.model import restore_original_dimensions
@ -21,12 +22,18 @@ torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
def on_episode_start(info):
# save env state when an episode starts
env = info["env"].get_unwrapped()[0]
state = env.get_state()
episode = info["episode"]
episode.user_data["initial_state"] = state
class AlphaZeroDefaultCallbacks(DefaultCallbacks):
"""AlphaZero callbacks.
If you use custom callbacks, you must extend this class and call super()
for on_episode_start.
"""
def on_episode_start(self, worker, base_env, policies, episode, **kwargs):
# save env state when an episode starts
env = base_env.get_unwrapped()[0]
state = env.get_state()
episode.user_data["initial_state"] = state
# yapf: disable
@ -94,9 +101,7 @@ DEFAULT_CONFIG = with_common_config({
},
# === Callbacks ===
"callbacks": {
"on_episode_start": on_episode_start,
},
"callbacks": AlphaZeroDefaultCallbacks,
"use_pytorch": True,
})

View file

@ -214,7 +214,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (dict): Dict of custom debug callbacks.
callbacks (DefaultCallbacks): Custom training callbacks.
input_creator (func): Function that returns an InputReader object
for loading previous generated experiences.
input_evaluation (list): How to evaluate the policy performance.
@ -279,7 +279,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
env_context = EnvContext(env_config or {}, worker_index)
self.policy_config = policy_config
self.callbacks = callbacks or {}
if callbacks:
self.callbacks = callbacks()
else:
from ray.rllib.agents.callbacks import DefaultCallbacks
self.callbacks = DefaultCallbacks()
self.worker_index = worker_index
self.num_workers = num_workers
model_config = model_config or {}
@ -444,6 +448,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
if sample_async:
self.sampler = AsyncSampler(
self,
self.async_env,
self.policy_map,
policy_mapping_fn,
@ -462,6 +467,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
self.sampler.start()
else:
self.sampler = SyncSampler(
self,
self.async_env,
self.policy_map,
policy_mapping_fn,
@ -518,8 +524,7 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
batches.append(batch)
batch = batches[0].concat_samples(batches)
if self.callbacks.get("on_sample_end"):
self.callbacks["on_sample_end"]({"worker": self, "samples": batch})
self.callbacks.on_sample_end(worker=self, samples=batch)
# Always do writes prior to compression for consistency and to allow
# for better compression inside the writer.

View file

@ -72,13 +72,13 @@ class MultiAgentSampleBatchBuilder:
corresponding policy batch for the agent's policy.
"""
def __init__(self, policy_map, clip_rewards, postp_callback):
def __init__(self, policy_map, clip_rewards, callbacks):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
postp_callback: function to call on each postprocessed batch.
callbacks (DefaultCallbacks): RLlib callbacks.
"""
self.policy_map = policy_map
@ -89,7 +89,7 @@ class MultiAgentSampleBatchBuilder:
}
self.agent_builders = {}
self.agent_to_policy = {}
self.postp_callback = postp_callback
self.callbacks = callbacks
self.count = 0 # increment this manually
def total(self):
@ -161,15 +161,16 @@ class MultiAgentSampleBatchBuilder:
format(summarize(post_batches)))
# Append into policy batches and reset
from ray.rllib.evaluation.rollout_worker import get_global_worker
for agent_id, post_batch in sorted(post_batches.items()):
if self.postp_callback:
self.postp_callback({
"episode": episode,
"agent_id": agent_id,
"pre_batch": pre_batches[agent_id],
"post_batch": post_batch,
"all_pre_batches": pre_batches,
})
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
episode=episode,
agent_id=agent_id,
policy_id=self.agent_to_policy[agent_id],
policies=self.policy_map,
postprocessed_batch=post_batch,
original_batches=pre_batches)
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)

View file

@ -60,6 +60,7 @@ class SamplerInput(InputReader):
class SyncSampler(SamplerInput):
def __init__(self,
worker,
env,
policies,
policy_mapping_fn,
@ -84,7 +85,7 @@ class SyncSampler(SamplerInput):
self.extra_batches = queue.Queue()
self.perf_stats = PerfStats()
self.rollout_provider = _env_runner(
self.base_env, self.extra_batches.put, self.policies,
worker, self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
@ -121,6 +122,7 @@ class SyncSampler(SamplerInput):
class AsyncSampler(threading.Thread, SamplerInput):
def __init__(self,
worker,
env,
policies,
policy_mapping_fn,
@ -139,6 +141,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
self.worker = worker
self.base_env = BaseEnv.to_base_env(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
@ -178,7 +181,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
extra_batches_putter = (
lambda x: self.extra_batches.put(x, timeout=600.0))
rollout_provider = _env_runner(
self.base_env, extra_batches_putter, self.policies,
self.worker, self.base_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.rollout_fragment_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
@ -224,13 +227,14 @@ class AsyncSampler(threading.Thread, SamplerInput):
return extra
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
rollout_fragment_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,
perf_stats, soft_horizon, no_done_at_end):
def _env_runner(worker, base_env, extra_batch_callback, policies,
policy_mapping_fn, rollout_fragment_length, horizon,
preprocessors, obs_filters, clip_rewards, clip_actions, pack,
callbacks, tf_sess, perf_stats, soft_horizon, no_done_at_end):
"""This implements the common experience collection logic.
Args:
worker (RolloutWorker): reference to the current rollout worker.
base_env (BaseEnv): env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
policies (dict): Map of policy ids to Policy instances.
@ -250,7 +254,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
guarantees batches will be exactly `rollout_fragment_length` in
size.
clip_actions (bool): Whether to clip actions to the space range.
callbacks (dict): User callbacks to run on episode events.
callbacks (DefaultCallbacks): User callbacks to run on episode events.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
perf_stats (PerfStats): Record perf stats into this object.
@ -300,8 +304,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
if batch_builder_pool:
return batch_builder_pool.pop()
else:
return MultiAgentSampleBatchBuilder(
policies, clip_rewards, callbacks.get("on_postprocess_traj"))
return MultiAgentSampleBatchBuilder(policies, clip_rewards,
callbacks)
def new_episode():
episode = MultiAgentEpisode(policies, policy_mapping_fn,
@ -313,13 +317,11 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
# Call custom on_episode_start callback.
if callbacks.get("on_episode_start"):
callbacks["on_episode_start"]({
"env": base_env,
"policy": policies,
"episode": episode,
})
callbacks.on_episode_start(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
return episode
active_episodes = defaultdict(new_episode)
@ -340,7 +342,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
# Process observations and prepare for policy evaluation
t1 = time.time()
active_envs, to_eval, outputs = _process_observations(
base_env, policies, batch_builder_pool, active_episodes,
worker, base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, rollout_fragment_length, pack,
callbacks, soft_horizon, no_done_at_end)
@ -368,7 +370,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
perf_stats.env_wait_time += time.time() - t4
def _process_observations(base_env, policies, batch_builder_pool,
def _process_observations(worker, base_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, preprocessors,
obs_filters, rollout_fragment_length, pack,
@ -482,8 +484,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
**episode.last_pi_info_for(agent_id))
# Invoke the step callback after the step is logged to the episode
if callbacks.get("on_episode_step"):
callbacks["on_episode_step"]({"env": base_env, "episode": episode})
callbacks.on_episode_step(
worker=worker, base_env=base_env, episode=episode)
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
@ -508,12 +510,11 @@ def _process_observations(base_env, policies, batch_builder_pool,
episode=episode,
tf_sess=getattr(p, "_sess", None))
# Call custom on_episode_end callback.
if callbacks.get("on_episode_end"):
callbacks["on_episode_end"]({
"env": base_env,
"policy": policies,
"episode": episode
})
callbacks.on_episode_end(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs = agent_obs

View file

@ -4,55 +4,62 @@ Here we use callbacks to track the average CartPole pole angle magnitude as a
custom metric.
"""
from typing import Dict
import argparse
import numpy as np
import ray
from ray import tune
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.agents.callbacks import DefaultCallbacks
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, **kwargs):
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, **kwargs):
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
**kwargs):
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
**kwargs):
print("returned sample batch of size {}".format(samples.count))
def on_episode_end(info):
episode = info["episode"]
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_train_result(self, trainer, result: dict, **kwargs):
print("trainer.train() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
def on_train_result(info):
print("trainer.train() result: {} -> {} episodes".format(
info["trainer"], info["result"]["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
def on_postprocess_trajectory(
self, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
print("postprocessed {} steps".format(postprocessed_batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
if __name__ == "__main__":
@ -68,14 +75,7 @@ if __name__ == "__main__":
},
config={
"env": "CartPole-v0",
"callbacks": {
"on_episode_start": on_episode_start,
"on_episode_step": on_episode_step,
"on_episode_end": on_episode_end,
"on_sample_end": on_sample_end,
"on_train_result": on_train_result,
"on_postprocess_traj": on_postprocess_traj,
},
"callbacks": MyCallbacks,
},
return_trials=True)

View file

@ -0,0 +1,85 @@
"""Deprecated API; see custom_metrics_and_callbacks.py instead."""
import argparse
import numpy as np
import ray
from ray import tune
def on_episode_start(info):
episode = info["episode"]
print("episode {} started".format(episode.episode_id))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(info):
episode = info["episode"]
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(info):
print("returned sample batch of size {}".format(info["samples"].count))
def on_train_result(info):
print("trainer.train() result: {} -> {} episodes".format(
info["trainer"], info["result"]["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
info["result"]["callback_ok"] = True
def on_postprocess_traj(info):
episode = info["episode"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=2000)
args = parser.parse_args()
ray.init()
trials = tune.run(
"PG",
stop={
"training_iteration": args.num_iters,
},
config={
"env": "CartPole-v0",
"callbacks": {
"on_episode_start": on_episode_start,
"on_episode_step": on_episode_step,
"on_episode_end": on_episode_end,
"on_sample_end": on_sample_end,
"on_train_result": on_train_result,
"on_postprocess_traj": on_postprocess_traj,
},
},
return_trials=True)
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert "callback_ok" in trials[0].last_result