ray/rllib/agents/tests/test_callbacks.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

111 lines
4.4 KiB
Python
Raw Normal View History

import unittest
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks, MultiCallbacks
import ray.rllib.algorithms.dqn as dqn
from ray.rllib.utils.test_utils import framework_iterator
class OnSubEnvironmentCreatedCallback(DefaultCallbacks):
def on_sub_environment_created(
self, *, worker, sub_environment, env_context, **kwargs
):
# Create a vector-index-sum property per remote worker.
if not hasattr(worker, "sum_sub_env_vector_indices"):
worker.sum_sub_env_vector_indices = 0
# Add the sub-env's vector index to the counter.
worker.sum_sub_env_vector_indices += env_context.vector_index
print(
f"sub-env {sub_environment} created; "
f"worker={worker.worker_index}; "
f"vector-idx={env_context.vector_index}"
)
class TestCallbacks(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_on_sub_environment_created(self):
base_config = {
"env": "CartPole-v1",
# Create 4 sub-environments per remote worker.
"num_envs_per_worker": 4,
# Create 2 remote workers.
"num_workers": 2,
}
for callbacks in (
OnSubEnvironmentCreatedCallback,
MultiCallbacks([OnSubEnvironmentCreatedCallback]),
):
config = dict(base_config, callbacks=callbacks)
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQNTrainer(config=config)
# Fake the counter on the local worker (doesn't have an env) and
# set it to -1 so the below `foreach_worker()` won't fail.
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
# Get sub-env vector index sums from the 2 remote workers:
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
lambda w: w.sum_sub_env_vector_indices
)
# Local worker has no environments -> Expect the -1 special
# value returned by the above lambda.
self.assertTrue(sum_sub_env_vector_indices[0] == -1)
# Both remote workers (index 1 and 2) have a vector index counter
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
trainer.stop()
def test_on_sub_environment_created_with_remote_envs(self):
base_config = {
"env": "CartPole-v1",
# Make each sub-environment a ray actor.
"remote_worker_envs": True,
# Create 4 sub-environments (ray remote actors) per remote
# worker.
"num_envs_per_worker": 4,
# Create 2 remote workers.
"num_workers": 2,
}
for callbacks in (
OnSubEnvironmentCreatedCallback,
MultiCallbacks([OnSubEnvironmentCreatedCallback]),
):
config = dict(base_config, callbacks=callbacks)
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQNTrainer(config=config)
# Fake the counter on the local worker (doesn't have an env) and
# set it to -1 so the below `foreach_worker()` won't fail.
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
# Get sub-env vector index sums from the 2 remote workers:
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
lambda w: w.sum_sub_env_vector_indices
)
# Local worker has no environments -> Expect the -1 special
# value returned by the above lambda.
self.assertTrue(sum_sub_env_vector_indices[0] == -1)
# Both remote workers (index 1 and 2) have a vector index counter
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
trainer.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))