[RLlib] Preparatory PR for multi-agent, multi-GPU learning agent (alpha-star style) #02. (#21649)

This commit is contained in:
Sven Mika 2022-01-27 22:07:05 +01:00 committed by GitHub
parent 8ebc50f844
commit ee41800c16
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 596 additions and 170 deletions

View file

@ -0,0 +1,146 @@
import collections
import platform
import random
from typing import Optional
from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType
class MixInMultiAgentReplayBuffer:
"""This buffer adds replayed samples to a stream of new experiences.
- Any newly added batch (`add_batch()`) is immediately returned upon
the next `replay` call (close to on-policy) as well as being moved
into the buffer.
- Additionally, a certain number of old samples is mixed into the
returned sample according to a given "replay ratio".
- If >1 calls to `add_batch()` are made without any `replay()` calls
in between, all newly added batches are returned (plus some older samples
according to the "replay ratio").
Examples:
# replay ratio 0.66 (2/3 replayed, 1/3 new samples):
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100,
... replay_ratio=0.66)
>>> buffer.add_batch(<A>)
>>> buffer.add_batch(<B>)
>>> buffer.replay()
... [<A>, <B>, <B>]
>>> buffer.add_batch(<C>)
>>> buffer.replay()
... [<C>, <A>, <B>]
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it
>>> # is the newest sample
>>> buffer.add_batch(<D>)
>>> buffer.replay()
... [<D>, <A>, <C>]
# replay proportion 0.0 -> replay disabled:
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0)
>>> buffer.add_batch(<A>)
>>> buffer.replay()
... [<A>]
>>> buffer.add_batch(<B>)
>>> buffer.replay()
... [<B>]
"""
def __init__(self, capacity: int, replay_ratio: float):
"""Initializes MixInReplay instance.
Args:
capacity (int): Number of batches to store in total.
replay_ratio (float): Ratio of replayed samples in the returned
batches. E.g. a ratio of 0.0 means only return new samples
(no replay), a ratio of 0.5 means always return newest sample
plus one old one (1:1), a ratio of 0.66 means always return
the newest sample plus 2 old (replayed) ones (1:2), etc...
"""
self.capacity = capacity
self.replay_ratio = replay_ratio
self.replay_proportion = None
if self.replay_ratio != 1.0:
self.replay_proportion = self.replay_ratio / (
1.0 - self.replay_ratio)
def new_buffer():
return SimpleReplayBuffer(num_slots=capacity)
self.replay_buffers = collections.defaultdict(new_buffer)
# Metrics.
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
# Added timesteps over lifetime.
self.num_added = 0
# Last added batch(es).
self.last_added_batches = collections.defaultdict(list)
def add_batch(self, batch: SampleBatchType) -> None:
"""Adds a batch to the appropriate policy's replay buffer.
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
it is not a MultiAgentBatch. Subsequently adds the individual policy
batches to the storage.
Args:
batch: The batch to be added.
"""
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
batch = batch.as_multi_agent()
with self.add_batch_timer:
for policy_id, sample_batch in batch.policy_batches.items():
self.replay_buffers[policy_id].add_batch(sample_batch)
self.last_added_batches[policy_id].append(sample_batch)
self.num_added += batch.count
def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
Optional[SampleBatchType]:
buffer = self.replay_buffers[policy_id]
# Return None, if:
# - Buffer empty or
# - `replay_ratio` < 1.0 (new samples required in returned batch)
# and no new samples to mix with replayed ones.
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0
and self.replay_ratio < 1.0):
return None
# Mix buffer's last added batches with older replayed batches.
with self.replay_timer:
output_batches = self.last_added_batches[policy_id]
self.last_added_batches[policy_id] = []
# No replay desired -> Return here.
if self.replay_ratio == 0.0:
return SampleBatch.concat_samples(output_batches)
# Only replay desired -> Return a (replayed) sample from the
# buffer.
elif self.replay_ratio == 1.0:
return buffer.replay()
# Replay ratio = old / [old + new]
# Replay proportion: old / new
num_new = len(output_batches)
replay_proportion = self.replay_proportion
while random.random() < num_new * replay_proportion:
replay_proportion -= 1
output_batches.append(buffer.replay())
return SampleBatch.concat_samples(output_batches)
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()

