2022-08-29 14:37:56 +02:00
|
|
|
import time
|
|
|
|
import unittest
|
2022-07-27 00:10:45 -07:00
|
|
|
from collections import defaultdict
|
2022-08-29 14:37:56 +02:00
|
|
|
|
2022-06-20 19:53:47 +02:00
|
|
|
import gym
|
2022-07-15 08:55:14 +02:00
|
|
|
import numpy as np
|
2022-06-20 19:53:47 +02:00
|
|
|
|
2019-03-16 13:34:09 -07:00
|
|
|
import ray
|
2022-07-27 00:10:45 -07:00
|
|
|
from ray.rllib.algorithms.pg import PG, PGConfig
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.registry import get_algorithm_class
|
2022-07-27 00:10:45 -07:00
|
|
|
from ray.rllib.env.multi_agent_env import make_multi_agent
|
2022-07-15 08:55:14 +02:00
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
2022-07-27 00:10:45 -07:00
|
|
|
from ray.rllib.policy.policy import PolicySpec
|
2020-05-27 16:19:13 +02:00
|
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
2019-03-16 13:34:09 -07:00
|
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
|
|
|
2022-07-27 00:10:45 -07:00
|
|
|
@ray.remote
|
|
|
|
class Counter:
|
|
|
|
"""Remote counter service that survives restarts."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
def _key(self, eval, worker_index, vector_index):
|
|
|
|
return f"{eval}:{worker_index}:{vector_index}"
|
|
|
|
|
|
|
|
def increment(self, eval, worker_index, vector_index):
|
|
|
|
self.counter[self._key(eval, worker_index, vector_index)] += 1
|
|
|
|
|
|
|
|
def get(self, eval, worker_index, vector_index):
|
|
|
|
return self.counter[self._key(eval, worker_index, vector_index)]
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.counter = defaultdict(int)
|
|
|
|
|
|
|
|
|
2019-03-16 13:34:09 -07:00
|
|
|
class FaultInjectEnv(gym.Env):
|
2022-04-08 15:33:28 +02:00
|
|
|
"""Env that fails upon calling `step()`, but only for some remote worker indices.
|
2021-11-16 11:26:47 +00:00
|
|
|
|
|
|
|
The worker indices that should produce the failure (a ValueError) can be
|
|
|
|
provided by a list (of ints) under the "bad_indices" key in the env's
|
|
|
|
config.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> from ray.rllib.env.env_context import EnvContext
|
|
|
|
>>> # This env will fail for workers 1 and 2 (not for the local worker
|
2022-04-08 15:33:28 +02:00
|
|
|
>>> # or any others with an index != [1|2]).
|
2021-11-16 11:26:47 +00:00
|
|
|
>>> bad_env = FaultInjectEnv(
|
2022-07-27 00:10:45 -07:00
|
|
|
... EnvContext(
|
|
|
|
... {"bad_indices": [1, 2]},
|
|
|
|
... worker_index=1,
|
|
|
|
... num_workers=3,
|
2022-08-29 14:37:56 +02:00
|
|
|
... )
|
2022-07-27 00:10:45 -07:00
|
|
|
... )
|
2022-06-30 13:25:22 +02:00
|
|
|
|
|
|
|
>>> from ray.rllib.env.env_context import EnvContext
|
|
|
|
>>> # This env will fail only on the first evaluation worker, not on the first
|
|
|
|
>>> # regular rollout worker.
|
|
|
|
>>> bad_env = FaultInjectEnv(
|
2022-07-27 00:10:45 -07:00
|
|
|
... EnvContext(
|
|
|
|
... {"bad_indices": [1], "eval_only": True},
|
|
|
|
... worker_index=2,
|
|
|
|
... num_workers=5,
|
|
|
|
... )
|
|
|
|
... )
|
2021-11-16 11:26:47 +00:00
|
|
|
"""
|
|
|
|
|
2019-03-16 13:34:09 -07:00
|
|
|
def __init__(self, config):
|
2022-07-27 00:10:45 -07:00
|
|
|
# Use RandomEnv to control episode length if needed.
|
|
|
|
self.env = RandomEnv(config)
|
2022-02-17 05:06:14 -08:00
|
|
|
self._skip_env_checking = True
|
2019-03-16 13:34:09 -07:00
|
|
|
self.action_space = self.env.action_space
|
|
|
|
self.observation_space = self.env.observation_space
|
|
|
|
self.config = config
|
2022-07-27 00:10:45 -07:00
|
|
|
# External counter service.
|
|
|
|
if "counter" in config:
|
|
|
|
self.counter = ray.get_actor(config["counter"])
|
|
|
|
else:
|
|
|
|
self.counter = None
|
|
|
|
|
|
|
|
if config.get("init_delay", 0) > 0.0:
|
|
|
|
# Simulate an initialization delay.
|
|
|
|
time.sleep(config.get("init_delay"))
|
|
|
|
|
|
|
|
def _increment_count(self):
|
|
|
|
if self.counter:
|
|
|
|
eval = self.config.get("evaluation", False)
|
|
|
|
worker_index = self.config.worker_index
|
|
|
|
vector_index = self.config.vector_index
|
|
|
|
ray.wait([self.counter.increment.remote(eval, worker_index, vector_index)])
|
|
|
|
|
|
|
|
def _get_count(self):
|
|
|
|
if self.counter:
|
|
|
|
eval = self.config.get("evaluation", False)
|
|
|
|
worker_index = self.config.worker_index
|
|
|
|
vector_index = self.config.vector_index
|
|
|
|
return ray.get(self.counter.get.remote(eval, worker_index, vector_index))
|
|
|
|
return -1
|
|
|
|
|
|
|
|
def _maybe_raise_error(self):
|
|
|
|
# Do not raise simulated error if this worker is not bad.
|
|
|
|
if self.config.worker_index not in self.config.get("bad_indices", []):
|
|
|
|
return
|
|
|
|
|
|
|
|
# Do not raise simulated error if recreated worker can not fail,
|
|
|
|
# and this is a recreated worker.
|
|
|
|
if (
|
|
|
|
not self.config.get("recreated_worker_can_fail", False)
|
|
|
|
and self.config.recreated_worker
|
|
|
|
):
|
|
|
|
return
|
|
|
|
|
|
|
|
if self.counter:
|
|
|
|
count = self._get_count()
|
|
|
|
if self.config.get(
|
|
|
|
"failure_start_count", -1
|
|
|
|
) >= 0 and count < self.config.get("failure_start_count"):
|
|
|
|
return
|
|
|
|
|
|
|
|
if self.config.get(
|
|
|
|
"failure_stop_count", -1
|
|
|
|
) >= 0 and count >= self.config.get("failure_stop_count"):
|
|
|
|
return
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
"This is a simulated error from "
|
|
|
|
f"{'eval-' if self.config.get('evaluation', False) else ''}"
|
|
|
|
f"worker-idx={self.config.worker_index}!"
|
|
|
|
)
|
2019-03-16 13:34:09 -07:00
|
|
|
|
|
|
|
def reset(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._increment_count()
|
|
|
|
self._maybe_raise_error()
|
2019-03-16 13:34:09 -07:00
|
|
|
return self.env.reset()
|
|
|
|
|
|
|
|
def step(self, action):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._increment_count()
|
|
|
|
self._maybe_raise_error()
|
2019-03-16 13:34:09 -07:00
|
|
|
return self.env.step(action)
|
|
|
|
|
|
|
|
|
2022-07-27 00:10:45 -07:00
|
|
|
def is_recreated(w):
|
|
|
|
return w.apply.remote(
|
|
|
|
lambda w: w.recreated_worker or w.env_context.recreated_worker
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-06-20 19:53:47 +02:00
|
|
|
class TestWorkerFailure(unittest.TestCase):
|
2021-12-21 08:39:05 +01:00
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
2022-07-27 00:10:45 -07:00
|
|
|
ray.init()
|
|
|
|
|
|
|
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
|
|
|
register_env(
|
|
|
|
"multi-agent-fault_env", lambda c: make_multi_agent(FaultInjectEnv)(c)
|
|
|
|
)
|
2021-12-21 08:39:05 +01:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
2022-07-27 00:10:45 -07:00
|
|
|
def _do_test_fault_ignore(self, algo: str, config: dict, fail_eval: bool = False):
|
2022-06-20 19:53:47 +02:00
|
|
|
algo_cls = get_algorithm_class(algo)
|
2019-03-16 13:34:09 -07:00
|
|
|
|
|
|
|
# Test fault handling
|
2022-07-27 00:10:45 -07:00
|
|
|
config["num_workers"] = 2
|
|
|
|
config["ignore_worker_failures"] = True
|
|
|
|
# Make worker idx=1 fail. Other workers will be ok.
|
|
|
|
config["env_config"] = {"bad_indices": [1]}
|
|
|
|
if fail_eval:
|
2022-06-30 13:25:22 +02:00
|
|
|
config["evaluation_num_workers"] = 2
|
|
|
|
config["evaluation_interval"] = 1
|
|
|
|
config["evaluation_config"] = {
|
|
|
|
"ignore_worker_failures": True,
|
|
|
|
"env_config": {
|
|
|
|
# Make worker idx=1 fail. Other workers will be ok.
|
|
|
|
"bad_indices": [1],
|
2022-07-27 00:10:45 -07:00
|
|
|
"evaluation": True,
|
2022-06-30 13:25:22 +02:00
|
|
|
},
|
|
|
|
}
|
2021-11-16 11:26:47 +00:00
|
|
|
|
2022-04-08 15:33:28 +02:00
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
2022-06-20 19:53:47 +02:00
|
|
|
algo = algo_cls(config=config, env="fault_env")
|
|
|
|
result = algo.train()
|
2022-07-27 00:10:45 -07:00
|
|
|
|
|
|
|
# Both rollout workers are healthy.
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
|
|
|
if fail_eval:
|
|
|
|
# One of the eval workers failed.
|
2022-06-30 13:25:22 +02:00
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
2022-07-27 00:10:45 -07:00
|
|
|
|
2022-06-20 19:53:47 +02:00
|
|
|
algo.stop()
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2022-07-27 00:10:45 -07:00
|
|
|
def _do_test_fault_fatal(self, alg, config, fail_eval=False):
|
2022-06-11 15:10:39 +02:00
|
|
|
agent_cls = get_algorithm_class(alg)
|
2021-11-16 11:26:47 +00:00
|
|
|
|
2022-06-30 13:25:22 +02:00
|
|
|
# Test raises real error when out of workers.
|
2022-07-27 00:10:45 -07:00
|
|
|
config["num_workers"] = 2
|
|
|
|
config["ignore_worker_failures"] = False
|
|
|
|
# Make both worker idx=1 and 2 fail.
|
|
|
|
config["env_config"] = {"bad_indices": [1, 2]}
|
|
|
|
if fail_eval:
|
|
|
|
config["evaluation_num_workers"] = 2
|
2022-06-30 13:25:22 +02:00
|
|
|
config["evaluation_interval"] = 1
|
|
|
|
config["evaluation_config"] = {
|
2022-07-27 00:10:45 -07:00
|
|
|
"ignore_worker_failures": False,
|
2022-06-30 13:25:22 +02:00
|
|
|
# Make eval worker (index 1) fail.
|
|
|
|
"env_config": {
|
|
|
|
"bad_indices": [1],
|
2022-07-27 00:10:45 -07:00
|
|
|
"evaluation": True,
|
2022-06-30 13:25:22 +02:00
|
|
|
},
|
|
|
|
}
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
|
|
|
a = agent_cls(config=config, env="fault_env")
|
|
|
|
self.assertRaises(Exception, lambda: a.train())
|
|
|
|
a.stop()
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2022-08-16 12:05:55 +02:00
|
|
|
def _do_test_fault_fatal_but_recreate(self, alg, config):
|
|
|
|
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
|
|
|
agent_cls = get_algorithm_class(alg)
|
|
|
|
|
|
|
|
# Test raises real error when out of workers.
|
|
|
|
config["num_workers"] = 1
|
|
|
|
config["evaluation_num_workers"] = 1
|
|
|
|
config["evaluation_interval"] = 1
|
|
|
|
config["evaluation_config"] = {
|
|
|
|
"recreate_failed_workers": True,
|
|
|
|
# Make eval worker (index 1) fail.
|
|
|
|
"env_config": {
|
|
|
|
"bad_indices": [1],
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf", "tf2", "torch")):
|
|
|
|
a = agent_cls(config=config, env="fault_env")
|
|
|
|
# Expect this to go well and all faulty workers are recovered.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
worker.apply.remote(
|
|
|
|
lambda w: w.recreated_worker
|
|
|
|
or w.env_context.recreated_worker
|
|
|
|
)
|
|
|
|
)
|
|
|
|
for worker in a.workers.remote_workers()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
result = a.train()
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
|
|
|
# This should also work several times.
|
|
|
|
result = a.train()
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
|
|
|
a.stop()
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_fatal(self):
|
2022-04-08 15:33:28 +02:00
|
|
|
# Test the case where all workers fail (w/o recovery).
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_fatal("PG", {"optimizer": {}})
|
2022-04-08 15:33:28 +02:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_grads(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore("A3C", {"optimizer": {"grads_per_step": 1}})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_replay(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore(
|
2019-03-16 13:34:09 -07:00
|
|
|
"APEX",
|
|
|
|
{
|
|
|
|
"num_gpus": 0,
|
2022-06-20 19:53:47 +02:00
|
|
|
"min_sample_timesteps_per_iteration": 1000,
|
2022-06-10 17:09:18 +02:00
|
|
|
"min_time_s_per_iteration": 1,
|
2020-02-20 17:39:16 +01:00
|
|
|
"explore": False,
|
2022-08-11 13:07:30 +02:00
|
|
|
"num_steps_sampled_before_learning_starts": 1000,
|
2019-03-16 13:34:09 -07:00
|
|
|
"target_network_update_freq": 100,
|
|
|
|
"optimizer": {
|
|
|
|
"num_replay_buffer_shards": 1,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_samples(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore("IMPALA", {"num_gpus": 0})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_sync_replay(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore("DQN", {"min_sample_timesteps_per_iteration": 1})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_multi_g_p_u(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore(
|
2019-03-16 13:34:09 -07:00
|
|
|
"PPO",
|
|
|
|
{
|
|
|
|
"num_sgd_iter": 1,
|
|
|
|
"train_batch_size": 10,
|
2020-03-14 12:05:04 -07:00
|
|
|
"rollout_fragment_length": 10,
|
2019-03-16 13:34:09 -07:00
|
|
|
"sgd_minibatch_size": 1,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_sync_samples(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore("PG", {"optimizer": {}})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
def test_async_sampling_option(self):
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore("PG", {"optimizer": {}, "sample_async": True})
|
2019-03-16 13:34:09 -07:00
|
|
|
|
2022-06-30 13:25:22 +02:00
|
|
|
def test_eval_workers_failing_ignore(self):
|
|
|
|
# Test the case where one eval worker fails, but we chose to ignore.
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_ignore(
|
2022-06-30 13:25:22 +02:00
|
|
|
"PG",
|
|
|
|
config={"model": {"fcnet_hiddens": [4]}},
|
2022-07-27 00:10:45 -07:00
|
|
|
fail_eval=True,
|
2022-06-30 13:25:22 +02:00
|
|
|
)
|
|
|
|
|
2022-08-16 12:05:55 +02:00
|
|
|
def test_recreate_eval_workers_parallel_to_training_w_async_req_manager(self):
|
|
|
|
# Test the case where all eval workers fail, but we chose to recover.
|
|
|
|
config = (
|
|
|
|
PGConfig()
|
|
|
|
.evaluation(
|
|
|
|
enable_async_evaluation=True,
|
|
|
|
evaluation_parallel_to_training=True,
|
|
|
|
evaluation_duration="auto",
|
|
|
|
)
|
|
|
|
.training(model={"fcnet_hiddens": [4]})
|
|
|
|
)
|
|
|
|
|
|
|
|
self._do_test_fault_fatal_but_recreate("PG", config=config.to_dict())
|
|
|
|
|
2022-06-30 13:25:22 +02:00
|
|
|
def test_eval_workers_failing_fatal(self):
|
|
|
|
# Test the case where all eval workers fail (w/o recovery).
|
2022-07-27 00:10:45 -07:00
|
|
|
self._do_test_fault_fatal(
|
2022-06-30 13:25:22 +02:00
|
|
|
"PG",
|
|
|
|
config={"model": {"fcnet_hiddens": [4]}},
|
2022-07-27 00:10:45 -07:00
|
|
|
fail_eval=True,
|
2022-06-30 13:25:22 +02:00
|
|
|
)
|
|
|
|
|
2022-07-27 00:10:45 -07:00
|
|
|
def test_workers_fatal_but_recover(self):
|
|
|
|
# Counter that will survive restarts.
|
|
|
|
COUNTER_NAME = "test_workers_fatal_but_recover"
|
|
|
|
counter = Counter.options(name=COUNTER_NAME).remote()
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_workers": 2,
|
|
|
|
# Worker fault tolerance.
|
|
|
|
"ignore_worker_failures": False, # Do not ignore
|
|
|
|
"recreate_failed_workers": True, # But recover.
|
|
|
|
"model": {"fcnet_hiddens": [4]},
|
|
|
|
"env_config": {
|
|
|
|
# Make both worker idx=1 and 2 fail.
|
|
|
|
"bad_indices": [1, 2],
|
|
|
|
# Env throws error between steps 100 and 102.
|
|
|
|
"failure_start_count": 100,
|
|
|
|
"failure_stop_count": 102,
|
|
|
|
"counter": COUNTER_NAME,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
# Reset interaciton counter.
|
|
|
|
ray.wait([counter.reset.remote()])
|
|
|
|
|
|
|
|
a = PG(config=config, env="fault_env")
|
|
|
|
|
|
|
|
# Before train loop, workers are fresh and not recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
self.assertEqual(result["num_healthy_workers"], 2)
|
|
|
|
# Workers are re-created.
|
|
|
|
self.assertEqual(result["num_recreated_workers"], 2)
|
|
|
|
self.assertTrue(
|
|
|
|
all(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_eval_workers_fault_but_recover(self):
|
|
|
|
# Counter that will survive restarts.
|
|
|
|
COUNTER_NAME = "test_eval_workers_fault_but_recover"
|
|
|
|
counter = Counter.options(name=COUNTER_NAME).remote()
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_workers": 2,
|
|
|
|
# Worker fault tolerance.
|
|
|
|
"ignore_worker_failures": True, # Ignore failure.
|
|
|
|
"recreate_failed_workers": True, # And recover.
|
|
|
|
"model": {"fcnet_hiddens": [4]},
|
|
|
|
# 2 eval workers.
|
|
|
|
"evaluation_num_workers": 2,
|
|
|
|
"evaluation_interval": 1,
|
|
|
|
"evaluation_config": {
|
|
|
|
"env_config": {
|
|
|
|
"evaluation": True,
|
|
|
|
"p_done": 0.0,
|
|
|
|
"max_episode_len": 20,
|
|
|
|
# Make both eval workers fail.
|
|
|
|
"bad_indices": [1, 2],
|
|
|
|
# Env throws error between steps 10 and 12.
|
|
|
|
"failure_start_count": 10,
|
|
|
|
"failure_stop_count": 12,
|
|
|
|
"counter": COUNTER_NAME,
|
|
|
|
}
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
# Reset interaciton counter.
|
|
|
|
ray.wait([counter.reset.remote()])
|
|
|
|
|
|
|
|
a = PG(config=config, env="fault_env")
|
|
|
|
|
|
|
|
# Before train loop, workers are fresh and not recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# Eval workers are also fresh and not recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[
|
|
|
|
is_recreated(worker)
|
|
|
|
for worker in a.evaluation_workers.remote_workers()
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
self.assertEqual(result["num_healthy_workers"], 2)
|
|
|
|
# Nothing happens to worker. They are still not re-created.
|
|
|
|
self.assertEqual(result["num_recreated_workers"], 0)
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertEqual(result["evaluation"]["num_healthy_workers"], 2)
|
|
|
|
# But all the eval workers are re-created.
|
|
|
|
self.assertEqual(result["evaluation"]["num_recreated_workers"], 2)
|
|
|
|
self.assertTrue(
|
|
|
|
all(
|
|
|
|
ray.get(
|
|
|
|
[
|
|
|
|
is_recreated(worker)
|
|
|
|
for worker in a.evaluation_workers.remote_workers()
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_eval_workers_fault_but_restore_env(self):
|
|
|
|
# Counter that will survive restarts.
|
|
|
|
COUNTER_NAME = "test_eval_workers_fault_but_restore_env"
|
|
|
|
counter = Counter.options(name=COUNTER_NAME).remote()
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_workers": 2,
|
|
|
|
# Worker fault tolerance.
|
|
|
|
"ignore_worker_failures": True,
|
|
|
|
"recreate_failed_workers": True,
|
|
|
|
"model": {"fcnet_hiddens": [4]},
|
|
|
|
"env_config": {
|
|
|
|
# Make both worker idx=1 and 2 fail.
|
|
|
|
"bad_indices": [1, 2],
|
|
|
|
# Env throws error before step 2.
|
|
|
|
"failure_stop_count": 2,
|
|
|
|
"counter": COUNTER_NAME,
|
|
|
|
},
|
|
|
|
# 2 eval workers.
|
|
|
|
"evaluation_num_workers": 2,
|
|
|
|
"evaluation_interval": 1,
|
|
|
|
"evaluation_config": {
|
|
|
|
"ignore_worker_failures": True,
|
|
|
|
"recreate_failed_workers": True,
|
|
|
|
# Now instead of recreating failed workers,
|
|
|
|
# we want to recreate the failed sub env instead.
|
|
|
|
"restart_failed_sub_environments": True,
|
|
|
|
"env_config": {
|
|
|
|
"evaluation": True,
|
|
|
|
# Make eval worker (index 1) fail.
|
|
|
|
"bad_indices": [1],
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
# Reset interaciton counter.
|
|
|
|
ray.wait([counter.reset.remote()])
|
|
|
|
|
|
|
|
a = PG(config=config, env="fault_env")
|
|
|
|
|
|
|
|
# Before train loop, workers are fresh and not recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 2)
|
|
|
|
# Workers are re-created.
|
|
|
|
self.assertEqual(result["num_recreated_workers"], 2)
|
|
|
|
self.assertTrue(
|
|
|
|
all(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 2)
|
|
|
|
# However eval worker is not, since we want to restart
|
|
|
|
# individual env.
|
|
|
|
self.assertEqual(result["evaluation"]["num_recreated_workers"], 0)
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[
|
|
|
|
is_recreated(worker)
|
|
|
|
for worker in a.evaluation_workers.remote_workers()
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# This should also work several times.
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 2)
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 2)
|
|
|
|
|
|
|
|
a.stop()
|
|
|
|
|
|
|
|
def test_multi_agent_env_eval_workers_fault_but_restore_env(self):
|
|
|
|
# Counter that will survive restarts.
|
|
|
|
COUNTER_NAME = "test_multi_agent_env_eval_workers_fault_but_restore_env"
|
|
|
|
counter = Counter.options(name=COUNTER_NAME).remote()
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_workers": 2,
|
|
|
|
"model": {"fcnet_hiddens": [4]},
|
|
|
|
# Workers do not fault and no fault tolerance.
|
|
|
|
"env_config": {},
|
|
|
|
"multiagent": {
|
|
|
|
"policies": {
|
|
|
|
"main_agent": PolicySpec(),
|
|
|
|
},
|
|
|
|
"policies_to_train": ["main_agent"],
|
|
|
|
"policy_mapping_fn": lambda _: "main_agent",
|
|
|
|
},
|
|
|
|
# 2 eval workers.
|
|
|
|
"evaluation_num_workers": 2,
|
|
|
|
"evaluation_interval": 1,
|
|
|
|
"evaluation_config": {
|
|
|
|
# Now instead of recreating failed workers,
|
|
|
|
# we want to recreate the failed sub env instead.
|
|
|
|
"restart_failed_sub_environments": True,
|
|
|
|
"env_config": {
|
|
|
|
"evaluation": True,
|
|
|
|
"p_done": 0.0,
|
|
|
|
"max_episode_len": 20,
|
|
|
|
# Make eval worker (index 1) fail.
|
|
|
|
"bad_indices": [1],
|
|
|
|
"counter": COUNTER_NAME,
|
|
|
|
"failure_start_count": 10,
|
|
|
|
"failure_stop_count": 12,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"disable_env_checking": True,
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
# Reset interaciton counter.
|
|
|
|
ray.wait([counter.reset.remote()])
|
|
|
|
|
|
|
|
a = PG(config=config, env="multi-agent-fault_env")
|
|
|
|
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 2)
|
|
|
|
self.assertEqual(result["num_faulty_episodes"], 0)
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 2)
|
|
|
|
self.assertEqual(result["evaluation"]["num_recreated_workers"], 0)
|
|
|
|
# There should be a faulty episode.
|
|
|
|
self.assertEqual(result["evaluation"]["num_faulty_episodes"], 2)
|
|
|
|
|
|
|
|
# This should also work several times.
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 2)
|
|
|
|
self.assertEqual(result["num_faulty_episodes"], 0)
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 2)
|
|
|
|
self.assertEqual(result["evaluation"]["num_recreated_workers"], 0)
|
|
|
|
# There shouldn't be any faulty episode anymore.
|
|
|
|
self.assertEqual(result["evaluation"]["num_faulty_episodes"], 0)
|
|
|
|
|
|
|
|
a.stop()
|
|
|
|
|
|
|
|
def test_long_failure_period_restore_env(self):
|
|
|
|
# Counter that will survive restarts.
|
|
|
|
COUNTER_NAME = "test_long_failure_period_restore_env"
|
|
|
|
counter = Counter.options(name=COUNTER_NAME).remote()
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_workers": 1,
|
|
|
|
"create_env_on_driver": False,
|
|
|
|
# Worker fault tolerance.
|
|
|
|
"recreate_failed_workers": True, # Restore failed workers.
|
|
|
|
"restart_failed_sub_environments": True, # And create failed envs.
|
|
|
|
"model": {"fcnet_hiddens": [4]},
|
|
|
|
"env_config": {
|
|
|
|
"p_done": 0.0,
|
|
|
|
"max_episode_len": 100,
|
|
|
|
"bad_indices": [1],
|
|
|
|
# Env throws error between steps 50 and 150.
|
|
|
|
"failure_start_count": 30,
|
|
|
|
"failure_stop_count": 80,
|
|
|
|
"counter": COUNTER_NAME,
|
|
|
|
},
|
|
|
|
# 2 eval workers.
|
|
|
|
"evaluation_num_workers": 1,
|
|
|
|
"evaluation_interval": 1,
|
|
|
|
"evaluation_config": {
|
|
|
|
"env_config": {
|
|
|
|
"evaluation": True,
|
|
|
|
}
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
# Reset interaciton counter.
|
|
|
|
ray.wait([counter.reset.remote()])
|
|
|
|
|
|
|
|
a = PG(config=config, env="fault_env")
|
|
|
|
|
|
|
|
# Before train loop, workers are fresh and not recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# Eval workers are also fresh and not recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[
|
|
|
|
is_recreated(worker)
|
|
|
|
for worker in a.evaluation_workers.remote_workers()
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
result = a.train()
|
|
|
|
|
|
|
|
# Should see a lot of faulty episodes.
|
|
|
|
self.assertGreaterEqual(result["num_faulty_episodes"], 50)
|
|
|
|
self.assertGreaterEqual(result["evaluation"]["num_faulty_episodes"], 50)
|
|
|
|
|
|
|
|
self.assertTrue(result["num_healthy_workers"] == 1)
|
|
|
|
# All workers are still not restored, since env are restored.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[is_recreated(worker) for worker in a.workers.remote_workers()]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.assertTrue(result["evaluation"]["num_healthy_workers"] == 1)
|
|
|
|
# All eval workers are still not restored, since env are recreated.
|
|
|
|
self.assertTrue(
|
|
|
|
not any(
|
|
|
|
ray.get(
|
|
|
|
[
|
|
|
|
is_recreated(worker)
|
|
|
|
for worker in a.evaluation_workers.remote_workers()
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_env_wait_time_workers_restore_env(self):
|
|
|
|
# Counter that will survive restarts.
|
|
|
|
COUNTER_NAME = "test_env_wait_time_workers_restore_env"
|
|
|
|
counter = Counter.options(name=COUNTER_NAME).remote()
|
|
|
|
|
|
|
|
config = {
|
|
|
|
"num_workers": 1,
|
|
|
|
# Worker fault tolerance.
|
|
|
|
"ignore_worker_failures": False, # Do not ignore
|
|
|
|
"recreate_failed_workers": True, # But recover.
|
|
|
|
"restart_failed_sub_environments": True,
|
|
|
|
"model": {"fcnet_hiddens": [4]},
|
|
|
|
"rollout_fragment_length": 10,
|
|
|
|
"train_batch_size": 10,
|
|
|
|
"env_config": {
|
|
|
|
"p_done": 0.0,
|
|
|
|
"max_episode_len": 10,
|
|
|
|
"init_delay": 10, # 10 sec init delay.
|
|
|
|
# Make both worker idx=1 and 2 fail.
|
|
|
|
"bad_indices": [1],
|
|
|
|
# Env throws error between steps 100 and 102.
|
|
|
|
"failure_start_count": 7,
|
|
|
|
"failure_stop_count": 8,
|
|
|
|
"counter": COUNTER_NAME,
|
|
|
|
},
|
|
|
|
# Use EMA PerfStat.
|
|
|
|
# Really large coeff to show the difference in env_wait_time_ms.
|
|
|
|
# Pretty much consider the last 2 data points.
|
|
|
|
"sampler_perf_stats_ema_coef": 0.5,
|
|
|
|
# Important, don't smooth over all the episodes,
|
|
|
|
# otherwise we don't see latency spike.
|
|
|
|
"metrics_num_episodes_for_smoothing": 1,
|
|
|
|
}
|
|
|
|
|
|
|
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
|
|
|
# Reset interaciton counter.
|
|
|
|
ray.wait([counter.reset.remote()])
|
|
|
|
|
|
|
|
a = PG(config=config, env="fault_env")
|
|
|
|
|
|
|
|
# Had to restore env during this iteration.
|
|
|
|
result = a.train()
|
|
|
|
self.assertEqual(result["num_faulty_episodes"], 1)
|
|
|
|
time_with_restore = result["sampler_perf"]["mean_env_wait_ms"]
|
|
|
|
|
|
|
|
# Doesn't have to restore env during this iteration.
|
|
|
|
result = a.train()
|
|
|
|
# Still only 1 faulty episode.
|
|
|
|
self.assertEqual(result["num_faulty_episodes"], 0)
|
|
|
|
time_without_restore = result["sampler_perf"]["mean_env_wait_ms"]
|
|
|
|
|
|
|
|
# wait time with restore is at least 2 times wait time without restore.
|
|
|
|
self.assertGreater(time_with_restore, 2 * time_without_restore)
|
|
|
|
|
2022-07-15 08:55:14 +02:00
|
|
|
def test_eval_workers_on_infinite_episodes(self):
|
|
|
|
"""Tests whether eval workers warn appropriately after some episode timeout."""
|
|
|
|
# Create infinitely running episodes, but with horizon setting (RLlib will
|
|
|
|
# auto-terminate the episode). However, in the eval workers, don't set a
|
|
|
|
# horizon -> Expect warning and no proper evaluation results.
|
|
|
|
config = (
|
2022-07-27 00:10:45 -07:00
|
|
|
PGConfig()
|
2022-07-15 08:55:14 +02:00
|
|
|
.rollouts(num_rollout_workers=2, horizon=100)
|
|
|
|
.reporting(metrics_episode_collection_timeout_s=5.0)
|
|
|
|
.environment(env=RandomEnv, env_config={"p_done": 0.0})
|
|
|
|
.evaluation(
|
|
|
|
evaluation_num_workers=2,
|
|
|
|
evaluation_interval=1,
|
|
|
|
evaluation_sample_timeout_s=5.0,
|
|
|
|
evaluation_config={
|
|
|
|
"horizon": None,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
algo = config.build()
|
|
|
|
results = algo.train()
|
|
|
|
self.assertTrue(np.isnan(results["evaluation"]["episode_reward_mean"]))
|
|
|
|
|
2019-03-16 13:34:09 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-03-12 04:39:47 +01:00
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-06-20 19:53:47 +02:00
|
|
|
import pytest
|
|
|
|
|
2020-03-12 04:39:47 +01:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|