2022-07-20 23:25:53 +01:00
|
|
|
import gym
|
2022-01-05 11:29:44 +01:00
|
|
|
import numpy as np
|
|
|
|
import tree # pip install dm_tree
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
|
|
|
from ray.rllib.utils.numpy import flatten_inputs_to_1d_tensor as flatten_np
|
2022-04-29 10:27:06 +02:00
|
|
|
from ray.rllib.utils.numpy import make_action_immutable
|
2022-01-05 11:29:44 +01:00
|
|
|
from ray.rllib.utils.test_utils import check
|
2022-07-20 23:25:53 +01:00
|
|
|
from ray.rllib.utils.tf_utils import (
|
|
|
|
flatten_inputs_to_1d_tensor as flatten_tf,
|
|
|
|
one_hot as one_hot_tf,
|
|
|
|
)
|
|
|
|
from ray.rllib.utils.torch_utils import (
|
|
|
|
flatten_inputs_to_1d_tensor as flatten_torch,
|
|
|
|
one_hot as one_hot_torch,
|
|
|
|
)
|
2022-01-05 11:29:44 +01:00
|
|
|
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
|
|
|
|
|
|
class TestUtils(unittest.TestCase):
|
|
|
|
# Nested struct of data with B=3.
|
|
|
|
struct = {
|
|
|
|
"a": np.array([1, 3, 2]),
|
|
|
|
"b": (
|
|
|
|
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
|
|
|
|
np.array(
|
|
|
|
[[[8.0], [7.0], [6.0]], [[5.0], [4.0], [3.0]], [[2.0], [1.0], [0.0]]]
|
|
|
|
),
|
|
|
|
),
|
|
|
|
"c": {
|
|
|
|
"ca": np.array([[1, 2], [3, 5], [0, 1]]),
|
|
|
|
"cb": np.array([1.0, 2.0, 3.0]),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
# Nested struct of data with B=2 and T=1.
|
|
|
|
struct_w_time_axis = {
|
|
|
|
"a": np.array([[1], [3]]),
|
|
|
|
"b": (
|
|
|
|
np.array([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]),
|
|
|
|
np.array([[[[8.0], [7.0], [6.0]]], [[[5.0], [4.0], [3.0]]]]),
|
|
|
|
),
|
|
|
|
"c": {"ca": np.array([[[1, 2]], [[3, 5]]]), "cb": np.array([[1.0], [2.0]])},
|
|
|
|
}
|
|
|
|
# Corresponding space struct.
|
|
|
|
spaces = dict(
|
|
|
|
{
|
2022-07-20 23:25:53 +01:00
|
|
|
"a": gym.spaces.Discrete(4),
|
|
|
|
"b": (gym.spaces.Box(-1.0, 10.0, (3,)), gym.spaces.Box(-1.0, 1.0, (3, 1))),
|
2022-01-05 11:29:44 +01:00
|
|
|
"c": dict(
|
|
|
|
{
|
2022-07-20 23:25:53 +01:00
|
|
|
"ca": gym.spaces.MultiDiscrete([4, 6]),
|
|
|
|
"cb": gym.spaces.Box(-1.0, 1.0, ()),
|
2022-01-05 11:29:44 +01:00
|
|
|
}
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
2022-07-20 23:25:53 +01:00
|
|
|
tf1.enable_eager_execution()
|
2022-01-05 11:29:44 +01:00
|
|
|
ray.init()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2022-04-29 15:21:11 +01:00
|
|
|
def test_make_action_immutable(self):
|
2022-04-29 10:27:06 +02:00
|
|
|
from types import MappingProxyType
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
# Test Box space.
|
|
|
|
space = gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
|
2022-04-29 15:21:11 +01:00
|
|
|
action = space.sample()
|
2022-04-29 10:27:06 +02:00
|
|
|
action = make_action_immutable(action)
|
|
|
|
self.assertFalse(action.flags["WRITEABLE"])
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
# Test Discrete space.
|
|
|
|
# Nothing to be tested as sampled actions are integers
|
2022-04-29 15:21:11 +01:00
|
|
|
# and integers are immutable by nature.
|
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
# Test MultiDiscrete space.
|
2022-04-29 15:21:11 +01:00
|
|
|
space = gym.spaces.MultiDiscrete([3, 3, 3])
|
2022-04-29 10:27:06 +02:00
|
|
|
action = space.sample()
|
|
|
|
action = make_action_immutable(action)
|
|
|
|
self.assertFalse(action.flags["WRITEABLE"])
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
# Test MultiBinary space.
|
2022-04-29 15:21:11 +01:00
|
|
|
space = gym.spaces.MultiBinary([2, 2, 2])
|
2022-04-29 10:27:06 +02:00
|
|
|
action = space.sample()
|
|
|
|
action = make_action_immutable(action)
|
|
|
|
self.assertFalse(action.flags["WRITEABLE"])
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
# Test Tuple space.
|
|
|
|
space = gym.spaces.Tuple(
|
|
|
|
(
|
|
|
|
gym.spaces.Discrete(2),
|
|
|
|
gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
action = space.sample()
|
|
|
|
action = tree.traverse(make_action_immutable, action, top_down=False)
|
|
|
|
self.assertFalse(action[1].flags["WRITEABLE"])
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
# Test Dict space.
|
2022-04-29 15:21:11 +01:00
|
|
|
space = gym.spaces.Dict(
|
|
|
|
{
|
|
|
|
"a": gym.spaces.Discrete(2),
|
|
|
|
"b": gym.spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32),
|
|
|
|
"c": gym.spaces.Tuple(
|
|
|
|
(
|
|
|
|
gym.spaces.Discrete(2),
|
|
|
|
gym.spaces.Box(
|
|
|
|
low=-1.0, high=1.0, shape=(8,), dtype=np.float32
|
|
|
|
),
|
|
|
|
)
|
|
|
|
),
|
|
|
|
}
|
|
|
|
)
|
2022-04-29 10:27:06 +02:00
|
|
|
action = space.sample()
|
|
|
|
action = tree.traverse(make_action_immutable, action, top_down=False)
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
def fail_fun(obj):
|
|
|
|
obj["a"] = 5
|
2022-04-29 15:21:11 +01:00
|
|
|
|
2022-04-29 10:27:06 +02:00
|
|
|
self.assertRaises(TypeError, fail_fun, action)
|
|
|
|
self.assertFalse(action["b"].flags["WRITEABLE"])
|
2022-04-29 15:21:11 +01:00
|
|
|
self.assertFalse(action["c"][1].flags["WRITEABLE"])
|
|
|
|
self.assertTrue(isinstance(action, MappingProxyType))
|
|
|
|
|
2022-01-05 11:29:44 +01:00
|
|
|
def test_flatten_inputs_to_1d_tensor(self):
|
|
|
|
# B=3; no time axis.
|
|
|
|
check(
|
|
|
|
flatten_np(self.struct, spaces_struct=self.spaces),
|
|
|
|
np.array(
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
3.0,
|
|
|
|
8.0,
|
|
|
|
7.0,
|
|
|
|
6.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
],
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
4.0,
|
|
|
|
5.0,
|
|
|
|
6.0,
|
|
|
|
5.0,
|
|
|
|
4.0,
|
|
|
|
3.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
],
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
7.0,
|
|
|
|
8.0,
|
|
|
|
9.0,
|
|
|
|
2.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
3.0,
|
|
|
|
],
|
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
struct_tf = tree.map_structure(lambda s: tf.convert_to_tensor(s), self.struct)
|
|
|
|
check(
|
|
|
|
flatten_tf(struct_tf, spaces_struct=self.spaces),
|
|
|
|
np.array(
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
3.0,
|
|
|
|
8.0,
|
|
|
|
7.0,
|
|
|
|
6.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
],
|
|
|
|
[
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
4.0,
|
|
|
|
5.0,
|
|
|
|
6.0,
|
|
|
|
5.0,
|
|
|
|
4.0,
|
|
|
|
3.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
|
|
|
2.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
],
|
|
|
|
[
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
7.0,
|
|
|
|
8.0,
|
|
|
|
9.0,
|
|
|
|
2.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
3.0,
|
|
|
|
],
|
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
struct_torch = tree.map_structure(lambda s: torch.from_numpy(s), self.struct)
|
|
|
|
check(
|
|
|
|
flatten_torch(struct_torch, spaces_struct=self.spaces),
|
|
|
|
np.array(
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
3.0,
|
|
|
|
8.0,
|
|
|
|
7.0,
|
|
|
|
6.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
],
|
|
|
|
[
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
4.0,
|
|
|
|
5.0,
|
|
|
|
6.0,
|
|
|
|
5.0,
|
|
|
|
4.0,
|
|
|
|
3.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
|
|
|
2.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
],
|
|
|
|
[
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
7.0,
|
|
|
|
8.0,
|
|
|
|
9.0,
|
|
|
|
2.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
3.0,
|
|
|
|
],
|
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
def test_flatten_inputs_to_1d_tensor_w_time_axis(self):
|
|
|
|
# B=2; T=1
|
|
|
|
check(
|
|
|
|
flatten_np(
|
|
|
|
self.struct_w_time_axis, spaces_struct=self.spaces, time_axis=True
|
|
|
|
),
|
|
|
|
np.array(
|
|
|
|
[
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
3.0,
|
|
|
|
8.0,
|
|
|
|
7.0,
|
|
|
|
6.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
1.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2022-01-05 11:29:44 +01:00
|
|
|
],
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
4.0,
|
|
|
|
5.0,
|
|
|
|
6.0,
|
|
|
|
5.0,
|
|
|
|
4.0,
|
|
|
|
3.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2022-01-05 11:29:44 +01:00
|
|
|
],
|
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
struct_tf = tree.map_structure(
|
|
|
|
lambda s: tf.convert_to_tensor(s), self.struct_w_time_axis
|
|
|
|
)
|
|
|
|
check(
|
|
|
|
flatten_tf(struct_tf, spaces_struct=self.spaces, time_axis=True),
|
|
|
|
np.array(
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
3.0,
|
|
|
|
8.0,
|
|
|
|
7.0,
|
|
|
|
6.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2022-01-05 11:29:44 +01:00
|
|
|
],
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
4.0,
|
|
|
|
5.0,
|
|
|
|
6.0,
|
|
|
|
5.0,
|
|
|
|
4.0,
|
|
|
|
3.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
|
|
|
2.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2022-01-05 11:29:44 +01:00
|
|
|
],
|
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
struct_torch = tree.map_structure(
|
|
|
|
lambda s: torch.from_numpy(s), self.struct_w_time_axis
|
|
|
|
)
|
|
|
|
check(
|
|
|
|
flatten_torch(struct_torch, spaces_struct=self.spaces, time_axis=True),
|
|
|
|
np.array(
|
|
|
|
[
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
|
|
|
3.0,
|
|
|
|
8.0,
|
|
|
|
7.0,
|
|
|
|
6.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
1.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
0.0,
|
|
|
|
0.0,
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2022-01-05 11:29:44 +01:00
|
|
|
],
|
2022-01-29 18:41:57 -08:00
|
|
|
[
|
|
|
|
[
|
2022-01-05 11:29:44 +01:00
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
4.0,
|
|
|
|
5.0,
|
|
|
|
6.0,
|
|
|
|
5.0,
|
|
|
|
4.0,
|
|
|
|
3.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
2.0,
|
2022-01-29 18:41:57 -08:00
|
|
|
]
|
2022-01-05 11:29:44 +01:00
|
|
|
],
|
|
|
|
]
|
2022-01-29 18:41:57 -08:00
|
|
|
),
|
2022-01-05 11:29:44 +01:00
|
|
|
)
|
|
|
|
|
2022-07-20 23:25:53 +01:00
|
|
|
def test_one_hot(self):
|
|
|
|
space = gym.spaces.MultiDiscrete([[3, 3], [3, 3]])
|
|
|
|
|
|
|
|
# TF
|
|
|
|
x = tf.Variable([[0, 2, 1, 0]], dtype=tf.int32)
|
|
|
|
y = one_hot_tf(x, space)
|
|
|
|
self.assertTrue(([1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0] == y.numpy()).all())
|
|
|
|
|
|
|
|
# Torch
|
|
|
|
x = torch.tensor([[0, 2, 1, 0]], dtype=torch.int32)
|
|
|
|
y = one_hot_torch(x, space)
|
|
|
|
self.assertTrue(([1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0] == y.numpy()).all())
|
|
|
|
|
2022-01-05 11:29:44 +01:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-01-05 11:29:44 +01:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|