mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
74 lines
3 KiB
Python
74 lines
3 KiB
Python
from typing import Dict
|
|
|
|
from ray.rllib.env import BaseEnv
|
|
from ray.rllib.policy import Policy
|
|
from ray.rllib.evaluation import Episode, RolloutWorker
|
|
from ray.rllib.utils.framework import TensorType
|
|
from ray.rllib.utils.typing import AgentID, PolicyID
|
|
|
|
|
|
class ObservationFunction:
|
|
"""Interceptor function for rewriting observations from the environment.
|
|
|
|
These callbacks can be used for preprocessing of observations, especially
|
|
in multi-agent scenarios.
|
|
|
|
Observation functions can be specified in the multi-agent config by
|
|
specifying ``{"observation_fn": your_obs_func}``. Note that
|
|
``your_obs_func`` can be a plain Python function.
|
|
|
|
This API is **experimental**.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
agent_obs: Dict[AgentID, TensorType],
|
|
worker: RolloutWorker,
|
|
base_env: BaseEnv,
|
|
policies: Dict[PolicyID, Policy],
|
|
episode: Episode,
|
|
**kw
|
|
) -> Dict[AgentID, TensorType]:
|
|
"""Callback run on each environment step to observe the environment.
|
|
|
|
This method takes in the original agent observation dict returned by
|
|
a MultiAgentEnv, and returns a possibly modified one. It can be
|
|
thought of as a "wrapper" around the environment.
|
|
|
|
TODO(ekl): allow end-to-end differentiation through the observation
|
|
function and policy losses.
|
|
|
|
TODO(ekl): enable batch processing.
|
|
|
|
Args:
|
|
agent_obs (dict): Dictionary of default observations from the
|
|
environment. The default implementation of observe() simply
|
|
returns this dict.
|
|
worker (RolloutWorker): Reference to the current rollout worker.
|
|
base_env (BaseEnv): BaseEnv running the episode. The underlying
|
|
sub environment objects (BaseEnvs are vectorized) can be
|
|
retrieved by calling `base_env.get_sub_environments()`.
|
|
policies (dict): Mapping of policy id to policy objects. In single
|
|
agent mode there will only be a single "default" policy.
|
|
episode (Episode): Episode state object.
|
|
kwargs: Forward compatibility placeholder.
|
|
|
|
Returns:
|
|
new_agent_obs (dict): copy of agent obs with updates. You can
|
|
rewrite or drop data from the dict if needed (e.g., the env
|
|
can have a dummy "global" observation, and the observer can
|
|
merge the global state into individual observations.
|
|
|
|
Examples:
|
|
>>> # Observer that merges global state into individual obs. It is
|
|
... # rewriting the discrete obs into a tuple with global state.
|
|
>>> example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
|
|
{"a": [1, 101], "b": [2, 101]}
|
|
|
|
>>> # Observer for e.g., custom centralized critic model. It is
|
|
... # rewriting the discrete obs into a dict with more data.
|
|
>>> example_obs_fn2({"a": 1, "b": 2}, ...)
|
|
{"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
|
|
"""
|
|
|
|
return agent_obs
|