[RLlib] Add on_sub_environment_created to DefaultCallbacks class. (#21893)

This commit is contained in:
Sven Mika 2022-02-04 22:22:47 +01:00 committed by GitHub
parent d9dc388082
commit f6617506a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 289 additions and 79 deletions

View file

@ -573,6 +573,13 @@ py_test(
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Generic (all Trainers) # Generic (all Trainers)
py_test(
name = "test_callbacks",
tags = ["team:ml", "trainers_dir"],
size = "medium",
srcs = ["agents/tests/test_callbacks.py"]
)
py_test( py_test(
name = "test_trainer", name = "test_trainer",
tags = ["team:ml", "trainers_dir"], tags = ["team:ml", "trainers_dir"],

View file

@ -3,7 +3,8 @@ import os
import tracemalloc import tracemalloc
from typing import Dict, Optional, TYPE_CHECKING from typing import Dict, Optional, TYPE_CHECKING
from ray.rllib.env import BaseEnv from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode import Episode
@ -15,7 +16,7 @@ from ray.rllib.utils.exploration.random_encoder import (
compute_states_entropy, compute_states_entropy,
update_beta, update_beta,
) )
from ray.rllib.utils.typing import AgentID, PolicyID from ray.rllib.utils.typing import AgentID, EnvType, PolicyID
# Import psutil after ray so the packaged version is used. # Import psutil after ray so the packaged version is used.
import psutil import psutil
@ -44,6 +45,31 @@ class DefaultCallbacks:
) )
self.legacy_callbacks = legacy_callbacks_dict or {} self.legacy_callbacks = legacy_callbacks_dict or {}
def on_sub_environment_created(
self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
env_context: EnvContext,
**kwargs,
) -> None:
"""Callback run when a new sub-environment has been created.
This method gets callled after each sub-environment (usually a
gym.Env) has been created, validated (RLlib built-in validation
+ possible custom validation function implemented by overriding
`Trainer.validate_env()`), wrapped (e.g. video-wrapper), and seeded.
Args:
worker: Reference to the current rollout worker.
sub_environment: The sub-environment instance that has been
created. This is usally a gym.Env object.
env_context: The `EnvContext` object that has been passed to
the env's constructor.
kwargs: Forward compatibility placeholder.
"""
pass
def on_episode_start( def on_episode_start(
self, self,
*, *,

View file

@ -0,0 +1,100 @@
import unittest
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.dqn as dqn
from ray.rllib.utils.test_utils import framework_iterator
class OnSubEnvironmentCreatedCallback(DefaultCallbacks):
def on_sub_environment_created(
self, *, worker, sub_environment, env_context, **kwargs
):
# Create a vector-index-sum property per remote worker.
if not hasattr(worker, "sum_sub_env_vector_indices"):
worker.sum_sub_env_vector_indices = 0
# Add the sub-env's vector index to the counter.
worker.sum_sub_env_vector_indices += env_context.vector_index
print(
f"sub-env {sub_environment} created; "
f"worker={worker.worker_index}; "
f"vector-idx={env_context.vector_index}"
)
class TestCallbacks(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_on_sub_environment_created(self):
config = {
"env": "CartPole-v1",
# Create 4 sub-environments per remote worker.
"num_envs_per_worker": 4,
# Create 2 remote workers.
"num_workers": 2,
"callbacks": OnSubEnvironmentCreatedCallback,
}
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQNTrainer(config=config)
# Fake the counter on the local worker (doesn't have an env) and
# set it to -1 so the below `foreach_worker()` won't fail.
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
# Get sub-env vector index sums from the 2 remote workers:
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
lambda w: w.sum_sub_env_vector_indices
)
# Local worker has no environments -> Expect the -1 special
# value returned by the above lambda.
self.assertTrue(sum_sub_env_vector_indices[0] == -1)
# Both remote workers (index 1 and 2) have a vector index counter
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
trainer.stop()
def test_on_sub_environment_created_with_remote_envs(self):
config = {
"env": "CartPole-v1",
# Make each sub-environment a ray actor.
"remote_worker_envs": True,
# Create 4 sub-environments (ray remote actors) per remote
# worker.
"num_envs_per_worker": 4,
# Create 2 remote workers.
"num_workers": 2,
"callbacks": OnSubEnvironmentCreatedCallback,
}
for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQNTrainer(config=config)
# Fake the counter on the local worker (doesn't have an env) and
# set it to -1 so the below `foreach_worker()` won't fail.
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
# Get sub-env vector index sums from the 2 remote workers:
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
lambda w: w.sum_sub_env_vector_indices
)
# Local worker has no environments -> Expect the -1 special
# value returned by the above lambda.
self.assertTrue(sum_sub_env_vector_indices[0] == -1)
# Both remote workers (index 1 and 2) have a vector index counter
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
trainer.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
ASYNC_RESET_RETURN = "async_reset_return" ASYNC_RESET_RETURN = "async_reset_return"
@ -717,6 +718,7 @@ def convert_to_base_env(
num_envs: int = 1, num_envs: int = 1,
remote_envs: bool = False, remote_envs: bool = False,
remote_env_batch_wait_ms: int = 0, remote_env_batch_wait_ms: int = 0,
worker: Optional["RolloutWorker"] = None,
) -> "BaseEnv": ) -> "BaseEnv":
"""Converts an RLlib-supported env into a BaseEnv object. """Converts an RLlib-supported env into a BaseEnv object.
@ -748,6 +750,10 @@ def convert_to_base_env(
remote_env_batch_wait_ms: The wait time (in ms) to poll remote remote_env_batch_wait_ms: The wait time (in ms) to poll remote
sub-environments for, if applicable. Only used if sub-environments for, if applicable. Only used if
`remote_envs` is True. `remote_envs` is True.
worker: An optional RolloutWorker that owns the env. This is only
used if `remote_worker_envs` is True in your config and the
`on_sub_environment_created` custom callback needs to be called
on each created actor.
Returns: Returns:
The resulting BaseEnv object. The resulting BaseEnv object.
@ -761,7 +767,7 @@ def convert_to_base_env(
if remote_envs and num_envs == 1: if remote_envs and num_envs == 1:
raise ValueError( raise ValueError(
"Remote envs only make sense to use if num_envs > 1 " "Remote envs only make sense to use if num_envs > 1 "
"(i.e. vectorization is enabled)." "(i.e. environment vectorization is enabled)."
) )
# Given `env` is already a BaseEnv -> Return as is. # Given `env` is already a BaseEnv -> Return as is.
@ -789,6 +795,7 @@ def convert_to_base_env(
multiagent=multiagent, multiagent=multiagent,
remote_env_batch_wait_ms=remote_env_batch_wait_ms, remote_env_batch_wait_ms=remote_env_batch_wait_ms,
existing_envs=[env], existing_envs=[env],
worker=worker,
) )
# Sub-environments are not ray.remote actors. # Sub-environments are not ray.remote actors.
else: else:

View file

@ -1,5 +1,5 @@
import logging import logging
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
import gym import gym
@ -8,6 +8,9 @@ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID
if TYPE_CHECKING:
from ray.rllib.evaluation.rollout_worker import RolloutWorker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,6 +36,7 @@ class RemoteBaseEnv(BaseEnv):
multiagent: bool, multiagent: bool,
remote_env_batch_wait_ms: int, remote_env_batch_wait_ms: int,
existing_envs: Optional[List[ray.actor.ActorHandle]] = None, existing_envs: Optional[List[ray.actor.ActorHandle]] = None,
worker: Optional["RolloutWorker"] = None,
): ):
"""Initializes a RemoteVectorEnv instance. """Initializes a RemoteVectorEnv instance.
@ -51,71 +55,105 @@ class RemoteBaseEnv(BaseEnv):
existing_envs: Optional list of already created sub-environments. existing_envs: Optional list of already created sub-environments.
These will be used as-is and only as many new sub-envs as These will be used as-is and only as many new sub-envs as
necessary (`num_envs - len(existing_envs)`) will be created. necessary (`num_envs - len(existing_envs)`) will be created.
worker: An optional RolloutWorker that owns the env. This is only
used if `remote_worker_envs` is True in your config and the
`on_sub_environment_created` custom callback needs to be
called on each created actor.
""" """
# Could be creating local or remote envs. # Could be creating local or remote envs.
self.make_env = make_env self.make_env = make_env
# Whether the given `make_env` callable already returns ray.remote
# objects or not.
self.make_env_creates_actors = False
# Already existing env objects (generated by the RolloutWorker).
self.existing_envs = existing_envs or []
self.num_envs = num_envs self.num_envs = num_envs
self.multiagent = multiagent self.multiagent = multiagent
self.poll_timeout = remote_env_batch_wait_ms / 1000 self.poll_timeout = remote_env_batch_wait_ms / 1000
self.worker = worker
# Already existing env objects (generated by the RolloutWorker).
existing_envs = existing_envs or []
# Whether the given `make_env` callable already returns ActorHandles
# (@ray.remote class instances) or not.
self.make_env_creates_actors = False
self._observation_space = None
self._action_space = None
# List of ray actor handles (each handle points to one @ray.remote # List of ray actor handles (each handle points to one @ray.remote
# sub-environment). # sub-environment).
self.actors: Optional[List[ray.actor.ActorHandle]] = None self.actors: Optional[List[ray.actor.ActorHandle]] = None
self._observation_space = None
self._action_space = None # `self.make_env` already produces Actors: Use it directly.
if len(existing_envs) > 0 and isinstance(
existing_envs[0], ray.actor.ActorHandle
):
self.make_env_creates_actors = True
self.actors = existing_envs
while len(self.actors) < self.num_envs:
sub_env = self.make_env(len(self.actors))
if self.worker is not None:
self.worker.callbacks.on_sub_environment_created(
worker=self.worker,
sub_environment=sub_env,
env_context=self.worker.env_context.copy_with_overrides(
vector_index=len(self.actors)
),
)
self.actors.append(sub_env)
# `self.make_env` produces gym.Envs (or children thereof, such
# as MultiAgentEnv): Need to auto-wrap it here. The problem with
# this is that custom methods wil get lost. If you would like to
# keep your custom methods in your envs, you should provide the
# env class directly in your config (w/o tune.register_env()),
# such that your class can directly be made a @ray.remote
# (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
# Also, if `len(existing_envs) > 0`, we have to throw those away
# as we need to create ray actors here.
else:
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
if self.multiagent:
sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i)
else:
sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i)
if self.worker is not None:
self.worker.callbacks.on_sub_environment_created(
worker=self.worker,
sub_environment=sub_env,
env_context=self.worker.env_context.copy_with_overrides(
vector_index=i
),
)
return sub_env
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
# Utilize existing envs for inferring observation/action spaces.
if len(existing_envs) > 0:
self._observation_space = existing_envs[0].observation_space
self._action_space = existing_envs[0].action_space
# Have to call actors' remote methods to get observation/action spaces.
else:
self._observation_space, self._action_space = ray.get(
[
self.actors[0].observation_space.remote(),
self.actors[0].action_space.remote(),
]
)
# Dict mapping object refs (return values of @ray.remote calls), # Dict mapping object refs (return values of @ray.remote calls),
# whose actual values we are waiting for (via ray.wait in # whose actual values we are waiting for (via ray.wait in
# `self.poll()`) to their corresponding actor handles (the actors # `self.poll()`) to their corresponding actor handles (the actors
# that created these return values). # that created these return values).
self.pending: Optional[Dict[ray.actor.ActorHandle]] = None # Call `reset()` on all @ray.remote sub-environment actors.
self.pending: Dict[ray.actor.ActorHandle] = {
a.reset.remote(): a for a in self.actors
}
@override(BaseEnv) @override(BaseEnv)
def poll( def poll(
self, self,
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]: ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
if self.actors is None:
# `self.make_env` already produces Actors: Use it directly.
if len(self.existing_envs) > 0 and isinstance(
self.existing_envs[0], ray.actor.ActorHandle
):
self.make_env_creates_actors = True
self.actors = []
while len(self.actors) < self.num_envs:
self.actors.append(self.make_env(len(self.actors)))
# `self.make_env` produces gym.Envs (or children thereof, such
# as MultiAgentEnv): Need to auto-wrap it here. The problem with
# this is that custom methods wil get lost. If you would like to
# keep your custom methods in your envs, you should provide the
# env class directly in your config (w/o tune.register_env()),
# such that your class will directly be made a @ray.remote
# (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
else:
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
if self.multiagent:
return _RemoteMultiAgentEnv.remote(self.make_env, i)
else:
return _RemoteSingleAgentEnv.remote(self.make_env, i)
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
self._observation_space = ray.get(
self.actors[0].observation_space.remote()
)
self._action_space = ray.get(self.actors[0].action_space.remote())
# Lazy initialization. Call `reset()` on all @ray.remote
# sub-environment actors at the beginning.
if self.pending is None:
# Initialize our pending object ref -> actor handle mapping
# dict.
self.pending = {a.reset.remote(): a for a in self.actors}
# each keyed by env_id in [0, num_remote_envs) # each keyed by env_id in [0, num_remote_envs)
obs, rewards, dones, infos = {}, {}, {}, {} obs, rewards, dones, infos = {}, {}, {}, {}
@ -247,8 +285,8 @@ class _RemoteMultiAgentEnv:
def step(self, action_dict): def step(self, action_dict):
return self.env.step(action_dict) return self.env.step(action_dict)
# defining these 2 functions that way this information can be queried # Defining these 2 functions that way this information can be queried
# with a call to ray.get() # with a call to ray.get().
def observation_space(self): def observation_space(self):
return self.env.observation_space return self.env.observation_space
@ -276,8 +314,8 @@ class _RemoteSingleAgentEnv:
done["__all__"] = done[_DUMMY_AGENT_ID] done["__all__"] = done[_DUMMY_AGENT_ID]
return obs, rew, done, info return obs, rew, done, info
# defining these 2 functions that way this information can be queried # Defining these 2 functions that way this information can be queried
# with a call to ray.get() # with a call to ray.get().
def observation_space(self): def observation_space(self):
return self.env.observation_space return self.env.observation_space

View file

@ -427,6 +427,7 @@ class RolloutWorker(ParallelIteratorWorker):
worker_index=worker_index, worker_index=worker_index,
vector_index=0, vector_index=0,
num_workers=num_workers, num_workers=num_workers,
remote=remote_worker_envs,
) )
self.env_context = env_context self.env_context = env_context
self.policy_config: PartialTrainerConfigDict = policy_config self.policy_config: PartialTrainerConfigDict = policy_config
@ -478,7 +479,7 @@ class RolloutWorker(ParallelIteratorWorker):
# 3) Seed the env, if necessary. # 3) Seed the env, if necessary.
# 4) Vectorize the existing single env by creating more clones of # 4) Vectorize the existing single env by creating more clones of
# this env and wrapping it with the RLlib BaseEnv class. # this env and wrapping it with the RLlib BaseEnv class.
self.env = None self.env = self.make_sub_env_fn = None
# Create a (single) env for this worker. # Create a (single) env for this worker.
if not ( if not (
@ -539,34 +540,17 @@ class RolloutWorker(ParallelIteratorWorker):
# dependency on each other right now, so we would settle on # dependency on each other right now, so we would settle on
# duplicating the random seed setting logic for now. # duplicating the random seed setting logic for now.
_update_env_seed_if_necessary(self.env, seed, worker_index, 0) _update_env_seed_if_necessary(self.env, seed, worker_index, 0)
# Call custom callback function `on_sub_environment_created`.
def make_sub_env(vector_index): self.callbacks.on_sub_environment_created(
# Used to created additional environments during environment worker=self,
# vectorization. sub_environment=self.env,
env_context=self.env_context,
# Create the env context (config dict + meta-data) for
# this particular sub-env within the vectorized one.
env_ctx = env_context.copy_with_overrides(
worker_index=worker_index,
vector_index=vector_index,
remote=remote_worker_envs,
) )
# Create the sub-env.
env = env_creator(env_ctx)
# Validate first.
_validate_env(env, env_context=env_ctx)
# Custom validation function given by user.
if validate_env is not None:
validate_env(env, env_ctx)
# Use our wrapper, defined above.
env = wrap(env)
# Make sure a deterministic random seed is set on self.make_sub_env_fn = self._get_make_sub_env_fn(
# all the sub-environments if specified. env_creator, env_context, validate_env, wrap, seed
_update_env_seed_if_necessary(env, seed, worker_index, vector_index) )
return env
self.make_sub_env_fn = make_sub_env
self.spaces = spaces self.spaces = spaces
self.policy_dict = _determine_spaces_for_multi_agent_dict( self.policy_dict = _determine_spaces_for_multi_agent_dict(
@ -687,6 +671,7 @@ class RolloutWorker(ParallelIteratorWorker):
num_envs=num_envs, num_envs=num_envs,
remote_envs=remote_worker_envs, remote_envs=remote_worker_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms, remote_env_batch_wait_ms=remote_env_batch_wait_ms,
worker=self,
) )
# `truncate_episodes`: Allow a batch to contain more than one episode # `truncate_episodes`: Allow a batch to contain more than one episode
@ -1718,6 +1703,53 @@ class RolloutWorker(ParallelIteratorWorker):
logger.info(f"Built policy map: {self.policy_map}") logger.info(f"Built policy map: {self.policy_map}")
logger.info(f"Built preprocessor map: {self.preprocessors}") logger.info(f"Built preprocessor map: {self.preprocessors}")
def _get_make_sub_env_fn(
self, env_creator, env_context, validate_env, env_wrapper, seed
):
def _make_sub_env_local(vector_index):
# Used to created additional environments during environment
# vectorization.
# Create the env context (config dict + meta-data) for
# this particular sub-env within the vectorized one.
env_ctx = env_context.copy_with_overrides(vector_index=vector_index)
# Create the sub-env.
env = env_creator(env_ctx)
# Validate first.
_validate_env(env, env_context=env_ctx)
# Custom validation function given by user.
if validate_env is not None:
validate_env(env, env_ctx)
# Use our wrapper, defined above.
env = env_wrapper(env)
# Make sure a deterministic random seed is set on
# all the sub-environments if specified.
_update_env_seed_if_necessary(
env, seed, env_context.worker_index, vector_index
)
return env
if not env_context.remote:
def _make_sub_env_remote(vector_index):
sub_env = _make_sub_env_local(vector_index)
self.callbacks.on_sub_environment_created(
worker=self,
sub_environment=sub_env,
env_context=env_context.copy_with_overrides(
worker_index=env_context.worker_index,
vector_index=vector_index,
remote=False,
),
)
return sub_env
return _make_sub_env_remote
else:
return _make_sub_env_local
@Deprecated( @Deprecated(
new="Trainer.get_policy().export_model([export_dir], [onnx]?)", error=False new="Trainer.get_policy().export_model([export_dir], [onnx]?)", error=False
) )