ray/rllib/tests/test_ray_client.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

72 lines
2.2 KiB
Python
Raw Normal View History

import os
import sys
import unittest
import pytest
import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.util.client.ray_client_helpers import ray_start_client_server
class TestRayClient(unittest.TestCase):
def test_connection(self):
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
assert ray.util.client.ray.is_connected() is False
def test_custom_train_fn(self):
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
config = {
# Special flag signalling `my_train_fn` how many iters to do.
"train-iterations": 2,
"lr": 0.01,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
"framework": "tf",
}
resources = ppo.PPOTrainer.default_resource_request(config)
from ray.rllib.examples.custom_train_fn import my_train_fn
tune.run(my_train_fn, resources_per_trial=resources, config=config)
def test_cartpole_lstm(self):
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
config = {
"env": StatelessCartPole,
}
stop = {
"training_iteration": 3,
}
tune.run("PPO", config=config, stop=stop, verbose=2)
def test_custom_experiment(self):
with ray_start_client_server(ray_init_kwargs={"num_cpus": 3}):
assert ray.util.client.ray.is_connected()
config = ppo.DEFAULT_CONFIG.copy()
# Special flag signalling `experiment` how many iters to do.
config["train-iterations"] = 2
config["env"] = "CartPole-v0"
from ray.rllib.examples.custom_experiment import experiment
tune.run(
experiment,
config=config,
resources_per_trial=ppo.PPOTrainer.default_resource_request(config),
)
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))