[RLlib] Decentralized multi-agent learning; PR #01 (#21421)

This commit is contained in:
Sven Mika 2022-01-13 10:52:55 +01:00 committed by GitHub
parent d392f97331
commit 90c6b10498
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 475 additions and 73 deletions

View file

@ -23,6 +23,15 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config.
"replay_buffer_config": None,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
@ -31,6 +40,7 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"worker_side_prioritization": True,
"min_iter_time_s": 30,
},
_allow_unknown_configs=True,
)

View file

@ -14,6 +14,7 @@ https://docs.ray.io/en/master/rllib-algorithms.html#distributed-prioritized-expe
import collections
import copy
import platform
from typing import Tuple
import ray
@ -32,7 +33,7 @@ from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
@ -55,10 +56,21 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"n_step": 3,
"num_gpus": 1,
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": None,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
@ -129,7 +141,8 @@ class ApexTrainer(DQNTrainer):
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"][
"num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [
replay_actor_args = [
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
@ -139,7 +152,24 @@ class ApexTrainer(DQNTrainer):
config["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
], num_replay_buffer_shards)
]
# Place all replay buffer shards on the same node as the learner
# (driver process that runs this execution plan).
if config["replay_buffer_shards_colocated_with_driver"]:
replay_actors = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count)
(ReplayActor, replay_actor_args, {},
num_replay_buffer_shards)
],
node=platform.node(), # localhost
)[0] # [0]=only one item in `actor_specs`.
# Place replay buffer shards on any node(s).
else:
replay_actors = [
ReplayActor(*replay_actor_args)
for _ in range(num_replay_buffer_shards)
]
# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())

View file

