diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 5db34a29a..f75808f9a 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -21,6 +21,7 @@ RLLIB_MODEL = "rllib_model" RLLIB_PREPROCESSOR = "rllib_preprocessor" RLLIB_ACTION_DIST = "rllib_action_dist" RLLIB_INPUT = "rllib_input" +RLLIB_CONNECTOR = "rllib_connector" TEST = "__test__" KNOWN_CATEGORIES = [ TRAINABLE_CLASS, @@ -29,6 +30,7 @@ KNOWN_CATEGORIES = [ RLLIB_PREPROCESSOR, RLLIB_ACTION_DIST, RLLIB_INPUT, + RLLIB_CONNECTOR, TEST, ] diff --git a/rllib/BUILD b/rllib/BUILD index 1da27e935..48b1c6b9a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1292,6 +1292,34 @@ py_test( ] ) +# -------------------------------------------------------------------- +# Connector tests +# rllib/connector/ +# +# Tag: connector +# -------------------------------------------------------------------- + +py_test( + name = "test_connector", + tags = ["team:ml", "connector"], + size = "small", + srcs = ["connectors/tests/test_connector.py"] +) + +py_test( + name = "test_action", + tags = ["team:ml", "connector"], + size = "small", + srcs = ["connectors/tests/test_action.py"] +) + +py_test( + name = "test_agent", + tags = ["team:ml", "connector"], + size = "small", + srcs = ["connectors/tests/test_agent.py"] +) + # -------------------------------------------------------------------- # Env tests # rllib/env/ diff --git a/rllib/connectors/__init__.py b/rllib/connectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rllib/connectors/action/clip.py b/rllib/connectors/action/clip.py new file mode 100644 index 000000000..dffbe5d40 --- /dev/null +++ b/rllib/connectors/action/clip.py @@ -0,0 +1,43 @@ +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + ActionConnector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.spaces.space_utils import ( + clip_action, + get_base_struct_from_space, +) +from ray.rllib.utils.typing import ActionConnectorDataType + + +@DeveloperAPI +class ClipActionsConnector(ActionConnector): + def __init__(self, ctx: ConnectorContext): + super().__init__(ctx) + + self._action_space_struct = get_base_struct_from_space(ctx.action_space) + + def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + assert isinstance( + ac_data.output, tuple + ), "Action connector requires PolicyOutputType data." + + actions, states, fetches = ac_data.output + return ActionConnectorDataType( + ac_data.env_id, + ac_data.agent_id, + (clip_action(actions, self._action_space_struct), states, fetches), + ) + + def to_config(self): + return ClipActionsConnector.__name__, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return ClipActionsConnector(ctx) + + +register_connector(ClipActionsConnector.__name__, ClipActionsConnector) diff --git a/rllib/connectors/action/lambdas.py b/rllib/connectors/action/lambdas.py new file mode 100644 index 000000000..63014c08e --- /dev/null +++ b/rllib/connectors/action/lambdas.py @@ -0,0 +1,79 @@ +from typing import Any, Callable, Dict, List, Type + +from ray.rllib.connectors.connector import ( + ConnectorContext, + ActionConnector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import unbatch +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + PolicyOutputType, + StateBatches, + TensorStructType, +) + + +@DeveloperAPI +def register_lambda_action_connector( + name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType] +) -> Type[ActionConnector]: + """A util to register any function transforming PolicyOutputType as an ActionConnector. + + The only requirement is that fn should take actions, states, and fetches as input, + and return transformed actions, states, and fetches. + + Args: + name: Name of the resulting actor connector. + fn: The function that transforms PolicyOutputType. + + Returns: + A new ActionConnector class that transforms PolicyOutputType using fn. + """ + + class LambdaActionConnector(ActionConnector): + def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + assert isinstance( + ac_data.output, tuple + ), "Action connector requires PolicyOutputType data." + + actions, states, fetches = ac_data.output + return ActionConnectorDataType( + ac_data.env_id, + ac_data.agent_id, + fn(actions, states, fetches), + ) + + def to_config(self): + return name, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return LambdaActionConnector(ctx) + + LambdaActionConnector.__name__ = name + LambdaActionConnector.__qualname__ = name + + register_connector(name, LambdaActionConnector) + + return LambdaActionConnector + + +# Convert actions and states into numpy arrays if necessary. +ConvertToNumpyConnector = register_lambda_action_connector( + "ConvertToNumpyConnector", + lambda actions, states, fetches: ( + convert_to_numpy(actions), + convert_to_numpy(states), + fetches, + ), +) + + +# Split action-component batches into single action rows. +UnbatchActionsConnector = register_lambda_action_connector( + "UnbatchActionsConnector", + lambda actions, states, fetches: (unbatch(actions), states, fetches), +) diff --git a/rllib/connectors/action/normalize.py b/rllib/connectors/action/normalize.py new file mode 100644 index 000000000..6dc7b7b0e --- /dev/null +++ b/rllib/connectors/action/normalize.py @@ -0,0 +1,43 @@ +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + ActionConnector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.spaces.space_utils import ( + get_base_struct_from_space, + unsquash_action, +) +from ray.rllib.utils.typing import ActionConnectorDataType + + +@DeveloperAPI +class NormalizeActionsConnector(ActionConnector): + def __init__(self, ctx: ConnectorContext): + super().__init__(ctx) + + self._action_space_struct = get_base_struct_from_space(ctx.action_space) + + def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + assert isinstance( + ac_data.output, tuple + ), "Action connector requires PolicyOutputType data." + + actions, states, fetches = ac_data.output + return ActionConnectorDataType( + ac_data.env_id, + ac_data.agent_id, + (unsquash_action(actions, self._action_space_struct), states, fetches), + ) + + def to_config(self): + return NormalizeActionsConnector.__name__, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return NormalizeActionsConnector(ctx) + + +register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector) diff --git a/rllib/connectors/action/pipeline.py b/rllib/connectors/action/pipeline.py new file mode 100644 index 000000000..ba9ad55e3 --- /dev/null +++ b/rllib/connectors/action/pipeline.py @@ -0,0 +1,57 @@ +import gym +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ActionConnector, + Connector, + ConnectorContext, + ConnectorPipeline, + get_connector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + TrainerConfigDict, +) + + +@DeveloperAPI +class ActionConnectorPipeline(ActionConnector, ConnectorPipeline): + def __init__(self, ctx: ConnectorContext, connectors: List[Connector]): + super().__init__(ctx) + self.connectors = connectors + + def is_training(self, is_training: bool): + self.is_training = is_training + for c in self.connectors: + c.is_training(is_training) + + def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + for c in self.connectors: + ac_data = c(ac_data) + return ac_data + + def to_config(self): + return ActionConnectorPipeline.__name__, [ + c.to_config() for c in self.connectors + ] + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + assert ( + type(params) == list + ), "ActionConnectorPipeline takes a list of connector params." + connectors = [get_connector(ctx, name, subparams) for name, subparams in params] + return ActionConnectorPipeline(ctx, connectors) + + +register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline) + + +@DeveloperAPI +def get_action_connectors_from_trainer_config( + config: TrainerConfigDict, action_space: gym.Space +) -> ActionConnectorPipeline: + connectors = [] + return ActionConnectorPipeline(connectors) diff --git a/rllib/connectors/agent/clip_reward.py b/rllib/connectors/agent/clip_reward.py new file mode 100644 index 000000000..63c6e9b6c --- /dev/null +++ b/rllib/connectors/agent/clip_reward.py @@ -0,0 +1,52 @@ +import numpy as np +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + AgentConnector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.typing import AgentConnectorDataType + + +@DeveloperAPI +class ClipRewardAgentConnector(AgentConnector): + def __init__(self, ctx: ConnectorContext, sign=False, limit=None): + super().__init__(ctx) + assert ( + not sign or not limit + ), "should not enable both sign and limit reward clipping." + self.sign = sign + self.limit = limit + + def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: + d = ac_data.data + assert ( + type(d) == dict + ), "Single agent data must be of type Dict[str, TensorStructType]" + + assert SampleBatch.REWARDS in d, "input data does not have reward column." + if self.sign: + d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS]) + elif self.limit: + d[SampleBatch.REWARDS] = np.clip( + d[SampleBatch.REWARDS], + a_min=-self.limit, + a_max=self.limit, + ) + return [ac_data] + + def to_config(self): + return ClipRewardAgentConnector.__name__, { + "sign": self.sign, + "limit": self.limit, + } + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return ClipRewardAgentConnector(ctx, **params) + + +register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector) diff --git a/rllib/connectors/agent/env_to_agent.py b/rllib/connectors/agent/env_to_agent.py new file mode 100644 index 000000000..b5ba6cfb5 --- /dev/null +++ b/rllib/connectors/agent/env_to_agent.py @@ -0,0 +1,72 @@ +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + AgentConnector, + register_connector, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import AgentConnectorDataType + + +@DeveloperAPI +class EnvToAgentDataConnector(AgentConnector): + """Converts per environment multi-agent obs into per agent SampleBatches.""" + + def __init__(self, ctx: ConnectorContext): + super().__init__(ctx) + self._view_requirements = ctx.view_requirements + + def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: + if ac_data.agent_id: + # data is already for a single agent. + return [ac_data] + + assert isinstance(ac_data.data, (tuple, list)) and len(ac_data.data) == 5, ( + "EnvToPerAgentDataConnector expects a tuple of " + + "(obs, rewards, dones, infos, episode_infos)." + ) + # episode_infos contains additional training related data bits + # for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX, + # SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon), + # and is usually empty in inference mode. + obs, rewards, dones, infos, training_episode_infos = ac_data.data + for var, name in zip( + (obs, rewards, dones, infos, training_episode_infos), + ("obs", "rewards", "dones", "infos", "training_episode_infos"), + ): + assert isinstance(var, dict), ( + f"EnvToPerAgentDataConnector expects {name} " + + "to be a MultiAgentDict." + ) + + env_id = ac_data.env_id + per_agent_data = [] + for agent_id, obs in obs.items(): + input_dict = { + SampleBatch.ENV_ID: env_id, + SampleBatch.REWARDS: rewards[agent_id], + # SampleBatch.DONES may be overridden by data from + # training_episode_infos next. + SampleBatch.DONES: dones[agent_id], + SampleBatch.NEXT_OBS: obs, + } + if SampleBatch.INFOS in self._view_requirements: + input_dict[SampleBatch.INFOS] = infos[agent_id] + if agent_id in training_episode_infos: + input_dict.update(training_episode_infos[agent_id]) + + per_agent_data.append(AgentConnectorDataType(env_id, agent_id, input_dict)) + + return per_agent_data + + def to_config(self): + return EnvToAgentDataConnector.__name__, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return EnvToAgentDataConnector(ctx) + + +register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector) diff --git a/rllib/connectors/agent/lambdas.py b/rllib/connectors/agent/lambdas.py new file mode 100644 index 000000000..6e52014f8 --- /dev/null +++ b/rllib/connectors/agent/lambdas.py @@ -0,0 +1,81 @@ +import numpy as np +import tree # dm_tree +from typing import Any, Callable, Dict, List, Type + +from ray.rllib.connectors.connector import ( + ConnectorContext, + AgentConnector, + register_connector, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import ( + AgentConnectorDataType, + TensorStructType, +) + + +@DeveloperAPI +def register_lambda_agent_connector( + name: str, fn: Callable[[Any], Any] +) -> Type[AgentConnector]: + """A util to register any simple transforming function as an AgentConnector + + The only requirement is that fn should take a single data object and return + a single data object. + + Args: + name: Name of the resulting actor connector. + fn: The function that transforms env / agent data. + + Returns: + A new AgentConnector class that transforms data using fn. + """ + + class LambdaAgentConnector(AgentConnector): + def __call__( + self, ac_data: AgentConnectorDataType + ) -> List[AgentConnectorDataType]: + d = ac_data.data + return [AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))] + + def to_config(self): + return name, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return LambdaAgentConnector(ctx) + + LambdaAgentConnector.__name__ = name + LambdaAgentConnector.__qualname__ = name + + register_connector(name, LambdaAgentConnector) + + return LambdaAgentConnector + + +@DeveloperAPI +def flatten_data(data: Dict[str, TensorStructType]): + assert ( + type(data) == dict + ), "Single agent data must be of type Dict[str, TensorStructType]" + + flattened = {} + for k, v in data.items(): + if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"): + # Do not flatten infos, actions, and state_out_ columns. + flattened[k] = v + continue + if v is None: + # Keep the same column shape. + flattened[k] = None + continue + flattened[k] = np.array(tree.flatten(v)) + + return flattened + + +# Flatten observation data. +FlattenDataAgentConnector = register_lambda_agent_connector( + "FlattenDataAgentConnector", flatten_data +) diff --git a/rllib/connectors/agent/obs_preproc.py b/rllib/connectors/agent/obs_preproc.py new file mode 100644 index 000000000..03720a3a1 --- /dev/null +++ b/rllib/connectors/agent/obs_preproc.py @@ -0,0 +1,60 @@ +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + AgentConnector, + register_connector, +) +from ray.rllib.models.preprocessors import get_preprocessor +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import AgentConnectorDataType + + +# Bridging between current obs preprocessors and connector. +# We should not introduce any new preprocessors. +# TODO(jungong) : migrate and implement preprocessor library in Connector framework. +@DeveloperAPI +class ObsPreprocessorConnector(AgentConnector): + """A connector that wraps around existing RLlib observation preprocessors. + + This includes: + - OneHotPreprocessor for Discrete and Multi-Discrete spaces. + - GenericPixelPreprocessor and AtariRamPreprocessor for Atari spaces. + - TupleFlatteningPreprocessor and DictFlatteningPreprocessor for flattening + arbitrary nested input observations. + - RepeatedValuesPreprocessor for padding observations from RLlib Repeated + observation space. + """ + + def __init__(self, ctx: ConnectorContext): + super().__init__(ctx) + + self._preprocessor = get_preprocessor(ctx.observation_space)( + ctx.observation_space, ctx.config.get("model", {}) + ) + + def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: + d = ac_data.data + assert ( + type(d) == dict + ), "Single agent data must be of type Dict[str, TensorStructType]" + + if SampleBatch.OBS in d: + d[SampleBatch.OBS] = self._preprocessor.transform(d[SampleBatch.OBS]) + if SampleBatch.NEXT_OBS in d: + d[SampleBatch.NEXT_OBS] = self._preprocessor.transform( + d[SampleBatch.NEXT_OBS] + ) + + return [ac_data] + + def to_config(self): + return ObsPreprocessorConnector.__name__, {} + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return ObsPreprocessorConnector(ctx, **params) + + +register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector) diff --git a/rllib/connectors/agent/pipeline.py b/rllib/connectors/agent/pipeline.py new file mode 100644 index 000000000..595f02c39 --- /dev/null +++ b/rllib/connectors/agent/pipeline.py @@ -0,0 +1,79 @@ +import gym +from typing import Any, List + +from ray.rllib.connectors.connector import ( + Connector, + ConnectorContext, + ConnectorPipeline, + AgentConnector, + register_connector, + get_connector, +) +from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector +from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + AgentConnectorDataType, + TrainerConfigDict, +) + + +@DeveloperAPI +class AgentConnectorPipeline(AgentConnector, ConnectorPipeline): + def __init__(self, ctx: ConnectorContext, connectors: List[Connector]): + super().__init__(ctx) + self.connectors = connectors + + def is_training(self, is_training: bool): + self.is_training = is_training + for c in self.connectors: + c.is_training(is_training) + + def reset(self, env_id: str): + for c in self.connectors: + c.reset(env_id) + + def on_policy_output(self, output: ActionConnectorDataType): + for c in self.connectors: + c.on_policy_output(output) + + def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: + ret = [ac_data] + for c in self.connectors: + # Run the list of input data through the next agent connect, + # and collect the list of output data. + new_ret = [] + for d in ret: + new_ret += c(d) + ret = new_ret + return ret + + def to_config(self): + return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors] + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + assert ( + type(params) == list + ), "AgentConnectorPipeline takes a list of connector params." + connectors = [get_connector(ctx, name, subparams) for name, subparams in params] + return AgentConnectorPipeline(ctx, connectors) + + +register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline) + + +# TODO(jungong) : finish this. +@DeveloperAPI +def get_agent_connectors_from_config( + config: TrainerConfigDict, obs_space: gym.Space +) -> AgentConnectorPipeline: + connectors = [FlattenDataAgentConnector()] + + if config["clip_rewards"] is True: + connectors.append(ClipRewardAgentConnector(sign=True)) + elif type(config["clip_rewards"]) == float: + connectors.append(ClipRewardAgentConnector(limit=abs(config["clip_rewards"]))) + + return AgentConnectorPipeline(connectors) diff --git a/rllib/connectors/agent/state_buffer.py b/rllib/connectors/agent/state_buffer.py new file mode 100644 index 000000000..1328d5b02 --- /dev/null +++ b/rllib/connectors/agent/state_buffer.py @@ -0,0 +1,99 @@ +from collections import defaultdict +import numpy as np +import tree # dm_tree +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + AgentConnector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import ( + AgentConnectorDataType, + PolicyOutputType, +) + + +@DeveloperAPI +class _AgentState(object): + def __init__(self): + self.t = 0 + self.action = None + self.states = None + + +@DeveloperAPI +class StateBufferConnector(AgentConnector): + def __init__(self, ctx: ConnectorContext): + super().__init__(ctx) + + self._initial_states = ctx.initial_states + self._action_space_struct = get_base_struct_from_space(ctx.action_space) + self._states = defaultdict(lambda: defaultdict(_AgentState)) + + def reset(self, env_id: str): + del self._states[env_id] + + def on_policy_output(self, env_id: str, agent_id: str, output: PolicyOutputType): + # Buffer latest output states for next input __call__. + action, states, _ = output + agent_state = self._states[env_id][agent_id] + agent_state.action = convert_to_numpy(action) + agent_state.states = convert_to_numpy(states) + + def __call__( + self, ctx: ConnectorContext, ac_data: AgentConnectorDataType + ) -> List[AgentConnectorDataType]: + d = ac_data.data + assert ( + type(d) == dict + ), "Single agent data must be of type Dict[str, TensorStructType]" + + env_id = ac_data.env_id + agent_id = ac_data.agent_id + assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id" + + agent_state = self._states[env_id][agent_id] + + d.update( + { + SampleBatch.T: agent_state.t, + SampleBatch.ENV_ID: env_id, + } + ) + + if agent_state.states is not None: + states = agent_state.states + else: + states = self._initial_states + for i, v in enumerate(states): + d["state_out_{}".format(i)] = v + + if agent_state.action is not None: + d[SampleBatch.ACTIONS] = agent_state.action # Last action + else: + # Default zero action. + d[SampleBatch.ACTIONS] = tree.map_structure( + lambda s: np.zeros_like(s.sample(), s.dtype) + if hasattr(s, "dtype") + else np.zeros_like(s.sample()), + self._action_space_struct, + ) + + agent_state.t += 1 + + return [ac_data] + + def to_config(self): + return StateBufferConnector.__name__, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return StateBufferConnector(ctx) + + +register_connector(StateBufferConnector.__name__, StateBufferConnector) diff --git a/rllib/connectors/agent/view_requirement.py b/rllib/connectors/agent/view_requirement.py new file mode 100644 index 000000000..2cec1a736 --- /dev/null +++ b/rllib/connectors/agent/view_requirement.py @@ -0,0 +1,136 @@ +from collections import defaultdict +import numpy as np +from typing import Any, List + +from ray.rllib.connectors.connector import ( + ConnectorContext, + AgentConnector, + register_connector, +) +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.typing import ( + AgentConnectorDataType, + AgentConnectorsOutput, +) + + +@DeveloperAPI +class ViewRequirementAgentConnector(AgentConnector): + """This connector does 2 things: + 1. It filters data columns based on view_requirements for training and inference. + 2. It buffers the right amount of history for computing the sample batch for + action computation. + The output of this connector is AgentConnectorsOut, which basically is + a tuple of 2 things: + { + "for_training": {"obs": ...} + "for_action": SampleBatch + } + The "for_training" dict, which contains data for the latest time slice, + can be used to construct a complete episode by Sampler for training purpose. + The "for_action" SampleBatch can be used to directly call the policy. + """ + + def __init__(self, ctx: ConnectorContext): + super().__init__(ctx) + + self._view_requirements = ctx.view_requirements + self._agent_data = defaultdict(lambda: defaultdict(SampleBatch)) + + def reset(self, env_id: str): + if env_id in self._agent_data: + del self._agent_data[env_id] + + def _get_sample_batch_for_action( + self, view_requirements, agent_batch + ) -> SampleBatch: + # TODO(jungong) : actually support buildling input sample batch with all the + # view shift requirements, etc. + # For now, we use some simple logics for demo purpose. + input_batch = SampleBatch() + for k, v in view_requirements.items(): + if not v.used_for_compute_actions: + continue + data_col = v.data_col or k + if data_col not in agent_batch: + continue + input_batch[k] = agent_batch[data_col][-1:] + input_batch.count = 1 + return input_batch + + def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: + d = ac_data.data + assert ( + type(d) == dict + ), "Single agent data must be of type Dict[str, TensorStructType]" + + env_id = ac_data.env_id + agent_id = ac_data.agent_id + assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id" + + vr = self._view_requirements + assert vr, "ViewRequirements required by ViewRequirementConnector" + + training_dict = {} + # We construct a proper per-timeslice dict in training mode, + # for Sampler to construct a complete episode for back propagation. + if self.is_training: + # Filter columns that are not needed for traing. + for col, req in vr.items(): + # Not used for training. + if not req.used_for_training: + continue + + # Create the batch of data from the different buffers. + data_col = req.data_col or col + if data_col not in d: + continue + + training_dict[data_col] = d[data_col] + + # Agent batch is our buffer of necessary history for computing + # a SampleBatch for policy forward pass. + # This is used by both training and inference. + agent_batch = self._agent_data[env_id][agent_id] + for col, req in vr.items(): + # Not used for action computation. + if not req.used_for_compute_actions: + continue + + # Create the batch of data from the different buffers. + data_col = req.data_col or col + if data_col not in d: + continue + + # Add batch dim to this data_col. + d_col = np.expand_dims(d[data_col], axis=0) + + if col in agent_batch: + # Stack along batch dim. + agent_batch[data_col] = np.vstack((agent_batch[data_col], d_col)) + else: + agent_batch[data_col] = d_col + # Only keep the useful part of the history. + h = req.shift_from if req.shift_from else -1 + assert h <= 0, "Can use future data to compute action" + agent_batch[data_col] = agent_batch[data_col][h:] + + sample_batch = self._get_sample_batch_for_action(vr, agent_batch) + + return_data = AgentConnectorDataType( + env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch) + ) + return return_data + + def to_config(self): + return ViewRequirementAgentConnector.__name__, None + + @staticmethod + def from_config(ctx: ConnectorContext, params: List[Any]): + return ViewRequirementAgentConnector(ctx) + + +register_connector( + ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector +) diff --git a/rllib/connectors/connector.py b/rllib/connectors/connector.py new file mode 100644 index 000000000..492683a20 --- /dev/null +++ b/rllib/connectors/connector.py @@ -0,0 +1,366 @@ +"""This file defines base types and common structures for RLlib connectors. +""" + +import abc +import gym +import logging +from typing import Any, Dict, List, Tuple + +from ray.tune.registry import RLLIB_CONNECTOR, _global_registry +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + AgentConnectorDataType, + TensorType, + TrainerConfigDict, +) + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class ConnectorContext: + """Data bits that may be needed for running connectors. + + Note(jungong) : we need to be really careful with the data fields here. + E.g., everything needs to be serializable, in case we need to fetch them + in a remote setting. + """ + + # TODO(jungong) : figure out how to fetch these in a remote setting. + # Probably from a policy server when initializing a policy client. + + def __init__( + self, + config: TrainerConfigDict = None, + model_initial_states: List[TensorType] = None, + observation_space: gym.Space = None, + action_space: gym.Space = None, + view_requirements: Dict[str, ViewRequirement] = None, + ): + """Construct a ConnectorContext instance. + + Args: + model_initial_states: States that are used for constructing + the initial input dict for RNN models. [] if a model is not recurrent. + action_space_struct: a policy's action space, in python + data format. E.g., python dict instead of DictSpace, python tuple + instead of TupleSpace. + """ + self.config = config + self.initial_states = model_initial_states or [] + self.observation_space = observation_space + self.action_space = action_space + self.view_requirements = view_requirements + + @staticmethod + def from_policy(policy: Policy) -> "ConnectorContext": + """Build ConnectorContext from a given policy. + + Args: + policy: Policy + + Returns: + A ConnectorContext instance. + """ + return ConnectorContext( + policy.config, + policy.get_initial_state(), + policy.observation_space, + policy.action_space, + policy.view_requirements, + ) + + +@DeveloperAPI +class Connector(abc.ABC): + """Connector base class. + + A connector is a step of transformation, of either envrionment data before they + get to a policy, or policy output before it is sent back to the environment. + + Connectors may be training-aware, for example, behave slightly differently + during training and inference. + + All connectors are required to be serializable and implement to_config(). + """ + + def __init__(self, ctx: ConnectorContext): + # This gets flipped to False for inference. + self.is_training = True + + def is_training(self, is_training: bool): + self.is_training = is_training + + def to_config(self) -> Tuple[str, List[Any]]: + """Serialize a connector into a JSON serializable Tuple. + + to_config is required, so that all Connectors are serializable. + + Returns: + A tuple of connector's name and its serialized states. + """ + # Must implement by each connector. + return NotImplementedError + + @staticmethod + def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector": + """De-serialize a JSON params back into a Connector. + + from_config is required, so that all Connectors are serializable. + + Args: + ctx: Context for constructing this connector. + params: Serialized states of the connector to be recovered. + + Returns: + De-serialized connector. + """ + # Must implement by each connector. + return NotImplementedError + + +@DeveloperAPI +class AgentConnector(Connector): + """Connector connecting user environments to RLlib policies. + + An agent connector transforms a single piece of data in AgentConnectorDataType + format into a list of data in the same AgentConnectorDataTypes format. + The API is designed so multi-agent observations can be broken and emitted as + multiple single agent observations. + + AgentConnectorDataTypes can be used to specify arbitrary type of env data, + + Example: + .. code-block:: python + # A dict of multi-agent data from one env step() call. + ac = AgentConnectorDataType( + env_id="env_1", + agent_id=None, + data={ + "agent_1": np.array(...), + "agent_2": np.array(...), + } + ) + + Example: + .. code-block:: python + # Single agent data ready to be preprocessed. + ac = AgentConnectorDataType( + env_id="env_1", + agent_id="agent_1", + data=np.array(...) + ) + + We can adapt a simple stateless function into an agent connector by using + register_lambda_agent_connector: + .. code-block:: python + TimesTwoAgentConnector = register_lambda_agent_connector( + "TimesTwoAgentConnector", lambda data: data * 2 + ) + + More complicated agent connectors can be implemented by extending this + AgentConnector class: + + Example: + .. code-block:: python + class FrameSkippingAgentConnector(AgentConnector): + def __init__(self, n): + self._n = n + self._frame_count = default_dict(str, default_dict(str, int)) + + def reset(self, env_id: str): + del self._frame_count[env_id] + + def __call__( + self, ac_data: AgentConnectorDataType + ) -> List[AgentConnectorDataType]: + assert ac_data.env_id and ac_data.agent_id, ( + "Frame skipping works per agent") + + count = self._frame_count[ac_data.env_id][ac_data.agent_id] + self._frame_count[ac_data.env_id][ac_data.agent_id] = count + 1 + + return [ac_data] if count % self._n == 0 else [] + + As shown, an agent connector may choose to emit an empty list to stop input + observations from being prosessed further. + """ + + def reset(self, env_id: str): + """Reset connector state for a specific environment. + + For example, at the end of an episode. + + Args: + env_id: required. ID of a user environment. Required. + """ + pass + + def on_policy_output(self, output: ActionConnectorDataType): + """Callback on agent connector of policy output. + + This is useful for certain connectors, for example RNN state buffering, + where the agent connect needs to be aware of the output of a policy + forward pass. + + Args: + ctx: Context for running this connector call. + output: Env and agent IDs, plus data output from policy forward pass. + """ + pass + + def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]: + """Transform incoming data from environment before they reach policy. + + Args: + ctx: Context for running this connector call. + data: Env and agent IDs, plus arbitrary data from an environment or + upstream agent connectors. + + Returns: + A list of transformed data in AgentConnectorDataType format. + The return type is a list because an AgentConnector may choose to + derive multiple outputs for a single input data, for example + multi-agent obs -> multiple single agent obs. + Agent connectors may also return an empty list for certain input, + useful for connectors such as frame skipping. + """ + raise NotImplementedError + + +@DeveloperAPI +class ActionConnector(Connector): + """Action connector connects policy outputs including actions, + to user environments. + + An action connector transforms a single piece of policy output in + ActionConnectorDataType format, which is basically PolicyOutputType + plus env and agent IDs. + + Any functions that operates directly on PolicyOutputType can be + easily adpated into an ActionConnector by using register_lambda_action_connector. + + Example: + .. code-block:: python + ZeroActionConnector = register_lambda_action_connector( + "ZeroActionsConnector", + lambda actions, states, fetches: ( + np.zeros_like(actions), states, fetches + ) + ) + + More complicated action connectors can also be implemented by sub-classing + this ActionConnector class. + """ + + def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType: + """Transform policy output before they are sent to a user environment. + + Args: + ctx: Context for running this connector call. + ac_data: Env and agent IDs, plus policy output. + + Returns: + The processed action connector data. + """ + raise NotImplementedError + + +@DeveloperAPI +class ConnectorPipeline: + """Utility class for quick manipulation of a connector pipeline.""" + + def remove(self, name: str): + """Remove a connector by + + Args: + name: name of the connector to be removed. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if c.__class__.__name__ == name: + break + if idx < 0: + raise ValueError(f"Can not find connector {name}") + del self.connectors[idx] + + def insert_before(self, name: str, connector: Connector): + """Insert a new connector before connector + + Args: + name: name of the connector before which a new connector + will get inserted. + connector: a new connector to be inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if c.__class__.__name__ == name: + break + if idx < 0: + raise ValueError(f"Can not find connector {name}") + self.connectors.insert(idx, connector) + + def insert_after(self, name: str, connector: Connector): + """Insert a new connector after connector + + Args: + name: name of the connector after which a new connector + will get inserted. + connector: a new connector to be inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if c.__class__.__name__ == name: + break + if idx < 0: + raise ValueError(f"Can not find connector {name}") + self.connectors.insert(idx + 1, connector) + + def prepend(self, connector: Connector): + """Append a new connector at the beginning of a connector pipeline. + + Args: + connector: a new connector to be appended. + """ + self.connectors.insert(0, connector) + + def append(self, connector: Connector): + """Append a new connector at the end of a connector pipeline. + + Args: + connector: a new connector to be appended. + """ + self.connectors.append(connector) + + +@DeveloperAPI +def register_connector(name: str, cls: Connector): + """Register a connector for use with RLlib. + + Args: + name: Name to register. + cls: Callable that creates an env. + """ + if not issubclass(cls, Connector): + raise TypeError("Can only register Connector type.", cls) + _global_registry.register(RLLIB_CONNECTOR, name, cls) + + +@DeveloperAPI +def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Connector: + """Get a connector by its name and serialized config. + + Args: + name: name of the connector. + params: serialized parameters of the connector. + + Returns: + Constructed connector. + """ + if not _global_registry.contains(RLLIB_CONNECTOR, name): + raise NameError("connector not found.", name) + cls = _global_registry.get(RLLIB_CONNECTOR, name) + return cls.from_config(ctx, params) diff --git a/rllib/connectors/tests/test_action.py b/rllib/connectors/tests/test_action.py new file mode 100644 index 000000000..ce082fa2a --- /dev/null +++ b/rllib/connectors/tests/test_action.py @@ -0,0 +1,127 @@ +import gym +import numpy as np +import unittest + +from ray.rllib.connectors.action.clip import ClipActionsConnector +from ray.rllib.connectors.action.lambdas import ( + ConvertToNumpyConnector, + UnbatchActionsConnector, +) +from ray.rllib.connectors.action.normalize import NormalizeActionsConnector +from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline +from ray.rllib.connectors.connector import ( + ConnectorContext, + get_connector, +) +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ActionConnectorDataType + +torch, _ = try_import_torch() + + +class TestActionConnector(unittest.TestCase): + def test_connector_pipeline(self): + ctx = ConnectorContext() + connectors = [ConvertToNumpyConnector(ctx)] + pipeline = ActionConnectorPipeline(ctx, connectors) + name, params = pipeline.to_config() + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, ActionConnectorPipeline)) + self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector)) + + def test_convert_to_numpy_connector(self): + ctx = ConnectorContext() + c = ConvertToNumpyConnector(ctx) + + name, params = c.to_config() + + self.assertEqual(name, "ConvertToNumpyConnector") + + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, ConvertToNumpyConnector)) + + action = torch.Tensor([8, 9]) + states = torch.Tensor([[1, 1, 1], [2, 2, 2]]) + ac_data = ActionConnectorDataType(0, 1, (action, states, {})) + + converted = c(ac_data) + self.assertTrue(isinstance(converted.output[0], np.ndarray)) + self.assertTrue(isinstance(converted.output[1], np.ndarray)) + + def test_unbatch_action_connector(self): + ctx = ConnectorContext() + c = UnbatchActionsConnector(ctx) + + name, params = c.to_config() + + self.assertEqual(name, "UnbatchActionsConnector") + + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, UnbatchActionsConnector)) + + ac_data = ActionConnectorDataType( + 0, + 1, + ( + { + "a": np.array([1, 2, 3]), + "b": (np.array([4, 5, 6]), np.array([7, 8, 9])), + }, + [], + {}, + ), + ) + + unbatched = c(ac_data) + actions, _, _ = unbatched.output + + self.assertEqual(len(actions), 3) + self.assertEqual(actions[0]["a"], 1) + self.assertTrue((actions[0]["b"] == np.array((4, 7))).all()) + + self.assertEqual(actions[1]["a"], 2) + self.assertTrue((actions[1]["b"] == np.array((5, 8))).all()) + + self.assertEqual(actions[2]["a"], 3) + self.assertTrue((actions[2]["b"] == np.array((6, 9))).all()) + + def test_normalize_action_connector(self): + ctx = ConnectorContext( + action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]) + ) + c = NormalizeActionsConnector(ctx) + + name, params = c.to_config() + self.assertEqual(name, "NormalizeActionsConnector") + + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, NormalizeActionsConnector)) + + ac_data = ActionConnectorDataType(0, 1, (0.5, [], {})) + + normalized = c(ac_data) + self.assertEqual(normalized.output[0], 4.5) + + def test_clip_action_connector(self): + ctx = ConnectorContext( + action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1]) + ) + c = ClipActionsConnector(ctx) + + name, params = c.to_config() + self.assertEqual(name, "ClipActionsConnector") + + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, ClipActionsConnector)) + + ac_data = ActionConnectorDataType(0, 1, (8.8, [], {})) + + clipped = c(ac_data) + self.assertEqual(clipped.output[0], 6.0) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/tests/test_agent.py b/rllib/connectors/tests/test_agent.py new file mode 100644 index 000000000..fb55d0cd2 --- /dev/null +++ b/rllib/connectors/tests/test_agent.py @@ -0,0 +1,186 @@ +import gym +import numpy as np +import unittest + +from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector +from ray.rllib.connectors.agent.env_to_agent import EnvToAgentDataConnector +from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector +from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector +from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline +from ray.rllib.connectors.connector import ( + ConnectorContext, + get_connector, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.typing import ( + AgentConnectorDataType, +) + + +class TestAgentConnector(unittest.TestCase): + def test_connector_pipeline(self): + ctx = ConnectorContext() + connectors = [ClipRewardAgentConnector(ctx, False, 1.0)] + pipeline = AgentConnectorPipeline(ctx, connectors) + name, params = pipeline.to_config() + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, AgentConnectorPipeline)) + self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector)) + + def test_env_to_per_agent_data_connector(self): + vrs = { + "infos": ViewRequirement( + "infos", + used_for_training=True, + used_for_compute_actions=False, + ) + } + ctx = ConnectorContext(view_requirements=vrs) + + c = EnvToAgentDataConnector(ctx) + + name, params = c.to_config() + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, EnvToAgentDataConnector)) + + d = AgentConnectorDataType( + 0, + None, + [ + # obs + {1: [8, 8], 2: [9, 9]}, + # rewards + { + 1: 8.8, + 2: 9.9, + }, + # dones + { + 1: False, + 2: False, + }, + # infos + { + 1: {"random": "info"}, + 2: {}, + }, + # training_episode_info + { + 1: {SampleBatch.DONES: True}, + }, + ], + ) + per_agent = c(d) + + self.assertEqual(len(per_agent), 2) + + batch1 = per_agent[0].data + self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8]) + self.assertTrue(batch1[SampleBatch.DONES]) # from training_episode_info + self.assertTrue(SampleBatch.INFOS in batch1) + self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info") + + batch2 = per_agent[1].data + self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9]) + self.assertFalse(batch2[SampleBatch.DONES]) + + def test_obs_preprocessor_connector(self): + obs_space = gym.spaces.Dict( + { + "a": gym.spaces.Box(low=0, high=1, shape=(1,)), + "b": gym.spaces.Tuple( + [gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])] + ), + } + ) + ctx = ConnectorContext(config={}, observation_space=obs_space) + + c = ObsPreprocessorConnector(ctx) + name, params = c.to_config() + + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, ObsPreprocessorConnector)) + + obs = obs_space.sample() + # Fake deterministic data. + obs["a"][0] = 0.5 + obs["b"] = (1, np.array([0, 2])) + + d = AgentConnectorDataType( + 0, + 1, + { + SampleBatch.OBS: obs, + }, + ) + preprocessed = c(d) + + # obs is completely flattened. + self.assertTrue( + (preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all() + ) + + def test_clip_reward_connector(self): + ctx = ConnectorContext() + + c = ClipRewardAgentConnector(ctx, limit=2.0) + name, params = c.to_config() + + self.assertEqual(name, "ClipRewardAgentConnector") + self.assertAlmostEqual(params["limit"], 2.0) + + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, ClipRewardAgentConnector)) + + d = AgentConnectorDataType( + 0, + 1, + { + SampleBatch.REWARDS: 5.8, + }, + ) + clipped = restored(ac_data=d) + + self.assertEqual(len(clipped), 1) + self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0) + + def test_flatten_data_connector(self): + ctx = ConnectorContext() + + c = FlattenDataAgentConnector(ctx) + + name, params = c.to_config() + restored = get_connector(ctx, name, params) + self.assertTrue(isinstance(restored, FlattenDataAgentConnector)) + + d = AgentConnectorDataType( + 0, + 1, + { + SampleBatch.NEXT_OBS: { + "sensor1": [[1, 1], [2, 2]], + "sensor2": 8.8, + }, + SampleBatch.REWARDS: 5.8, + SampleBatch.ACTIONS: [[1, 1], [2]], + SampleBatch.INFOS: {"random": "info"}, + }, + ) + + flattened = c(d) + self.assertEqual(len(flattened), 1) + + batch = flattened[0].data + self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all()) + self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8) + # Not flattened. + self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2) + self.assertEqual(batch[SampleBatch.INFOS]["random"], "info") + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/tests/test_connector.py b/rllib/connectors/tests/test_connector.py new file mode 100644 index 000000000..89b19fafc --- /dev/null +++ b/rllib/connectors/tests/test_connector.py @@ -0,0 +1,59 @@ +import unittest + +from ray.rllib.connectors.connector import Connector, ConnectorPipeline + + +class TestConnectorPipeline(unittest.TestCase): + class Tom(Connector): + def to_config(): + return "tom" + + class Bob(Connector): + def to_config(): + return "bob" + + class Mary(Connector): + def to_config(): + return "mary" + + class MockConnectorPipeline(ConnectorPipeline): + def __init__(self, ctx, connectors): + # Real connector pipelines should keep a list of + # Connectors. + # Use strings here for simple unit tests. + self.connectors = connectors + + def test_sanity_check(self): + ctx = {} + + m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) + m.insert_before("Bob", self.Mary(ctx)) + self.assertEqual(len(m.connectors), 3) + self.assertEqual(m.connectors[1].__class__.__name__, "Mary") + + m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) + m.insert_after("Tom", self.Mary(ctx)) + self.assertEqual(len(m.connectors), 3) + self.assertEqual(m.connectors[1].__class__.__name__, "Mary") + + m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) + m.prepend(self.Mary(ctx)) + self.assertEqual(len(m.connectors), 3) + self.assertEqual(m.connectors[0].__class__.__name__, "Mary") + + m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)]) + m.append(self.Mary(ctx)) + self.assertEqual(len(m.connectors), 3) + self.assertEqual(m.connectors[2].__class__.__name__, "Mary") + + m.remove("Bob") + self.assertEqual(len(m.connectors), 2) + self.assertEqual(m.connectors[0].__class__.__name__, "Tom") + self.assertEqual(m.connectors[1].__class__.__name__, "Mary") + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/connectors/util.py b/rllib/connectors/util.py new file mode 100644 index 000000000..401fecac3 --- /dev/null +++ b/rllib/connectors/util.py @@ -0,0 +1,9 @@ +from ray.rllib.connectors.connector import ( + Connector, + get_connector, +) +from typing import Dict + + +def get_connectors_from_cfg(config: dict) -> Dict[str, Connector]: + return {k: get_connector(*v) for k, v in config.items()} diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index f5c6f5e67..baf35b057 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -276,6 +276,8 @@ class JsonReader(InputReader): # Clip actions (from any values into env's bounds), if necessary. cfg = self.ioctx.config + # TODO(jungong) : we should not clip_action in input reader. + # Use connector to handle this. if cfg.get("clip_actions") and self.ioctx.worker is not None: if isinstance(batch, SampleBatch): batch[SampleBatch.ACTIONS] = clip_action( diff --git a/rllib/utils/serialization.py b/rllib/utils/serialization.py index 3160188de..ba489bfd9 100644 --- a/rllib/utils/serialization.py +++ b/rllib/utils/serialization.py @@ -8,6 +8,10 @@ import zlib from ray.rllib.utils.annotations import DeveloperAPI +# TODO(jungong) : We need to handle RLlib custom space types, +# FlexDict, Repeated, and Simplex. + + def _serialize_ndarray(array: np.ndarray) -> str: """Pack numpy ndarray into Base64 encoded strings for serialization. diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index 36f9004e3..ea5c9dd5b 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -182,7 +182,10 @@ def unbatch(batches_struct): """Converts input from (nested) struct of batches to batch of structs. Input: Struct of different batches (each batch has size=3): - {"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])} + { + "a": np.array([1, 2, 3]), + "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0])) + } Output: Batch (list) of structs (each of these structs representing a single action): [ diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 4f511025f..f8ad312ad 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -4,6 +4,7 @@ from typing import ( Callable, Dict, List, + NamedTuple, Optional, Tuple, Type, @@ -12,6 +13,8 @@ from typing import ( Union, ) +from ray.rllib.utils.annotations import DeveloperAPI + if TYPE_CHECKING: from ray.rllib.env.env_context import EnvContext from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 @@ -149,5 +152,35 @@ SampleBatchType = Union["SampleBatch", "MultiAgentBatch"] # (possibly nested) dict|tuple of gym.space.Spaces. SpaceStruct = Union[gym.spaces.Space, dict, tuple] +# A list of batches of RNN states. +# Each item in this list has dimension [B, S] (S=state vector size) +StateBatches = List[List[Any]] + +# Format of data output from policy forward pass. +PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict] + +# Data type that is fed into and yielded from agent connectors. +AgentConnectorDataType = DeveloperAPI( # API stability declaration. + NamedTuple( + "AgentConnectorDataType", [("env_id", str), ("agent_id", str), ("data", Any)] + ) +) + +# Data type that is fed into and yielded from agent connectors. +ActionConnectorDataType = DeveloperAPI( # API stability declaration. + NamedTuple( + "ActionConnectorDataType", + [("env_id", str), ("agent_id", str), ("output", PolicyOutputType)], + ) +) + +# Final output data type of agent connectors. +AgentConnectorsOutput = DeveloperAPI( # API stability declaration. + NamedTuple( + "AgentConnectorsOut", + [("for_training", Dict[str, TensorStructType]), ("for_action", "SampleBatch")], + ) +) + # Generic type var. T = TypeVar("T")