ray/rllib/execution/tree_agg.py

140 lines
4.9 KiB
Python

import logging
import platform
from typing import Any, Dict, List
import ray
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
AGENT_STEPS_SAMPLED_COUNTER,
STEPS_SAMPLED_COUNTER,
_get_shared_metrics,
)
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.typing import ModelWeights, SampleBatchType
from ray.util.iter import (
LocalIterator,
ParallelIterator,
ParallelIteratorWorker,
from_actors,
)
logger = logging.getLogger(__name__)
@ray.remote(num_cpus=0)
class Aggregator(ParallelIteratorWorker):
"""An aggregation worker used by gather_experiences_tree_aggregation().
Each of these actors is a shard of a parallel iterator that consumes
batches from RolloutWorker actors, and emits batches of size
train_batch_size. This allows expensive decompression / concatenation
work to be offloaded to these actors instead of run in the learner.
"""
def __init__(
self, config: Dict, rollout_group: "ParallelIterator[SampleBatchType]"
):
self.weights = None
self.global_vars = None
def generator():
it = rollout_group.gather_async(
num_async=config["max_requests_in_flight_per_aggregator_worker"]
)
# Update the rollout worker with our latest policy weights.
def update_worker(item):
worker, batch = item
if self.weights:
worker.set_weights.remote(self.weights, self.global_vars)
return batch
# Augment with replay and concat to desired train batch size.
it = (
it.zip_with_source_actor()
.for_each(update_worker)
.for_each(lambda batch: batch.decompress_if_needed())
.for_each(
MixInReplay(
num_slots=config["replay_buffer_num_slots"],
replay_proportion=config["replay_proportion"],
)
)
.flatten()
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
)
for train_batch in it:
yield train_batch
super().__init__(generator, repeat=False)
def get_host(self) -> str:
return platform.node()
def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None:
self.weights = weights
self.global_vars = global_vars
def gather_experiences_tree_aggregation(
workers: WorkerSet, config: Dict
) -> "LocalIterator[Any]":
"""Tree aggregation version of gather_experiences_directly()."""
rollouts = ParallelRollouts(workers, mode="raw")
# Divide up the workers between aggregators.
worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
i = 0
for worker_idx in range(len(workers.remote_workers())):
worker_assignments[i].append(worker_idx)
i += 1
i %= len(worker_assignments)
logger.info("Worker assignments: {}".format(worker_assignments))
# Create parallel iterators that represent each aggregation group.
rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
rollouts.select_shards(assigned) for assigned in worker_assignments
]
# This spawns |num_aggregation_workers| intermediate actors that aggregate
# experiences in parallel. We force colocation on the same node (localhost)
# to maximize data bandwidth between them and the driver.
localhost = platform.node()
assert localhost != "", (
"ERROR: Cannot determine local node name! "
"`platform.node()` returned empty string."
)
all_co_located = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count=1)
(Aggregator, [config, g], {}, 1)
for g in rollout_groups
],
node=localhost,
)
# Use the first ([0]) of each created group (each group only has one
# actor: count=1).
train_batches = from_actors([group[0] for group in all_co_located])
# TODO(ekl) properly account for replay.
def record_steps_sampled(batch):
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
if isinstance(batch, MultiAgentBatch):
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps()
else:
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
return batch
return train_batches.gather_async().for_each(record_steps_sampled)