diff --git a/rllib/agents/__init__.py b/rllib/agents/__init__.py index cffe01243..ab5bc784a 100644 --- a/rllib/agents/__init__.py +++ b/rllib/agents/__init__.py @@ -1,6 +1,11 @@ +from ray.rllib.agents.callbacks import DefaultCallbacks, \ + MemoryTrackingCallbacks, MultiCallbacks from ray.rllib.agents.trainer import Trainer, with_common_config __all__ = [ + "DefaultCallbacks", + "MemoryTrackingCallbacks", + "MultiCallbacks", "Trainer", "with_common_config", ] diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 1972fabec..4f6fdda1e 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -1,3 +1,6 @@ +import os +import psutil +import tracemalloc from typing import Dict, Optional, TYPE_CHECKING from ray.rllib.env import BaseEnv @@ -206,3 +209,171 @@ class DefaultCallbacks: "trainer": trainer, "result": result, }) + + +class MemoryTrackingCallbacks(DefaultCallbacks): + """ + MemoryTrackingCallbacks can be used to trace and track memory usage + in rollout workers. + + The Memory Tracking Callbacks uses tracemalloc and psutil to track + python allocations during rollouts, + in training or evaluation. + + The tracking data is logged to the custom_metrics of an episode and + can therefore be viewed in tensorboard + (or in WandB etc..) + + Warning: This class is meant for debugging and should not be used + in production code as tracemalloc incurs + a significant slowdown in execution speed. + + Add MemoryTrackingCallbacks callback to the tune config + e.g. { ...'callbacks': MemoryTrackingCallbacks ...} + """ + + def __init__(self): + super().__init__() + + # Will track the top 10 lines where memory is allocated + tracemalloc.start(10) + + def on_episode_end(self, + *, + worker: "RolloutWorker", + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode: MultiAgentEpisode, + env_index: Optional[int] = None, + **kwargs) -> None: + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics("lineno") + + for stat in top_stats[:10]: + count = stat.count + size = stat.size + + trace = str(stat.traceback) + + episode.custom_metrics[f"tracemalloc/{trace}/size"] = size + episode.custom_metrics[f"tracemalloc/{trace}/count"] = count + + process = psutil.Process(os.getpid()) + worker_rss = process.memory_info().rss + worker_data = process.memory_info().data + worker_vms = process.memory_info().vms + episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss + episode.custom_metrics["tracemalloc/worker/data"] = worker_data + episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms + + +class MultiCallbacks(DefaultCallbacks): + """ + MultiCallback allows multiple callbacks to be registered at the same + time in the config of the environment + + For example: + + 'callbacks': MultiCallbacks([ + MyCustomStatsCallbacks, + MyCustomVideoCallbacks, + MyCustomTraceCallbacks, + .... + ]) + """ + + def __init__(self, callback_class_list): + super().__init__() + self._callback_class_list = callback_class_list + + self._callback_list = [] + + def __call__(self, *args, **kwargs): + self._callback_list = [ + callback_class() for callback_class in self._callback_class_list + ] + + return self + + def on_episode_start(self, + *, + worker: "RolloutWorker", + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode: MultiAgentEpisode, + env_index: Optional[int] = None, + **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_start( + worker=worker, + base_env=base_env, + policies=policies, + episode=episode, + env_index=env_index, + **kwargs) + + def on_episode_step(self, + *, + worker: "RolloutWorker", + base_env: BaseEnv, + episode: MultiAgentEpisode, + env_index: Optional[int] = None, + **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_step( + worker=worker, + base_env=base_env, + episode=episode, + env_index=env_index, + **kwargs) + + def on_episode_end(self, + *, + worker: "RolloutWorker", + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode: MultiAgentEpisode, + env_index: Optional[int] = None, + **kwargs) -> None: + for callback in self._callback_list: + callback.on_episode_end( + worker=worker, + base_env=base_env, + policies=policies, + episode=episode, + env_index=env_index, + **kwargs) + + def on_postprocess_trajectory( + self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, + agent_id: AgentID, policy_id: PolicyID, + policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, + original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None: + for callback in self._callback_list: + callback.on_postprocess_trajectory( + worker=worker, + episode=episode, + agent_id=agent_id, + policy_id=policy_id, + policies=policies, + postprocessed_batch=postprocessed_batch, + original_batches=original_batches, + **kwargs) + + def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, + **kwargs) -> None: + for callback in self._callback_list: + callback.on_sample_end(worker=worker, samples=samples, **kwargs) + + def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch, + result: dict, **kwargs) -> None: + for callback in self._callback_list: + callback.on_learn_on_batch( + policy=policy, + train_batch=train_batch, + result=result, + **kwargs) + + def on_train_result(self, *, trainer, result: dict, **kwargs) -> None: + for callback in self._callback_list: + callback.on_train_result(trainer=trainer, result=result, **kwargs) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 2846f4c84..4f28d806a 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -644,7 +644,7 @@ class SimpleListCollector(SampleCollector): if is_done and check_dones and \ not pre_batch[SampleBatch.DONES][-1]: raise ValueError( - "Episode {} terminated for all agents, but we still don't " + "Episode {} terminated for all agents, but we still" "don't have a last observation for agent {} (policy " "{}). ".format( episode_id, agent_id, self.agent_key_to_policy_id[( @@ -653,9 +653,6 @@ class SimpleListCollector(SampleCollector): "of all live agents when setting done[__all__] to " "True. Alternatively, set no_done_at_end=True to " "allow this.") - # If (only this?) agent is done, erase its buffer entirely. - if pre_batch[SampleBatch.DONES][-1]: - del self.agent_collectors[(episode_id, agent_id)] other_batches = pre_batches.copy() del other_batches[agent_id] @@ -683,7 +680,8 @@ class SimpleListCollector(SampleCollector): # Append into policy batches and reset. from ray.rllib.evaluation.rollout_worker import get_global_worker for agent_id, post_batch in sorted(post_batches.items()): - pid = self.agent_key_to_policy_id[(episode_id, agent_id)] + agent_key = (episode_id, agent_id) + pid = self.agent_key_to_policy_id[agent_key] policy = self.policy_map[pid] self.callbacks.on_postprocess_trajectory( worker=get_global_worker(), @@ -699,6 +697,10 @@ class SimpleListCollector(SampleCollector): pid].add_postprocessed_batch_for_training( post_batch, policy.view_requirements) + if is_done: + del self.agent_key_to_policy_id[agent_key] + del self.agent_collectors[agent_key] + if policy_collector_group: env_steps = self.episode_steps[episode_id] policy_collector_group.env_steps += env_steps