2017-12-06 17:51:57 -08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2019-08-23 02:21:11 -04:00
|
|
|
import random
|
2019-09-11 12:15:34 -07:00
|
|
|
from collections import defaultdict
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
import ray
|
2019-09-11 12:15:34 -07:00
|
|
|
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
|
|
|
from ray.rllib.optimizers.multi_gpu_optimizer import _averaged
|
2018-03-15 15:57:31 -07:00
|
|
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
2019-09-11 12:15:34 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
|
|
|
MultiAgentBatch
|
2018-12-08 16:28:58 -08:00
|
|
|
from ray.rllib.utils.annotations import override
|
2018-03-04 12:25:25 -08:00
|
|
|
from ray.rllib.utils.filter import RunningStat
|
2017-12-06 17:51:57 -08:00
|
|
|
from ray.rllib.utils.timer import TimerStat
|
2019-04-17 20:30:03 -04:00
|
|
|
from ray.rllib.utils.memory import ray_get_and_free
|
2017-12-06 17:51:57 -08:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2017-12-06 17:51:57 -08:00
|
|
|
|
2018-06-27 02:30:15 -07:00
|
|
|
class SyncSamplesOptimizer(PolicyOptimizer):
|
2017-12-06 17:51:57 -08:00
|
|
|
"""A simple synchronous RL optimizer.
|
|
|
|
|
|
|
|
In each step, this optimizer pulls samples from a number of remote
|
2019-06-03 06:49:24 +08:00
|
|
|
workers, concatenates them, and then updates a local model. The updated
|
|
|
|
model weights are then broadcast to all remote workers.
|
2017-12-06 17:51:57 -08:00
|
|
|
"""
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def __init__(self,
|
|
|
|
workers,
|
|
|
|
num_sgd_iter=1,
|
|
|
|
train_batch_size=1,
|
2019-09-11 12:15:34 -07:00
|
|
|
sgd_minibatch_size=0,
|
|
|
|
standardize_fields=frozenset([])):
|
2019-06-03 06:49:24 +08:00
|
|
|
PolicyOptimizer.__init__(self, workers)
|
2019-04-12 21:03:26 -07:00
|
|
|
|
2017-12-06 17:51:57 -08:00
|
|
|
self.update_weights_timer = TimerStat()
|
2019-09-11 12:15:34 -07:00
|
|
|
self.standardize_fields = standardize_fields
|
2017-12-06 17:51:57 -08:00
|
|
|
self.sample_timer = TimerStat()
|
|
|
|
self.grad_timer = TimerStat()
|
2018-03-04 12:25:25 -08:00
|
|
|
self.throughput = RunningStat()
|
2018-07-12 19:22:46 +02:00
|
|
|
self.num_sgd_iter = num_sgd_iter
|
2019-08-23 02:21:11 -04:00
|
|
|
self.sgd_minibatch_size = sgd_minibatch_size
|
2018-09-05 12:06:13 -07:00
|
|
|
self.train_batch_size = train_batch_size
|
2018-08-23 17:49:10 -07:00
|
|
|
self.learner_stats = {}
|
2019-09-11 12:15:34 -07:00
|
|
|
self.policies = dict(self.workers.local_worker()
|
|
|
|
.foreach_trainable_policy(lambda p, i: (i, p)))
|
|
|
|
logger.debug("Policies to train: {}".format(self.policies))
|
2017-12-06 17:51:57 -08:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(PolicyOptimizer)
|
2017-12-06 17:51:57 -08:00
|
|
|
def step(self):
|
|
|
|
with self.update_weights_timer:
|
2019-06-03 06:49:24 +08:00
|
|
|
if self.workers.remote_workers():
|
|
|
|
weights = ray.put(self.workers.local_worker().get_weights())
|
|
|
|
for e in self.workers.remote_workers():
|
2017-12-06 17:51:57 -08:00
|
|
|
e.set_weights.remote(weights)
|
|
|
|
|
|
|
|
with self.sample_timer:
|
2018-08-06 12:10:59 -07:00
|
|
|
samples = []
|
2018-09-05 12:06:13 -07:00
|
|
|
while sum(s.count for s in samples) < self.train_batch_size:
|
2019-06-03 06:49:24 +08:00
|
|
|
if self.workers.remote_workers():
|
2018-08-06 12:10:59 -07:00
|
|
|
samples.extend(
|
2019-04-17 20:30:03 -04:00
|
|
|
ray_get_and_free([
|
2019-06-03 06:49:24 +08:00
|
|
|
e.sample.remote()
|
|
|
|
for e in self.workers.remote_workers()
|
2018-08-06 12:10:59 -07:00
|
|
|
]))
|
|
|
|
else:
|
2019-06-03 06:49:24 +08:00
|
|
|
samples.append(self.workers.local_worker().sample())
|
2018-08-06 12:10:59 -07:00
|
|
|
samples = SampleBatch.concat_samples(samples)
|
2018-08-20 15:28:03 -07:00
|
|
|
self.sample_timer.push_units_processed(samples.count)
|
2017-12-06 17:51:57 -08:00
|
|
|
|
2019-09-11 12:15:34 -07:00
|
|
|
# Handle everything as if multiagent
|
|
|
|
if isinstance(samples, SampleBatch):
|
|
|
|
samples = MultiAgentBatch({
|
|
|
|
DEFAULT_POLICY_ID: samples
|
|
|
|
}, samples.count)
|
|
|
|
|
|
|
|
fetches = {}
|
2017-12-06 17:51:57 -08:00
|
|
|
with self.grad_timer:
|
2019-09-11 12:15:34 -07:00
|
|
|
for policy_id, policy in self.policies.items():
|
|
|
|
if policy_id not in samples.policy_batches:
|
|
|
|
continue
|
|
|
|
|
|
|
|
batch = samples.policy_batches[policy_id]
|
|
|
|
for field in self.standardize_fields:
|
|
|
|
value = batch[field]
|
|
|
|
standardized = (value - value.mean()) / max(
|
|
|
|
1e-4, value.std())
|
|
|
|
batch[field] = standardized
|
|
|
|
|
|
|
|
for i in range(self.num_sgd_iter):
|
|
|
|
iter_extra_fetches = defaultdict(list)
|
|
|
|
for minibatch in self._minibatches(batch):
|
|
|
|
batch_fetches = (
|
|
|
|
self.workers.local_worker().learn_on_batch(
|
|
|
|
MultiAgentBatch({
|
|
|
|
policy_id: minibatch
|
|
|
|
}, minibatch.count)))[policy_id]
|
|
|
|
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
|
|
|
|
iter_extra_fetches[k].append(v)
|
|
|
|
logger.debug("{} {}".format(i,
|
|
|
|
_averaged(iter_extra_fetches)))
|
|
|
|
fetches[policy_id] = _averaged(iter_extra_fetches)
|
2018-03-04 12:25:25 -08:00
|
|
|
|
2019-09-11 12:15:34 -07:00
|
|
|
self.grad_timer.push_units_processed(samples.count)
|
|
|
|
if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
|
|
|
|
self.learner_stats = fetches[DEFAULT_POLICY_ID]
|
|
|
|
else:
|
|
|
|
self.learner_stats = fetches
|
2018-03-04 12:25:25 -08:00
|
|
|
self.num_steps_sampled += samples.count
|
|
|
|
self.num_steps_trained += samples.count
|
2019-05-27 17:24:45 -07:00
|
|
|
return self.learner_stats
|
2017-12-06 17:51:57 -08:00
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
@override(PolicyOptimizer)
|
2017-12-06 17:51:57 -08:00
|
|
|
def stats(self):
|
2018-07-19 15:30:36 -07:00
|
|
|
return dict(
|
|
|
|
PolicyOptimizer.stats(self), **{
|
|
|
|
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
|
|
|
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
|
|
|
"update_time_ms": round(1000 * self.update_weights_timer.mean,
|
|
|
|
3),
|
|
|
|
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
|
|
|
|
3),
|
2018-08-20 15:28:03 -07:00
|
|
|
"sample_peak_throughput": round(
|
|
|
|
self.sample_timer.mean_throughput, 3),
|
2018-07-19 15:30:36 -07:00
|
|
|
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
2018-08-23 17:49:10 -07:00
|
|
|
"learner": self.learner_stats,
|
2018-07-19 15:30:36 -07:00
|
|
|
})
|
2019-08-23 02:21:11 -04:00
|
|
|
|
|
|
|
def _minibatches(self, samples):
|
|
|
|
if not self.sgd_minibatch_size:
|
|
|
|
yield samples
|
|
|
|
return
|
|
|
|
|
|
|
|
if isinstance(samples, MultiAgentBatch):
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Minibatching not implemented for multi-agent in simple mode")
|
|
|
|
|
|
|
|
if "state_in_0" in samples.data:
|
2019-11-11 17:49:15 -08:00
|
|
|
logger.warning("Not shuffling RNN data for SGD in simple mode")
|
2019-08-23 02:21:11 -04:00
|
|
|
else:
|
|
|
|
samples.shuffle()
|
|
|
|
|
|
|
|
i = 0
|
|
|
|
slices = []
|
|
|
|
while i < samples.count:
|
|
|
|
slices.append((i, i + self.sgd_minibatch_size))
|
|
|
|
i += self.sgd_minibatch_size
|
|
|
|
random.shuffle(slices)
|
|
|
|
|
|
|
|
for i, j in slices:
|
|
|
|
yield samples.slice(i, j)
|