mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Fixing Memory Leak In Multi-Agent environments. Adding tooling for finding memory leaks in workers. (#15815)
This commit is contained in:
parent
d2c755ccef
commit
0be83d9a95
3 changed files with 183 additions and 5 deletions
|
@ -1,6 +1,11 @@
|
|||
from ray.rllib.agents.callbacks import DefaultCallbacks, \
|
||||
MemoryTrackingCallbacks, MultiCallbacks
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
|
||||
__all__ = [
|
||||
"DefaultCallbacks",
|
||||
"MemoryTrackingCallbacks",
|
||||
"MultiCallbacks",
|
||||
"Trainer",
|
||||
"with_common_config",
|
||||
]
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import os
|
||||
import psutil
|
||||
import tracemalloc
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.env import BaseEnv
|
||||
|
@ -206,3 +209,171 @@ class DefaultCallbacks:
|
|||
"trainer": trainer,
|
||||
"result": result,
|
||||
})
|
||||
|
||||
|
||||
class MemoryTrackingCallbacks(DefaultCallbacks):
|
||||
"""
|
||||
MemoryTrackingCallbacks can be used to trace and track memory usage
|
||||
in rollout workers.
|
||||
|
||||
The Memory Tracking Callbacks uses tracemalloc and psutil to track
|
||||
python allocations during rollouts,
|
||||
in training or evaluation.
|
||||
|
||||
The tracking data is logged to the custom_metrics of an episode and
|
||||
can therefore be viewed in tensorboard
|
||||
(or in WandB etc..)
|
||||
|
||||
Warning: This class is meant for debugging and should not be used
|
||||
in production code as tracemalloc incurs
|
||||
a significant slowdown in execution speed.
|
||||
|
||||
Add MemoryTrackingCallbacks callback to the tune config
|
||||
e.g. { ...'callbacks': MemoryTrackingCallbacks ...}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Will track the top 10 lines where memory is allocated
|
||||
tracemalloc.start(10)
|
||||
|
||||
def on_episode_end(self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics("lineno")
|
||||
|
||||
for stat in top_stats[:10]:
|
||||
count = stat.count
|
||||
size = stat.size
|
||||
|
||||
trace = str(stat.traceback)
|
||||
|
||||
episode.custom_metrics[f"tracemalloc/{trace}/size"] = size
|
||||
episode.custom_metrics[f"tracemalloc/{trace}/count"] = count
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
worker_rss = process.memory_info().rss
|
||||
worker_data = process.memory_info().data
|
||||
worker_vms = process.memory_info().vms
|
||||
episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss
|
||||
episode.custom_metrics["tracemalloc/worker/data"] = worker_data
|
||||
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms
|
||||
|
||||
|
||||
class MultiCallbacks(DefaultCallbacks):
|
||||
"""
|
||||
MultiCallback allows multiple callbacks to be registered at the same
|
||||
time in the config of the environment
|
||||
|
||||
For example:
|
||||
|
||||
'callbacks': MultiCallbacks([
|
||||
MyCustomStatsCallbacks,
|
||||
MyCustomVideoCallbacks,
|
||||
MyCustomTraceCallbacks,
|
||||
....
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(self, callback_class_list):
|
||||
super().__init__()
|
||||
self._callback_class_list = callback_class_list
|
||||
|
||||
self._callback_list = []
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._callback_list = [
|
||||
callback_class() for callback_class in self._callback_class_list
|
||||
]
|
||||
|
||||
return self
|
||||
|
||||
def on_episode_start(self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs)
|
||||
|
||||
def on_episode_step(self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_step(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs)
|
||||
|
||||
def on_episode_end(self,
|
||||
*,
|
||||
worker: "RolloutWorker",
|
||||
base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode,
|
||||
env_index: Optional[int] = None,
|
||||
**kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_episode_end(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
**kwargs)
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
|
||||
agent_id: AgentID, policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_postprocess_trajectory(
|
||||
worker=worker,
|
||||
episode=episode,
|
||||
agent_id=agent_id,
|
||||
policy_id=policy_id,
|
||||
policies=policies,
|
||||
postprocessed_batch=postprocessed_batch,
|
||||
original_batches=original_batches,
|
||||
**kwargs)
|
||||
|
||||
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
|
||||
**kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_sample_end(worker=worker, samples=samples, **kwargs)
|
||||
|
||||
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
|
||||
result: dict, **kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_learn_on_batch(
|
||||
policy=policy,
|
||||
train_batch=train_batch,
|
||||
result=result,
|
||||
**kwargs)
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
|
||||
for callback in self._callback_list:
|
||||
callback.on_train_result(trainer=trainer, result=result, **kwargs)
|
||||
|
|
|
@ -644,7 +644,7 @@ class SimpleListCollector(SampleCollector):
|
|||
if is_done and check_dones and \
|
||||
not pre_batch[SampleBatch.DONES][-1]:
|
||||
raise ValueError(
|
||||
"Episode {} terminated for all agents, but we still don't "
|
||||
"Episode {} terminated for all agents, but we still"
|
||||
"don't have a last observation for agent {} (policy "
|
||||
"{}). ".format(
|
||||
episode_id, agent_id, self.agent_key_to_policy_id[(
|
||||
|
@ -653,9 +653,6 @@ class SimpleListCollector(SampleCollector):
|
|||
"of all live agents when setting done[__all__] to "
|
||||
"True. Alternatively, set no_done_at_end=True to "
|
||||
"allow this.")
|
||||
# If (only this?) agent is done, erase its buffer entirely.
|
||||
if pre_batch[SampleBatch.DONES][-1]:
|
||||
del self.agent_collectors[(episode_id, agent_id)]
|
||||
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
|
@ -683,7 +680,8 @@ class SimpleListCollector(SampleCollector):
|
|||
# Append into policy batches and reset.
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
|
||||
agent_key = (episode_id, agent_id)
|
||||
pid = self.agent_key_to_policy_id[agent_key]
|
||||
policy = self.policy_map[pid]
|
||||
self.callbacks.on_postprocess_trajectory(
|
||||
worker=get_global_worker(),
|
||||
|
@ -699,6 +697,10 @@ class SimpleListCollector(SampleCollector):
|
|||
pid].add_postprocessed_batch_for_training(
|
||||
post_batch, policy.view_requirements)
|
||||
|
||||
if is_done:
|
||||
del self.agent_key_to_policy_id[agent_key]
|
||||
del self.agent_collectors[agent_key]
|
||||
|
||||
if policy_collector_group:
|
||||
env_steps = self.episode_steps[episode_id]
|
||||
policy_collector_group.env_steps += env_steps
|
||||
|
|
Loading…
Add table
Reference in a new issue