ray/rllib/tests/test_rollout_worker.py
Sven Mika 4fd8977eaf
[RLlib] Minor cleanup in preparation to tf2.x support. (#9130)
* WIP.

* Fixes.

* LINT.

* Fixes.

* Fixes and LINT.

* WIP.
2020-06-25 19:01:32 +02:00

547 lines
19 KiB
Python

from collections import Counter
import gym
import numpy as np
import os
import random
import time
import unittest
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.tune.registry import register_env
class MockPolicy(RandomPolicy):
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
return np.array([random.choice([0, 1])] * len(obs_batch)), [], {}
def postprocess_trajectory(self,
batch,
other_agent_batches=None,
episode=None):
assert episode is not None
return compute_advantages(
batch, 100.0, 0.9, use_gae=False, use_critic=False)
class BadPolicy(RandomPolicy):
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None,
explore=None,
timestep=None,
**kwargs):
raise Exception("intentional error")
class FailOnStepEnv(gym.Env):
def __init__(self):
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
raise ValueError("kaboom")
def step(self, action):
raise ValueError("kaboom")
class MockEnv(gym.Env):
def __init__(self, episode_length, config=None):
self.episode_length = episode_length
self.config = config
self.i = 0
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return 0, 1, self.i >= self.episode_length, {}
class MockEnv2(gym.Env):
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, 100, self.i >= self.episode_length, {}
class MockVectorEnv(VectorEnv):
def __init__(self, episode_length, num_envs):
super().__init__(
observation_space=gym.spaces.Discrete(1),
action_space=gym.spaces.Discrete(2),
num_envs=num_envs)
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
def vector_reset(self):
return [e.reset() for e in self.envs]
def reset_at(self, index):
return self.envs[index].reset()
def vector_step(self, actions):
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for i in range(len(self.envs)):
obs, rew, done, info = self.envs[i].step(actions[i])
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch
def get_unwrapped(self):
return self.envs
class TestRolloutWorker(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=5)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_basic(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
batch = ev.sample()
for key in [
"obs", "actions", "rewards", "dones", "advantages",
"prev_rewards", "prev_actions"
]:
self.assertIn(key, batch)
self.assertGreater(np.abs(np.mean(batch[key])), 0)
def to_prev(vec):
out = np.zeros_like(vec)
for i, v in enumerate(vec):
if i + 1 < len(out) and not batch["dones"][i]:
out[i + 1] = v
return out.tolist()
self.assertEqual(batch["prev_rewards"].tolist(),
to_prev(batch["rewards"]))
self.assertEqual(batch["prev_actions"].tolist(),
to_prev(batch["actions"]))
self.assertGreater(batch["advantages"][0], 1)
ev.stop()
def test_batch_ids(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
rollout_fragment_length=1)
batch1 = ev.sample()
batch2 = ev.sample()
self.assertEqual(len(set(batch1["unroll_id"])), 1)
self.assertEqual(len(set(batch2["unroll_id"])), 1)
self.assertEqual(
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
ev.stop()
def test_global_vars_update(self):
# Allow for Unittest run.
ray.init(num_cpus=5, ignore_reinit_error=True)
for fw in framework_iterator(frameworks=()):
agent = A2CTrainer(
env="CartPole-v0",
config={
"num_workers": 1,
"lr_schedule": [[0, 0.1], [100000, 0.000001]],
"framework": fw,
})
result = agent.train()
for i in range(10):
result = agent.train()
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
print("num_steps_sampled={}".format(
result["info"]["num_steps_sampled"]))
print("num_steps_trained={}".format(
result["info"]["num_steps_trained"]))
if i == 0:
self.assertGreater(
result["info"]["learner"]["default_policy"]["cur_lr"],
0.01)
if result["info"]["learner"]["default_policy"]["cur_lr"] < \
0.07:
break
self.assertLess(
result["info"]["learner"]["default_policy"]["cur_lr"], 0.07)
agent.stop()
def test_no_step_on_init(self):
register_env("fail", lambda _: FailOnStepEnv())
for fw in framework_iterator(frameworks=()):
pg = PGTrainer(
env="fail", config={
"num_workers": 1,
"framework": fw,
})
self.assertRaises(Exception, lambda: pg.train())
pg.stop()
def test_callbacks(self):
for fw in framework_iterator(frameworks=("torch", "tf")):
counts = Counter()
pg = PGTrainer(
env="CartPole-v0", config={
"num_workers": 0,
"rollout_fragment_length": 50,
"train_batch_size": 50,
"callbacks": {
"on_episode_start":
lambda x: counts.update({"start": 1}),
"on_episode_step":
lambda x: counts.update({"step": 1}),
"on_episode_end": lambda x: counts.update({"end": 1}),
"on_sample_end":
lambda x: counts.update({"sample": 1}),
},
"framework": fw,
})
pg.train()
pg.train()
self.assertGreater(counts["sample"], 0)
self.assertGreater(counts["start"], 0)
self.assertGreater(counts["end"], 0)
self.assertGreater(counts["step"], 0)
pg.stop()
def test_query_evaluators(self):
register_env("test", lambda _: gym.make("CartPole-v0"))
for fw in framework_iterator(frameworks=("torch", "tf")):
pg = PGTrainer(
env="test",
config={
"num_workers": 2,
"rollout_fragment_length": 5,
"num_envs_per_worker": 2,
"framework": fw,
})
results = pg.workers.foreach_worker(
lambda ev: ev.rollout_fragment_length)
results2 = pg.workers.foreach_worker_with_index(
lambda ev, i: (i, ev.rollout_fragment_length))
results3 = pg.workers.foreach_worker(
lambda ev: ev.foreach_env(lambda env: 1))
self.assertEqual(results, [10, 10, 10])
self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
pg.stop()
def test_reward_clipping(self):
# clipping on
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
clip_rewards=True,
batch_mode="complete_episodes")
self.assertEqual(max(ev.sample()["rewards"]), 1)
result = collect_metrics(ev, [])
self.assertEqual(result["episode_reward_mean"], 1000)
ev.stop()
# clipping off
ev2 = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
clip_rewards=False,
batch_mode="complete_episodes")
self.assertEqual(max(ev2.sample()["rewards"]), 100)
result2 = collect_metrics(ev2, [])
self.assertEqual(result2["episode_reward_mean"], 1000)
ev2.stop()
def test_hard_horizon(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes",
rollout_fragment_length=10,
episode_horizon=4,
soft_horizon=False)
samples = ev.sample()
# Three logical episodes and correct episode resets (always after 4
# steps).
self.assertEqual(len(set(samples["eps_id"])), 3)
for i in range(4):
self.assertEqual(np.argmax(samples["obs"][i]), i)
self.assertEqual(np.argmax(samples["obs"][4]), 0)
# 3 done values.
self.assertEqual(sum(samples["dones"]), 3)
ev.stop()
# A gym env's max_episode_steps is smaller than Trainer's horizon.
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
batch_mode="complete_episodes",
rollout_fragment_length=10,
episode_horizon=6,
soft_horizon=False)
samples = ev.sample()
# 12 steps due to `complete_episodes` batch_mode.
self.assertEqual(len(samples["eps_id"]), 12)
# Two logical episodes and correct episode resets (always after 6(!)
# steps).
self.assertEqual(len(set(samples["eps_id"])), 2)
# 2 done values after 6 and 12 steps.
check(samples["dones"], [
False, False, False, False, False, True, False, False, False,
False, False, True
])
ev.stop()
def test_soft_horizon(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes",
rollout_fragment_length=10,
episode_horizon=4,
soft_horizon=True)
samples = ev.sample()
# three logical episodes
self.assertEqual(len(set(samples["eps_id"])), 3)
# only 1 hard done value
self.assertEqual(sum(samples["dones"]), 1)
ev.stop()
def test_metrics(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes")
remote_ev = RolloutWorker.as_remote().remote(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes")
ev.sample()
ray.get(remote_ev.sample.remote())
result = collect_metrics(ev, [remote_ev])
self.assertEqual(result["episodes_this_iter"], 20)
self.assertEqual(result["episode_reward_mean"], 10)
ev.stop()
def test_async(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
sample_async=True,
policy=MockPolicy)
batch = ev.sample()
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
self.assertIn(key, batch)
self.assertGreater(batch["advantages"][0], 1)
ev.stop()
def test_auto_vectorization(self):
ev = RolloutWorker(
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy=MockPolicy,
batch_mode="truncate_episodes",
rollout_fragment_length=2,
num_envs=8)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_this_iter"], 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_this_iter"], 8)
indices = []
for env in ev.async_env.vector_env.envs:
self.assertEqual(env.unwrapped.config.worker_index, 0)
indices.append(env.unwrapped.config.vector_index)
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
ev.stop()
def test_batches_larger_when_vectorized(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=8),
policy=MockPolicy,
batch_mode="truncate_episodes",
rollout_fragment_length=4,
num_envs=4)
batch = ev.sample()
self.assertEqual(batch.count, 16)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_this_iter"], 0)
batch = ev.sample()
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_this_iter"], 4)
ev.stop()
def test_vector_env_support(self):
ev = RolloutWorker(
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
policy=MockPolicy,
batch_mode="truncate_episodes",
rollout_fragment_length=10)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_this_iter"], 0)
for _ in range(8):
batch = ev.sample()
self.assertEqual(batch.count, 10)
result = collect_metrics(ev, [])
self.assertEqual(result["episodes_this_iter"], 8)
ev.stop()
def test_truncate_episodes(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
rollout_fragment_length=15,
batch_mode="truncate_episodes")
batch = ev.sample()
self.assertEqual(batch.count, 15)
ev.stop()
def test_complete_episodes(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
rollout_fragment_length=5,
batch_mode="complete_episodes")
batch = ev.sample()
self.assertEqual(batch.count, 10)
ev.stop()
def test_complete_episodes_packing(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
rollout_fragment_length=15,
batch_mode="complete_episodes")
batch = ev.sample()
self.assertEqual(batch.count, 20)
self.assertEqual(
batch["t"].tolist(),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
ev.stop()
def test_filter_sync(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
time.sleep(2)
ev.sample()
filters = ev.get_filters(flush_after=True)
obs_f = filters[DEFAULT_POLICY_ID]
self.assertNotEqual(obs_f.rs.n, 0)
self.assertNotEqual(obs_f.buffer.n, 0)
ev.stop()
def test_get_filters(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
self.sample_and_flush(ev)
filters = ev.get_filters(flush_after=False)
time.sleep(2)
filters2 = ev.get_filters(flush_after=False)
obs_f = filters[DEFAULT_POLICY_ID]
obs_f2 = filters2[DEFAULT_POLICY_ID]
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
ev.stop()
def test_sync_filter(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
obs_f = self.sample_and_flush(ev)
# Current State
filters = ev.get_filters(flush_after=False)
obs_f = filters[DEFAULT_POLICY_ID]
self.assertLessEqual(obs_f.buffer.n, 20)
new_obsf = obs_f.copy()
new_obsf.rs._n = 100
ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
filters = ev.get_filters(flush_after=False)
obs_f = filters[DEFAULT_POLICY_ID]
self.assertGreaterEqual(obs_f.rs.n, 100)
self.assertLessEqual(obs_f.buffer.n, 20)
ev.stop()
def test_extra_python_envs(self):
extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"}
self.assertFalse("env_key_1" in os.environ)
self.assertFalse("env_key_2" in os.environ)
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
extra_python_environs=extra_envs)
self.assertTrue("env_key_1" in os.environ)
self.assertTrue("env_key_2" in os.environ)
ev.stop()
# reset to original
del os.environ["env_key_1"]
del os.environ["env_key_2"]
def sample_and_flush(self, ev):
time.sleep(2)
ev.sample()
filters = ev.get_filters(flush_after=True)
obs_f = filters[DEFAULT_POLICY_ID]
self.assertNotEqual(obs_f.rs.n, 0)
self.assertNotEqual(obs_f.buffer.n, 0)
return obs_f
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))