View file

@ -1,6 +1,6 @@
import collections
import platform
from typing import Any, Dict
from typing import Any, Dict, Optional
import numpy as np
import ray
@ -13,7 +13,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.util.iter import ParallelIteratorWorker
@ -195,7 +195,7 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
time_slice, weight=weight)
self.num_added += batch.count
def replay(self) -> SampleBatchType:
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
"""If this buffer was given a fake batch, return it, otherwise return
a MultiAgentBatch with samples.
"""
@ -211,8 +211,13 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
# Lockstep mode: Sample from all policies at the same time an
# equal amount of steps.
if self.replay_mode == "lockstep":
assert policy_id is None, \
"`policy_id` specifier not allowed in `locksetp` mode!"
return self.replay_buffers[_ALL_POLICIES].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
elif policy_id is not None:
return self.replay_buffers[policy_id].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
else:
samples = {}
for policy_id, replay_buffer in self.replay_buffers.items():

View file

@ -132,19 +132,25 @@ class ReplayBuffer:
@DeveloperAPI
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
"""Sample a batch of experiences.
"""Sample a batch of size `num_items` from this buffer.
If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.
Args:
num_items: Number of items to sample from this buffer.
beta: This is ignored (only used by prioritized replay buffers).
beta: The prioritized replay beta value. Only relevant if this
ReplayBuffer is a PrioritizedReplayBuffer.
Returns:
Concatenated batch of items.
"""
idxes = [
random.randint(0,
len(self._storage) - 1) for _ in range(num_items)
]
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
sample = self._encode_sample(idxes)
# Update our timesteps counters.
self._num_timesteps_sampled += len(sample)
@ -282,6 +288,10 @@ class PrioritizedReplayBuffer(ReplayBuffer):
"batch_indexes" fields denoting IS of each sampled
transition and original idxes in buffer of sampled experiences.
"""
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None
assert beta >= 0.0
idxes = self._sample_proportional(num_items)

View file

