ray/rllib/evaluation/observation_function.py

75 lines
3 KiB
Python
Raw Normal View History

from typing import Dict
from ray.rllib.env import BaseEnv
from ray.rllib.policy import Policy
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.utils.framework import TensorType
from ray.rllib.utils.typing import AgentID, PolicyID
class ObservationFunction:
"""Interceptor function for rewriting observations from the environment.
These callbacks can be used for preprocessing of observations, especially
in multi-agent scenarios.
Observation functions can be specified in the multi-agent config by
specifying ``{"observation_fn": your_obs_func}``. Note that
``your_obs_func`` can be a plain Python function.
This API is **experimental**.
"""
def __call__(
self,
agent_obs: Dict[AgentID, TensorType],
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: Episode,
**kw
) -> Dict[AgentID, TensorType]:
"""Callback run on each environment step to observe the environment.
This method takes in the original agent observation dict returned by
a MultiAgentEnv, and returns a possibly modified one. It can be
thought of as a "wrapper" around the environment.
TODO(ekl): allow end-to-end differentiation through the observation
function and policy losses.
TODO(ekl): enable batch processing.
Args:
agent_obs (dict): Dictionary of default observations from the
environment. The default implementation of observe() simply
returns this dict.
worker (RolloutWorker): Reference to the current rollout worker.
base_env (BaseEnv): BaseEnv running the episode. The underlying
sub environment objects (BaseEnvs are vectorized) can be
retrieved by calling `base_env.get_sub_environments()`.
policies (dict): Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode (Episode): Episode state object.
kwargs: Forward compatibility placeholder.
Returns:
new_agent_obs (dict): copy of agent obs with updates. You can
rewrite or drop data from the dict if needed (e.g., the env
can have a dummy "global" observation, and the observer can
merge the global state into individual observations.
Examples:
>>> # Observer that merges global state into individual obs. It is
... # rewriting the discrete obs into a tuple with global state.
>>> example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
{"a": [1, 101], "b": [2, 101]}
>>> # Observer for e.g., custom centralized critic model. It is
... # rewriting the discrete obs into a dict with more data.
>>> example_obs_fn2({"a": 1, "b": 2}, ...)
{"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
"""
return agent_obs