2022-06-07 10:18:14 -07:00
|
|
|
from collections import defaultdict
|
2022-06-29 23:44:10 -07:00
|
|
|
from typing import Any, List
|
|
|
|
|
2022-06-07 10:18:14 -07:00
|
|
|
import numpy as np
|
|
|
|
import tree # dm_tree
|
|
|
|
|
|
|
|
from ray.rllib.connectors.connector import (
|
|
|
|
AgentConnector,
|
2022-06-29 23:44:10 -07:00
|
|
|
ConnectorContext,
|
2022-06-07 10:18:14 -07:00
|
|
|
register_connector,
|
|
|
|
)
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
2022-06-29 23:44:10 -07:00
|
|
|
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
|
|
|
|
from ray.util.annotations import PublicAPI
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
@PublicAPI(stability="alpha")
|
2022-06-07 10:18:14 -07:00
|
|
|
class StateBufferConnector(AgentConnector):
|
|
|
|
def __init__(self, ctx: ConnectorContext):
|
|
|
|
super().__init__(ctx)
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
self._soft_horizon = ctx.config.get("soft_horizon", False)
|
2022-06-07 10:18:14 -07:00
|
|
|
self._initial_states = ctx.initial_states
|
|
|
|
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
2022-06-29 23:44:10 -07:00
|
|
|
self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None)))
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
def reset(self, env_id: str):
|
2022-06-29 23:44:10 -07:00
|
|
|
# If soft horizon, states should be carried over between episodes.
|
|
|
|
if not self._soft_horizon:
|
|
|
|
del self._states[env_id]
|
2022-06-07 10:18:14 -07:00
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
def on_policy_output(self, ac_data: ActionConnectorDataType):
|
2022-06-07 10:18:14 -07:00
|
|
|
# Buffer latest output states for next input __call__.
|
2022-06-29 23:44:10 -07:00
|
|
|
self._states[ac_data.env_id][ac_data.agent_id] = ac_data.output
|
|
|
|
|
|
|
|
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
2022-06-07 10:18:14 -07:00
|
|
|
d = ac_data.data
|
|
|
|
assert (
|
|
|
|
type(d) == dict
|
|
|
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
|
|
|
|
|
|
|
env_id = ac_data.env_id
|
|
|
|
agent_id = ac_data.agent_id
|
2022-06-29 23:44:10 -07:00
|
|
|
assert (
|
|
|
|
env_id is not None and agent_id is not None
|
|
|
|
), f"StateBufferConnector requires env_id(f{env_id}) and agent_id(f{agent_id})"
|
2022-06-07 10:18:14 -07:00
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
action, states, fetches = self._states[env_id][agent_id]
|
2022-06-07 10:18:14 -07:00
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
if action is not None:
|
|
|
|
d[SampleBatch.ACTIONS] = action # Last action
|
2022-06-07 10:18:14 -07:00
|
|
|
else:
|
|
|
|
# Default zero action.
|
|
|
|
d[SampleBatch.ACTIONS] = tree.map_structure(
|
|
|
|
lambda s: np.zeros_like(s.sample(), s.dtype)
|
|
|
|
if hasattr(s, "dtype")
|
|
|
|
else np.zeros_like(s.sample()),
|
|
|
|
self._action_space_struct,
|
|
|
|
)
|
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
if states is None:
|
|
|
|
states = self._initial_states
|
|
|
|
for i, v in enumerate(states):
|
|
|
|
d["state_out_{}".format(i)] = v
|
|
|
|
|
|
|
|
# Also add extra fetches if available.
|
|
|
|
if fetches:
|
|
|
|
d.update(fetches)
|
2022-06-07 10:18:14 -07:00
|
|
|
|
2022-06-29 23:44:10 -07:00
|
|
|
return ac_data
|
2022-06-07 10:18:14 -07:00
|
|
|
|
|
|
|
def to_config(self):
|
|
|
|
return StateBufferConnector.__name__, None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
|
|
|
return StateBufferConnector(ctx)
|
|
|
|
|
|
|
|
|
|
|
|
register_connector(StateBufferConnector.__name__, StateBufferConnector)
|