@ -0,0 +1,152 @@
import logging
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set
import ray
from ray.actor import ActorHandle
from ray.rllib.utils.annotations import ExperimentalAPI
logger = logging.getLogger(__name__)
@ExperimentalAPI
def asynchronous_parallel_requests(
remote_requests_in_flight: DefaultDict[ActorHandle, Set[
ray.ObjectRef]],
actors: List[ActorHandle],
ray_wait_timeout_s: Optional[float] = None,
max_remote_requests_in_flight_per_actor: int = 2,
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None,
remote_args: Optional[List[List[Any]]] = None,
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Dict[ActorHandle, Any]:
"""Runs parallel and asynchronous rollouts on all remote workers.
May use a timeout (if provided) on `ray.wait()` and returns only those
samples that could be gathered in the timeout window. Allows a maximum
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
per remote actor.
Alternatively to calling `actor.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the actor(s) instead.
Args:
remote_requests_in_flight: Dict mapping actor handles to a set of
their currently-in-flight pending requests (those we expect to
ray.get results for next). If you have an RLlib Trainer that calls
this function, you can use its `self.remote_requests_in_flight`
property here.
actors: The List of ActorHandles to perform the remote requests on.
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
`ray.wait()` calls. If None (default), never time out (block
until at least one actor returns something).
max_remote_requests_in_flight_per_actor: Maximum number of remote
requests sent to each actor. 2 (default) is probably
sufficient to avoid idle times between two requests.
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
`actor.sample.remote()` to generate the requests.
remote_args: If provided, use this list (per-actor) of lists (call
args) as *args to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_args=[[...] <- *args for A, [...] <- *args for B].
remote_kwargs: If provided, use this list (per-actor) of dicts
(kwargs) as **kwargs to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
Returns:
A dict mapping actor handles to the results received by sending requests
to these actors.
None, if no samples are ready.
Examples:
>>> # 2 remote rollout workers (num_workers=2):
>>> batches = asynchronous_parallel_sample(
... trainer.remote_requests_in_flight,
... actors=trainer.workers.remote_workers(),
... ray_wait_timeout_s=0.1,
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
... )
>>> print(len(batches))
... 2
>>> # Expect a timeout to have happened.
>>> batches[0] is None and batches[1] is None
... True
"""
if remote_args is not None:
assert len(remote_args) == len(actors)
if remote_kwargs is not None:
assert len(remote_kwargs) == len(actors)
# For faster hash lookup.
actor_set = set(actors)
# Collect all currently pending remote requests into a single set of
# object refs.
pending_remotes = set()
# Also build a map to get the associated actor for each remote request.
remote_to_actor = {}
for actor, set_ in remote_requests_in_flight.items():
# Only consider those actors' pending requests that are in
# the given `actors` list.
if actor in actor_set:
pending_remotes |= set_
for r in set_:
remote_to_actor[r] = actor
# Add new requests, if possible (if
# `max_remote_requests_in_flight_per_actor` setting allows it).
for actor_idx, actor in enumerate(actors):
# Still room for another request to this actor.
if len(remote_requests_in_flight[actor]) < \
max_remote_requests_in_flight_per_actor:
if remote_fn is None:
req = actor.sample.remote()
else:
args = remote_args[actor_idx] if remote_args else []
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
req = actor.apply.remote(remote_fn, *args, **kwargs)
# Add to our set to send to ray.wait().
pending_remotes.add(req)
# Keep our mappings properly updated.
remote_requests_in_flight[actor].add(req)
remote_to_actor[req] = actor
# There must always be pending remote requests.
assert len(pending_remotes) > 0
pending_remote_list = list(pending_remotes)
# No timeout: Block until at least one result is returned.
if ray_wait_timeout_s is None:
# First try to do a `ray.wait` w/o timeout for efficiency.
ready, _ = ray.wait(
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
# Nothing returned and `timeout` is None -> Fall back to a
# blocking wait to make sure we can return something.
if not ready:
ready, _ = ray.wait(pending_remote_list, num_returns=1)
# Timeout: Do a `ray.wait() call` w/ timeout.
else:
ready, _ = ray.wait(
pending_remote_list,
num_returns=len(pending_remotes),
timeout=ray_wait_timeout_s)
# Return empty results if nothing ready after the timeout.
if not ready:
return {}
# Remove in-flight records for ready refs.
for obj_ref in ready:
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)
# Do one ray.get().
results = ray.get(ready)
assert len(ready) == len(results)
# Return mapping from (ready) actors to their results.
ret = {}
for obj_ref, result in zip(ready, results):
ret[remote_to_actor[obj_ref]] = result
return ret

View file

@ -42,12 +42,12 @@ class StoreToReplayBuffer:
actors: An optional list of replay actors to use instead of
`local_buffer`.
"""
if bool(local_buffer) == bool(actors):
if local_buffer is not None and actors is not None:
raise ValueError(
"Either `local_buffer` or `replay_actors` must be given, "
"not both!")
if local_buffer:
if local_buffer is not None:
self.local_actor = local_buffer
self.replay_actors = None
else:
@ -55,7 +55,7 @@ class StoreToReplayBuffer:
self.replay_actors = actors
def __call__(self, batch: SampleBatchType):
if self.local_actor:
if self.local_actor is not None:
self.local_actor.add_batch(batch)
else:
actor = random.choice(self.replay_actors)
@ -64,8 +64,8 @@ class StoreToReplayBuffer:
def Replay(*,
local_buffer: MultiAgentReplayBuffer = None,
actors: List[ActorHandle] = None,
local_buffer: Optional[MultiAgentReplayBuffer] = None,
actors: Optional[List[ActorHandle]] = None,
num_async: int = 4) -> LocalIterator[SampleBatchType]:
"""Replay experiences from the given buffer or actors.
@ -87,11 +87,11 @@ def Replay(*,
SampleBatch(...)
"""
if bool(local_buffer) == bool(actors):
if local_buffer is not None and actors is not None:
raise ValueError(
"Exactly one of local_buffer and replay_actors must be given.")
if actors:
if actors is not None:
replay = from_actors(actors)
return replay.gather_async(
num_async=num_async).filter(lambda x: x is not None)
@ -135,6 +135,8 @@ class SimpleReplayBuffer:
self.replay_batches = []
self.replay_index = 0
self.last_added_batches = []
def add_batch(self, sample_batch: SampleBatchType) -> None:
warn_replay_capacity(item=sample_batch, num_items=self.num_slots)
if self.num_slots > 0:
@ -145,9 +147,14 @@ class SimpleReplayBuffer:
self.replay_index += 1
self.replay_index %= self.num_slots
self.last_added_batches.append(sample_batch)
def replay(self) -> SampleBatchType:
return random.choice(self.replay_batches)
def __len__(self):
return len(self.replay_batches)
class MixInReplay:
"""This operator adds replay to a stream of experiences.

View file

@ -1,10 +1,9 @@
import logging
import time
from typing import Any, Callable, Container, Dict, List, Optional, Tuple, \
from typing import Callable, Container, List, Optional, Tuple, \
TYPE_CHECKING
import ray
from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
@ -21,7 +20,6 @@ from ray.util.iter import from_actors, LocalIterator
from ray.util.iter_metrics import SharedMetrics
if TYPE_CHECKING:
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.rollout_worker import RolloutWorker
logger = logging.getLogger(__name__)
@ -77,135 +75,6 @@ def synchronous_parallel_sample(
return sample_batches
# TODO: Move to generic parallel ops module and rename to
# `asynchronous_parallel_requests`:
@ExperimentalAPI
def asynchronous_parallel_sample(
trainer: "Trainer",
actors: List[ActorHandle],
ray_wait_timeout_s: Optional[float] = None,
max_remote_requests_in_flight_per_actor: int = 2,
remote_fn: Optional[Callable[["RolloutWorker"], None]] = None,
remote_args: Optional[List[List[Any]]] = None,
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Optional[List[SampleBatch]]:
"""Runs parallel and asynchronous rollouts on all remote workers.
May use a timeout (if provided) on `ray.wait()` and returns only those
samples that could be gathered in the timeout window. Allows a maximum
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
per remote actor.
Alternatively to calling `actor.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the actor(s) instead.
Args:
trainer: The Trainer object that we run the sampling for.
actors: The List of ActorHandles to perform the remote requests on.
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
`ray.wait()` calls. If None (default), never time out (block
until at least one actor returns something).
max_remote_requests_in_flight_per_actor: Maximum number of remote
requests sent to each actor. 2 (default) is probably
sufficient to avoid idle times between two requests.
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
`actor.sample.remote()` to generate the requests.
remote_args: If provided, use this list (per-actor) of lists (call
args) as *args to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_args=[[...] <- *args for A, [...] <- *args for B].
remote_kwargs: If provided, use this list (per-actor) of dicts
(kwargs) as **kwargs to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
Returns:
The list of asynchronously collected sample batch types. None, if no
samples are ready.
Examples:
>>> # 2 remote rollout workers (num_workers=2):
>>> batches = asynchronous_parallel_sample(
... trainer,
... actors=trainer.workers.remote_workers(),
... ray_wait_timeout_s=0.1,
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
... )
>>> print(len(batches))
... 2
>>> # Expect a timeout to have happened.
>>> batches[0] is None and batches[1] is None
... True
"""
if remote_args is not None:
assert len(remote_args) == len(actors)
if remote_kwargs is not None:
assert len(remote_kwargs) == len(actors)
# Collect all currently pending remote requests into a single set of
# object refs.
pending_remotes = set()
# Also build a map to get the associated actor for each remote request.
remote_to_actor = {}
for actor, set_ in trainer.remote_requests_in_flight.items():
pending_remotes |= set_
for r in set_:
remote_to_actor[r] = actor
# Add new requests, if possible (if
# `max_remote_requests_in_flight_per_actor` setting allows it).
for actor_idx, actor in enumerate(actors):
# Still room for another request to this actor.
if len(trainer.remote_requests_in_flight[actor]) < \
max_remote_requests_in_flight_per_actor:
if remote_fn is None:
req = actor.sample.remote()
else:
args = remote_args[actor_idx] if remote_args else []
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
req = actor.apply.remote(remote_fn, *args, **kwargs)
# Add to our set to send to ray.wait().
pending_remotes.add(req)
# Keep our mappings properly updated.
trainer.remote_requests_in_flight[actor].add(req)
remote_to_actor[req] = actor
# There must always be pending remote requests.
assert len(pending_remotes) > 0
pending_remote_list = list(pending_remotes)
# No timeout: Block until at least one result is returned.
if ray_wait_timeout_s is None:
# First try to do a `ray.wait` w/o timeout for efficiency.
ready, _ = ray.wait(
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
# Nothing returned and `timeout` is None -> Fall back to a
# blocking wait to make sure we can return something.
if not ready:
ready, _ = ray.wait(pending_remote_list, num_returns=1)
# Timeout: Do a `ray.wait() call` w/ timeout.
else:
ready, _ = ray.wait(
pending_remote_list,
num_returns=len(pending_remotes),
timeout=ray_wait_timeout_s)
# Return None if nothing ready after the timeout.
if not ready:
return None
for obj_ref in ready:
# Remove in-flight record for this ref.
trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove(
obj_ref)
remote_to_actor.pop(obj_ref)
results = ray.get(ready)
return results
def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
num_async=1) -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers.

View file

@ -0,0 +1,151 @@
import numpy as np
import unittest
from ray.rllib.execution.buffers.mixin_replay_buffer import \
MixInMultiAgentReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
"""Tests insertion and mixed sampling of the MixInMultiAgentReplayBuffer.
"""
capacity = 10
def _generate_data(self):
return SampleBatch({
"obs": [np.random.random((4, ))],
"action": [np.random.choice([0, 1])],
"reward": [np.random.rand()],
"new_obs": [np.random.random((4, ))],
"done": [np.random.choice([False, True])],
})
def test_mixin_sampling(self):
# 50% replay ratio.
buffer = MixInMultiAgentReplayBuffer(
capacity=self.capacity, replay_ratio=0.5)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect at least 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) >= 1)
# If we insert and replay n times, expect roughly return batches of
# len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample
# on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 2.0)
# 33% replay ratio.
buffer = MixInMultiAgentReplayBuffer(
capacity=self.capacity, replay_ratio=0.333)
# Expect exactly 0 samples to be returned (buffer empty).
sample = buffer.replay()
self.assertTrue(sample is None)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect at least 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) >= 1)
# If we insert-2x and replay n times, expect roughly return batches of
# len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample
# on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1)
# If we insert-1x and replay n times, expect roughly return batches of
# len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old
# samples on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1)
# 90% replay ratio.
buffer = MixInMultiAgentReplayBuffer(
capacity=self.capacity, replay_ratio=0.9)
# Expect exactly 0 samples to be returned (buffer empty).
sample = buffer.replay()
self.assertTrue(sample is None)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect at least 2 samples to be returned (new one plus at least one
# replay sample).
sample = buffer.replay()
self.assertTrue(len(sample) >= 2)
# If we insert and replay n times, expect roughly return batches of
# len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
# samples on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1)
# 0% replay ratio -> Only new samples.
buffer = MixInMultiAgentReplayBuffer(
capacity=self.capacity, replay_ratio=0.0)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect exactly 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) == 1)
# Expect exactly 0 sample to be returned (nothing new to be returned;
# no replay allowed (replay_ratio=0.0)).
sample = buffer.replay()
self.assertTrue(sample is None)
# If we insert and replay n times, expect roughly return batches of
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
# on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 1.0)
# 100% replay ratio -> Only new samples.
buffer = MixInMultiAgentReplayBuffer(
capacity=self.capacity, replay_ratio=1.0)
# Expect exactly 0 samples to be returned (buffer empty).
sample = buffer.replay()
self.assertTrue(sample is None)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect exactly 1 sample to be returned (the new batch).
sample = buffer.replay()
self.assertTrue(len(sample) == 1)
# Another replay -> Expect exactly 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) == 1)
# If we replay n times, expect roughly return batches of
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
# on average in each returned value).
results = []
for _ in range(100):
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 1.0)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -38,7 +38,7 @@ class SlimFC:
# By default, use Glorot unform initializer.
if initializer is None:
initializer = flax.nn.initializers.xavier_uniform()
initializer = nn.initializers.xavier_uniform()
self.prng_key = prng_key or jax.random.PRNGKey(int(time.time()))
_, self.prng_key = jax.random.split(self.prng_key)

