mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] First pass at pipeline implementation of DQN (#7433)
* wip iters * add test * speed up * update docs * document it * support serial sampling * add test * spacing * annotate it * update * rename to pipeline * comment * iter2 wip * update * update * context test * update * fix * fix * a3c pipeline * doc * update * move timer * comment * add piepline test * fix * clean up * document * iter s * wip dqn * wip * wip * metrics * metrics rename * metrics ctx * wip * constants * add todo * suppport .union * wip * support union * remove prints * add todo * remove auto timer * fix up * fix pipeline test * typing * fix breakage * remove bad assert * wip * fix multiagent example * fixapply * update a3c * remove a2c pl * 0 workers * wip * wip * share metrics * wip * wip * doc * fix weight sync and global var updates * mode * fix * fix * doc * fix
This commit is contained in:
parent
beb9b02dbd
commit
a644060daa
8 changed files with 258 additions and 74 deletions
|
@ -47,14 +47,14 @@ class TuneReporterBase(ProgressReporter):
|
||||||
"""Abstract base class for the default Tune reporters."""
|
"""Abstract base class for the default Tune reporters."""
|
||||||
|
|
||||||
# Truncated representations of column names (to accommodate small screens).
|
# Truncated representations of column names (to accommodate small screens).
|
||||||
DEFAULT_COLUMNS = {
|
DEFAULT_COLUMNS = collections.OrderedDict({
|
||||||
EPISODE_REWARD_MEAN: "reward",
|
|
||||||
MEAN_ACCURACY: "acc",
|
MEAN_ACCURACY: "acc",
|
||||||
MEAN_LOSS: "loss",
|
MEAN_LOSS: "loss",
|
||||||
|
TRAINING_ITERATION: "iter",
|
||||||
TIME_TOTAL_S: "total time (s)",
|
TIME_TOTAL_S: "total time (s)",
|
||||||
TIMESTEPS_TOTAL: "ts",
|
TIMESTEPS_TOTAL: "ts",
|
||||||
TRAINING_ITERATION: "iter",
|
EPISODE_REWARD_MEAN: "reward",
|
||||||
}
|
})
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
metric_columns=None,
|
metric_columns=None,
|
||||||
|
@ -301,7 +301,6 @@ def trial_progress_str(trials, metric_columns, fmt="psql", max_rows=None):
|
||||||
k for k in keys if any(
|
k for k in keys if any(
|
||||||
t.last_result.get(k) is not None for t in trials)
|
t.last_result.get(k) is not None for t in trials)
|
||||||
]
|
]
|
||||||
keys = sorted(keys)
|
|
||||||
# Build trial rows.
|
# Build trial rows.
|
||||||
params = sorted(set().union(*[t.evaluated_params for t in trials]))
|
params = sorted(set().union(*[t.evaluated_params for t in trials]))
|
||||||
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
|
trial_table = [_get_trial_info(trial, params, keys) for trial in trials]
|
||||||
|
|
|
@ -776,36 +776,35 @@ class LocalIterator(Generic[T]):
|
||||||
if i >= n:
|
if i >= n:
|
||||||
break
|
break
|
||||||
|
|
||||||
def union(self, other: "LocalIterator[T]",
|
def union(self, *others: "LocalIterator[T]",
|
||||||
deterministic: bool = False) -> "LocalIterator[T]":
|
deterministic: bool = False) -> "LocalIterator[T]":
|
||||||
"""Return an iterator that is the union of this and the other.
|
"""Return an iterator that is the union of this and the others.
|
||||||
|
|
||||||
If deterministic=True, we alternate between reading from one iterator
|
If deterministic=True, we alternate between reading from one iterator
|
||||||
and the other. Otherwise we return items from iterators as they
|
and the others. Otherwise we return items from iterators as they
|
||||||
become ready.
|
become ready.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(other, LocalIterator):
|
for it in others:
|
||||||
|
if not isinstance(it, LocalIterator):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"other must be of type LocalIterator, got {}".format(
|
"other must be of type LocalIterator, got {}".format(
|
||||||
type(other)))
|
type(it)))
|
||||||
|
|
||||||
if deterministic:
|
if deterministic:
|
||||||
timeout = None
|
timeout = None
|
||||||
else:
|
else:
|
||||||
timeout = 0
|
timeout = 0
|
||||||
|
|
||||||
it1 = LocalIterator(
|
active = []
|
||||||
self.base_iterator,
|
shared_metrics = MetricsContext()
|
||||||
self.metrics,
|
for it in [self] + list(others):
|
||||||
self.local_transforms,
|
active.append(
|
||||||
timeout=timeout)
|
LocalIterator(
|
||||||
it2 = LocalIterator(
|
it.base_iterator,
|
||||||
other.base_iterator,
|
shared_metrics,
|
||||||
other.metrics,
|
it.local_transforms,
|
||||||
other.local_transforms,
|
timeout=timeout))
|
||||||
timeout=timeout)
|
|
||||||
active = [it1, it2]
|
|
||||||
|
|
||||||
def build_union(timeout=None):
|
def build_union(timeout=None):
|
||||||
while True:
|
while True:
|
||||||
|
@ -826,15 +825,11 @@ class LocalIterator(Generic[T]):
|
||||||
if not active:
|
if not active:
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO(ekl) is this the best way to represent union() of metrics?
|
|
||||||
new_ctx = MetricsContext()
|
|
||||||
new_ctx.parent_metrics.append(self.metrics)
|
|
||||||
new_ctx.parent_metrics.append(other.metrics)
|
|
||||||
|
|
||||||
return LocalIterator(
|
return LocalIterator(
|
||||||
build_union,
|
build_union,
|
||||||
new_ctx, [],
|
shared_metrics, [],
|
||||||
name="LocalUnion[{}, {}]".format(self, other))
|
name="LocalUnion[{}, {}]".format(self, ", ".join(map(str,
|
||||||
|
others))))
|
||||||
|
|
||||||
|
|
||||||
class ParallelIteratorWorker(object):
|
class ParallelIteratorWorker(object):
|
||||||
|
|
|
@ -5,9 +5,13 @@ from ray.rllib.agents.trainer_template import build_trainer
|
||||||
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
||||||
from ray.rllib.agents.dqn.simple_q_policy import SimpleQPolicy
|
from ray.rllib.agents.dqn.simple_q_policy import SimpleQPolicy
|
||||||
from ray.rllib.optimizers import SyncReplayOptimizer
|
from ray.rllib.optimizers import SyncReplayOptimizer
|
||||||
|
from ray.rllib.optimizers.replay_buffer import ReplayBuffer
|
||||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||||
from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy
|
from ray.rllib.utils.exploration import PerWorkerEpsilonGreedy
|
||||||
|
from ray.rllib.utils.experimental_dsl import (
|
||||||
|
ParallelRollouts, Concurrently, StoreToReplayBuffer, LocalReplay,
|
||||||
|
TrainOneStep, StandardMetricsReporting, UpdateTargetNetwork)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -308,6 +312,30 @@ def update_target_if_needed(trainer, fetches):
|
||||||
trainer.state["num_target_updates"] += 1
|
trainer.state["num_target_updates"] += 1
|
||||||
|
|
||||||
|
|
||||||
|
# Experimental pipeline-based impl; enable with "use_pipeline_impl": True.
|
||||||
|
def training_pipeline(workers, config):
|
||||||
|
local_replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||||
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||||
|
|
||||||
|
# We execute the following steps concurrently:
|
||||||
|
# (1) Generate rollouts and store them in our local replay buffer. Calling
|
||||||
|
# next() on store_op drives this.
|
||||||
|
store_op = rollouts.for_each(StoreToReplayBuffer(local_replay_buffer))
|
||||||
|
|
||||||
|
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||||
|
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||||
|
# take a SGD step, and then we decide whether to update the target network.
|
||||||
|
replay_op = LocalReplay(local_replay_buffer, config["train_batch_size"]) \
|
||||||
|
.for_each(TrainOneStep(workers)) \
|
||||||
|
.for_each(UpdateTargetNetwork(
|
||||||
|
workers, config["target_network_update_freq"]))
|
||||||
|
|
||||||
|
# Alternate deterministically between (1) and (2).
|
||||||
|
train_op = Concurrently([store_op, replay_op], mode="round_robin")
|
||||||
|
|
||||||
|
return StandardMetricsReporting(train_op, workers, config)
|
||||||
|
|
||||||
|
|
||||||
GenericOffPolicyTrainer = build_trainer(
|
GenericOffPolicyTrainer = build_trainer(
|
||||||
name="GenericOffPolicyAlgorithm",
|
name="GenericOffPolicyAlgorithm",
|
||||||
default_policy=None,
|
default_policy=None,
|
||||||
|
@ -317,7 +345,8 @@ GenericOffPolicyTrainer = build_trainer(
|
||||||
make_policy_optimizer=make_policy_optimizer,
|
make_policy_optimizer=make_policy_optimizer,
|
||||||
before_train_step=update_worker_exploration,
|
before_train_step=update_worker_exploration,
|
||||||
after_optimizer_step=update_target_if_needed,
|
after_optimizer_step=update_target_if_needed,
|
||||||
after_train_result=after_train_result)
|
after_train_result=after_train_result,
|
||||||
|
training_pipeline=training_pipeline)
|
||||||
|
|
||||||
DQNTrainer = GenericOffPolicyTrainer.with_updates(
|
DQNTrainer = GenericOffPolicyTrainer.with_updates(
|
||||||
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
|
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
|
||||||
|
|
|
@ -168,10 +168,10 @@ def build_trainer(name,
|
||||||
|
|
||||||
def _train_pipeline(self):
|
def _train_pipeline(self):
|
||||||
if before_train_step:
|
if before_train_step:
|
||||||
before_train_step(self)
|
logger.warning("Ignoring before_train_step callback")
|
||||||
res = next(self.train_pipeline)
|
res = next(self.train_pipeline)
|
||||||
if after_train_result:
|
if after_train_result:
|
||||||
after_train_result(self, res)
|
logger.warning("Ignoring after_train_result callback")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@override(Trainer)
|
@override(Trainer)
|
||||||
|
|
|
@ -546,9 +546,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||||
}
|
}
|
||||||
|
|
||||||
@override(EvaluatorInterface)
|
@override(EvaluatorInterface)
|
||||||
def set_weights(self, weights):
|
def set_weights(self, weights, global_vars=None):
|
||||||
for pid, w in weights.items():
|
for pid, w in weights.items():
|
||||||
self.policy_map[pid].set_weights(w)
|
self.policy_map[pid].set_weights(w)
|
||||||
|
if global_vars:
|
||||||
|
self.set_global_vars(global_vars)
|
||||||
|
|
||||||
@override(EvaluatorInterface)
|
@override(EvaluatorInterface)
|
||||||
def compute_gradients(self, samples):
|
def compute_gradients(self, samples):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
|
|
||||||
|
import ray
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
||||||
_validate_multiagent_config
|
_validate_multiagent_config
|
||||||
|
@ -71,6 +72,12 @@ class WorkerSet:
|
||||||
"""Return a list of remote rollout workers."""
|
"""Return a list of remote rollout workers."""
|
||||||
return self._remote_workers
|
return self._remote_workers
|
||||||
|
|
||||||
|
def sync_weights(self):
|
||||||
|
"""Syncs weights of remote workers with the local worker."""
|
||||||
|
weights = ray.put(self.local_worker().get_weights())
|
||||||
|
for e in self.remote_workers():
|
||||||
|
e.set_weights.remote(weights)
|
||||||
|
|
||||||
def add_workers(self, num_workers):
|
def add_workers(self, num_workers):
|
||||||
"""Creates and add a number of remote workers to this worker set.
|
"""Creates and add a number of remote workers to this worker set.
|
||||||
|
|
||||||
|
|
|
@ -97,18 +97,18 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
|
||||||
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
|
if alg not in ["DDPG", "ES", "ARS", "SAC"]:
|
||||||
if o_name in ["atari", "image"]:
|
if o_name in ["atari", "image"]:
|
||||||
if torch:
|
if torch:
|
||||||
assert isinstance(
|
assert isinstance(a.get_policy().model,
|
||||||
a.get_policy().model, TorchVisionNetV2)
|
TorchVisionNetV2)
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(a.get_policy().model,
|
||||||
a.get_policy().model, VisionNetV2)
|
VisionNetV2)
|
||||||
elif o_name in ["vector", "vector2"]:
|
elif o_name in ["vector", "vector2"]:
|
||||||
if torch:
|
if torch:
|
||||||
assert isinstance(
|
assert isinstance(a.get_policy().model,
|
||||||
a.get_policy().model, TorchFCNetV2)
|
TorchFCNetV2)
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(a.get_policy().model,
|
||||||
a.get_policy().model, FCNetV2)
|
FCNetV2)
|
||||||
a.train()
|
a.train()
|
||||||
covered_a.add(a_name)
|
covered_a.add(a_name)
|
||||||
covered_o.add(o_name)
|
covered_o.add(o_name)
|
||||||
|
@ -159,12 +159,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
def test_a3c(self):
|
def test_a3c(self):
|
||||||
config = {
|
config = {"num_workers": 1, "optimizer": {"grads_per_step": 1}}
|
||||||
"num_workers": 1,
|
|
||||||
"optimizer": {
|
|
||||||
"grads_per_step": 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
check_support("A3C", config, self.stats, check_bounds=True)
|
check_support("A3C", config, self.stats, check_bounds=True)
|
||||||
config["use_pytorch"] = True
|
config["use_pytorch"] = True
|
||||||
check_support("A3C", config, self.stats, check_bounds=True)
|
check_support("A3C", config, self.stats, check_bounds=True)
|
||||||
|
@ -228,10 +223,7 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||||
check_support("PPO", config, self.stats, check_bounds=True)
|
check_support("PPO", config, self.stats, check_bounds=True)
|
||||||
|
|
||||||
def test_pg(self):
|
def test_pg(self):
|
||||||
config = {
|
config = {"num_workers": 1, "optimizer": {}}
|
||||||
"num_workers": 1,
|
|
||||||
"optimizer": {}
|
|
||||||
}
|
|
||||||
check_support("PG", config, self.stats, check_bounds=True)
|
check_support("PG", config, self.stats, check_bounds=True)
|
||||||
config["use_pytorch"] = True
|
config["use_pytorch"] = True
|
||||||
check_support("PG", config, self.stats, check_bounds=True)
|
check_support("PG", config, self.stats, check_bounds=True)
|
||||||
|
|
|
@ -4,28 +4,40 @@ TODO(ekl): describe the concepts."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Any, Tuple, Union
|
from typing import List, Any, Tuple, Union
|
||||||
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.util.iter import from_actors, LocalIterator
|
from ray.util.iter import from_actors, LocalIterator
|
||||||
from ray.util.iter_metrics import MetricsContext
|
from ray.util.iter_metrics import MetricsContext
|
||||||
|
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||||
from ray.rllib.evaluation.metrics import collect_episodes, \
|
from ray.rllib.evaluation.metrics import collect_episodes, \
|
||||||
summarize_episodes, get_learner_stats
|
summarize_episodes, get_learner_stats
|
||||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, \
|
||||||
|
DEFAULT_POLICY_ID
|
||||||
|
from ray.rllib.utils.compression import pack_if_needed
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Metrics context key definitions.
|
# Counters for training progress (keys for metrics.counters).
|
||||||
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
|
STEPS_SAMPLED_COUNTER = "num_steps_sampled"
|
||||||
STEPS_TRAINED_COUNTER = "num_steps_trained"
|
STEPS_TRAINED_COUNTER = "num_steps_trained"
|
||||||
|
|
||||||
|
# Counters to track target network updates.
|
||||||
|
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||||
|
NUM_TARGET_UPDATES = "num_target_updates"
|
||||||
|
|
||||||
|
# Performance timers (keys for metrics.timers).
|
||||||
APPLY_GRADS_TIMER = "apply_grad"
|
APPLY_GRADS_TIMER = "apply_grad"
|
||||||
COMPUTE_GRADS_TIMER = "compute_grads"
|
COMPUTE_GRADS_TIMER = "compute_grads"
|
||||||
WORKER_UPDATE_TIMER = "update"
|
WORKER_UPDATE_TIMER = "update"
|
||||||
GRAD_WAIT_TIMER = "grad_wait"
|
GRAD_WAIT_TIMER = "grad_wait"
|
||||||
SAMPLE_TIMER = "sample"
|
SAMPLE_TIMER = "sample"
|
||||||
LEARN_ON_BATCH_TIMER = "learn"
|
LEARN_ON_BATCH_TIMER = "learn"
|
||||||
|
|
||||||
|
# Instant metrics (keys for metrics.info).
|
||||||
LEARNER_INFO = "learner"
|
LEARNER_INFO = "learner"
|
||||||
|
|
||||||
# Type aliases.
|
# Type aliases.
|
||||||
|
@ -33,12 +45,19 @@ GradientType = dict
|
||||||
SampleBatchType = Union[SampleBatch, MultiAgentBatch]
|
SampleBatchType = Union[SampleBatch, MultiAgentBatch]
|
||||||
|
|
||||||
|
|
||||||
|
# Asserts that an object is a type of SampleBatch.
|
||||||
def _check_sample_batch_type(batch):
|
def _check_sample_batch_type(batch):
|
||||||
if not isinstance(batch, SampleBatchType.__args__):
|
if not isinstance(batch, SampleBatchType.__args__):
|
||||||
raise ValueError("Expected either SampleBatch or MultiAgentBatch, "
|
raise ValueError("Expected either SampleBatch or MultiAgentBatch, "
|
||||||
"got {}: {}".format(type(batch), batch))
|
"got {}: {}".format(type(batch), batch))
|
||||||
|
|
||||||
|
|
||||||
|
# Returns pipeline global vars that should be periodically sent to each worker.
|
||||||
|
def _get_global_vars():
|
||||||
|
metrics = LocalIterator.get_metrics()
|
||||||
|
return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
|
||||||
|
|
||||||
|
|
||||||
def ParallelRollouts(workers: WorkerSet,
|
def ParallelRollouts(workers: WorkerSet,
|
||||||
mode="bulk_sync") -> LocalIterator[SampleBatch]:
|
mode="bulk_sync") -> LocalIterator[SampleBatch]:
|
||||||
"""Operator to collect experiences in parallel from rollout workers.
|
"""Operator to collect experiences in parallel from rollout workers.
|
||||||
|
@ -71,6 +90,9 @@ def ParallelRollouts(workers: WorkerSet,
|
||||||
Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
|
Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Ensure workers are initially in sync.
|
||||||
|
workers.sync_weights()
|
||||||
|
|
||||||
def report_timesteps(batch):
|
def report_timesteps(batch):
|
||||||
metrics = LocalIterator.get_metrics()
|
metrics = LocalIterator.get_metrics()
|
||||||
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
|
||||||
|
@ -119,6 +141,9 @@ def AsyncGradients(
|
||||||
local iterator context.
|
local iterator context.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Ensure workers are initially in sync.
|
||||||
|
workers.sync_weights()
|
||||||
|
|
||||||
# This function will be applied remotely on the workers.
|
# This function will be applied remotely on the workers.
|
||||||
def samples_to_grads(samples):
|
def samples_to_grads(samples):
|
||||||
return get_global_worker().compute_gradients(samples), samples.count
|
return get_global_worker().compute_gradients(samples), samples.count
|
||||||
|
@ -240,7 +265,9 @@ class TrainOneStep:
|
||||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||||
weights = ray.put(self.workers.local_worker().get_weights())
|
weights = ray.put(self.workers.local_worker().get_weights())
|
||||||
for e in self.workers.remote_workers():
|
for e in self.workers.remote_workers():
|
||||||
e.set_weights.remote(weights)
|
e.set_weights.remote(weights, _get_global_vars())
|
||||||
|
# Also update global vars of the local worker.
|
||||||
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
@ -266,9 +293,7 @@ class CollectMetrics:
|
||||||
self.timeout_seconds = timeout_seconds
|
self.timeout_seconds = timeout_seconds
|
||||||
|
|
||||||
def __call__(self, _):
|
def __call__(self, _):
|
||||||
metrics = LocalIterator.get_metrics()
|
# Collect worker metrics.
|
||||||
if metrics.parent_metrics:
|
|
||||||
raise ValueError("TODO: support nested metrics")
|
|
||||||
episodes, self.to_be_collected = collect_episodes(
|
episodes, self.to_be_collected = collect_episodes(
|
||||||
self.workers.local_worker(),
|
self.workers.local_worker(),
|
||||||
self.workers.remote_workers(),
|
self.workers.remote_workers(),
|
||||||
|
@ -282,22 +307,31 @@ class CollectMetrics:
|
||||||
self.episode_history.extend(orig_episodes)
|
self.episode_history.extend(orig_episodes)
|
||||||
self.episode_history = self.episode_history[-self.min_history:]
|
self.episode_history = self.episode_history[-self.min_history:]
|
||||||
res = summarize_episodes(episodes, orig_episodes)
|
res = summarize_episodes(episodes, orig_episodes)
|
||||||
res.update(info=metrics.info)
|
|
||||||
res["info"].update({
|
# Add in iterator metrics.
|
||||||
STEPS_SAMPLED_COUNTER: metrics.counters[STEPS_SAMPLED_COUNTER],
|
metrics = LocalIterator.get_metrics()
|
||||||
STEPS_TRAINED_COUNTER: metrics.counters[STEPS_TRAINED_COUNTER],
|
if metrics.parent_metrics:
|
||||||
})
|
print("TODO: support nested metrics better")
|
||||||
|
all_metrics = [metrics] + metrics.parent_metrics
|
||||||
timers = {}
|
timers = {}
|
||||||
|
counters = {}
|
||||||
|
info = {}
|
||||||
|
for metrics in all_metrics:
|
||||||
|
info.update(metrics.info)
|
||||||
|
for k, counter in metrics.counters.items():
|
||||||
|
counters[k] = counter
|
||||||
for k, timer in metrics.timers.items():
|
for k, timer in metrics.timers.items():
|
||||||
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
|
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
|
||||||
if timer.has_units_processed():
|
if timer.has_units_processed():
|
||||||
timers["{}_throughput".format(k)] = round(
|
timers["{}_throughput".format(k)] = round(
|
||||||
timer.mean_throughput, 3)
|
timer.mean_throughput, 3)
|
||||||
res["timers"] = timers
|
|
||||||
res.update({
|
res.update({
|
||||||
"num_healthy_workers": len(self.workers.remote_workers()),
|
"num_healthy_workers": len(self.workers.remote_workers()),
|
||||||
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
|
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
|
||||||
})
|
})
|
||||||
|
res["timers"] = timers
|
||||||
|
res["info"] = info
|
||||||
|
res["info"].update(counters)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -392,13 +426,16 @@ class ApplyGradients:
|
||||||
self.workers.local_worker().apply_gradients(gradients)
|
self.workers.local_worker().apply_gradients(gradients)
|
||||||
apply_timer.push_units_processed(count)
|
apply_timer.push_units_processed(count)
|
||||||
|
|
||||||
|
# Also update global vars of the local worker.
|
||||||
|
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||||
|
|
||||||
if self.update_all:
|
if self.update_all:
|
||||||
if self.workers.remote_workers():
|
if self.workers.remote_workers():
|
||||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||||
weights = ray.put(
|
weights = ray.put(
|
||||||
self.workers.local_worker().get_weights())
|
self.workers.local_worker().get_weights())
|
||||||
for e in self.workers.remote_workers():
|
for e in self.workers.remote_workers():
|
||||||
e.set_weights.remote(weights)
|
e.set_weights.remote(weights, _get_global_vars())
|
||||||
else:
|
else:
|
||||||
if metrics.cur_actor is None:
|
if metrics.cur_actor is None:
|
||||||
raise ValueError("Could not find actor to update. When "
|
raise ValueError("Could not find actor to update. When "
|
||||||
|
@ -406,7 +443,8 @@ class ApplyGradients:
|
||||||
"in the iterator context.")
|
"in the iterator context.")
|
||||||
with metrics.timers[WORKER_UPDATE_TIMER]:
|
with metrics.timers[WORKER_UPDATE_TIMER]:
|
||||||
weights = self.workers.local_worker().get_weights()
|
weights = self.workers.local_worker().get_weights()
|
||||||
metrics.cur_actor.set_weights.remote(weights)
|
metrics.cur_actor.set_weights.remote(weights,
|
||||||
|
_get_global_vars())
|
||||||
|
|
||||||
|
|
||||||
class AverageGradients:
|
class AverageGradients:
|
||||||
|
@ -434,3 +472,125 @@ class AverageGradients:
|
||||||
logger.info("Computing average of {} microbatch gradients "
|
logger.info("Computing average of {} microbatch gradients "
|
||||||
"({} samples total)".format(len(gradients), sum_count))
|
"({} samples total)".format(len(gradients), sum_count))
|
||||||
return acc, sum_count
|
return acc, sum_count
|
||||||
|
|
||||||
|
|
||||||
|
class StoreToReplayBuffer:
|
||||||
|
def __init__(self, replay_buffer):
|
||||||
|
self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
|
||||||
|
|
||||||
|
def __call__(self, batch: SampleBatchType):
|
||||||
|
# Handle everything as if multiagent
|
||||||
|
if isinstance(batch, SampleBatch):
|
||||||
|
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
||||||
|
|
||||||
|
for policy_id, s in batch.policy_batches.items():
|
||||||
|
for row in s.rows():
|
||||||
|
self.replay_buffers[policy_id].add(
|
||||||
|
pack_if_needed(row["obs"]),
|
||||||
|
row["actions"],
|
||||||
|
row["rewards"],
|
||||||
|
pack_if_needed(row["new_obs"]),
|
||||||
|
row["dones"],
|
||||||
|
weight=None)
|
||||||
|
|
||||||
|
|
||||||
|
def LocalReplay(replay_buffer, train_batch_size):
|
||||||
|
replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
|
||||||
|
# TODO(ekl) support more options
|
||||||
|
synchronize_sampling = False
|
||||||
|
prioritized_replay_beta = None
|
||||||
|
|
||||||
|
def gen_replay(timeout):
|
||||||
|
while True:
|
||||||
|
samples = {}
|
||||||
|
idxes = None
|
||||||
|
for policy_id, replay_buffer in replay_buffers.items():
|
||||||
|
if synchronize_sampling:
|
||||||
|
if idxes is None:
|
||||||
|
idxes = replay_buffer.sample_idxes(train_batch_size)
|
||||||
|
else:
|
||||||
|
idxes = replay_buffer.sample_idxes(train_batch_size)
|
||||||
|
|
||||||
|
if isinstance(replay_buffer, PrioritizedReplayBuffer):
|
||||||
|
metrics = LocalIterator.get_metrics()
|
||||||
|
num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
|
||||||
|
(obses_t, actions, rewards, obses_tp1, dones, weights,
|
||||||
|
batch_indexes) = replay_buffer.sample_with_idxes(
|
||||||
|
idxes,
|
||||||
|
beta=prioritized_replay_beta.value(num_steps_trained))
|
||||||
|
else:
|
||||||
|
(obses_t, actions, rewards, obses_tp1,
|
||||||
|
dones) = replay_buffer.sample_with_idxes(idxes)
|
||||||
|
weights = np.ones_like(rewards)
|
||||||
|
batch_indexes = -np.ones_like(rewards)
|
||||||
|
samples[policy_id] = SampleBatch({
|
||||||
|
"obs": obses_t,
|
||||||
|
"actions": actions,
|
||||||
|
"rewards": rewards,
|
||||||
|
"new_obs": obses_tp1,
|
||||||
|
"dones": dones,
|
||||||
|
"weights": weights,
|
||||||
|
"batch_indexes": batch_indexes
|
||||||
|
})
|
||||||
|
yield MultiAgentBatch(samples, train_batch_size)
|
||||||
|
|
||||||
|
return LocalIterator(gen_replay, MetricsContext())
|
||||||
|
|
||||||
|
|
||||||
|
def Concurrently(ops: List[LocalIterator], mode="round_robin"):
|
||||||
|
"""Operator that runs the given parent iterators concurrently.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
mode (str): One of {'round_robin', 'async'}.
|
||||||
|
- In 'round_robin' mode, we alternate between pulling items from
|
||||||
|
each parent iterator in order deterministically.
|
||||||
|
- In 'async' mode, we pull from each parent iterator as fast as
|
||||||
|
they are produced. This is non-deterministic.
|
||||||
|
|
||||||
|
>>> sim_op = ParallelRollouts(...).for_each(...)
|
||||||
|
>>> replay_op = LocalReplay(...).for_each(...)
|
||||||
|
>>> combined_op = Concurrently([sim_op, replay_op])
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(ops) < 2:
|
||||||
|
raise ValueError("Should specify at least 2 ops.")
|
||||||
|
if mode == "round_robin":
|
||||||
|
deterministic = True
|
||||||
|
elif mode == "async":
|
||||||
|
deterministic = False
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown mode {}".format(mode))
|
||||||
|
return ops[0].union(*ops[1:], deterministic=deterministic)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateTargetNetwork:
|
||||||
|
"""Periodically call policy.update_target() on all trainable policies.
|
||||||
|
|
||||||
|
This should be used with the .for_each() operator after training step
|
||||||
|
has been taken.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
||||||
|
>>> update_op = train_op.for_each(
|
||||||
|
... UpdateTargetIfNeeded(workers, target_update_freq=500))
|
||||||
|
>>> print(next(update_op))
|
||||||
|
None
|
||||||
|
|
||||||
|
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
|
||||||
|
local iterator context. The value of the last update counter is used to
|
||||||
|
track when we should update the target next.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, workers, target_update_freq):
|
||||||
|
self.workers = workers
|
||||||
|
self.target_update_freq = target_update_freq
|
||||||
|
|
||||||
|
def __call__(self, _):
|
||||||
|
metrics = LocalIterator.get_metrics()
|
||||||
|
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
|
||||||
|
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
||||||
|
if cur_ts - last_update > self.target_update_freq:
|
||||||
|
self.workers.local_worker().foreach_trainable_policy(
|
||||||
|
lambda p, _: p.update_target())
|
||||||
|
metrics.counters[NUM_TARGET_UPDATES] += 1
|
||||||
|
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||||
|
|
Loading…
Add table
Reference in a new issue