2022-01-05 11:29:44 +01:00
|
|
|
from gym.spaces import Discrete, MultiDiscrete
|
2019-12-30 15:27:32 -05:00
|
|
|
import numpy as np
|
2021-04-16 09:16:24 +02:00
|
|
|
import tree # pip install dm_tree
|
2021-11-01 21:46:02 +01:00
|
|
|
from typing import List, Optional
|
2019-12-30 15:27:32 -05:00
|
|
|
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
2020-04-01 07:00:28 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.utils.typing import SpaceStruct, TensorType, TensorStructType, Union
|
2020-02-22 23:19:49 +01:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-02-22 23:19:49 +01:00
|
|
|
torch, _ = try_import_torch()
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
SMALL_NUMBER = 1e-6
|
|
|
|
# Some large int number. May be increased here, if needed.
|
|
|
|
LARGE_INTEGER = 100000000
|
|
|
|
# Min and Max outputs (clipped) from an NN-output layer interpreted as the
|
|
|
|
# log(x) of some x (e.g. a stddev of a normal
|
|
|
|
# distribution).
|
2021-05-04 10:06:19 -07:00
|
|
|
MIN_LOG_NN_OUTPUT = -5
|
2019-12-30 15:27:32 -05:00
|
|
|
MAX_LOG_NN_OUTPUT = 2
|
|
|
|
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
def aligned_array(size: int, dtype, align: int = 64) -> np.ndarray:
|
|
|
|
"""Returns an array of a given size that is 64-byte aligned.
|
2020-04-26 23:08:13 +02:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
The returned array can be efficiently copied into GPU memory by TensorFlow.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Args:
|
2021-11-01 21:46:02 +01:00
|
|
|
size: The size (total number of items) of the array. For example,
|
|
|
|
array([[0.0, 1.0], [2.0, 3.0]]) would have size=4.
|
|
|
|
dtype: The numpy dtype of the array.
|
|
|
|
align: The alignment to use.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Returns:
|
2021-11-01 21:46:02 +01:00
|
|
|
A np.ndarray with the given specifications.
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
2021-11-01 21:46:02 +01:00
|
|
|
n = size * dtype.itemsize
|
|
|
|
empty = np.empty(n + (align - 1), dtype=np.uint8)
|
|
|
|
data_align = empty.ctypes.data % align
|
|
|
|
offset = 0 if data_align == 0 else (align - data_align)
|
|
|
|
if n == 0:
|
|
|
|
# stop np from optimising out empty slice reference
|
2022-01-29 18:41:57 -08:00
|
|
|
output = empty[offset : offset + 1][0:0].view(dtype)
|
2019-12-30 15:27:32 -05:00
|
|
|
else:
|
2022-01-29 18:41:57 -08:00
|
|
|
output = empty[offset : offset + n].view(dtype)
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
assert len(output) == size, len(output)
|
|
|
|
assert output.ctypes.data % align == 0, output.ctypes.data
|
|
|
|
return output
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def concat_aligned(
|
|
|
|
items: List[np.ndarray], time_major: Optional[bool] = None
|
|
|
|
) -> np.ndarray:
|
2021-11-01 21:46:02 +01:00
|
|
|
"""Concatenate arrays, ensuring the output is 64-byte aligned.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
We only align float arrays; other arrays are concatenated as normal.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
This should be used instead of np.concatenate() to improve performance
|
|
|
|
when the output array is likely to be fed into TensorFlow.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Args:
|
2021-11-01 21:46:02 +01:00
|
|
|
items: The list of items to concatenate and align.
|
|
|
|
time_major: Whether the data in items is time-major, in which
|
|
|
|
case, we will concatenate along axis=1.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Returns:
|
2021-11-01 21:46:02 +01:00
|
|
|
The concat'd and aligned array.
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
if len(items) == 0:
|
|
|
|
return []
|
|
|
|
elif len(items) == 1:
|
|
|
|
# we assume the input is aligned. In any case, it doesn't help
|
|
|
|
# performance to force align it since that incurs a needless copy.
|
|
|
|
return items[0]
|
2022-01-29 18:41:57 -08:00
|
|
|
elif isinstance(items[0], np.ndarray) and items[0].dtype in [
|
|
|
|
np.float32,
|
|
|
|
np.float64,
|
|
|
|
np.uint8,
|
|
|
|
]:
|
2021-11-01 21:46:02 +01:00
|
|
|
dtype = items[0].dtype
|
|
|
|
flat = aligned_array(sum(s.size for s in items), dtype)
|
|
|
|
if time_major is not None:
|
|
|
|
if time_major is True:
|
|
|
|
batch_dim = sum(s.shape[1] for s in items)
|
2022-01-29 18:41:57 -08:00
|
|
|
new_shape = (items[0].shape[0], batch_dim,) + items[
|
|
|
|
0
|
|
|
|
].shape[2:]
|
2021-11-01 21:46:02 +01:00
|
|
|
else:
|
|
|
|
batch_dim = sum(s.shape[0] for s in items)
|
2022-01-29 18:41:57 -08:00
|
|
|
new_shape = (batch_dim, items[0].shape[1],) + items[
|
|
|
|
0
|
|
|
|
].shape[2:]
|
2021-11-01 21:46:02 +01:00
|
|
|
else:
|
|
|
|
batch_dim = sum(s.shape[0] for s in items)
|
2022-01-29 18:41:57 -08:00
|
|
|
new_shape = (batch_dim,) + items[0].shape[1:]
|
2021-11-01 21:46:02 +01:00
|
|
|
output = flat.reshape(new_shape)
|
|
|
|
assert output.ctypes.data % 64 == 0, output.ctypes.data
|
|
|
|
np.concatenate(items, out=output, axis=1 if time_major else 0)
|
|
|
|
return output
|
|
|
|
else:
|
|
|
|
return np.concatenate(items, axis=1 if time_major else 0)
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def convert_to_numpy(
|
|
|
|
x: TensorStructType, reduce_type: bool = True, reduce_floats=DEPRECATED_VALUE
|
|
|
|
):
|
2021-11-01 21:46:02 +01:00
|
|
|
"""Converts values in `stats` to non-Tensor numpy or python types.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Args:
|
2021-11-01 21:46:02 +01:00
|
|
|
x: Any (possibly nested) struct, the values in which will be
|
|
|
|
converted and returned as a new struct with all torch/tf tensors
|
|
|
|
being converted to numpy types.
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
reduce_type: Whether to automatically reduce all float64 and int64 data
|
|
|
|
into float32 and int32 data, respectively.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Returns:
|
2021-11-01 21:46:02 +01:00
|
|
|
A new struct with the same structure as `x`, but with all
|
|
|
|
values converted to numpy arrays (on CPU).
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
2020-11-25 20:27:46 +01:00
|
|
|
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
if reduce_floats != DEPRECATED_VALUE:
|
2022-01-29 18:41:57 -08:00
|
|
|
deprecation_warning(old="reduce_floats", new="reduce_types", error=False)
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
reduce_type = reduce_floats
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
# The mapping function used to numpyize torch/tf Tensors (and move them
|
|
|
|
# to the CPU beforehand).
|
|
|
|
def mapping(item):
|
|
|
|
if torch and isinstance(item, torch.Tensor):
|
2022-01-29 18:41:57 -08:00
|
|
|
ret = (
|
|
|
|
item.cpu().item()
|
|
|
|
if len(item.size()) == 0
|
|
|
|
else item.detach().cpu().numpy()
|
|
|
|
)
|
|
|
|
elif (
|
|
|
|
tf and isinstance(item, (tf.Tensor, tf.Variable)) and hasattr(item, "numpy")
|
|
|
|
):
|
2021-11-01 21:46:02 +01:00
|
|
|
assert tf.executing_eagerly()
|
|
|
|
ret = item.numpy()
|
|
|
|
else:
|
|
|
|
ret = item
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
if reduce_type and isinstance(ret, np.ndarray):
|
|
|
|
if np.issubdtype(ret.dtype, np.floating):
|
|
|
|
ret = ret.astype(np.float32)
|
|
|
|
elif np.issubdtype(ret.dtype, int):
|
|
|
|
ret = ret.astype(np.int32)
|
|
|
|
return ret
|
2021-11-01 21:46:02 +01:00
|
|
|
return ret
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
return tree.map_structure(mapping, x)
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def fc(
|
|
|
|
x: np.ndarray,
|
|
|
|
weights: np.ndarray,
|
|
|
|
biases: Optional[np.ndarray] = None,
|
|
|
|
framework: Optional[str] = None,
|
|
|
|
) -> np.ndarray:
|
2021-11-01 21:46:02 +01:00
|
|
|
"""Calculates FC (dense) layer outputs given weights/biases and input.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Args:
|
2021-11-01 21:46:02 +01:00
|
|
|
x: The input to the dense layer.
|
|
|
|
weights: The weights matrix.
|
|
|
|
biases: The biases vector. All 0s if None.
|
|
|
|
framework: An optional framework hint (to figure out,
|
2020-04-06 20:56:16 +02:00
|
|
|
e.g. whether to transpose torch weight matrices).
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
The dense layer's output.
|
|
|
|
"""
|
2020-04-01 07:00:28 +02:00
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
def map_(data, transpose=False):
|
|
|
|
if torch:
|
|
|
|
if isinstance(data, torch.Tensor):
|
|
|
|
data = data.cpu().detach().numpy()
|
|
|
|
if tf and tf.executing_eagerly():
|
|
|
|
if isinstance(data, tf.Variable):
|
|
|
|
data = data.numpy()
|
|
|
|
if transpose:
|
|
|
|
data = np.transpose(data)
|
|
|
|
return data
|
|
|
|
|
|
|
|
x = map_(x)
|
2020-02-22 23:19:49 +01:00
|
|
|
# Torch stores matrices in transpose (faster for backprop).
|
2022-01-29 18:41:57 -08:00
|
|
|
transpose = framework == "torch" and (
|
|
|
|
x.shape[1] != weights.shape[0] and x.shape[1] == weights.shape[1]
|
|
|
|
)
|
2020-04-15 13:25:16 +02:00
|
|
|
weights = map_(weights, transpose=transpose)
|
2020-04-06 20:56:16 +02:00
|
|
|
biases = map_(biases)
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
return np.matmul(x, weights) + (0.0 if biases is None else biases)
|
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def flatten_inputs_to_1d_tensor(
|
|
|
|
inputs: TensorStructType,
|
|
|
|
spaces_struct: Optional[SpaceStruct] = None,
|
|
|
|
time_axis: bool = False,
|
|
|
|
) -> TensorType:
|
2022-01-05 11:29:44 +01:00
|
|
|
"""Flattens arbitrary input structs according to the given spaces struct.
|
|
|
|
|
|
|
|
Returns a single 1D tensor resulting from the different input
|
|
|
|
components' values.
|
|
|
|
|
|
|
|
Thereby:
|
|
|
|
- Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes
|
|
|
|
are not treated differently from other types of Boxes and get
|
|
|
|
flattened as well.
|
|
|
|
- Discrete (int) values are one-hot'd, e.g. a batch of [1, 0, 3] (B=3 with
|
|
|
|
Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]].
|
|
|
|
- MultiDiscrete values are multi-one-hot'd, e.g. a batch of
|
|
|
|
[[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in
|
|
|
|
[[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].
|
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs: The inputs to be flattened.
|
|
|
|
spaces_struct: The structure of the spaces that behind the input
|
|
|
|
time_axis: Whether all inputs have a time-axis (after the batch axis).
|
|
|
|
If True, will keep not only the batch axis (0th), but the time axis
|
|
|
|
(1st) as-is and flatten everything from the 2nd axis up.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A single 1D tensor resulting from concatenating all
|
|
|
|
flattened/one-hot'd input components. Depending on the time_axis flag,
|
|
|
|
the shape is (B, n) or (B, T, n).
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # B=2
|
|
|
|
>>> out = flatten_inputs_to_1d_tensor(
|
|
|
|
... {"a": [1, 0], "b": [[[0.0], [0.1]], [1.0], [1.1]]},
|
|
|
|
... spaces_struct=dict(a=Discrete(2), b=Box(shape=(2, 1)))
|
|
|
|
... )
|
|
|
|
>>> print(out)
|
|
|
|
... [[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]] # B=2 n=4
|
|
|
|
|
|
|
|
>>> # B=2; T=2
|
|
|
|
>>> out = flatten_inputs_to_1d_tensor(
|
|
|
|
... ([[1, 0], [0, 1]],
|
|
|
|
... [[[0.0, 0.1], [1.0, 1.1]], [[2.0, 2.1], [3.0, 3.1]]]),
|
|
|
|
... spaces_struct=tuple([Discrete(2), Box(shape=(2, ))]),
|
|
|
|
... time_axis=True
|
|
|
|
... )
|
|
|
|
>>> print(out)
|
|
|
|
... [[[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]],
|
|
|
|
... [[1.0, 0.0, 2.0, 2.1], [0.0, 1.0, 3.0, 3.1]]] # B=2 T=2 n=4
|
|
|
|
"""
|
|
|
|
|
|
|
|
flat_inputs = tree.flatten(inputs)
|
2022-01-29 18:41:57 -08:00
|
|
|
flat_spaces = (
|
|
|
|
tree.flatten(spaces_struct)
|
|
|
|
if spaces_struct is not None
|
2022-01-05 11:29:44 +01:00
|
|
|
else [None] * len(flat_inputs)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2022-01-05 11:29:44 +01:00
|
|
|
|
|
|
|
B = None
|
|
|
|
T = None
|
|
|
|
out = []
|
|
|
|
for input_, space in zip(flat_inputs, flat_spaces):
|
|
|
|
assert isinstance(input_, np.ndarray)
|
|
|
|
|
|
|
|
# Store batch and (if applicable) time dimension.
|
|
|
|
if B is None:
|
|
|
|
B = input_.shape[0]
|
|
|
|
if time_axis:
|
|
|
|
T = input_.shape[1]
|
|
|
|
|
|
|
|
# One-hot encoding.
|
|
|
|
if isinstance(space, Discrete):
|
|
|
|
if time_axis:
|
|
|
|
input_ = np.reshape(input_, [B * T])
|
|
|
|
out.append(one_hot(input_, depth=space.n).astype(np.float32))
|
|
|
|
# Multi one-hot encoding.
|
|
|
|
elif isinstance(space, MultiDiscrete):
|
|
|
|
if time_axis:
|
|
|
|
input_ = np.reshape(input_, [B * T, -1])
|
|
|
|
out.append(
|
|
|
|
np.concatenate(
|
|
|
|
[
|
|
|
|
one_hot(input_[:, i], depth=n).astype(np.float32)
|
|
|
|
for i, n in enumerate(space.nvec)
|
|
|
|
],
|
2022-01-29 18:41:57 -08:00
|
|
|
axis=-1,
|
|
|
|
)
|
|
|
|
)
|
2022-01-05 11:29:44 +01:00
|
|
|
# Box: Flatten.
|
|
|
|
else:
|
|
|
|
if time_axis:
|
|
|
|
input_ = np.reshape(input_, [B * T, -1])
|
|
|
|
else:
|
|
|
|
input_ = np.reshape(input_, [B, -1])
|
|
|
|
out.append(input_.astype(np.float32))
|
|
|
|
|
|
|
|
merged = np.concatenate(out, axis=-1)
|
|
|
|
# Restore the time-dimension, if applicable.
|
|
|
|
if time_axis:
|
|
|
|
merged = np.reshape(merged, [B, T, -1])
|
|
|
|
|
|
|
|
return merged
|
|
|
|
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
def huber_loss(x: np.ndarray, delta: float = 1.0) -> np.ndarray:
|
|
|
|
"""Reference: https://en.wikipedia.org/wiki/Huber_loss."""
|
|
|
|
return np.where(
|
2022-01-29 18:41:57 -08:00
|
|
|
np.abs(x) < delta, np.power(x, 2.0) * 0.5, delta * (np.abs(x) - 0.5 * delta)
|
|
|
|
)
|
2021-11-01 21:46:02 +01:00
|
|
|
|
|
|
|
|
|
|
|
def l2_loss(x: np.ndarray) -> np.ndarray:
|
|
|
|
"""Computes half the L2 norm of a tensor (w/o the sqrt): sum(x**2) / 2.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
Args:
|
2021-11-01 21:46:02 +01:00
|
|
|
x: The input tensor.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
Returns:
|
|
|
|
The l2-loss output according to the above formula given `x`.
|
|
|
|
"""
|
|
|
|
return np.sum(np.square(x)) / 2.0
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def lstm(
|
|
|
|
x,
|
|
|
|
weights: np.ndarray,
|
|
|
|
biases: Optional[np.ndarray] = None,
|
|
|
|
initial_internal_states: Optional[np.ndarray] = None,
|
|
|
|
time_major: bool = False,
|
|
|
|
forget_bias: float = 1.0,
|
|
|
|
):
|
2021-11-01 21:46:02 +01:00
|
|
|
"""Calculates LSTM layer output given weights/biases, states, and input.
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
Args:
|
|
|
|
x: The inputs to the LSTM layer including time-rank
|
|
|
|
(0th if time-major, else 1st) and the batch-rank
|
|
|
|
(1st if time-major, else 0th).
|
|
|
|
weights: The weights matrix.
|
|
|
|
biases: The biases vector. All 0s if None.
|
|
|
|
initial_internal_states: The initial internal
|
|
|
|
states to pass into the layer. All 0s if None.
|
|
|
|
time_major: Whether to use time-major or not. Default: False.
|
|
|
|
forget_bias: Gets added to first sigmoid (forget gate) output.
|
2019-12-30 15:27:32 -05:00
|
|
|
Default: 1.0.
|
|
|
|
|
|
|
|
Returns:
|
2021-11-01 21:46:02 +01:00
|
|
|
Tuple consisting of 1) The LSTM layer's output and
|
|
|
|
2) Tuple: Last (c-state, h-state).
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
|
|
|
sequence_length = x.shape[0 if time_major else 1]
|
|
|
|
batch_size = x.shape[1 if time_major else 0]
|
|
|
|
units = weights.shape[1] // 4 # 4 internal layers (3x sigmoid, 1x tanh)
|
|
|
|
|
|
|
|
if initial_internal_states is None:
|
|
|
|
c_states = np.zeros(shape=(batch_size, units))
|
|
|
|
h_states = np.zeros(shape=(batch_size, units))
|
|
|
|
else:
|
|
|
|
c_states = initial_internal_states[0]
|
|
|
|
h_states = initial_internal_states[1]
|
|
|
|
|
|
|
|
# Create a placeholder for all n-time step outputs.
|
|
|
|
if time_major:
|
|
|
|
unrolled_outputs = np.zeros(shape=(sequence_length, batch_size, units))
|
|
|
|
else:
|
|
|
|
unrolled_outputs = np.zeros(shape=(batch_size, sequence_length, units))
|
|
|
|
|
|
|
|
# Push the batch 4 times through the LSTM cell and capture the outputs plus
|
|
|
|
# the final h- and c-states.
|
|
|
|
for t in range(sequence_length):
|
|
|
|
input_matrix = x[t, :, :] if time_major else x[:, t, :]
|
|
|
|
input_matrix = np.concatenate((input_matrix, h_states), axis=1)
|
|
|
|
input_matmul_matrix = np.matmul(input_matrix, weights) + biases
|
|
|
|
# Forget gate (3rd slot in tf output matrix). Add static forget bias.
|
2022-01-29 18:41:57 -08:00
|
|
|
sigmoid_1 = sigmoid(input_matmul_matrix[:, units * 2 : units * 3] + forget_bias)
|
2019-12-30 15:27:32 -05:00
|
|
|
c_states = np.multiply(c_states, sigmoid_1)
|
|
|
|
# Add gate (1st and 2nd slots in tf output matrix).
|
|
|
|
sigmoid_2 = sigmoid(input_matmul_matrix[:, 0:units])
|
2022-01-29 18:41:57 -08:00
|
|
|
tanh_3 = np.tanh(input_matmul_matrix[:, units : units * 2])
|
2019-12-30 15:27:32 -05:00
|
|
|
c_states = np.add(c_states, np.multiply(sigmoid_2, tanh_3))
|
|
|
|
# Output gate (last slot in tf output matrix).
|
2022-01-29 18:41:57 -08:00
|
|
|
sigmoid_4 = sigmoid(input_matmul_matrix[:, units * 3 : units * 4])
|
2019-12-30 15:27:32 -05:00
|
|
|
h_states = np.multiply(sigmoid_4, np.tanh(c_states))
|
|
|
|
|
|
|
|
# Store this output time-slice.
|
|
|
|
if time_major:
|
|
|
|
unrolled_outputs[t, :, :] = h_states
|
|
|
|
else:
|
|
|
|
unrolled_outputs[:, t, :] = h_states
|
|
|
|
|
|
|
|
return unrolled_outputs, (c_states, h_states)
|
2020-07-14 04:27:49 +02:00
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def one_hot(
|
|
|
|
x: Union[TensorType, int],
|
|
|
|
depth: int = 0,
|
|
|
|
on_value: float = 1.0,
|
|
|
|
off_value: float = 0.0,
|
|
|
|
) -> np.ndarray:
|
2021-11-01 21:46:02 +01:00
|
|
|
"""One-hot utility function for numpy.
|
|
|
|
|
|
|
|
Thanks to qianyizhang:
|
|
|
|
https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30.
|
2020-07-14 04:27:49 +02:00
|
|
|
|
|
|
|
Args:
|
2021-11-01 21:46:02 +01:00
|
|
|
x: The input to be one-hot encoded.
|
|
|
|
depth: The max. number to be one-hot encoded (size of last rank).
|
|
|
|
on_value: The value to use for on. Default: 1.0.
|
|
|
|
off_value: The value to use for off. Default: 0.0.
|
2020-07-14 04:27:49 +02:00
|
|
|
|
|
|
|
Returns:
|
2021-11-01 21:46:02 +01:00
|
|
|
The one-hot encoded equivalent of the input array.
|
2020-07-14 04:27:49 +02:00
|
|
|
"""
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
# Handle simple ints properly.
|
|
|
|
if isinstance(x, int):
|
|
|
|
x = np.array(x, dtype=np.int32)
|
|
|
|
# Handle torch arrays properly.
|
|
|
|
elif torch and isinstance(x, torch.Tensor):
|
|
|
|
x = x.numpy()
|
2020-07-14 04:27:49 +02:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
# Handle bool arrays correctly.
|
|
|
|
if x.dtype == np.bool_:
|
|
|
|
x = x.astype(np.int)
|
|
|
|
depth = 2
|
|
|
|
|
|
|
|
# If depth is not given, try to infer it from the values in the array.
|
|
|
|
if depth == 0:
|
|
|
|
depth = np.max(x) + 1
|
2022-01-29 18:41:57 -08:00
|
|
|
assert (
|
|
|
|
np.max(x) < depth
|
|
|
|
), "ERROR: The max. index of `x` ({}) is larger than depth ({})!".format(
|
|
|
|
np.max(x), depth
|
|
|
|
)
|
2021-11-01 21:46:02 +01:00
|
|
|
shape = x.shape
|
|
|
|
|
|
|
|
# Python 2.7 compatibility, (*shape, depth) is not allowed.
|
|
|
|
shape_list = list(shape[:])
|
|
|
|
shape_list.append(depth)
|
|
|
|
out = np.ones(shape_list) * off_value
|
|
|
|
indices = []
|
|
|
|
for i in range(x.ndim):
|
|
|
|
tiles = [1] * x.ndim
|
|
|
|
s = [1] * x.ndim
|
|
|
|
s[i] = -1
|
|
|
|
r = np.arange(shape[i]).reshape(s)
|
|
|
|
if i > 0:
|
|
|
|
tiles[i - 1] = shape[i - 1]
|
|
|
|
r = np.tile(r, tiles)
|
|
|
|
indices.append(r)
|
|
|
|
indices.append(x)
|
|
|
|
out[tuple(indices)] = on_value
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray:
|
|
|
|
"""Implementation of the leaky ReLU function.
|
|
|
|
|
|
|
|
y = x * alpha if x < 0 else x
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: The input values.
|
|
|
|
alpha: A scaling ("leak") factor to use for negative x.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The leaky ReLU output for x.
|
|
|
|
"""
|
|
|
|
return np.maximum(x, x * alpha, x)
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid(x: np.ndarray, derivative: bool = False) -> np.ndarray:
|
|
|
|
"""
|
|
|
|
Returns the sigmoid function applied to x.
|
|
|
|
Alternatively, can return the derivative or the sigmoid function.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: The input to the sigmoid function.
|
|
|
|
derivative: Whether to return the derivative or not.
|
|
|
|
Default: False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The sigmoid function (or its derivative) applied to x.
|
|
|
|
"""
|
|
|
|
if derivative:
|
|
|
|
return x * (1 - x)
|
|
|
|
else:
|
|
|
|
return 1 / (1 + np.exp(-x))
|
|
|
|
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def softmax(
|
2022-02-08 16:43:00 +01:00
|
|
|
x: Union[np.ndarray, list], axis: int = -1, epsilon: Optional[float] = None
|
2022-01-29 18:41:57 -08:00
|
|
|
) -> np.ndarray:
|
2021-11-01 21:46:02 +01:00
|
|
|
"""Returns the softmax values for x.
|
|
|
|
|
|
|
|
The exact formula used is:
|
|
|
|
S(xi) = e^xi / SUMj(e^xj), where j goes over all elements in x.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x: The input to the softmax function.
|
|
|
|
axis: The axis along which to softmax.
|
2021-11-05 14:39:28 +01:00
|
|
|
epsilon: Optional epsilon as a minimum value. If None, use
|
|
|
|
`SMALL_NUMBER`.
|
2021-11-01 21:46:02 +01:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
The softmax over x.
|
|
|
|
"""
|
2021-11-05 14:39:28 +01:00
|
|
|
epsilon = epsilon or SMALL_NUMBER
|
2021-11-01 21:46:02 +01:00
|
|
|
# x_exp = np.maximum(np.exp(x), SMALL_NUMBER)
|
|
|
|
x_exp = np.exp(x)
|
|
|
|
# return x_exp /
|
|
|
|
# np.maximum(np.sum(x_exp, axis, keepdims=True), SMALL_NUMBER)
|
2021-11-05 14:39:28 +01:00
|
|
|
return np.maximum(x_exp / np.sum(x_exp, axis, keepdims=True), epsilon)
|