From 138b273136343c4c896f02eaf4896215d65ba648 Mon Sep 17 00:00:00 2001 From: matthewdeng Date: Wed, 9 Jun 2021 10:39:14 -0700 Subject: [PATCH] [rllib] Add tests for examples using ray client (#16271) * [rllib] add tests for examples using ray client * rename test_client to test_ray_client --- rllib/BUILD | 7 +++ rllib/tests/test_ray_client.py | 92 ++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 rllib/tests/test_ray_client.py diff --git a/rllib/BUILD b/rllib/BUILD index 286010b93..99b5a3677 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1550,6 +1550,13 @@ py_test( srcs = ["tests/test_placement_groups.py"] ) +py_test( + name = "tests/test_ray_client", + tags = ["tests_dir", "tests_dir_R"], + size = "large", + srcs = ["tests/test_ray_client.py"] +) + py_test( name = "tests/test_reproducibility", tags = ["tests_dir", "tests_dir_R"], diff --git a/rllib/tests/test_ray_client.py b/rllib/tests/test_ray_client.py new file mode 100644 index 000000000..e4c8ec592 --- /dev/null +++ b/rllib/tests/test_ray_client.py @@ -0,0 +1,92 @@ +import os +import sys +import unittest + +import pytest +import ray +from ray import tune +from ray.job_config import JobConfig +from ray.rllib.agents import ppo +from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole +from ray.rllib.utils.test_utils import check_learning_achieved +from ray.util.client.ray_client_helpers import ray_start_client_server + + +class TestRayClient(unittest.TestCase): + def test_connection(self): + with ray_start_client_server(): + assert ray.util.client.ray.is_connected() + assert ray.util.client.ray.is_connected() is False + + def test_custom_train_fn(self): + with ray_start_client_server(): + assert ray.util.client.ray.is_connected() + + config = { + "lr": 0.01, + # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. + "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), + "num_workers": 0, + "framework": "tf", + } + resources = ppo.PPOTrainer.default_resource_request(config) + from ray.rllib.examples.custom_train_fn import my_train_fn + tune.run(my_train_fn, resources_per_trial=resources, config=config) + + def test_cartpole_lstm(self): + with ray_start_client_server(): + assert ray.util.client.ray.is_connected() + + config = dict( + { + "num_sgd_iter": 5, + "model": { + "vf_share_layers": True, + }, + "vf_loss_coeff": 0.0001, + }, + **{ + "env": StatelessCartPole, + # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. + "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), + "model": { + "use_lstm": True, + "lstm_cell_size": 256, + "lstm_use_prev_action": None, + "lstm_use_prev_reward": None, + }, + "framework": "tf", + # Run with tracing enabled for tfe/tf2? + "eager_tracing": None, + }) + + stop = { + "training_iteration": 200, + "timesteps_total": 100000, + "episode_reward_mean": 150.0, + } + + results = tune.run("PPO", config=config, stop=stop, verbose=2) + check_learning_achieved(results, 150.0) + + def test_custom_experiment(self): + def ray_connect_handler(job_config: JobConfig = None): + ray.init(num_cpus=3) + + with ray_start_client_server(ray_connect_handler=ray_connect_handler): + assert ray.util.client.ray.is_connected() + + config = ppo.DEFAULT_CONFIG.copy() + config["train-iterations"] = 10 + config["env"] = "CartPole-v0" + + from ray.rllib.examples.custom_experiment import experiment + tune.run( + experiment, + config=config, + resources_per_trial=ppo.PPOTrainer.default_resource_request( + config)) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__]))