@ -11,9 +11,11 @@ import os
import pickle
import tempfile
import time
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Callable, DefaultDict, Dict, List, Optional, Set, Tuple, \
Type, Union
import ray
from ray.actor import ActorHandle
from ray.exceptions import RayError
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env.env_context import EnvContext
@ -722,8 +724,9 @@ class Trainer(Trainable):
self._episode_history = []
self._episodes_to_be_collected = []
# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
self.evaluation_workers = None
# Evaluation WorkerSet.
self.evaluation_workers: Optional[WorkerSet] = None
# Metrics most recently returned by `self.evaluate()`.
self.evaluation_metrics = {}
super().__init__(config, logger_creator, remote_checkpoint_dir,
@ -798,12 +801,19 @@ class Trainer(Trainable):
self.local_replay_buffer = (
self._create_local_replay_buffer_if_necessary(self.config))
# Create a dict, mapping ActorHandles to sets of open remote
# requests (object refs). This way, we keep track, of which actors
# inside this Trainer (e.g. a remote RolloutWorker) have
# already been sent how many (e.g. `sample()`) requests.
self.remote_requests_in_flight: \
DefaultDict[ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
# Deprecated way of implementing Trainer sub-classes (or "templates"
# via the soon-to-be deprecated `build_trainer` utility function).
# Instead, sub-classes should override the Trainable's `setup()`
# method and call super().setup() from within that override at some
# point.
self.workers = None
self.workers: Optional[WorkerSet] = None
self.train_exec_impl = None
# Old design: Override `Trainer._init` (or use `build_trainer()`, which
@ -845,13 +855,10 @@ class Trainer(Trainable):
self.workers, self.config,
**self._kwargs_for_execution_plan())
# TODO: Now that workers have been created, update our policy
# specs in the config[multiagent] dict with the correct spaces.
# However, this leads to a problem with the evaluation
# workers' observation one-hot preprocessor in
# `examples/documentation/rllib_in_6sec.py` script.
# self.config["multiagent"]["policies"] = \
# self.workers.local_worker().policy_map.policy_specs
# Now that workers have been created, update our policy
# specs in the config[multiagent] dict with the correct spaces.
self.config["multiagent"]["policies"] = \
self.workers.local_worker().policy_dict
# Evaluation WorkerSet setup.
# User would like to setup a separate evaluation worker set.
@ -912,7 +919,7 @@ class Trainer(Trainable):
# If evaluation_num_workers=0, use the evaluation set's local
# worker for evaluation, otherwise, use its remote workers
# (parallelized evaluation).
self.evaluation_workers = self._make_workers(
self.evaluation_workers: WorkerSet = self._make_workers(
env_creator=self.env_creator,
validate_env=None,
policy_class=self.get_default_policy_class(self.config),

View file

@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
TYPE_CHECKING, Union
import ray
from ray import ObjectRef
from ray import cloudpickle as pickle
from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
from ray.rllib.env.env_context import EnvContext
@ -537,16 +538,16 @@ class RolloutWorker(ParallelIteratorWorker):
self.make_sub_env_fn = make_sub_env
self.spaces = spaces
policy_dict = _determine_spaces_for_multi_agent_dict(
self.policy_dict = _determine_spaces_for_multi_agent_dict(
policy_spec,
self.env,
spaces=self.spaces,
policy_config=policy_config)
# List of IDs of those policies, which should be trained.
# By default, these are all policies found in the policy_dict.
# By default, these are all policies found in `self.policy_dict`.
self.policies_to_train: List[PolicyID] = policies_to_train or list(
policy_dict.keys())
self.policy_dict.keys())
self.set_policies_to_train(self.policies_to_train)
self.policy_map: PolicyMap = None
@ -583,7 +584,7 @@ class RolloutWorker(ParallelIteratorWorker):
f"is ignored.")
self._build_policy_map(
policy_dict,
self.policy_dict,
policy_config,
session_creator=tf_session_creator,
seed=seed)
@ -1111,7 +1112,7 @@ class RolloutWorker(ParallelIteratorWorker):
"""
if policy_id in self.policy_map:
raise ValueError(f"Policy ID '{policy_id}' already in policy map!")
policy_dict = _determine_spaces_for_multi_agent_dict(
policy_dict_to_add = _determine_spaces_for_multi_agent_dict(
{
policy_id: PolicySpec(policy_cls, observation_space,
action_space, config or {})
@ -1120,8 +1121,9 @@ class RolloutWorker(ParallelIteratorWorker):
spaces=self.spaces,
policy_config=self.policy_config,
)
self.policy_dict.update(policy_dict_to_add)
self._build_policy_map(
policy_dict,
policy_dict_to_add,
self.policy_config,
seed=self.policy_config.get("seed"))
new_policy = self.policy_map[policy_id]
@ -1386,6 +1388,14 @@ class RolloutWorker(ParallelIteratorWorker):
>>> # Set `global_vars` (timestep) as well.
>>> worker.set_weights(weights, {"timestep": 42})
"""
# If per-policy weights are object refs, `ray.get()` them first.
if weights and isinstance(next(iter(weights.values())), ObjectRef):
actual_weights = ray.get(list(weights.values()))
weights = {
pid: actual_weights[i]
for i, pid in enumerate(weights.keys())
}
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
if global_vars:

View file

@ -1,10 +1,10 @@
import logging
from typing import List, Tuple
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, \
TYPE_CHECKING
import ray
from ray.util.iter import from_actors, LocalIterator
from ray.util.iter_metrics import SharedMetrics
from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
@ -12,27 +12,200 @@ from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
_check_sample_batch_type, _get_shared_metrics
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
LEARNER_STATS_KEY
from ray.rllib.utils.sgd import standardized
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
from ray.util.iter import from_actors, LocalIterator
from ray.util.iter_metrics import SharedMetrics
if TYPE_CHECKING:
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.rollout_worker import RolloutWorker
logger = logging.getLogger(__name__)
def synchronous_parallel_sample(workers: WorkerSet) -> List[SampleBatch]:
@ExperimentalAPI
def synchronous_parallel_sample(
worker_set: WorkerSet,
remote_fn: Optional[Callable[["RolloutWorker"], None]] = None,
) -> List[SampleBatch]:
"""Runs parallel and synchronous rollouts on all remote workers.
Waits for all workers to return from the remote calls.
If no remote workers exist (num_workers == 0), use the local worker
for sampling.
Alternatively to calling `worker.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the worker(s) instead.
Args:
worker_set: The WorkerSet to use for sampling.
remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
of `worker.sample.remote()` to generate the requests.
Returns:
The list of collected sample batch types (one for each parallel
rollout worker in the given `worker_set`).
Examples:
>>> # 2 remote workers (num_workers=2):
>>> batches = synchronous_parallel_sample(trainer.workers)
>>> print(len(batches))
... 2
>>> print(batches[0])
... SampleBatch(16: ['obs', 'actions', 'rewards', 'dones'])
>>> # 0 remote workers (num_workers=0): Using the local worker.
>>> batches = synchronous_parallel_sample(trainer.workers)
>>> print(len(batches))
... 1
"""
# No remote workers in the set -> Use local worker for collecting
# samples.
if not workers.remote_workers():
return [workers.local_worker().sample()]
if not worker_set.remote_workers():
return [worker_set.local_worker().sample()]
# Loop over remote workers' `sample()` method in parallel.
sample_batches = ray.get(
[r.sample.remote() for r in workers.remote_workers()])
[r.sample.remote() for r in worker_set.remote_workers()])
# Return all collected batches.
return sample_batches
# TODO: Move to generic parallel ops module and rename to
# `asynchronous_parallel_requests`:
@ExperimentalAPI
def asynchronous_parallel_sample(
trainer: "Trainer",
actors: List[ActorHandle],
ray_wait_timeout_s: Optional[float] = None,
max_remote_requests_in_flight_per_actor: int = 2,
remote_fn: Optional[Callable[["RolloutWorker"], None]] = None,
remote_args: Optional[List[List[Any]]] = None,
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Optional[List[SampleBatch]]:
"""Runs parallel and asynchronous rollouts on all remote workers.
May use a timeout (if provided) on `ray.wait()` and returns only those
samples that could be gathered in the timeout window. Allows a maximum
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
per remote actor.
Alternatively to calling `actor.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the actor(s) instead.
Args:
trainer: The Trainer object that we run the sampling for.
actors: The List of ActorHandles to perform the remote requests on.
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
`ray.wait()` calls. If None (default), never time out (block
until at least one actor returns something).
max_remote_requests_in_flight_per_actor: Maximum number of remote
requests sent to each actor. 2 (default) is probably
sufficient to avoid idle times between two requests.
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
`actor.sample.remote()` to generate the requests.
remote_args: If provided, use this list (per-actor) of lists (call
args) as *args to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_args=[[...] <- *args for A, [...] <- *args for B].
remote_kwargs: If provided, use this list (per-actor) of dicts
(kwargs) as **kwargs to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
Returns:
The list of asynchronously collected sample batch types. None, if no
samples are ready.
Examples:
>>> # 2 remote rollout workers (num_workers=2):
>>> batches = asynchronous_parallel_sample(
... trainer,
... actors=trainer.workers.remote_workers(),
... ray_wait_timeout_s=0.1,
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
... )
>>> print(len(batches))
... 2
>>> # Expect a timeout to have happened.
>>> batches[0] is None and batches[1] is None
... True
"""
if remote_args is not None:
assert len(remote_args) == len(actors)
if remote_kwargs is not None:
assert len(remote_kwargs) == len(actors)
# Collect all currently pending remote requests into a single set of
# object refs.
pending_remotes = set()
# Also build a map to get the associated actor for each remote request.
remote_to_actor = {}
for actor, set_ in trainer.remote_requests_in_flight.items():
pending_remotes |= set_
for r in set_:
remote_to_actor[r] = actor
# Add new requests, if possible (if
# `max_remote_requests_in_flight_per_actor` setting allows it).
for actor_idx, actor in enumerate(actors):
# Still room for another request to this actor.
if len(trainer.remote_requests_in_flight[actor]) < \
max_remote_requests_in_flight_per_actor:
if remote_fn is None:
req = actor.sample.remote()
else:
args = remote_args[actor_idx] if remote_args else []
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
req = actor.apply.remote(remote_fn, *args, **kwargs)
# Add to our set to send to ray.wait().
pending_remotes.add(req)
# Keep our mappings properly updated.
trainer.remote_requests_in_flight[actor].add(req)
remote_to_actor[req] = actor
# There must always be pending remote requests.
assert len(pending_remotes) > 0
pending_remote_list = list(pending_remotes)
# No timeout: Block until at least one result is returned.
if ray_wait_timeout_s is None:
# First try to do a `ray.wait` w/o timeout for efficiency.
ready, _ = ray.wait(
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
# Nothing returned and `timeout` is None -> Fall back to a
# blocking wait to make sure we can return something.
if not ready:
ready, _ = ray.wait(pending_remote_list, num_returns=1)
# Timeout: Do a `ray.wait() call` w/ timeout.
else:
ready, _ = ray.wait(
pending_remote_list,
num_returns=len(pending_remotes),
timeout=ray_wait_timeout_s)
# Return None if nothing ready after the timeout.
if not ready:
return None
for obj_ref in ready:
# Remove in-flight record for this ref.
trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove(
obj_ref)
remote_to_actor.pop(obj_ref)
results = ray.get(ready)
return results
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
num_async=1) -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers.

View file

@ -9,7 +9,7 @@ from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actors import create_colocated
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.typing import SampleBatchType, ModelWeights
from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
from_actors, LocalIterator
@ -91,11 +91,22 @@ def gather_experiences_tree_aggregation(workers: WorkerSet,
]
# This spawns |num_aggregation_workers| intermediate actors that aggregate
# experiences in parallel. We force colocation on the same node to maximize
# data bandwidth between them and the driver.
train_batches = from_actors([
create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups
])
# experiences in parallel. We force colocation on the same node (localhost)
# to maximize data bandwidth between them and the driver.
localhost = platform.node()
assert localhost != "", \
"ERROR: Cannot determine local node name! " \
"`platform.node()` returned empty string."
all_co_located = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count=1)
(Aggregator, [config, g], {}, 1) for g in rollout_groups
],
node=localhost)
# Use the first ([0]) of each created group (each group only has one
# actor: count=1).
train_batches = from_actors([group[0] for group in all_co_located])
# TODO(ekl) properly account for replay.
def record_steps_sampled(batch):

