mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add on_sub_environment_created
to DefaultCallbacks class. (#21893)
This commit is contained in:
parent
d9dc388082
commit
f6617506a2
6 changed files with 289 additions and 79 deletions
|
@ -573,6 +573,13 @@ py_test(
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
# Generic (all Trainers)
|
# Generic (all Trainers)
|
||||||
|
py_test(
|
||||||
|
name = "test_callbacks",
|
||||||
|
tags = ["team:ml", "trainers_dir"],
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["agents/tests/test_callbacks.py"]
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_trainer",
|
name = "test_trainer",
|
||||||
tags = ["team:ml", "trainers_dir"],
|
tags = ["team:ml", "trainers_dir"],
|
||||||
|
|
|
@ -3,7 +3,8 @@ import os
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
from typing import Dict, Optional, TYPE_CHECKING
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.env import BaseEnv
|
from ray.rllib.env.base_env import BaseEnv
|
||||||
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
from ray.rllib.evaluation.episode import Episode
|
from ray.rllib.evaluation.episode import Episode
|
||||||
|
@ -15,7 +16,7 @@ from ray.rllib.utils.exploration.random_encoder import (
|
||||||
compute_states_entropy,
|
compute_states_entropy,
|
||||||
update_beta,
|
update_beta,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.typing import AgentID, PolicyID
|
from ray.rllib.utils.typing import AgentID, EnvType, PolicyID
|
||||||
|
|
||||||
# Import psutil after ray so the packaged version is used.
|
# Import psutil after ray so the packaged version is used.
|
||||||
import psutil
|
import psutil
|
||||||
|
@ -44,6 +45,31 @@ class DefaultCallbacks:
|
||||||
)
|
)
|
||||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||||
|
|
||||||
|
def on_sub_environment_created(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
worker: "RolloutWorker",
|
||||||
|
sub_environment: EnvType,
|
||||||
|
env_context: EnvContext,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Callback run when a new sub-environment has been created.
|
||||||
|
|
||||||
|
This method gets callled after each sub-environment (usually a
|
||||||
|
gym.Env) has been created, validated (RLlib built-in validation
|
||||||
|
+ possible custom validation function implemented by overriding
|
||||||
|
`Trainer.validate_env()`), wrapped (e.g. video-wrapper), and seeded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker: Reference to the current rollout worker.
|
||||||
|
sub_environment: The sub-environment instance that has been
|
||||||
|
created. This is usally a gym.Env object.
|
||||||
|
env_context: The `EnvContext` object that has been passed to
|
||||||
|
the env's constructor.
|
||||||
|
kwargs: Forward compatibility placeholder.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_episode_start(
|
def on_episode_start(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
100
rllib/agents/tests/test_callbacks.py
Normal file
100
rllib/agents/tests/test_callbacks.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||||
|
import ray.rllib.agents.dqn as dqn
|
||||||
|
from ray.rllib.utils.test_utils import framework_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class OnSubEnvironmentCreatedCallback(DefaultCallbacks):
|
||||||
|
def on_sub_environment_created(
|
||||||
|
self, *, worker, sub_environment, env_context, **kwargs
|
||||||
|
):
|
||||||
|
# Create a vector-index-sum property per remote worker.
|
||||||
|
if not hasattr(worker, "sum_sub_env_vector_indices"):
|
||||||
|
worker.sum_sub_env_vector_indices = 0
|
||||||
|
# Add the sub-env's vector index to the counter.
|
||||||
|
worker.sum_sub_env_vector_indices += env_context.vector_index
|
||||||
|
print(
|
||||||
|
f"sub-env {sub_environment} created; "
|
||||||
|
f"worker={worker.worker_index}; "
|
||||||
|
f"vector-idx={env_context.vector_index}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbacks(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_on_sub_environment_created(self):
|
||||||
|
config = {
|
||||||
|
"env": "CartPole-v1",
|
||||||
|
# Create 4 sub-environments per remote worker.
|
||||||
|
"num_envs_per_worker": 4,
|
||||||
|
# Create 2 remote workers.
|
||||||
|
"num_workers": 2,
|
||||||
|
"callbacks": OnSubEnvironmentCreatedCallback,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||||
|
trainer = dqn.DQNTrainer(config=config)
|
||||||
|
# Fake the counter on the local worker (doesn't have an env) and
|
||||||
|
# set it to -1 so the below `foreach_worker()` won't fail.
|
||||||
|
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||||
|
|
||||||
|
# Get sub-env vector index sums from the 2 remote workers:
|
||||||
|
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
|
||||||
|
lambda w: w.sum_sub_env_vector_indices
|
||||||
|
)
|
||||||
|
# Local worker has no environments -> Expect the -1 special
|
||||||
|
# value returned by the above lambda.
|
||||||
|
self.assertTrue(sum_sub_env_vector_indices[0] == -1)
|
||||||
|
# Both remote workers (index 1 and 2) have a vector index counter
|
||||||
|
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
||||||
|
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
||||||
|
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
||||||
|
trainer.stop()
|
||||||
|
|
||||||
|
def test_on_sub_environment_created_with_remote_envs(self):
|
||||||
|
config = {
|
||||||
|
"env": "CartPole-v1",
|
||||||
|
# Make each sub-environment a ray actor.
|
||||||
|
"remote_worker_envs": True,
|
||||||
|
# Create 4 sub-environments (ray remote actors) per remote
|
||||||
|
# worker.
|
||||||
|
"num_envs_per_worker": 4,
|
||||||
|
# Create 2 remote workers.
|
||||||
|
"num_workers": 2,
|
||||||
|
"callbacks": OnSubEnvironmentCreatedCallback,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||||
|
trainer = dqn.DQNTrainer(config=config)
|
||||||
|
# Fake the counter on the local worker (doesn't have an env) and
|
||||||
|
# set it to -1 so the below `foreach_worker()` won't fail.
|
||||||
|
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||||
|
|
||||||
|
# Get sub-env vector index sums from the 2 remote workers:
|
||||||
|
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
|
||||||
|
lambda w: w.sum_sub_env_vector_indices
|
||||||
|
)
|
||||||
|
# Local worker has no environments -> Expect the -1 special
|
||||||
|
# value returned by the above lambda.
|
||||||
|
self.assertTrue(sum_sub_env_vector_indices[0] == -1)
|
||||||
|
# Both remote workers (index 1 and 2) have a vector index counter
|
||||||
|
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
||||||
|
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
||||||
|
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
||||||
|
trainer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
9
rllib/env/base_env.py
vendored
9
rllib/env/base_env.py
vendored
|
@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||||
from ray.rllib.env.external_env import ExternalEnv
|
from ray.rllib.env.external_env import ExternalEnv
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
from ray.rllib.env.vector_env import VectorEnv
|
from ray.rllib.env.vector_env import VectorEnv
|
||||||
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
|
|
||||||
ASYNC_RESET_RETURN = "async_reset_return"
|
ASYNC_RESET_RETURN = "async_reset_return"
|
||||||
|
|
||||||
|
@ -717,6 +718,7 @@ def convert_to_base_env(
|
||||||
num_envs: int = 1,
|
num_envs: int = 1,
|
||||||
remote_envs: bool = False,
|
remote_envs: bool = False,
|
||||||
remote_env_batch_wait_ms: int = 0,
|
remote_env_batch_wait_ms: int = 0,
|
||||||
|
worker: Optional["RolloutWorker"] = None,
|
||||||
) -> "BaseEnv":
|
) -> "BaseEnv":
|
||||||
"""Converts an RLlib-supported env into a BaseEnv object.
|
"""Converts an RLlib-supported env into a BaseEnv object.
|
||||||
|
|
||||||
|
@ -748,6 +750,10 @@ def convert_to_base_env(
|
||||||
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
remote_env_batch_wait_ms: The wait time (in ms) to poll remote
|
||||||
sub-environments for, if applicable. Only used if
|
sub-environments for, if applicable. Only used if
|
||||||
`remote_envs` is True.
|
`remote_envs` is True.
|
||||||
|
worker: An optional RolloutWorker that owns the env. This is only
|
||||||
|
used if `remote_worker_envs` is True in your config and the
|
||||||
|
`on_sub_environment_created` custom callback needs to be called
|
||||||
|
on each created actor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The resulting BaseEnv object.
|
The resulting BaseEnv object.
|
||||||
|
@ -761,7 +767,7 @@ def convert_to_base_env(
|
||||||
if remote_envs and num_envs == 1:
|
if remote_envs and num_envs == 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Remote envs only make sense to use if num_envs > 1 "
|
"Remote envs only make sense to use if num_envs > 1 "
|
||||||
"(i.e. vectorization is enabled)."
|
"(i.e. environment vectorization is enabled)."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Given `env` is already a BaseEnv -> Return as is.
|
# Given `env` is already a BaseEnv -> Return as is.
|
||||||
|
@ -789,6 +795,7 @@ def convert_to_base_env(
|
||||||
multiagent=multiagent,
|
multiagent=multiagent,
|
||||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
||||||
existing_envs=[env],
|
existing_envs=[env],
|
||||||
|
worker=worker,
|
||||||
)
|
)
|
||||||
# Sub-environments are not ray.remote actors.
|
# Sub-environments are not ray.remote actors.
|
||||||
else:
|
else:
|
||||||
|
|
138
rllib/env/remote_base_env.py
vendored
138
rllib/env/remote_base_env.py
vendored
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
@ -8,6 +8,9 @@ from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID
|
from ray.rllib.utils.typing import MultiEnvDict, EnvType, EnvID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,6 +36,7 @@ class RemoteBaseEnv(BaseEnv):
|
||||||
multiagent: bool,
|
multiagent: bool,
|
||||||
remote_env_batch_wait_ms: int,
|
remote_env_batch_wait_ms: int,
|
||||||
existing_envs: Optional[List[ray.actor.ActorHandle]] = None,
|
existing_envs: Optional[List[ray.actor.ActorHandle]] = None,
|
||||||
|
worker: Optional["RolloutWorker"] = None,
|
||||||
):
|
):
|
||||||
"""Initializes a RemoteVectorEnv instance.
|
"""Initializes a RemoteVectorEnv instance.
|
||||||
|
|
||||||
|
@ -51,71 +55,105 @@ class RemoteBaseEnv(BaseEnv):
|
||||||
existing_envs: Optional list of already created sub-environments.
|
existing_envs: Optional list of already created sub-environments.
|
||||||
These will be used as-is and only as many new sub-envs as
|
These will be used as-is and only as many new sub-envs as
|
||||||
necessary (`num_envs - len(existing_envs)`) will be created.
|
necessary (`num_envs - len(existing_envs)`) will be created.
|
||||||
|
worker: An optional RolloutWorker that owns the env. This is only
|
||||||
|
used if `remote_worker_envs` is True in your config and the
|
||||||
|
`on_sub_environment_created` custom callback needs to be
|
||||||
|
called on each created actor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Could be creating local or remote envs.
|
# Could be creating local or remote envs.
|
||||||
self.make_env = make_env
|
self.make_env = make_env
|
||||||
# Whether the given `make_env` callable already returns ray.remote
|
|
||||||
# objects or not.
|
|
||||||
self.make_env_creates_actors = False
|
|
||||||
# Already existing env objects (generated by the RolloutWorker).
|
|
||||||
self.existing_envs = existing_envs or []
|
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.multiagent = multiagent
|
self.multiagent = multiagent
|
||||||
self.poll_timeout = remote_env_batch_wait_ms / 1000
|
self.poll_timeout = remote_env_batch_wait_ms / 1000
|
||||||
|
self.worker = worker
|
||||||
|
|
||||||
|
# Already existing env objects (generated by the RolloutWorker).
|
||||||
|
existing_envs = existing_envs or []
|
||||||
|
|
||||||
|
# Whether the given `make_env` callable already returns ActorHandles
|
||||||
|
# (@ray.remote class instances) or not.
|
||||||
|
self.make_env_creates_actors = False
|
||||||
|
|
||||||
|
self._observation_space = None
|
||||||
|
self._action_space = None
|
||||||
|
|
||||||
# List of ray actor handles (each handle points to one @ray.remote
|
# List of ray actor handles (each handle points to one @ray.remote
|
||||||
# sub-environment).
|
# sub-environment).
|
||||||
self.actors: Optional[List[ray.actor.ActorHandle]] = None
|
self.actors: Optional[List[ray.actor.ActorHandle]] = None
|
||||||
self._observation_space = None
|
|
||||||
self._action_space = None
|
# `self.make_env` already produces Actors: Use it directly.
|
||||||
|
if len(existing_envs) > 0 and isinstance(
|
||||||
|
existing_envs[0], ray.actor.ActorHandle
|
||||||
|
):
|
||||||
|
self.make_env_creates_actors = True
|
||||||
|
self.actors = existing_envs
|
||||||
|
while len(self.actors) < self.num_envs:
|
||||||
|
sub_env = self.make_env(len(self.actors))
|
||||||
|
if self.worker is not None:
|
||||||
|
self.worker.callbacks.on_sub_environment_created(
|
||||||
|
worker=self.worker,
|
||||||
|
sub_environment=sub_env,
|
||||||
|
env_context=self.worker.env_context.copy_with_overrides(
|
||||||
|
vector_index=len(self.actors)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.actors.append(sub_env)
|
||||||
|
# `self.make_env` produces gym.Envs (or children thereof, such
|
||||||
|
# as MultiAgentEnv): Need to auto-wrap it here. The problem with
|
||||||
|
# this is that custom methods wil get lost. If you would like to
|
||||||
|
# keep your custom methods in your envs, you should provide the
|
||||||
|
# env class directly in your config (w/o tune.register_env()),
|
||||||
|
# such that your class can directly be made a @ray.remote
|
||||||
|
# (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
|
||||||
|
# Also, if `len(existing_envs) > 0`, we have to throw those away
|
||||||
|
# as we need to create ray actors here.
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_remote_env(i):
|
||||||
|
logger.info("Launching env {} in remote actor".format(i))
|
||||||
|
if self.multiagent:
|
||||||
|
sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i)
|
||||||
|
else:
|
||||||
|
sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i)
|
||||||
|
|
||||||
|
if self.worker is not None:
|
||||||
|
self.worker.callbacks.on_sub_environment_created(
|
||||||
|
worker=self.worker,
|
||||||
|
sub_environment=sub_env,
|
||||||
|
env_context=self.worker.env_context.copy_with_overrides(
|
||||||
|
vector_index=i
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return sub_env
|
||||||
|
|
||||||
|
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
|
||||||
|
# Utilize existing envs for inferring observation/action spaces.
|
||||||
|
if len(existing_envs) > 0:
|
||||||
|
self._observation_space = existing_envs[0].observation_space
|
||||||
|
self._action_space = existing_envs[0].action_space
|
||||||
|
# Have to call actors' remote methods to get observation/action spaces.
|
||||||
|
else:
|
||||||
|
self._observation_space, self._action_space = ray.get(
|
||||||
|
[
|
||||||
|
self.actors[0].observation_space.remote(),
|
||||||
|
self.actors[0].action_space.remote(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Dict mapping object refs (return values of @ray.remote calls),
|
# Dict mapping object refs (return values of @ray.remote calls),
|
||||||
# whose actual values we are waiting for (via ray.wait in
|
# whose actual values we are waiting for (via ray.wait in
|
||||||
# `self.poll()`) to their corresponding actor handles (the actors
|
# `self.poll()`) to their corresponding actor handles (the actors
|
||||||
# that created these return values).
|
# that created these return values).
|
||||||
self.pending: Optional[Dict[ray.actor.ActorHandle]] = None
|
# Call `reset()` on all @ray.remote sub-environment actors.
|
||||||
|
self.pending: Dict[ray.actor.ActorHandle] = {
|
||||||
|
a.reset.remote(): a for a in self.actors
|
||||||
|
}
|
||||||
|
|
||||||
@override(BaseEnv)
|
@override(BaseEnv)
|
||||||
def poll(
|
def poll(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
|
) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
|
||||||
if self.actors is None:
|
|
||||||
# `self.make_env` already produces Actors: Use it directly.
|
|
||||||
if len(self.existing_envs) > 0 and isinstance(
|
|
||||||
self.existing_envs[0], ray.actor.ActorHandle
|
|
||||||
):
|
|
||||||
self.make_env_creates_actors = True
|
|
||||||
self.actors = []
|
|
||||||
while len(self.actors) < self.num_envs:
|
|
||||||
self.actors.append(self.make_env(len(self.actors)))
|
|
||||||
# `self.make_env` produces gym.Envs (or children thereof, such
|
|
||||||
# as MultiAgentEnv): Need to auto-wrap it here. The problem with
|
|
||||||
# this is that custom methods wil get lost. If you would like to
|
|
||||||
# keep your custom methods in your envs, you should provide the
|
|
||||||
# env class directly in your config (w/o tune.register_env()),
|
|
||||||
# such that your class will directly be made a @ray.remote
|
|
||||||
# (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
|
|
||||||
else:
|
|
||||||
|
|
||||||
def make_remote_env(i):
|
|
||||||
logger.info("Launching env {} in remote actor".format(i))
|
|
||||||
if self.multiagent:
|
|
||||||
return _RemoteMultiAgentEnv.remote(self.make_env, i)
|
|
||||||
else:
|
|
||||||
return _RemoteSingleAgentEnv.remote(self.make_env, i)
|
|
||||||
|
|
||||||
self.actors = [make_remote_env(i) for i in range(self.num_envs)]
|
|
||||||
self._observation_space = ray.get(
|
|
||||||
self.actors[0].observation_space.remote()
|
|
||||||
)
|
|
||||||
self._action_space = ray.get(self.actors[0].action_space.remote())
|
|
||||||
|
|
||||||
# Lazy initialization. Call `reset()` on all @ray.remote
|
|
||||||
# sub-environment actors at the beginning.
|
|
||||||
if self.pending is None:
|
|
||||||
# Initialize our pending object ref -> actor handle mapping
|
|
||||||
# dict.
|
|
||||||
self.pending = {a.reset.remote(): a for a in self.actors}
|
|
||||||
|
|
||||||
# each keyed by env_id in [0, num_remote_envs)
|
# each keyed by env_id in [0, num_remote_envs)
|
||||||
obs, rewards, dones, infos = {}, {}, {}, {}
|
obs, rewards, dones, infos = {}, {}, {}, {}
|
||||||
|
@ -247,8 +285,8 @@ class _RemoteMultiAgentEnv:
|
||||||
def step(self, action_dict):
|
def step(self, action_dict):
|
||||||
return self.env.step(action_dict)
|
return self.env.step(action_dict)
|
||||||
|
|
||||||
# defining these 2 functions that way this information can be queried
|
# Defining these 2 functions that way this information can be queried
|
||||||
# with a call to ray.get()
|
# with a call to ray.get().
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return self.env.observation_space
|
return self.env.observation_space
|
||||||
|
|
||||||
|
@ -276,8 +314,8 @@ class _RemoteSingleAgentEnv:
|
||||||
done["__all__"] = done[_DUMMY_AGENT_ID]
|
done["__all__"] = done[_DUMMY_AGENT_ID]
|
||||||
return obs, rew, done, info
|
return obs, rew, done, info
|
||||||
|
|
||||||
# defining these 2 functions that way this information can be queried
|
# Defining these 2 functions that way this information can be queried
|
||||||
# with a call to ray.get()
|
# with a call to ray.get().
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
return self.env.observation_space
|
return self.env.observation_space
|
||||||
|
|
||||||
|
|
|
@ -427,6 +427,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
worker_index=worker_index,
|
worker_index=worker_index,
|
||||||
vector_index=0,
|
vector_index=0,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
remote=remote_worker_envs,
|
||||||
)
|
)
|
||||||
self.env_context = env_context
|
self.env_context = env_context
|
||||||
self.policy_config: PartialTrainerConfigDict = policy_config
|
self.policy_config: PartialTrainerConfigDict = policy_config
|
||||||
|
@ -478,7 +479,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
# 3) Seed the env, if necessary.
|
# 3) Seed the env, if necessary.
|
||||||
# 4) Vectorize the existing single env by creating more clones of
|
# 4) Vectorize the existing single env by creating more clones of
|
||||||
# this env and wrapping it with the RLlib BaseEnv class.
|
# this env and wrapping it with the RLlib BaseEnv class.
|
||||||
self.env = None
|
self.env = self.make_sub_env_fn = None
|
||||||
|
|
||||||
# Create a (single) env for this worker.
|
# Create a (single) env for this worker.
|
||||||
if not (
|
if not (
|
||||||
|
@ -539,34 +540,17 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
# dependency on each other right now, so we would settle on
|
# dependency on each other right now, so we would settle on
|
||||||
# duplicating the random seed setting logic for now.
|
# duplicating the random seed setting logic for now.
|
||||||
_update_env_seed_if_necessary(self.env, seed, worker_index, 0)
|
_update_env_seed_if_necessary(self.env, seed, worker_index, 0)
|
||||||
|
# Call custom callback function `on_sub_environment_created`.
|
||||||
def make_sub_env(vector_index):
|
self.callbacks.on_sub_environment_created(
|
||||||
# Used to created additional environments during environment
|
worker=self,
|
||||||
# vectorization.
|
sub_environment=self.env,
|
||||||
|
env_context=self.env_context,
|
||||||
# Create the env context (config dict + meta-data) for
|
|
||||||
# this particular sub-env within the vectorized one.
|
|
||||||
env_ctx = env_context.copy_with_overrides(
|
|
||||||
worker_index=worker_index,
|
|
||||||
vector_index=vector_index,
|
|
||||||
remote=remote_worker_envs,
|
|
||||||
)
|
)
|
||||||
# Create the sub-env.
|
|
||||||
env = env_creator(env_ctx)
|
|
||||||
# Validate first.
|
|
||||||
_validate_env(env, env_context=env_ctx)
|
|
||||||
# Custom validation function given by user.
|
|
||||||
if validate_env is not None:
|
|
||||||
validate_env(env, env_ctx)
|
|
||||||
# Use our wrapper, defined above.
|
|
||||||
env = wrap(env)
|
|
||||||
|
|
||||||
# Make sure a deterministic random seed is set on
|
self.make_sub_env_fn = self._get_make_sub_env_fn(
|
||||||
# all the sub-environments if specified.
|
env_creator, env_context, validate_env, wrap, seed
|
||||||
_update_env_seed_if_necessary(env, seed, worker_index, vector_index)
|
)
|
||||||
return env
|
|
||||||
|
|
||||||
self.make_sub_env_fn = make_sub_env
|
|
||||||
self.spaces = spaces
|
self.spaces = spaces
|
||||||
|
|
||||||
self.policy_dict = _determine_spaces_for_multi_agent_dict(
|
self.policy_dict = _determine_spaces_for_multi_agent_dict(
|
||||||
|
@ -687,6 +671,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
num_envs=num_envs,
|
num_envs=num_envs,
|
||||||
remote_envs=remote_worker_envs,
|
remote_envs=remote_worker_envs,
|
||||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
remote_env_batch_wait_ms=remote_env_batch_wait_ms,
|
||||||
|
worker=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
# `truncate_episodes`: Allow a batch to contain more than one episode
|
# `truncate_episodes`: Allow a batch to contain more than one episode
|
||||||
|
@ -1718,6 +1703,53 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
logger.info(f"Built policy map: {self.policy_map}")
|
logger.info(f"Built policy map: {self.policy_map}")
|
||||||
logger.info(f"Built preprocessor map: {self.preprocessors}")
|
logger.info(f"Built preprocessor map: {self.preprocessors}")
|
||||||
|
|
||||||
|
def _get_make_sub_env_fn(
|
||||||
|
self, env_creator, env_context, validate_env, env_wrapper, seed
|
||||||
|
):
|
||||||
|
def _make_sub_env_local(vector_index):
|
||||||
|
# Used to created additional environments during environment
|
||||||
|
# vectorization.
|
||||||
|
|
||||||
|
# Create the env context (config dict + meta-data) for
|
||||||
|
# this particular sub-env within the vectorized one.
|
||||||
|
env_ctx = env_context.copy_with_overrides(vector_index=vector_index)
|
||||||
|
# Create the sub-env.
|
||||||
|
env = env_creator(env_ctx)
|
||||||
|
# Validate first.
|
||||||
|
_validate_env(env, env_context=env_ctx)
|
||||||
|
# Custom validation function given by user.
|
||||||
|
if validate_env is not None:
|
||||||
|
validate_env(env, env_ctx)
|
||||||
|
# Use our wrapper, defined above.
|
||||||
|
env = env_wrapper(env)
|
||||||
|
|
||||||
|
# Make sure a deterministic random seed is set on
|
||||||
|
# all the sub-environments if specified.
|
||||||
|
_update_env_seed_if_necessary(
|
||||||
|
env, seed, env_context.worker_index, vector_index
|
||||||
|
)
|
||||||
|
return env
|
||||||
|
|
||||||
|
if not env_context.remote:
|
||||||
|
|
||||||
|
def _make_sub_env_remote(vector_index):
|
||||||
|
sub_env = _make_sub_env_local(vector_index)
|
||||||
|
self.callbacks.on_sub_environment_created(
|
||||||
|
worker=self,
|
||||||
|
sub_environment=sub_env,
|
||||||
|
env_context=env_context.copy_with_overrides(
|
||||||
|
worker_index=env_context.worker_index,
|
||||||
|
vector_index=vector_index,
|
||||||
|
remote=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return sub_env
|
||||||
|
|
||||||
|
return _make_sub_env_remote
|
||||||
|
|
||||||
|
else:
|
||||||
|
return _make_sub_env_local
|
||||||
|
|
||||||
@Deprecated(
|
@Deprecated(
|
||||||
new="Trainer.get_policy().export_model([export_dir], [onnx]?)", error=False
|
new="Trainer.get_policy().export_model([export_dir], [onnx]?)", error=False
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue