ray/rllib/connectors/agent/state_buffer.py

83 lines
2.8 KiB
Python
Raw Normal View History

from collections import defaultdict
from typing import Any, List
import numpy as np
import tree # dm_tree
from ray.rllib.connectors.connector import (
AgentConnector,
ConnectorContext,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI
@PublicAPI(stability="alpha")
class StateBufferConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._soft_horizon = ctx.config.get("soft_horizon", False)
self._initial_states = ctx.initial_states
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None)))
def reset(self, env_id: str):
# If soft horizon, states should be carried over between episodes.
if not self._soft_horizon and env_id in self._states:
del self._states[env_id]
def on_policy_output(self, ac_data: ActionConnectorDataType):
# Buffer latest output states for next input __call__.
self._states[ac_data.env_id][ac_data.agent_id] = ac_data.output
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
env_id = ac_data.env_id
agent_id = ac_data.agent_id
assert (
env_id is not None and agent_id is not None
), f"StateBufferConnector requires env_id(f{env_id}) and agent_id(f{agent_id})"
action, states, fetches = self._states[env_id][agent_id]
# TODO(jungong): Support buffering more than 1 prev actions.
if action is not None:
d[SampleBatch.ACTIONS] = action # Last action
else:
# Default zero action.
d[SampleBatch.ACTIONS] = tree.map_structure(
lambda s: np.zeros_like(s.sample(), s.dtype)
if hasattr(s, "dtype")
else np.zeros_like(s.sample()),
self._action_space_struct,
)
if states is None:
states = self._initial_states
for i, v in enumerate(states):
d["state_out_{}".format(i)] = v
# Also add extra fetches if available.
if fetches:
d.update(fetches)
return ac_data
def to_config(self):
return StateBufferConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return StateBufferConnector(ctx)
register_connector(StateBufferConnector.__name__, StateBufferConnector)