ray/rllib/env/env_context.py

123 lines
5.1 KiB
Python

import copy
from typing import Optional
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import EnvConfigDict
@PublicAPI
class EnvContext(dict):
"""Wraps env configurations to include extra rllib metadata.
These attributes can be used to parameterize environments per process.
For example, one might use `worker_index` to control which data file an
environment reads in on initialization.
RLlib auto-sets these attributes when constructing registered envs.
"""
def __init__(
self,
env_config: EnvConfigDict,
worker_index: int,
vector_index: int = 0,
remote: bool = False,
num_workers: Optional[int] = None,
recreated_worker: bool = False,
):
"""Initializes an EnvContext instance.
Args:
env_config: The env's configuration defined under the
"env_config" key in the Trainer's config.
worker_index: When there are multiple workers created, this
uniquely identifies the worker the env is created in.
0 for local worker, >0 for remote workers.
vector_index: When there are multiple envs per worker, this
uniquely identifies the env index within the worker.
Starts from 0.
remote: Whether individual sub-environments (in a vectorized
env) should be @ray.remote actors or not.
num_workers: The total number of (remote) workers in the set.
0 if only a local worker exists.
recreated_worker: Whether the worker that holds this env is a recreated one.
This means that it replaced a previous (failed) worker when
`recreate_failed_workers=True` in the Trainer's config.
"""
# Store the env_config in the (super) dict.
dict.__init__(self, env_config)
# Set some metadata attributes.
self.worker_index = worker_index
self.vector_index = vector_index
self.remote = remote
self.num_workers = num_workers
self.recreated_worker = recreated_worker
def copy_with_overrides(
self,
env_config: Optional[EnvConfigDict] = None,
worker_index: Optional[int] = None,
vector_index: Optional[int] = None,
remote: Optional[bool] = None,
num_workers: Optional[int] = None,
recreated_worker: Optional[bool] = None,
) -> "EnvContext":
"""Returns a copy of this EnvContext with some attributes overridden.
Args:
env_config: Optional env config to use. None for not overriding
the one from the source (self).
worker_index: Optional worker index to use. None for not
overriding the one from the source (self).
vector_index: Optional vector index to use. None for not
overriding the one from the source (self).
remote: Optional remote setting to use. None for not overriding
the one from the source (self).
num_workers: Optional num_workers to use. None for not overriding
the one from the source (self).
recreated_worker: Optional flag, indicating, whether the worker that holds
the env is a recreated one. This means that it replaced a previous
(failed) worker when `recreate_failed_workers=True` in the Trainer's
config.
Returns:
A new EnvContext object as a copy of self plus the provided
overrides.
"""
return EnvContext(
copy.deepcopy(env_config) if env_config is not None else self,
worker_index if worker_index is not None else self.worker_index,
vector_index if vector_index is not None else self.vector_index,
remote if remote is not None else self.remote,
num_workers if num_workers is not None else self.num_workers,
recreated_worker if recreated_worker is not None else self.recreated_worker,
)
def set_defaults(self, defaults: dict) -> None:
"""Sets missing keys of self to the values given in `defaults`.
If `defaults` contains keys that already exist in self, don't override
the values with these defaults.
Args:
defaults: The key/value pairs to add to self, but only for those
keys in `defaults` that don't exist yet in self.
Examples:
>>> from ray.rllib.env.env_context import EnvContext
>>> env_ctx = EnvContext({"a": 1, "b": 2}, worker_index=0) # doctest: +SKIP
>>> env_ctx.set_defaults({"a": -42, "c": 3}) # doctest: +SKIP
>>> print(env_ctx) # doctest: +SKIP
{"a": 1, "b": 2, "c": 3}
"""
for key, value in defaults.items():
if key not in self:
self[key] = value
def __str__(self):
return (
super().__str__()[:-1]
+ f", worker={self.worker_index}/{self.num_workers}, "
f"vector_idx={self.vector_index}, remote={self.remote}" + "}"
)