2021-10-07 22:39:21 +02:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
2022-03-01 05:23:27 -05:00
|
|
|
from pettingzoo.butterfly import pistonball_v6
|
2021-10-07 22:39:21 +02:00
|
|
|
from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.pg import pg
|
2021-10-07 22:39:21 +02:00
|
|
|
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
|
|
|
from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.examples.remote_base_env_with_custom_api import (
|
|
|
|
NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv,
|
|
|
|
)
|
2021-10-07 22:39:21 +02:00
|
|
|
from ray import tune
|
|
|
|
|
|
|
|
|
|
|
|
# Function that outputs the environment you wish to register.
|
|
|
|
def env_creator(config):
|
2022-03-01 05:23:27 -05:00
|
|
|
env = pistonball_v6.env()
|
2021-10-07 22:39:21 +02:00
|
|
|
env = dtype_v0(env, dtype=np.float32)
|
|
|
|
env = color_reduction_v0(env, mode="R")
|
|
|
|
env = normalize_obs_v0(env)
|
|
|
|
return env
|
|
|
|
|
|
|
|
|
|
|
|
tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0"))
|
2021-10-09 00:11:53 +02:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
tune.register_env("pistonball", lambda config: PettingZooEnv(env_creator(config)))
|
2021-10-07 22:39:21 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TestRemoteWorkerEnvSetting(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls) -> None:
|
|
|
|
ray.init(num_cpus=4)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def tearDownClass(cls) -> None:
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
def test_remote_worker_env(self):
|
|
|
|
config = pg.DEFAULT_CONFIG.copy()
|
|
|
|
config["remote_worker_envs"] = True
|
|
|
|
config["num_envs_per_worker"] = 4
|
|
|
|
|
|
|
|
# Simple string env definition (gym.make(...)).
|
|
|
|
config["env"] = "CartPole-v0"
|
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
# Using tune.register.
|
|
|
|
config["env"] = "cartpole"
|
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
# Using class directly.
|
|
|
|
config["env"] = RandomEnv
|
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
# Using class directly: Sub-class of gym.Env,
|
|
|
|
# which implements its own API.
|
2021-11-17 21:40:16 +01:00
|
|
|
config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
|
2021-10-07 22:39:21 +02:00
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
def test_remote_worker_env_multi_agent(self):
|
|
|
|
config = pg.DEFAULT_CONFIG.copy()
|
|
|
|
config["remote_worker_envs"] = True
|
|
|
|
config["num_envs_per_worker"] = 4
|
|
|
|
|
|
|
|
# Full classpath provided.
|
2022-01-29 18:41:57 -08:00
|
|
|
config["env"] = "ray.rllib.examples.env.random_env.RandomMultiAgentEnv"
|
2021-10-07 22:39:21 +02:00
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
# Using tune.register.
|
|
|
|
config["env"] = "pistonball"
|
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
# Using class directly.
|
|
|
|
config["env"] = RandomMultiAgentEnv
|
|
|
|
trainer = pg.PGTrainer(config=config)
|
|
|
|
print(trainer.train())
|
|
|
|
trainer.stop()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pytest
|
|
|
|
import sys
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2021-10-07 22:39:21 +02:00
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|