mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
187 lines
6.7 KiB
Python
187 lines
6.7 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import os
|
|
|
|
from threading import Lock
|
|
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
pass # soft dep
|
|
|
|
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
|
|
|
|
|
class TorchPolicy(Policy):
|
|
"""Template for a PyTorch policy and loss to use with RLlib.
|
|
|
|
This is similar to TFPolicy, but for PyTorch.
|
|
|
|
Attributes:
|
|
observation_space (gym.Space): observation space of the policy.
|
|
action_space (gym.Space): action space of the policy.
|
|
lock (Lock): Lock that must be held around PyTorch ops on this graph.
|
|
This is necessary when using the async sampler.
|
|
"""
|
|
|
|
def __init__(self, observation_space, action_space, model, loss,
|
|
action_distribution_class):
|
|
"""Build a policy from policy and loss torch modules.
|
|
|
|
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
|
is set. Only single GPU is supported for now.
|
|
|
|
Arguments:
|
|
observation_space (gym.Space): observation space of the policy.
|
|
action_space (gym.Space): action space of the policy.
|
|
model (nn.Module): PyTorch policy module. Given observations as
|
|
input, this module must return a list of outputs where the
|
|
first item is action logits, and the rest can be any value.
|
|
loss (func): Function that takes (policy, batch_tensors)
|
|
and returns a single scalar loss.
|
|
action_distribution_class (ActionDistribution): Class for action
|
|
distribution.
|
|
"""
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
self.lock = Lock()
|
|
self.device = (torch.device("cuda")
|
|
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
|
else torch.device("cpu"))
|
|
self._model = model.to(self.device)
|
|
self._loss = loss
|
|
self._optimizer = self.optimizer()
|
|
self._action_dist_class = action_distribution_class
|
|
|
|
@override(Policy)
|
|
def compute_actions(self,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
info_batch=None,
|
|
episodes=None,
|
|
**kwargs):
|
|
with self.lock:
|
|
with torch.no_grad():
|
|
input_dict = self._lazy_tensor_dict({
|
|
"obs": obs_batch,
|
|
})
|
|
if prev_action_batch:
|
|
input_dict["prev_actions"] = prev_action_batch
|
|
if prev_reward_batch:
|
|
input_dict["prev_rewards"] = prev_reward_batch
|
|
model_out = self._model(input_dict, state_batches, [1])
|
|
logits, state = model_out
|
|
action_dist = self._action_dist_class(logits, self._model)
|
|
actions = action_dist.sample()
|
|
return (actions.cpu().numpy(),
|
|
[h.cpu().numpy() for h in state],
|
|
self.extra_action_out(input_dict, state_batches,
|
|
self._model))
|
|
|
|
@override(Policy)
|
|
def learn_on_batch(self, postprocessed_batch):
|
|
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
|
|
|
with self.lock:
|
|
loss_out = self._loss(self, batch_tensors)
|
|
self._optimizer.zero_grad()
|
|
loss_out.backward()
|
|
|
|
grad_process_info = self.extra_grad_process()
|
|
self._optimizer.step()
|
|
|
|
grad_info = self.extra_grad_info(batch_tensors)
|
|
grad_info.update(grad_process_info)
|
|
return {LEARNER_STATS_KEY: grad_info}
|
|
|
|
@override(Policy)
|
|
def compute_gradients(self, postprocessed_batch):
|
|
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
|
|
|
with self.lock:
|
|
loss_out = self._loss(self, batch_tensors)
|
|
self._optimizer.zero_grad()
|
|
loss_out.backward()
|
|
|
|
grad_process_info = self.extra_grad_process()
|
|
|
|
# Note that return values are just references;
|
|
# calling zero_grad will modify the values
|
|
grads = []
|
|
for p in self._model.parameters():
|
|
if p.grad is not None:
|
|
grads.append(p.grad.data.cpu().numpy())
|
|
else:
|
|
grads.append(None)
|
|
|
|
grad_info = self.extra_grad_info(batch_tensors)
|
|
grad_info.update(grad_process_info)
|
|
return grads, {LEARNER_STATS_KEY: grad_info}
|
|
|
|
@override(Policy)
|
|
def apply_gradients(self, gradients):
|
|
with self.lock:
|
|
for g, p in zip(gradients, self._model.parameters()):
|
|
if g is not None:
|
|
p.grad = torch.from_numpy(g).to(self.device)
|
|
self._optimizer.step()
|
|
|
|
@override(Policy)
|
|
def get_weights(self):
|
|
with self.lock:
|
|
return {k: v.cpu() for k, v in self._model.state_dict().items()}
|
|
|
|
@override(Policy)
|
|
def set_weights(self, weights):
|
|
with self.lock:
|
|
self._model.load_state_dict(weights)
|
|
|
|
@override(Policy)
|
|
def get_initial_state(self):
|
|
return [s.numpy() for s in self._model.get_initial_state()]
|
|
|
|
def extra_grad_process(self):
|
|
"""Allow subclass to do extra processing on gradients and
|
|
return processing info."""
|
|
return {}
|
|
|
|
def extra_action_out(self, input_dict, state_batches, model):
|
|
"""Returns dict of extra info to include in experience batch.
|
|
|
|
Arguments:
|
|
input_dict (dict): Dict of model input tensors.
|
|
state_batches (list): List of state tensors.
|
|
model (TorchModelV2): Reference to the model."""
|
|
return {}
|
|
|
|
def extra_grad_info(self, batch_tensors):
|
|
"""Return dict of extra grad info."""
|
|
|
|
return {}
|
|
|
|
def optimizer(self):
|
|
"""Custom PyTorch optimizer to use."""
|
|
if hasattr(self, "config"):
|
|
return torch.optim.Adam(
|
|
self._model.parameters(), lr=self.config["lr"])
|
|
else:
|
|
return torch.optim.Adam(self._model.parameters())
|
|
|
|
def _lazy_tensor_dict(self, postprocessed_batch):
|
|
batch_tensors = UsageTrackingDict(postprocessed_batch)
|
|
|
|
def convert(arr):
|
|
tensor = torch.from_numpy(np.asarray(arr))
|
|
if tensor.dtype == torch.double:
|
|
tensor = tensor.float()
|
|
return tensor.to(self.device)
|
|
|
|
batch_tensors.set_get_interceptor(convert)
|
|
return batch_tensors
|