mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Improve datapath throughput of IMPALA / APPO (#4324)
This commit is contained in:
parent
dffe19c59c
commit
0d94f3eeef
15 changed files with 860 additions and 345 deletions
|
@ -214,6 +214,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
--stop '{"training_iteration": 1}' \
|
||||
--config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
--run IMPALA \
|
||||
--stop '{"training_iteration": 1}' \
|
||||
--config '{"num_gpus": 0, "num_workers": 2, "num_aggregation_workers": 2, "min_iter_time_s": 1}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v0 \
|
||||
|
|
|
@ -745,7 +745,8 @@ class Agent(Trainable):
|
|||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"])
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
|
|
|
@ -8,7 +8,10 @@ from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
|||
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"lr",
|
||||
|
@ -23,6 +26,7 @@ OPTIMIZER_SHARED_CONFIGS = [
|
|||
"broadcast_interval",
|
||||
"num_sgd_iter",
|
||||
"minibatch_buffer_size",
|
||||
"num_aggregation_workers",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
|
@ -71,6 +75,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"max_sample_requests_in_flight_per_worker": 2,
|
||||
# max number of workers to broadcast one set of weights to
|
||||
"broadcast_interval": 1,
|
||||
# use intermediate actors for multi-level aggregation. This can make sense
|
||||
# if ingesting >2GB/s of samples, or if the data requires decompression.
|
||||
"num_aggregation_workers": 0,
|
||||
|
||||
# Learning params.
|
||||
"grad_clip": 40.0,
|
||||
|
@ -85,6 +92,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# balancing the three losses
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": 0.01,
|
||||
|
||||
# use fake (infinite speed) sampler for testing
|
||||
"_fake_sampler": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -104,7 +114,13 @@ class ImpalaAgent(Agent):
|
|||
config["optimizer"][k] = config[k]
|
||||
policy_cls = self._get_policy_graph()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, policy_cls)
|
||||
self.env_creator, policy_cls)
|
||||
|
||||
if self.config["num_aggregation_workers"] > 0:
|
||||
# Create co-located aggregator actors first for placement pref
|
||||
aggregators = TreeAggregator.precreate_aggregators(
|
||||
self.config["num_aggregation_workers"])
|
||||
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(
|
||||
|
@ -112,6 +128,22 @@ class ImpalaAgent(Agent):
|
|||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
if self.config["num_aggregation_workers"] > 0:
|
||||
# Assign the pre-created aggregators to the optimizer
|
||||
self.optimizer.aggregator.init(aggregators)
|
||||
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
Agent._validate_config(cf)
|
||||
return Resources(
|
||||
cpu=cf["num_cpus_for_driver"],
|
||||
gpu=cf["num_gpus"],
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"] +
|
||||
cf["num_aggregation_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@override(Agent)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
|
|
|
@ -27,7 +27,6 @@ from ray.rllib.models import ModelCatalog
|
|||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
|
||||
summarize, enable_periodic_logging
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
|
@ -125,7 +124,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False,
|
||||
remote_env_batch_wait_ms=0):
|
||||
remote_env_batch_wait_ms=0,
|
||||
_fake_sampler=False):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
|
@ -203,12 +203,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
can take much time to step / reset (e.g., for StarCraft)
|
||||
remote_env_batch_wait_ms (float): Timeout that remote workers
|
||||
are waiting when polling environments. 0 (continue when at
|
||||
least one env is ready) is a reasonable default, but optimal
|
||||
value could be obtained by measuring your environment
|
||||
step / reset and model inference perf.
|
||||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
"""
|
||||
|
||||
if log_level:
|
||||
|
@ -237,6 +237,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
self.preprocessing_enabled = True
|
||||
self.last_batch = None
|
||||
self._fake_sampler = _fake_sampler
|
||||
|
||||
self.env = _validate_env(env_creator(env_context))
|
||||
if isinstance(self.env, MultiAgentEnv) or \
|
||||
|
@ -403,6 +405,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
if self._fake_sampler and self.last_batch is not None:
|
||||
return self.last_batch
|
||||
|
||||
if log_once("sample_start"):
|
||||
logger.info("Generating sample batch of size {}".format(
|
||||
self.sample_batch_size))
|
||||
|
@ -444,15 +449,13 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
logger.info("Completed sample batch:\n\n{}\n".format(
|
||||
summarize(batch)))
|
||||
|
||||
if self.compress_observations:
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
data["obs"] = [pack(o) for o in data["obs"]]
|
||||
data["new_obs"] = [pack(o) for o in data["new_obs"]]
|
||||
else:
|
||||
batch["obs"] = [pack(o) for o in batch["obs"]]
|
||||
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
|
||||
if self.compress_observations == "bulk":
|
||||
batch.compress(bulk=True)
|
||||
elif self.compress_observations:
|
||||
batch.compress()
|
||||
|
||||
if self._fake_sampler:
|
||||
self.last_batch = batch
|
||||
return batch
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -6,7 +6,8 @@ import six
|
|||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
||||
from ray.rllib.utils.memory import concat_aligned
|
||||
|
||||
# Defaults policy id for single agent environments
|
||||
|
@ -65,6 +66,16 @@ class MultiAgentBatch(object):
|
|||
ct += batch.count
|
||||
return ct
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
for batch in self.policy_batches.values():
|
||||
batch.compress(bulk=bulk, columns=columns)
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
for batch in self.policy_batches.values():
|
||||
batch.decompress_if_needed(columns)
|
||||
|
||||
def __str__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
@ -246,6 +257,27 @@ class SampleBatch(object):
|
|||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
if bulk:
|
||||
self.data[key] = pack(self.data[key])
|
||||
else:
|
||||
self.data[key] = np.array(
|
||||
[pack(o) for o in self.data[key]])
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
if is_compressed(arr):
|
||||
self.data[key] = unpack(arr)
|
||||
elif len(arr) > 0 and is_compressed(arr[0]):
|
||||
self.data[key] = np.array(
|
||||
[unpack(o) for o in self.data[key]])
|
||||
|
||||
def __str__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
|
|
73
python/ray/rllib/examples/custom_fast_model.py
Normal file
73
python/ray/rllib/examples/custom_fast_model.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
"""Example of using a custom image env and model.
|
||||
|
||||
Both the model and env are trivial (and super-fast), so they are useful
|
||||
for running perf microbenchmarks.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Discrete, Box
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.tune import run_experiments, sample_from
|
||||
|
||||
|
||||
class FastModel(Model):
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
bias = tf.get_variable(
|
||||
dtype=tf.float32,
|
||||
name="bias",
|
||||
initializer=tf.zeros_initializer,
|
||||
shape=())
|
||||
output = bias + tf.zeros([tf.shape(input_dict["obs"])[0], num_outputs])
|
||||
return output, output
|
||||
|
||||
|
||||
class FastImageEnv(gym.Env):
|
||||
def __init__(self, config):
|
||||
self.zeros = np.zeros((84, 84, 4))
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Box(
|
||||
0.0, 1.0, shape=(84, 84, 4), dtype=np.float32)
|
||||
self.i = 0
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return self.zeros
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return self.zeros, 1, self.i > 1000, {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_model("fast_model", FastModel)
|
||||
run_experiments({
|
||||
"demo": {
|
||||
"run": "IMPALA",
|
||||
"env": FastImageEnv,
|
||||
"config": {
|
||||
"compress_observations": True,
|
||||
"model": {
|
||||
"custom_model": "fast_model"
|
||||
},
|
||||
"num_gpus": 0,
|
||||
"num_workers": 2,
|
||||
"num_envs_per_worker": 10,
|
||||
"num_data_loader_buffers": 1,
|
||||
"num_aggregation_workers": 1,
|
||||
"broadcast_interval": 50,
|
||||
"sample_batch_size": 100,
|
||||
"train_batch_size": sample_from(
|
||||
lambda spec: 1000 * max(1, spec.config.num_gpus)),
|
||||
"_fake_sampler": True,
|
||||
},
|
||||
},
|
||||
})
|
185
python/ray/rllib/optimizers/aso_aggregator.py
Normal file
185
python/ray/rllib/optimizers/aso_aggregator.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
"""Helper class for AsyncSamplesOptimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class Aggregator(object):
|
||||
"""An aggregator collects and processes samples from evaluators.
|
||||
|
||||
This class is used to abstract away the strategy for sample collection.
|
||||
For example, you may want to use a tree of actors to collect samples. The
|
||||
use of multiple actors can be necessary to offload expensive work such
|
||||
as concatenating and decompressing sample batches.
|
||||
|
||||
Attributes:
|
||||
local_evaluator: local PolicyEvaluator copy
|
||||
"""
|
||||
|
||||
def iter_train_batches(self):
|
||||
"""Returns a generator over batches ready to learn on.
|
||||
|
||||
Iterating through this generator will also send out weight updates to
|
||||
remote evaluators as needed.
|
||||
|
||||
This call may block until results are available.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def broadcast_new_weights(self):
|
||||
"""Broadcast a new set of weights from the local evaluator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def should_broadcast(self):
|
||||
"""Returns whether broadcast() should be called to update weights."""
|
||||
raise NotImplementedError
|
||||
|
||||
def stats(self):
|
||||
"""Returns runtime statistics for debugging."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self, remote_evaluators):
|
||||
"""Called to change the set of remote evaluators being used."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AggregationWorkerBase(object):
|
||||
"""Aggregators should extend from this class."""
|
||||
|
||||
def __init__(self, initial_weights_obj_id, remote_evaluators,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
||||
self.broadcasted_weights = initial_weights_obj_id
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
if replay_proportion:
|
||||
if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
|
||||
raise ValueError(
|
||||
"Replay buffer size is too small to produce train, "
|
||||
"please increase replay_buffer_num_slots.",
|
||||
replay_buffer_num_slots, sample_batch_size,
|
||||
train_batch_size)
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
for ev in self.remote_evaluators:
|
||||
ev.set_weights.remote(self.broadcasted_weights)
|
||||
for _ in range(max_sample_requests_in_flight_per_worker):
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
|
||||
self.batch_buffer = []
|
||||
|
||||
self.replay_proportion = replay_proportion
|
||||
self.replay_buffer_num_slots = replay_buffer_num_slots
|
||||
self.replay_batches = []
|
||||
self.num_sent_since_broadcast = 0
|
||||
self.num_weight_syncs = 0
|
||||
self.num_replayed = 0
|
||||
|
||||
@override(Aggregator)
|
||||
def iter_train_batches(self, max_yield=999):
|
||||
"""Iterate over train batches.
|
||||
|
||||
Arguments:
|
||||
max_yield (int): Max number of batches to iterate over in this
|
||||
cycle. Setting this avoids iter_train_batches returning too
|
||||
much data at once.
|
||||
"""
|
||||
|
||||
for ev, sample_batch in self._augment_with_replay(
|
||||
self.sample_tasks.completed_prefetch(
|
||||
blocking_wait=True, max_yield=max_yield)):
|
||||
sample_batch.decompress_if_needed()
|
||||
self.batch_buffer.append(sample_batch)
|
||||
if sum(b.count
|
||||
for b in self.batch_buffer) >= self.train_batch_size:
|
||||
train_batch = self.batch_buffer[0].concat_samples(
|
||||
self.batch_buffer)
|
||||
yield train_batch
|
||||
self.batch_buffer = []
|
||||
|
||||
# If the batch was replayed, skip the update below.
|
||||
if ev is None:
|
||||
continue
|
||||
|
||||
# Put in replay buffer if enabled
|
||||
if self.replay_buffer_num_slots > 0:
|
||||
self.replay_batches.append(sample_batch)
|
||||
if len(self.replay_batches) > self.replay_buffer_num_slots:
|
||||
self.replay_batches.pop(0)
|
||||
|
||||
ev.set_weights.remote(self.broadcasted_weights)
|
||||
self.num_weight_syncs += 1
|
||||
self.num_sent_since_broadcast += 1
|
||||
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
|
||||
@override(Aggregator)
|
||||
def stats(self):
|
||||
return {
|
||||
"num_weight_syncs": self.num_weight_syncs,
|
||||
"num_steps_replayed": self.num_replayed,
|
||||
}
|
||||
|
||||
@override(Aggregator)
|
||||
def reset(self, remote_evaluators):
|
||||
self.sample_tasks.reset_evaluators(remote_evaluators)
|
||||
|
||||
def _augment_with_replay(self, sample_futures):
|
||||
def can_replay():
|
||||
num_needed = int(
|
||||
np.ceil(self.train_batch_size / self.sample_batch_size))
|
||||
return len(self.replay_batches) > num_needed
|
||||
|
||||
for ev, sample_batch in sample_futures:
|
||||
sample_batch = ray.get(sample_batch)
|
||||
yield ev, sample_batch
|
||||
|
||||
if can_replay():
|
||||
f = self.replay_proportion
|
||||
while random.random() < f:
|
||||
f -= 1
|
||||
replay_batch = random.choice(self.replay_batches)
|
||||
self.num_replayed += replay_batch.count
|
||||
yield None, replay_batch
|
||||
|
||||
|
||||
class SimpleAggregator(AggregationWorkerBase, Aggregator):
|
||||
"""Simple single-threaded implementation of an Aggregator."""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
replay_proportion=0.0,
|
||||
replay_buffer_num_slots=0,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
broadcast_interval=5):
|
||||
self.local_evaluator = local_evaluator
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.broadcast_new_weights()
|
||||
AggregationWorkerBase.__init__(
|
||||
self, self.broadcasted_weights, remote_evaluators,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
||||
|
||||
@override(Aggregator)
|
||||
def broadcast_new_weights(self):
|
||||
self.broadcasted_weights = ray.put(self.local_evaluator.get_weights())
|
||||
self.num_sent_since_broadcast = 0
|
||||
|
||||
@override(Aggregator)
|
||||
def should_broadcast(self):
|
||||
return self.num_sent_since_broadcast >= self.broadcast_interval
|
60
python/ray/rllib/optimizers/aso_learner.py
Normal file
60
python/ray/rllib/optimizers/aso_learner.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
"""Helper class for AsyncSamplesOptimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from sample trajectories.
|
||||
|
||||
This is for use with AsyncSamplesOptimizer.
|
||||
|
||||
The learner thread communicates with the main thread through Queues. This
|
||||
is needed since Ray operations can only be run on the main thread. In
|
||||
addition, moving heavyweight gradient ops session runs off the main thread
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter,
|
||||
learner_queue_size):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_evaluator = local_evaluator
|
||||
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
||||
self.outqueue = queue.Queue()
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
self.inqueue, minibatch_buffer_size, num_sgd_iter)
|
||||
self.queue_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.load_timer = TimerStat()
|
||||
self.load_wait_timer = TimerStat()
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
self.stats = {}
|
||||
self.stopped = False
|
||||
|
||||
def run(self):
|
||||
while not self.stopped:
|
||||
self.step()
|
||||
|
||||
def step(self):
|
||||
with self.queue_timer:
|
||||
batch, _ = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = self.local_evaluator.learn_on_batch(batch)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
self.outqueue.put(batch.count)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
48
python/ray/rllib/optimizers/aso_minibatch_buffer.py
Normal file
48
python/ray/rllib/optimizers/aso_minibatch_buffer.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""Helper class for AsyncSamplesOptimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MinibatchBuffer(object):
|
||||
"""Ring buffer of recent data batches for minibatch SGD.
|
||||
|
||||
This is for use with AsyncSamplesOptimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, inqueue, size, num_passes):
|
||||
"""Initialize a minibatch buffer.
|
||||
|
||||
Arguments:
|
||||
inqueue: Queue to populate the internal ring buffer from.
|
||||
size: Max number of data items to buffer.
|
||||
num_passes: Max num times each data item should be emitted.
|
||||
"""
|
||||
self.inqueue = inqueue
|
||||
self.size = size
|
||||
self.max_ttl = num_passes
|
||||
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
|
||||
self.buffers = [None] * size
|
||||
self.ttl = [0] * size
|
||||
self.idx = 0
|
||||
|
||||
def get(self):
|
||||
"""Get a new batch from the internal ring buffer.
|
||||
|
||||
Returns:
|
||||
buf: Data item saved from inqueue.
|
||||
released: True if the item is now removed from the ring buffer.
|
||||
"""
|
||||
if self.ttl[self.idx] <= 0:
|
||||
self.buffers[self.idx] = self.inqueue.get(timeout=60.0)
|
||||
self.ttl[self.idx] = self.cur_max_ttl
|
||||
if self.cur_max_ttl < self.max_ttl:
|
||||
self.cur_max_ttl += 1
|
||||
buf = self.buffers[self.idx]
|
||||
self.ttl[self.idx] -= 1
|
||||
released = self.ttl[self.idx] <= 0
|
||||
if released:
|
||||
self.buffers[self.idx] = None
|
||||
self.idx = (self.idx + 1) % len(self.buffers)
|
||||
return buf, released
|
150
python/ray/rllib/optimizers/aso_multi_gpu_learner.py
Normal file
150
python/ray/rllib/optimizers/aso_multi_gpu_learner.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
"""Helper class for AsyncSamplesOptimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.aso_learner import LearnerThread
|
||||
from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TFMultiGPULearner(LearnerThread):
|
||||
"""Learner that can use multiple GPUs and parallel loading.
|
||||
|
||||
This is for use with AsyncSamplesOptimizer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
num_gpus=1,
|
||||
lr=0.0005,
|
||||
train_batch_size=500,
|
||||
num_data_loader_buffers=1,
|
||||
minibatch_buffer_size=1,
|
||||
num_sgd_iter=1,
|
||||
learner_queue_size=16,
|
||||
num_data_load_threads=16,
|
||||
_fake_gpus=False):
|
||||
# Multi-GPU requires TensorFlow to function.
|
||||
import tensorflow as tf
|
||||
|
||||
LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size,
|
||||
num_sgd_iter, learner_queue_size)
|
||||
self.lr = lr
|
||||
self.train_batch_size = train_batch_size
|
||||
if not num_gpus:
|
||||
self.devices = ["/cpu:0"]
|
||||
elif _fake_gpus:
|
||||
self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)]
|
||||
else:
|
||||
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
|
||||
logger.info("TFMultiGPULearner devices {}".format(self.devices))
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
||||
if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID]
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.par_opt = []
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
self.policy._seq_lens
|
||||
]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
adam = tf.train.AdamOptimizer(self.lr)
|
||||
for _ in range(num_data_loader_buffers):
|
||||
self.par_opt.append(
|
||||
LocalSyncParallelOptimizer(
|
||||
adam,
|
||||
self.devices,
|
||||
[v for _, v in self.policy._loss_inputs],
|
||||
rnn_inputs,
|
||||
999999, # it will get rounded down
|
||||
self.policy.copy))
|
||||
|
||||
self.sess = self.local_evaluator.tf_sess
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.idle_optimizers = queue.Queue()
|
||||
self.ready_optimizers = queue.Queue()
|
||||
for opt in self.par_opt:
|
||||
self.idle_optimizers.put(opt)
|
||||
for i in range(num_data_load_threads):
|
||||
self.loader_thread = _LoaderThread(self, share_stats=(i == 0))
|
||||
self.loader_thread.start()
|
||||
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
self.ready_optimizers, minibatch_buffer_size, num_sgd_iter)
|
||||
|
||||
@override(LearnerThread)
|
||||
def step(self):
|
||||
assert self.loader_thread.is_alive()
|
||||
with self.load_wait_timer:
|
||||
opt, released = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = opt.optimize(self.sess, 0)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
self.outqueue.put(opt.num_tuples_loaded)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
|
||||
class _LoaderThread(threading.Thread):
|
||||
def __init__(self, learner, share_stats):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner = learner
|
||||
self.daemon = True
|
||||
if share_stats:
|
||||
self.queue_timer = learner.queue_timer
|
||||
self.load_timer = learner.load_timer
|
||||
else:
|
||||
self.queue_timer = TimerStat()
|
||||
self.load_timer = TimerStat()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self._step()
|
||||
|
||||
def _step(self):
|
||||
s = self.learner
|
||||
with self.queue_timer:
|
||||
batch = s.inqueue.get()
|
||||
|
||||
opt = s.idle_optimizers.get()
|
||||
|
||||
with self.load_timer:
|
||||
tuples = s.policy._get_loss_inputs_dict(batch)
|
||||
data_keys = [ph for _, ph in s.policy._loss_inputs]
|
||||
if s.policy._state_inputs:
|
||||
state_keys = s.policy._state_inputs + [s.policy._seq_lens]
|
||||
else:
|
||||
state_keys = []
|
||||
opt.load_data(s.sess, [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys])
|
||||
|
||||
s.ready_optimizers.put(opt)
|
160
python/ray/rllib/optimizers/aso_tree_aggregator.py
Normal file
160
python/ray/rllib/optimizers/aso_tree_aggregator.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
"""Helper class for AsyncSamplesOptimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.actors import TaskPool, create_colocated
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.optimizers.aso_aggregator import Aggregator, \
|
||||
AggregationWorkerBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TreeAggregator(Aggregator):
|
||||
"""A hierarchical experiences aggregator.
|
||||
|
||||
The given set of remote evaluators is divided into subsets and assigned to
|
||||
one of several aggregation workers. These aggregation workers collate
|
||||
experiences into batches of size `train_batch_size` and we collect them
|
||||
in this class when `iter_train_batches` is called.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
num_aggregation_workers,
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
replay_proportion=0.0,
|
||||
replay_buffer_num_slots=0,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
broadcast_interval=5):
|
||||
self.local_evaluator = local_evaluator
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.num_aggregation_workers = num_aggregation_workers
|
||||
self.max_sample_requests_in_flight_per_worker = \
|
||||
max_sample_requests_in_flight_per_worker
|
||||
self.replay_proportion = replay_proportion
|
||||
self.replay_buffer_num_slots = replay_buffer_num_slots
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.train_batch_size = train_batch_size
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.broadcasted_weights = ray.put(local_evaluator.get_weights())
|
||||
self.num_batches_processed = 0
|
||||
self.num_broadcasts = 0
|
||||
self.num_sent_since_broadcast = 0
|
||||
self.initialized = False
|
||||
|
||||
def init(self, aggregators):
|
||||
"""Deferred init so that we can pass in previously created workers."""
|
||||
|
||||
assert len(aggregators) == self.num_aggregation_workers, aggregators
|
||||
if len(self.remote_evaluators) < self.num_aggregation_workers:
|
||||
raise ValueError(
|
||||
"The number of aggregation workers should not exceed the "
|
||||
"number of total evaluation workers ({} vs {})".format(
|
||||
self.num_aggregation_workers, len(self.remote_evaluators)))
|
||||
|
||||
assigned_evaluators = collections.defaultdict(list)
|
||||
for i, ev in enumerate(self.remote_evaluators):
|
||||
assigned_evaluators[i % self.num_aggregation_workers].append(ev)
|
||||
|
||||
self.workers = aggregators
|
||||
for i, worker in enumerate(self.workers):
|
||||
worker.init.remote(
|
||||
self.broadcasted_weights, assigned_evaluators[i],
|
||||
self.max_sample_requests_in_flight_per_worker,
|
||||
self.replay_proportion, self.replay_buffer_num_slots,
|
||||
self.train_batch_size, self.sample_batch_size)
|
||||
|
||||
self.agg_tasks = TaskPool()
|
||||
for agg in self.workers:
|
||||
agg.set_weights.remote(self.broadcasted_weights)
|
||||
self.agg_tasks.add(agg, agg.get_train_batches.remote())
|
||||
|
||||
self.initialized = True
|
||||
|
||||
@override(Aggregator)
|
||||
def iter_train_batches(self):
|
||||
assert self.initialized, "Must call init() before using this class."
|
||||
for agg, batches in self.agg_tasks.completed_prefetch():
|
||||
for b in ray.get(batches):
|
||||
self.num_sent_since_broadcast += 1
|
||||
yield b
|
||||
agg.set_weights.remote(self.broadcasted_weights)
|
||||
self.agg_tasks.add(agg, agg.get_train_batches.remote())
|
||||
self.num_batches_processed += 1
|
||||
|
||||
@override(Aggregator)
|
||||
def broadcast_new_weights(self):
|
||||
self.broadcasted_weights = ray.put(self.local_evaluator.get_weights())
|
||||
self.num_sent_since_broadcast = 0
|
||||
self.num_broadcasts += 1
|
||||
|
||||
@override(Aggregator)
|
||||
def should_broadcast(self):
|
||||
return self.num_sent_since_broadcast >= self.broadcast_interval
|
||||
|
||||
@override(Aggregator)
|
||||
def stats(self):
|
||||
return {
|
||||
"num_broadcasts": self.num_broadcasts,
|
||||
"num_batches_processed": self.num_batches_processed,
|
||||
}
|
||||
|
||||
@override(Aggregator)
|
||||
def reset(self, remote_evaluators):
|
||||
raise NotImplementedError("changing number of remote evaluators")
|
||||
|
||||
@staticmethod
|
||||
def precreate_aggregators(n):
|
||||
return create_colocated(AggregationWorker, [], n)
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class AggregationWorker(AggregationWorkerBase):
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def init(self, initial_weights_obj_id, remote_evaluators,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
||||
"""Deferred init that assigns sub-workers to this aggregator."""
|
||||
|
||||
logger.info("Assigned evaluators {} to aggregation worker {}".format(
|
||||
remote_evaluators, self))
|
||||
assert remote_evaluators
|
||||
AggregationWorkerBase.__init__(
|
||||
self, initial_weights_obj_id, remote_evaluators,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
||||
self.initialized = True
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.broadcasted_weights = weights
|
||||
|
||||
def get_train_batches(self):
|
||||
assert self.initialized, "Must call init() before using this class."
|
||||
start = time.time()
|
||||
result = []
|
||||
for batch in self.iter_train_batches(max_yield=5):
|
||||
result.append(batch)
|
||||
while not result:
|
||||
time.sleep(0.01)
|
||||
for batch in self.iter_train_batches(max_yield=5):
|
||||
result.append(batch)
|
||||
logger.debug("Returning {} train batches, {}s".format(
|
||||
len(result),
|
||||
time.time() - start))
|
||||
return result
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
|
@ -1,4 +1,4 @@
|
|||
"""Implements the IMPALA architecture.
|
||||
"""Implements the IMPALA asynchronous sampling architecture.
|
||||
|
||||
https://arxiv.org/abs/1802.01561"""
|
||||
|
||||
|
@ -7,27 +7,18 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.aso_aggregator import SimpleAggregator
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.optimizers.aso_learner import LearnerThread
|
||||
from ray.rllib.optimizers.aso_multi_gpu_learner import TFMultiGPULearner
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.actors import TaskPool
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NUM_DATA_LOAD_THREADS = 16
|
||||
|
||||
|
||||
class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the IMPALA architecture.
|
||||
|
@ -51,11 +42,8 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
|||
num_sgd_iter=1,
|
||||
minibatch_buffer_size=1,
|
||||
learner_queue_size=16,
|
||||
num_aggregation_workers=0,
|
||||
_fake_gpus=False):
|
||||
self.train_batch_size = train_batch_size
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.broadcast_interval = broadcast_interval
|
||||
|
||||
self._stats_start_time = time.time()
|
||||
self._last_stats_time = {}
|
||||
self._last_stats_sum = {}
|
||||
|
@ -88,32 +76,32 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
|||
|
||||
# Stats
|
||||
self._optimizer_step_timer = TimerStat()
|
||||
self.num_weight_syncs = 0
|
||||
self.num_replayed = 0
|
||||
self._stats_start_time = time.time()
|
||||
self._last_stats_time = {}
|
||||
self._last_stats_val = {}
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
weights = self.local_evaluator.get_weights()
|
||||
for ev in self.remote_evaluators:
|
||||
ev.set_weights.remote(weights)
|
||||
for _ in range(max_sample_requests_in_flight_per_worker):
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
|
||||
self.batch_buffer = []
|
||||
|
||||
if replay_proportion:
|
||||
if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
|
||||
raise ValueError(
|
||||
"Replay buffer size is too small to produce train, "
|
||||
"please increase replay_buffer_num_slots.",
|
||||
replay_buffer_num_slots, sample_batch_size,
|
||||
train_batch_size)
|
||||
self.replay_proportion = replay_proportion
|
||||
self.replay_buffer_num_slots = replay_buffer_num_slots
|
||||
self.replay_batches = []
|
||||
if num_aggregation_workers > 0:
|
||||
self.aggregator = TreeAggregator(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
num_aggregation_workers,
|
||||
replay_proportion=replay_proportion,
|
||||
max_sample_requests_in_flight_per_worker=(
|
||||
max_sample_requests_in_flight_per_worker),
|
||||
replay_buffer_num_slots=replay_buffer_num_slots,
|
||||
train_batch_size=train_batch_size,
|
||||
sample_batch_size=sample_batch_size,
|
||||
broadcast_interval=broadcast_interval)
|
||||
else:
|
||||
self.aggregator = SimpleAggregator(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
replay_proportion=replay_proportion,
|
||||
max_sample_requests_in_flight_per_worker=(
|
||||
max_sample_requests_in_flight_per_worker),
|
||||
replay_buffer_num_slots=replay_buffer_num_slots,
|
||||
train_batch_size=train_batch_size,
|
||||
sample_batch_size=sample_batch_size,
|
||||
broadcast_interval=broadcast_interval)
|
||||
|
||||
def add_stat_val(self, key, val):
|
||||
if key not in self._last_stats_sum:
|
||||
|
@ -157,14 +145,16 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
|||
@override(PolicyOptimizer)
|
||||
def reset(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.sample_tasks.reset_evaluators(remote_evaluators)
|
||||
self.aggregator.reset(remote_evaluators)
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
def timer_to_ms(timer):
|
||||
return round(1000 * timer.mean, 3)
|
||||
|
||||
timing = {
|
||||
stats = self.aggregator.stats()
|
||||
stats.update(self.get_mean_stats_and_reset())
|
||||
stats["timing_breakdown"] = {
|
||||
"optimizer_step_time_ms": timer_to_ms(self._optimizer_step_timer),
|
||||
"learner_grad_time_ms": timer_to_ms(self.learner.grad_timer),
|
||||
"learner_load_time_ms": timer_to_ms(self.learner.load_timer),
|
||||
|
@ -172,288 +162,23 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
|||
self.learner.load_wait_timer),
|
||||
"learner_dequeue_time_ms": timer_to_ms(self.learner.queue_timer),
|
||||
}
|
||||
stats = dict({
|
||||
"num_weight_syncs": self.num_weight_syncs,
|
||||
"num_steps_replayed": self.num_replayed,
|
||||
"timing_breakdown": timing,
|
||||
"learner_queue": self.learner.learner_queue_size.stats(),
|
||||
}, **self.get_mean_stats_and_reset())
|
||||
self._last_stats_val.clear()
|
||||
stats["learner_queue"] = self.learner.learner_queue_size.stats()
|
||||
if self.learner.stats:
|
||||
stats["learner"] = self.learner.stats
|
||||
return dict(PolicyOptimizer.stats(self), **stats)
|
||||
|
||||
def _step(self):
|
||||
sample_timesteps, train_timesteps = 0, 0
|
||||
num_sent = 0
|
||||
weights = None
|
||||
|
||||
for ev, sample_batch in self._augment_with_replay(
|
||||
self.sample_tasks.completed_prefetch()):
|
||||
self.batch_buffer.append(sample_batch)
|
||||
if sum(b.count
|
||||
for b in self.batch_buffer) >= self.train_batch_size:
|
||||
train_batch = self.batch_buffer[0].concat_samples(
|
||||
self.batch_buffer)
|
||||
self.learner.inqueue.put(train_batch)
|
||||
self.batch_buffer = []
|
||||
|
||||
# If the batch was replayed, skip the update below.
|
||||
if ev is None:
|
||||
continue
|
||||
|
||||
sample_timesteps += sample_batch.count
|
||||
|
||||
# Put in replay buffer if enabled
|
||||
if self.replay_buffer_num_slots > 0:
|
||||
self.replay_batches.append(sample_batch)
|
||||
if len(self.replay_batches) > self.replay_buffer_num_slots:
|
||||
self.replay_batches.pop(0)
|
||||
|
||||
# Note that it's important to pull new weights once
|
||||
# updated to avoid excessive correlation between actors
|
||||
if weights is None or (self.learner.weights_updated
|
||||
and num_sent >= self.broadcast_interval):
|
||||
self.learner.weights_updated = False
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
num_sent = 0
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
num_sent += 1
|
||||
|
||||
# Kick off another sample request
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
for train_batch in self.aggregator.iter_train_batches():
|
||||
sample_timesteps += train_batch.count
|
||||
self.learner.inqueue.put(train_batch)
|
||||
if (self.learner.weights_updated
|
||||
and self.aggregator.should_broadcast()):
|
||||
self.aggregator.broadcast_new_weights()
|
||||
|
||||
while not self.learner.outqueue.empty():
|
||||
count = self.learner.outqueue.get()
|
||||
train_timesteps += count
|
||||
|
||||
return sample_timesteps, train_timesteps
|
||||
|
||||
def _augment_with_replay(self, sample_futures):
|
||||
def can_replay():
|
||||
num_needed = int(
|
||||
np.ceil(self.train_batch_size / self.sample_batch_size))
|
||||
return len(self.replay_batches) > num_needed
|
||||
|
||||
for ev, sample_batch in sample_futures:
|
||||
sample_batch = ray.get(sample_batch)
|
||||
yield ev, sample_batch
|
||||
|
||||
if can_replay():
|
||||
f = self.replay_proportion
|
||||
while random.random() < f:
|
||||
f -= 1
|
||||
replay_batch = random.choice(self.replay_batches)
|
||||
self.num_replayed += replay_batch.count
|
||||
yield None, replay_batch
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from sample trajectories.
|
||||
|
||||
The learner thread communicates with the main thread through Queues. This
|
||||
is needed since Ray operations can only be run on the main thread. In
|
||||
addition, moving heavyweight gradient ops session runs off the main thread
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter,
|
||||
learner_queue_size):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_evaluator = local_evaluator
|
||||
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
||||
self.outqueue = queue.Queue()
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
self.inqueue, minibatch_buffer_size, num_sgd_iter)
|
||||
self.queue_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.load_timer = TimerStat()
|
||||
self.load_wait_timer = TimerStat()
|
||||
self.daemon = True
|
||||
self.weights_updated = False
|
||||
self.stats = {}
|
||||
self.stopped = False
|
||||
|
||||
def run(self):
|
||||
while not self.stopped:
|
||||
self.step()
|
||||
|
||||
def step(self):
|
||||
with self.queue_timer:
|
||||
batch, _ = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = self.local_evaluator.learn_on_batch(batch)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
self.outqueue.put(batch.count)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
|
||||
class TFMultiGPULearner(LearnerThread):
|
||||
"""Learner that can use multiple GPUs and parallel loading."""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
num_gpus=1,
|
||||
lr=0.0005,
|
||||
train_batch_size=500,
|
||||
num_data_loader_buffers=1,
|
||||
minibatch_buffer_size=1,
|
||||
num_sgd_iter=1,
|
||||
learner_queue_size=16,
|
||||
_fake_gpus=False):
|
||||
# Multi-GPU requires TensorFlow to function.
|
||||
import tensorflow as tf
|
||||
|
||||
LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size,
|
||||
num_sgd_iter, learner_queue_size)
|
||||
self.lr = lr
|
||||
self.train_batch_size = train_batch_size
|
||||
if not num_gpus:
|
||||
self.devices = ["/cpu:0"]
|
||||
elif _fake_gpus:
|
||||
self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)]
|
||||
else:
|
||||
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
|
||||
logger.info("TFMultiGPULearner devices {}".format(self.devices))
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
||||
if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID]
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.par_opt = []
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
self.policy._seq_lens
|
||||
]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
adam = tf.train.AdamOptimizer(self.lr)
|
||||
for _ in range(num_data_loader_buffers):
|
||||
self.par_opt.append(
|
||||
LocalSyncParallelOptimizer(
|
||||
adam,
|
||||
self.devices,
|
||||
[v for _, v in self.policy._loss_inputs],
|
||||
rnn_inputs,
|
||||
999999, # it will get rounded down
|
||||
self.policy.copy))
|
||||
|
||||
self.sess = self.local_evaluator.tf_sess
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.idle_optimizers = queue.Queue()
|
||||
self.ready_optimizers = queue.Queue()
|
||||
for opt in self.par_opt:
|
||||
self.idle_optimizers.put(opt)
|
||||
for i in range(NUM_DATA_LOAD_THREADS):
|
||||
self.loader_thread = _LoaderThread(self, share_stats=(i == 0))
|
||||
self.loader_thread.start()
|
||||
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
self.ready_optimizers, minibatch_buffer_size, num_sgd_iter)
|
||||
|
||||
@override(LearnerThread)
|
||||
def step(self):
|
||||
assert self.loader_thread.is_alive()
|
||||
with self.load_wait_timer:
|
||||
opt, released = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = opt.optimize(self.sess, 0)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
if released:
|
||||
self.idle_optimizers.put(opt)
|
||||
|
||||
self.outqueue.put(opt.num_tuples_loaded)
|
||||
self.learner_queue_size.push(self.inqueue.qsize())
|
||||
|
||||
|
||||
class _LoaderThread(threading.Thread):
|
||||
def __init__(self, learner, share_stats):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner = learner
|
||||
self.daemon = True
|
||||
if share_stats:
|
||||
self.queue_timer = learner.queue_timer
|
||||
self.load_timer = learner.load_timer
|
||||
else:
|
||||
self.queue_timer = TimerStat()
|
||||
self.load_timer = TimerStat()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
self._step()
|
||||
|
||||
def _step(self):
|
||||
s = self.learner
|
||||
with self.queue_timer:
|
||||
batch = s.inqueue.get()
|
||||
|
||||
opt = s.idle_optimizers.get()
|
||||
|
||||
with self.load_timer:
|
||||
tuples = s.policy._get_loss_inputs_dict(batch)
|
||||
data_keys = [ph for _, ph in s.policy._loss_inputs]
|
||||
if s.policy._state_inputs:
|
||||
state_keys = s.policy._state_inputs + [s.policy._seq_lens]
|
||||
else:
|
||||
state_keys = []
|
||||
opt.load_data(s.sess, [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys])
|
||||
|
||||
s.ready_optimizers.put(opt)
|
||||
|
||||
|
||||
class MinibatchBuffer(object):
|
||||
"""Ring buffer of recent data batches for minibatch SGD."""
|
||||
|
||||
def __init__(self, inqueue, size, num_passes):
|
||||
"""Initialize a minibatch buffer.
|
||||
|
||||
Arguments:
|
||||
inqueue: Queue to populate the internal ring buffer from.
|
||||
size: Max number of data items to buffer.
|
||||
num_passes: Max num times each data item should be emitted.
|
||||
"""
|
||||
self.inqueue = inqueue
|
||||
self.size = size
|
||||
self.max_ttl = num_passes
|
||||
self.cur_max_ttl = 1 # ramp up slowly to better mix the input data
|
||||
self.buffers = [None] * size
|
||||
self.ttl = [0] * size
|
||||
self.idx = 0
|
||||
|
||||
def get(self):
|
||||
"""Get a new batch from the internal ring buffer.
|
||||
|
||||
Returns:
|
||||
buf: Data item saved from inqueue.
|
||||
released: True if the item is now removed from the ring buffer.
|
||||
"""
|
||||
if self.ttl[self.idx] <= 0:
|
||||
self.buffers[self.idx] = self.inqueue.get()
|
||||
self.ttl[self.idx] = self.cur_max_ttl
|
||||
if self.cur_max_ttl < self.max_ttl:
|
||||
self.cur_max_ttl += 1
|
||||
buf = self.buffers[self.idx]
|
||||
self.ttl[self.idx] -= 1
|
||||
released = self.ttl[self.idx] <= 0
|
||||
if released:
|
||||
self.buffers[self.idx] = None
|
||||
self.idx = (self.idx + 1) % len(self.buffers)
|
||||
return buf, released
|
||||
|
|
|
@ -14,6 +14,7 @@ from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
|||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.tests.mock_evaluator import _MockEvaluator
|
||||
|
||||
|
||||
|
@ -157,9 +158,11 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
|||
"train_batch_size": 10,
|
||||
})
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
self.assertLess(optimizer.stats()["num_steps_sampled"], 5000)
|
||||
self.assertGreater(optimizer.stats()["num_steps_replayed"], 8000)
|
||||
self.assertGreater(optimizer.stats()["num_steps_trained"], 8000)
|
||||
stats = optimizer.stats()
|
||||
self.assertLess(stats["num_steps_sampled"], 5000)
|
||||
replay_ratio = stats["num_steps_replayed"] / stats["num_steps_sampled"]
|
||||
self.assertGreater(replay_ratio, 0.7)
|
||||
self.assertLess(stats["num_steps_trained"], stats["num_steps_sampled"])
|
||||
|
||||
def testReplayAndMultiplePasses(self):
|
||||
local, remotes = self._make_evs()
|
||||
|
@ -173,9 +176,31 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
|||
"train_batch_size": 10,
|
||||
})
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
self.assertLess(optimizer.stats()["num_steps_sampled"], 5000)
|
||||
self.assertGreater(optimizer.stats()["num_steps_replayed"], 8000)
|
||||
self.assertGreater(optimizer.stats()["num_steps_trained"], 40000)
|
||||
|
||||
stats = optimizer.stats()
|
||||
print(stats)
|
||||
self.assertLess(stats["num_steps_sampled"], 5000)
|
||||
replay_ratio = stats["num_steps_replayed"] / stats["num_steps_sampled"]
|
||||
train_ratio = stats["num_steps_sampled"] / stats["num_steps_trained"]
|
||||
self.assertGreater(replay_ratio, 0.7)
|
||||
self.assertLess(train_ratio, 0.4)
|
||||
|
||||
def testMultiTierAggregationBadConf(self):
|
||||
local, remotes = self._make_evs()
|
||||
aggregators = TreeAggregator.precreate_aggregators(4)
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes,
|
||||
{"num_aggregation_workers": 4})
|
||||
self.assertRaises(ValueError,
|
||||
lambda: optimizer.aggregator.init(aggregators))
|
||||
|
||||
def testMultiTierAggregation(self):
|
||||
local, remotes = self._make_evs()
|
||||
aggregators = TreeAggregator.precreate_aggregators(1)
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes, {
|
||||
"num_aggregation_workers": 1,
|
||||
})
|
||||
optimizer.aggregator.init(aggregators)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testRejectBadConfigs(self):
|
||||
local, remotes = self._make_evs()
|
||||
|
|
|
@ -25,30 +25,35 @@ class TaskPool(object):
|
|||
self._tasks[obj_id] = worker
|
||||
self._objects[obj_id] = all_obj_ids
|
||||
|
||||
def completed(self):
|
||||
def completed(self, blocking_wait=False):
|
||||
pending = list(self._tasks)
|
||||
if pending:
|
||||
ready, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=0.01)
|
||||
ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0)
|
||||
if not ready and blocking_wait:
|
||||
ready, _ = ray.wait(pending, num_returns=1, timeout=10.0)
|
||||
for obj_id in ready:
|
||||
yield (self._tasks.pop(obj_id), self._objects.pop(obj_id))
|
||||
|
||||
def completed_prefetch(self):
|
||||
def completed_prefetch(self, blocking_wait=False, max_yield=999):
|
||||
"""Similar to completed but only returns once the object is local.
|
||||
|
||||
Assumes obj_id only is one id."""
|
||||
|
||||
for worker, obj_id in self.completed():
|
||||
for worker, obj_id in self.completed(blocking_wait=blocking_wait):
|
||||
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.binary())
|
||||
(ray.worker.global_worker.raylet_client.fetch_or_reconstruct(
|
||||
[obj_id], True))
|
||||
self._fetching.append((worker, obj_id))
|
||||
|
||||
remaining = []
|
||||
num_yielded = 0
|
||||
for worker, obj_id in self._fetching:
|
||||
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.binary())
|
||||
if ray.worker.global_worker.plasma_client.contains(plasma_id):
|
||||
if (num_yielded < max_yield
|
||||
and ray.worker.global_worker.plasma_client.contains(
|
||||
plasma_id)):
|
||||
yield (worker, obj_id)
|
||||
num_yielded += 1
|
||||
else:
|
||||
remaining.append((worker, obj_id))
|
||||
self._fetching = remaining
|
||||
|
@ -92,8 +97,10 @@ def split_colocated(actors):
|
|||
|
||||
def try_create_colocated(cls, args, count):
|
||||
actors = [cls.remote(*args) for _ in range(count)]
|
||||
local, _ = split_colocated(actors)
|
||||
local, rest = split_colocated(actors)
|
||||
logger.info("Got {} colocated actors of {}".format(len(local), count))
|
||||
for a in rest:
|
||||
a.__ray_terminate__.remote()
|
||||
return local
|
||||
|
||||
|
||||
|
@ -107,4 +114,6 @@ def create_colocated(cls, args, count):
|
|||
i += 1
|
||||
if len(ok) < count:
|
||||
raise Exception("Unable to create enough colocated actors, abort.")
|
||||
for a in ok[count:]:
|
||||
a.__ray_terminate__.remote()
|
||||
return ok[:count]
|
||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
import logging
|
||||
import time
|
||||
import base64
|
||||
|
@ -9,8 +11,6 @@ import numpy as np
|
|||
import pyarrow
|
||||
from six import string_types
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
|
@ -52,11 +52,16 @@ def unpack(data):
|
|||
|
||||
@DeveloperAPI
|
||||
def unpack_if_needed(data):
|
||||
if isinstance(data, bytes) or isinstance(data, string_types):
|
||||
if is_compressed(data):
|
||||
data = unpack(data)
|
||||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def is_compressed(data):
|
||||
return isinstance(data, bytes) or isinstance(data, string_types)
|
||||
|
||||
|
||||
# Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
|
||||
# Compression speed: 753.664 MB/s
|
||||
# Compression ratio: 87.4839812046
|
||||
|
|
Loading…
Add table
Reference in a new issue