from collections import defaultdict
import numpy as np
import tree  # dm_tree
from typing import Any, List

from ray.rllib.connectors.connector import (
    ConnectorContext,
    AgentConnector,
    register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import (
    AgentConnectorDataType,
    PolicyOutputType,
)


@DeveloperAPI
class _AgentState(object):
    def __init__(self):
        self.t = 0
        self.action = None
        self.states = None


@DeveloperAPI
class StateBufferConnector(AgentConnector):
    def __init__(self, ctx: ConnectorContext):
        super().__init__(ctx)

        self._initial_states = ctx.initial_states
        self._action_space_struct = get_base_struct_from_space(ctx.action_space)
        self._states = defaultdict(lambda: defaultdict(_AgentState))

    def reset(self, env_id: str):
        del self._states[env_id]

    def on_policy_output(self, env_id: str, agent_id: str, output: PolicyOutputType):
        # Buffer latest output states for next input __call__.
        action, states, _ = output
        agent_state = self._states[env_id][agent_id]
        agent_state.action = convert_to_numpy(action)
        agent_state.states = convert_to_numpy(states)

    def __call__(
        self, ctx: ConnectorContext, ac_data: AgentConnectorDataType
    ) -> List[AgentConnectorDataType]:
        d = ac_data.data
        assert (
            type(d) == dict
        ), "Single agent data must be of type Dict[str, TensorStructType]"

        env_id = ac_data.env_id
        agent_id = ac_data.agent_id
        assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"

        agent_state = self._states[env_id][agent_id]

        d.update(
            {
                SampleBatch.T: agent_state.t,
                SampleBatch.ENV_ID: env_id,
            }
        )

        if agent_state.states is not None:
            states = agent_state.states
        else:
            states = self._initial_states
        for i, v in enumerate(states):
            d["state_out_{}".format(i)] = v

        if agent_state.action is not None:
            d[SampleBatch.ACTIONS] = agent_state.action  # Last action
        else:
            # Default zero action.
            d[SampleBatch.ACTIONS] = tree.map_structure(
                lambda s: np.zeros_like(s.sample(), s.dtype)
                if hasattr(s, "dtype")
                else np.zeros_like(s.sample()),
                self._action_space_struct,
            )

        agent_state.t += 1

        return [ac_data]

    def to_config(self):
        return StateBufferConnector.__name__, None

    @staticmethod
    def from_config(ctx: ConnectorContext, params: List[Any]):
        return StateBufferConnector(ctx)


register_connector(StateBufferConnector.__name__, StateBufferConnector)