View file

@ -1,7 +1,11 @@
from collections import defaultdict, deque
import logging
import platform
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
import ray
from collections import deque
from ray.actor import ActorClass, ActorHandle
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__)
@ -65,45 +69,202 @@ class TaskPool:
return len(self._tasks)
def drop_colocated(actors):
def create_colocated_actors(
actor_specs: Sequence[Tuple[Type, Any, Any, int]],
node: Optional[str] = "localhost",
max_attempts: int = 10,
) -> Dict[Type, List[ActorHandle]]:
"""Create co-located actors of any type(s) on any node.
Args:
actor_specs: Tuple/list with tuples consisting of: 1) The
(already @ray.remote) class(es) to construct, 2) c'tor args,
3) c'tor kwargs, and 4) the number of actors of that class with
given args/kwargs to construct.
node: The node to co-locate the actors on. By default ("localhost"),
place the actors on the node the caller of this function is
located on. Use None for indicating that any (resource fulfilling)
node in the cluster may be used.
max_attempts: The maximum number of co-location attempts to
perform before throwing an error.
Returns:
A dict mapping the created types to the list of n ActorHandles
created (and co-located) for that type.
"""
if node == "localhost":
node = platform.node()
# Maps each entry in `actor_specs` to lists of already co-located actors.
ok = [[] for _ in range(len(actor_specs))]
# Try n times to co-locate all given actor types (`actor_specs`).
# With each (failed) attempt, increase the number of actors we try to
# create (on the same node), then kill the ones that have been created in
# excess.
for attempt in range(max_attempts):
# If any attempt to co-locate fails, set this to False and we'll do
# another attempt.
all_good = True
# Process all `actor_specs` in sequence.
for i, (typ, args, kwargs, count) in enumerate(actor_specs):
args = args or [] # Allow None.
kwargs = kwargs or {} # Allow None.
# We don't have enough actors yet of this spec co-located on
# the desired node.
if len(ok[i]) < count:
co_located = try_create_colocated(
cls=typ,
args=args,
kwargs=kwargs,
count=count * (attempt + 1),
node=node)
# If node did not matter (None), from here on, use the host
# that the first actor(s) are already co-located on.
if node is None:
node = ray.get(co_located[0].get_host.remote())
# Add the newly co-located actors to the `ok` list.
ok[i].extend(co_located)
# If we still don't have enough -> We'll have to do another
# attempt.
if len(ok[i]) < count:
all_good = False
# We created too many actors for this spec -> Kill/truncate
# the excess ones.
if len(ok[i]) > count:
for a in ok[i][count:]:
a.__ray_terminate__.remote()
ok[i] = ok[i][:count]
# All `actor_specs` have been fulfilled, return lists of
# co-located actors.
if all_good:
return ok
raise Exception("Unable to create enough colocated actors -> aborting.")
def try_create_colocated(
cls: Type[ActorClass],
args: List[Any],
count: int,
kwargs: Optional[List[Any]] = None,
node: Optional[str] = "localhost",
) -> List[ActorHandle]:
"""Tries to co-locate (same node) a set of Actors of the same type.
Returns a list of successfully co-located actors. All actors that could
not be co-located (with the others on the given node) will not be in this
list.
Creates each actor via it's remote() constructor and then checks, whether
it has been co-located (on the same node) with the other (already created)
ones. If not, terminates the just created actor.
Args:
cls: The Actor class to use (already @ray.remote "converted").
args: List of args to pass to the Actor's constructor. One item
per to-be-created actor (`count`).
count: Number of actors of the given `cls` to construct.
kwargs: Optional list of kwargs to pass to the Actor's constructor.
One item per to-be-created actor (`count`).
node: The node to co-locate the actors on. By default ("localhost"),
place the actors on the node the caller of this function is
located on. If None, will try to co-locate all actors on
any available node.
Returns:
List containing all successfully co-located actor handles.
"""
if node == "localhost":
node = platform.node()
kwargs = kwargs or {}
actors = [cls.remote(*args, **kwargs) for _ in range(count)]
co_located, non_co_located = split_colocated(actors, node=node)
logger.info("Got {} colocated actors of {}".format(len(co_located), count))
for a in non_co_located:
a.__ray_terminate__.remote()
return co_located
def split_colocated(
actors: List[ActorHandle],
node: Optional[str] = "localhost",
) -> Tuple[List[ActorHandle], List[ActorHandle]]:
"""Splits up given actors into colocated (on same node) and non colocated.
The co-location criterion depends on the `node` given:
If given (or default: platform.node()): Consider all actors that are on
that node "colocated".
If None: Consider the largest sub-set of actors that are all located on
the same node (whatever that node is) as "colocated".
Args:
actors: The list of actor handles to split into "colocated" and
"non colocated".
node: The node defining "colocation" criterion. If provided, consider
thos actors "colocated" that sit on this node. If None, use the
largest subset within `actors` that are sitting on the same
(any) node.
Returns:
Tuple of two lists: 1) Co-located ActorHandles, 2) non co-located
ActorHandles.
"""
if node == "localhost":
node = platform.node()
# Get nodes of all created actors.
hosts = ray.get([a.get_host.remote() for a in actors])
# If `node` not provided, use the largest group of actors that sit on the
# same node, regardless of what that node is.
if node is None:
node_groups = defaultdict(set)
for host, actor in zip(hosts, actors):
node_groups[host].add(actor)
max_ = -1
largest_group = None
for host in node_groups:
if max_ < len(node_groups[host]):
max_ = len(node_groups[host])
largest_group = host
non_co_located = []
for host in node_groups:
if host != largest_group:
non_co_located.extend(list(node_groups[host]))
return list(node_groups[largest_group]), non_co_located
# Node provided (or default: localhost): Consider those actors "colocated"
# that were placed on `node`.
else:
# Split into co-located (on `node) and non-co-located (not on `node`).
co_located = []
non_co_located = []
for host, a in zip(hosts, actors):
# This actor has been placed on the correct node.
if host == node:
co_located.append(a)
# This actor has been placed on a different node.
else:
non_co_located.append(a)
return co_located, non_co_located
@Deprecated(new="create_colocated_actors", error=False)
def create_colocated(cls, arg, count):
kwargs = {}
args = arg
return create_colocated_actors(
actor_specs=[(cls, args, kwargs, count)],
node=platform.node(), # force on localhost
)[cls]
@Deprecated(error=False)
def drop_colocated(actors: List[ActorHandle]) -> List[ActorHandle]:
colocated, non_colocated = split_colocated(actors)
for a in colocated:
a.__ray_terminate__.remote()
return non_colocated
def split_colocated(actors):
localhost = platform.node()
hosts = ray.get([a.get_host.remote() for a in actors])
local = []
non_local = []
for host, a in zip(hosts, actors):
if host == localhost:
local.append(a)
else:
non_local.append(a)
return local, non_local
def try_create_colocated(cls, args, count):
actors = [cls.remote(*args) for _ in range(count)]
local, rest = split_colocated(actors)
logger.info("Got {} colocated actors of {}".format(len(local), count))
for a in rest:
a.__ray_terminate__.remote()
return local
def create_colocated(cls, args, count):
logger.info("Trying to create {} colocated actors".format(count))
ok = []
i = 1
while len(ok) < count and i < 10:
attempt = try_create_colocated(cls, args, count * i)
ok.extend(attempt)
i += 1
if len(ok) < count:
raise Exception("Unable to create enough colocated actors, abort.")
for a in ok[count:]:
a.__ray_terminate__.remote()
return ok[:count]