mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00

- Translate all vtrace functionality to torch and added torch to the framework_iterator-loop in all existing vtrace test cases. - Add learning test cases for APPO torch (both w/ and w/o v-trace). - Add quick compilation tests for APPO (tf and torch, v-trace and no v-trace).
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
import logging
|
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def make_time_major(policy, seq_lens, tensor, drop_last=False):
|
|
"""Swaps batch and trajectory axis.
|
|
|
|
Arguments:
|
|
policy: Policy reference
|
|
seq_lens: Sequence lengths if recurrent or None
|
|
tensor: A tensor or list of tensors to reshape.
|
|
drop_last: A bool indicating whether to drop the last
|
|
trajectory item.
|
|
|
|
Returns:
|
|
res: A tensor with swapped axes or a list of tensors with
|
|
swapped axes.
|
|
"""
|
|
if isinstance(tensor, (list, tuple)):
|
|
return [
|
|
make_time_major(policy, seq_lens, t, drop_last) for t in tensor
|
|
]
|
|
|
|
if policy.is_recurrent():
|
|
B = seq_lens.shape[0]
|
|
T = tensor.shape[0] // B
|
|
else:
|
|
# Important: chop the tensor into batches at known episode cut
|
|
# boundaries. TODO(ekl) this is kind of a hack
|
|
T = policy.config["rollout_fragment_length"]
|
|
B = tensor.shape[0] // T
|
|
rs = torch.reshape(tensor, [B, T] + list(tensor.shape[1:]))
|
|
|
|
# Swap B and T axes.
|
|
res = torch.transpose(rs, 1, 0)
|
|
|
|
if drop_last:
|
|
return res[:-1]
|
|
return res
|
|
|
|
|
|
def choose_optimizer(policy, config):
|
|
if policy.config["opt_type"] == "adam":
|
|
return torch.optim.Adam(
|
|
params=policy.model.parameters(), lr=policy.cur_lr)
|
|
else:
|
|
return torch.optim.RMSProp(
|
|
params=policy.model.parameters(),
|
|
lr=policy.cur_lr,
|
|
weight_decay=config["decay"],
|
|
momentum=config["momentum"],
|
|
eps=config["epsilon"])
|