ray/rllib/utils/torch_ops.py
Sven Mika e2edca45d4
[RLlib] PPO torch memory leak and unnecessary torch.Tensor creation and gc'ing. (#7238)
* Take out stats to analyze memory leak in torch PPO.

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT.

* Fix determine_tests_to_run.py.

* minor change to re-test after determine_tests_to_run.py.

* LINT.

* update comments.

* WIP

* WIP

* WIP

* FIX.

* Fix sequence_mask being dependent on torch being installed.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.
2020-02-22 11:02:31 -08:00

40 lines
1.1 KiB
Python

from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
def sequence_mask(lengths, maxlen, dtype=None):
"""
Exact same behavior as tf.sequence_mask.
Thanks to Dimitris Papatheodorou
(https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
39036).
"""
if maxlen is None:
maxlen = lengths.max()
mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths). \
t()
mask.type(dtype or torch.bool)
return mask
def convert_to_non_torch_type(stats_dict):
"""Converts values in stats_dict to non-Tensor numpy or python types.
Args:
stats_dict (dict): A flat key, value dict, the values of which will be
converted and returned as a new dict.
Returns:
dict: A new dict with the same structure as stats_dict, but with all
values converted to non-torch Tensor types.
"""
ret = {}
for k, v in stats_dict.items():
if isinstance(v, torch.Tensor):
ret[k] = v.item() if len(v.size()) == 0 else v.numpy()
else:
ret[k] = v
return ret