[RLlib] Cleanup some deprecated metric keys and classes. (#26036)

This commit is contained in:
Sven Mika 2022-06-23 21:30:01 +02:00 committed by GitHub
parent 33b30aed15
commit 59a967a3a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 146 additions and 105 deletions

View file

@ -4,10 +4,6 @@ from typing import Optional
from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.a3c.a3c import A3CConfig, A3C from ray.rllib.algorithms.a3c.a3c import A3CConfig, A3C
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.rollout_ops import ( from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample, synchronous_parallel_sample,
) )
@ -18,8 +14,10 @@ from ray.rllib.utils.metrics import (
APPLY_GRADS_TIMER, APPLY_GRADS_TIMER,
COMPUTE_GRADS_TIMER, COMPUTE_GRADS_TIMER,
NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED,
WORKER_UPDATE_TIMER, NUM_ENV_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
) )
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
PartialAlgorithmConfigDict, PartialAlgorithmConfigDict,
@ -188,8 +186,8 @@ class A2C(A3C):
) )
if self._num_microbatches >= num_microbatches: if self._num_microbatches >= num_microbatches:
# Update counters. # Update counters.
self._counters[STEPS_TRAINED_COUNTER] += self._microbatches_counts self._counters[NUM_ENV_STEPS_TRAINED] += self._microbatches_counts
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = self._microbatches_counts self._counters[NUM_AGENT_STEPS_TRAINED] += self._microbatches_counts
# Apply gradients. # Apply gradients.
apply_timer = self._timers[APPLY_GRADS_TIMER] apply_timer = self._timers[APPLY_GRADS_TIMER]
@ -206,7 +204,7 @@ class A2C(A3C):
global_vars = { global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
} }
with self._timers[WORKER_UPDATE_TIMER]: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights( self.workers.sync_weights(
policies=self.workers.local_worker().get_policies_to_train(), policies=self.workers.local_worker().get_policies_to_train(),
global_vars=global_vars, global_vars=global_vars,

View file

@ -42,7 +42,9 @@ from ray.rllib.evaluation.metrics import (
) )
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.common import (
STEPS_TRAINED_THIS_ITER_COUNTER, # TODO: Backward compatibility.
)
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step
from ray.rllib.offline import get_offline_io_resource_bundles from ray.rllib.offline import get_offline_io_resource_bundles
@ -73,6 +75,7 @@ from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED_THIS_ITER, NUM_ENV_STEPS_SAMPLED_THIS_ITER,
NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
TRAINING_ITERATION_TIMER, TRAINING_ITERATION_TIMER,
) )
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
@ -569,7 +572,6 @@ class Algorithm(Trainable):
# Results dict for training (and if appolicable: evaluation). # Results dict for training (and if appolicable: evaluation).
results: ResultDict = {} results: ResultDict = {}
self._rollout_worker_metrics = []
local_worker = ( local_worker = (
self.workers.local_worker() self.workers.local_worker()
if hasattr(self.workers, "local_worker") if hasattr(self.workers, "local_worker")
@ -593,13 +595,12 @@ class Algorithm(Trainable):
results.update(self._run_one_evaluation(train_future=None)) results.update(self._run_one_evaluation(train_future=None))
# Collect rollout worker metrics. # Collect rollout worker metrics.
episodes, self._episodes_to_be_collected = collect_episodes( episodes_this_iter, self._episodes_to_be_collected = collect_episodes(
local_worker, local_worker,
self._remote_workers_for_metrics, self._remote_workers_for_metrics,
self._episodes_to_be_collected, self._episodes_to_be_collected,
timeout_seconds=self.config["metrics_episode_collection_timeout_s"], timeout_seconds=self.config["metrics_episode_collection_timeout_s"],
) )
self._rollout_worker_metrics.extend(episodes)
# Attach latest available evaluation results to train results, # Attach latest available evaluation results to train results,
# if necessary. # if necessary.
@ -613,9 +614,10 @@ class Algorithm(Trainable):
# Sync filters on workers. # Sync filters on workers.
self._sync_filters_if_needed(self.workers) self._sync_filters_if_needed(self.workers)
# Collect worker metrics. # Collect worker metrics and add combine them with `results`.
if self.config["_disable_execution_plan_api"]: if self.config["_disable_execution_plan_api"]:
results = self._compile_iteration_results( results = self._compile_iteration_results(
episodes_this_iter=episodes_this_iter,
step_ctx=train_iter_ctx, step_ctx=train_iter_ctx,
iteration_results=results, iteration_results=results,
) )
@ -780,19 +782,20 @@ class Algorithm(Trainable):
< units_left_to_do < units_left_to_do
] ]
) )
agent_steps_this_iter = sum(b.agent_steps() for b in batches) _agent_steps = sum(b.agent_steps() for b in batches)
env_steps_this_iter = sum(b.env_steps() for b in batches) _env_steps = sum(b.env_steps() for b in batches)
# 1 episode per returned batch. # 1 episode per returned batch.
if unit == "episodes": if unit == "episodes":
num_units_done += len(batches) num_units_done += len(batches)
# n timesteps per returned batch. # n timesteps per returned batch.
else: else:
num_units_done += ( num_units_done += (
agent_steps_this_iter _agent_steps if self._by_agent_steps else _env_steps
if self._by_agent_steps
else env_steps_this_iter
) )
agent_steps_this_iter += _agent_steps
env_steps_this_iter += _env_steps
logger.info( logger.info(
f"Ran round {round_} of parallel evaluation " f"Ran round {round_} of parallel evaluation "
f"({num_units_done}/{duration if not auto else '?'} " f"({num_units_done}/{duration if not auto else '?'} "
@ -862,7 +865,7 @@ class Algorithm(Trainable):
global_vars = { global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED], "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
} }
with self._timers[WORKER_UPDATE_TIMER]: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars) self.workers.sync_weights(global_vars=global_vars)
return train_results return train_results
@ -2366,7 +2369,9 @@ class Algorithm(Trainable):
* eval_cfg["num_envs_per_worker"] * eval_cfg["num_envs_per_worker"]
) )
def _compile_iteration_results(self, *, step_ctx, iteration_results=None): def _compile_iteration_results(
self, *, episodes_this_iter, step_ctx, iteration_results=None
):
# Return dict. # Return dict.
results: ResultDict = {} results: ResultDict = {}
iteration_results = iteration_results or {} iteration_results = iteration_results or {}
@ -2382,18 +2387,33 @@ class Algorithm(Trainable):
# Learner info. # Learner info.
results["info"] = {LEARNER_INFO: iteration_results} results["info"] = {LEARNER_INFO: iteration_results}
episodes = self._rollout_worker_metrics # Calculate how many (if any) of older, historical episodes we have to add to
orig_episodes = list(episodes) # `episodes_this_iter` in order to reach the required smoothing window.
missing = self.config["metrics_num_episodes_for_smoothing"] - len(episodes) episodes_for_metrics = episodes_this_iter[:]
missing = self.config["metrics_num_episodes_for_smoothing"] - len(
episodes_this_iter
)
# We have to add some older episodes to reach the smoothing window size.
if missing > 0: if missing > 0:
episodes = self._episode_history[-missing:] + episodes episodes_for_metrics = self._episode_history[-missing:] + episodes_this_iter
assert len(episodes) <= self.config["metrics_num_episodes_for_smoothing"] assert (
self._episode_history.extend(orig_episodes) len(episodes_for_metrics)
<= self.config["metrics_num_episodes_for_smoothing"]
)
# Note that when there are more than `metrics_num_episodes_for_smoothing`
# episodes in `episodes_for_metrics`, leave them as-is. In this case, we'll
# compute the stats over that larger number.
# Add new episodes to our history and make sure it doesn't grow larger than
# needed.
self._episode_history.extend(episodes_this_iter)
self._episode_history = self._episode_history[ self._episode_history = self._episode_history[
-self.config["metrics_num_episodes_for_smoothing"] : -self.config["metrics_num_episodes_for_smoothing"] :
] ]
results["sampler_results"] = summarize_episodes( results["sampler_results"] = summarize_episodes(
episodes, orig_episodes, self.config["keep_per_episode_custom_metrics"] episodes_for_metrics,
episodes_this_iter,
self.config["keep_per_episode_custom_metrics"],
) )
# TODO: Don't dump sampler results into top-level. # TODO: Don't dump sampler results into top-level.
results.update(results["sampler_results"]) results.update(results["sampler_results"])
@ -2413,12 +2433,16 @@ class Algorithm(Trainable):
results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
# TODO: For CQL and other algos, count by trained steps. # TODO: For CQL and other algos, count by trained steps.
results["timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED] results["timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
# TODO: Backward compatibility.
results[STEPS_TRAINED_THIS_ITER_COUNTER] = step_ctx.trained
else: else:
results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
# TODO: For CQL and other algos, count by trained steps. # TODO: For CQL and other algos, count by trained steps.
results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED] results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED]
# TODO: Backward compatibility. # TODO: Backward compatibility.
results[STEPS_TRAINED_THIS_ITER_COUNTER] = step_ctx.trained
# TODO: Backward compatibility.
results["agent_timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED] results["agent_timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
# Process timer results. # Process timer results.

View file

@ -1111,7 +1111,14 @@ class AlgorithmConfig:
metrics_episode_collection_timeout_s: Wait for metric batches for at most metrics_episode_collection_timeout_s: Wait for metric batches for at most
this many seconds. Those that have not returned in time will be this many seconds. Those that have not returned in time will be
collected in the next train iteration. collected in the next train iteration.
metrics_num_episodes_for_smoothing: Smooth metrics over this many episodes. metrics_num_episodes_for_smoothing: Smooth rollout metrics over this many
episodes, if possible.
In case rollouts (sample collection) just started, there may be fewer
than this many episodes in the buffer and we'll compute metrics
over this smaller number of available episodes.
In case there are more than this many episodes collected in a single
training iteration, use all of these episodes for metrics computation,
meaning don't ever cut any "excess" episodes.
min_time_s_per_iteration: Minimum time to accumulate within a single min_time_s_per_iteration: Minimum time to accumulate within a single
`train()` call. This value does not affect learning, `train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by only the number of times `Algorithm.training_step()` is called by

View file

@ -25,10 +25,6 @@ from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig
from ray.rllib.algorithms.dqn.learner_thread import LearnerThread from ray.rllib.algorithms.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.parallel_requests import AsyncRequestsManager from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.actors import create_colocated_actors
@ -507,6 +503,7 @@ class ApexDQN(DQN):
Args: Args:
_num_samples_ready: A mapping from ActorHandle (RolloutWorker) to _num_samples_ready: A mapping from ActorHandle (RolloutWorker) to
the number of samples returned by the remote worker. the number of samples returned by the remote worker.
Returns: Returns:
The number of remote workers whose weights were updated. The number of remote workers whose weights were updated.
""" """
@ -517,6 +514,9 @@ class ApexDQN(DQN):
self.learner_thread.weights_updated = False self.learner_thread.weights_updated = False
weights = self.workers.local_worker().get_weights() weights = self.workers.local_worker().get_weights()
self.curr_learner_weights = ray.put(weights) self.curr_learner_weights = ray.put(weights)
num_workers_updated = 0
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
for ( for (
remote_sampler_worker, remote_sampler_worker,
@ -529,11 +529,21 @@ class ApexDQN(DQN):
): ):
remote_sampler_worker.set_weights.remote( remote_sampler_worker.set_weights.remote(
self.curr_learner_weights, self.curr_learner_weights,
{"timestep": self._counters[STEPS_TRAINED_COUNTER]}, {
"timestep": self._counters[
NUM_AGENT_STEPS_TRAINED
if self._by_agent_steps
else NUM_ENV_STEPS_TRAINED
]
},
) )
self.steps_since_update[remote_sampler_worker] = 0 self.steps_since_update[remote_sampler_worker] = 0
num_workers_updated += 1
self._counters["num_weight_syncs"] += 1 self._counters["num_weight_syncs"] += 1
return num_workers_updated
def sample_from_replay_buffer_place_on_learner_queue_non_blocking( def sample_from_replay_buffer_place_on_learner_queue_non_blocking(
self, num_samples_collected: Dict[ActorHandle, int] self, num_samples_collected: Dict[ActorHandle, int]
) -> None: ) -> None:
@ -617,7 +627,6 @@ class ApexDQN(DQN):
else: else:
raise RuntimeError("The learner thread died in while training") raise RuntimeError("The learner thread died in while training")
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_samples_trained_this_itr
self._timers["learner_dequeue"] = self.learner_thread.queue_timer self._timers["learner_dequeue"] = self.learner_thread.queue_timer
self._timers["learner_grad"] = self.learner_thread.grad_timer self._timers["learner_grad"] = self.learner_thread.grad_timer
self._timers["learner_overall"] = self.learner_thread.overall_timer self._timers["learner_overall"] = self.learner_thread.overall_timer
@ -637,7 +646,9 @@ class ApexDQN(DQN):
) )
self._counters[NUM_TARGET_UPDATES] += 1 self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ self._counters[LAST_TARGET_UPDATE_TS] = self._counters[
STEPS_TRAINED_COUNTER NUM_AGENT_STEPS_TRAINED
if self._by_agent_steps
else NUM_ENV_STEPS_TRAINED
] ]
@override(Algorithm) @override(Algorithm)
@ -657,10 +668,8 @@ class ApexDQN(DQN):
self._sampling_actor_manager.add_workers(new_workers) self._sampling_actor_manager.add_workers(new_workers)
@override(Algorithm) @override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None): def _compile_iteration_results(self, *args, **kwargs):
result = super()._compile_iteration_results( result = super()._compile_iteration_results(*args, **kwargs)
step_ctx=step_ctx, iteration_results=iteration_results
)
replay_stats = ray.get( replay_stats = ray.get(
self._replay_actors[0].stats.remote(self.config["optimizer"].get("debug")) self._replay_actors[0].stats.remote(self.config["optimizer"].get("debug"))
) )

View file

@ -24,9 +24,6 @@ import ray
from ray.rllib.algorithms.ppo import PPOConfig, PPO from ray.rllib.algorithms.ppo import PPOConfig, PPO
from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import (
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.parallel_requests import AsyncRequestsManager from ray.rllib.execution.parallel_requests import AsyncRequestsManager
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.deprecation import Deprecated
@ -297,10 +294,8 @@ class DDPPO(PPO):
# - Update the worker's global_vars. # - Update the worker's global_vars.
# - Build info dict using a LearnerInfoBuilder object. # - Build info dict using a LearnerInfoBuilder object.
learner_info_builder = LearnerInfoBuilder(num_devices=1) learner_info_builder = LearnerInfoBuilder(num_devices=1)
steps_this_iter = 0
for worker, results in sample_and_update_results.items(): for worker, results in sample_and_update_results.items():
for result in results: for result in results:
steps_this_iter += result["env_steps"]
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
@ -315,8 +310,6 @@ class DDPPO(PPO):
for worker in self.workers.remote_workers(): for worker in self.workers.remote_workers():
worker.set_global_vars.remote(global_vars) worker.set_global_vars.remote(global_vars)
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter
# Sync down the weights from 1st remote worker (only if we have received # Sync down the weights from 1st remote worker (only if we have received
# some results from it). # some results from it).
# As with the sync up, this is not really needed unless the user is # As with the sync up, this is not really needed unless the user is

View file

@ -1,10 +1,10 @@
import queue import queue
import threading import threading
from ray.util.timer import _Timer
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.timer import TimerStat
LEARNER_QUEUE_MAX_SIZE = 16 LEARNER_QUEUE_MAX_SIZE = 16
@ -26,9 +26,9 @@ class LearnerThread(threading.Thread):
self.local_worker = local_worker self.local_worker = local_worker
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
self.outqueue = queue.Queue() self.outqueue = queue.Queue()
self.queue_timer = TimerStat() self.queue_timer = _Timer()
self.grad_timer = TimerStat() self.grad_timer = _Timer()
self.overall_timer = TimerStat() self.overall_timer = _Timer()
self.daemon = True self.daemon = True
self.weights_updated = False self.weights_updated = False
self.stopped = False self.stopped = False

View file

@ -16,6 +16,10 @@ from ray.rllib.execution.rollout_ops import (
) )
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
PartialAlgorithmConfigDict, PartialAlgorithmConfigDict,
@ -383,10 +387,11 @@ class Dreamer(Algorithm):
# Number of sub-iterations for Dreamer # Number of sub-iterations for Dreamer
dreamer_train_iters = self.config["dreamer_train_iters"] dreamer_train_iters = self.config["dreamer_train_iters"]
batch_size = self.config["batch_size"] batch_size = self.config["batch_size"]
action_repeat = self.config["action_repeat"]
# Collect SampleBatches from rollout workers. # Collect SampleBatches from rollout workers.
batch = synchronous_parallel_sample(worker_set=self.workers) batch = synchronous_parallel_sample(worker_set=self.workers)
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
fetches = {} fetches = {}
@ -398,25 +403,16 @@ class Dreamer(Algorithm):
fetches = local_worker.learn_on_batch(batch) fetches = local_worker.learn_on_batch(batch)
if fetches: if fetches:
# Custom Logging # Custom logging.
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"] policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches: if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"] gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self._postprocess_gif(gif) policy_fetches["log_gif"] = self._postprocess_gif(gif)
self._counters[STEPS_SAMPLED_COUNTER] = (
self.local_replay_buffer.timesteps * action_repeat
)
self.local_replay_buffer.add(batch) self.local_replay_buffer.add(batch)
return fetches return fetches
def _compile_iteration_results(self, *args, **kwargs):
results = super()._compile_iteration_results(*args, **kwargs)
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
return results
# Deprecated: Use ray.rllib.algorithms.dreamer.DreamerConfig instead! # Deprecated: Use ray.rllib.algorithms.dreamer.DreamerConfig instead!
class _deprecated_default_config(dict): class _deprecated_default_config(dict):

View file

@ -37,6 +37,8 @@ from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED,
NUM_SYNCH_WORKER_WEIGHTS,
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS,
) )
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
@ -822,7 +824,6 @@ class Impala(Algorithm):
final_learner_info = builder.finalize() final_learner_info = builder.finalize()
# Update the steps trained counters. # Update the steps trained counters.
self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained
self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained
@ -874,17 +875,17 @@ class Impala(Algorithm):
def update_workers_if_necessary(self) -> None: def update_workers_if_necessary(self) -> None:
# Only need to update workers if there are remote workers. # Only need to update workers if there are remote workers.
global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]} global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]}
self._counters["steps_since_broadcast"] += 1 self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] += 1
if ( if (
self.workers.remote_workers() self.workers.remote_workers()
and self._counters["steps_since_broadcast"] and self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS]
>= self.config["broadcast_interval"] >= self.config["broadcast_interval"]
and self.workers_that_need_updates and self.workers_that_need_updates
): ):
weights = ray.put(self.workers.local_worker().get_weights()) weights = ray.put(self.workers.local_worker().get_weights())
self._counters["steps_since_broadcast"] = 0 self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0
self._learner_thread.weights_updated = False self._learner_thread.weights_updated = False
self._counters["num_weight_broadcasts"] += 1 self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1
for worker in self.workers_that_need_updates: for worker in self.workers_that_need_updates:
worker.set_weights.remote(weights, global_vars) worker.set_weights.remote(weights, global_vars)
@ -910,10 +911,8 @@ class Impala(Algorithm):
self._sampling_actor_manager.add_workers(new_workers) self._sampling_actor_manager.add_workers(new_workers)
@override(Algorithm) @override(Algorithm)
def _compile_iteration_results(self, *, step_ctx, iteration_results=None): def _compile_iteration_results(self, *args, **kwargs):
result = super()._compile_iteration_results( result = super()._compile_iteration_results(*args, **kwargs)
step_ctx=step_ctx, iteration_results=iteration_results
)
result = self._learner_thread.add_learner_metrics( result = self._learner_thread.add_learner_metrics(
result, overwrite_learner_info=False result, overwrite_learner_info=False
) )

View file

@ -17,7 +17,7 @@ from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
from ray.rllib.utils.metrics import ( from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED,
WORKER_UPDATE_TIMER, SYNCH_WORKER_WEIGHTS_TIMER,
) )
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
ResultDict, ResultDict,
@ -284,7 +284,7 @@ class MARWIL(Algorithm):
# Update weights - after learning on the local worker - on all remote # Update weights - after learning on the local worker - on all remote
# workers. # workers.
if self.workers.remote_workers(): if self.workers.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars) self.workers.sync_weights(global_vars=global_vars)
# Update global vars on local worker as well. # Update global vars on local worker as well.

View file

@ -38,7 +38,7 @@ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.metrics import ( from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED,
WORKER_UPDATE_TIMER, SYNCH_WORKER_WEIGHTS_TIMER,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -426,7 +426,7 @@ class PPO(Algorithm):
# Update weights - after learning on the local worker - on all remote # Update weights - after learning on the local worker - on all remote
# workers. # workers.
if self.workers.remote_workers(): if self.workers.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars) self.workers.sync_weights(global_vars=global_vars)
# For each policy: update KL scale and warn about possible issues # For each policy: update KL scale and warn about possible issues

View file

@ -139,9 +139,10 @@ def summarize_episodes(
"""Summarizes a set of episode metrics tuples. """Summarizes a set of episode metrics tuples.
Args: Args:
episodes: smoothed set of episodes including historical ones episodes: List of most recent n episodes. This may include historical ones
new_episodes: just the new episodes in this iteration. This must be (not newly collected in this iteration) in order to achieve the size of
a subset of `episodes`. If None, assumes all episodes are new. the smoothing window.
new_episodes: All the episodes that were completed in this iteration.
""" """
if new_episodes is None: if new_episodes is None:

View file

@ -3,11 +3,11 @@ import platform
import random import random
from typing import Optional from typing import Optional
from ray.util.timer import _Timer
from ray.rllib.execution.replay_ops import SimpleReplayBuffer from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType from ray.rllib.utils.typing import PolicyID, SampleBatchType
@ -87,9 +87,9 @@ class MixInMultiAgentReplayBuffer:
self.replay_buffers = collections.defaultdict(new_buffer) self.replay_buffers = collections.defaultdict(new_buffer)
# Metrics. # Metrics.
self.add_batch_timer = TimerStat() self.add_batch_timer = _Timer()
self.replay_timer = TimerStat() self.replay_timer = _Timer()
self.update_priorities_timer = TimerStat() self.update_priorities_timer = _Timer()
# Added timesteps over lifetime. # Added timesteps over lifetime.
self.num_added = 0 self.num_added = 0

View file

@ -9,7 +9,7 @@ from ray.rllib.utils.metrics import ( # noqa: F401
NUM_TARGET_UPDATES, NUM_TARGET_UPDATES,
APPLY_GRADS_TIMER, APPLY_GRADS_TIMER,
COMPUTE_GRADS_TIMER, COMPUTE_GRADS_TIMER,
WORKER_UPDATE_TIMER, SYNCH_WORKER_WEIGHTS_TIMER as WORKER_UPDATE_TIMER,
GRAD_WAIT_TIMER, GRAD_WAIT_TIMER,
SAMPLE_TIMER, SAMPLE_TIMER,
LEARN_ON_BATCH_TIMER, LEARN_ON_BATCH_TIMER,

View file

@ -3,11 +3,11 @@ from six.moves import queue
import threading import threading
from typing import Dict, Optional from typing import Dict, Optional
from ray.util.timer import _Timer
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.metrics.window_stat import WindowStat from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.util.iter import _NextValueNotReady from ray.util.iter import _NextValueNotReady
@ -56,10 +56,10 @@ class LearnerThread(threading.Thread):
num_passes=num_sgd_iter, num_passes=num_sgd_iter,
init_num_passes=num_sgd_iter, init_num_passes=num_sgd_iter,
) )
self.queue_timer = TimerStat() self.queue_timer = _Timer()
self.grad_timer = TimerStat() self.grad_timer = _Timer()
self.load_timer = TimerStat() self.load_timer = _Timer()
self.load_wait_timer = TimerStat() self.load_wait_timer = _Timer()
self.daemon = True self.daemon = True
self.weights_updated = False self.weights_updated = False
self.learner_info = {} self.learner_info = {}

View file

@ -2,6 +2,7 @@ import logging
from six.moves import queue from six.moves import queue
import threading import threading
from ray.util.timer import _Timer
from ray.rllib.execution.learner_thread import LearnerThread from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
@ -9,7 +10,6 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.timer import TimerStat
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -192,8 +192,8 @@ class _MultiGPULoaderThread(threading.Thread):
self.queue_timer = multi_gpu_learner_thread.queue_timer self.queue_timer = multi_gpu_learner_thread.queue_timer
self.load_timer = multi_gpu_learner_thread.load_timer self.load_timer = multi_gpu_learner_thread.load_timer
else: else:
self.queue_timer = TimerStat() self.queue_timer = _Timer()
self.load_timer = TimerStat() self.load_timer = _Timer()
def run(self) -> None: def run(self) -> None:
while True: while True:

View file

@ -16,7 +16,6 @@ from ray.rllib.execution.common import (
STEPS_SAMPLED_COUNTER, STEPS_SAMPLED_COUNTER,
STEPS_TRAINED_COUNTER, STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER, STEPS_TRAINED_THIS_ITER_COUNTER,
WORKER_UPDATE_TIMER,
_check_sample_batch_type, _check_sample_batch_type,
_get_global_vars, _get_global_vars,
_get_shared_metrics, _get_shared_metrics,
@ -25,7 +24,11 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import NUM_ENV_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_TRAINED,
NUM_AGENT_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
@ -231,7 +234,7 @@ class TrainOneStep:
# Update weights - after learning on the local worker - on all remote # Update weights - after learning on the local worker - on all remote
# workers. # workers.
if self.workers.remote_workers(): if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = ray.put( weights = ray.put(
lw.get_weights(self.policies or lw.get_policies_to_train(batch)) lw.get_weights(self.policies or lw.get_policies_to_train(batch))
) )
@ -354,7 +357,7 @@ class MultiGPUTrainOneStep:
metrics.info[LEARNER_INFO] = learner_info metrics.info[LEARNER_INFO] = learner_info
if self.workers.remote_workers(): if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = ray.put( weights = ray.put(
self.workers.local_worker().get_weights( self.workers.local_worker().get_weights(
self.local_worker.get_policies_to_train() self.local_worker.get_policies_to_train()
@ -453,7 +456,7 @@ class ApplyGradients:
if self.update_all: if self.update_all:
if self.workers.remote_workers(): if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = ray.put( weights = ray.put(
self.local_worker.get_weights( self.local_worker.get_weights(
self.policies or self.local_worker.get_policies_to_train() self.policies or self.local_worker.get_policies_to_train()
@ -468,7 +471,7 @@ class ApplyGradients:
"update_all=False, `current_actor` must be set " "update_all=False, `current_actor` must be set "
"in the iterator context." "in the iterator context."
) )
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = self.local_worker.get_weights( weights = self.local_worker.get_weights(
self.policies or self.local_worker.get_policies_to_train() self.policies or self.local_worker.get_policies_to_train()
) )

View file

@ -8,11 +8,18 @@ NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter" NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter"
NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter" NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
# Counters for keeping track of worker weight updates (synchronization
# between local worker and remote workers).
NUM_SYNCH_WORKER_WEIGHTS = "num_weight_broadcasts"
NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS = (
"num_training_step_calls_since_last_synch_worker_weights"
)
# Counters to track target network updates. # Counters to track target network updates.
LAST_TARGET_UPDATE_TS = "last_target_update_ts" LAST_TARGET_UPDATE_TS = "last_target_update_ts"
NUM_TARGET_UPDATES = "num_target_updates" NUM_TARGET_UPDATES = "num_target_updates"
# Performance timers (keys for Algorithm._timers or metrics.timers). # Performance timers (keys for Algorithm._timers).
TRAINING_ITERATION_TIMER = "training_iteration" TRAINING_ITERATION_TIMER = "training_iteration"
APPLY_GRADS_TIMER = "apply_grad" APPLY_GRADS_TIMER = "apply_grad"
COMPUTE_GRADS_TIMER = "compute_grads" COMPUTE_GRADS_TIMER = "compute_grads"
@ -22,6 +29,3 @@ SAMPLE_TIMER = "sample"
LEARN_ON_BATCH_TIMER = "learn" LEARN_ON_BATCH_TIMER = "learn"
LOAD_BATCH_TIMER = "load" LOAD_BATCH_TIMER = "load"
TARGET_NET_UPDATE_TIMER = "target_net_update" TARGET_NET_UPDATE_TIMER = "target_net_update"
# Deprecated: Use `SYNCH_WORKER_WEIGHTS_TIMER` instead.
WORKER_UPDATE_TIMER = "update"

View file

@ -2,6 +2,7 @@ from typing import Dict
import logging import logging
import numpy as np import numpy as np
from ray.util.timer import _Timer
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ( from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer, MultiAgentReplayBuffer,
@ -16,7 +17,6 @@ from ray.rllib.utils.replay_buffers.replay_buffer import (
) )
from ray.rllib.utils.typing import PolicyID, SampleBatchType from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.timer import TimerStat
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
@ -137,7 +137,7 @@ class MultiAgentPrioritizedReplayBuffer(
) )
self.prioritized_replay_eps = prioritized_replay_eps self.prioritized_replay_eps = prioritized_replay_eps
self.update_priorities_timer = TimerStat() self.update_priorities_timer = _Timer()
@DeveloperAPI @DeveloperAPI
@override(MultiAgentReplayBuffer) @override(MultiAgentReplayBuffer)

View file

@ -3,6 +3,7 @@ import logging
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from ray.util.timer import _Timer
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
@ -13,7 +14,6 @@ from ray.rllib.utils.replay_buffers.replay_buffer import (
ReplayBuffer, ReplayBuffer,
StorageUnit, StorageUnit,
) )
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once from ray.util.debug import log_once
@ -184,8 +184,8 @@ class MultiAgentReplayBuffer(ReplayBuffer):
self.replay_buffers = collections.defaultdict(new_buffer) self.replay_buffers = collections.defaultdict(new_buffer)
# Metrics. # Metrics.
self.add_batch_timer = TimerStat() self.add_batch_timer = _Timer()
self.replay_timer = TimerStat() self.replay_timer = _Timer()
self._num_added = 0 self._num_added = 0
def __len__(self) -> int: def __len__(self) -> int:

View file

@ -1,3 +1,10 @@
from ray.util.timer import _Timer from ray.util.timer import _Timer
from ray.rllib.utils.deprecation import deprecation_warning
deprecation_warning(
old="ray.rllib.utils.timer::TimerStat",
new="ray.util.timer::_Timer",
error=False,
)
TimerStat = _Timer # backwards compatibility alias TimerStat = _Timer # backwards compatibility alias