ray/rllib/optimizers/torch_distributed_data_parallel_optimizer.py

108 lines
4.3 KiB
Python

import logging
import ray
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
logger = logging.getLogger(__name__)
class TorchDistributedDataParallelOptimizer(PolicyOptimizer):
"""EXPERIMENTAL: torch distributed multi-node SGD."""
def __init__(self,
workers,
expected_batch_size,
num_sgd_iter=1,
sgd_minibatch_size=0,
standardize_fields=frozenset([]),
keep_local_weights_in_sync=True,
backend="gloo"):
PolicyOptimizer.__init__(self, workers)
self.learner_stats = {}
self.num_sgd_iter = num_sgd_iter
self.expected_batch_size = expected_batch_size
self.sgd_minibatch_size = sgd_minibatch_size
self.standardize_fields = standardize_fields
self.keep_local_weights_in_sync = keep_local_weights_in_sync
self.sync_down_timer = TimerStat()
self.sync_up_timer = TimerStat()
self.learn_timer = TimerStat()
# Setup the distributed processes.
if not self.workers.remote_workers():
raise ValueError("This optimizer requires >0 remote workers.")
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
logger.info(
"Creating torch process group with leader {}".format(address))
# Get setup tasks in order to throw errors on failure.
ray.get([
worker.setup_torch_data_parallel.remote(
address, i, len(workers.remote_workers()), backend)
for i, worker in enumerate(workers.remote_workers())
])
logger.info("Torch process group init completed")
@override(PolicyOptimizer)
def step(self):
# Sync up the weights. In principle we don't need this, but it doesn't
# add too much overhead and handles the case where the user manually
# updates the local weights.
if self.keep_local_weights_in_sync:
with self.sync_up_timer:
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
with self.learn_timer:
results = ray.get([
w.sample_and_learn.remote(
self.expected_batch_size, self.num_sgd_iter,
self.sgd_minibatch_size, self.standardize_fields)
for w in self.workers.remote_workers()
])
for info, count in results:
self.num_steps_sampled += count
self.num_steps_trained += count
self.learner_stats = results[0][0]
# In debug mode, check the allreduce successfully synced the weights.
if logger.isEnabledFor(logging.DEBUG):
weights = ray.get([
w.get_weights.remote() for w in self.workers.remote_workers()
])
sums = []
for w in weights:
acc = 0
for p in w.values():
for k, v in p.items():
acc += v.sum()
sums.append(float(acc))
logger.debug("The worker weight sums are {}".format(sums))
assert len(set(sums)) == 1, sums
# Sync down the weights. As with the sync up, this is not really
# needed unless the user is reading the local weights.
if self.keep_local_weights_in_sync:
with self.sync_down_timer:
self.workers.local_worker().set_weights(
ray.get(
self.workers.remote_workers()[0].get_weights.remote()))
return self.learner_stats
@override(PolicyOptimizer)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"sync_weights_up_time": round(1000 * self.sync_up_timer.mean,
3),
"sync_weights_down_time": round(
1000 * self.sync_down_timer.mean, 3),
"learn_time_ms": round(1000 * self.learn_timer.mean, 3),
"learner": self.learner_stats,
})