[RLlib] Fixing Memory Leak In Multi-Agent environments. Adding tooling for finding memory leaks in workers. (#15815)

This commit is contained in:
Chris Bamford 2021-05-18 12:23:00 +01:00 committed by GitHub
parent d2c755ccef
commit 0be83d9a95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 5 deletions

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.callbacks import DefaultCallbacks, \
MemoryTrackingCallbacks, MultiCallbacks
from ray.rllib.agents.trainer import Trainer, with_common_config
__all__ = [
"DefaultCallbacks",
"MemoryTrackingCallbacks",
"MultiCallbacks",
"Trainer",
"with_common_config",
]

View file

@ -1,3 +1,6 @@
import os
import psutil
import tracemalloc
from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.env import BaseEnv
@ -206,3 +209,171 @@ class DefaultCallbacks:
"trainer": trainer,
"result": result,
})
class MemoryTrackingCallbacks(DefaultCallbacks):
"""
MemoryTrackingCallbacks can be used to trace and track memory usage
in rollout workers.
The Memory Tracking Callbacks uses tracemalloc and psutil to track
python allocations during rollouts,
in training or evaluation.
The tracking data is logged to the custom_metrics of an episode and
can therefore be viewed in tensorboard
(or in WandB etc..)
Warning: This class is meant for debugging and should not be used
in production code as tracemalloc incurs
a significant slowdown in execution speed.
Add MemoryTrackingCallbacks callback to the tune config
e.g. { ...'callbacks': MemoryTrackingCallbacks ...}
"""
def __init__(self):
super().__init__()
# Will track the top 10 lines where memory is allocated
tracemalloc.start(10)
def on_episode_end(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
for stat in top_stats[:10]:
count = stat.count
size = stat.size
trace = str(stat.traceback)
episode.custom_metrics[f"tracemalloc/{trace}/size"] = size
episode.custom_metrics[f"tracemalloc/{trace}/count"] = count
process = psutil.Process(os.getpid())
worker_rss = process.memory_info().rss
worker_data = process.memory_info().data
worker_vms = process.memory_info().vms
episode.custom_metrics["tracemalloc/worker/rss"] = worker_rss
episode.custom_metrics["tracemalloc/worker/data"] = worker_data
episode.custom_metrics["tracemalloc/worker/vms"] = worker_vms
class MultiCallbacks(DefaultCallbacks):
"""
MultiCallback allows multiple callbacks to be registered at the same
time in the config of the environment
For example:
'callbacks': MultiCallbacks([
MyCustomStatsCallbacks,
MyCustomVideoCallbacks,
MyCustomTraceCallbacks,
....
])
"""
def __init__(self, callback_class_list):
super().__init__()
self._callback_class_list = callback_class_list
self._callback_list = []
def __call__(self, *args, **kwargs):
self._callback_list = [
callback_class() for callback_class in self._callback_class_list
]
return self
def on_episode_start(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_start(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs)
def on_episode_step(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_step(
worker=worker,
base_env=base_env,
episode=episode,
env_index=env_index,
**kwargs)
def on_episode_end(self,
*,
worker: "RolloutWorker",
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode,
env_index: Optional[int] = None,
**kwargs) -> None:
for callback in self._callback_list:
callback.on_episode_end(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode,
env_index=env_index,
**kwargs)
def on_postprocess_trajectory(
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs) -> None:
for callback in self._callback_list:
callback.on_postprocess_trajectory(
worker=worker,
episode=episode,
agent_id=agent_id,
policy_id=policy_id,
policies=policies,
postprocessed_batch=postprocessed_batch,
original_batches=original_batches,
**kwargs)
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
**kwargs) -> None:
for callback in self._callback_list:
callback.on_sample_end(worker=worker, samples=samples, **kwargs)
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
result: dict, **kwargs) -> None:
for callback in self._callback_list:
callback.on_learn_on_batch(
policy=policy,
train_batch=train_batch,
result=result,
**kwargs)
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
for callback in self._callback_list:
callback.on_train_result(trainer=trainer, result=result, **kwargs)

View file

@ -644,7 +644,7 @@ class SimpleListCollector(SampleCollector):
if is_done and check_dones and \
not pre_batch[SampleBatch.DONES][-1]:
raise ValueError(
"Episode {} terminated for all agents, but we still don't "
"Episode {} terminated for all agents, but we still"
"don't have a last observation for agent {} (policy "
"{}). ".format(
episode_id, agent_id, self.agent_key_to_policy_id[(
@ -653,9 +653,6 @@ class SimpleListCollector(SampleCollector):
"of all live agents when setting done[__all__] to "
"True. Alternatively, set no_done_at_end=True to "
"allow this.")
# If (only this?) agent is done, erase its buffer entirely.
if pre_batch[SampleBatch.DONES][-1]:
del self.agent_collectors[(episode_id, agent_id)]
other_batches = pre_batches.copy()
del other_batches[agent_id]
@ -683,7 +680,8 @@ class SimpleListCollector(SampleCollector):
# Append into policy batches and reset.
from ray.rllib.evaluation.rollout_worker import get_global_worker
for agent_id, post_batch in sorted(post_batches.items()):
pid = self.agent_key_to_policy_id[(episode_id, agent_id)]
agent_key = (episode_id, agent_id)
pid = self.agent_key_to_policy_id[agent_key]
policy = self.policy_map[pid]
self.callbacks.on_postprocess_trajectory(
worker=get_global_worker(),
@ -699,6 +697,10 @@ class SimpleListCollector(SampleCollector):
pid].add_postprocessed_batch_for_training(
post_batch, policy.view_requirements)
if is_done:
del self.agent_key_to_policy_id[agent_key]
del self.agent_collectors[agent_key]
if policy_collector_group:
env_steps = self.episode_steps[episode_id]
policy_collector_group.env_steps += env_steps