[rllib] First pass at pipeline implementation of DQN (#7433)

* wip iters

* add test

* speed up

* update docs

* document it

* support serial sampling

* add test

* spacing

* annotate it

* update

* rename to pipeline

* comment

* iter2 wip

* update

* update

* context test

* update

* fix

* fix

* a3c pipeline

* doc

* update

* move timer

* comment

* add piepline test

* fix

* clean up

* document

* iter s

* wip dqn

* wip

* wip

* metrics

* metrics rename

* metrics ctx

* wip

* constants

* add todo

* suppport .union

* wip

* support union

* remove prints

* add todo

* remove auto timer

* fix up

* fix pipeline test

* typing

* fix breakage

* remove bad assert

* wip

* fix multiagent example

* fixapply

* update a3c

* remove a2c pl

* 0 workers

* wip

* wip

* share metrics

* wip

* wip

* doc

* fix weight sync and global var updates

* mode

* fix

* fix

* doc

* fix
This commit is contained in:
Eric Liang 2020-03-07 14:47:58 -08:00 committed by GitHub
parent beb9b02dbd
commit a644060daa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 258 additions and 74 deletions

View file

@ -47,14 +47,14 @@ class TuneReporterBase(ProgressReporter):
"""Abstract base class for the default Tune reporters.""" """Abstract base class for the default Tune reporters."""
# Truncated representations of column names (to accommodate small screens). # Truncated representations of column names (to accommodate small screens).
DEFAULT_COLUMNS = { DEFAULT_COLUMNS = collections.OrderedDict({
EPISODE_REWARD_MEAN: "reward",
MEAN_ACCURACY: "acc", MEAN_ACCURACY: "acc",
MEAN_LOSS: "loss", MEAN_LOSS: "loss",
TRAINING_ITERATION: "iter",
TIME_TOTAL_S: "total time (s)", TIME_TOTAL_S: "total time (s)",
TIMESTEPS_TOTAL: "ts", TIMESTEPS_TOTAL: "ts",
TRAINING_ITERATION: "iter", EPISODE_REWARD_MEAN: "reward",
} })
def __init__(self, def __init__(self,
metric_columns=None, metric_columns=None,
@ -301,7 +301,6 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
k for k in keys if any( k for k in keys if any(
t.last_result.get(k) is not None for t in trials) t.last_result.get(k) is not None for t in trials)
] ]
keys = sorted(keys)
# Build trial rows. # Build trial rows.
params = sorted(set().union(*[t.evaluated_params for t in trials])) params = sorted(set().union(*[t.evaluated_params for t in trials]))
trial_table = [_get_trial_info(trial, params, keys) for trial in trials] trial_table = [_get_trial_info(trial, params, keys) for trial in trials]

View file

@ -776,36 +776,35 @@ class LocalIterator(Generic[T]):
if i >= n: if i >= n:
break break
def union(self, other: "LocalIterator[T]", def union(self, *others: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]": deterministic: bool = False) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the other. """Return an iterator that is the union of this and the others.
If deterministic=True, we alternate between reading from one iterator If deterministic=True, we alternate between reading from one iterator
and the other. Otherwise we return items from iterators as they and the others. Otherwise we return items from iterators as they
become ready. become ready.
""" """
if not isinstance(other, LocalIterator): for it in others:
if not isinstance(it, LocalIterator):
raise ValueError( raise ValueError(
"other must be of type LocalIterator, got {}".format( "other must be of type LocalIterator, got {}".format(
type(other))) type(it)))
if deterministic: if deterministic:
timeout = None timeout = None
else: else:
timeout = 0 timeout = 0
it1 = LocalIterator( active = []
self.base_iterator, shared_metrics = MetricsContext()
self.metrics, for it in [self] + list(others):
self.local_transforms, active.append(
timeout=timeout) LocalIterator(
it2 = LocalIterator( it.base_iterator,
other.base_iterator, shared_metrics,
other.metrics, it.local_transforms,
other.local_transforms, timeout=timeout))
timeout=timeout)
active = [it1, it2]
def build_union(timeout=None): def build_union(timeout=None):
while True: while True:
@ -826,15 +825,11 @@ class LocalIterator(Generic[T]):
if not active: if not active:
break break
# TODO(ekl) is this the best way to represent union() of metrics?
new_ctx = MetricsContext()
new_ctx.parent_metrics.append(self.metrics)
new_ctx.parent_metrics.append(other.metrics)
return LocalIterator( return LocalIterator(
build_union, build_union,
new_ctx, [], shared_metrics, [],
name="LocalUnion[{}, {}]".format(self, other)) name="LocalUnion[{}, {}]".format(self, ", ".join(map(str,
others))))
class ParallelIteratorWorker(object): class ParallelIteratorWorker(object):

View file

@ -5,9 +5,13 @@ from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.agents.dqn.simple_q_policy import SimpleQPolicy from ray.rllib.agents.dqn.simple_q_policy import SimpleQPolicy
from ray.rllib.optimizers import SyncReplayOptimizer from ray.rllib.optimizers import SyncReplayOptimizer
from ray.rllib.optimizers.replay_buffer import ReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy
from ray.rllib.utils.experimental_dsl import (
ParallelRollouts, Concurrently, StoreToReplayBuffer, LocalReplay,
TrainOneStep, StandardMetricsReporting, UpdateTargetNetwork)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -308,6 +312,30 @@ def update_target_if_needed(trainer, fetches):
trainer.state["num_target_updates"] += 1 trainer.state["num_target_updates"] += 1
# Experimental pipeline-based impl; enable with "use_pipeline_impl": True.
def training_pipeline(workers, config):
local_replay_buffer = ReplayBuffer(config["buffer_size"])
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer))
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \
.for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
# Alternate deterministically between (1) and (2).
train_op = Concurrently([store_op, replay_op], mode="round_robin")
return StandardMetricsReporting(train_op, workers, config)
GenericOffPolicyTrainer = build_trainer( GenericOffPolicyTrainer = build_trainer(
name="GenericOffPolicyAlgorithm", name="GenericOffPolicyAlgorithm",
default_policy=None, default_policy=None,
@ -317,7 +345,8 @@ GenericOffPolicyTrainer = build_trainer(
make_policy_optimizer=make_policy_optimizer, make_policy_optimizer=make_policy_optimizer,
before_train_step=update_worker_exploration, before_train_step=update_worker_exploration,
after_optimizer_step=update_target_if_needed, after_optimizer_step=update_target_if_needed,
after_train_result=after_train_result) after_train_result=after_train_result,
training_pipeline=training_pipeline)
DQNTrainer = GenericOffPolicyTrainer.with_updates( DQNTrainer = GenericOffPolicyTrainer.with_updates(
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)

View file

@ -168,10 +168,10 @@ def build_trainer(name,
def _train_pipeline(self): def _train_pipeline(self):
if before_train_step: if before_train_step:
before_train_step(self) logger.warning("Ignoring before_train_step callback")
res = next(self.train_pipeline) res = next(self.train_pipeline)
if after_train_result: if after_train_result:
after_train_result(self, res) logger.warning("Ignoring after_train_result callback")
return res return res
@override(Trainer) @override(Trainer)

View file

@ -546,9 +546,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
} }
@override(EvaluatorInterface) @override(EvaluatorInterface)
def set_weights(self, weights): def set_weights(self, weights, global_vars=None):
for pid, w in weights.items(): for pid, w in weights.items():
self.policy_map[pid].set_weights(w) self.policy_map[pid].set_weights(w)
if global_vars:
self.set_global_vars(global_vars)
@override(EvaluatorInterface) @override(EvaluatorInterface)
def compute_gradients(self, samples): def compute_gradients(self, samples):

View file

@ -1,6 +1,7 @@
import logging import logging
from types import FunctionType from types import FunctionType
import ray
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \ from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
_validate_multiagent_config _validate_multiagent_config
@ -71,6 +72,12 @@ class WorkerSet:
"""Return a list of remote rollout workers.""" """Return a list of remote rollout workers."""
return self._remote_workers return self._remote_workers
def sync_weights(self):
"""Syncs weights of remote workers with the local worker."""
weights = ray.put(self.local_worker().get_weights())
for e in self.remote_workers():
e.set_weights.remote(weights)
def add_workers(self, num_workers): def add_workers(self, num_workers):
"""Creates and add a number of remote workers to this worker set. """Creates and add a number of remote workers to this worker set.

View file

@ -97,18 +97,18 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
if alg not in ["DDPG", "ES", "ARS", "SAC"]: if alg not in ["DDPG", "ES", "ARS", "SAC"]:
if o_name in ["atari", "image"]: if o_name in ["atari", "image"]:
if torch: if torch:
assert isinstance( assert isinstance(a.get_policy().model,
a.get_policy().model, TorchVisionNetV2) TorchVisionNetV2)
else: else:
assert isinstance( assert isinstance(a.get_policy().model,
a.get_policy().model, VisionNetV2) VisionNetV2)
elif o_name in ["vector", "vector2"]: elif o_name in ["vector", "vector2"]:
if torch: if torch:
assert isinstance( assert isinstance(a.get_policy().model,
a.get_policy().model, TorchFCNetV2) TorchFCNetV2)
else: else:
assert isinstance( assert isinstance(a.get_policy().model,
a.get_policy().model, FCNetV2) FCNetV2)
a.train() a.train()
covered_a.add(a_name) covered_a.add(a_name)
covered_o.add(o_name) covered_o.add(o_name)
@ -159,12 +159,7 @@ class ModelSupportedSpaces(unittest.TestCase):
ray.shutdown() ray.shutdown()
def test_a3c(self): def test_a3c(self):
config = { config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
"num_workers": 1,
"optimizer": {
"grads_per_step": 1
}
}
check_support("A3C", config, self.stats, check_bounds=True) check_support("A3C", config, self.stats, check_bounds=True)
config["use_pytorch"] = True config["use_pytorch"] = True
check_support("A3C", config, self.stats, check_bounds=True) check_support("A3C", config, self.stats, check_bounds=True)
@ -228,10 +223,7 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support("PPO", config, self.stats, check_bounds=True) check_support("PPO", config, self.stats, check_bounds=True)
def test_pg(self): def test_pg(self):
config = { config = {"num_workers": 1, "optimizer": {}}
"num_workers": 1,
"optimizer": {}
}
check_support("PG", config, self.stats, check_bounds=True) check_support("PG", config, self.stats, check_bounds=True)
config["use_pytorch"] = True config["use_pytorch"] = True
check_support("PG", config, self.stats, check_bounds=True) check_support("PG", config, self.stats, check_bounds=True)

View file

@ -4,28 +4,40 @@ TODO(ekl): describe the concepts."""
import logging import logging
from typing import List, Any, Tuple, Union from typing import List, Any, Tuple, Union
import numpy as np
import time import time
import ray import ray
from ray.util.iter import from_actors, LocalIterator from ray.util.iter import from_actors, LocalIterator
from ray.util.iter_metrics import MetricsContext from ray.util.iter_metrics import MetricsContext
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.evaluation.metrics import collect_episodes, \ from ray.rllib.evaluation.metrics import collect_episodes, \
summarize_episodes, get_learner_stats summarize_episodes, get_learner_stats
from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.compression import pack_if_needed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Metrics context key definitions. # Counters for training progress (keys for metrics.counters).
STEPS_SAMPLED_COUNTER = "num_steps_sampled" STEPS_SAMPLED_COUNTER = "num_steps_sampled"
STEPS_TRAINED_COUNTER = "num_steps_trained" STEPS_TRAINED_COUNTER = "num_steps_trained"
# Counters to track target network updates.
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
NUM_TARGET_UPDATES = "num_target_updates"
# Performance timers (keys for metrics.timers).
APPLY_GRADS_TIMER = "apply_grad" APPLY_GRADS_TIMER = "apply_grad"
COMPUTE_GRADS_TIMER = "compute_grads" COMPUTE_GRADS_TIMER = "compute_grads"
WORKER_UPDATE_TIMER = "update" WORKER_UPDATE_TIMER = "update"
GRAD_WAIT_TIMER = "grad_wait" GRAD_WAIT_TIMER = "grad_wait"
SAMPLE_TIMER = "sample" SAMPLE_TIMER = "sample"
LEARN_ON_BATCH_TIMER = "learn" LEARN_ON_BATCH_TIMER = "learn"
# Instant metrics (keys for metrics.info).
LEARNER_INFO = "learner" LEARNER_INFO = "learner"
# Type aliases. # Type aliases.
@ -33,12 +45,19 @@ GradientType = dict
SampleBatchType = Union[SampleBatch, MultiAgentBatch] SampleBatchType = Union[SampleBatch, MultiAgentBatch]
# Asserts that an object is a type of SampleBatch.
def _check_sample_batch_type(batch): def _check_sample_batch_type(batch):
if not isinstance(batch, SampleBatchType.__args__): if not isinstance(batch, SampleBatchType.__args__):
raise ValueError("Expected either SampleBatch or MultiAgentBatch, " raise ValueError("Expected either SampleBatch or MultiAgentBatch, "
"got {}: {}".format(type(batch), batch)) "got {}: {}".format(type(batch), batch))
# Returns pipeline global vars that should be periodically sent to each worker.
def _get_global_vars():
metrics = LocalIterator.get_metrics()
return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
def ParallelRollouts(workers: WorkerSet, def ParallelRollouts(workers: WorkerSet,
mode="bulk_sync") -> LocalIterator[SampleBatch]: mode="bulk_sync") -> LocalIterator[SampleBatch]:
"""Operator to collect experiences in parallel from rollout workers. """Operator to collect experiences in parallel from rollout workers.
@ -71,6 +90,9 @@ def ParallelRollouts(workers: WorkerSet,
Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
""" """
# Ensure workers are initially in sync.
workers.sync_weights()
def report_timesteps(batch): def report_timesteps(batch):
metrics = LocalIterator.get_metrics() metrics = LocalIterator.get_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
@ -119,6 +141,9 @@ def AsyncGradients(
local iterator context. local iterator context.
""" """
# Ensure workers are initially in sync.
workers.sync_weights()
# This function will be applied remotely on the workers. # This function will be applied remotely on the workers.
def samples_to_grads(samples): def samples_to_grads(samples):
return get_global_worker().compute_gradients(samples), samples.count return get_global_worker().compute_gradients(samples), samples.count
@ -240,7 +265,9 @@ class TrainOneStep:
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put(self.workers.local_worker().get_weights()) weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers(): for e in self.workers.remote_workers():
e.set_weights.remote(weights) e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
return info return info
@ -266,9 +293,7 @@ class CollectMetrics:
self.timeout_seconds = timeout_seconds self.timeout_seconds = timeout_seconds
def __call__(self, _): def __call__(self, _):
metrics = LocalIterator.get_metrics() # Collect worker metrics.
if metrics.parent_metrics:
raise ValueError("TODO: support nested metrics")
episodes, self.to_be_collected = collect_episodes( episodes, self.to_be_collected = collect_episodes(
self.workers.local_worker(), self.workers.local_worker(),
self.workers.remote_workers(), self.workers.remote_workers(),
@ -282,22 +307,31 @@ class CollectMetrics:
self.episode_history.extend(orig_episodes) self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-self.min_history:] self.episode_history = self.episode_history[-self.min_history:]
res = summarize_episodes(episodes, orig_episodes) res = summarize_episodes(episodes, orig_episodes)
res.update(info=metrics.info)
res["info"].update({ # Add in iterator metrics.
STEPS_SAMPLED_COUNTER: metrics.counters[STEPS_SAMPLED_COUNTER], metrics = LocalIterator.get_metrics()
STEPS_TRAINED_COUNTER: metrics.counters[STEPS_TRAINED_COUNTER], if metrics.parent_metrics:
}) print("TODO: support nested metrics better")
all_metrics = [metrics] + metrics.parent_metrics
timers = {} timers = {}
counters = {}
info = {}
for metrics in all_metrics:
info.update(metrics.info)
for k, counter in metrics.counters.items():
counters[k] = counter
for k, timer in metrics.timers.items(): for k, timer in metrics.timers.items():
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
if timer.has_units_processed(): if timer.has_units_processed():
timers["{}_throughput".format(k)] = round( timers["{}_throughput".format(k)] = round(
timer.mean_throughput, 3) timer.mean_throughput, 3)
res["timers"] = timers
res.update({ res.update({
"num_healthy_workers": len(self.workers.remote_workers()), "num_healthy_workers": len(self.workers.remote_workers()),
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
}) })
res["timers"] = timers
res["info"] = info
res["info"].update(counters)
return res return res
@ -392,13 +426,16 @@ class ApplyGradients:
self.workers.local_worker().apply_gradients(gradients) self.workers.local_worker().apply_gradients(gradients)
apply_timer.push_units_processed(count) apply_timer.push_units_processed(count)
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
if self.update_all: if self.update_all:
if self.workers.remote_workers(): if self.workers.remote_workers():
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[WORKER_UPDATE_TIMER]:
weights = ray.put( weights = ray.put(
self.workers.local_worker().get_weights()) self.workers.local_worker().get_weights())
for e in self.workers.remote_workers(): for e in self.workers.remote_workers():
e.set_weights.remote(weights) e.set_weights.remote(weights, _get_global_vars())
else: else:
if metrics.cur_actor is None: if metrics.cur_actor is None:
raise ValueError("Could not find actor to update. When " raise ValueError("Could not find actor to update. When "
@ -406,7 +443,8 @@ class ApplyGradients:
"in the iterator context.") "in the iterator context.")
with metrics.timers[WORKER_UPDATE_TIMER]: with metrics.timers[WORKER_UPDATE_TIMER]:
weights = self.workers.local_worker().get_weights() weights = self.workers.local_worker().get_weights()
metrics.cur_actor.set_weights.remote(weights) metrics.cur_actor.set_weights.remote(weights,
_get_global_vars())
class AverageGradients: class AverageGradients:
@ -434,3 +472,125 @@ class AverageGradients:
logger.info("Computing average of {} microbatch gradients " logger.info("Computing average of {} microbatch gradients "
"({} samples total)".format(len(gradients), sum_count)) "({} samples total)".format(len(gradients), sum_count))
return acc, sum_count return acc, sum_count
class StoreToReplayBuffer:
def __init__(self, replay_buffer):
self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
def __call__(self, batch: SampleBatchType):
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
for policy_id, s in batch.policy_batches.items():
for row in s.rows():
self.replay_buffers[policy_id].add(
pack_if_needed(row["obs"]),
row["actions"],
row["rewards"],
pack_if_needed(row["new_obs"]),
row["dones"],
weight=None)
def LocalReplay(replay_buffer, train_batch_size):
replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
# TODO(ekl) support more options
synchronize_sampling = False
prioritized_replay_beta = None
def gen_replay(timeout):
while True:
samples = {}
idxes = None
for policy_id, replay_buffer in replay_buffers.items():
if synchronize_sampling:
if idxes is None:
idxes = replay_buffer.sample_idxes(train_batch_size)
else:
idxes = replay_buffer.sample_idxes(train_batch_size)
if isinstance(replay_buffer, PrioritizedReplayBuffer):
metrics = LocalIterator.get_metrics()
num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
(obses_t, actions, rewards, obses_tp1, dones, weights,
batch_indexes) = replay_buffer.sample_with_idxes(
idxes,
beta=prioritized_replay_beta.value(num_steps_trained))
else:
(obses_t, actions, rewards, obses_tp1,
dones) = replay_buffer.sample_with_idxes(idxes)
weights = np.ones_like(rewards)
batch_indexes = -np.ones_like(rewards)
samples[policy_id] = SampleBatch({
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
"weights": weights,
"batch_indexes": batch_indexes
})
yield MultiAgentBatch(samples, train_batch_size)
return LocalIterator(gen_replay, MetricsContext())
def Concurrently(ops: List[LocalIterator], mode="round_robin"):
"""Operator that runs the given parent iterators concurrently.
Arguments:
mode (str): One of {'round_robin', 'async'}.
- In 'round_robin' mode, we alternate between pulling items from
each parent iterator in order deterministically.
- In 'async' mode, we pull from each parent iterator as fast as
they are produced. This is non-deterministic.
>>> sim_op = ParallelRollouts(...).for_each(...)
>>> replay_op = LocalReplay(...).for_each(...)
>>> combined_op = Concurrently([sim_op, replay_op])
"""
if len(ops) < 2:
raise ValueError("Should specify at least 2 ops.")
if mode == "round_robin":
deterministic = True
elif mode == "async":
deterministic = False
else:
raise ValueError("Unknown mode {}".format(mode))
return ops[0].union(*ops[1:], deterministic=deterministic)
class UpdateTargetNetwork:
"""Periodically call policy.update_target() on all trainable policies.
This should be used with the .for_each() operator after training step
has been taken.
Examples:
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
>>> update_op = train_op.for_each(
... UpdateTargetIfNeeded(workers, target_update_freq=500))
>>> print(next(update_op))
None
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
local iterator context. The value of the last update counter is used to
track when we should update the target next.
"""
def __init__(self, workers, target_update_freq):
self.workers = workers
self.target_update_freq = target_update_freq
def __call__(self, _):
metrics = LocalIterator.get_metrics()
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update > self.target_update_freq:
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
metrics.counters[NUM_TARGET_UPDATES] += 1
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts