2021-12-11 14:57:58 +01:00
|
|
|
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import tree # pip install dm_tree
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
2022-06-04 07:35:24 +02:00
|
|
|
from ray.rllib.algorithms.bc import BC
|
|
|
|
from ray.rllib.algorithms.pg import PG, DEFAULT_CONFIG
|
2021-12-11 14:57:58 +01:00
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
|
|
|
from ray.rllib.offline.json_reader import JsonReader
|
|
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
|
|
|
|
|
|
SPACES = {
|
|
|
|
"dict": Dict(
|
|
|
|
{
|
|
|
|
"a": Dict(
|
2022-01-29 18:41:57 -08:00
|
|
|
{
|
2021-12-11 14:57:58 +01:00
|
|
|
"aa": Box(-1.0, 1.0, shape=(3,)),
|
|
|
|
"ab": MultiDiscrete([4, 3]),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
"b": Discrete(3),
|
|
|
|
"c": Tuple([Box(0, 10, (2,), dtype=np.int32), Discrete(2)]),
|
|
|
|
"d": Box(0, 3, (), dtype=np.int64),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
"tuple": Tuple(
|
|
|
|
[
|
|
|
|
Tuple(
|
|
|
|
[
|
|
|
|
Box(-1.0, 1.0, shape=(2,)),
|
|
|
|
Discrete(3),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
MultiDiscrete([4, 3]),
|
|
|
|
Dict(
|
|
|
|
{
|
|
|
|
"a": Box(0, 100, (), dtype=np.int32),
|
|
|
|
"b": Discrete(2),
|
|
|
|
}
|
|
|
|
),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
"multidiscrete": MultiDiscrete([2, 3, 4]),
|
|
|
|
"intbox": Box(0, 100, (2,), dtype=np.int32),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class NestedActionSpacesTest(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
ray.init(num_cpus=5)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls):
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_nested_action_spaces(self):
|
|
|
|
config = DEFAULT_CONFIG.copy()
|
|
|
|
config["env"] = RandomEnv
|
|
|
|
# Write output to check, whether actions are written correctly.
|
|
|
|
tmp_dir = os.popen("mktemp -d").read()[:-1]
|
|
|
|
if not os.path.exists(tmp_dir):
|
|
|
|
# Last resort: Resolve via underlying tempdir (and cut tmp_.
|
|
|
|
tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
|
|
|
|
assert os.path.exists(tmp_dir), f"'{tmp_dir}' not found!"
|
|
|
|
config["output"] = tmp_dir
|
|
|
|
# Switch off OPE as we don't write action-probs.
|
|
|
|
# TODO: We should probably always write those if `output` is given.
|
2022-06-07 03:52:19 -07:00
|
|
|
config["off_policy_estimation_methods"] = {}
|
2021-12-11 14:57:58 +01:00
|
|
|
|
|
|
|
# Pretend actions in offline files are already normalized.
|
|
|
|
config["actions_in_input_normalized"] = True
|
|
|
|
|
|
|
|
for _ in framework_iterator(config):
|
|
|
|
for name, action_space in SPACES.items():
|
|
|
|
config["env_config"] = {
|
|
|
|
"action_space": action_space,
|
|
|
|
}
|
2022-06-03 01:50:36 -07:00
|
|
|
for flatten in [True, False]:
|
2021-12-11 14:57:58 +01:00
|
|
|
print(f"A={action_space} flatten={flatten}")
|
|
|
|
shutil.rmtree(config["output"])
|
|
|
|
config["_disable_action_flattening"] = not flatten
|
2022-06-11 15:10:39 +02:00
|
|
|
pg = PG(config)
|
|
|
|
pg.train()
|
|
|
|
pg.stop()
|
2021-12-11 14:57:58 +01:00
|
|
|
|
|
|
|
# Check actions in output file (whether properly flattened
|
|
|
|
# or not).
|
|
|
|
reader = JsonReader(
|
|
|
|
inputs=config["output"],
|
2022-06-11 15:10:39 +02:00
|
|
|
ioctx=pg.workers.local_worker().io_context,
|
2021-12-11 14:57:58 +01:00
|
|
|
)
|
|
|
|
sample_batch = reader.next()
|
|
|
|
if flatten:
|
|
|
|
assert isinstance(sample_batch["actions"], np.ndarray)
|
|
|
|
assert len(sample_batch["actions"].shape) == 2
|
|
|
|
assert sample_batch["actions"].shape[0] == len(sample_batch)
|
|
|
|
else:
|
|
|
|
tree.assert_same_structure(
|
2022-06-11 15:10:39 +02:00
|
|
|
pg.get_policy().action_space_struct,
|
2021-12-11 14:57:58 +01:00
|
|
|
sample_batch["actions"],
|
|
|
|
)
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
# Test, whether offline data can be properly read by
|
|
|
|
# BC, configured accordingly.
|
2022-07-08 12:43:35 -07:00
|
|
|
|
|
|
|
# doing this for backwards compatibility until we move to parquet
|
|
|
|
# as default output
|
|
|
|
config["input"] = lambda ioctx: JsonReader(
|
|
|
|
ioctx.config["input_config"]["paths"], ioctx
|
|
|
|
)
|
|
|
|
config["input_config"] = {"paths": config["output"]}
|
2021-12-11 14:57:58 +01:00
|
|
|
del config["output"]
|
2022-06-11 15:10:39 +02:00
|
|
|
bc = BC(config=config)
|
|
|
|
bc.train()
|
|
|
|
bc.stop()
|
2021-12-11 14:57:58 +01:00
|
|
|
config["output"] = tmp_dir
|
|
|
|
config["input"] = "sampler"
|
[CI] Check test files for `if __name__...` snippet (#25322)
Bazel operates by simply running the python scripts given to it in `py_test`. If the script doesn't invoke pytest on itself in the `if _name__ == "__main__"` snippet, no tests will be ran, and the script will pass. This has led to several tests (indeed, some are fixed in this PR) that, despite having been written, have never ran in CI. This PR adds a lint check to check all `py_test` sources for the presence of `if _name__ == "__main__"` snippet, and will fail CI if there are any detected without it. This system is only enabled for libraries right now (tune, train, air, rllib), but it could be trivially extended to other modules if approved.
2022-06-02 11:30:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
|
|
|
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|