2020-06-16 09:01:20 +02:00
|
|
|
import functools
|
2019-06-01 16:58:49 +08:00
|
|
|
import numpy as np
|
2020-01-25 22:36:43 -08:00
|
|
|
import time
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-06-16 09:01:20 +02:00
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
2020-04-01 09:43:21 +02:00
|
|
|
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
2020-01-18 07:26:28 +01:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-04-01 07:00:28 +02:00
|
|
|
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
2020-04-15 13:25:16 +02:00
|
|
|
from ray.rllib.utils import force_list
|
2020-01-18 07:26:28 +01:00
|
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
2020-02-22 23:19:49 +01:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
2020-01-18 07:26:28 +01:00
|
|
|
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
2020-04-06 20:56:16 +02:00
|
|
|
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
|
|
|
convert_to_torch_tensor
|
2020-02-22 23:19:49 +01:00
|
|
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
torch, _ = try_import_torch()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TorchPolicy(Policy):
|
|
|
|
"""Template for a PyTorch policy and loss to use with RLlib.
|
|
|
|
|
|
|
|
This is similar to TFPolicy, but for PyTorch.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
observation_space (gym.Space): observation space of the policy.
|
|
|
|
action_space (gym.Space): action space of the policy.
|
2020-02-19 21:18:45 +01:00
|
|
|
config (dict): config of the policy.
|
|
|
|
model (TorchModel): Torch model instance.
|
|
|
|
dist_class (type): Torch action distribution class.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-01-25 22:36:43 -08:00
|
|
|
|
2020-04-01 07:00:28 +02:00
|
|
|
def __init__(self,
|
|
|
|
observation_space,
|
|
|
|
action_space,
|
|
|
|
config,
|
2020-04-01 09:43:21 +02:00
|
|
|
*,
|
2020-04-01 07:00:28 +02:00
|
|
|
model,
|
|
|
|
loss,
|
|
|
|
action_distribution_class,
|
2020-04-01 09:43:21 +02:00
|
|
|
action_sampler_fn=None,
|
|
|
|
action_distribution_fn=None,
|
2020-04-01 07:00:28 +02:00
|
|
|
max_seq_len=20,
|
|
|
|
get_batch_divisibility_req=None):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Build a policy from policy and loss torch modules.
|
|
|
|
|
|
|
|
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
|
|
|
is set. Only single GPU is supported for now.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
observation_space (gym.Space): observation space of the policy.
|
|
|
|
action_space (gym.Space): action space of the policy.
|
2020-01-18 07:26:28 +01:00
|
|
|
config (dict): The Policy config dict.
|
2019-05-20 16:46:05 -07:00
|
|
|
model (nn.Module): PyTorch policy module. Given observations as
|
|
|
|
input, this module must return a list of outputs where the
|
|
|
|
first item is action logits, and the rest can be any value.
|
2019-08-23 02:21:11 -04:00
|
|
|
loss (func): Function that takes (policy, model, dist_class,
|
|
|
|
train_batch) and returns a single scalar loss.
|
2019-08-10 14:05:12 -07:00
|
|
|
action_distribution_class (ActionDistribution): Class for action
|
2019-05-20 16:46:05 -07:00
|
|
|
distribution.
|
2020-04-01 09:43:21 +02:00
|
|
|
action_sampler_fn (Optional[callable]): A callable returning a
|
|
|
|
sampled action and its log-likelihood given some (obs and
|
|
|
|
state) inputs.
|
|
|
|
action_distribution_fn (Optional[callable]): A callable returning
|
|
|
|
distribution inputs (parameters), a dist-class to generate an
|
|
|
|
action distribution object from, and internal-state outputs
|
|
|
|
(or an empty list if not applicable).
|
|
|
|
Note: No Exploration hooks have to be called from within
|
|
|
|
`action_distribution_fn`. It's should only perform a simple
|
|
|
|
forward pass through some model.
|
|
|
|
If None, pass inputs through `self.model()` to get the
|
|
|
|
distribution inputs.
|
2020-04-01 07:00:28 +02:00
|
|
|
max_seq_len (int): Max sequence length for LSTM training.
|
|
|
|
get_batch_divisibility_req (Optional[callable]): Optional callable
|
|
|
|
that returns the divisibility requirement for sample batches.
|
2019-05-20 16:46:05 -07:00
|
|
|
"""
|
2020-02-19 21:18:45 +01:00
|
|
|
self.framework = "torch"
|
2020-02-11 00:22:07 +01:00
|
|
|
super().__init__(observation_space, action_space, config)
|
2019-05-20 16:46:05 -07:00
|
|
|
self.device = (torch.device("cuda")
|
2019-10-12 00:13:00 -07:00
|
|
|
if torch.cuda.is_available() else torch.device("cpu"))
|
2019-09-08 23:01:26 -07:00
|
|
|
self.model = model.to(self.device)
|
2020-04-01 09:43:21 +02:00
|
|
|
self.exploration = self._create_exploration()
|
2020-01-25 22:36:43 -08:00
|
|
|
self.unwrapped_model = model # used to support DistributedDataParallel
|
2019-05-20 16:46:05 -07:00
|
|
|
self._loss = loss
|
2020-04-15 13:25:16 +02:00
|
|
|
self._optimizers = force_list(self.optimizer())
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2019-09-08 23:01:26 -07:00
|
|
|
self.dist_class = action_distribution_class
|
2020-04-01 09:43:21 +02:00
|
|
|
self.action_sampler_fn = action_sampler_fn
|
|
|
|
self.action_distribution_fn = action_distribution_fn
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-01-25 22:36:43 -08:00
|
|
|
# If set, means we are using distributed allreduce during learning.
|
|
|
|
self.distributed_world_size = None
|
|
|
|
|
2020-04-01 07:00:28 +02:00
|
|
|
self.max_seq_len = max_seq_len
|
|
|
|
self.batch_divisibility_req = \
|
|
|
|
get_batch_divisibility_req(self) if get_batch_divisibility_req \
|
|
|
|
else 1
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
2020-02-19 21:18:45 +01:00
|
|
|
explore=None,
|
2020-02-11 00:22:07 +01:00
|
|
|
timestep=None,
|
2019-05-20 16:46:05 -07:00
|
|
|
**kwargs):
|
2020-02-19 21:18:45 +01:00
|
|
|
|
|
|
|
explore = explore if explore is not None else self.config["explore"]
|
2020-03-29 00:16:30 +01:00
|
|
|
timestep = timestep if timestep is not None else self.global_timestep
|
2020-02-19 21:18:45 +01:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
with torch.no_grad():
|
2020-04-03 19:44:25 +02:00
|
|
|
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
2019-09-24 17:52:16 -07:00
|
|
|
input_dict = self._lazy_tensor_dict({
|
2020-01-18 07:26:28 +01:00
|
|
|
SampleBatch.CUR_OBS: obs_batch,
|
2020-04-06 20:56:16 +02:00
|
|
|
"is_training": False,
|
2019-09-24 17:52:16 -07:00
|
|
|
})
|
2020-04-06 20:56:16 +02:00
|
|
|
if prev_action_batch is not None:
|
2020-01-18 07:26:28 +01:00
|
|
|
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
2020-04-06 20:56:16 +02:00
|
|
|
if prev_reward_batch is not None:
|
2020-01-18 07:26:28 +01:00
|
|
|
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
2020-04-06 20:56:16 +02:00
|
|
|
state_batches = [
|
|
|
|
self._convert_to_tensor(s) for s in (state_batches or [])
|
|
|
|
]
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
if self.action_sampler_fn:
|
2020-04-06 20:56:16 +02:00
|
|
|
action_dist = dist_inputs = None
|
2020-04-01 09:43:21 +02:00
|
|
|
state_out = []
|
|
|
|
actions, logp = self.action_sampler_fn(
|
|
|
|
self,
|
|
|
|
self.model,
|
|
|
|
input_dict[SampleBatch.CUR_OBS],
|
|
|
|
explore=explore,
|
|
|
|
timestep=timestep)
|
|
|
|
else:
|
|
|
|
# Call the exploration before_compute_actions hook.
|
2020-04-06 20:56:16 +02:00
|
|
|
self.exploration.before_compute_actions(
|
|
|
|
explore=explore, timestep=timestep)
|
2020-04-01 09:43:21 +02:00
|
|
|
if self.action_distribution_fn:
|
|
|
|
dist_inputs, dist_class, state_out = \
|
|
|
|
self.action_distribution_fn(
|
2020-04-06 20:56:16 +02:00
|
|
|
self,
|
|
|
|
self.model,
|
|
|
|
input_dict[SampleBatch.CUR_OBS],
|
|
|
|
explore=explore,
|
|
|
|
timestep=timestep,
|
|
|
|
is_training=False)
|
2020-04-01 09:43:21 +02:00
|
|
|
else:
|
|
|
|
dist_class = self.dist_class
|
|
|
|
dist_inputs, state_out = self.model(
|
|
|
|
input_dict, state_batches, seq_lens)
|
2020-06-16 09:01:20 +02:00
|
|
|
if not (isinstance(dist_class, functools.partial)
|
|
|
|
or issubclass(dist_class, TorchDistributionWrapper)):
|
|
|
|
raise ValueError(
|
|
|
|
"`dist_class` ({}) not a TorchDistributionWrapper "
|
|
|
|
"subclass! Make sure your `action_distribution_fn` or "
|
|
|
|
"`make_model_and_action_dist` return a correct "
|
|
|
|
"distribution class.".format(dist_class.__name__))
|
2020-04-01 09:43:21 +02:00
|
|
|
action_dist = dist_class(dist_inputs, self.model)
|
|
|
|
|
|
|
|
# Get the exploration action from the forward results.
|
|
|
|
actions, logp = \
|
|
|
|
self.exploration.get_exploration_action(
|
|
|
|
action_distribution=action_dist,
|
|
|
|
timestep=timestep,
|
|
|
|
explore=explore)
|
2020-03-29 00:16:30 +01:00
|
|
|
|
2020-01-21 08:06:50 +01:00
|
|
|
input_dict[SampleBatch.ACTIONS] = actions
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
# Add default and custom fetches.
|
|
|
|
extra_fetches = self.extra_action_out(input_dict, state_batches,
|
|
|
|
self.model, action_dist)
|
|
|
|
# Action-logp and action-prob.
|
2020-02-19 21:18:45 +01:00
|
|
|
if logp is not None:
|
2020-03-08 21:03:18 +01:00
|
|
|
logp = convert_to_non_torch_type(logp)
|
2020-04-01 09:43:21 +02:00
|
|
|
extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
|
|
|
|
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
|
|
|
# Action-dist inputs.
|
|
|
|
if dist_inputs is not None:
|
|
|
|
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
|
|
|
return convert_to_non_torch_type((actions, state_out,
|
|
|
|
extra_fetches))
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
@override(Policy)
|
|
|
|
def compute_log_likelihoods(self,
|
|
|
|
actions,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None):
|
2020-04-01 09:43:21 +02:00
|
|
|
|
|
|
|
if self.action_sampler_fn and self.action_distribution_fn is None:
|
|
|
|
raise ValueError("Cannot compute log-prob/likelihood w/o an "
|
|
|
|
"`action_distribution_fn` and a provided "
|
|
|
|
"`action_sampler_fn`!")
|
|
|
|
|
2020-02-22 23:19:49 +01:00
|
|
|
with torch.no_grad():
|
|
|
|
input_dict = self._lazy_tensor_dict({
|
|
|
|
SampleBatch.CUR_OBS: obs_batch,
|
|
|
|
SampleBatch.ACTIONS: actions
|
|
|
|
})
|
2020-04-07 01:38:50 +02:00
|
|
|
if prev_action_batch is not None:
|
2020-02-22 23:19:49 +01:00
|
|
|
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
|
2020-04-07 01:38:50 +02:00
|
|
|
if prev_reward_batch is not None:
|
2020-02-22 23:19:49 +01:00
|
|
|
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
|
2020-04-01 09:43:21 +02:00
|
|
|
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
|
|
|
|
|
|
|
|
# Exploration hook before each forward pass.
|
|
|
|
self.exploration.before_compute_actions(explore=False)
|
|
|
|
|
|
|
|
# Action dist class and inputs are generated via custom function.
|
|
|
|
if self.action_distribution_fn:
|
|
|
|
dist_inputs, dist_class, _ = self.action_distribution_fn(
|
2020-04-15 13:25:16 +02:00
|
|
|
policy=self,
|
|
|
|
model=self.model,
|
|
|
|
obs_batch=input_dict[SampleBatch.CUR_OBS],
|
|
|
|
explore=False,
|
|
|
|
is_training=False)
|
2020-04-01 09:43:21 +02:00
|
|
|
# Default action-dist inputs calculation.
|
|
|
|
else:
|
|
|
|
dist_class = self.dist_class
|
|
|
|
dist_inputs, _ = self.model(input_dict, state_batches,
|
|
|
|
seq_lens)
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2020-04-01 09:43:21 +02:00
|
|
|
action_dist = dist_class(dist_inputs, self.model)
|
2020-02-22 23:19:49 +01:00
|
|
|
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
|
|
|
return log_likelihoods
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
|
|
|
def learn_on_batch(self, postprocessed_batch):
|
2020-04-01 07:00:28 +02:00
|
|
|
# Get batch ready for RNNs, if applicable.
|
|
|
|
pad_batch_to_sequences_of_same_size(
|
|
|
|
postprocessed_batch,
|
|
|
|
max_seq_len=self.max_seq_len,
|
|
|
|
shuffle=False,
|
|
|
|
batch_divisibility_req=self.batch_divisibility_req)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-04-01 07:00:28 +02:00
|
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
2020-04-15 13:25:16 +02:00
|
|
|
loss_out = force_list(
|
|
|
|
self._loss(self, self.model, self.dist_class, train_batch))
|
|
|
|
assert len(loss_out) == len(self._optimizers)
|
2020-05-30 22:48:34 +02:00
|
|
|
# assert not any(torch.isnan(l) for l in loss_out)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
# Loop through all optimizers.
|
|
|
|
grad_info = {"allreduce_latency": 0.0}
|
|
|
|
for i, opt in enumerate(self._optimizers):
|
|
|
|
# Erase gradients in all vars of this optimizer.
|
|
|
|
opt.zero_grad()
|
|
|
|
# Recompute gradients of loss over all variables.
|
|
|
|
loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
|
|
|
|
grad_info.update(self.extra_grad_process(opt, loss_out[i]))
|
|
|
|
|
|
|
|
if self.distributed_world_size:
|
|
|
|
grads = []
|
|
|
|
for param_group in opt.param_groups:
|
|
|
|
for p in param_group["params"]:
|
|
|
|
if p.grad is not None:
|
|
|
|
grads.append(p.grad)
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
# Sadly, allreduce_coalesced does not work with CUDA yet.
|
|
|
|
for g in grads:
|
|
|
|
torch.distributed.all_reduce(
|
|
|
|
g, op=torch.distributed.ReduceOp.SUM)
|
|
|
|
else:
|
|
|
|
torch.distributed.all_reduce_coalesced(
|
|
|
|
grads, op=torch.distributed.ReduceOp.SUM)
|
2020-01-25 22:36:43 -08:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
for param_group in opt.param_groups:
|
|
|
|
for p in param_group["params"]:
|
|
|
|
if p.grad is not None:
|
|
|
|
p.grad /= self.distributed_world_size
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
grad_info["allreduce_latency"] += time.time() - start
|
|
|
|
|
|
|
|
# Step the optimizer.
|
|
|
|
opt.step()
|
|
|
|
|
|
|
|
grad_info["allreduce_latency"] /= len(self._optimizers)
|
|
|
|
grad_info.update(self.extra_grad_info(train_batch))
|
|
|
|
return {LEARNER_STATS_KEY: grad_info}
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def compute_gradients(self, postprocessed_batch):
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
2020-04-15 13:25:16 +02:00
|
|
|
loss_out = force_list(
|
|
|
|
self._loss(self, self.model, self.dist_class, train_batch))
|
|
|
|
assert len(loss_out) == len(self._optimizers)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
grad_process_info = {}
|
2019-09-24 17:52:16 -07:00
|
|
|
grads = []
|
2020-04-15 13:25:16 +02:00
|
|
|
for i, opt in enumerate(self._optimizers):
|
|
|
|
opt.zero_grad()
|
|
|
|
loss_out[i].backward()
|
|
|
|
grad_process_info = self.extra_grad_process(opt, loss_out[i])
|
|
|
|
|
|
|
|
# Note that return values are just references;
|
|
|
|
# calling zero_grad will modify the values
|
|
|
|
for param_group in opt.param_groups:
|
|
|
|
for p in param_group["params"]:
|
|
|
|
if p.grad is not None:
|
|
|
|
grads.append(p.grad.data.cpu().numpy())
|
|
|
|
else:
|
|
|
|
grads.append(None)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2019-09-24 17:52:16 -07:00
|
|
|
grad_info = self.extra_grad_info(train_batch)
|
|
|
|
grad_info.update(grad_process_info)
|
|
|
|
return grads, {LEARNER_STATS_KEY: grad_info}
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def apply_gradients(self, gradients):
|
2020-04-15 13:25:16 +02:00
|
|
|
# TODO(sven): Not supported for multiple optimizers yet.
|
|
|
|
assert len(self._optimizers) == 1
|
2019-09-24 17:52:16 -07:00
|
|
|
for g, p in zip(gradients, self.model.parameters()):
|
|
|
|
if g is not None:
|
|
|
|
p.grad = torch.from_numpy(g).to(self.device)
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
self._optimizers[0].step()
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def get_weights(self):
|
2020-04-06 20:56:16 +02:00
|
|
|
return {
|
|
|
|
k: v.cpu().detach().numpy()
|
|
|
|
for k, v in self.model.state_dict().items()
|
|
|
|
}
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def set_weights(self, weights):
|
2020-04-06 20:56:16 +02:00
|
|
|
weights = convert_to_torch_tensor(weights, device=self.device)
|
2019-09-24 17:52:16 -07:00
|
|
|
self.model.load_state_dict(weights)
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-06-05 21:07:02 +02:00
|
|
|
@override(Policy)
|
|
|
|
def get_state(self):
|
|
|
|
state = super().get_state()
|
|
|
|
state["_optimizer_variables"] = []
|
|
|
|
for i, o in enumerate(self._optimizers):
|
|
|
|
state["_optimizer_variables"].append(o.state_dict())
|
|
|
|
return state
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def set_state(self, state):
|
|
|
|
state = state.copy() # shallow copy
|
|
|
|
# Set optimizer vars first.
|
|
|
|
optimizer_vars = state.pop("_optimizer_variables", None)
|
|
|
|
if optimizer_vars:
|
|
|
|
assert len(optimizer_vars) == len(self._optimizers)
|
|
|
|
for o, s in zip(self._optimizers, optimizer_vars):
|
|
|
|
o.load_state_dict(s)
|
|
|
|
# Then the Policy's (NN) weights.
|
|
|
|
super().set_state(state)
|
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
@override(Policy)
|
|
|
|
def is_recurrent(self):
|
|
|
|
return len(self.model.get_initial_state()) > 0
|
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
@override(Policy)
|
|
|
|
def num_state_tensors(self):
|
|
|
|
return len(self.model.get_initial_state())
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
@override(Policy)
|
|
|
|
def get_initial_state(self):
|
2020-04-26 03:49:09 +02:00
|
|
|
return [
|
|
|
|
s.cpu().detach().numpy() for s in self.model.get_initial_state()
|
|
|
|
]
|
2019-05-20 16:46:05 -07:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
def extra_grad_process(self, optimizer, loss):
|
|
|
|
"""Called after each optimizer.zero_grad() + loss.backward() call.
|
|
|
|
|
|
|
|
Called for each self._optimizers/loss-value pair.
|
|
|
|
Allows for gradient processing before optimizer.step() is called.
|
|
|
|
E.g. for gradient clipping.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer (torch.optim.Optimizer): A torch optimizer object.
|
|
|
|
loss (torch.Tensor): The loss tensor associated with the optimizer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: An info dict.
|
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return {}
|
|
|
|
|
2020-04-03 19:44:25 +02:00
|
|
|
def extra_action_out(self, input_dict, state_batches, model, action_dist):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Returns dict of extra info to include in experience batch.
|
|
|
|
|
2020-04-03 19:44:25 +02:00
|
|
|
Args:
|
2019-06-01 16:58:49 +08:00
|
|
|
input_dict (dict): Dict of model input tensors.
|
|
|
|
state_batches (list): List of state tensors.
|
2020-01-18 07:26:28 +01:00
|
|
|
model (TorchModelV2): Reference to the model.
|
2020-04-03 19:44:25 +02:00
|
|
|
action_dist (TorchActionDistribution): Torch action dist object
|
|
|
|
to get log-probs (e.g. for already sampled actions).
|
2020-01-18 07:26:28 +01:00
|
|
|
"""
|
2019-05-20 16:46:05 -07:00
|
|
|
return {}
|
|
|
|
|
2019-08-23 02:21:11 -04:00
|
|
|
def extra_grad_info(self, train_batch):
|
2019-05-20 16:46:05 -07:00
|
|
|
"""Return dict of extra grad info."""
|
|
|
|
return {}
|
|
|
|
|
|
|
|
def optimizer(self):
|
|
|
|
"""Custom PyTorch optimizer to use."""
|
|
|
|
if hasattr(self, "config"):
|
|
|
|
return torch.optim.Adam(
|
2019-09-08 23:01:26 -07:00
|
|
|
self.model.parameters(), lr=self.config["lr"])
|
2019-05-20 16:46:05 -07:00
|
|
|
else:
|
2019-09-08 23:01:26 -07:00
|
|
|
return torch.optim.Adam(self.model.parameters())
|
2019-05-20 16:46:05 -07:00
|
|
|
|
|
|
|
def _lazy_tensor_dict(self, postprocessed_batch):
|
2019-08-23 02:21:11 -04:00
|
|
|
train_batch = UsageTrackingDict(postprocessed_batch)
|
2020-02-17 10:26:58 -08:00
|
|
|
train_batch.set_get_interceptor(self._convert_to_tensor)
|
2019-08-23 02:21:11 -04:00
|
|
|
return train_batch
|
2020-01-18 07:26:28 +01:00
|
|
|
|
2020-02-17 10:26:58 -08:00
|
|
|
def _convert_to_tensor(self, arr):
|
|
|
|
if torch.is_tensor(arr):
|
|
|
|
return arr.to(self.device)
|
|
|
|
tensor = torch.from_numpy(np.asarray(arr))
|
|
|
|
if tensor.dtype == torch.double:
|
|
|
|
tensor = tensor.float()
|
|
|
|
return tensor.to(self.device)
|
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
@override(Policy)
|
|
|
|
def export_model(self, export_dir):
|
2020-03-23 20:19:30 +01:00
|
|
|
"""TODO(sven): implement for torch.
|
2020-01-18 07:26:28 +01:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def export_checkpoint(self, export_dir):
|
2020-03-23 20:19:30 +01:00
|
|
|
"""TODO(sven): implement for torch.
|
2020-01-18 07:26:28 +01:00
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2020-03-23 20:19:30 +01:00
|
|
|
@override(Policy)
|
|
|
|
def import_model_from_h5(self, import_file):
|
|
|
|
"""Imports weights into torch model."""
|
|
|
|
return self.model.import_from_h5(import_file)
|
|
|
|
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-02-11 00:22:07 +01:00
|
|
|
class LearningRateSchedule:
|
2020-01-18 07:26:28 +01:00
|
|
|
"""Mixin for TFPolicy that adds a learning rate schedule."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, lr, lr_schedule):
|
|
|
|
self.cur_lr = lr
|
|
|
|
if lr_schedule is None:
|
2020-03-10 11:14:14 -07:00
|
|
|
self.lr_schedule = ConstantSchedule(lr, framework=None)
|
2020-01-18 07:26:28 +01:00
|
|
|
else:
|
|
|
|
self.lr_schedule = PiecewiseSchedule(
|
2020-03-10 11:14:14 -07:00
|
|
|
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super(LearningRateSchedule, self).on_global_var_update(global_vars)
|
|
|
|
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])
|
|
|
|
|
|
|
|
@override(TorchPolicy)
|
|
|
|
def optimizer(self):
|
2020-04-15 13:25:16 +02:00
|
|
|
for opt in self._optimizers:
|
|
|
|
for p in opt.param_groups:
|
|
|
|
p["lr"] = self.cur_lr
|
|
|
|
return self._optimizers
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-02-11 00:22:07 +01:00
|
|
|
class EntropyCoeffSchedule:
|
2020-01-18 07:26:28 +01:00
|
|
|
"""Mixin for TorchPolicy that adds entropy coeff decay."""
|
|
|
|
|
|
|
|
@DeveloperAPI
|
|
|
|
def __init__(self, entropy_coeff, entropy_coeff_schedule):
|
|
|
|
self.entropy_coeff = entropy_coeff
|
|
|
|
|
|
|
|
if entropy_coeff_schedule is None:
|
2020-03-10 11:14:14 -07:00
|
|
|
self.entropy_coeff_schedule = ConstantSchedule(
|
|
|
|
entropy_coeff, framework=None)
|
2020-01-18 07:26:28 +01:00
|
|
|
else:
|
|
|
|
# Allows for custom schedule similar to lr_schedule format
|
|
|
|
if isinstance(entropy_coeff_schedule, list):
|
|
|
|
self.entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
entropy_coeff_schedule,
|
2020-03-10 11:14:14 -07:00
|
|
|
outside_value=entropy_coeff_schedule[-1][-1],
|
|
|
|
framework=None)
|
2020-01-18 07:26:28 +01:00
|
|
|
else:
|
|
|
|
# Implements previous version but enforces outside_value
|
|
|
|
self.entropy_coeff_schedule = PiecewiseSchedule(
|
|
|
|
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
|
2020-03-10 11:14:14 -07:00
|
|
|
outside_value=0.0,
|
|
|
|
framework=None)
|
2020-01-18 07:26:28 +01:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def on_global_var_update(self, global_vars):
|
|
|
|
super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
|
|
|
|
self.entropy_coeff = self.entropy_coeff_schedule.value(
|
2020-01-25 22:36:43 -08:00
|
|
|
global_vars["timestep"])
|