ray/rllib/execution/tree_agg.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

141 lines
4.9 KiB
Python
Raw Normal View History

import logging
import platform
from typing import Any, Dict, List
import ray
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
AGENT_STEPS_SAMPLED_COUNTER,
STEPS_SAMPLED_COUNTER,
_get_shared_metrics,
)
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.typing import ModelWeights, SampleBatchType
from ray.util.iter import (
LocalIterator,
ParallelIterator,
ParallelIteratorWorker,
from_actors,
)
logger = logging.getLogger(__name__)
@ray.remote(num_cpus=0)
class Aggregator(ParallelIteratorWorker):
"""An aggregation worker used by gather_experiences_tree_aggregation().
Each of these actors is a shard of a parallel iterator that consumes
batches from RolloutWorker actors, and emits batches of size
train_batch_size. This allows expensive decompression / concatenation
work to be offloaded to these actors instead of run in the learner.
"""
def __init__(
self, config: Dict, rollout_group: "ParallelIterator[SampleBatchType]"
):
self.weights = None
self.global_vars = None
def generator():
it = rollout_group.gather_async(
num_async=config["max_requests_in_flight_per_aggregator_worker"]
)
# Update the rollout worker with our latest policy weights.
def update_worker(item):
worker, batch = item
if self.weights:
worker.set_weights.remote(self.weights, self.global_vars)
return batch
# Augment with replay and concat to desired train batch size.
it = (
it.zip_with_source_actor()
.for_each(update_worker)
.for_each(lambda batch: batch.decompress_if_needed())
.for_each(
MixInReplay(
num_slots=config["replay_buffer_num_slots"],
replay_proportion=config["replay_proportion"],
)
)
.flatten()
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
)
for train_batch in it:
yield train_batch
super().__init__(generator, repeat=False)
def get_host(self) -> str:
return platform.node()
def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None:
self.weights = weights
self.global_vars = global_vars
def gather_experiences_tree_aggregation(
workers: WorkerSet, config: Dict
) -> "LocalIterator[Any]":
"""Tree aggregation version of gather_experiences_directly()."""
rollouts = ParallelRollouts(workers, mode="raw")
# Divide up the workers between aggregators.
worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
i = 0
for worker_idx in range(len(workers.remote_workers())):
worker_assignments[i].append(worker_idx)
i += 1
i %= len(worker_assignments)
logger.info("Worker assignments: {}".format(worker_assignments))
# Create parallel iterators that represent each aggregation group.
rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
rollouts.select_shards(assigned) for assigned in worker_assignments
]
# This spawns |num_aggregation_workers| intermediate actors that aggregate
# experiences in parallel. We force colocation on the same node (localhost)
# to maximize data bandwidth between them and the driver.
localhost = platform.node()
assert localhost != "", (
"ERROR: Cannot determine local node name! "
"`platform.node()` returned empty string."
)
all_co_located = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count=1)
(Aggregator, [config, g], {}, 1)
for g in rollout_groups
],
node=localhost,
)
# Use the first ([0]) of each created group (each group only has one
# actor: count=1).
train_batches = from_actors([group[0] for group in all_co_located])
# TODO(ekl) properly account for replay.
def record_steps_sampled(batch):
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
if isinstance(batch, MultiAgentBatch):
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps()
else:
metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
return batch
return train_batches.gather_async().for_each(record_steps_sampled)