mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import pytest
|
|
import ray
|
|
from ray import tune
|
|
from ray.rllib.agents import ppo
|
|
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
|
from ray.util.client.ray_client_helpers import ray_start_client_server
|
|
|
|
|
|
class TestRayClient(unittest.TestCase):
|
|
def test_connection(self):
|
|
with ray_start_client_server():
|
|
assert ray.util.client.ray.is_connected()
|
|
assert ray.util.client.ray.is_connected() is False
|
|
|
|
def test_custom_train_fn(self):
|
|
with ray_start_client_server():
|
|
assert ray.util.client.ray.is_connected()
|
|
|
|
config = {
|
|
# Special flag signalling `my_train_fn` how many iters to do.
|
|
"train-iterations": 2,
|
|
"lr": 0.01,
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
|
"num_workers": 0,
|
|
"framework": "tf",
|
|
}
|
|
resources = ppo.PPOTrainer.default_resource_request(config)
|
|
from ray.rllib.examples.custom_train_fn import my_train_fn
|
|
tune.run(my_train_fn, resources_per_trial=resources, config=config)
|
|
|
|
def test_cartpole_lstm(self):
|
|
with ray_start_client_server():
|
|
assert ray.util.client.ray.is_connected()
|
|
|
|
config = {
|
|
"env": StatelessCartPole,
|
|
}
|
|
|
|
stop = {
|
|
"training_iteration": 3,
|
|
}
|
|
|
|
tune.run("PPO", config=config, stop=stop, verbose=2)
|
|
|
|
def test_custom_experiment(self):
|
|
|
|
with ray_start_client_server(ray_init_kwargs={"num_cpus": 3}):
|
|
assert ray.util.client.ray.is_connected()
|
|
|
|
config = ppo.DEFAULT_CONFIG.copy()
|
|
# Special flag signalling `experiment` how many iters to do.
|
|
config["train-iterations"] = 2
|
|
config["env"] = "CartPole-v0"
|
|
|
|
from ray.rllib.examples.custom_experiment import experiment
|
|
tune.run(
|
|
experiment,
|
|
config=config,
|
|
resources_per_trial=ppo.PPOTrainer.default_resource_request(
|
|
config))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(pytest.main(["-v", __file__]))
|