mirror of
https://github.com/vale981/ray
synced 2025-03-13 22:56:38 -04:00
58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
![]() |
import logging
|
||
|
|
||
|
from ray.rllib.utils.framework import try_import_torch
|
||
|
|
||
|
torch, nn = try_import_torch()
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def make_time_major(policy, seq_lens, tensor, drop_last=False):
|
||
|
"""Swaps batch and trajectory axis.
|
||
|
|
||
|
Arguments:
|
||
|
policy: Policy reference
|
||
|
seq_lens: Sequence lengths if recurrent or None
|
||
|
tensor: A tensor or list of tensors to reshape.
|
||
|
drop_last: A bool indicating whether to drop the last
|
||
|
trajectory item.
|
||
|
|
||
|
Returns:
|
||
|
res: A tensor with swapped axes or a list of tensors with
|
||
|
swapped axes.
|
||
|
"""
|
||
|
if isinstance(tensor, (list, tuple)):
|
||
|
return [
|
||
|
make_time_major(policy, seq_lens, t, drop_last) for t in tensor
|
||
|
]
|
||
|
|
||
|
if policy.is_recurrent():
|
||
|
B = seq_lens.shape[0]
|
||
|
T = tensor.shape[0] // B
|
||
|
else:
|
||
|
# Important: chop the tensor into batches at known episode cut
|
||
|
# boundaries. TODO(ekl) this is kind of a hack
|
||
|
T = policy.config["rollout_fragment_length"]
|
||
|
B = tensor.shape[0] // T
|
||
|
rs = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
|
||
|
|
||
|
# Swap B and T axes.
|
||
|
res = torch.transpose(rs, 1, 0)
|
||
|
|
||
|
if drop_last:
|
||
|
return res[:-1]
|
||
|
return res
|
||
|
|
||
|
|
||
|
def choose_optimizer(policy, config):
|
||
|
if policy.config["opt_type"] == "adam":
|
||
|
return torch.optim.Adam(
|
||
|
params=policy.model.parameters(), lr=policy.cur_lr)
|
||
|
else:
|
||
|
return torch.optim.RMSProp(
|
||
|
params=policy.model.parameters(),
|
||
|
lr=policy.cur_lr,
|
||
|
weight_decay=config["decay"],
|
||
|
momentum=config["momentum"],
|
||
|
eps=config["epsilon"])
|