mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
This commit is contained in:
parent
8ebc50f844
commit
ee41800c16
12 changed files with 596 additions and 170 deletions
146
rllib/execution/buffers/mixin_replay_buffer.py
Normal file
146
rllib/execution/buffers/mixin_replay_buffer.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
import collections
|
||||
import platform
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
|
||||
|
||||
class MixInMultiAgentReplayBuffer:
|
||||
"""This buffer adds replayed samples to a stream of new experiences.
|
||||
|
||||
- Any newly added batch (`add_batch()`) is immediately returned upon
|
||||
the next `replay` call (close to on-policy) as well as being moved
|
||||
into the buffer.
|
||||
- Additionally, a certain number of old samples is mixed into the
|
||||
returned sample according to a given "replay ratio".
|
||||
- If >1 calls to `add_batch()` are made without any `replay()` calls
|
||||
in between, all newly added batches are returned (plus some older samples
|
||||
according to the "replay ratio").
|
||||
|
||||
Examples:
|
||||
# replay ratio 0.66 (2/3 replayed, 1/3 new samples):
|
||||
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100,
|
||||
... replay_ratio=0.66)
|
||||
>>> buffer.add_batch(<A>)
|
||||
>>> buffer.add_batch(<B>)
|
||||
>>> buffer.replay()
|
||||
... [<A>, <B>, <B>]
|
||||
>>> buffer.add_batch(<C>)
|
||||
>>> buffer.replay()
|
||||
... [<C>, <A>, <B>]
|
||||
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it
|
||||
>>> # is the newest sample
|
||||
|
||||
>>> buffer.add_batch(<D>)
|
||||
>>> buffer.replay()
|
||||
... [<D>, <A>, <C>]
|
||||
|
||||
# replay proportion 0.0 -> replay disabled:
|
||||
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0)
|
||||
>>> buffer.add_batch(<A>)
|
||||
>>> buffer.replay()
|
||||
... [<A>]
|
||||
>>> buffer.add_batch(<B>)
|
||||
>>> buffer.replay()
|
||||
... [<B>]
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int, replay_ratio: float):
|
||||
"""Initializes MixInReplay instance.
|
||||
|
||||
Args:
|
||||
capacity (int): Number of batches to store in total.
|
||||
replay_ratio (float): Ratio of replayed samples in the returned
|
||||
batches. E.g. a ratio of 0.0 means only return new samples
|
||||
(no replay), a ratio of 0.5 means always return newest sample
|
||||
plus one old one (1:1), a ratio of 0.66 means always return
|
||||
the newest sample plus 2 old (replayed) ones (1:2), etc...
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.replay_ratio = replay_ratio
|
||||
self.replay_proportion = None
|
||||
if self.replay_ratio != 1.0:
|
||||
self.replay_proportion = self.replay_ratio / (
|
||||
1.0 - self.replay_ratio)
|
||||
|
||||
def new_buffer():
|
||||
return SimpleReplayBuffer(num_slots=capacity)
|
||||
|
||||
self.replay_buffers = collections.defaultdict(new_buffer)
|
||||
|
||||
# Metrics.
|
||||
self.add_batch_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.update_priorities_timer = TimerStat()
|
||||
|
||||
# Added timesteps over lifetime.
|
||||
self.num_added = 0
|
||||
|
||||
# Last added batch(es).
|
||||
self.last_added_batches = collections.defaultdict(list)
|
||||
|
||||
def add_batch(self, batch: SampleBatchType) -> None:
|
||||
"""Adds a batch to the appropriate policy's replay buffer.
|
||||
|
||||
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
|
||||
it is not a MultiAgentBatch. Subsequently adds the individual policy
|
||||
batches to the storage.
|
||||
|
||||
Args:
|
||||
batch: The batch to be added.
|
||||
"""
|
||||
# Make a copy so the replay buffer doesn't pin plasma memory.
|
||||
batch = batch.copy()
|
||||
batch = batch.as_multi_agent()
|
||||
|
||||
with self.add_batch_timer:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
self.replay_buffers[policy_id].add_batch(sample_batch)
|
||||
self.last_added_batches[policy_id].append(sample_batch)
|
||||
self.num_added += batch.count
|
||||
|
||||
def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
|
||||
Optional[SampleBatchType]:
|
||||
buffer = self.replay_buffers[policy_id]
|
||||
# Return None, if:
|
||||
# - Buffer empty or
|
||||
# - `replay_ratio` < 1.0 (new samples required in returned batch)
|
||||
# and no new samples to mix with replayed ones.
|
||||
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0
|
||||
and self.replay_ratio < 1.0):
|
||||
return None
|
||||
|
||||
# Mix buffer's last added batches with older replayed batches.
|
||||
with self.replay_timer:
|
||||
output_batches = self.last_added_batches[policy_id]
|
||||
self.last_added_batches[policy_id] = []
|
||||
|
||||
# No replay desired -> Return here.
|
||||
if self.replay_ratio == 0.0:
|
||||
return SampleBatch.concat_samples(output_batches)
|
||||
# Only replay desired -> Return a (replayed) sample from the
|
||||
# buffer.
|
||||
elif self.replay_ratio == 1.0:
|
||||
return buffer.replay()
|
||||
|
||||
# Replay ratio = old / [old + new]
|
||||
# Replay proportion: old / new
|
||||
num_new = len(output_batches)
|
||||
replay_proportion = self.replay_proportion
|
||||
while random.random() < num_new * replay_proportion:
|
||||
replay_proportion -= 1
|
||||
output_batches.append(buffer.replay())
|
||||
return SampleBatch.concat_samples(output_batches)
|
||||
|
||||
def get_host(self) -> str:
|
||||
"""Returns the computer's network name.
|
||||
|
||||
Returns:
|
||||
The computer's networks name or an empty string, if the network
|
||||
name could not be determined.
|
||||
"""
|
||||
return platform.node()
|
|
@ -1,6 +1,6 @@
|
|||
import collections
|
||||
import platform
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
|
@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch
|
|||
from ray.rllib.utils import deprecation_warning
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
|
||||
|
||||
|
@ -195,7 +195,7 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
|
|||
time_slice, weight=weight)
|
||||
self.num_added += batch.count
|
||||
|
||||
def replay(self) -> SampleBatchType:
|
||||
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
|
||||
"""If this buffer was given a fake batch, return it, otherwise return
|
||||
a MultiAgentBatch with samples.
|
||||
"""
|
||||
|
@ -211,8 +211,13 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
|
|||
# Lockstep mode: Sample from all policies at the same time an
|
||||
# equal amount of steps.
|
||||
if self.replay_mode == "lockstep":
|
||||
assert policy_id is None, \
|
||||
"`policy_id` specifier not allowed in `locksetp` mode!"
|
||||
return self.replay_buffers[_ALL_POLICIES].sample(
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta)
|
||||
elif policy_id is not None:
|
||||
return self.replay_buffers[policy_id].sample(
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta)
|
||||
else:
|
||||
samples = {}
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
|
|
|
@ -132,19 +132,25 @@ class ReplayBuffer:
|
|||
|
||||
@DeveloperAPI
|
||||
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
|
||||
"""Sample a batch of experiences.
|
||||
"""Sample a batch of size `num_items` from this buffer.
|
||||
|
||||
If less than `num_items` records are in this buffer, some samples in
|
||||
the results may be repeated to fulfil the batch size (`num_items`)
|
||||
request.
|
||||
|
||||
Args:
|
||||
num_items: Number of items to sample from this buffer.
|
||||
beta: This is ignored (only used by prioritized replay buffers).
|
||||
beta: The prioritized replay beta value. Only relevant if this
|
||||
ReplayBuffer is a PrioritizedReplayBuffer.
|
||||
|
||||
Returns:
|
||||
Concatenated batch of items.
|
||||
"""
|
||||
idxes = [
|
||||
random.randint(0,
|
||||
len(self._storage) - 1) for _ in range(num_items)
|
||||
]
|
||||
# If we don't have any samples yet in this buffer, return None.
|
||||
if len(self) == 0:
|
||||
return None
|
||||
|
||||
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
|
||||
sample = self._encode_sample(idxes)
|
||||
# Update our timesteps counters.
|
||||
self._num_timesteps_sampled += len(sample)
|
||||
|
@ -282,6 +288,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
"batch_indexes" fields denoting IS of each sampled
|
||||
transition and original idxes in buffer of sampled experiences.
|
||||
"""
|
||||
# If we don't have any samples yet in this buffer, return None.
|
||||
if len(self) == 0:
|
||||
return None
|
||||
|
||||
assert beta >= 0.0
|
||||
|
||||
idxes = self._sample_proportional(num_items)
|
||||
|
|
152
rllib/execution/parallel_requests.py
Normal file
152
rllib/execution/parallel_requests.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
import logging
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
def asynchronous_parallel_requests(
|
||||
remote_requests_in_flight: DefaultDict[ActorHandle, Set[
|
||||
ray.ObjectRef]],
|
||||
actors: List[ActorHandle],
|
||||
ray_wait_timeout_s: Optional[float] = None,
|
||||
max_remote_requests_in_flight_per_actor: int = 2,
|
||||
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None,
|
||||
remote_args: Optional[List[List[Any]]] = None,
|
||||
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Dict[ActorHandle, Any]:
|
||||
"""Runs parallel and asynchronous rollouts on all remote workers.
|
||||
|
||||
May use a timeout (if provided) on `ray.wait()` and returns only those
|
||||
samples that could be gathered in the timeout window. Allows a maximum
|
||||
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
|
||||
per remote actor.
|
||||
|
||||
Alternatively to calling `actor.sample.remote()`, the user can provide a
|
||||
`remote_fn()`, which will be applied to the actor(s) instead.
|
||||
|
||||
Args:
|
||||
remote_requests_in_flight: Dict mapping actor handles to a set of
|
||||
their currently-in-flight pending requests (those we expect to
|
||||
ray.get results for next). If you have an RLlib Trainer that calls
|
||||
this function, you can use its `self.remote_requests_in_flight`
|
||||
property here.
|
||||
actors: The List of ActorHandles to perform the remote requests on.
|
||||
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
|
||||
`ray.wait()` calls. If None (default), never time out (block
|
||||
until at least one actor returns something).
|
||||
max_remote_requests_in_flight_per_actor: Maximum number of remote
|
||||
requests sent to each actor. 2 (default) is probably
|
||||
sufficient to avoid idle times between two requests.
|
||||
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
|
||||
`actor.sample.remote()` to generate the requests.
|
||||
remote_args: If provided, use this list (per-actor) of lists (call
|
||||
args) as *args to be passed to the `remote_fn`.
|
||||
E.g.: actors=[A, B],
|
||||
remote_args=[[...] <- *args for A, [...] <- *args for B].
|
||||
remote_kwargs: If provided, use this list (per-actor) of dicts
|
||||
(kwargs) as **kwargs to be passed to the `remote_fn`.
|
||||
E.g.: actors=[A, B],
|
||||
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
|
||||
|
||||
Returns:
|
||||
A dict mapping actor handles to the results received by sending requests
|
||||
to these actors.
|
||||
None, if no samples are ready.
|
||||
|
||||
Examples:
|
||||
>>> # 2 remote rollout workers (num_workers=2):
|
||||
>>> batches = asynchronous_parallel_sample(
|
||||
... trainer.remote_requests_in_flight,
|
||||
... actors=trainer.workers.remote_workers(),
|
||||
... ray_wait_timeout_s=0.1,
|
||||
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
|
||||
... )
|
||||
>>> print(len(batches))
|
||||
... 2
|
||||
>>> # Expect a timeout to have happened.
|
||||
>>> batches[0] is None and batches[1] is None
|
||||
... True
|
||||
"""
|
||||
|
||||
if remote_args is not None:
|
||||
assert len(remote_args) == len(actors)
|
||||
if remote_kwargs is not None:
|
||||
assert len(remote_kwargs) == len(actors)
|
||||
|
||||
# For faster hash lookup.
|
||||
actor_set = set(actors)
|
||||
|
||||
# Collect all currently pending remote requests into a single set of
|
||||
# object refs.
|
||||
pending_remotes = set()
|
||||
# Also build a map to get the associated actor for each remote request.
|
||||
remote_to_actor = {}
|
||||
for actor, set_ in remote_requests_in_flight.items():
|
||||
# Only consider those actors' pending requests that are in
|
||||
# the given `actors` list.
|
||||
if actor in actor_set:
|
||||
pending_remotes |= set_
|
||||
for r in set_:
|
||||
remote_to_actor[r] = actor
|
||||
|
||||
# Add new requests, if possible (if
|
||||
# `max_remote_requests_in_flight_per_actor` setting allows it).
|
||||
for actor_idx, actor in enumerate(actors):
|
||||
# Still room for another request to this actor.
|
||||
if len(remote_requests_in_flight[actor]) < \
|
||||
max_remote_requests_in_flight_per_actor:
|
||||
if remote_fn is None:
|
||||
req = actor.sample.remote()
|
||||
else:
|
||||
args = remote_args[actor_idx] if remote_args else []
|
||||
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
|
||||
req = actor.apply.remote(remote_fn, *args, **kwargs)
|
||||
# Add to our set to send to ray.wait().
|
||||
pending_remotes.add(req)
|
||||
# Keep our mappings properly updated.
|
||||
remote_requests_in_flight[actor].add(req)
|
||||
remote_to_actor[req] = actor
|
||||
|
||||
# There must always be pending remote requests.
|
||||
assert len(pending_remotes) > 0
|
||||
pending_remote_list = list(pending_remotes)
|
||||
|
||||
# No timeout: Block until at least one result is returned.
|
||||
if ray_wait_timeout_s is None:
|
||||
# First try to do a `ray.wait` w/o timeout for efficiency.
|
||||
ready, _ = ray.wait(
|
||||
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
|
||||
# Nothing returned and `timeout` is None -> Fall back to a
|
||||
# blocking wait to make sure we can return something.
|
||||
if not ready:
|
||||
ready, _ = ray.wait(pending_remote_list, num_returns=1)
|
||||
# Timeout: Do a `ray.wait() call` w/ timeout.
|
||||
else:
|
||||
ready, _ = ray.wait(
|
||||
pending_remote_list,
|
||||
num_returns=len(pending_remotes),
|
||||
timeout=ray_wait_timeout_s)
|
||||
|
||||
# Return empty results if nothing ready after the timeout.
|
||||
if not ready:
|
||||
return {}
|
||||
|
||||
# Remove in-flight records for ready refs.
|
||||
for obj_ref in ready:
|
||||
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)
|
||||
|
||||
# Do one ray.get().
|
||||
results = ray.get(ready)
|
||||
assert len(ready) == len(results)
|
||||
|
||||
# Return mapping from (ready) actors to their results.
|
||||
ret = {}
|
||||
for obj_ref, result in zip(ready, results):
|
||||
ret[remote_to_actor[obj_ref]] = result
|
||||
|
||||
return ret
|
|
@ -42,12 +42,12 @@ class StoreToReplayBuffer:
|
|||
actors: An optional list of replay actors to use instead of
|
||||
`local_buffer`.
|
||||
"""
|
||||
if bool(local_buffer) == bool(actors):
|
||||
if local_buffer is not None and actors is not None:
|
||||
raise ValueError(
|
||||
"Either `local_buffer` or `replay_actors` must be given, "
|
||||
"not both!")
|
||||
|
||||
if local_buffer:
|
||||
if local_buffer is not None:
|
||||
self.local_actor = local_buffer
|
||||
self.replay_actors = None
|
||||
else:
|
||||
|
@ -55,7 +55,7 @@ class StoreToReplayBuffer:
|
|||
self.replay_actors = actors
|
||||
|
||||
def __call__(self, batch: SampleBatchType):
|
||||
if self.local_actor:
|
||||
if self.local_actor is not None:
|
||||
self.local_actor.add_batch(batch)
|
||||
else:
|
||||
actor = random.choice(self.replay_actors)
|
||||
|
@ -64,8 +64,8 @@ class StoreToReplayBuffer:
|
|||
|
||||
|
||||
def Replay(*,
|
||||
local_buffer: MultiAgentReplayBuffer = None,
|
||||
actors: List[ActorHandle] = None,
|
||||
local_buffer: Optional[MultiAgentReplayBuffer] = None,
|
||||
actors: Optional[List[ActorHandle]] = None,
|
||||
num_async: int = 4) -> LocalIterator[SampleBatchType]:
|
||||
"""Replay experiences from the given buffer or actors.
|
||||
|
||||
|
@ -87,11 +87,11 @@ def Replay(*,
|
|||
SampleBatch(...)
|
||||
"""
|
||||
|
||||
if bool(local_buffer) == bool(actors):
|
||||
if local_buffer is not None and actors is not None:
|
||||
raise ValueError(
|
||||
"Exactly one of local_buffer and replay_actors must be given.")
|
||||
|
||||
if actors:
|
||||
if actors is not None:
|
||||
replay = from_actors(actors)
|
||||
return replay.gather_async(
|
||||
num_async=num_async).filter(lambda x: x is not None)
|
||||
|
@ -135,6 +135,8 @@ class SimpleReplayBuffer:
|
|||
self.replay_batches = []
|
||||
self.replay_index = 0
|
||||
|
||||
self.last_added_batches = []
|
||||
|
||||
def add_batch(self, sample_batch: SampleBatchType) -> None:
|
||||
warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
|
||||
if self.num_slots > 0:
|
||||
|
@ -145,9 +147,14 @@ class SimpleReplayBuffer:
|
|||
self.replay_index += 1
|
||||
self.replay_index %= self.num_slots
|
||||
|
||||
self.last_added_batches.append(sample_batch)
|
||||
|
||||
def replay(self) -> SampleBatchType:
|
||||
return random.choice(self.replay_batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.replay_batches)
|
||||
|
||||
|
||||
class MixInReplay:
|
||||
"""This operator adds replay to a stream of experiences.
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Container, Dict, List, Optional, Tuple, \
|
||||
from typing import Callable, Container, List, Optional, Tuple, \
|
||||
TYPE_CHECKING
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
|
||||
|
@ -21,7 +20,6 @@ from ray.util.iter import from_actors, LocalIterator
|
|||
from ray.util.iter_metrics import SharedMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -77,135 +75,6 @@ def synchronous_parallel_sample(
|
|||
return sample_batches
|
||||
|
||||
|
||||
# TODO: Move to generic parallel ops module and rename to
|
||||
# `asynchronous_parallel_requests`:
|
||||
@ExperimentalAPI
|
||||
def asynchronous_parallel_sample(
|
||||
trainer: "Trainer",
|
||||
actors: List[ActorHandle],
|
||||
ray_wait_timeout_s: Optional[float] = None,
|
||||
max_remote_requests_in_flight_per_actor: int = 2,
|
||||
remote_fn: Optional[Callable[["RolloutWorker"], None]] = None,
|
||||
remote_args: Optional[List[List[Any]]] = None,
|
||||
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[List[SampleBatch]]:
|
||||
"""Runs parallel and asynchronous rollouts on all remote workers.
|
||||
|
||||
May use a timeout (if provided) on `ray.wait()` and returns only those
|
||||
samples that could be gathered in the timeout window. Allows a maximum
|
||||
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
|
||||
per remote actor.
|
||||
|
||||
Alternatively to calling `actor.sample.remote()`, the user can provide a
|
||||
`remote_fn()`, which will be applied to the actor(s) instead.
|
||||
|
||||
Args:
|
||||
trainer: The Trainer object that we run the sampling for.
|
||||
actors: The List of ActorHandles to perform the remote requests on.
|
||||
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
|
||||
`ray.wait()` calls. If None (default), never time out (block
|
||||
until at least one actor returns something).
|
||||
max_remote_requests_in_flight_per_actor: Maximum number of remote
|
||||
requests sent to each actor. 2 (default) is probably
|
||||
sufficient to avoid idle times between two requests.
|
||||
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
|
||||
`actor.sample.remote()` to generate the requests.
|
||||
remote_args: If provided, use this list (per-actor) of lists (call
|
||||
args) as *args to be passed to the `remote_fn`.
|
||||
E.g.: actors=[A, B],
|
||||
remote_args=[[...] <- *args for A, [...] <- *args for B].
|
||||
remote_kwargs: If provided, use this list (per-actor) of dicts
|
||||
(kwargs) as **kwargs to be passed to the `remote_fn`.
|
||||
E.g.: actors=[A, B],
|
||||
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
|
||||
|
||||
Returns:
|
||||
The list of asynchronously collected sample batch types. None, if no
|
||||
samples are ready.
|
||||
|
||||
Examples:
|
||||
>>> # 2 remote rollout workers (num_workers=2):
|
||||
>>> batches = asynchronous_parallel_sample(
|
||||
... trainer,
|
||||
... actors=trainer.workers.remote_workers(),
|
||||
... ray_wait_timeout_s=0.1,
|
||||
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
|
||||
... )
|
||||
>>> print(len(batches))
|
||||
... 2
|
||||
>>> # Expect a timeout to have happened.
|
||||
>>> batches[0] is None and batches[1] is None
|
||||
... True
|
||||
"""
|
||||
|
||||
if remote_args is not None:
|
||||
assert len(remote_args) == len(actors)
|
||||
if remote_kwargs is not None:
|
||||
assert len(remote_kwargs) == len(actors)
|
||||
|
||||
# Collect all currently pending remote requests into a single set of
|
||||
# object refs.
|
||||
pending_remotes = set()
|
||||
# Also build a map to get the associated actor for each remote request.
|
||||
remote_to_actor = {}
|
||||
for actor, set_ in trainer.remote_requests_in_flight.items():
|
||||
pending_remotes |= set_
|
||||
for r in set_:
|
||||
remote_to_actor[r] = actor
|
||||
|
||||
# Add new requests, if possible (if
|
||||
# `max_remote_requests_in_flight_per_actor` setting allows it).
|
||||
for actor_idx, actor in enumerate(actors):
|
||||
# Still room for another request to this actor.
|
||||
if len(trainer.remote_requests_in_flight[actor]) < \
|
||||
max_remote_requests_in_flight_per_actor:
|
||||
if remote_fn is None:
|
||||
req = actor.sample.remote()
|
||||
else:
|
||||
args = remote_args[actor_idx] if remote_args else []
|
||||
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
|
||||
req = actor.apply.remote(remote_fn, *args, **kwargs)
|
||||
# Add to our set to send to ray.wait().
|
||||
pending_remotes.add(req)
|
||||
# Keep our mappings properly updated.
|
||||
trainer.remote_requests_in_flight[actor].add(req)
|
||||
remote_to_actor[req] = actor
|
||||
|
||||
# There must always be pending remote requests.
|
||||
assert len(pending_remotes) > 0
|
||||
pending_remote_list = list(pending_remotes)
|
||||
|
||||
# No timeout: Block until at least one result is returned.
|
||||
if ray_wait_timeout_s is None:
|
||||
# First try to do a `ray.wait` w/o timeout for efficiency.
|
||||
ready, _ = ray.wait(
|
||||
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
|
||||
# Nothing returned and `timeout` is None -> Fall back to a
|
||||
# blocking wait to make sure we can return something.
|
||||
if not ready:
|
||||
ready, _ = ray.wait(pending_remote_list, num_returns=1)
|
||||
# Timeout: Do a `ray.wait() call` w/ timeout.
|
||||
else:
|
||||
ready, _ = ray.wait(
|
||||
pending_remote_list,
|
||||
num_returns=len(pending_remotes),
|
||||
timeout=ray_wait_timeout_s)
|
||||
|
||||
# Return None if nothing ready after the timeout.
|
||||
if not ready:
|
||||
return None
|
||||
|
||||
for obj_ref in ready:
|
||||
# Remove in-flight record for this ref.
|
||||
trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove(
|
||||
obj_ref)
|
||||
remote_to_actor.pop(obj_ref)
|
||||
|
||||
results = ray.get(ready)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
|
||||
num_async=1) -> LocalIterator[SampleBatch]:
|
||||
"""Operator to collect experiences in parallel from rollout workers.
|
||||
|
|
151
rllib/execution/tests/test_mixin_multi_agent_replay_buffer.py
Normal file
151
rllib/execution/tests/test_mixin_multi_agent_replay_buffer.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.execution.buffers.mixin_replay_buffer import \
|
||||
MixInMultiAgentReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
||||
|
||||
class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
||||
"""Tests insertion and mixed sampling of the MixInMultiAgentReplayBuffer.
|
||||
"""
|
||||
|
||||
capacity = 10
|
||||
|
||||
def _generate_data(self):
|
||||
return SampleBatch({
|
||||
"obs": [np.random.random((4, ))],
|
||||
"action": [np.random.choice([0, 1])],
|
||||
"reward": [np.random.rand()],
|
||||
"new_obs": [np.random.random((4, ))],
|
||||
"done": [np.random.choice([False, True])],
|
||||
})
|
||||
|
||||
def test_mixin_sampling(self):
|
||||
# 50% replay ratio.
|
||||
buffer = MixInMultiAgentReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=0.5)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect at least 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) >= 1)
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
# len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 2.0)
|
||||
|
||||
# 33% replay ratio.
|
||||
buffer = MixInMultiAgentReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=0.333)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect at least 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) >= 1)
|
||||
# If we insert-2x and replay n times, expect roughly return batches of
|
||||
# len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1)
|
||||
|
||||
# If we insert-1x and replay n times, expect roughly return batches of
|
||||
# len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old
|
||||
# samples on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1)
|
||||
|
||||
# 90% replay ratio.
|
||||
buffer = MixInMultiAgentReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=0.9)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect at least 2 samples to be returned (new one plus at least one
|
||||
# replay sample).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) >= 2)
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
# len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
|
||||
# samples on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1)
|
||||
|
||||
# 0% replay ratio -> Only new samples.
|
||||
buffer = MixInMultiAgentReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=0.0)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect exactly 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) == 1)
|
||||
# Expect exactly 0 sample to be returned (nothing new to be returned;
|
||||
# no replay allowed (replay_ratio=0.0)).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 1.0)
|
||||
|
||||
# 100% replay ratio -> Only new samples.
|
||||
buffer = MixInMultiAgentReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=1.0)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect exactly 1 sample to be returned (the new batch).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) == 1)
|
||||
# Another replay -> Expect exactly 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) == 1)
|
||||
# If we replay n times, expect roughly return batches of
|
||||
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -38,7 +38,7 @@ class SlimFC:
|
|||
|
||||
# By default, use Glorot unform initializer.
|
||||
if initializer is None:
|
||||
initializer = flax.nn.initializers.xavier_uniform()
|
||||
initializer = nn.initializers.xavier_uniform()
|
||||
|
||||
self.prng_key = prng_key or jax.random.PRNGKey(int(time.time()))
|
||||
_, self.prng_key = jax.random.split(self.prng_key)
|
||||
|
|
|
@ -19,6 +19,7 @@ from ray.rllib.utils import add_mixins, force_list
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.metrics import NUM_AGENT_STEPS_TRAINED
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.spaces.space_utils import normalize_action
|
||||
|
@ -296,7 +297,10 @@ def build_eager_tf_policy(
|
|||
|
||||
class eager_policy_cls(base):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
assert tf.executing_eagerly()
|
||||
# If this class runs as a @ray.remote actor, eager mode may not
|
||||
# have been activated yet.
|
||||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
self.framework = config.get("framework", "tfe")
|
||||
Policy.__init__(self, observation_space, action_space, config)
|
||||
|
||||
|
@ -600,7 +604,10 @@ def build_eager_tf_policy(
|
|||
postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
postprocessed_batch.set_training(True)
|
||||
stats = self._learn_on_batch_helper(postprocessed_batch)
|
||||
stats.update({"custom_metrics": learn_stats})
|
||||
stats.update({
|
||||
"custom_metrics": learn_stats,
|
||||
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
|
||||
})
|
||||
return convert_to_numpy(stats)
|
||||
|
||||
@override(Policy)
|
||||
|
|
|
@ -6,8 +6,11 @@ import logging
|
|||
import numpy as np
|
||||
import platform
|
||||
import tree # pip install dm_tree
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
|
||||
TYPE_CHECKING, Union
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
@ -22,7 +25,7 @@ from ray.rllib.utils.from_config import from_config
|
|||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
get_dummy_batch_for_space, unbatch
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
||||
T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
|
||||
PolicyID, PolicyState, T, TensorType, TensorStructType, TrainerConfigDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -451,6 +454,36 @@ class Policy(metaclass=ABCMeta):
|
|||
self.apply_gradients(grads)
|
||||
return grad_info
|
||||
|
||||
@ExperimentalAPI
|
||||
def learn_on_batch_from_replay_buffer(
|
||||
self, replay_actor: ActorHandle,
|
||||
policy_id: PolicyID) -> Dict[str, TensorType]:
|
||||
"""Samples a batch from given replay actor and performs an update.
|
||||
|
||||
Args:
|
||||
replay_actor: The replay buffer actor to sample from.
|
||||
policy_id: The ID of this policy.
|
||||
|
||||
Returns:
|
||||
Dictionary of extra metadata from `compute_gradients()`.
|
||||
"""
|
||||
# Sample a batch from the given replay actor.
|
||||
# Note that for better performance (less data sent through the
|
||||
# network), this policy should be co-located on the same node
|
||||
# as `replay_actor`. Such a co-location step is usually done during
|
||||
# the Trainer's `setup()` phase.
|
||||
batch = ray.get(replay_actor.replay.remote(policy_id=policy_id))
|
||||
if batch is None:
|
||||
return {}
|
||||
|
||||
# Send to own learn_on_batch method for updating.
|
||||
# TODO: hack w/ `hasattr`
|
||||
if hasattr(self, "devices") and len(self.devices) > 1:
|
||||
self.load_batch_into_buffer(batch, buffer_index=0)
|
||||
return self.learn_on_loaded_batch(offset=0, buffer_index=0)
|
||||
else:
|
||||
return self.learn_on_batch(batch)
|
||||
|
||||
@DeveloperAPI
|
||||
def load_batch_into_buffer(self, batch: SampleBatch,
|
||||
buffer_index: int = 0) -> int:
|
||||
|
@ -606,7 +639,7 @@ class Policy(metaclass=ABCMeta):
|
|||
return []
|
||||
|
||||
@DeveloperAPI
|
||||
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||
def get_state(self) -> PolicyState:
|
||||
"""Returns the entire current state of this Policy.
|
||||
|
||||
Note: Not to be confused with an RNN model's internal state.
|
||||
|
@ -626,10 +659,7 @@ class Policy(metaclass=ABCMeta):
|
|||
return state
|
||||
|
||||
@DeveloperAPI
|
||||
def set_state(
|
||||
self,
|
||||
state: Union[Dict[str, TensorType], List[TensorType]],
|
||||
) -> None:
|
||||
def set_state(self, state: PolicyState) -> None:
|
||||
"""Restores the entire current state of this Policy from `state`.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,11 +8,15 @@ NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
|
|||
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||
NUM_TARGET_UPDATES = "num_target_updates"
|
||||
|
||||
# Performance timers (keys for metrics.timers).
|
||||
# Performance timers (keys for Trainer._timers or metrics.timers).
|
||||
APPLY_GRADS_TIMER = "apply_grad"
|
||||
COMPUTE_GRADS_TIMER = "compute_grads"
|
||||
WORKER_UPDATE_TIMER = "update"
|
||||
SYNCH_WORKER_WEIGHTS_TIMER = "synch_weights"
|
||||
GRAD_WAIT_TIMER = "grad_wait"
|
||||
SAMPLE_TIMER = "sample"
|
||||
LEARN_ON_BATCH_TIMER = "learn"
|
||||
LOAD_BATCH_TIMER = "load"
|
||||
TARGET_NET_UPDATE_TIMER = "target_net_update"
|
||||
|
||||
# Deprecated: Use `SYNCH_WORKER_WEIGHTS_TIMER` instead.
|
||||
WORKER_UPDATE_TIMER = "update"
|
||||
|
|
|
@ -2,27 +2,72 @@ import numpy as np
|
|||
|
||||
|
||||
class WindowStat:
|
||||
def __init__(self, name, n):
|
||||
"""Handles/stores incoming datastream and provides window-based statistics.
|
||||
|
||||
Examples:
|
||||
>>> win_stats = WindowStat("level", 3)
|
||||
>>> win_stats.push(5.0)
|
||||
>>> win_stats.push(7.0)
|
||||
>>> win_stats.push(7.0)
|
||||
>>> win_stats.push(10.0)
|
||||
>>> # Expect 8.0 as the mean of the last 3 values: (7+7+10)/3=8.0
|
||||
>>> print(win_stats.mean())
|
||||
... 8.0
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, n: int):
|
||||
"""Initializes a WindowStat instance.
|
||||
|
||||
Args:
|
||||
name: The name of the stats to collect and return stats for.
|
||||
n: The window size. Statistics will be computed for the last n
|
||||
items received from the stream.
|
||||
"""
|
||||
# The window-size.
|
||||
self.window_size = n
|
||||
# The name of the data (used for `self.stats()`).
|
||||
self.name = name
|
||||
self.items = [None] * n
|
||||
# List of items to do calculations over (len=self.n).
|
||||
self.items = [None] * self.window_size
|
||||
# The current index to insert the next item into `self.items`.
|
||||
self.idx = 0
|
||||
# How many items have been added over the lifetime of this object.
|
||||
self.count = 0
|
||||
|
||||
def push(self, obj):
|
||||
def push(self, obj) -> None:
|
||||
"""Pushes a new value/object into the data buffer."""
|
||||
# Insert object at current index.
|
||||
self.items[self.idx] = obj
|
||||
# Increase insertion index by 1.
|
||||
self.idx += 1
|
||||
# Increase lifetime count by 1.
|
||||
self.count += 1
|
||||
# Fix index in case of rollover.
|
||||
self.idx %= len(self.items)
|
||||
|
||||
def stats(self):
|
||||
def mean(self) -> float:
|
||||
"""Returns the (NaN-)mean of the last `self.window_size` items.
|
||||
"""
|
||||
return float(np.nanmean(self.items[:self.count]))
|
||||
|
||||
def std(self) -> float:
|
||||
"""Returns the (NaN)-stddev of the last `self.window_size` items.
|
||||
"""
|
||||
return float(np.nanstd(self.items[:self.count]))
|
||||
|
||||
def quantiles(self) -> np.ndarray:
|
||||
"""Returns ndarray with 0, 10, 50, 90, and 100 percentiles.
|
||||
"""
|
||||
if not self.count:
|
||||
_quantiles = []
|
||||
return np.ndarray([], dtype=np.float32)
|
||||
else:
|
||||
_quantiles = np.nanpercentile(self.items[:self.count],
|
||||
[0, 10, 50, 90, 100]).tolist()
|
||||
return np.nanpercentile(self.items[:self.count],
|
||||
[0, 10, 50, 90, 100]).tolist()
|
||||
|
||||
def stats(self):
|
||||
return {
|
||||
self.name + "_count": int(self.count),
|
||||
self.name + "_mean": float(np.nanmean(self.items[:self.count])),
|
||||
self.name + "_std": float(np.nanstd(self.items[:self.count])),
|
||||
self.name + "_quantiles": _quantiles,
|
||||
self.name + "_mean": self.mean(),
|
||||
self.name + "_std": self.std(),
|
||||
self.name + "_quantiles": self.quantiles(),
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue