diff --git a/rllib/BUILD b/rllib/BUILD index 4a857fa89..17c628f4d 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -573,6 +573,13 @@ py_test( # -------------------------------------------------------------------- # Generic (all Trainers) +py_test( + name = "test_callbacks", + tags = ["team:ml", "trainers_dir"], + size = "medium", + srcs = ["agents/tests/test_callbacks.py"] +) + py_test( name = "test_trainer", tags = ["team:ml", "trainers_dir"], diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index e277e12b0..b1f876b27 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -3,7 +3,8 @@ import os import tracemalloc from typing import Dict, Optional, TYPE_CHECKING -from ray.rllib.env import BaseEnv +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.env.env_context import EnvContext from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.episode import Episode @@ -15,7 +16,7 @@ from ray.rllib.utils.exploration.random_encoder import ( compute_states_entropy, update_beta, ) -from ray.rllib.utils.typing import AgentID, PolicyID +from ray.rllib.utils.typing import AgentID, EnvType, PolicyID # Import psutil after ray so the packaged version is used. import psutil @@ -44,6 +45,31 @@ class DefaultCallbacks: ) self.legacy_callbacks = legacy_callbacks_dict or {} + def on_sub_environment_created( + self, + *, + worker: "RolloutWorker", + sub_environment: EnvType, + env_context: EnvContext, + **kwargs, + ) -> None: + """Callback run when a new sub-environment has been created. + + This method gets callled after each sub-environment (usually a + gym.Env) has been created, validated (RLlib built-in validation + + possible custom validation function implemented by overriding + `Trainer.validate_env()`), wrapped (e.g. video-wrapper), and seeded. + + Args: + worker: Reference to the current rollout worker. + sub_environment: The sub-environment instance that has been + created. This is usally a gym.Env object. + env_context: The `EnvContext` object that has been passed to + the env's constructor. + kwargs: Forward compatibility placeholder. + """ + pass + def on_episode_start( self, *, diff --git a/rllib/agents/tests/test_callbacks.py b/rllib/agents/tests/test_callbacks.py new file mode 100644 index 000000000..f4851bda4 --- /dev/null +++ b/rllib/agents/tests/test_callbacks.py @@ -0,0 +1,100 @@ +import unittest + +import ray +from ray.rllib.agents.callbacks import DefaultCallbacks +import ray.rllib.agents.dqn as dqn +from ray.rllib.utils.test_utils import framework_iterator + + +class OnSubEnvironmentCreatedCallback(DefaultCallbacks): + def on_sub_environment_created( + self, *, worker, sub_environment, env_context, **kwargs + ): + # Create a vector-index-sum property per remote worker. + if not hasattr(worker, "sum_sub_env_vector_indices"): + worker.sum_sub_env_vector_indices = 0 + # Add the sub-env's vector index to the counter. + worker.sum_sub_env_vector_indices += env_context.vector_index + print( + f"sub-env {sub_environment} created; " + f"worker={worker.worker_index}; " + f"vector-idx={env_context.vector_index}" + ) + + +class TestCallbacks(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_on_sub_environment_created(self): + config = { + "env": "CartPole-v1", + # Create 4 sub-environments per remote worker. + "num_envs_per_worker": 4, + # Create 2 remote workers. + "num_workers": 2, + "callbacks": OnSubEnvironmentCreatedCallback, + } + + for _ in framework_iterator(config, frameworks=("tf", "torch")): + trainer = dqn.DQNTrainer(config=config) + # Fake the counter on the local worker (doesn't have an env) and + # set it to -1 so the below `foreach_worker()` won't fail. + trainer.workers.local_worker().sum_sub_env_vector_indices = -1 + + # Get sub-env vector index sums from the 2 remote workers: + sum_sub_env_vector_indices = trainer.workers.foreach_worker( + lambda w: w.sum_sub_env_vector_indices + ) + # Local worker has no environments -> Expect the -1 special + # value returned by the above lambda. + self.assertTrue(sum_sub_env_vector_indices[0] == -1) + # Both remote workers (index 1 and 2) have a vector index counter + # of 6 (sum of vector indices: 0 + 1 + 2 + 3). + self.assertTrue(sum_sub_env_vector_indices[1] == 6) + self.assertTrue(sum_sub_env_vector_indices[2] == 6) + trainer.stop() + + def test_on_sub_environment_created_with_remote_envs(self): + config = { + "env": "CartPole-v1", + # Make each sub-environment a ray actor. + "remote_worker_envs": True, + # Create 4 sub-environments (ray remote actors) per remote + # worker. + "num_envs_per_worker": 4, + # Create 2 remote workers. + "num_workers": 2, + "callbacks": OnSubEnvironmentCreatedCallback, + } + + for _ in framework_iterator(config, frameworks=("tf", "torch")): + trainer = dqn.DQNTrainer(config=config) + # Fake the counter on the local worker (doesn't have an env) and + # set it to -1 so the below `foreach_worker()` won't fail. + trainer.workers.local_worker().sum_sub_env_vector_indices = -1 + + # Get sub-env vector index sums from the 2 remote workers: + sum_sub_env_vector_indices = trainer.workers.foreach_worker( + lambda w: w.sum_sub_env_vector_indices + ) + # Local worker has no environments -> Expect the -1 special + # value returned by the above lambda. + self.assertTrue(sum_sub_env_vector_indices[0] == -1) + # Both remote workers (index 1 and 2) have a vector index counter + # of 6 (sum of vector indices: 0 + 1 + 2 + 3). + self.assertTrue(sum_sub_env_vector_indices[1] == 6) + self.assertTrue(sum_sub_env_vector_indices[2] == 6) + trainer.stop() + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index a899185fe..85feec4a7 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv + from ray.rllib.evaluation.rollout_worker import RolloutWorker ASYNC_RESET_RETURN = "async_reset_return" @@ -717,6 +718,7 @@ def convert_to_base_env( num_envs: int = 1, remote_envs: bool = False, remote_env_batch_wait_ms: int = 0, + worker: Optional["RolloutWorker"] = None, ) -> "BaseEnv": """Converts an RLlib-supported env into a BaseEnv object. @@ -748,6 +750,10 @@ def convert_to_base_env( remote_env_batch_wait_ms: The wait time (in ms) to poll remote sub-environments for, if applicable. Only used if `remote_envs` is True. + worker: An optional RolloutWorker that owns the env. This is only + used if `remote_worker_envs` is True in your config and the + `on_sub_environment_created` custom callback needs to be called + on each created actor. Returns: The resulting BaseEnv object. @@ -761,7 +767,7 @@ def convert_to_base_env( if remote_envs and num_envs == 1: raise ValueError( "Remote envs only make sense to use if num_envs > 1 " - "(i.e. vectorization is enabled)." + "(i.e. environment vectorization is enabled)." ) # Given `env` is already a BaseEnv -> Return as is. @@ -789,6 +795,7 @@ def convert_to_base_env( multiagent=multiagent, remote_env_batch_wait_ms=remote_env_batch_wait_ms, existing_envs=[env], + worker=worker, ) # Sub-environments are not ray.remote actors. else: diff --git a/rllib/env/remote_base_env.py b/rllib/env/remote_base_env.py index f0762a39b..3ebe40e77 100644 --- a/rllib/env/remote_base_env.py +++ b/rllib/env/remote_base_env.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING import gym @@ -8,6 +8,9 @@ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID +if TYPE_CHECKING: + from ray.rllib.evaluation.rollout_worker import RolloutWorker + logger = logging.getLogger(__name__) @@ -33,6 +36,7 @@ class RemoteBaseEnv(BaseEnv): multiagent: bool, remote_env_batch_wait_ms: int, existing_envs: Optional[List[ray.actor.ActorHandle]] = None, + worker: Optional["RolloutWorker"] = None, ): """Initializes a RemoteVectorEnv instance. @@ -51,71 +55,105 @@ class RemoteBaseEnv(BaseEnv): existing_envs: Optional list of already created sub-environments. These will be used as-is and only as many new sub-envs as necessary (`num_envs - len(existing_envs)`) will be created. + worker: An optional RolloutWorker that owns the env. This is only + used if `remote_worker_envs` is True in your config and the + `on_sub_environment_created` custom callback needs to be + called on each created actor. """ # Could be creating local or remote envs. self.make_env = make_env - # Whether the given `make_env` callable already returns ray.remote - # objects or not. - self.make_env_creates_actors = False - # Already existing env objects (generated by the RolloutWorker). - self.existing_envs = existing_envs or [] self.num_envs = num_envs self.multiagent = multiagent self.poll_timeout = remote_env_batch_wait_ms / 1000 + self.worker = worker + + # Already existing env objects (generated by the RolloutWorker). + existing_envs = existing_envs or [] + + # Whether the given `make_env` callable already returns ActorHandles + # (@ray.remote class instances) or not. + self.make_env_creates_actors = False + + self._observation_space = None + self._action_space = None # List of ray actor handles (each handle points to one @ray.remote # sub-environment). self.actors: Optional[List[ray.actor.ActorHandle]] = None - self._observation_space = None - self._action_space = None + + # `self.make_env` already produces Actors: Use it directly. + if len(existing_envs) > 0 and isinstance( + existing_envs[0], ray.actor.ActorHandle + ): + self.make_env_creates_actors = True + self.actors = existing_envs + while len(self.actors) < self.num_envs: + sub_env = self.make_env(len(self.actors)) + if self.worker is not None: + self.worker.callbacks.on_sub_environment_created( + worker=self.worker, + sub_environment=sub_env, + env_context=self.worker.env_context.copy_with_overrides( + vector_index=len(self.actors) + ), + ) + self.actors.append(sub_env) + # `self.make_env` produces gym.Envs (or children thereof, such + # as MultiAgentEnv): Need to auto-wrap it here. The problem with + # this is that custom methods wil get lost. If you would like to + # keep your custom methods in your envs, you should provide the + # env class directly in your config (w/o tune.register_env()), + # such that your class can directly be made a @ray.remote + # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`). + # Also, if `len(existing_envs) > 0`, we have to throw those away + # as we need to create ray actors here. + else: + + def make_remote_env(i): + logger.info("Launching env {} in remote actor".format(i)) + if self.multiagent: + sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i) + else: + sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i) + + if self.worker is not None: + self.worker.callbacks.on_sub_environment_created( + worker=self.worker, + sub_environment=sub_env, + env_context=self.worker.env_context.copy_with_overrides( + vector_index=i + ), + ) + return sub_env + + self.actors = [make_remote_env(i) for i in range(self.num_envs)] + # Utilize existing envs for inferring observation/action spaces. + if len(existing_envs) > 0: + self._observation_space = existing_envs[0].observation_space + self._action_space = existing_envs[0].action_space + # Have to call actors' remote methods to get observation/action spaces. + else: + self._observation_space, self._action_space = ray.get( + [ + self.actors[0].observation_space.remote(), + self.actors[0].action_space.remote(), + ] + ) + # Dict mapping object refs (return values of @ray.remote calls), # whose actual values we are waiting for (via ray.wait in # `self.poll()`) to their corresponding actor handles (the actors # that created these return values). - self.pending: Optional[Dict[ray.actor.ActorHandle]] = None + # Call `reset()` on all @ray.remote sub-environment actors. + self.pending: Dict[ray.actor.ActorHandle] = { + a.reset.remote(): a for a in self.actors + } @override(BaseEnv) def poll( self, ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]: - if self.actors is None: - # `self.make_env` already produces Actors: Use it directly. - if len(self.existing_envs) > 0 and isinstance( - self.existing_envs[0], ray.actor.ActorHandle - ): - self.make_env_creates_actors = True - self.actors = [] - while len(self.actors) < self.num_envs: - self.actors.append(self.make_env(len(self.actors))) - # `self.make_env` produces gym.Envs (or children thereof, such - # as MultiAgentEnv): Need to auto-wrap it here. The problem with - # this is that custom methods wil get lost. If you would like to - # keep your custom methods in your envs, you should provide the - # env class directly in your config (w/o tune.register_env()), - # such that your class will directly be made a @ray.remote - # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`). - else: - - def make_remote_env(i): - logger.info("Launching env {} in remote actor".format(i)) - if self.multiagent: - return _RemoteMultiAgentEnv.remote(self.make_env, i) - else: - return _RemoteSingleAgentEnv.remote(self.make_env, i) - - self.actors = [make_remote_env(i) for i in range(self.num_envs)] - self._observation_space = ray.get( - self.actors[0].observation_space.remote() - ) - self._action_space = ray.get(self.actors[0].action_space.remote()) - - # Lazy initialization. Call `reset()` on all @ray.remote - # sub-environment actors at the beginning. - if self.pending is None: - # Initialize our pending object ref -> actor handle mapping - # dict. - self.pending = {a.reset.remote(): a for a in self.actors} # each keyed by env_id in [0, num_remote_envs) obs, rewards, dones, infos = {}, {}, {}, {} @@ -247,8 +285,8 @@ class _RemoteMultiAgentEnv: def step(self, action_dict): return self.env.step(action_dict) - # defining these 2 functions that way this information can be queried - # with a call to ray.get() + # Defining these 2 functions that way this information can be queried + # with a call to ray.get(). def observation_space(self): return self.env.observation_space @@ -276,8 +314,8 @@ class _RemoteSingleAgentEnv: done["__all__"] = done[_DUMMY_AGENT_ID] return obs, rew, done, info - # defining these 2 functions that way this information can be queried - # with a call to ray.get() + # Defining these 2 functions that way this information can be queried + # with a call to ray.get(). def observation_space(self): return self.env.observation_space diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 92f780364..e37de61fd 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -427,6 +427,7 @@ class RolloutWorker(ParallelIteratorWorker): worker_index=worker_index, vector_index=0, num_workers=num_workers, + remote=remote_worker_envs, ) self.env_context = env_context self.policy_config: PartialTrainerConfigDict = policy_config @@ -478,7 +479,7 @@ class RolloutWorker(ParallelIteratorWorker): # 3) Seed the env, if necessary. # 4) Vectorize the existing single env by creating more clones of # this env and wrapping it with the RLlib BaseEnv class. - self.env = None + self.env = self.make_sub_env_fn = None # Create a (single) env for this worker. if not ( @@ -539,34 +540,17 @@ class RolloutWorker(ParallelIteratorWorker): # dependency on each other right now, so we would settle on # duplicating the random seed setting logic for now. _update_env_seed_if_necessary(self.env, seed, worker_index, 0) - - def make_sub_env(vector_index): - # Used to created additional environments during environment - # vectorization. - - # Create the env context (config dict + meta-data) for - # this particular sub-env within the vectorized one. - env_ctx = env_context.copy_with_overrides( - worker_index=worker_index, - vector_index=vector_index, - remote=remote_worker_envs, + # Call custom callback function `on_sub_environment_created`. + self.callbacks.on_sub_environment_created( + worker=self, + sub_environment=self.env, + env_context=self.env_context, ) - # Create the sub-env. - env = env_creator(env_ctx) - # Validate first. - _validate_env(env, env_context=env_ctx) - # Custom validation function given by user. - if validate_env is not None: - validate_env(env, env_ctx) - # Use our wrapper, defined above. - env = wrap(env) - # Make sure a deterministic random seed is set on - # all the sub-environments if specified. - _update_env_seed_if_necessary(env, seed, worker_index, vector_index) - return env + self.make_sub_env_fn = self._get_make_sub_env_fn( + env_creator, env_context, validate_env, wrap, seed + ) - self.make_sub_env_fn = make_sub_env self.spaces = spaces self.policy_dict = _determine_spaces_for_multi_agent_dict( @@ -687,6 +671,7 @@ class RolloutWorker(ParallelIteratorWorker): num_envs=num_envs, remote_envs=remote_worker_envs, remote_env_batch_wait_ms=remote_env_batch_wait_ms, + worker=self, ) # `truncate_episodes`: Allow a batch to contain more than one episode @@ -1718,6 +1703,53 @@ class RolloutWorker(ParallelIteratorWorker): logger.info(f"Built policy map: {self.policy_map}") logger.info(f"Built preprocessor map: {self.preprocessors}") + def _get_make_sub_env_fn( + self, env_creator, env_context, validate_env, env_wrapper, seed + ): + def _make_sub_env_local(vector_index): + # Used to created additional environments during environment + # vectorization. + + # Create the env context (config dict + meta-data) for + # this particular sub-env within the vectorized one. + env_ctx = env_context.copy_with_overrides(vector_index=vector_index) + # Create the sub-env. + env = env_creator(env_ctx) + # Validate first. + _validate_env(env, env_context=env_ctx) + # Custom validation function given by user. + if validate_env is not None: + validate_env(env, env_ctx) + # Use our wrapper, defined above. + env = env_wrapper(env) + + # Make sure a deterministic random seed is set on + # all the sub-environments if specified. + _update_env_seed_if_necessary( + env, seed, env_context.worker_index, vector_index + ) + return env + + if not env_context.remote: + + def _make_sub_env_remote(vector_index): + sub_env = _make_sub_env_local(vector_index) + self.callbacks.on_sub_environment_created( + worker=self, + sub_environment=sub_env, + env_context=env_context.copy_with_overrides( + worker_index=env_context.worker_index, + vector_index=vector_index, + remote=False, + ), + ) + return sub_env + + return _make_sub_env_remote + + else: + return _make_sub_env_local + @Deprecated( new="Trainer.get_policy().export_model([export_dir], [onnx]?)", error=False )