ray/rllib/policy/torch_policy.py

951 lines
39 KiB
Python

import copy
import functools
import gym
import logging
import numpy as np
import os
import time
import threading
import tree # pip install dm_tree
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
import ray
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.utils import force_list, NullContextManager
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
from ray.rllib.utils.typing import ModelGradients, ModelWeights, TensorType, \
TrainerConfigDict
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
@DeveloperAPI
class TorchPolicy(Policy):
"""Template for a PyTorch policy and loss to use with RLlib.
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
config (dict): config of the policy.
model (TorchModel): Torch model instance.
dist_class (type): Torch action distribution class.
"""
@DeveloperAPI
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
*,
model: ModelV2,
loss: Callable[[
Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
], Union[TensorType, List[TensorType]]],
action_distribution_class: Type[TorchDistributionWrapper],
action_sampler_fn: Optional[Callable[[
TensorType, List[TensorType]
], Tuple[TensorType, TensorType]]] = None,
action_distribution_fn: Optional[Callable[[
Policy, ModelV2, TensorType, TensorType, TensorType
], Tuple[TensorType, Type[TorchDistributionWrapper], List[
TensorType]]]] = None,
max_seq_len: int = 20,
get_batch_divisibility_req: Optional[Callable[[Policy],
int]] = None,
):
"""Build a policy from policy and loss torch modules.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
is set. Only single GPU is supported for now.
Args:
observation_space (gym.spaces.Space): observation space of the
policy.
action_space (gym.spaces.Space): action space of the policy.
config (TrainerConfigDict): The Policy config dict.
model (ModelV2): PyTorch policy module. Given observations as
input, this module must return a list of outputs where the
first item is action logits, and the rest can be any value.
loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper],
SampleBatch], Union[TensorType, List[TensorType]]]): Callable
that returns a single scalar loss or a list of loss terms.
action_distribution_class (Type[TorchDistributionWrapper]): Class
for a torch action distribution.
action_sampler_fn (Callable[[TensorType, List[TensorType]],
Tuple[TensorType, TensorType]]): A callable returning a
sampled action and its log-likelihood given Policy, ModelV2,
input_dict, explore, timestep, and is_training.
action_distribution_fn (Optional[Callable[[Policy, ModelV2,
ModelInputDict, TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]]]]): A callable
returning distribution inputs (parameters), a dist-class to
generate an action distribution object from, and
internal-state outputs (or an empty list if not applicable).
Note: No Exploration hooks have to be called from within
`action_distribution_fn`. It's should only perform a simple
forward pass through some model.
If None, pass inputs through `self.model()` to get distribution
inputs.
The callable takes as inputs: Policy, ModelV2, ModelInputDict,
explore, timestep, is_training.
max_seq_len (int): Max sequence length for LSTM training.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
Optional callable that returns the divisibility requirement
for sample batches given the Policy.
"""
self.framework = "torch"
super().__init__(observation_space, action_space, config)
# Log device and worker index.
from ray.rllib.evaluation.rollout_worker import get_global_worker
worker = get_global_worker()
worker_idx = worker.worker_index if worker else 0
# Create multi-GPU model towers, if necessary.
# - The central main model will be stored under self.model, residing on
# self.device.
# - Each GPU will have a copy of that model under
# self.model_gpu_towers, matching the devices in self.devices.
# - Parallelization is done by splitting the train batch and passing
# it through the model copies in parallel, then averaging over the
# resulting gradients, applying these averages on the main model and
# updating all towers' weights from the main model.
# - In case of just one device (1 (fake) GPU or 1 CPU), no
# parallelization will be done.
if config["_fake_gpus"] or config["num_gpus"] == 0 or \
not torch.cuda.is_available():
logger.info("TorchPolicy (worker={}) running on {}.".format(
worker_idx if worker_idx > 0 else "local",
"{} fake-GPUs".format(config["num_gpus"])
if config["_fake_gpus"] else "CPU"))
self.device = torch.device("cpu")
self.devices = [
self.device for _ in range(config["num_gpus"] or 1)
]
self.model_gpu_towers = [
model if i == 0 else copy.deepcopy(model)
for i in range(config["num_gpus"] or 1)
]
self.model = model
else:
logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
worker_idx if worker_idx > 0 else "local", config["num_gpus"]))
gpu_ids = ray.get_gpu_ids()
self.devices = [
torch.device("cuda:{}".format(i))
for i, id_ in enumerate(gpu_ids) if i < config["num_gpus"]
]
self.device = self.devices[0]
ids = [
id_ for i, id_ in enumerate(gpu_ids) if i < config["num_gpus"]
]
self.model_gpu_towers = []
for i, _ in enumerate(ids):
model_copy = copy.deepcopy(model)
self.model_gpu_towers.append(model_copy.to(self.devices[i]))
self.model = self.model_gpu_towers[0]
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when calling the model
# first, then its value function (e.g. in a loss function), in
# between of which another model call is made (e.g. to compute an
# action).
self._lock = threading.RLock()
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0
# Auto-update model's inference view requirements, if recurrent.
self._update_model_view_requirements_from_init_state()
# Combine view_requirements for Model and Policy.
self.view_requirements.update(self.model.view_requirements)
self.exploration = self._create_exploration()
self.unwrapped_model = model # used to support DistributedDataParallel
self._loss = loss
self._optimizers = force_list(self.optimizer())
# Store, which params (by index within the model's list of
# parameters) should be updated per optimizer.
# Maps optimizer idx to set or param indices.
self.multi_gpu_param_groups: List[Set[int]] = []
main_params = {p: i for i, p in enumerate(self.model.parameters())}
for o in self._optimizers:
param_indices = []
for pg_idx, pg in enumerate(o.param_groups):
for p in pg["params"]:
param_indices.append(main_params[p])
self.multi_gpu_param_groups.append(set(param_indices))
self.dist_class = action_distribution_class
self.action_sampler_fn = action_sampler_fn
self.action_distribution_fn = action_distribution_fn
# If set, means we are using distributed allreduce during learning.
self.distributed_world_size = None
self.max_seq_len = max_seq_len
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
callable(get_batch_divisibility_req) else \
(get_batch_divisibility_req or 1)
@override(Policy)
@DeveloperAPI
def compute_actions(
self,
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorType], TensorType] = None,
prev_reward_batch: Union[List[TensorType], TensorType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["MultiAgentEpisode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict(
SampleBatch({
SampleBatch.CUR_OBS: np.asarray(obs_batch),
}))
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
np.asarray(prev_action_batch)
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = \
np.asarray(prev_reward_batch)
state_batches = [
convert_to_torch_tensor(s, self.device)
for s in (state_batches or [])
]
return self._compute_action_helper(input_dict, state_batches,
seq_lens, explore, timestep)
@override(Policy)
def compute_actions_from_input_dict(
self,
input_dict: Dict[str, TensorType],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs) -> \
Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
]
# Calculate RNN sequence lengths.
seq_lens = np.array([1] * len(input_dict["obs"])) \
if state_batches else None
return self._compute_action_helper(input_dict, state_batches,
seq_lens, explore, timestep)
@with_lock
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
explore, timestep):
"""Shared forward pass logic (w/ and w/o trajectory view API).
Returns:
Tuple:
- actions, state_out, extra_fetches, logp.
"""
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
self._is_recurrent = state_batches is not None and state_batches != []
# Switch to eval mode.
if self.model:
self.model.eval()
if self.action_sampler_fn:
action_dist = dist_inputs = None
actions, logp, state_out = self.action_sampler_fn(
self,
self.model,
input_dict,
state_batches,
explore=explore,
timestep=timestep)
else:
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(
explore=explore, timestep=timestep)
if self.action_distribution_fn:
# Try new action_distribution_fn signature, supporting
# state_batches and seq_lens.
try:
dist_inputs, dist_class, state_out = \
self.action_distribution_fn(
self,
self.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=explore,
timestep=timestep,
is_training=False)
# Trying the old way (to stay backward compatible).
# TODO: Remove in future.
except TypeError as e:
if "positional argument" in e.args[0] or \
"unexpected keyword argument" in e.args[0]:
dist_inputs, dist_class, state_out = \
self.action_distribution_fn(
self,
self.model,
input_dict[SampleBatch.CUR_OBS],
explore=explore,
timestep=timestep,
is_training=False)
else:
raise e
else:
dist_class = self.dist_class
dist_inputs, state_out = self.model(input_dict, state_batches,
seq_lens)
if not (isinstance(dist_class, functools.partial)
or issubclass(dist_class, TorchDistributionWrapper)):
raise ValueError(
"`dist_class` ({}) not a TorchDistributionWrapper "
"subclass! Make sure your `action_distribution_fn` or "
"`make_model_and_action_dist` return a correct "
"distribution class.".format(dist_class.__name__))
action_dist = dist_class(dist_inputs, self.model)
# Get the exploration action from the forward results.
actions, logp = \
self.exploration.get_exploration_action(
action_distribution=action_dist,
timestep=timestep,
explore=explore)
input_dict[SampleBatch.ACTIONS] = actions
# Add default and custom fetches.
extra_fetches = self.extra_action_out(input_dict, state_batches,
self.model, action_dist)
# Action-dist inputs.
if dist_inputs is not None:
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
# Action-logp and action-prob.
if logp is not None:
extra_fetches[SampleBatch.ACTION_PROB] = \
torch.exp(logp.float())
extra_fetches[SampleBatch.ACTION_LOGP] = logp
# Update our global timestep by the batch size.
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
return convert_to_non_torch_type((actions, state_out, extra_fetches))
@with_lock
@override(Policy)
@DeveloperAPI
def compute_log_likelihoods(
self,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType],
TensorType]] = None,
prev_reward_batch: Optional[Union[List[
TensorType], TensorType]] = None) -> TensorType:
if self.action_sampler_fn and self.action_distribution_fn is None:
raise ValueError("Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
"`action_sampler_fn`!")
with torch.no_grad():
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_batch,
SampleBatch.ACTIONS: actions
})
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
state_batches = [
convert_to_torch_tensor(s, self.device)
for s in (state_batches or [])
]
# Exploration hook before each forward pass.
self.exploration.before_compute_actions(explore=False)
# Action dist class and inputs are generated via custom function.
if self.action_distribution_fn:
# Try new action_distribution_fn signature, supporting
# state_batches and seq_lens.
try:
dist_inputs, dist_class, state_out = \
self.action_distribution_fn(
self,
self.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False)
# Trying the old way (to stay backward compatible).
# TODO: Remove in future.
except TypeError as e:
if "positional argument" in e.args[0] or \
"unexpected keyword argument" in e.args[0]:
dist_inputs, dist_class, _ = \
self.action_distribution_fn(
policy=self,
model=self.model,
obs_batch=input_dict[SampleBatch.CUR_OBS],
explore=False,
is_training=False)
else:
raise e
# Default action-dist inputs calculation.
else:
dist_class = self.dist_class
dist_inputs, _ = self.model(input_dict, state_batches,
seq_lens)
action_dist = dist_class(dist_inputs, self.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
return log_likelihoods
@with_lock
@override(Policy)
@DeveloperAPI
def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Set Model to train mode.
if self.model:
self.model.train()
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch, result=learn_stats)
# Compute gradients (will calculate all losses and `backward()`
# them to get the grads).
grads, fetches = self.compute_gradients(postprocessed_batch)
# Step the optimizers.
self.apply_gradients(_directStepOptimizerSingleton)
if self.model:
fetches["model"] = self.model.metrics()
fetches.update({"custom_metrics": learn_stats})
return fetches
@with_lock
@override(Policy)
@DeveloperAPI
def compute_gradients(self,
postprocessed_batch: SampleBatch) -> ModelGradients:
if not isinstance(postprocessed_batch, SampleBatch) or \
not postprocessed_batch.zero_padded:
pad_batch_to_sequences_of_same_size(
postprocessed_batch,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
)
# Mark the batch as "is_training" so the Model can use this
# information.
postprocessed_batch.is_training = True
# Single device case: Use batch as-is (no slicing).
if len(self.devices) == 1:
batches = [self._lazy_tensor_dict(postprocessed_batch)]
# Multi-GPU case: Slice inputs into n (roughly) equal batches.
else:
len_ = len(postprocessed_batch)
batches = []
start = 0
for i, device in enumerate(self.devices):
shard_len = len_ // (len(self.devices) - i)
batch = self._lazy_tensor_dict(
postprocessed_batch.slice(start, start + shard_len),
device=device)
batches.append(batch)
len_ -= shard_len
start += shard_len
# Copy weights of main model to all towers.
state_dict = self.model.state_dict()
for tower in self.model_gpu_towers:
tower.load_state_dict(state_dict)
# Do the (maybe parallelized) gradient calculation step.
tower_outputs = self._multi_gpu_parallel_grad_calc(batches)
# Multi device (GPU) case.
if len(self.devices) > 1:
# Mean-reduce over GPU-towers.
all_grads = []
for i in range(len(tower_outputs[0][0])):
if tower_outputs[0][0][i] is not None:
all_grads.append(
torch.mean(
torch.stack([
t[0][i].to(self.device) for t in tower_outputs
]),
dim=0))
else:
all_grads.append(None)
# Set main model's grads to mean-reduced values.
for i, p in enumerate(self.model.parameters()):
p.grad = all_grads[i]
# Reduce stats over towers as well.
from ray.rllib.execution.train_ops import all_tower_reduce
grad_info = tree.map_structure_with_path(
lambda p, *t: all_tower_reduce(p, *t),
*[t[1] for t in tower_outputs])
# Single device case.
else:
all_grads, grad_info = tower_outputs[0]
grad_info["allreduce_latency"] /= len(self._optimizers)
grad_info.update(self.extra_grad_info(postprocessed_batch))
fetches = self.extra_compute_grad_fetches()
return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
@override(Policy)
@DeveloperAPI
def apply_gradients(self, gradients: ModelGradients) -> None:
if gradients == _directStepOptimizerSingleton:
for i, opt in enumerate(self._optimizers):
opt.step()
else:
# TODO(sven): Not supported for multiple optimizers yet.
assert len(self._optimizers) == 1
for g, p in zip(gradients, self.model.parameters()):
if g is not None:
if torch.is_tensor(g):
p.grad = g.to(self.device)
else:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizers[0].step()
@override(Policy)
@DeveloperAPI
def get_weights(self) -> ModelWeights:
return {
k: v.cpu().detach().numpy()
for k, v in self.model.state_dict().items()
}
@override(Policy)
@DeveloperAPI
def set_weights(self, weights: ModelWeights) -> None:
weights = convert_to_torch_tensor(weights, device=self.device)
self.model.load_state_dict(weights)
@override(Policy)
@DeveloperAPI
def is_recurrent(self) -> bool:
return self._is_recurrent
@override(Policy)
@DeveloperAPI
def num_state_tensors(self) -> int:
return len(self.model.get_initial_state())
@override(Policy)
@DeveloperAPI
def get_initial_state(self) -> List[TensorType]:
return [
s.detach().cpu().numpy() for s in self.model.get_initial_state()
]
@override(Policy)
@DeveloperAPI
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
state = super().get_state()
state["_optimizer_variables"] = []
for i, o in enumerate(self._optimizers):
optim_state_dict = convert_to_non_torch_type(o.state_dict())
state["_optimizer_variables"].append(optim_state_dict)
return state
@override(Policy)
@DeveloperAPI
def set_state(self, state: object) -> None:
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
optim_state_dict = convert_to_torch_tensor(
s, device=self.device)
o.load_state_dict(optim_state_dict)
# Then the Policy's (NN) weights.
super().set_state(state)
@DeveloperAPI
def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
loss: TensorType):
"""Called after each optimizer.zero_grad() + loss.backward() call.
Called for each self._optimizers/loss-value pair.
Allows for gradient processing before optimizer.step() is called.
E.g. for gradient clipping.
Args:
optimizer (torch.optim.Optimizer): A torch optimizer object.
loss (TensorType): The loss tensor associated with the optimizer.
Returns:
Dict[str, TensorType]: An dict with information on the gradient
processing step.
"""
return {}
@DeveloperAPI
def extra_compute_grad_fetches(self) -> Dict[str, any]:
"""Extra values to fetch and return from compute_gradients().
Returns:
Dict[str, any]: Extra fetch dict to be added to the fetch dict
of the compute_gradients call.
"""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@DeveloperAPI
def extra_action_out(
self, input_dict: Dict[str, TensorType],
state_batches: List[TensorType], model: TorchModelV2,
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
"""Returns dict of extra info to include in experience batch.
Args:
input_dict (Dict[str, TensorType]): Dict of model input tensors.
state_batches (List[TensorType]): List of state tensors.
model (TorchModelV2): Reference to the model object.
action_dist (TorchDistributionWrapper): Torch action dist object
to get log-probs (e.g. for already sampled actions).
Returns:
Dict[str, TensorType]: Extra outputs to return in a
compute_actions() call (3rd return value).
"""
return {}
@DeveloperAPI
def extra_grad_info(self,
train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Return dict of extra grad info.
Args:
train_batch (SampleBatch): The training batch for which to produce
extra grad info for.
Returns:
Dict[str, TensorType]: The info dict carrying grad info per str
key.
"""
return {}
@DeveloperAPI
def optimizer(
self
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
"""Custom the local PyTorch optimizer(s) to use.
Returns:
Union[List[torch.optim.Optimizer], torch.optim.Optimizer]:
The local PyTorch optimizer(s) to use for this Policy.
"""
if hasattr(self, "config"):
return torch.optim.Adam(
self.model.parameters(), lr=self.config["lr"])
else:
return torch.optim.Adam(self.model.parameters())
@override(Policy)
@DeveloperAPI
def export_model(self, export_dir: str) -> None:
"""Exports the Policy's Model to local directory for serving.
Creates a TorchScript model and saves it.
Args:
export_dir (str): Local writable directory or filename.
"""
self._lazy_tensor_dict(self._dummy_batch)
# Provide dummy state inputs if not an RNN (torch cannot jit with
# returned empty internal states list).
if "state_in_0" not in self._dummy_batch:
self._dummy_batch["state_in_0"] = \
self._dummy_batch["seq_lens"] = np.array([1.0])
seq_lens = self._dummy_batch["seq_lens"]
state_ins = []
i = 0
while "state_in_{}".format(i) in self._dummy_batch:
state_ins.append(self._dummy_batch["state_in_{}".format(i)])
i += 1
dummy_inputs = {
k: self._dummy_batch[k]
for k in self._dummy_batch.keys() if k != "is_training"
}
traced = torch.jit.trace(self.model,
(dummy_inputs, state_ins, seq_lens))
if not os.path.exists(export_dir):
os.makedirs(export_dir)
file_name = os.path.join(export_dir, "model.pt")
traced.save(file_name)
@override(Policy)
@DeveloperAPI
def export_checkpoint(self, export_dir: str) -> None:
"""TODO(sven): implement for torch.
"""
raise NotImplementedError
@override(Policy)
@DeveloperAPI
def import_model_from_h5(self, import_file: str) -> None:
"""Imports weights into torch model."""
return self.model.import_from_h5(import_file)
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
# TODO: (sven): Keep for a while to ensure backward compatibility.
if not isinstance(postprocessed_batch, SampleBatch):
postprocessed_batch = SampleBatch(postprocessed_batch)
postprocessed_batch.set_get_interceptor(
functools.partial(
convert_to_torch_tensor, device=device or self.device))
return postprocessed_batch
def _multi_gpu_parallel_grad_calc(self, sample_batches):
"""Performs a parallelized loss and gradient calculation over the batch.
Splits up the given train batch into n shards (n=number of this
Policy's devices) and passes each data shard (in parallel) through
the loss function using the individual devices' models
(self.model_gpu_towers). Then returns each tower's outputs.
Args:
sample_batches (List[SampleBatch]): A list of SampleBatch shards to
calculate loss and gradients for.
Returns:
List[Tuple[List[TensorType], StatsDict]]: A list (one item per
device) of 2-tuples with 1) gradient list and 2) stats dict.
"""
assert len(self.model_gpu_towers) == len(sample_batches)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(shard_idx, model, sample_batch, device):
torch.set_grad_enabled(grad_enabled)
try:
with NullContextManager(
) if device.type == "cpu" else torch.cuda.device(device):
loss_out = force_list(
self._loss(self, model, self.dist_class, sample_batch))
# Call Model's custom-loss with Policy loss outputs and
# train_batch.
loss_out = model.custom_loss(loss_out, sample_batch)
assert len(loss_out) == len(self._optimizers)
# Loop through all optimizers.
grad_info = {"allreduce_latency": 0.0}
parameters = list(model.parameters())
all_grads = [None for _ in range(len(parameters))]
for opt_idx, opt in enumerate(self._optimizers):
# Erase gradients in all vars of the tower that this
# optimizer would affect.
param_indices = self.multi_gpu_param_groups[opt_idx]
for param_idx, param in enumerate(parameters):
if param_idx in param_indices and \
param.grad is not None:
param.grad.data.zero_()
# Recompute gradients of loss over all variables.
loss_out[opt_idx].backward(retain_graph=True)
grad_info.update(
self.extra_grad_process(opt, loss_out[opt_idx]))
grads = []
# Note that return values are just references;
# Calling zero_grad would modify the values.
for param_idx, param in enumerate(parameters):
if param_idx in param_indices:
if param.grad is not None:
grads.append(param.grad)
all_grads[param_idx] = param.grad
if self.distributed_world_size:
start = time.time()
if torch.cuda.is_available():
# Sadly, allreduce_coalesced does not work with
# CUDA yet.
for g in grads:
torch.distributed.all_reduce(
g, op=torch.distributed.ReduceOp.SUM)
else:
torch.distributed.all_reduce_coalesced(
grads, op=torch.distributed.ReduceOp.SUM)
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
p.grad /= self.distributed_world_size
grad_info[
"allreduce_latency"] += time.time() - start
with lock:
results[shard_idx] = (all_grads, grad_info)
except Exception as e:
with lock:
results[shard_idx] = ValueError(
e.args[0] + "\n" +
"In tower {} on device {}".format(shard_idx, device))
# Single device (GPU) case.
if len(self.devices) == 1:
_worker(0, self.model, sample_batches[0], self.device)
if isinstance(results[0], ValueError):
raise (results[0])
return [results[0]]
# Multi device (GPU) case: Parallelize via threads.
else:
threads = [
threading.Thread(
target=_worker,
args=(shard_idx, model, sample_batch, device))
for shard_idx, (model, sample_batch, device) in enumerate(
zip(self.model_gpu_towers, sample_batches, self.devices))
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# Gather all threads' outputs and return.
outputs = []
for shard_idx in range(len(sample_batches)):
output = results[shard_idx]
if isinstance(output, Exception):
raise output
outputs.append(results[shard_idx])
return outputs
# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch)
# and for all possible hyperparams, not just lr.
@DeveloperAPI
class LearningRateSchedule:
"""Mixin for TFPolicy that adds a learning rate schedule."""
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self._lr_schedule = None
if lr_schedule is None:
self.cur_lr = lr
else:
self._lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
self.cur_lr = self._lr_schedule.value(0)
@override(Policy)
def on_global_var_update(self, global_vars):
super().on_global_var_update(global_vars)
if self._lr_schedule:
self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
@DeveloperAPI
class EntropyCoeffSchedule:
"""Mixin for TorchPolicy that adds entropy coeff decay."""
@DeveloperAPI
def __init__(self, entropy_coeff, entropy_coeff_schedule):
self.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is None:
self.entropy_coeff_schedule = ConstantSchedule(
entropy_coeff, framework=None)
else:
# Allows for custom schedule similar to lr_schedule format
if isinstance(entropy_coeff_schedule, list):
self.entropy_coeff_schedule = PiecewiseSchedule(
entropy_coeff_schedule,
outside_value=entropy_coeff_schedule[-1][-1],
framework=None)
else:
# Implements previous version but enforces outside_value
self.entropy_coeff_schedule = PiecewiseSchedule(
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
outside_value=0.0,
framework=None)
@override(Policy)
def on_global_var_update(self, global_vars):
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
self.entropy_coeff = self.entropy_coeff_schedule.value(
global_vars["timestep"])
@DeveloperAPI
class DirectStepOptimizer:
"""Typesafe method for indicating apply gradients can directly step the
optimizers with in-place gradients.
"""
_instance = None
def __new__(cls):
if DirectStepOptimizer._instance is None:
DirectStepOptimizer._instance = super().__new__(cls)
return DirectStepOptimizer._instance
def __eq__(self, other):
return type(self) == type(other)
def __repr__(self):
return "DirectStepOptimizer"
_directStepOptimizerSingleton = DirectStepOptimizer()