ray/rllib/env/env_context.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

124 lines
5.1 KiB
Python
Raw Normal View History

import copy
from typing import Optional
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import EnvConfigDict
@PublicAPI
class EnvContext(dict):
"""Wraps env configurations to include extra rllib metadata.
These attributes can be used to parameterize environments per process.
For example, one might use `worker_index` to control which data file an
environment reads in on initialization.
RLlib auto-sets these attributes when constructing registered envs.
"""
def __init__(
self,
env_config: EnvConfigDict,
worker_index: int,
vector_index: int = 0,
remote: bool = False,
num_workers: Optional[int] = None,
recreated_worker: bool = False,
):
"""Initializes an EnvContext instance.
Args:
env_config: The env's configuration defined under the
"env_config" key in the Trainer's config.
worker_index: When there are multiple workers created, this
uniquely identifies the worker the env is created in.
0 for local worker, >0 for remote workers.
vector_index: When there are multiple envs per worker, this
uniquely identifies the env index within the worker.
Starts from 0.
remote: Whether individual sub-environments (in a vectorized
env) should be @ray.remote actors or not.
num_workers: The total number of (remote) workers in the set.
0 if only a local worker exists.
recreated_worker: Whether the worker that holds this env is a recreated one.
This means that it replaced a previous (failed) worker when
`recreate_failed_workers=True` in the Trainer's config.
"""
# Store the env_config in the (super) dict.
dict.__init__(self, env_config)
# Set some metadata attributes.
self.worker_index = worker_index
self.vector_index = vector_index
self.remote = remote
self.num_workers = num_workers
self.recreated_worker = recreated_worker
def copy_with_overrides(
self,
env_config: Optional[EnvConfigDict] = None,
worker_index: Optional[int] = None,
vector_index: Optional[int] = None,
remote: Optional[bool] = None,
num_workers: Optional[int] = None,
recreated_worker: Optional[bool] = None,
) -> "EnvContext":
"""Returns a copy of this EnvContext with some attributes overridden.
Args:
env_config: Optional env config to use. None for not overriding
the one from the source (self).
worker_index: Optional worker index to use. None for not
overriding the one from the source (self).
vector_index: Optional vector index to use. None for not
overriding the one from the source (self).
remote: Optional remote setting to use. None for not overriding
the one from the source (self).
num_workers: Optional num_workers to use. None for not overriding
the one from the source (self).
recreated_worker: Optional flag, indicating, whether the worker that holds
the env is a recreated one. This means that it replaced a previous
(failed) worker when `recreate_failed_workers=True` in the Trainer's
config.
Returns:
A new EnvContext object as a copy of self plus the provided
overrides.
"""
return EnvContext(
copy.deepcopy(env_config) if env_config is not None else self,
worker_index if worker_index is not None else self.worker_index,
vector_index if vector_index is not None else self.vector_index,
remote if remote is not None else self.remote,
num_workers if num_workers is not None else self.num_workers,
recreated_worker if recreated_worker is not None else self.recreated_worker,
)
def set_defaults(self, defaults: dict) -> None:
"""Sets missing keys of self to the values given in `defaults`.
If `defaults` contains keys that already exist in self, don't override
the values with these defaults.
Args:
defaults: The key/value pairs to add to self, but only for those
keys in `defaults` that don't exist yet in self.
Examples:
>>> from ray.rllib.env.env_context import EnvContext
>>> env_ctx = EnvContext({"a": 1, "b": 2}, worker_index=0) # doctest: +SKIP
>>> env_ctx.set_defaults({"a": -42, "c": 3}) # doctest: +SKIP
>>> print(env_ctx) # doctest: +SKIP
{"a": 1, "b": 2, "c": 3}
"""
for key, value in defaults.items():
if key not in self:
self[key] = value
def __str__(self):
return (
super().__str__()[:-1]
+ f", worker={self.worker_index}/{self.num_workers}, "
f"vector_idx={self.vector_index}, remote={self.remote}" + "}"
)