ray/rllib/tests/test_execution.py

253 lines
8.5 KiB
Python

import numpy as np
import time
import gym
import queue
import unittest
import ray
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER
from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
AsyncGradients,
ConcatBatches,
StandardizeFields,
)
from ray.rllib.execution.train_ops import (
TrainOneStep,
ComputeGradients,
AverageGradients,
)
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.util.iter import LocalIterator, from_range
from ray.util.iter_metrics import SharedMetrics
def iter_list(values):
return LocalIterator(lambda _: values, SharedMetrics())
def make_workers(n):
local = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_spec=PPOTF1Policy,
rollout_fragment_length=100,
)
remotes = [
RolloutWorker.as_remote().remote(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_spec=PPOTF1Policy,
rollout_fragment_length=100,
)
for _ in range(n)
]
workers = WorkerSet._from_existing(local, remotes)
return workers
class TestExecution(unittest.TestCase):
def test_concurrently(self):
a = iter_list([1, 2, 3])
b = iter_list([4, 5, 6])
c = Concurrently([a, b], mode="round_robin")
assert c.take(6) == [1, 4, 2, 5, 3, 6]
a = iter_list([1, 2, 3])
b = iter_list([4, 5, 6])
c = Concurrently([a, b], mode="async")
assert c.take(6) == [1, 4, 2, 5, 3, 6]
def test_concurrently_weighted(self):
a = iter_list([1, 1, 1])
b = iter_list([2, 2, 2])
c = iter_list([3, 3, 3])
c = Concurrently([a, b, c], mode="round_robin", round_robin_weights=[3, 1, 2])
assert c.take(9) == [1, 1, 1, 2, 3, 3, 2, 3, 2]
a = iter_list([1, 1, 1])
b = iter_list([2, 2, 2])
c = iter_list([3, 3, 3])
c = Concurrently([a, b, c], mode="round_robin", round_robin_weights=[1, 1, "*"])
assert c.take(9) == [1, 2, 3, 3, 3, 1, 2, 1, 2]
def test_concurrently_output(self):
a = iter_list([1, 2, 3])
b = iter_list([4, 5, 6])
c = Concurrently([a, b], mode="round_robin", output_indexes=[1])
assert c.take(6) == [4, 5, 6]
a = iter_list([1, 2, 3])
b = iter_list([4, 5, 6])
c = Concurrently([a, b], mode="round_robin", output_indexes=[0, 1])
assert c.take(6) == [1, 4, 2, 5, 3, 6]
def test_enqueue_dequeue(self):
a = iter_list([1, 2, 3])
q = queue.Queue(100)
a.for_each(Enqueue(q)).take(3)
assert q.qsize() == 3
assert q.get_nowait() == 1
assert q.get_nowait() == 2
assert q.get_nowait() == 3
q.put("a")
q.put("b")
q.put("c")
a = Dequeue(q)
assert a.take(3) == ["a", "b", "c"]
def test_metrics(self):
workers = make_workers(1)
workers.foreach_worker(lambda w: w.sample())
a = from_range(10, repeat=True).gather_sync()
b = StandardMetricsReporting(
a,
workers,
{
"min_time_s_per_iteration": 2.5,
"min_sample_timesteps_per_iteration": 0,
"metrics_num_episodes_for_smoothing": 10,
"metrics_episode_collection_timeout_s": 10,
"keep_per_episode_custom_metrics": False,
},
)
start = time.time()
res1 = next(b)
assert res1["episode_reward_mean"] > 0, res1
res2 = next(b)
assert res2["episode_reward_mean"] > 0, res2
assert time.time() - start > 2.4
workers.stop()
def test_rollouts(self):
workers = make_workers(2)
a = ParallelRollouts(workers, mode="bulk_sync")
assert next(a).count == 200
counters = a.shared_metrics.get().counters
assert counters[STEPS_SAMPLED_COUNTER] == 200, counters
a = ParallelRollouts(workers, mode="async")
assert next(a).count == 100
counters = a.shared_metrics.get().counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
workers.stop()
def test_rollouts_local(self):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
assert next(a).count == 100
counters = a.shared_metrics.get().counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
workers.stop()
def test_concat_batches(self):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="async")
b = a.combine(ConcatBatches(1000))
assert next(b).count == 1000
timers = b.shared_metrics.get().timers
assert "sample" in timers
def test_standardize(self):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="async")
b = a.for_each(StandardizeFields([SampleBatch.EPS_ID]))
batch = next(b)
assert abs(np.mean(batch[SampleBatch.EPS_ID])) < 0.001, batch
assert abs(np.std(batch[SampleBatch.EPS_ID]) - 1.0) < 0.001, batch
def test_async_grads(self):
workers = make_workers(2)
a = AsyncGradients(workers)
res1 = next(a)
assert isinstance(res1, tuple) and len(res1) == 2, res1
counters = a.shared_metrics.get().counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
workers.stop()
def test_train_one_step(self):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(TrainOneStep(workers))
batch, stats = next(b)
assert isinstance(batch, SampleBatch)
assert DEFAULT_POLICY_ID in stats
assert "learner_stats" in stats[DEFAULT_POLICY_ID]
counters = a.shared_metrics.get().counters
assert counters[STEPS_SAMPLED_COUNTER] == 100, counters
assert counters[STEPS_TRAINED_COUNTER] == 100, counters
timers = a.shared_metrics.get().timers
assert "learn" in timers
workers.stop()
def test_compute_gradients(self):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(ComputeGradients(workers))
grads, counts = next(b)
assert counts == 100, counts
timers = a.shared_metrics.get().timers
assert "compute_grads" in timers
def test_avg_gradients(self):
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(ComputeGradients(workers)).batch(4)
c = b.for_each(AverageGradients())
grads, counts = next(c)
assert counts == 400, counts
def test_store_to_replay_local(self):
buf = MultiAgentReplayBuffer(
num_shards=1,
capacity=1000,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=0.0001,
)
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(StoreToReplayBuffer(local_buffer=buf))
next(b)
assert buf.sample(100).count == 100
replay_op = Replay(local_buffer=buf, num_items_to_replay=100)
assert next(replay_op).count == 100
def test_store_to_replay_actor(self):
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)
actor = ReplayActor.remote(
num_shards=1,
capacity=1000,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=0.0001,
)
assert len(ray.get(actor.sample.remote(100))) == 0
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(StoreToReplayBuffer(actors=[actor]))
next(b)
assert ray.get(actor.sample.remote(100)).count == 100
replay_op = Replay(actors=[actor], num_items_to_replay=100)
assert next(replay_op).count == 100
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))