View file

@ -19,6 +19,7 @@ from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import NUM_AGENT_STEPS_TRAINED
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import normalize_action
@ -296,7 +297,10 @@ def build_eager_tf_policy(
class eager_policy_cls(base):
def __init__(self, observation_space, action_space, config):
assert tf.executing_eagerly()
# If this class runs as a @ray.remote actor, eager mode may not
# have been activated yet.
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
self.framework = config.get("framework", "tfe")
Policy.__init__(self, observation_space, action_space, config)
@ -600,7 +604,10 @@ def build_eager_tf_policy(
postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
postprocessed_batch.set_training(True)
stats = self._learn_on_batch_helper(postprocessed_batch)
stats.update({"custom_metrics": learn_stats})
stats.update({
"custom_metrics": learn_stats,
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
})
return convert_to_numpy(stats)
@override(Policy)

View file

@ -6,8 +6,11 @@ import logging
import numpy as np
import platform
import tree # pip install dm_tree
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
TYPE_CHECKING, Union
import ray
from ray.actor import ActorHandle
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
@ -22,7 +25,7 @@ from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
get_dummy_batch_for_space, unbatch
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
PolicyID, PolicyState, T, TensorType, TensorStructType, TrainerConfigDict
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -451,6 +454,36 @@ class Policy(metaclass=ABCMeta):
self.apply_gradients(grads)
return grad_info
@ExperimentalAPI
def learn_on_batch_from_replay_buffer(
self, replay_actor: ActorHandle,
policy_id: PolicyID) -> Dict[str, TensorType]:
"""Samples a batch from given replay actor and performs an update.
Args:
replay_actor: The replay buffer actor to sample from.
policy_id: The ID of this policy.
Returns:
Dictionary of extra metadata from `compute_gradients()`.
"""
# Sample a batch from the given replay actor.
# Note that for better performance (less data sent through the
# network), this policy should be co-located on the same node
# as `replay_actor`. Such a co-location step is usually done during
# the Trainer's `setup()` phase.
batch = ray.get(replay_actor.replay.remote(policy_id=policy_id))
if batch is None:
return {}
# Send to own learn_on_batch method for updating.
# TODO: hack w/ `hasattr`
if hasattr(self, "devices") and len(self.devices) > 1:
self.load_batch_into_buffer(batch, buffer_index=0)
return self.learn_on_loaded_batch(offset=0, buffer_index=0)
else:
return self.learn_on_batch(batch)
@DeveloperAPI
def load_batch_into_buffer(self, batch: SampleBatch,
buffer_index: int = 0) -> int:
@ -606,7 +639,7 @@ class Policy(metaclass=ABCMeta):
return []
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
def get_state(self) -> PolicyState:
"""Returns the entire current state of this Policy.
Note: Not to be confused with an RNN model's internal state.
@ -626,10 +659,7 @@ class Policy(metaclass=ABCMeta):
return state
@DeveloperAPI
def set_state(
self,
state: Union[Dict[str, TensorType], List[TensorType]],
) -> None:
def set_state(self, state: PolicyState) -> None:
"""Restores the entire current state of this Policy from `state`.
Args:

View file

@ -8,11 +8,15 @@ NUM_AGENT_STEPS_TRAINED = "num_agent_steps_trained"
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
NUM_TARGET_UPDATES = "num_target_updates"
# Performance timers (keys for metrics.timers).
# Performance timers (keys for Trainer._timers or metrics.timers).
APPLY_GRADS_TIMER = "apply_grad"
COMPUTE_GRADS_TIMER = "compute_grads"
WORKER_UPDATE_TIMER = "update"
SYNCH_WORKER_WEIGHTS_TIMER = "synch_weights"
GRAD_WAIT_TIMER = "grad_wait"
SAMPLE_TIMER = "sample"
LEARN_ON_BATCH_TIMER = "learn"
LOAD_BATCH_TIMER = "load"
TARGET_NET_UPDATE_TIMER = "target_net_update"
# Deprecated: Use `SYNCH_WORKER_WEIGHTS_TIMER` instead.
WORKER_UPDATE_TIMER = "update"

View file

@ -2,27 +2,72 @@ import numpy as np
class WindowStat:
def __init__(self, name, n):
"""Handles/stores incoming datastream and provides window-based statistics.
Examples:
>>> win_stats = WindowStat("level", 3)
>>> win_stats.push(5.0)
>>> win_stats.push(7.0)
>>> win_stats.push(7.0)
>>> win_stats.push(10.0)
>>> # Expect 8.0 as the mean of the last 3 values: (7+7+10)/3=8.0
>>> print(win_stats.mean())
... 8.0
"""
def __init__(self, name: str, n: int):
"""Initializes a WindowStat instance.
Args:
name: The name of the stats to collect and return stats for.
n: The window size. Statistics will be computed for the last n
items received from the stream.
"""
# The window-size.
self.window_size = n
# The name of the data (used for `self.stats()`).
self.name = name
self.items = [None] * n
# List of items to do calculations over (len=self.n).
self.items = [None] * self.window_size
# The current index to insert the next item into `self.items`.
self.idx = 0
# How many items have been added over the lifetime of this object.
self.count = 0
def push(self, obj):
def push(self, obj) -> None:
"""Pushes a new value/object into the data buffer."""
# Insert object at current index.
self.items[self.idx] = obj
# Increase insertion index by 1.
self.idx += 1
# Increase lifetime count by 1.
self.count += 1
# Fix index in case of rollover.
self.idx %= len(self.items)
def stats(self):
def mean(self) -> float:
"""Returns the (NaN-)mean of the last `self.window_size` items.
"""
return float(np.nanmean(self.items[:self.count]))
def std(self) -> float:
"""Returns the (NaN)-stddev of the last `self.window_size` items.
"""
return float(np.nanstd(self.items[:self.count]))
def quantiles(self) -> np.ndarray:
"""Returns ndarray with 0, 10, 50, 90, and 100 percentiles.
"""
if not self.count:
_quantiles = []
return np.ndarray([], dtype=np.float32)
else:
_quantiles = np.nanpercentile(self.items[:self.count],
[0, 10, 50, 90, 100]).tolist()
return np.nanpercentile(self.items[:self.count],
[0, 10, 50, 90, 100]).tolist()
def stats(self):
return {
self.name + "_count": int(self.count),
self.name + "_mean": float(np.nanmean(self.items[:self.count])),
self.name + "_std": float(np.nanstd(self.items[:self.count])),
self.name + "_quantiles": _quantiles,
self.name + "_mean": self.mean(),
self.name + "_std": self.std(),
self.name + "_quantiles": self.quantiles(),
}