ray/rllib/env/tests/test_remote_worker_envs.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

100 lines
2.8 KiB
Python
Raw Permalink Normal View History

import gym
import numpy as np
from pettingzoo.butterfly import pistonball_v6
from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
import unittest
import ray
from ray.rllib.algorithms.pg import pg
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
2022-05-23 08:18:44 +02:00
# from ray.rllib.examples.env.random_env import RandomEnv
from ray import tune
# Function that outputs the environment you wish to register.
def env_creator(config):
env = pistonball_v6.env()
env = dtype_v0(env, dtype=np.float32)
env = color_reduction_v0(env, mode="R")
env = normalize_obs_v0(env)
return env
tune.register_env("cartpole", lambda env_ctx: gym.make("CartPole-v0"))
tune.register_env("pistonball", lambda config: PettingZooEnv(env_creator(config)))
class TestRemoteWorkerEnvSetting(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_remote_worker_env(self):
config = pg.DEFAULT_CONFIG.copy()
config["remote_worker_envs"] = True
config["num_envs_per_worker"] = 4
# Simple string env definition (gym.make(...)).
config["env"] = "CartPole-v0"
algo = pg.PG(config=config)
print(algo.train())
algo.stop()
# Using tune.register.
config["env"] = "cartpole"
algo = pg.PG(config=config)
print(algo.train())
algo.stop()
# Using class directly.
2022-05-23 08:18:44 +02:00
# This doesn't work anymore as of gym==0.23
# config["env"] = RandomEnv
# algo = pg.PG(config=config)
# print(algo.train())
# algo.stop()
# Using class directly: Sub-class of gym.Env,
# which implements its own API.
2022-05-23 08:18:44 +02:00
# This doesn't work anymore as of gym==0.23
# config["env"] = NonVectorizedEnvToBeVectorizedIntoRemoteBaseEnv
# algo = pg.PG(config=config)
# print(algo.train())
# algo.stop()
def test_remote_worker_env_multi_agent(self):
config = pg.DEFAULT_CONFIG.copy()
config["remote_worker_envs"] = True
config["num_envs_per_worker"] = 4
# Full classpath provided.
config["env"] = "ray.rllib.examples.env.random_env.RandomMultiAgentEnv"
algo = pg.PG(config=config)
print(algo.train())
algo.stop()
# Using tune.register.
config["env"] = "pistonball"
algo = pg.PG(config=config)
print(algo.train())
algo.stop()
# Using class directly.
2022-05-23 08:18:44 +02:00
# This doesn't work anymore as of gym==0.23.
# config["env"] = RandomMultiAgentEnv
# algo = pg.PG(config=config)
# print(algo.train())
# algo.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))