mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
248 lines
8.6 KiB
Python
248 lines
8.6 KiB
Python
import gym
|
|
import numpy as np
|
|
import random
|
|
import unittest
|
|
import uuid
|
|
|
|
import ray
|
|
from ray.rllib.algorithms.dqn import DQNTrainer
|
|
from ray.rllib.algorithms.pg import PGTrainer
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|
from ray.rllib.env.external_env import ExternalEnv
|
|
from ray.rllib.evaluation.tests.test_rollout_worker import BadPolicy, MockPolicy
|
|
from ray.rllib.examples.env.mock_env import MockEnv
|
|
from ray.rllib.utils.test_utils import framework_iterator
|
|
from ray.tune.registry import register_env
|
|
|
|
|
|
def make_simple_serving(multiagent, superclass):
|
|
class SimpleServing(superclass):
|
|
def __init__(self, env):
|
|
superclass.__init__(self, env.action_space, env.observation_space)
|
|
self.env = env
|
|
|
|
def run(self):
|
|
eid = self.start_episode()
|
|
obs = self.env.reset()
|
|
while True:
|
|
action = self.get_action(eid, obs)
|
|
obs, reward, done, info = self.env.step(action)
|
|
if multiagent:
|
|
self.log_returns(eid, reward)
|
|
else:
|
|
self.log_returns(eid, reward, info=info)
|
|
if done:
|
|
self.end_episode(eid, obs)
|
|
obs = self.env.reset()
|
|
eid = self.start_episode()
|
|
|
|
return SimpleServing
|
|
|
|
|
|
# generate & register SimpleServing class
|
|
SimpleServing = make_simple_serving(False, ExternalEnv)
|
|
|
|
|
|
class PartOffPolicyServing(ExternalEnv):
|
|
def __init__(self, env, off_pol_frac):
|
|
ExternalEnv.__init__(self, env.action_space, env.observation_space)
|
|
self.env = env
|
|
self.off_pol_frac = off_pol_frac
|
|
|
|
def run(self):
|
|
eid = self.start_episode()
|
|
obs = self.env.reset()
|
|
while True:
|
|
if random.random() < self.off_pol_frac:
|
|
action = self.env.action_space.sample()
|
|
self.log_action(eid, obs, action)
|
|
else:
|
|
action = self.get_action(eid, obs)
|
|
obs, reward, done, info = self.env.step(action)
|
|
self.log_returns(eid, reward, info=info)
|
|
if done:
|
|
self.end_episode(eid, obs)
|
|
obs = self.env.reset()
|
|
eid = self.start_episode()
|
|
|
|
|
|
class SimpleOffPolicyServing(ExternalEnv):
|
|
def __init__(self, env, fixed_action):
|
|
ExternalEnv.__init__(self, env.action_space, env.observation_space)
|
|
self.env = env
|
|
self.fixed_action = fixed_action
|
|
|
|
def run(self):
|
|
eid = self.start_episode()
|
|
obs = self.env.reset()
|
|
while True:
|
|
action = self.fixed_action
|
|
self.log_action(eid, obs, action)
|
|
obs, reward, done, info = self.env.step(action)
|
|
self.log_returns(eid, reward, info=info)
|
|
if done:
|
|
self.end_episode(eid, obs)
|
|
obs = self.env.reset()
|
|
eid = self.start_episode()
|
|
|
|
|
|
class MultiServing(ExternalEnv):
|
|
def __init__(self, env_creator):
|
|
self.env_creator = env_creator
|
|
self.env = env_creator()
|
|
ExternalEnv.__init__(self, self.env.action_space, self.env.observation_space)
|
|
|
|
def run(self):
|
|
envs = [self.env_creator() for _ in range(5)]
|
|
cur_obs = {}
|
|
eids = {}
|
|
while True:
|
|
active = np.random.choice(range(5), 2, replace=False)
|
|
for i in active:
|
|
if i not in cur_obs:
|
|
eids[i] = uuid.uuid4().hex
|
|
self.start_episode(episode_id=eids[i])
|
|
cur_obs[i] = envs[i].reset()
|
|
actions = [self.get_action(eids[i], cur_obs[i]) for i in active]
|
|
for i, action in zip(active, actions):
|
|
obs, reward, done, _ = envs[i].step(action)
|
|
cur_obs[i] = obs
|
|
self.log_returns(eids[i], reward)
|
|
if done:
|
|
self.end_episode(eids[i], obs)
|
|
del cur_obs[i]
|
|
|
|
|
|
class TestExternalEnv(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
ray.init(ignore_reinit_error=True)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
ray.shutdown()
|
|
|
|
def test_external_env_complete_episodes(self):
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
|
policy_spec=MockPolicy,
|
|
rollout_fragment_length=40,
|
|
batch_mode="complete_episodes",
|
|
)
|
|
for _ in range(3):
|
|
batch = ev.sample()
|
|
self.assertEqual(batch.count, 50)
|
|
|
|
def test_external_env_truncate_episodes(self):
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
|
policy_spec=MockPolicy,
|
|
rollout_fragment_length=40,
|
|
batch_mode="truncate_episodes",
|
|
)
|
|
for _ in range(3):
|
|
batch = ev.sample()
|
|
self.assertEqual(batch.count, 40)
|
|
|
|
def test_external_env_off_policy(self):
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
|
policy_spec=MockPolicy,
|
|
rollout_fragment_length=40,
|
|
batch_mode="complete_episodes",
|
|
)
|
|
for _ in range(3):
|
|
batch = ev.sample()
|
|
self.assertEqual(batch.count, 50)
|
|
self.assertEqual(batch["actions"][0], 42)
|
|
self.assertEqual(batch["actions"][-1], 42)
|
|
|
|
def test_external_env_bad_actions(self):
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
|
policy_spec=BadPolicy,
|
|
sample_async=True,
|
|
rollout_fragment_length=40,
|
|
batch_mode="truncate_episodes",
|
|
)
|
|
self.assertRaises(Exception, lambda: ev.sample())
|
|
|
|
def test_train_cartpole_off_policy(self):
|
|
register_env(
|
|
"test3",
|
|
lambda _: PartOffPolicyServing(gym.make("CartPole-v0"), off_pol_frac=0.2),
|
|
)
|
|
config = {
|
|
"num_workers": 0,
|
|
"exploration_config": {"epsilon_timesteps": 100},
|
|
}
|
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
|
dqn = DQNTrainer(env="test3", config=config)
|
|
reached = False
|
|
for i in range(50):
|
|
result = dqn.train()
|
|
print(
|
|
"Iteration {}, reward {}, timesteps {}".format(
|
|
i, result["episode_reward_mean"], result["timesteps_total"]
|
|
)
|
|
)
|
|
if result["episode_reward_mean"] >= 80:
|
|
reached = True
|
|
break
|
|
if not reached:
|
|
raise Exception("failed to improve reward")
|
|
|
|
def test_train_cartpole(self):
|
|
register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
|
|
config = {"num_workers": 0}
|
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
|
pg = PGTrainer(env="test", config=config)
|
|
reached = False
|
|
for i in range(80):
|
|
result = pg.train()
|
|
print(
|
|
"Iteration {}, reward {}, timesteps {}".format(
|
|
i, result["episode_reward_mean"], result["timesteps_total"]
|
|
)
|
|
)
|
|
if result["episode_reward_mean"] >= 80:
|
|
reached = True
|
|
break
|
|
if not reached:
|
|
raise Exception("failed to improve reward")
|
|
|
|
def test_train_cartpole_multi(self):
|
|
register_env("test2", lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
|
|
config = {"num_workers": 0}
|
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
|
pg = PGTrainer(env="test2", config=config)
|
|
reached = False
|
|
for i in range(80):
|
|
result = pg.train()
|
|
print(
|
|
"Iteration {}, reward {}, timesteps {}".format(
|
|
i, result["episode_reward_mean"], result["timesteps_total"]
|
|
)
|
|
)
|
|
if result["episode_reward_mean"] >= 80:
|
|
reached = True
|
|
break
|
|
if not reached:
|
|
raise Exception("failed to improve reward")
|
|
|
|
def test_external_env_horizon_not_supported(self):
|
|
ev = RolloutWorker(
|
|
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
|
policy_spec=MockPolicy,
|
|
episode_horizon=20,
|
|
rollout_fragment_length=10,
|
|
batch_mode="complete_episodes",
|
|
)
|
|
self.assertRaises(ValueError, lambda: ev.sample())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|