[RLlib] Issue 8769 broken OOM tests_dir cases (R & S). (#8770)

This commit is contained in:
Sven Mika 2020-06-05 08:34:21 +02:00 committed by GitHub
parent 9410e5884d
commit 97d524c075
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 65 deletions

View file

@ -303,7 +303,7 @@ script:
# ray serve tests
- if [ $RAY_CI_SERVE_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci --test_tag_filters=-jenkins_only python/ray/serve/...; fi
# ray dashboard tests
- if [ "$RAY_CI_DASHBOARD_AFFECTED" == "1" ]; then ./ci/keep_alive bazel test python/ray/dashboard/...; fi

View file

@ -1174,6 +1174,20 @@ py_test(
srcs = ["tests/test_evaluators.py"]
)
py_test(
name = "tests/test_exec_api",
tags = ["tests_dir", "tests_dir_E"],
size = "medium",
srcs = ["tests/test_exec_api.py"]
)
py_test(
name = "tests/test_execution",
tags = ["tests_dir", "tests_dir_E"],
size = "medium",
srcs = ["tests/test_execution.py"]
)
py_test(
name = "tests/test_export",
tags = ["tests_dir", "tests_dir_E"],
@ -1216,13 +1230,6 @@ py_test(
srcs = ["tests/test_io.py"]
)
py_test(
name = "tests/test_execution",
tags = ["tests_dir", "tests_dir_E"],
size = "medium",
srcs = ["tests/test_execution.py"]
)
py_test(
name = "tests/test_local",
tags = ["tests_dir", "tests_dir_L"],
@ -1248,7 +1255,7 @@ py_test(
py_test(
name = "tests/test_multi_agent_env",
tags = ["tests_dir", "tests_dir_M"],
size = "large",
size = "medium",
srcs = ["tests/test_multi_agent_env.py"]
)
@ -1267,17 +1274,10 @@ py_test(
srcs = ["tests/test_nested_observation_spaces.py"]
)
py_test(
name = "tests/test_exec_api",
tags = ["tests_dir", "tests_dir_E"],
size = "medium",
srcs = ["tests/test_exec_api.py"]
)
py_test(
name = "tests/test_reproducibility",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
size = "medium",
srcs = ["tests/test_reproducibility.py"]
)
@ -1293,7 +1293,7 @@ py_test(
py_test(
name = "tests/test_rollout_worker",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
size = "medium",
srcs = ["tests/test_rollout_worker.py"]
)
@ -1307,7 +1307,7 @@ py_test(
py_test(
name = "tests/test_supported_spaces",
tags = ["tests_dir", "tests_dir_S"],
size = "large",
size = "enormous",
srcs = ["tests/test_supported_spaces.py"]
)

View file

@ -54,6 +54,10 @@ def get_policy_class(config):
def validate_config(config):
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
if config["sample_async"] and config["framework"] == "torch":
config["sample_async"] = False
logger.warning("`sample_async=True` is not supported for PyTorch! "
"Multithreading can lead to crashes.")
def execution_plan(workers, config):

View file

@ -11,8 +11,8 @@ import ray
from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.ars.ars_tf_policy import ARSTFPolicy
from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import utils
from ray.rllib.agents.es import optimizers, utils
from ray.rllib.agents.es.es import validate_config
from ray.rllib.agents.es.es_tf_policy import rollout
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
@ -179,6 +179,7 @@ class ARSTrainer(Trainer):
@override(Trainer)
def _init(self, config, env_creator):
validate_config(config)
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)

View file

@ -169,6 +169,11 @@ def get_policy_class(config):
return policy_cls
def validate_config(config):
if config["num_workers"] <= 0:
raise ValueError("`num_workers` must be > 0 for ES!")
class ESTrainer(Trainer):
"""Large-scale implementation of Evolution Strategies in Ray."""
@ -177,6 +182,7 @@ class ESTrainer(Trainer):
@override(Trainer)
def _init(self, config, env_creator):
validate_config(config)
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
policy_cls = get_policy_class(config)

View file

@ -6,9 +6,13 @@ import unittest
from ray.rllib.utils.test_utils import framework_iterator
def rollout_test(algo, env="CartPole-v0"):
def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
extra_config = ""
if algo == "ES":
extra_config = ",\"episodes_per_batch\": 1,\"train_batch_size\": 10, "\
"\"noise_size\": 250000"
for fw in framework_iterator(frameworks=("torch", "tf")):
for fw in framework_iterator(frameworks=("tf", "torch")):
fw_ = ", \"framework\": \"{}\"".format(fw)
tmp_dir = os.popen("mktemp -d").read()[:-1]
@ -22,8 +26,8 @@ def rollout_test(algo, env="CartPole-v0"):
os.path.exists(rllib_dir)))
os.system("python {}/train.py --local-dir={} --run={} "
"--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo) +
"--config='{" +
"\"num_workers\": 0, \"num_gpus\": 0{}".format(fw_) +
"--config='{" + "\"num_workers\": 1, \"num_gpus\": 0{}{}".
format(fw_, extra_config) +
", \"model\": {\"fcnet_hiddens\": [10]}"
"}' --stop='{\"training_iteration\": 1, "
"\"timesteps_per_iter\": 5, "
@ -44,12 +48,13 @@ def rollout_test(algo, env="CartPole-v0"):
print("rollout output (10 steps) exists!".format(checkpoint_path))
# Test rolling out 1 episode.
os.popen("python {}/rollout.py --run={} \"{}\" --episodes=1 "
"--out=\"{}/rollouts_1episode.pkl\" --no-render".format(
rllib_dir, algo, checkpoint_path, tmp_dir)).read()
if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"):
sys.exit(1)
print("rollout output (1 ep) exists!".format(checkpoint_path))
if test_episode_rollout:
os.popen("python {}/rollout.py --run={} \"{}\" --episodes=1 "
"--out=\"{}/rollouts_1episode.pkl\" --no-render".format(
rllib_dir, algo, checkpoint_path, tmp_dir)).read()
if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"):
sys.exit(1)
print("rollout output (1 ep) exists!".format(checkpoint_path))
# Cleanup.
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
@ -72,13 +77,10 @@ class TestRollout(unittest.TestCase):
rollout_test("ES")
def test_impala(self):
rollout_test("IMPALA", env="Pong-ram-v4")
def test_pg(self):
rollout_test("PG")
rollout_test("IMPALA", env="CartPole-v0")
def test_ppo(self):
rollout_test("PPO", env="Pendulum-v0")
rollout_test("PPO", env="CartPole-v0", test_episode_rollout=True)
def test_sac(self):
rollout_test("SAC", env="Pendulum-v0")

View file

@ -14,23 +14,25 @@ def check_support_multiagent(alg, config):
register_env("multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": 2}))
config["log_level"] = "ERROR"
for _ in framework_iterator(config, frameworks=("tf", "torch")):
for _ in framework_iterator(config, frameworks=("torch", "tf")):
if alg in ["DDPG", "APEX_DDPG", "SAC"]:
a = get_agent_class(alg)(
config=config, env="multi_agent_mountaincar")
else:
a = get_agent_class(alg)(config=config, env="multi_agent_cartpole")
try:
a.train()
print(a.train())
finally:
a.stop()
class ModelSupportedSpaces(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, ignore_reinit_error=True)
class TestSupportedMultiAgent(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
def tearDown(self):
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_a3c_multiagent(self):
@ -45,10 +47,11 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support_multiagent(
"APEX", {
"num_workers": 2,
"timesteps_per_iteration": 1000,
"timesteps_per_iteration": 100,
"num_gpus": 0,
"buffer_size": 1000,
"min_iter_time_s": 1,
"learning_starts": 1000,
"learning_starts": 10,
"target_network_update_freq": 100,
})
@ -56,10 +59,11 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support_multiagent(
"APEX_DDPG", {
"num_workers": 2,
"timesteps_per_iteration": 1000,
"timesteps_per_iteration": 100,
"buffer_size": 1000,
"num_gpus": 0,
"min_iter_time_s": 1,
"learning_starts": 1000,
"learning_starts": 10,
"target_network_update_freq": 100,
"use_state_preprocessor": True,
})
@ -68,12 +72,16 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support_multiagent(
"DDPG", {
"timesteps_per_iteration": 1,
"buffer_size": 1000,
"use_state_preprocessor": True,
"learning_starts": 500,
})
def test_dqn_multiagent(self):
check_support_multiagent("DQN", {"timesteps_per_iteration": 1})
check_support_multiagent("DQN", {
"timesteps_per_iteration": 1,
"buffer_size": 1000,
})
def test_impala_multiagent(self):
check_support_multiagent("IMPALA", {"num_gpus": 0})
@ -94,6 +102,7 @@ class ModelSupportedSpaces(unittest.TestCase):
def test_sac_multiagent(self):
check_support_multiagent("SAC", {
"num_workers": 0,
"buffer_size": 1000,
"normalize_actions": False,
})

View file

@ -48,7 +48,7 @@ OBSERVATION_SPACES_TO_TEST = {
}
def check_support(alg, config, check_bounds=False, tfe=False):
def check_support(alg, config, train=True, check_bounds=False, tfe=False):
config["log_level"] = "ERROR"
def _do_check(alg, config, a_name, o_name):
@ -83,7 +83,8 @@ def check_support(alg, config, check_bounds=False, tfe=False):
assert isinstance(a.get_policy().model, TorchFCNetV2)
else:
assert isinstance(a.get_policy().model, FCNetV2)
a.train()
if train:
a.train()
except UnsupportedSpaceException:
stat = "unsupported"
finally:
@ -99,19 +100,22 @@ def check_support(alg, config, check_bounds=False, tfe=False):
if tfe:
frameworks += ("tfe", )
for _ in framework_iterator(config, frameworks=frameworks):
# Check all action spaces.
# Check all action spaces (using a discrete obs-space).
for a_name, action_space in ACTION_SPACES_TO_TEST.items():
_do_check(alg, config, a_name, "discrete")
# Check all obs spaces.
# Check all obs spaces (using a supported action-space).
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
_do_check(alg, config, "discrete", o_name)
a_name = "discrete" if alg not in ["DDPG", "SAC"] else "vector"
_do_check(alg, config, a_name, o_name)
class ModelSupportedSpaces(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, ignore_reinit_error=True, local_mode=True)
class TestSupportedSpaces(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
def tearDown(self):
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_a3c(self):
@ -119,7 +123,7 @@ class ModelSupportedSpaces(unittest.TestCase):
check_support("A3C", config, check_bounds=True)
def test_appo(self):
check_support("APPO", {"num_gpus": 0, "vtrace": False})
check_support("APPO", {"num_gpus": 0, "vtrace": False}, train=False)
check_support("APPO", {"num_gpus": 0, "vtrace": True})
def test_ars(self):
@ -138,12 +142,13 @@ class ModelSupportedSpaces(unittest.TestCase):
"ou_base_scale": 100.0
},
"timesteps_per_iteration": 1,
"buffer_size": 1000,
"use_state_preprocessor": True,
},
check_bounds=True)
def test_dqn(self):
config = {"timesteps_per_iteration": 1}
config = {"timesteps_per_iteration": 1, "buffer_size": 1000}
check_support("DQN", config, tfe=True)
def test_es(self):
@ -170,10 +175,10 @@ class ModelSupportedSpaces(unittest.TestCase):
def test_pg(self):
config = {"num_workers": 1, "optimizer": {}}
check_support("PG", config, check_bounds=True, tfe=True)
check_support("PG", config, train=False, check_bounds=True, tfe=True)
def test_sac(self):
check_support("SAC", {}, check_bounds=True)
check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
if __name__ == "__main__":

View file

@ -8,16 +8,15 @@ pendulum-ppo:
config:
# Works for both torch and tf.
framework: tf
train_batch_size: 2048
train_batch_size: 512
vf_clip_param: 10.0
num_workers: 0
num_envs_per_worker: 10
num_envs_per_worker: 20
lambda: 0.1
gamma: 0.95
lr: 0.0003
sgd_minibatch_size: 64
num_sgd_iter: 10
num_sgd_iter: 6
model:
fcnet_hiddens: [64, 64]
batch_mode: complete_episodes
fcnet_hiddens: [256, 256]
observation_filter: MeanStdFilter