mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
746 lines
26 KiB
Python
746 lines
26 KiB
Python
"""Experimental distributed execution API.
|
|
|
|
TODO(ekl): describe the concepts."""
|
|
|
|
import logging
|
|
from typing import List, Any, Tuple, Union
|
|
import numpy as np
|
|
import queue
|
|
import random
|
|
import time
|
|
|
|
import ray
|
|
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
|
|
from ray.util.iter_metrics import MetricsContext
|
|
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer, \
|
|
ReplayBuffer
|
|
from ray.rllib.evaluation.metrics import collect_episodes, \
|
|
summarize_episodes, get_learner_stats
|
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
|
|
DEFAULT_POLICY_ID
|
|
from ray.rllib.utils.compression import pack_if_needed
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Counters for training progress (keys for metrics.counters).
|
|
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
|
|
STEPS_TRAINED_COUNTER = "num_steps_trained"
|
|
|
|
# Counters to track target network updates.
|
|
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
|
NUM_TARGET_UPDATES = "num_target_updates"
|
|
|
|
# Performance timers (keys for metrics.timers).
|
|
APPLY_GRADS_TIMER = "apply_grad"
|
|
COMPUTE_GRADS_TIMER = "compute_grads"
|
|
WORKER_UPDATE_TIMER = "update"
|
|
GRAD_WAIT_TIMER = "grad_wait"
|
|
SAMPLE_TIMER = "sample"
|
|
LEARN_ON_BATCH_TIMER = "learn"
|
|
|
|
# Instant metrics (keys for metrics.info).
|
|
LEARNER_INFO = "learner"
|
|
|
|
# Type aliases.
|
|
GradientType = dict
|
|
SampleBatchType = Union[SampleBatch, MultiAgentBatch]
|
|
|
|
|
|
# Asserts that an object is a type of SampleBatch.
|
|
def _check_sample_batch_type(batch):
|
|
if not isinstance(batch, SampleBatchType.__args__):
|
|
raise ValueError("Expected either SampleBatch or MultiAgentBatch, "
|
|
"got {}: {}".format(type(batch), batch))
|
|
|
|
|
|
# Returns pipeline global vars that should be periodically sent to each worker.
|
|
def _get_global_vars():
|
|
metrics = LocalIterator.get_metrics()
|
|
return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
|
|
|
|
|
|
def ParallelRollouts(workers: WorkerSet, mode="bulk_sync",
|
|
async_queue_depth=1) -> LocalIterator[SampleBatch]:
|
|
"""Operator to collect experiences in parallel from rollout workers.
|
|
|
|
If there are no remote workers, experiences will be collected serially from
|
|
the local worker instance instead.
|
|
|
|
Arguments:
|
|
workers (WorkerSet): set of rollout workers to use.
|
|
mode (str): One of {'async', 'bulk_sync'}.
|
|
- In 'async' mode, batches are returned as soon as they are
|
|
computed by rollout workers with no order guarantees.
|
|
- In 'bulk_sync' mode, we collect one batch from each worker
|
|
and concatenate them together into a large batch to return.
|
|
async_queue_depth (int): In async mode, the max number of async
|
|
requests in flight per actor.
|
|
|
|
Returns:
|
|
A local iterator over experiences collected in parallel.
|
|
|
|
Examples:
|
|
>>> rollouts = ParallelRollouts(workers, mode="async")
|
|
>>> batch = next(rollouts)
|
|
>>> print(batch.count)
|
|
50 # config.sample_batch_size
|
|
|
|
>>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
>>> batch = next(rollouts)
|
|
>>> print(batch.count)
|
|
200 # config.sample_batch_size * config.num_workers
|
|
|
|
Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
|
|
"""
|
|
|
|
# Ensure workers are initially in sync.
|
|
workers.sync_weights()
|
|
|
|
def report_timesteps(batch):
|
|
metrics = LocalIterator.get_metrics()
|
|
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
|
return batch
|
|
|
|
if not workers.remote_workers():
|
|
# Handle the serial sampling case.
|
|
def sampler(_):
|
|
while True:
|
|
yield workers.local_worker().sample()
|
|
|
|
return (LocalIterator(sampler, MetricsContext())
|
|
.for_each(report_timesteps))
|
|
|
|
# Create a parallel iterator over generated experiences.
|
|
rollouts = from_actors(workers.remote_workers())
|
|
|
|
if mode == "bulk_sync":
|
|
return rollouts \
|
|
.batch_across_shards() \
|
|
.for_each(lambda batches: SampleBatch.concat_samples(batches)) \
|
|
.for_each(report_timesteps)
|
|
elif mode == "async":
|
|
return rollouts.gather_async(
|
|
async_queue_depth=async_queue_depth).for_each(report_timesteps)
|
|
else:
|
|
raise ValueError(
|
|
"mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
|
|
|
|
|
|
def AsyncGradients(
|
|
workers: WorkerSet) -> LocalIterator[Tuple[GradientType, int]]:
|
|
"""Operator to compute gradients in parallel from rollout workers.
|
|
|
|
Arguments:
|
|
workers (WorkerSet): set of rollout workers to use.
|
|
|
|
Returns:
|
|
A local iterator over policy gradients computed on rollout workers.
|
|
|
|
Examples:
|
|
>>> grads_op = AsyncGradients(workers)
|
|
>>> print(next(grads_op))
|
|
{"var_0": ..., ...}, 50 # grads, batch count
|
|
|
|
Updates the STEPS_SAMPLED_COUNTER counter and LEARNER_INFO field in the
|
|
local iterator context.
|
|
"""
|
|
|
|
# Ensure workers are initially in sync.
|
|
workers.sync_weights()
|
|
|
|
# This function will be applied remotely on the workers.
|
|
def samples_to_grads(samples):
|
|
return get_global_worker().compute_gradients(samples), samples.count
|
|
|
|
# Record learner metrics and pass through (grads, count).
|
|
class record_metrics:
|
|
def _on_fetch_start(self):
|
|
self.fetch_start_time = time.perf_counter()
|
|
|
|
def __call__(self, item):
|
|
(grads, info), count = item
|
|
metrics = LocalIterator.get_metrics()
|
|
metrics.counters[STEPS_SAMPLED_COUNTER] += count
|
|
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
|
metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() -
|
|
self.fetch_start_time)
|
|
return grads, count
|
|
|
|
rollouts = from_actors(workers.remote_workers())
|
|
grads = rollouts.for_each(samples_to_grads)
|
|
return grads.gather_async().for_each(record_metrics())
|
|
|
|
|
|
def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet,
|
|
config: dict) -> LocalIterator[dict]:
|
|
"""Operator to periodically collect and report metrics.
|
|
|
|
Arguments:
|
|
train_op (LocalIterator): Operator for executing training steps.
|
|
We ignore the output values.
|
|
workers (WorkerSet): Rollout workers to collect metrics from.
|
|
config (dict): Trainer configuration, used to determine the frequency
|
|
of stats reporting.
|
|
|
|
Returns:
|
|
A local iterator over training results.
|
|
|
|
Examples:
|
|
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
|
>>> metrics_op = StandardMetricsReporting(train_op, workers, config)
|
|
>>> next(metrics_op)
|
|
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
|
|
"""
|
|
|
|
output_op = train_op \
|
|
.filter(OncePerTimeInterval(max(2, config["min_iter_time_s"]))) \
|
|
.for_each(CollectMetrics(
|
|
workers, min_history=config["metrics_smoothing_episodes"],
|
|
timeout_seconds=config["collect_metrics_timeout"]))
|
|
return output_op
|
|
|
|
|
|
class ConcatBatches:
|
|
"""Callable used to merge batches into larger batches for training.
|
|
|
|
This should be used with the .combine() operator.
|
|
|
|
Examples:
|
|
>>> rollouts = ParallelRollouts(...)
|
|
>>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000))
|
|
>>> print(next(rollouts).count)
|
|
10000
|
|
"""
|
|
|
|
def __init__(self, min_batch_size: int):
|
|
self.min_batch_size = min_batch_size
|
|
self.buffer = []
|
|
self.count = 0
|
|
self.batch_start_time = None
|
|
|
|
def _on_fetch_start(self):
|
|
if self.batch_start_time is None:
|
|
self.batch_start_time = time.perf_counter()
|
|
|
|
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
|
|
_check_sample_batch_type(batch)
|
|
self.buffer.append(batch)
|
|
self.count += batch.count
|
|
if self.count >= self.min_batch_size:
|
|
out = SampleBatch.concat_samples(self.buffer)
|
|
timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
|
|
timer.push(time.perf_counter() - self.batch_start_time)
|
|
timer.push_units_processed(self.count)
|
|
self.batch_start_time = None
|
|
self.buffer = []
|
|
self.count = 0
|
|
return [out]
|
|
return []
|
|
|
|
|
|
class TrainOneStep:
|
|
"""Callable that improves the policy and updates workers.
|
|
|
|
This should be used with the .for_each() operator.
|
|
|
|
Examples:
|
|
>>> rollouts = ParallelRollouts(...)
|
|
>>> train_op = rollouts.for_each(TrainOneStep(workers))
|
|
>>> print(next(train_op)) # This trains the policy on one batch.
|
|
None
|
|
|
|
Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
|
|
local iterator context.
|
|
"""
|
|
|
|
def __init__(self, workers: WorkerSet):
|
|
self.workers = workers
|
|
|
|
def __call__(self, batch: SampleBatchType) -> List[dict]:
|
|
_check_sample_batch_type(batch)
|
|
metrics = LocalIterator.get_metrics()
|
|
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
|
with learn_timer:
|
|
info = self.workers.local_worker().learn_on_batch(batch)
|
|
learn_timer.push_units_processed(batch.count)
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
|
|
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
|
if self.workers.remote_workers():
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
|
weights = ray.put(self.workers.local_worker().get_weights())
|
|
for e in self.workers.remote_workers():
|
|
e.set_weights.remote(weights, _get_global_vars())
|
|
# Also update global vars of the local worker.
|
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
|
return info
|
|
|
|
|
|
class CollectMetrics:
|
|
"""Callable that collects metrics from workers.
|
|
|
|
The metrics are smoothed over a given history window.
|
|
|
|
This should be used with the .for_each() operator. For a higher level
|
|
API, consider using StandardMetricsReporting instead.
|
|
|
|
Examples:
|
|
>>> output_op = train_op.for_each(CollectMetrics(workers))
|
|
>>> print(next(output_op))
|
|
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
|
|
"""
|
|
|
|
def __init__(self, workers, min_history=100, timeout_seconds=180):
|
|
self.workers = workers
|
|
self.episode_history = []
|
|
self.to_be_collected = []
|
|
self.min_history = min_history
|
|
self.timeout_seconds = timeout_seconds
|
|
|
|
def __call__(self, _):
|
|
# Collect worker metrics.
|
|
episodes, self.to_be_collected = collect_episodes(
|
|
self.workers.local_worker(),
|
|
self.workers.remote_workers(),
|
|
self.to_be_collected,
|
|
timeout_seconds=self.timeout_seconds)
|
|
orig_episodes = list(episodes)
|
|
missing = self.min_history - len(episodes)
|
|
if missing > 0:
|
|
episodes.extend(self.episode_history[-missing:])
|
|
assert len(episodes) <= self.min_history
|
|
self.episode_history.extend(orig_episodes)
|
|
self.episode_history = self.episode_history[-self.min_history:]
|
|
res = summarize_episodes(episodes, orig_episodes)
|
|
|
|
# Add in iterator metrics.
|
|
metrics = LocalIterator.get_metrics()
|
|
if metrics.parent_metrics:
|
|
print("TODO: support nested metrics better")
|
|
all_metrics = [metrics] + metrics.parent_metrics
|
|
timers = {}
|
|
counters = {}
|
|
info = {}
|
|
for metrics in all_metrics:
|
|
info.update(metrics.info)
|
|
for k, counter in metrics.counters.items():
|
|
counters[k] = counter
|
|
for k, timer in metrics.timers.items():
|
|
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
|
|
if timer.has_units_processed():
|
|
timers["{}_throughput".format(k)] = round(
|
|
timer.mean_throughput, 3)
|
|
res.update({
|
|
"num_healthy_workers": len(self.workers.remote_workers()),
|
|
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
|
|
})
|
|
res["timers"] = timers
|
|
res["info"] = info
|
|
res["info"].update(counters)
|
|
return res
|
|
|
|
|
|
class OncePerTimeInterval:
|
|
"""Callable that returns True once per given interval.
|
|
|
|
This should be used with the .filter() operator to throttle / rate-limit
|
|
metrics reporting. For a higher-level API, consider using
|
|
StandardMetricsReporting instead.
|
|
|
|
Examples:
|
|
>>> throttled_op = train_op.filter(OncePerTimeInterval(5))
|
|
>>> start = time.time()
|
|
>>> next(throttled_op)
|
|
>>> print(time.time() - start)
|
|
5.00001 # will be greater than 5 seconds
|
|
"""
|
|
|
|
def __init__(self, delay):
|
|
self.delay = delay
|
|
self.last_called = 0
|
|
|
|
def __call__(self, item):
|
|
now = time.time()
|
|
if now - self.last_called > self.delay:
|
|
self.last_called = now
|
|
return True
|
|
return False
|
|
|
|
|
|
class ComputeGradients:
|
|
"""Callable that computes gradients with respect to the policy loss.
|
|
|
|
This should be used with the .for_each() operator.
|
|
|
|
Examples:
|
|
>>> grads_op = rollouts.for_each(ComputeGradients(workers))
|
|
>>> print(next(grads_op))
|
|
{"var_0": ..., ...}, 50 # grads, batch count
|
|
|
|
Updates the LEARNER_INFO info field in the local iterator context.
|
|
"""
|
|
|
|
def __init__(self, workers):
|
|
self.workers = workers
|
|
|
|
def __call__(self, samples: SampleBatchType):
|
|
_check_sample_batch_type(samples)
|
|
metrics = LocalIterator.get_metrics()
|
|
with metrics.timers[COMPUTE_GRADS_TIMER]:
|
|
grad, info = self.workers.local_worker().compute_gradients(samples)
|
|
metrics.info[LEARNER_INFO] = get_learner_stats(info)
|
|
return grad, samples.count
|
|
|
|
|
|
class ApplyGradients:
|
|
"""Callable that applies gradients and updates workers.
|
|
|
|
This should be used with the .for_each() operator.
|
|
|
|
Examples:
|
|
>>> apply_op = grads_op.for_each(ApplyGradients(workers))
|
|
>>> print(next(apply_op))
|
|
None
|
|
|
|
Updates the STEPS_TRAINED_COUNTER counter in the local iterator context.
|
|
"""
|
|
|
|
def __init__(self, workers, update_all=True):
|
|
"""Creates an ApplyGradients instance.
|
|
|
|
Arguments:
|
|
workers (WorkerSet): workers to apply gradients to.
|
|
update_all (bool): If true, updates all workers. Otherwise, only
|
|
update the worker that produced the sample batch we are
|
|
currently processing (i.e., A3C style).
|
|
"""
|
|
self.workers = workers
|
|
self.update_all = update_all
|
|
|
|
def __call__(self, item):
|
|
if not isinstance(item, tuple) or len(item) != 2:
|
|
raise ValueError(
|
|
"Input must be a tuple of (grad_dict, count), got {}".format(
|
|
item))
|
|
gradients, count = item
|
|
metrics = LocalIterator.get_metrics()
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
|
|
|
apply_timer = metrics.timers[APPLY_GRADS_TIMER]
|
|
with apply_timer:
|
|
self.workers.local_worker().apply_gradients(gradients)
|
|
apply_timer.push_units_processed(count)
|
|
|
|
# Also update global vars of the local worker.
|
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
|
|
|
if self.update_all:
|
|
if self.workers.remote_workers():
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
|
weights = ray.put(
|
|
self.workers.local_worker().get_weights())
|
|
for e in self.workers.remote_workers():
|
|
e.set_weights.remote(weights, _get_global_vars())
|
|
else:
|
|
if metrics.current_actor is None:
|
|
raise ValueError(
|
|
"Could not find actor to update. When "
|
|
"update_all=False, `current_actor` must be set "
|
|
"in the iterator context.")
|
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
|
weights = self.workers.local_worker().get_weights()
|
|
metrics.current_actor.set_weights.remote(
|
|
weights, _get_global_vars())
|
|
|
|
|
|
class AverageGradients:
|
|
"""Callable that averages the gradients in a batch.
|
|
|
|
This should be used with the .for_each() operator after a set of gradients
|
|
have been batched with .batch().
|
|
|
|
Examples:
|
|
>>> batched_grads = grads_op.batch(32)
|
|
>>> avg_grads = batched_grads.for_each(AverageGradients())
|
|
>>> print(next(avg_grads))
|
|
{"var_0": ..., ...}, 1600 # averaged grads, summed batch count
|
|
"""
|
|
|
|
def __call__(self, gradients):
|
|
acc = None
|
|
sum_count = 0
|
|
for grad, count in gradients:
|
|
if acc is None:
|
|
acc = grad
|
|
else:
|
|
acc = [a + b for a, b in zip(acc, grad)]
|
|
sum_count += count
|
|
logger.info("Computing average of {} microbatch gradients "
|
|
"({} samples total)".format(len(gradients), sum_count))
|
|
return acc, sum_count
|
|
|
|
|
|
class StoreToReplayBuffer:
|
|
"""Callable that stores data into a local replay buffer.
|
|
|
|
This should be used with the .for_each() operator on a rollouts iterator.
|
|
The batch that was stored is returned.
|
|
|
|
Examples:
|
|
>>> buf = ReplayBuffer(1000)
|
|
>>> rollouts = ParallelRollouts(...)
|
|
>>> store_op = rollouts.for_each(StoreToReplayBuffer(buf))
|
|
>>> next(store_op)
|
|
SampleBatch(...)
|
|
"""
|
|
|
|
def __init__(self, replay_buffer: ReplayBuffer):
|
|
assert isinstance(replay_buffer, ReplayBuffer)
|
|
self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
|
|
|
|
def __call__(self, batch: SampleBatchType):
|
|
# Handle everything as if multiagent
|
|
if isinstance(batch, SampleBatch):
|
|
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
|
|
|
for policy_id, s in batch.policy_batches.items():
|
|
for row in s.rows():
|
|
self.replay_buffers[policy_id].add(
|
|
pack_if_needed(row["obs"]),
|
|
row["actions"],
|
|
row["rewards"],
|
|
pack_if_needed(row["new_obs"]),
|
|
row["dones"],
|
|
weight=None)
|
|
return batch
|
|
|
|
|
|
class StoreToReplayActors:
|
|
"""Callable that stores data into a replay buffer actors.
|
|
|
|
This should be used with the .for_each() operator on a rollouts iterator.
|
|
The batch that was stored is returned.
|
|
|
|
Examples:
|
|
>>> actors = [ReplayActor.remote() for _ in range(4)]
|
|
>>> rollouts = ParallelRollouts(...)
|
|
>>> store_op = rollouts.for_each(StoreToReplayActors(actors))
|
|
>>> next(store_op)
|
|
SampleBatch(...)
|
|
"""
|
|
|
|
def __init__(self, replay_actors: List["ActorHandle"]):
|
|
self.replay_actors = replay_actors
|
|
|
|
def __call__(self, batch: SampleBatchType):
|
|
actor = random.choice(self.replay_actors)
|
|
actor.add_batch.remote(batch)
|
|
return batch
|
|
|
|
|
|
def ParallelReplay(replay_actors: List["ActorHandle"], async_queue_depth=4):
|
|
"""Replay experiences in parallel from the given actors.
|
|
|
|
This should be combined with the StoreToReplayActors operation using the
|
|
Concurrently() operator.
|
|
|
|
Arguments:
|
|
replay_actors (list): List of replay actors.
|
|
async_queue_depth (int): In async mode, the max number of async
|
|
requests in flight per actor.
|
|
|
|
Examples:
|
|
>>> actors = [ReplayActor.remote() for _ in range(4)]
|
|
>>> replay_op = ParallelReplay(actors)
|
|
>>> next(replay_op)
|
|
SampleBatch(...)
|
|
"""
|
|
replay = from_actors(replay_actors)
|
|
return replay.gather_async(
|
|
async_queue_depth=async_queue_depth).filter(lambda x: x is not None)
|
|
|
|
|
|
def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int):
|
|
"""Replay experiences from a local buffer instance.
|
|
|
|
This should be combined with the StoreToReplayBuffer operation using the
|
|
Concurrently() operator.
|
|
|
|
Arguments:
|
|
replay_buffer (ReplayBuffer): Buffer to replay experiences from.
|
|
train_batch_size (int): Batch size of fetches from the buffer.
|
|
|
|
Examples:
|
|
>>> actors = [ReplayActor.remote() for _ in range(4)]
|
|
>>> replay_op = ParallelReplay(actors)
|
|
>>> next(replay_op)
|
|
SampleBatch(...)
|
|
"""
|
|
assert isinstance(replay_buffer, ReplayBuffer)
|
|
replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
|
|
# TODO(ekl) support more options, or combine with ParallelReplay (?)
|
|
synchronize_sampling = False
|
|
prioritized_replay_beta = None
|
|
|
|
def gen_replay(timeout):
|
|
while True:
|
|
samples = {}
|
|
idxes = None
|
|
for policy_id, replay_buffer in replay_buffers.items():
|
|
if synchronize_sampling:
|
|
if idxes is None:
|
|
idxes = replay_buffer.sample_idxes(train_batch_size)
|
|
else:
|
|
idxes = replay_buffer.sample_idxes(train_batch_size)
|
|
|
|
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
|
metrics = LocalIterator.get_metrics()
|
|
num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
|
|
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
|
batch_indexes) = replay_buffer.sample_with_idxes(
|
|
idxes,
|
|
beta=prioritized_replay_beta.value(num_steps_trained))
|
|
else:
|
|
(obses_t, actions, rewards, obses_tp1,
|
|
dones) = replay_buffer.sample_with_idxes(idxes)
|
|
weights = np.ones_like(rewards)
|
|
batch_indexes = -np.ones_like(rewards)
|
|
samples[policy_id] = SampleBatch({
|
|
"obs": obses_t,
|
|
"actions": actions,
|
|
"rewards": rewards,
|
|
"new_obs": obses_tp1,
|
|
"dones": dones,
|
|
"weights": weights,
|
|
"batch_indexes": batch_indexes
|
|
})
|
|
yield MultiAgentBatch(samples, train_batch_size)
|
|
|
|
return LocalIterator(gen_replay, MetricsContext())
|
|
|
|
|
|
def Concurrently(ops: List[LocalIterator], mode="round_robin"):
|
|
"""Operator that runs the given parent iterators concurrently.
|
|
|
|
Arguments:
|
|
mode (str): One of {'round_robin', 'async'}.
|
|
- In 'round_robin' mode, we alternate between pulling items from
|
|
each parent iterator in order deterministically.
|
|
- In 'async' mode, we pull from each parent iterator as fast as
|
|
they are produced. This is non-deterministic.
|
|
|
|
>>> sim_op = ParallelRollouts(...).for_each(...)
|
|
>>> replay_op = LocalReplay(...).for_each(...)
|
|
>>> combined_op = Concurrently([sim_op, replay_op])
|
|
"""
|
|
|
|
if len(ops) < 2:
|
|
raise ValueError("Should specify at least 2 ops.")
|
|
if mode == "round_robin":
|
|
deterministic = True
|
|
elif mode == "async":
|
|
deterministic = False
|
|
else:
|
|
raise ValueError("Unknown mode {}".format(mode))
|
|
return ops[0].union(*ops[1:], deterministic=deterministic)
|
|
|
|
|
|
class UpdateTargetNetwork:
|
|
"""Periodically call policy.update_target() on all trainable policies.
|
|
|
|
This should be used with the .for_each() operator after training step
|
|
has been taken.
|
|
|
|
Examples:
|
|
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
|
>>> update_op = train_op.for_each(
|
|
... UpdateTargetIfNeeded(workers, target_update_freq=500))
|
|
>>> print(next(update_op))
|
|
None
|
|
|
|
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
|
|
local iterator context. The value of the last update counter is used to
|
|
track when we should update the target next.
|
|
"""
|
|
|
|
def __init__(self, workers, target_update_freq, by_steps_trained=False):
|
|
self.workers = workers
|
|
self.target_update_freq = target_update_freq
|
|
if by_steps_trained:
|
|
self.metric = STEPS_TRAINED_COUNTER
|
|
else:
|
|
self.metric = STEPS_SAMPLED_COUNTER
|
|
|
|
def __call__(self, _):
|
|
metrics = LocalIterator.get_metrics()
|
|
cur_ts = metrics.counters[self.metric]
|
|
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
|
if cur_ts - last_update > self.target_update_freq:
|
|
self.workers.local_worker().foreach_trainable_policy(
|
|
lambda p, _: p.update_target())
|
|
metrics.counters[NUM_TARGET_UPDATES] += 1
|
|
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
|
|
|
|
|
class Enqueue:
|
|
"""Enqueue data items into a queue.Queue instance.
|
|
|
|
The enqueue is non-blocking, so Enqueue operations can executed with
|
|
Dequeue via the Concurrently() operator.
|
|
|
|
Examples:
|
|
>>> queue = queue.Queue(100)
|
|
>>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
|
|
>>> read_op = Dequeue(queue)
|
|
>>> combined_op = Concurrently([write_op, read_op], mode="async")
|
|
>>> next(combined_op)
|
|
SampleBatch(...)
|
|
"""
|
|
|
|
def __init__(self, output_queue: queue.Queue):
|
|
if not isinstance(output_queue, queue.Queue):
|
|
raise ValueError("Expected queue.Queue, got {}".format(
|
|
type(output_queue)))
|
|
self.queue = output_queue
|
|
|
|
def __call__(self, x):
|
|
try:
|
|
self.queue.put_nowait(x)
|
|
except queue.Full:
|
|
return _NextValueNotReady()
|
|
|
|
|
|
def Dequeue(input_queue: queue.Queue, check=lambda: True):
|
|
"""Dequeue data items from a queue.Queue instance.
|
|
|
|
The dequeue is non-blocking, so Dequeue operations can executed with
|
|
Enqueue via the Concurrently() operator.
|
|
|
|
Arguments:
|
|
input_queue (Queue): queue to pull items from.
|
|
check (fn): liveness check. When this function returns false,
|
|
Dequeue() will raise an error to halt execution.
|
|
|
|
Examples:
|
|
>>> queue = queue.Queue(100)
|
|
>>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
|
|
>>> read_op = Dequeue(queue)
|
|
>>> combined_op = Concurrently([write_op, read_op], mode="async")
|
|
>>> next(combined_op)
|
|
SampleBatch(...)
|
|
"""
|
|
if not isinstance(input_queue, queue.Queue):
|
|
raise ValueError("Expected queue.Queue, got {}".format(
|
|
type(input_queue)))
|
|
|
|
def base_iterator(timeout=None):
|
|
while check():
|
|
try:
|
|
item = input_queue.get_nowait()
|
|
yield item
|
|
except queue.Empty:
|
|
yield _NextValueNotReady()
|
|
raise RuntimeError("Error raised reading from queue")
|
|
|
|
return LocalIterator(base_iterator, MetricsContext())
|