ray/rllib/execution/train_ops.py

563 lines
22 KiB
Python
Raw Normal View History

import logging
import numpy as np
import math
from typing import Dict, List, Tuple, Any
import ray
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
AGENT_STEPS_TRAINED_COUNTER,
APPLY_GRADS_TIMER,
COMPUTE_GRADS_TIMER,
LAST_TARGET_UPDATE_TS,
LEARN_ON_BATCH_TIMER,
LOAD_BATCH_TIMER,
NUM_TARGET_UPDATES,
STEPS_SAMPLED_COUNTER,
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_check_sample_batch_type,
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_TRAINED,
NUM_AGENT_STEPS_TRAINED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
def train_one_step(algorithm, train_batch, policies_to_train=None) -> Dict:
"""Function that improves the all policies in `train_batch` on the local worker.
Examples:
>>> from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
>>> algo = [...] # doctest: +SKIP
>>> train_batch = synchronous_parallel_sample(algo.workers) # doctest: +SKIP
>>> # This trains the policy on one batch.
>>> results = train_one_step(algo, train_batch)) # doctest: +SKIP
{"default_policy": ...}
Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
the LEARN_ON_BATCH_TIMER timer of the `algorithm` object.
"""
config = algorithm.config
workers = algorithm.workers
local_worker = workers.local_worker()
num_sgd_iter = config.get("num_sgd_iter", 1)
sgd_minibatch_size = config.get("sgd_minibatch_size", 0)
learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
with learn_timer:
# Subsample minibatches (size=`sgd_minibatch_size`) from the
# train batch and loop through train batch `num_sgd_iter` times.
if num_sgd_iter > 1 or sgd_minibatch_size > 0:
info = do_minibatch_sgd(
train_batch,
{
pid: local_worker.get_policy(pid)
for pid in policies_to_train
or local_worker.get_policies_to_train(train_batch)
},
local_worker,
num_sgd_iter,
sgd_minibatch_size,
[],
)
# Single update step using train batch.
else:
info = local_worker.learn_on_batch(train_batch)
learn_timer.push_units_processed(train_batch.count)
algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
return info
@DeveloperAPI
def multi_gpu_train_one_step(algorithm, train_batch) -> Dict:
"""Multi-GPU version of train_one_step.
Uses the policies' `load_batch_into_buffer` and `learn_on_loaded_batch` methods
to be more efficient wrt CPU/GPU data transfers. For example, when doing multiple
passes through a train batch (e.g. for PPO) using `config.num_sgd_iter`, the
actual train batch is only split once and loaded once into the GPU(s).
Examples:
>>> from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
>>> algo = [...] # doctest: +SKIP
>>> train_batch = synchronous_parallel_sample(algo.workers) # doctest: +SKIP
>>> # This trains the policy on one batch.
>>> results = multi_gpu_train_one_step(algo, train_batch)) # doctest: +SKIP
{"default_policy": ...}
Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
the LOAD_BATCH_TIMER and LEARN_ON_BATCH_TIMER timers of the Algorithm instance.
"""
config = algorithm.config
workers = algorithm.workers
local_worker = workers.local_worker()
num_sgd_iter = config.get("num_sgd_iter", 1)
sgd_minibatch_size = config.get("sgd_minibatch_size", config["train_batch_size"])
# Determine the number of devices (GPUs or 1 CPU) we use.
num_devices = int(math.ceil(config["num_gpus"] or 1))
# Make sure total batch size is dividable by the number of devices.
# Batch size per tower.
per_device_batch_size = sgd_minibatch_size // num_devices
# Total batch size.
batch_size = per_device_batch_size * num_devices
assert batch_size % num_devices == 0
assert batch_size >= num_devices, "Batch size too small!"
# Handle everything as if multi-agent.
train_batch = train_batch.as_multi_agent()
# Load data into GPUs.
load_timer = algorithm._timers[LOAD_BATCH_TIMER]
with load_timer:
num_loaded_samples = {}
for policy_id, batch in train_batch.policy_batches.items():
# Not a policy-to-train.
if not local_worker.is_policy_to_train(policy_id, train_batch):
continue
# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
# Load the entire train batch into the Policy's only buffer
# (idx=0). Policies only have >1 buffers, if we are training
# asynchronously.
num_loaded_samples[policy_id] = local_worker.policy_map[
policy_id
].load_batch_into_buffer(batch, buffer_index=0)
# Execute minibatch SGD on loaded data.
learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
with learn_timer:
# Use LearnerInfoBuilder as a unified way to build the final
# results dict from `learn_on_loaded_batch` call(s).
# This makes sure results dicts always have the same structure
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
# tf vs torch).
learner_info_builder = LearnerInfoBuilder(num_devices=num_devices)
for policy_id, samples_per_device in num_loaded_samples.items():
policy = local_worker.policy_map[policy_id]
num_batches = max(1, int(samples_per_device) // int(per_device_batch_size))
logger.debug("== sgd epochs for {} ==".format(policy_id))
for _ in range(num_sgd_iter):
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
# Learn on the pre-loaded data in the buffer.
# Note: For minibatch SGD, the data is an offset into
# the pre-loaded entire train batch.
results = policy.learn_on_loaded_batch(
permutation[batch_index] * per_device_batch_size, buffer_index=0
)
learner_info_builder.add_learn_on_batch_results(results, policy_id)
# Tower reduce and finalize results.
learner_info = learner_info_builder.finalize()
load_timer.push_units_processed(train_batch.count)
learn_timer.push_units_processed(train_batch.count)
# TODO: Move this into Trainer's `training_iteration` method for
# better transparency.
algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
return learner_info
class TrainOneStep:
def __init__(
self,
workers: WorkerSet,
policies: List[PolicyID] = frozenset([]),
num_sgd_iter: int = 1,
sgd_minibatch_size: int = 0,
):
self.workers = workers
self.local_worker = workers.local_worker()
self.policies = policies
self.num_sgd_iter = num_sgd_iter
self.sgd_minibatch_size = sgd_minibatch_size
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]):
_check_sample_batch_type(batch)
metrics = _get_shared_metrics()
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
lw = self.local_worker
with learn_timer:
# Subsample minibatches (size=`sgd_minibatch_size`) from the
# train batch and loop through train batch `num_sgd_iter` times.
if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
learner_info = do_minibatch_sgd(
batch,
{
pid: lw.get_policy(pid)
for pid in self.policies or lw.get_policies_to_train(batch)
},
lw,
self.num_sgd_iter,
self.sgd_minibatch_size,
[],
)
# Single update step using train batch.
else:
learner_info = lw.learn_on_batch(batch)
metrics.info[LEARNER_INFO] = learner_info
learn_timer.push_units_processed(batch.count)
metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = batch.count
if isinstance(batch, MultiAgentBatch):
metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += batch.agent_steps()
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = ray.put(
lw.get_weights(self.policies or lw.get_policies_to_train(batch))
)
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
lw.set_global_vars(_get_global_vars())
return batch, learner_info
class MultiGPUTrainOneStep:
def __init__(
self,
*,
workers: WorkerSet,
sgd_minibatch_size: int,
num_sgd_iter: int,
num_gpus: int,
_fake_gpus: bool = False,
# Deprecated args.
shuffle_sequences=DEPRECATED_VALUE,
framework=DEPRECATED_VALUE
):
if framework != DEPRECATED_VALUE or shuffle_sequences != DEPRECATED_VALUE:
deprecation_warning(
old="MultiGPUTrainOneStep(framework=..., shuffle_sequences=...)",
error=False,
)
self.workers = workers
self.local_worker = workers.local_worker()
self.num_sgd_iter = num_sgd_iter
self.sgd_minibatch_size = sgd_minibatch_size
self.shuffle_sequences = shuffle_sequences
# Collect actual GPU devices to use.
if not num_gpus:
_fake_gpus = True
num_gpus = 1
type_ = "cpu" if _fake_gpus else "gpu"
self.devices = [
"/{}:{}".format(type_, 0 if _fake_gpus else i)
for i in range(int(math.ceil(num_gpus)))
]
# Make sure total batch size is dividable by the number of devices.
# Batch size per tower.
self.per_device_batch_size = sgd_minibatch_size // len(self.devices)
# Total batch size.
self.batch_size = self.per_device_batch_size * len(self.devices)
assert self.batch_size % len(self.devices) == 0
assert self.batch_size >= len(self.devices), "Batch size too small!"
def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]):
_check_sample_batch_type(samples)
# Handle everything as if multi agent.
samples = samples.as_multi_agent()
metrics = _get_shared_metrics()
load_timer = metrics.timers[LOAD_BATCH_TIMER]
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
# Load data into GPUs.
with load_timer:
num_loaded_samples = {}
for policy_id, batch in samples.policy_batches.items():
# Not a policy-to-train.
if not self.local_worker.is_policy_to_train(policy_id, samples):
continue
# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
# Load the entire train batch into the Policy's only buffer
# (idx=0). Policies only have >1 buffers, if we are training
# asynchronously.
num_loaded_samples[policy_id] = self.local_worker.policy_map[
policy_id
].load_batch_into_buffer(batch, buffer_index=0)
# Execute minibatch SGD on loaded data.
with learn_timer:
# Use LearnerInfoBuilder as a unified way to build the final
# results dict from `learn_on_loaded_batch` call(s).
# This makes sure results dicts always have the same structure
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
# tf vs torch).
learner_info_builder = LearnerInfoBuilder(num_devices=len(self.devices))
for policy_id, samples_per_device in num_loaded_samples.items():
policy = self.local_worker.policy_map[policy_id]
num_batches = max(
1, int(samples_per_device) // int(self.per_device_batch_size)
)
logger.debug("== sgd epochs for {} ==".format(policy_id))
for _ in range(self.num_sgd_iter):
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
# Learn on the pre-loaded data in the buffer.
# Note: For minibatch SGD, the data is an offset into
# the pre-loaded entire train batch.
results = policy.learn_on_loaded_batch(
permutation[batch_index] * self.per_device_batch_size,
buffer_index=0,
)
learner_info_builder.add_learn_on_batch_results(
results, policy_id
)
# Tower reduce and finalize results.
learner_info = learner_info_builder.finalize()
load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count)
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
metrics.counters[AGENT_STEPS_TRAINED_COUNTER] += samples.agent_steps()
metrics.info[LEARNER_INFO] = learner_info
if self.workers.remote_workers():
with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = ray.put(
self.workers.local_worker().get_weights(
self.local_worker.get_policies_to_train()
)
)
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
# Also update global vars of the local worker.
self.workers.local_worker().set_global_vars(_get_global_vars())
return samples, learner_info
# Backward compatibility.
TrainTFMultiGPU = MultiGPUTrainOneStep
class ComputeGradients:
"""Callable that computes gradients with respect to the policy loss.
This should be used with the .for_each() operator.
Examples:
>>> from ray.rllib.execution.train_ops import ComputeGradients
>>> rollouts, workers = ... # doctest: +SKIP
>>> grads_op = rollouts.for_each(ComputeGradients(workers)) # doctest: +SKIP
>>> print(next(grads_op)) # doctest: +SKIP
{"var_0": ..., ...}, 50 # grads, batch count
Updates the LEARNER_INFO info field in the local iterator context.
"""
def __init__(self, workers: WorkerSet):
self.workers = workers
def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]:
_check_sample_batch_type(samples)
metrics = _get_shared_metrics()
with metrics.timers[COMPUTE_GRADS_TIMER]:
grad, info = self.workers.local_worker().compute_gradients(
samples, single_agent=True
)
# RolloutWorker.compute_gradients returned single-agent stats.
metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info}
return grad, samples.count
class ApplyGradients:
"""Callable that applies gradients and updates workers.
This should be used with the .for_each() operator.
Examples:
>>> from ray.rllib.execution.train_ops import ApplyGradients
>>> grad_op, workers = ... # doctest: +SKIP
>>> apply_op = grads_op.for_each(ApplyGradients(workers)) # doctest: +SKIP
>>> print(next(apply_op)) # doctest: +SKIP
None
Updates the STEPS_TRAINED_COUNTER counter in the local iterator context.
"""
def __init__(
self, workers, policies: List[PolicyID] = frozenset([]), update_all=True
):
"""Creates an ApplyGradients instance.
Args:
workers: workers to apply gradients to.
update_all: If true, updates all workers. Otherwise, only
update the worker that produced the sample batch we are
currently processing (i.e., A3C style).
"""
self.workers = workers
self.local_worker = workers.local_worker()
self.policies = policies
self.update_all = update_all
def __call__(self, item: Tuple[ModelGradients, int]) -> None:
if not isinstance(item, tuple) or len(item) != 2:
raise ValueError(
"Input must be a tuple of (grad_dict, count), got {}".format(item)
)
gradients, count = item
metrics = _get_shared_metrics()
metrics.counters[STEPS_TRAINED_COUNTER] += count
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
apply_timer = metrics.timers[APPLY_GRADS_TIMER]
with apply_timer:
self.local_worker.apply_gradients(gradients)
apply_timer.push_units_processed(count)
# Also update global vars of the local worker.
self.local_worker.set_global_vars(_get_global_vars())
if self.update_all:
if self.workers.remote_workers():
with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = ray.put(
self.local_worker.get_weights(
self.policies or self.local_worker.get_policies_to_train()
)
)
for e in self.workers.remote_workers():
e.set_weights.remote(weights, _get_global_vars())
else:
if metrics.current_actor is None:
raise ValueError(
"Could not find actor to update. When "
"update_all=False, `current_actor` must be set "
"in the iterator context."
)
with metrics.timers[SYNCH_WORKER_WEIGHTS_TIMER]:
weights = self.local_worker.get_weights(
self.policies or self.local_worker.get_policies_to_train()
)
metrics.current_actor.set_weights.remote(weights, _get_global_vars())
class AverageGradients:
"""Callable that averages the gradients in a batch.
This should be used with the .for_each() operator after a set of gradients
have been batched with .batch().
Examples:
>>> from ray.rllib.execution.train_ops import AverageGradients
>>> grads_op = ... # doctest: +SKIP
>>> batched_grads = grads_op.batch(32) # doctest: +SKIP
>>> avg_grads = batched_grads.for_each(AverageGradients()) # doctest: +SKIP
>>> print(next(avg_grads)) # doctest: +SKIP
{"var_0": ..., ...}, 1600 # averaged grads, summed batch count
"""
def __call__(
self, gradients: List[Tuple[ModelGradients, int]]
) -> Tuple[ModelGradients, int]:
acc = None
sum_count = 0
for grad, count in gradients:
if acc is None:
acc = grad
else:
acc = [a + b for a, b in zip(acc, grad)]
sum_count += count
logger.info(
"Computing average of {} microbatch gradients "
"({} samples total)".format(len(gradients), sum_count)
)
return acc, sum_count
class UpdateTargetNetwork:
"""Periodically call policy.update_target() on all trainable policies.
This should be used with the .for_each() operator after training step
has been taken.
Examples:
>>> from ray.rllib.execution.train_ops import UpdateTargetNetwork
>>> from ray.rllib.execution import ParallelRollouts, TrainOneStep
>>> workers = ... # doctest: +SKIP
>>> train_op = ParallelRollouts(...).for_each( # doctest: +SKIP
... TrainOneStep(...))
>>> update_op = train_op.for_each( # doctest: +SKIP
... UpdateTargetNetwork(workers, target_update_freq=500)) # doctest: +SKIP
>>> print(next(update_op)) # doctest: +SKIP
None
Updates the LAST_TARGET_UPDATE_TS and NUM_TARGET_UPDATES counters in the
local iterator context. The value of the last update counter is used to
track when we should update the target next.
"""
def __init__(
self,
workers: WorkerSet,
target_update_freq: int,
by_steps_trained: bool = False,
policies: List[PolicyID] = frozenset([]),
):
self.workers = workers
self.local_worker = workers.local_worker()
self.target_update_freq = target_update_freq
self.policies = policies
if by_steps_trained:
self.metric = STEPS_TRAINED_COUNTER
else:
self.metric = STEPS_SAMPLED_COUNTER
def __call__(self, _: Any) -> None:
metrics = _get_shared_metrics()
cur_ts = metrics.counters[self.metric]
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
if cur_ts - last_update >= self.target_update_freq:
to_update = self.policies or self.local_worker.get_policies_to_train()
self.workers.local_worker().foreach_policy_to_train(
lambda p, pid: pid in to_update and p.update_target()
)
metrics.counters[NUM_TARGET_UPDATES] += 1
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts