mirror of
https://github.com/vale981/ray
synced 2025-03-10 13:26:39 -04:00
100 lines
3 KiB
Python
100 lines
3 KiB
Python
![]() |
from collections import defaultdict
|
||
|
import numpy as np
|
||
|
import tree # dm_tree
|
||
|
from typing import Any, List
|
||
|
|
||
|
from ray.rllib.connectors.connector import (
|
||
|
ConnectorContext,
|
||
|
AgentConnector,
|
||
|
register_connector,
|
||
|
)
|
||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||
|
from ray.rllib.utils.numpy import convert_to_numpy
|
||
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||
|
from ray.rllib.utils.typing import (
|
||
|
AgentConnectorDataType,
|
||
|
PolicyOutputType,
|
||
|
)
|
||
|
|
||
|
|
||
|
@DeveloperAPI
|
||
|
class _AgentState(object):
|
||
|
def __init__(self):
|
||
|
self.t = 0
|
||
|
self.action = None
|
||
|
self.states = None
|
||
|
|
||
|
|
||
|
@DeveloperAPI
|
||
|
class StateBufferConnector(AgentConnector):
|
||
|
def __init__(self, ctx: ConnectorContext):
|
||
|
super().__init__(ctx)
|
||
|
|
||
|
self._initial_states = ctx.initial_states
|
||
|
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||
|
self._states = defaultdict(lambda: defaultdict(_AgentState))
|
||
|
|
||
|
def reset(self, env_id: str):
|
||
|
del self._states[env_id]
|
||
|
|
||
|
def on_policy_output(self, env_id: str, agent_id: str, output: PolicyOutputType):
|
||
|
# Buffer latest output states for next input __call__.
|
||
|
action, states, _ = output
|
||
|
agent_state = self._states[env_id][agent_id]
|
||
|
agent_state.action = convert_to_numpy(action)
|
||
|
agent_state.states = convert_to_numpy(states)
|
||
|
|
||
|
def __call__(
|
||
|
self, ctx: ConnectorContext, ac_data: AgentConnectorDataType
|
||
|
) -> List[AgentConnectorDataType]:
|
||
|
d = ac_data.data
|
||
|
assert (
|
||
|
type(d) == dict
|
||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||
|
|
||
|
env_id = ac_data.env_id
|
||
|
agent_id = ac_data.agent_id
|
||
|
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
|
||
|
|
||
|
agent_state = self._states[env_id][agent_id]
|
||
|
|
||
|
d.update(
|
||
|
{
|
||
|
SampleBatch.T: agent_state.t,
|
||
|
SampleBatch.ENV_ID: env_id,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
if agent_state.states is not None:
|
||
|
states = agent_state.states
|
||
|
else:
|
||
|
states = self._initial_states
|
||
|
for i, v in enumerate(states):
|
||
|
d["state_out_{}".format(i)] = v
|
||
|
|
||
|
if agent_state.action is not None:
|
||
|
d[SampleBatch.ACTIONS] = agent_state.action # Last action
|
||
|
else:
|
||
|
# Default zero action.
|
||
|
d[SampleBatch.ACTIONS] = tree.map_structure(
|
||
|
lambda s: np.zeros_like(s.sample(), s.dtype)
|
||
|
if hasattr(s, "dtype")
|
||
|
else np.zeros_like(s.sample()),
|
||
|
self._action_space_struct,
|
||
|
)
|
||
|
|
||
|
agent_state.t += 1
|
||
|
|
||
|
return [ac_data]
|
||
|
|
||
|
def to_config(self):
|
||
|
return StateBufferConnector.__name__, None
|
||
|
|
||
|
@staticmethod
|
||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||
|
return StateBufferConnector(ctx)
|
||
|
|
||
|
|
||
|
register_connector(StateBufferConnector.__name__, StateBufferConnector)
|