mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] Port DQN/Ape-X to training workflow api (#8077)
This commit is contained in:
parent
499ad5fbe4
commit
2298f6fb40
15 changed files with 310 additions and 161 deletions
|
@ -1,11 +1,12 @@
|
|||
import collections
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.execution.common import STEPS_TRAINED_COUNTER
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayActors, ParallelReplay
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.train_ops import UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.optimizers import AsyncReplayOptimizer
|
||||
|
@ -144,7 +145,7 @@ def execution_plan(workers, config):
|
|||
# the weights of the worker that generated the batch.
|
||||
rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2)
|
||||
store_op = rollouts \
|
||||
.for_each(StoreToReplayActors(replay_actors)) \
|
||||
.for_each(StoreToReplayBuffer(actors=replay_actors)) \
|
||||
.zip_with_source_actor() \
|
||||
.for_each(UpdateWorkerWeights(
|
||||
learner_thread, workers,
|
||||
|
@ -153,7 +154,7 @@ def execution_plan(workers, config):
|
|||
|
||||
# (2) Read experiences from the replay buffer actors and send to the
|
||||
# learner thread via its in-queue.
|
||||
replay_op = ParallelReplay(replay_actors, async_queue_depth=4) \
|
||||
replay_op = Replay(actors=replay_actors, async_queue_depth=4) \
|
||||
.zip_with_source_actor() \
|
||||
.for_each(Enqueue(learner_thread.inqueue))
|
||||
|
||||
|
@ -166,10 +167,32 @@ def execution_plan(workers, config):
|
|||
workers, config["target_network_update_freq"],
|
||||
by_steps_trained=True))
|
||||
|
||||
# Execute (1), (2), (3) asynchronously as fast as possible.
|
||||
merged_op = Concurrently([store_op, replay_op, update_op], mode="async")
|
||||
# Execute (1), (2), (3) asynchronously as fast as possible. Only output
|
||||
# items from (3) since metrics aren't available before then.
|
||||
merged_op = Concurrently(
|
||||
[store_op, replay_op, update_op], mode="async", output_indexes=[2])
|
||||
|
||||
return StandardMetricsReporting(merged_op, workers, config)
|
||||
# Add in extra replay and learner metrics to the training result.
|
||||
def add_apex_metrics(result):
|
||||
replay_stats = ray.get(replay_actors[0].stats.remote(
|
||||
config["optimizer"].get("debug")))
|
||||
exploration_infos = workers.foreach_trainable_policy(
|
||||
lambda p, _: p.get_exploration_info())
|
||||
result["info"].update({
|
||||
"exploration_infos": exploration_infos,
|
||||
"learner_queue": learner_thread.learner_queue_size.stats(),
|
||||
"learner": copy.deepcopy(learner_thread.stats),
|
||||
"replay_shard_0": replay_stats,
|
||||
})
|
||||
return result
|
||||
|
||||
# Only report metrics from the workers with the lowest 1/3 of epsilons.
|
||||
selected_workers = workers.remote_workers()[
|
||||
-len(workers.remote_workers()) // 3:]
|
||||
|
||||
return StandardMetricsReporting(
|
||||
merged_op, workers, config,
|
||||
selected_workers=selected_workers).for_each(add_apex_metrics)
|
||||
|
||||
|
||||
APEX_TRAINER_PROPERTIES = {
|
||||
|
|
|
@ -5,12 +5,13 @@ from ray.rllib.agents.trainer_template import build_trainer
|
|||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||
from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
||||
from ray.rllib.optimizers import SyncReplayOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, LocalReplay
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
|
||||
|
@ -125,6 +126,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"soft_q": DEPRECATED_VALUE,
|
||||
"parameter_noise": DEPRECATED_VALUE,
|
||||
"grad_norm_clipping": DEPRECATED_VALUE,
|
||||
|
||||
# Use the execution plan API instead of policy optimizers.
|
||||
"use_exec_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -297,24 +301,52 @@ def update_target_if_needed(trainer, fetches):
|
|||
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
local_replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||
local_replay_buffer = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
prioritized_replay_alpha=config["prioritized_replay_alpha"],
|
||||
prioritized_replay_beta=config["prioritized_replay_beta"],
|
||||
prioritized_replay_eps=config["prioritized_replay_eps"])
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# We execute the following steps concurrently:
|
||||
# (1) Generate rollouts and store them in our local replay buffer. Calling
|
||||
# next() on store_op drives this.
|
||||
store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer))
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer))
|
||||
|
||||
def update_prio(item):
|
||||
samples, info_dict = item
|
||||
if config["prioritized_replay"]:
|
||||
prio_dict = {}
|
||||
for policy_id, info in info_dict.items():
|
||||
# TODO(sven): This is currently structured differently for
|
||||
# torch/tf. Clean up these results/info dicts across
|
||||
# policies (note: fixing this in torch_policy.py will
|
||||
# break e.g. DDPPO!).
|
||||
td_error = info.get("td_error",
|
||||
info[LEARNER_STATS_KEY].get("td_error"))
|
||||
prio_dict[policy_id] = (samples.policy_batches[policy_id]
|
||||
.data.get("batch_indexes"), td_error)
|
||||
local_replay_buffer.update_priorities(prio_dict)
|
||||
return info_dict
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step, and then we decide whether to update the target network.
|
||||
replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(update_prio) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
||||
# Alternate deterministically between (1) and (2).
|
||||
train_op = Concurrently([store_op, replay_op], mode="round_robin")
|
||||
# Alternate deterministically between (1) and (2). Only return the output
|
||||
# of (2) since training metrics are not available until (2) runs.
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
|
|
@ -19,8 +19,9 @@ class TestApex(unittest.TestCase):
|
|||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["prioritized_replay"] = True
|
||||
config["timesteps_per_iteration"] = 100
|
||||
config["min_iter_time_s"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
num_iterations = 1
|
||||
|
||||
for _ in framework_iterator(config, ("torch", "tf", "eager")):
|
||||
plain_config = config.copy()
|
||||
|
@ -30,12 +31,14 @@ class TestApex(unittest.TestCase):
|
|||
infos = trainer.workers.foreach_policy(
|
||||
lambda p, _: p.get_exploration_info())
|
||||
eps = [i["cur_epsilon"] for i in infos]
|
||||
assert np.allclose(eps,
|
||||
[1.0, 0.016190862, 0.00065536, 2.6527108e-05])
|
||||
assert np.allclose(eps, [0.0, 0.4, 0.016190862, 0.00065536])
|
||||
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
# TODO(ekl) fix iterator metrics bugs w/multiple trainers.
|
||||
# for i in range(1):
|
||||
# results = trainer.train()
|
||||
# print(results)
|
||||
|
||||
trainer.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -258,7 +258,7 @@ COMMON_CONFIG = {
|
|||
"min_iter_time_s": 0,
|
||||
# Minimum env steps to optimize for per train call. This value does
|
||||
# not affect learning, only the length of train iterations.
|
||||
"timesteps_per_iteration": 0, # TODO(ekl) deprecate this
|
||||
"timesteps_per_iteration": 0,
|
||||
# This argument, in conjunction with worker_index, sets the random seed of
|
||||
# each worker, so that identically configured trials will have identical
|
||||
# results. This makes experiments reproducible.
|
||||
|
|
|
@ -173,10 +173,10 @@ def build_trainer(name,
|
|||
|
||||
def _train_exec_impl(self):
|
||||
if before_train_step:
|
||||
logger.warning("Ignoring before_train_step callback")
|
||||
logger.debug("Ignoring before_train_step callback")
|
||||
res = next(self.train_exec_impl)
|
||||
if after_train_result:
|
||||
logger.warning("Ignoring after_train_result callback")
|
||||
logger.debug("Ignoring after_train_result callback")
|
||||
return res
|
||||
|
||||
@override(Trainer)
|
||||
|
|
|
@ -5,7 +5,10 @@ from ray.util.iter import LocalIterator, _NextValueNotReady
|
|||
from ray.util.iter_metrics import SharedMetrics
|
||||
|
||||
|
||||
def Concurrently(ops: List[LocalIterator], *, mode="round_robin"):
|
||||
def Concurrently(ops: List[LocalIterator],
|
||||
*,
|
||||
mode="round_robin",
|
||||
output_indexes=None):
|
||||
"""Operator that runs the given parent iterators concurrently.
|
||||
|
||||
Arguments:
|
||||
|
@ -14,6 +17,9 @@ def Concurrently(ops: List[LocalIterator], *, mode="round_robin"):
|
|||
each parent iterator in order deterministically.
|
||||
- In 'async' mode, we pull from each parent iterator as fast as
|
||||
they are produced. This is non-deterministic.
|
||||
output_indexes (list): If specified, only output results from the
|
||||
given ops. For example, if output_indexes=[0], only results from
|
||||
the first op in ops will be returned.
|
||||
|
||||
>>> sim_op = ParallelRollouts(...).for_each(...)
|
||||
>>> replay_op = LocalReplay(...).for_each(...)
|
||||
|
@ -28,7 +34,23 @@ def Concurrently(ops: List[LocalIterator], *, mode="round_robin"):
|
|||
deterministic = False
|
||||
else:
|
||||
raise ValueError("Unknown mode {}".format(mode))
|
||||
return ops[0].union(*ops[1:], deterministic=deterministic)
|
||||
|
||||
if output_indexes:
|
||||
for i in output_indexes:
|
||||
assert i in range(len(ops)), ("Index out of range", i)
|
||||
|
||||
def tag(op, i):
|
||||
return op.for_each(lambda x: (i, x))
|
||||
|
||||
ops = [tag(op, i) for i, op in enumerate(ops)]
|
||||
|
||||
output = ops[0].union(*ops[1:], deterministic=deterministic)
|
||||
|
||||
if output_indexes:
|
||||
output = (output.filter(lambda tup: tup[0] in output_indexes)
|
||||
.for_each(lambda tup: tup[1]))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Enqueue:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any
|
||||
from typing import Any, List
|
||||
import time
|
||||
|
||||
from ray.util.iter import LocalIterator
|
||||
|
@ -7,8 +7,11 @@ from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER
|
|||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
||||
|
||||
def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet,
|
||||
config: dict) -> LocalIterator[dict]:
|
||||
def StandardMetricsReporting(
|
||||
train_op: LocalIterator[Any],
|
||||
workers: WorkerSet,
|
||||
config: dict,
|
||||
selected_workers: List["ActorHandle"] = None) -> LocalIterator[dict]:
|
||||
"""Operator to periodically collect and report metrics.
|
||||
|
||||
Arguments:
|
||||
|
@ -17,6 +20,8 @@ def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet,
|
|||
workers (WorkerSet): Rollout workers to collect metrics from.
|
||||
config (dict): Trainer configuration, used to determine the frequency
|
||||
of stats reporting.
|
||||
selected_workers (list): Override the list of remote workers
|
||||
to collect metrics from.
|
||||
|
||||
Returns:
|
||||
A local iterator over training results.
|
||||
|
@ -29,10 +34,12 @@ def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet,
|
|||
"""
|
||||
|
||||
output_op = train_op \
|
||||
.filter(OncePerTimeInterval(max(2, config["min_iter_time_s"]))) \
|
||||
.filter(OncePerTimestepsElapsed(config["timesteps_per_iteration"])) \
|
||||
.filter(OncePerTimeInterval(config["min_iter_time_s"])) \
|
||||
.for_each(CollectMetrics(
|
||||
workers, min_history=config["metrics_smoothing_episodes"],
|
||||
timeout_seconds=config["collect_metrics_timeout"]))
|
||||
timeout_seconds=config["collect_metrics_timeout"],
|
||||
selected_workers=selected_workers))
|
||||
return output_op
|
||||
|
||||
|
||||
|
@ -50,18 +57,23 @@ class CollectMetrics:
|
|||
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
|
||||
"""
|
||||
|
||||
def __init__(self, workers, min_history=100, timeout_seconds=180):
|
||||
def __init__(self,
|
||||
workers,
|
||||
min_history=100,
|
||||
timeout_seconds=180,
|
||||
selected_workers: List["ActorHandle"] = None):
|
||||
self.workers = workers
|
||||
self.episode_history = []
|
||||
self.to_be_collected = []
|
||||
self.min_history = min_history
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.selected_workers = selected_workers
|
||||
|
||||
def __call__(self, _):
|
||||
# Collect worker metrics.
|
||||
episodes, self.to_be_collected = collect_episodes(
|
||||
self.workers.local_worker(),
|
||||
self.workers.remote_workers(),
|
||||
self.selected_workers or self.workers.remote_workers(),
|
||||
self.to_be_collected,
|
||||
timeout_seconds=self.timeout_seconds)
|
||||
orig_episodes = list(episodes)
|
||||
|
@ -116,8 +128,38 @@ class OncePerTimeInterval:
|
|||
self.last_called = 0
|
||||
|
||||
def __call__(self, item):
|
||||
if self.delay <= 0.0:
|
||||
return True
|
||||
now = time.time()
|
||||
if now - self.last_called > self.delay:
|
||||
self.last_called = now
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class OncePerTimestepsElapsed:
|
||||
"""Callable that returns True once per given number of timesteps.
|
||||
|
||||
This should be used with the .filter() operator to throttle / rate-limit
|
||||
metrics reporting. For a higher-level API, consider using
|
||||
StandardMetricsReporting instead.
|
||||
|
||||
Examples:
|
||||
>>> throttled_op = train_op.filter(OncePerTimestepsElapsed(1000))
|
||||
>>> next(throttled_op)
|
||||
# will only return after 1000 steps have elapsed
|
||||
"""
|
||||
|
||||
def __init__(self, delay_steps):
|
||||
self.delay_steps = delay_steps
|
||||
self.last_called = 0
|
||||
|
||||
def __call__(self, item):
|
||||
if self.delay_steps <= 0:
|
||||
return True
|
||||
metrics = LocalIterator.get_metrics()
|
||||
now = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||
if now - self.last_called > self.delay_steps:
|
||||
self.last_called = now
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -1,54 +1,18 @@
|
|||
from typing import List
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from ray.util.iter import from_actors, LocalIterator
|
||||
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
|
||||
from ray.util.iter_metrics import SharedMetrics
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer, \
|
||||
ReplayBuffer
|
||||
from ray.rllib.execution.common import SampleBatchType, STEPS_TRAINED_COUNTER
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer
|
||||
from ray.rllib.execution.common import SampleBatchType
|
||||
|
||||
|
||||
class StoreToReplayBuffer:
|
||||
"""Callable that stores data into a local replay buffer.
|
||||
"""Callable that stores data into replay buffer actors.
|
||||
|
||||
This should be used with the .for_each() operator on a rollouts iterator.
|
||||
The batch that was stored is returned.
|
||||
|
||||
Examples:
|
||||
>>> buf = ReplayBuffer(1000)
|
||||
>>> rollouts = ParallelRollouts(...)
|
||||
>>> store_op = rollouts.for_each(StoreToReplayBuffer(buf))
|
||||
>>> next(store_op)
|
||||
SampleBatch(...)
|
||||
"""
|
||||
|
||||
def __init__(self, replay_buffer: ReplayBuffer):
|
||||
assert isinstance(replay_buffer, ReplayBuffer)
|
||||
self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
|
||||
|
||||
def __call__(self, batch: SampleBatchType):
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
||||
|
||||
for policy_id, s in batch.policy_batches.items():
|
||||
for row in s.rows():
|
||||
self.replay_buffers[policy_id].add(
|
||||
pack_if_needed(row["obs"]),
|
||||
row["actions"],
|
||||
row["rewards"],
|
||||
pack_if_needed(row["new_obs"]),
|
||||
row["dones"],
|
||||
weight=None)
|
||||
return batch
|
||||
|
||||
|
||||
class StoreToReplayActors:
|
||||
"""Callable that stores data into a replay buffer actors.
|
||||
If constructed with a local replay actor, data will be stored into that
|
||||
buffer. If constructed with a list of replay actor handles, data will
|
||||
be stored randomly among those actors.
|
||||
|
||||
This should be used with the .for_each() operator on a rollouts iterator.
|
||||
The batch that was stored is returned.
|
||||
|
@ -56,96 +20,74 @@ class StoreToReplayActors:
|
|||
Examples:
|
||||
>>> actors = [ReplayActor.remote() for _ in range(4)]
|
||||
>>> rollouts = ParallelRollouts(...)
|
||||
>>> store_op = rollouts.for_each(StoreToReplayActors(actors))
|
||||
>>> store_op = rollouts.for_each(StoreToReplayActors(actors=actors))
|
||||
>>> next(store_op)
|
||||
SampleBatch(...)
|
||||
"""
|
||||
|
||||
def __init__(self, replay_actors: List["ActorHandle"]):
|
||||
self.replay_actors = replay_actors
|
||||
def __init__(self,
|
||||
*,
|
||||
local_buffer: LocalReplayBuffer = None,
|
||||
actors: List["ActorHandle"] = None):
|
||||
if bool(local_buffer) == bool(actors):
|
||||
raise ValueError(
|
||||
"Exactly one of local_buffer and replay_actors must be given.")
|
||||
|
||||
if local_buffer:
|
||||
self.local_actor = local_buffer
|
||||
self.replay_actors = None
|
||||
else:
|
||||
self.local_actor = None
|
||||
self.replay_actors = actors
|
||||
|
||||
def __call__(self, batch: SampleBatchType):
|
||||
actor = random.choice(self.replay_actors)
|
||||
actor.add_batch.remote(batch)
|
||||
if self.local_actor:
|
||||
self.local_actor.add_batch(batch)
|
||||
else:
|
||||
actor = random.choice(self.replay_actors)
|
||||
actor.add_batch.remote(batch)
|
||||
return batch
|
||||
|
||||
|
||||
def ParallelReplay(replay_actors: List["ActorHandle"], async_queue_depth=4):
|
||||
"""Replay experiences in parallel from the given actors.
|
||||
def Replay(*,
|
||||
local_buffer: LocalReplayBuffer = None,
|
||||
actors: List["ActorHandle"] = None,
|
||||
async_queue_depth=4):
|
||||
"""Replay experiences from the given buffer or actors.
|
||||
|
||||
This should be combined with the StoreToReplayActors operation using the
|
||||
Concurrently() operator.
|
||||
|
||||
Arguments:
|
||||
replay_actors (list): List of replay actors.
|
||||
local_buffer (LocalReplayBuffer): Local buffer to use. Only one of this
|
||||
and replay_actors can be specified.
|
||||
actors (list): List of replay actors. Only one of this and
|
||||
local_buffer can be specified.
|
||||
async_queue_depth (int): In async mode, the max number of async
|
||||
requests in flight per actor.
|
||||
|
||||
Examples:
|
||||
>>> actors = [ReplayActor.remote() for _ in range(4)]
|
||||
>>> replay_op = ParallelReplay(actors)
|
||||
>>> replay_op = Replay(actors=actors)
|
||||
>>> next(replay_op)
|
||||
SampleBatch(...)
|
||||
"""
|
||||
replay = from_actors(replay_actors)
|
||||
return replay.gather_async(
|
||||
async_queue_depth=async_queue_depth).filter(lambda x: x is not None)
|
||||
|
||||
if bool(local_buffer) == bool(actors):
|
||||
raise ValueError(
|
||||
"Exactly one of local_buffer and replay_actors must be given.")
|
||||
|
||||
def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int):
|
||||
"""Replay experiences from a local buffer instance.
|
||||
if actors:
|
||||
replay = from_actors(actors)
|
||||
return replay.gather_async(async_queue_depth=async_queue_depth).filter(
|
||||
lambda x: x is not None)
|
||||
|
||||
This should be combined with the StoreToReplayBuffer operation using the
|
||||
Concurrently() operator.
|
||||
|
||||
Arguments:
|
||||
replay_buffer (ReplayBuffer): Buffer to replay experiences from.
|
||||
train_batch_size (int): Batch size of fetches from the buffer.
|
||||
|
||||
Examples:
|
||||
>>> actors = [ReplayActor.remote() for _ in range(4)]
|
||||
>>> replay_op = ParallelReplay(actors)
|
||||
>>> next(replay_op)
|
||||
SampleBatch(...)
|
||||
"""
|
||||
assert isinstance(replay_buffer, ReplayBuffer)
|
||||
replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
|
||||
# TODO(ekl) support more options, or combine with ParallelReplay (?)
|
||||
synchronize_sampling = False
|
||||
prioritized_replay_beta = None
|
||||
|
||||
def gen_replay(timeout):
|
||||
def gen_replay(_):
|
||||
while True:
|
||||
samples = {}
|
||||
idxes = None
|
||||
for policy_id, replay_buffer in replay_buffers.items():
|
||||
if synchronize_sampling:
|
||||
if idxes is None:
|
||||
idxes = replay_buffer.sample_idxes(train_batch_size)
|
||||
else:
|
||||
idxes = replay_buffer.sample_idxes(train_batch_size)
|
||||
|
||||
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = replay_buffer.sample_with_idxes(
|
||||
idxes,
|
||||
beta=prioritized_replay_beta.value(num_steps_trained))
|
||||
else:
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones) = replay_buffer.sample_with_idxes(idxes)
|
||||
weights = np.ones_like(rewards)
|
||||
batch_indexes = -np.ones_like(rewards)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
yield MultiAgentBatch(samples, train_batch_size)
|
||||
item = local_buffer.replay()
|
||||
if item is None:
|
||||
yield _NextValueNotReady()
|
||||
else:
|
||||
yield item
|
||||
|
||||
return LocalIterator(gen_replay, SharedMetrics())
|
||||
|
|
|
@ -17,13 +17,14 @@ logger = logging.getLogger(__name__)
|
|||
class TrainOneStep:
|
||||
"""Callable that improves the policy and updates workers.
|
||||
|
||||
This should be used with the .for_each() operator.
|
||||
This should be used with the .for_each() operator. A tuple of the input
|
||||
and learner stats will be returned.
|
||||
|
||||
Examples:
|
||||
>>> rollouts = ParallelRollouts(...)
|
||||
>>> train_op = rollouts.for_each(TrainOneStep(workers))
|
||||
>>> print(next(train_op)) # This trains the policy on one batch.
|
||||
{"learner_stats": ...}
|
||||
SampleBatch(...), {"learner_stats": ...}
|
||||
|
||||
Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the
|
||||
local iterator context.
|
||||
|
@ -32,7 +33,8 @@ class TrainOneStep:
|
|||
def __init__(self, workers: WorkerSet):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, batch: SampleBatchType) -> List[dict]:
|
||||
def __call__(self,
|
||||
batch: SampleBatchType) -> (SampleBatchType, List[dict]):
|
||||
_check_sample_batch_type(batch)
|
||||
metrics = LocalIterator.get_metrics()
|
||||
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
||||
|
@ -48,7 +50,7 @@ class TrainOneStep:
|
|||
e.set_weights.remote(weights, _get_global_vars())
|
||||
# Also update global vars of the local worker.
|
||||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||
return info
|
||||
return batch, info
|
||||
|
||||
|
||||
class ComputeGradients:
|
||||
|
|
|
@ -285,19 +285,19 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
|||
return sample_timesteps, train_timesteps
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class ReplayActor(ParallelIteratorWorker):
|
||||
# TODO(ekl) move this class to common
|
||||
class LocalReplayBuffer(ParallelIteratorWorker):
|
||||
"""A replay buffer shard.
|
||||
|
||||
Ray actors are single-threaded, so for scalability multiple replay actors
|
||||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(self, num_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
replay_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.train_batch_size = train_batch_size
|
||||
self.replay_batch_size = replay_batch_size
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
|
||||
|
@ -331,7 +331,8 @@ class ReplayActor(ParallelIteratorWorker):
|
|||
for row in s.rows():
|
||||
self.replay_buffers[policy_id].add(
|
||||
row["obs"], row["actions"], row["rewards"],
|
||||
row["new_obs"], row["dones"], row["weights"])
|
||||
row["new_obs"], row["dones"], row["weights"]
|
||||
if "weights" in row else None)
|
||||
self.num_added += batch.count
|
||||
|
||||
def replay(self):
|
||||
|
@ -343,7 +344,7 @@ class ReplayActor(ParallelIteratorWorker):
|
|||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||
batch_indexes) = replay_buffer.sample(
|
||||
self.train_batch_size, beta=self.prioritized_replay_beta)
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
|
@ -353,7 +354,7 @@ class ReplayActor(ParallelIteratorWorker):
|
|||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return MultiAgentBatch(samples, self.train_batch_size)
|
||||
return MultiAgentBatch(samples, self.replay_batch_size)
|
||||
|
||||
def update_priorities(self, prio_dict):
|
||||
with self.update_priorities_timer:
|
||||
|
@ -377,10 +378,13 @@ class ReplayActor(ParallelIteratorWorker):
|
|||
return stat
|
||||
|
||||
|
||||
ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)
|
||||
|
||||
|
||||
# TODO(ekl) move this class to common
|
||||
# note: we set num_cpus=0 to avoid failing to create replay actors when
|
||||
# resources are fragmented. This isn't ideal.
|
||||
@ray.remote(num_cpus=0)
|
||||
class BatchReplayActor:
|
||||
class LocalBatchReplayBuffer(LocalReplayBuffer):
|
||||
"""The batch replay version of the replay actor.
|
||||
|
||||
This allows for RNN models, but ignores prioritization params.
|
||||
|
@ -398,9 +402,6 @@ class BatchReplayActor:
|
|||
self.num_added = 0
|
||||
self.cur_size = 0
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
|
@ -427,6 +428,9 @@ class BatchReplayActor:
|
|||
return stat
|
||||
|
||||
|
||||
BatchReplayActor = ray.remote(num_cpus=0)(LocalBatchReplayBuffer)
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from replay data.
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ class EvalTest(unittest.TestCase):
|
|||
def env_creator(env_config):
|
||||
return gym.make("CartPole-v0")
|
||||
|
||||
agent_classes = [DQNTrainer, A3CTrainer]
|
||||
agent_classes = [A3CTrainer, DQNTrainer]
|
||||
|
||||
for agent_cls in agent_classes:
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
|
|
|
@ -3,15 +3,20 @@ import time
|
|||
import gym
|
||||
import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, AsyncGradients, \
|
||||
ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, ComputeGradients, \
|
||||
AverageGradients
|
||||
from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer, \
|
||||
ReplayActor
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.util.iter import LocalIterator, from_range
|
||||
from ray.util.iter_metrics import SharedMetrics
|
||||
|
||||
|
@ -47,6 +52,18 @@ def test_concurrently(ray_start_regular_shared):
|
|||
assert c.take(6) == [1, 2, 3, 4, 5, 6]
|
||||
|
||||
|
||||
def test_concurrently_output(ray_start_regular_shared):
|
||||
a = iter_list([1, 2, 3])
|
||||
b = iter_list([4, 5, 6])
|
||||
c = Concurrently([a, b], mode="round_robin", output_indexes=[1])
|
||||
assert c.take(6) == [4, 5, 6]
|
||||
|
||||
a = iter_list([1, 2, 3])
|
||||
b = iter_list([4, 5, 6])
|
||||
c = Concurrently([a, b], mode="round_robin", output_indexes=[0, 1])
|
||||
assert c.take(6) == [1, 4, 2, 5, 3, 6]
|
||||
|
||||
|
||||
def test_enqueue_dequeue(ray_start_regular_shared):
|
||||
a = iter_list([1, 2, 3])
|
||||
q = queue.Queue(100)
|
||||
|
@ -70,6 +87,7 @@ def test_metrics(ray_start_regular_shared):
|
|||
b = StandardMetricsReporting(
|
||||
a, workers, {
|
||||
"min_iter_time_s": 2.5,
|
||||
"timesteps_per_iteration": 0,
|
||||
"metrics_smoothing_episodes": 10,
|
||||
"collect_metrics_timeout": 10,
|
||||
})
|
||||
|
@ -128,7 +146,9 @@ def test_train_one_step(ray_start_regular_shared):
|
|||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(TrainOneStep(workers))
|
||||
assert "learner_stats" in next(b)
|
||||
batch, stats = next(b)
|
||||
assert isinstance(batch, SampleBatch)
|
||||
assert "learner_stats" in stats
|
||||
counters = a.shared_metrics.get().counters
|
||||
assert counters["num_steps_sampled"] == 100, counters
|
||||
assert counters["num_steps_trained"] == 100, counters
|
||||
|
@ -156,6 +176,54 @@ def test_avg_gradients(ray_start_regular_shared):
|
|||
assert counts == 400, counts
|
||||
|
||||
|
||||
def test_store_to_replay_local(ray_start_regular_shared):
|
||||
buf = LocalReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=200,
|
||||
buffer_size=1000,
|
||||
replay_batch_size=100,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=0.0001)
|
||||
assert buf.replay() is None
|
||||
|
||||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(StoreToReplayBuffer(local_buffer=buf))
|
||||
|
||||
next(b)
|
||||
assert buf.replay() is None # learning hasn't started yet
|
||||
next(b)
|
||||
assert buf.replay().count == 100
|
||||
|
||||
replay_op = Replay(local_buffer=buf)
|
||||
assert next(replay_op).count == 100
|
||||
|
||||
|
||||
def test_store_to_replay_actor(ray_start_regular_shared):
|
||||
actor = ReplayActor.remote(
|
||||
num_shards=1,
|
||||
learning_starts=200,
|
||||
buffer_size=1000,
|
||||
replay_batch_size=100,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=0.0001)
|
||||
assert ray.get(actor.replay.remote()) is None
|
||||
|
||||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(StoreToReplayBuffer(actors=[actor]))
|
||||
|
||||
next(b)
|
||||
assert ray.get(actor.replay.remote()) is None # learning hasn't started
|
||||
next(b)
|
||||
assert ray.get(actor.replay.remote()).count == 100
|
||||
|
||||
replay_op = Replay(actors=[actor])
|
||||
assert next(replay_op).count == 100
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -32,7 +32,11 @@ class TestReproducibility(unittest.TestCase):
|
|||
register_env("PickLargest", env_creator)
|
||||
agent = DQNTrainer(
|
||||
env="PickLargest",
|
||||
config={"seed": 666 if trial in [0, 1] else 999})
|
||||
config={
|
||||
"seed": 666 if trial in [0, 1] else 999,
|
||||
"min_iter_time_s": 0,
|
||||
"timesteps_per_iteration": 100,
|
||||
})
|
||||
|
||||
trajectory = list()
|
||||
for _ in range(8):
|
||||
|
|
|
@ -45,6 +45,10 @@ def create_parser(parser_creator=None):
|
|||
type=str,
|
||||
help="Connect to an existing Ray cluster at this address instead "
|
||||
"of starting a new one.")
|
||||
parser.add_argument(
|
||||
"--no-ray-ui",
|
||||
action="store_true",
|
||||
help="Whether to disable the Ray web ui.")
|
||||
parser.add_argument(
|
||||
"--ray-num-cpus",
|
||||
default=None,
|
||||
|
@ -197,6 +201,7 @@ def run(args, parser):
|
|||
ray.init(address=cluster.address)
|
||||
else:
|
||||
ray.init(
|
||||
include_webui=not args.no_ray_ui,
|
||||
address=args.ray_address,
|
||||
object_store_memory=args.ray_object_store_memory,
|
||||
memory=args.ray_memory,
|
||||
|
|
|
@ -25,10 +25,12 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy):
|
|||
# Use a fixed, different epsilon per worker. See: Ape-X paper.
|
||||
assert worker_index <= num_workers, (worker_index, num_workers)
|
||||
if num_workers > 0:
|
||||
if worker_index >= 0:
|
||||
exponent = (1 + worker_index / float(num_workers - 1) * 7)
|
||||
if worker_index > 0:
|
||||
# From page 5 of https://arxiv.org/pdf/1803.00933.pdf
|
||||
alpha, eps, i = 7, 0.4, worker_index - 1
|
||||
epsilon_schedule = ConstantSchedule(
|
||||
0.4**exponent, framework=framework)
|
||||
eps**(1 + i / (num_workers - 1) * alpha),
|
||||
framework=framework)
|
||||
# Local worker should have zero exploration so that eval
|
||||
# rollouts run properly.
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue