import logging import platform from typing import Any, Dict, List import ray from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import ( AGENT_STEPS_SAMPLED_COUNTER, STEPS_SAMPLED_COUNTER, _get_shared_metrics, ) from ray.rllib.execution.replay_ops import MixInReplay from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.typing import ModelWeights, SampleBatchType from ray.util.iter import ( LocalIterator, ParallelIterator, ParallelIteratorWorker, from_actors, ) logger = logging.getLogger(__name__) @ray.remote(num_cpus=0) class Aggregator(ParallelIteratorWorker): """An aggregation worker used by gather_experiences_tree_aggregation(). Each of these actors is a shard of a parallel iterator that consumes batches from RolloutWorker actors, and emits batches of size train_batch_size. This allows expensive decompression / concatenation work to be offloaded to these actors instead of run in the learner. """ def __init__( self, config: Dict, rollout_group: "ParallelIterator[SampleBatchType]" ): self.weights = None self.global_vars = None def generator(): it = rollout_group.gather_async( num_async=config["max_requests_in_flight_per_aggregator_worker"] ) # Update the rollout worker with our latest policy weights. def update_worker(item): worker, batch = item if self.weights: worker.set_weights.remote(self.weights, self.global_vars) return batch # Augment with replay and concat to desired train batch size. it = ( it.zip_with_source_actor() .for_each(update_worker) .for_each(lambda batch: batch.decompress_if_needed()) .for_each( MixInReplay( num_slots=config["replay_buffer_num_slots"], replay_proportion=config["replay_proportion"], ) ) .flatten() .combine( ConcatBatches( min_batch_size=config["train_batch_size"], count_steps_by=config["multiagent"]["count_steps_by"], ) ) ) for train_batch in it: yield train_batch super().__init__(generator, repeat=False) def get_host(self) -> str: return platform.node() def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None: self.weights = weights self.global_vars = global_vars def gather_experiences_tree_aggregation( workers: WorkerSet, config: Dict ) -> "LocalIterator[Any]": """Tree aggregation version of gather_experiences_directly().""" rollouts = ParallelRollouts(workers, mode="raw") # Divide up the workers between aggregators. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])] i = 0 for worker_idx in range(len(workers.remote_workers())): worker_assignments[i].append(worker_idx) i += 1 i %= len(worker_assignments) logger.info("Worker assignments: {}".format(worker_assignments)) # Create parallel iterators that represent each aggregation group. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [ rollouts.select_shards(assigned) for assigned in worker_assignments ] # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node (localhost) # to maximize data bandwidth between them and the driver. localhost = platform.node() assert localhost != "", ( "ERROR: Cannot determine local node name! " "`platform.node()` returned empty string." ) all_co_located = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count=1) (Aggregator, [config, g], {}, 1) for g in rollout_groups ], node=localhost, ) # Use the first ([0]) of each created group (each group only has one # actor: count=1). train_batches = from_actors([group[0] for group in all_co_located]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps() else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch return train_batches.gather_async().for_each(record_steps_sampled)