2019-06-01 16:58:49 +08:00
|
|
|
import numpy as np
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
2020-01-18 07:26:28 +01:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.utils import try_import_torch
|
|
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
2020-01-18 07:26:28 +01:00
|
|
|
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
|
|
|
|
|
|
|
torch, _ = try_import_torch()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TorchPolicy(Policy):
|
|
|
|
"""Template for a PyTorch policy and loss to use with RLlib.
|
|
|
|
|
|
|
|
This is similar to TFPolicy, but for PyTorch.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
observation_space (gym.Space): observation space of the policy.
|
|
|
|
action_space (gym.Space): action space of the policy.
|
2019-09-08 23:01:26 -07:00
|
|
|
config (dict): config of the policy
|
|
|
|
model (TorchModel): Torch model instance
|
|
|
|
dist_class (type): Torch action distribution class
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-01-18 07:26:28 +01:00
|
|
|
def __init__(self, observation_space, action_space, config, model, loss,
|
2019-08-10 14:05:12 -07:00
|
|
|
action_distribution_class):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Build a policy from policy and loss torch modules.
|
|
|
|
|
|
|
|
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
|
|
|
is set. Only single GPU is supported for now.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
observation_space (gym.Space): observation space of the policy.
|
|
|
|
action_space (gym.Space): action space of the policy.
|
2020-01-18 07:26:28 +01:00
|
|
|
config (dict): The Policy config dict.
|
2019-05-20 16:46:05 -07:00
|
|
|
model (nn.Module): PyTorch policy module. Given observations as
|
|
|
|
input, this module must return a list of outputs where the
|
|
|
|
first item is action logits, and the rest can be any value.
|
2019-08-23 02:21:11 -04:00
|
|
|
loss (func): Function that takes (policy, model, dist_class,
|
|
|
|
train_batch) and returns a single scalar loss.
|
2019-08-10 14:05:12 -07:00
|
|
|
action_distribution_class (ActionDistribution): Class for action
|
2019-05-20 16:46:05 -07:00
|
|
|
distribution.
|
|
|
|
"""
|
2020-01-18 07:26:28 +01:00
|
|
|
super(TorchPolicy, self).__init__(
|
|
|
|
observation_space, action_space, config
|
|
|
|
)
|
2019-05-20 16:46:05 -07:00
|
|
|
self.device = (torch.device("cuda")
|
2019-10-12 00:13:00 -07:00
|
|
|
if torch.cuda.is_available() else torch.device("cpu"))
|
2019-09-08 23:01:26 -07:00
|
|
|
self.model = model.to(self.device)
|
2019-05-20 16:46:05 -07:00
|
|
|
self._loss = loss
|
|
|
|
self._optimizer = self.optimizer()
|
2019-09-08 23:01:26 -07:00
|
|
|
self.dist_class = action_distribution_class
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
|
|
|
**kwargs):
|
2019-09-24 17:52:16 -07:00
|
|
|
with torch.no_grad():
|
|
|
|
input_dict = self._lazy_tensor_dict({
|
2020-01-18 07:26:28 +01:00
|
|
|
SampleBatch.CUR_OBS: obs_batch,
|
2019-09-24 17:52:16 -07:00
|
|
|
})
|
|
|
|
if prev_action_batch:
|
2020-01-18 07:26:28 +01:00
|
|
|
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
2019-09-24 17:52:16 -07:00
|
|
|
if prev_reward_batch:
|
2020-01-18 07:26:28 +01:00
|
|
|
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
2019-09-24 17:52:16 -07:00
|
|
|
model_out = self.model(input_dict, state_batches, [1])
|
|
|
|
logits, state = model_out
|
|
|
|
action_dist = self.dist_class(logits, self.model)
|
|
|
|
actions = action_dist.sample()
|
|
|
|
return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],
|
|
|
|
self.extra_action_out(input_dict, state_batches,
|
|
|
|
self.model))
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def learn_on_batch(self, postprocessed_batch):
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
|
|
|
|
self._optimizer.zero_grad()
|
|
|
|
loss_out.backward()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
grad_process_info = self.extra_grad_process()
|
|
|
|
self._optimizer.step()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
grad_info = self.extra_grad_info(train_batch)
|
|
|
|
grad_info.update(grad_process_info)
|
|
|
|
return {LEARNER_STATS_KEY: grad_info}
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def compute_gradients(self, postprocessed_batch):
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
loss_out = self._loss(self, self.model, self.dist_class, train_batch)
|
|
|
|
self._optimizer.zero_grad()
|
|
|
|
loss_out.backward()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
grad_process_info = self.extra_grad_process()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
# Note that return values are just references;
|
|
|
|
# calling zero_grad will modify the values
|
|
|
|
grads = []
|
|
|
|
for p in self.model.parameters():
|
|
|
|
if p.grad is not None:
|
|
|
|
grads.append(p.grad.data.cpu().numpy())
|
|
|
|
else:
|
|
|
|
grads.append(None)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
grad_info = self.extra_grad_info(train_batch)
|
|
|
|
grad_info.update(grad_process_info)
|
|
|
|
return grads, {LEARNER_STATS_KEY: grad_info}
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def apply_gradients(self, gradients):
|
2019-09-24 17:52:16 -07:00
|
|
|
for g, p in zip(gradients, self.model.parameters()):
|
|
|
|
if g is not None:
|
|
|
|
p.grad = torch.from_numpy(g).to(self.device)
|
|
|
|
self._optimizer.step()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def get_weights(self):
|
2019-09-24 17:52:16 -07:00
|
|
|
return {k: v.cpu() for k, v in self.model.state_dict().items()}
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def set_weights(self, weights):
|
2019-09-24 17:52:16 -07:00
|
|
|
self.model.load_state_dict(weights)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
@override(Policy)
|
|
|
|
def num_state_tensors(self):
|
|
|
|
return len(self.model.get_initial_state())
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
|
|
|
def get_initial_state(self):
|
2019-09-08 23:01:26 -07:00
|
|
|
return [s.numpy() for s in self.model.get_initial_state()]
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
def extra_grad_process(self):
|
|
|
|
"""Allow subclass to do extra processing on gradients and
|
|
|
|
return processing info."""
|
|
|
|
return {}
|
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
def extra_action_out(self, input_dict, state_batches, model,
|
|
|
|
action_dist=None):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Returns dict of extra info to include in experience batch.
|
|
|
|
|
|
|
|
Arguments:
|
2019-06-01 16:58:49 +08:00
|
|
|
input_dict (dict): Dict of model input tensors.
|
|
|
|
state_batches (list): List of state tensors.
|
2020-01-18 07:26:28 +01:00
|
|
|
model (TorchModelV2): Reference to the model.
|
|
|
|
action_dist (Distribution): Torch Distribution object to get
|
|
|
|
log-probs (e.g. for already sampled actions).
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return {}
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def extra_grad_info(self, train_batch):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Return dict of extra grad info."""
|
|
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
def optimizer(self):
|
|
|
|
"""Custom PyTorch optimizer to use."""
|
|
|
|
if hasattr(self, "config"):
|
|
|
|
return torch.optim.Adam(
|
2019-09-08 23:01:26 -07:00
|
|
|
self.model.parameters(), lr=self.config["lr"])
|
2019-05-20 16:46:05 -07:00
|
|
|
else:
|
2019-09-08 23:01:26 -07:00
|
|
|
return torch.optim.Adam(self.model.parameters())
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
def _lazy_tensor_dict(self, postprocessed_batch):
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = UsageTrackingDict(postprocessed_batch)
|
2019-06-01 16:58:49 +08:00
|
|
|
|
|
|
|
def convert(arr):
|
|
|
|
tensor = torch.from_numpy(np.asarray(arr))
|
|
|
|
if tensor.dtype == torch.double:
|
|
|
|
tensor = tensor.float()
|
|
|
|
return tensor.to(self.device)
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch.set_get_interceptor(convert)
|
|
|
|
return train_batch
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def export_model(self, export_dir):
|
|
|
|
"""TODO: implement for torch.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def export_checkpoint(self, export_dir):
|
|
|
|
"""TODO: implement for torch.
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class LearningRateSchedule(object):
|
|
|
|
"""Mixin for TFPolicy that adds a learning rate schedule."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, lr, lr_schedule):
|
|
|
|
self.cur_lr = lr
|
|
|
|
if lr_schedule is None:
|
|
|
|
self.lr_schedule = ConstantSchedule(lr)
|
|
|
|
else:
|
|
|
|
self.lr_schedule = PiecewiseSchedule(
|
|
|
|
lr_schedule, outside_value=lr_schedule[-1][-1]
|
|
|
|
)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super(LearningRateSchedule, self).on_global_var_update(global_vars)
|
|
|
|
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])
|
|
|
|
|
|
|
|
@override(TorchPolicy)
|
|
|
|
def optimizer(self):
|
|
|
|
for p in self._optimizer.param_groups:
|
|
|
|
p["lr"] = self.cur_lr
|
|
|
|
return self._optimizer
|
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
class EntropyCoeffSchedule(object):
|
|
|
|
"""Mixin for TorchPolicy that adds entropy coeff decay."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
|
|
|
self.entropy_coeff = entropy_coeff
|
|
|
|
|
|
|
|
if entropy_coeff_schedule is None:
|
|
|
|
self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff)
|
|
|
|
else:
|
|
|
|
# Allows for custom schedule similar to lr_schedule format
|
|
|
|
if isinstance(entropy_coeff_schedule, list):
|
|
|
|
self.entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
entropy_coeff_schedule,
|
|
|
|
outside_value=entropy_coeff_schedule[-1][-1])
|
|
|
|
else:
|
|
|
|
# Implements previous version but enforces outside_value
|
|
|
|
self.entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
|
|
|
outside_value=0.0)
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
|
|
|
|
self.entropy_coeff = self.entropy_coeff_schedule.value(
|
|
|
|
global_vars["timestep"]
|
|
|
|
)
|