2019-03-16 13:34:09 -07:00
|
|
|
import gym
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
|
|
|
from ray.rllib import _register_all
|
2021-02-08 12:05:16 +01:00
|
|
|
from ray.rllib.agents.registry import get_trainer_class
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
2019-03-16 13:34:09 -07:00
|
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
|
|
|
|
|
|
class FaultInjectEnv(gym.Env):
|
2022-04-08 15:33:28 +02:00
|
|
|
"""Env that fails upon calling `step()`, but only for some remote worker indices.
|
2021-11-16 11:26:47 +00:00
|
|
|
|
|
|
|
The worker indices that should produce the failure (a ValueError) can be
|
|
|
|
provided by a list (of ints) under the "bad_indices" key in the env's
|
|
|
|
config.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> from ray.rllib.env.env_context import EnvContext
|
|
|
|
>>> # This env will fail for workers 1 and 2 (not for the local worker
|
2022-04-08 15:33:28 +02:00
|
|
|
>>> # or any others with an index != [1|2]).
|
2021-11-16 11:26:47 +00:00
|
|
|
>>> bad_env = FaultInjectEnv(
|
|
|
|
... EnvContext({"bad_indices": [1, 2]},
|
|
|
|
... worker_index=1, num_workers=3))
|
|
|
|
"""
|
|
|
|
|
2019-03-16 13:34:09 -07:00
|
|
|
def __init__(self, config):
|
|
|
|
self.env = gym.make("CartPole-v0")
|
2022-02-17 05:06:14 -08:00
|
|
|
self._skip_env_checking = True
|
2019-03-16 13:34:09 -07:00
|
|
|
self.action_space = self.env.action_space
|
|
|
|
self.observation_space = self.env.observation_space
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
return self.env.reset()
|
|
|
|
|
|
|
|
def step(self, action):
|
2022-04-08 15:33:28 +02:00
|
|
|
# Only fail on the original workers with the specified indices.
|
|
|
|
# Once on a recreated worker, don't fail anymore.
|
|
|
|
if (
|
|
|
|
self.config.worker_index in self.config["bad_indices"]
|
|
|
|
and not self.config.recreated_worker
|
|
|
|
):
|
2021-11-16 11:26:47 +00:00
|
|
|
raise ValueError(
|
|
|
|
"This is a simulated error from "
|
|
|
|
f"worker-idx={self.config.worker_index}."
|
|
|
|
)
|
2019-03-16 13:34:09 -07:00
|
|
|
return self.env.step(action)
|
|
|
|
|
|
|
|
|
|
|
|
class IgnoresWorkerFailure(unittest.TestCase):
|
2021-12-21 08:39:05 +01:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
2022-04-08 15:33:28 +02:00
|
|
|
ray.init(num_cpus=6, local_mode=True)
|
2021-12-21 08:39:05 +01:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2022-04-08 15:33:28 +02:00
|
|
|
def do_test(self, alg: str, config: dict, fn=None):
|
|
|
|
fn = fn or self._do_test_fault_ignore
|
2019-03-16 13:34:09 -07:00
|
|
|
try:
|
|
|
|
fn(alg, config)
|
|
|
|
finally:
|
|
|
|
_register_all() # re-register the evicted objects
|
|
|
|
|
2022-04-08 15:33:28 +02:00
|
|
|
def _do_test_fault_ignore(self, alg: str, config: dict):
|
2019-03-16 13:34:09 -07:00
|
|
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
2021-02-08 12:05:16 +01:00
|
|
|
agent_cls = get_trainer_class(alg)
|
2019-03-16 13:34:09 -07:00
|
|
|
|
|
|
|
# Test fault handling
|
|
|
|
config["num_workers"] = 2
|
|
|
|
config["ignore_worker_failures"] = True
|
2021-11-16 11:26:47 +00:00
|
|
|
# Make worker idx=1 fail. Other workers will be ok.
|
2019-03-16 13:34:09 -07:00
|
|
|
config["env_config"] = {"bad_indices": [1]}
|
2021-11-16 11:26:47 +00:00
|
|
|
|
2022-04-08 15:33:28 +02:00
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
2020-05-27 16:19:13 +02:00
|
|
|
a = agent_cls(config=config, env="fault_env")
|
|
|
|
result = a.train()
|
|
|
|
self.assertTrue(result["num_healthy_workers"], 1)
|
|
|
|
a.stop()
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def _do_test_fault_fatal(self, alg, config):
|
2019-03-16 13:34:09 -07:00
|
|
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
2021-02-08 12:05:16 +01:00
|
|
|
agent_cls = get_trainer_class(alg)
|
2021-11-16 11:26:47 +00:00
|
|
|
|
2019-03-16 13:34:09 -07:00
|
|
|
# Test raises real error when out of workers
|
|
|
|
config["num_workers"] = 2
|
|
|
|
config["ignore_worker_failures"] = True
|
2021-11-16 11:26:47 +00:00
|
|
|
# Make both worker idx=1 and 2 fail.
|
2019-03-16 13:34:09 -07:00
|
|
|
config["env_config"] = {"bad_indices": [1, 2]}
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
|
|
|
a = agent_cls(config=config, env="fault_env")
|
|
|
|
self.assertRaises(Exception, lambda: a.train())
|
|
|
|
a.stop()
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2022-04-08 15:33:28 +02:00
|
|
|
def _do_test_fault_fatal_but_recreate(self, alg, config):
|
|
|
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
|
|
|
agent_cls = get_trainer_class(alg)
|
|
|
|
|
|
|
|
# Test raises real error when out of workers
|
|
|
|
config["num_workers"] = 2
|
|
|
|
config["recreate_failed_workers"] = True
|
|
|
|
# Make both worker idx=1 and 2 fail.
|
|
|
|
config["env_config"] = {"bad_indices": [1, 2]}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
a = agent_cls(config=config, env="fault_env")
|
|
|
|
# Expect this to go well and all faulty workers are recovered.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
worker.apply.remote(
|
|
|
|
lambda w: w.recreated_worker
|
|
|
|
or w.env_context.recreated_worker
|
|
|
|
)
|
|
|
|
)
|
|
|
|
for worker in a.workers.remote_workers()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
result = a.train()
|
|
|
|
self.assertTrue(result["num_healthy_workers"], 2)
|
|
|
|
self.assertTrue(
|
|
|
|
all(
|
|
|
|
ray.get(
|
|
|
|
worker.apply.remote(
|
|
|
|
lambda w: w.recreated_worker
|
|
|
|
and w.env_context.recreated_worker
|
|
|
|
)
|
|
|
|
)
|
|
|
|
for worker in a.workers.remote_workers()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# This should also work several times.
|
|
|
|
result = a.train()
|
|
|
|
self.assertTrue(result["num_healthy_workers"], 2)
|
|
|
|
a.stop()
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_fatal(self):
|
2022-04-08 15:33:28 +02:00
|
|
|
# Test the case where all workers fail (w/o recovery).
|
2020-03-12 04:39:47 +01:00
|
|
|
self.do_test("PG", {"optimizer": {}}, fn=self._do_test_fault_fatal)
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2022-04-08 15:33:28 +02:00
|
|
|
def test_fatal_but_recover(self):
|
|
|
|
# Test the case where all workers fail, but we chose to recover.
|
|
|
|
self.do_test("PG", {"optimizer": {}}, fn=self._do_test_fault_fatal_but_recreate)
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_grads(self):
|
|
|
|
self.do_test("A3C", {"optimizer": {"grads_per_step": 1}})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_replay(self):
|
|
|
|
self.do_test(
|
2019-03-16 13:34:09 -07:00
|
|
|
"APEX",
|
|
|
|
{
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 1000,
|
2019-03-16 13:34:09 -07:00
|
|
|
"num_gpus": 0,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_time_s_per_iteration": 1,
|
2020-02-20 17:39:16 +01:00
|
|
|
"explore": False,
|
2019-03-16 13:34:09 -07:00
|
|
|
"learning_starts": 1000,
|
|
|
|
"target_network_update_freq": 100,
|
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_samples(self):
|
|
|
|
self.do_test("IMPALA", {"num_gpus": 0})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_sync_replay(self):
|
2022-06-10 17:09:18 +02:00
|
|
|
self.do_test("DQN", {"min_sample_timesteps_per_iteration": 1})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_multi_g_p_u(self):
|
|
|
|
self.do_test(
|
2019-03-16 13:34:09 -07:00
|
|
|
"PPO",
|
|
|
|
{
|
|
|
|
"num_sgd_iter": 1,
|
|
|
|
"train_batch_size": 10,
|
2020-03-14 12:05:04 -07:00
|
|
|
"rollout_fragment_length": 10,
|
2019-03-16 13:34:09 -07:00
|
|
|
"sgd_minibatch_size": 1,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_sync_samples(self):
|
|
|
|
self.do_test("PG", {"optimizer": {}})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_sampling_option(self):
|
|
|
|
self.do_test("PG", {"optimizer": {}, "sample_async": True})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|