[RLlib] Introduce basic connectors library. (#25311)

This commit is contained in:
Jun Gong 2022-06-07 10:18:14 -07:00 committed by GitHub
parent 4e887fe776
commit 9b65d5535d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1621 additions and 1 deletions

View file

@ -21,6 +21,7 @@ RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
RLLIB_INPUT = "rllib_input"
RLLIB_CONNECTOR = "rllib_connector"
TEST = "__test__"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS,
@ -29,6 +30,7 @@ KNOWN_CATEGORIES = [
RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST,
RLLIB_INPUT,
RLLIB_CONNECTOR,
TEST,
]

View file

@ -1292,6 +1292,34 @@ py_test(
]
)
# --------------------------------------------------------------------
# Connector tests
# rllib/connector/
#
# Tag: connector
# --------------------------------------------------------------------
py_test(
name = "test_connector",
tags = ["team:ml", "connector"],
size = "small",
srcs = ["connectors/tests/test_connector.py"]
)
py_test(
name = "test_action",
tags = ["team:ml", "connector"],
size = "small",
srcs = ["connectors/tests/test_action.py"]
)
py_test(
name = "test_agent",
tags = ["team:ml", "connector"],
size = "small",
srcs = ["connectors/tests/test_agent.py"]
)
# --------------------------------------------------------------------
# Env tests
# rllib/env/

View file

View file

@ -0,0 +1,43 @@
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
ActionConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.spaces.space_utils import (
clip_action,
get_base_struct_from_space,
)
from ray.rllib.utils.typing import ActionConnectorDataType
@DeveloperAPI
class ClipActionsConnector(ActionConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
assert isinstance(
ac_data.output, tuple
), "Action connector requires PolicyOutputType data."
actions, states, fetches = ac_data.output
return ActionConnectorDataType(
ac_data.env_id,
ac_data.agent_id,
(clip_action(actions, self._action_space_struct), states, fetches),
)
def to_config(self):
return ClipActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return ClipActionsConnector(ctx)
register_connector(ClipActionsConnector.__name__, ClipActionsConnector)

View file

@ -0,0 +1,79 @@
from typing import Any, Callable, Dict, List, Type
from ray.rllib.connectors.connector import (
ConnectorContext,
ActionConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import unbatch
from ray.rllib.utils.typing import (
ActionConnectorDataType,
PolicyOutputType,
StateBatches,
TensorStructType,
)
@DeveloperAPI
def register_lambda_action_connector(
name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
) -> Type[ActionConnector]:
"""A util to register any function transforming PolicyOutputType as an ActionConnector.
The only requirement is that fn should take actions, states, and fetches as input,
and return transformed actions, states, and fetches.
Args:
name: Name of the resulting actor connector.
fn: The function that transforms PolicyOutputType.
Returns:
A new ActionConnector class that transforms PolicyOutputType using fn.
"""
class LambdaActionConnector(ActionConnector):
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
assert isinstance(
ac_data.output, tuple
), "Action connector requires PolicyOutputType data."
actions, states, fetches = ac_data.output
return ActionConnectorDataType(
ac_data.env_id,
ac_data.agent_id,
fn(actions, states, fetches),
)
def to_config(self):
return name, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return LambdaActionConnector(ctx)
LambdaActionConnector.__name__ = name
LambdaActionConnector.__qualname__ = name
register_connector(name, LambdaActionConnector)
return LambdaActionConnector
# Convert actions and states into numpy arrays if necessary.
ConvertToNumpyConnector = register_lambda_action_connector(
"ConvertToNumpyConnector",
lambda actions, states, fetches: (
convert_to_numpy(actions),
convert_to_numpy(states),
fetches,
),
)
# Split action-component batches into single action rows.
UnbatchActionsConnector = register_lambda_action_connector(
"UnbatchActionsConnector",
lambda actions, states, fetches: (unbatch(actions), states, fetches),
)

View file

@ -0,0 +1,43 @@
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
ActionConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.spaces.space_utils import (
get_base_struct_from_space,
unsquash_action,
)
from ray.rllib.utils.typing import ActionConnectorDataType
@DeveloperAPI
class NormalizeActionsConnector(ActionConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
assert isinstance(
ac_data.output, tuple
), "Action connector requires PolicyOutputType data."
actions, states, fetches = ac_data.output
return ActionConnectorDataType(
ac_data.env_id,
ac_data.agent_id,
(unsquash_action(actions, self._action_space_struct), states, fetches),
)
def to_config(self):
return NormalizeActionsConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return NormalizeActionsConnector(ctx)
register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)

View file

@ -0,0 +1,57 @@
import gym
from typing import Any, List
from ray.rllib.connectors.connector import (
ActionConnector,
Connector,
ConnectorContext,
ConnectorPipeline,
get_connector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import (
ActionConnectorDataType,
TrainerConfigDict,
)
@DeveloperAPI
class ActionConnectorPipeline(ActionConnector, ConnectorPipeline):
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
super().__init__(ctx)
self.connectors = connectors
def is_training(self, is_training: bool):
self.is_training = is_training
for c in self.connectors:
c.is_training(is_training)
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
for c in self.connectors:
ac_data = c(ac_data)
return ac_data
def to_config(self):
return ActionConnectorPipeline.__name__, [
c.to_config() for c in self.connectors
]
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
), "ActionConnectorPipeline takes a list of connector params."
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
return ActionConnectorPipeline(ctx, connectors)
register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
@DeveloperAPI
def get_action_connectors_from_trainer_config(
config: TrainerConfigDict, action_space: gym.Space
) -> ActionConnectorPipeline:
connectors = []
return ActionConnectorPipeline(connectors)

View file

@ -0,0 +1,52 @@
import numpy as np
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
@DeveloperAPI
class ClipRewardAgentConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
super().__init__(ctx)
assert (
not sign or not limit
), "should not enable both sign and limit reward clipping."
self.sign = sign
self.limit = limit
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
assert SampleBatch.REWARDS in d, "input data does not have reward column."
if self.sign:
d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS])
elif self.limit:
d[SampleBatch.REWARDS] = np.clip(
d[SampleBatch.REWARDS],
a_min=-self.limit,
a_max=self.limit,
)
return [ac_data]
def to_config(self):
return ClipRewardAgentConnector.__name__, {
"sign": self.sign,
"limit": self.limit,
}
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return ClipRewardAgentConnector(ctx, **params)
register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)

View file

@ -0,0 +1,72 @@
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentConnectorDataType
@DeveloperAPI
class EnvToAgentDataConnector(AgentConnector):
"""Converts per environment multi-agent obs into per agent SampleBatches."""
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._view_requirements = ctx.view_requirements
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
if ac_data.agent_id:
# data is already for a single agent.
return [ac_data]
assert isinstance(ac_data.data, (tuple, list)) and len(ac_data.data) == 5, (
"EnvToPerAgentDataConnector expects a tuple of "
+ "(obs, rewards, dones, infos, episode_infos)."
)
# episode_infos contains additional training related data bits
# for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX,
# SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon),
# and is usually empty in inference mode.
obs, rewards, dones, infos, training_episode_infos = ac_data.data
for var, name in zip(
(obs, rewards, dones, infos, training_episode_infos),
("obs", "rewards", "dones", "infos", "training_episode_infos"),
):
assert isinstance(var, dict), (
f"EnvToPerAgentDataConnector expects {name} "
+ "to be a MultiAgentDict."
)
env_id = ac_data.env_id
per_agent_data = []
for agent_id, obs in obs.items():
input_dict = {
SampleBatch.ENV_ID: env_id,
SampleBatch.REWARDS: rewards[agent_id],
# SampleBatch.DONES may be overridden by data from
# training_episode_infos next.
SampleBatch.DONES: dones[agent_id],
SampleBatch.NEXT_OBS: obs,
}
if SampleBatch.INFOS in self._view_requirements:
input_dict[SampleBatch.INFOS] = infos[agent_id]
if agent_id in training_episode_infos:
input_dict.update(training_episode_infos[agent_id])
per_agent_data.append(AgentConnectorDataType(env_id, agent_id, input_dict))
return per_agent_data
def to_config(self):
return EnvToAgentDataConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return EnvToAgentDataConnector(ctx)
register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)

View file

@ -0,0 +1,81 @@
import numpy as np
import tree # dm_tree
from typing import Any, Callable, Dict, List, Type
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import (
AgentConnectorDataType,
TensorStructType,
)
@DeveloperAPI
def register_lambda_agent_connector(
name: str, fn: Callable[[Any], Any]
) -> Type[AgentConnector]:
"""A util to register any simple transforming function as an AgentConnector
The only requirement is that fn should take a single data object and return
a single data object.
Args:
name: Name of the resulting actor connector.
fn: The function that transforms env / agent data.
Returns:
A new AgentConnector class that transforms data using fn.
"""
class LambdaAgentConnector(AgentConnector):
def __call__(
self, ac_data: AgentConnectorDataType
) -> List[AgentConnectorDataType]:
d = ac_data.data
return [AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))]
def to_config(self):
return name, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return LambdaAgentConnector(ctx)
LambdaAgentConnector.__name__ = name
LambdaAgentConnector.__qualname__ = name
register_connector(name, LambdaAgentConnector)
return LambdaAgentConnector
@DeveloperAPI
def flatten_data(data: Dict[str, TensorStructType]):
assert (
type(data) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
flattened = {}
for k, v in data.items():
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
# Do not flatten infos, actions, and state_out_ columns.
flattened[k] = v
continue
if v is None:
# Keep the same column shape.
flattened[k] = None
continue
flattened[k] = np.array(tree.flatten(v))
return flattened
# Flatten observation data.
FlattenDataAgentConnector = register_lambda_agent_connector(
"FlattenDataAgentConnector", flatten_data
)

View file

@ -0,0 +1,60 @@
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentConnectorDataType
# Bridging between current obs preprocessors and connector.
# We should not introduce any new preprocessors.
# TODO(jungong) : migrate and implement preprocessor library in Connector framework.
@DeveloperAPI
class ObsPreprocessorConnector(AgentConnector):
"""A connector that wraps around existing RLlib observation preprocessors.
This includes:
- OneHotPreprocessor for Discrete and Multi-Discrete spaces.
- GenericPixelPreprocessor and AtariRamPreprocessor for Atari spaces.
- TupleFlatteningPreprocessor and DictFlatteningPreprocessor for flattening
arbitrary nested input observations.
- RepeatedValuesPreprocessor for padding observations from RLlib Repeated
observation space.
"""
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._preprocessor = get_preprocessor(ctx.observation_space)(
ctx.observation_space, ctx.config.get("model", {})
)
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
if SampleBatch.OBS in d:
d[SampleBatch.OBS] = self._preprocessor.transform(d[SampleBatch.OBS])
if SampleBatch.NEXT_OBS in d:
d[SampleBatch.NEXT_OBS] = self._preprocessor.transform(
d[SampleBatch.NEXT_OBS]
)
return [ac_data]
def to_config(self):
return ObsPreprocessorConnector.__name__, {}
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return ObsPreprocessorConnector(ctx, **params)
register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)

View file

@ -0,0 +1,79 @@
import gym
from typing import Any, List
from ray.rllib.connectors.connector import (
Connector,
ConnectorContext,
ConnectorPipeline,
AgentConnector,
register_connector,
get_connector,
)
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
TrainerConfigDict,
)
@DeveloperAPI
class AgentConnectorPipeline(AgentConnector, ConnectorPipeline):
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
super().__init__(ctx)
self.connectors = connectors
def is_training(self, is_training: bool):
self.is_training = is_training
for c in self.connectors:
c.is_training(is_training)
def reset(self, env_id: str):
for c in self.connectors:
c.reset(env_id)
def on_policy_output(self, output: ActionConnectorDataType):
for c in self.connectors:
c.on_policy_output(output)
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
ret = [ac_data]
for c in self.connectors:
# Run the list of input data through the next agent connect,
# and collect the list of output data.
new_ret = []
for d in ret:
new_ret += c(d)
ret = new_ret
return ret
def to_config(self):
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors]
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
), "AgentConnectorPipeline takes a list of connector params."
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
return AgentConnectorPipeline(ctx, connectors)
register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
# TODO(jungong) : finish this.
@DeveloperAPI
def get_agent_connectors_from_config(
config: TrainerConfigDict, obs_space: gym.Space
) -> AgentConnectorPipeline:
connectors = [FlattenDataAgentConnector()]
if config["clip_rewards"] is True:
connectors.append(ClipRewardAgentConnector(sign=True))
elif type(config["clip_rewards"]) == float:
connectors.append(ClipRewardAgentConnector(limit=abs(config["clip_rewards"])))
return AgentConnectorPipeline(connectors)

View file

@ -0,0 +1,99 @@
from collections import defaultdict
import numpy as np
import tree # dm_tree
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import (
AgentConnectorDataType,
PolicyOutputType,
)
@DeveloperAPI
class _AgentState(object):
def __init__(self):
self.t = 0
self.action = None
self.states = None
@DeveloperAPI
class StateBufferConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._initial_states = ctx.initial_states
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
self._states = defaultdict(lambda: defaultdict(_AgentState))
def reset(self, env_id: str):
del self._states[env_id]
def on_policy_output(self, env_id: str, agent_id: str, output: PolicyOutputType):
# Buffer latest output states for next input __call__.
action, states, _ = output
agent_state = self._states[env_id][agent_id]
agent_state.action = convert_to_numpy(action)
agent_state.states = convert_to_numpy(states)
def __call__(
self, ctx: ConnectorContext, ac_data: AgentConnectorDataType
) -> List[AgentConnectorDataType]:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
env_id = ac_data.env_id
agent_id = ac_data.agent_id
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
agent_state = self._states[env_id][agent_id]
d.update(
{
SampleBatch.T: agent_state.t,
SampleBatch.ENV_ID: env_id,
}
)
if agent_state.states is not None:
states = agent_state.states
else:
states = self._initial_states
for i, v in enumerate(states):
d["state_out_{}".format(i)] = v
if agent_state.action is not None:
d[SampleBatch.ACTIONS] = agent_state.action # Last action
else:
# Default zero action.
d[SampleBatch.ACTIONS] = tree.map_structure(
lambda s: np.zeros_like(s.sample(), s.dtype)
if hasattr(s, "dtype")
else np.zeros_like(s.sample()),
self._action_space_struct,
)
agent_state.t += 1
return [ac_data]
def to_config(self):
return StateBufferConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return StateBufferConnector(ctx)
register_connector(StateBufferConnector.__name__, StateBufferConnector)

View file

@ -0,0 +1,136 @@
from collections import defaultdict
import numpy as np
from typing import Any, List
from ray.rllib.connectors.connector import (
ConnectorContext,
AgentConnector,
register_connector,
)
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
AgentConnectorDataType,
AgentConnectorsOutput,
)
@DeveloperAPI
class ViewRequirementAgentConnector(AgentConnector):
"""This connector does 2 things:
1. It filters data columns based on view_requirements for training and inference.
2. It buffers the right amount of history for computing the sample batch for
action computation.
The output of this connector is AgentConnectorsOut, which basically is
a tuple of 2 things:
{
"for_training": {"obs": ...}
"for_action": SampleBatch
}
The "for_training" dict, which contains data for the latest time slice,
can be used to construct a complete episode by Sampler for training purpose.
The "for_action" SampleBatch can be used to directly call the policy.
"""
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
self._view_requirements = ctx.view_requirements
self._agent_data = defaultdict(lambda: defaultdict(SampleBatch))
def reset(self, env_id: str):
if env_id in self._agent_data:
del self._agent_data[env_id]
def _get_sample_batch_for_action(
self, view_requirements, agent_batch
) -> SampleBatch:
# TODO(jungong) : actually support buildling input sample batch with all the
# view shift requirements, etc.
# For now, we use some simple logics for demo purpose.
input_batch = SampleBatch()
for k, v in view_requirements.items():
if not v.used_for_compute_actions:
continue
data_col = v.data_col or k
if data_col not in agent_batch:
continue
input_batch[k] = agent_batch[data_col][-1:]
input_batch.count = 1
return input_batch
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
d = ac_data.data
assert (
type(d) == dict
), "Single agent data must be of type Dict[str, TensorStructType]"
env_id = ac_data.env_id
agent_id = ac_data.agent_id
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
vr = self._view_requirements
assert vr, "ViewRequirements required by ViewRequirementConnector"
training_dict = {}
# We construct a proper per-timeslice dict in training mode,
# for Sampler to construct a complete episode for back propagation.
if self.is_training:
# Filter columns that are not needed for traing.
for col, req in vr.items():
# Not used for training.
if not req.used_for_training:
continue
# Create the batch of data from the different buffers.
data_col = req.data_col or col
if data_col not in d:
continue
training_dict[data_col] = d[data_col]
# Agent batch is our buffer of necessary history for computing
# a SampleBatch for policy forward pass.
# This is used by both training and inference.
agent_batch = self._agent_data[env_id][agent_id]
for col, req in vr.items():
# Not used for action computation.
if not req.used_for_compute_actions:
continue
# Create the batch of data from the different buffers.
data_col = req.data_col or col
if data_col not in d:
continue
# Add batch dim to this data_col.
d_col = np.expand_dims(d[data_col], axis=0)
if col in agent_batch:
# Stack along batch dim.
agent_batch[data_col] = np.vstack((agent_batch[data_col], d_col))
else:
agent_batch[data_col] = d_col
# Only keep the useful part of the history.
h = req.shift_from if req.shift_from else -1
assert h <= 0, "Can use future data to compute action"
agent_batch[data_col] = agent_batch[data_col][h:]
sample_batch = self._get_sample_batch_for_action(vr, agent_batch)
return_data = AgentConnectorDataType(
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
)
return return_data
def to_config(self):
return ViewRequirementAgentConnector.__name__, None
@staticmethod
def from_config(ctx: ConnectorContext, params: List[Any]):
return ViewRequirementAgentConnector(ctx)
register_connector(
ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector
)

View file

@ -0,0 +1,366 @@
"""This file defines base types and common structures for RLlib connectors.
"""
import abc
import gym
import logging
from typing import Any, Dict, List, Tuple
from ray.tune.registry import RLLIB_CONNECTOR, _global_registry
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
TensorType,
TrainerConfigDict,
)
logger = logging.getLogger(__name__)
@DeveloperAPI
class ConnectorContext:
"""Data bits that may be needed for running connectors.
Note(jungong) : we need to be really careful with the data fields here.
E.g., everything needs to be serializable, in case we need to fetch them
in a remote setting.
"""
# TODO(jungong) : figure out how to fetch these in a remote setting.
# Probably from a policy server when initializing a policy client.
def __init__(
self,
config: TrainerConfigDict = None,
model_initial_states: List[TensorType] = None,
observation_space: gym.Space = None,
action_space: gym.Space = None,
view_requirements: Dict[str, ViewRequirement] = None,
):
"""Construct a ConnectorContext instance.
Args:
model_initial_states: States that are used for constructing
the initial input dict for RNN models. [] if a model is not recurrent.
action_space_struct: a policy's action space, in python
data format. E.g., python dict instead of DictSpace, python tuple
instead of TupleSpace.
"""
self.config = config
self.initial_states = model_initial_states or []
self.observation_space = observation_space
self.action_space = action_space
self.view_requirements = view_requirements
@staticmethod
def from_policy(policy: Policy) -> "ConnectorContext":
"""Build ConnectorContext from a given policy.
Args:
policy: Policy
Returns:
A ConnectorContext instance.
"""
return ConnectorContext(
policy.config,
policy.get_initial_state(),
policy.observation_space,
policy.action_space,
policy.view_requirements,
)
@DeveloperAPI
class Connector(abc.ABC):
"""Connector base class.
A connector is a step of transformation, of either envrionment data before they
get to a policy, or policy output before it is sent back to the environment.
Connectors may be training-aware, for example, behave slightly differently
during training and inference.
All connectors are required to be serializable and implement to_config().
"""
def __init__(self, ctx: ConnectorContext):
# This gets flipped to False for inference.
self.is_training = True
def is_training(self, is_training: bool):
self.is_training = is_training
def to_config(self) -> Tuple[str, List[Any]]:
"""Serialize a connector into a JSON serializable Tuple.
to_config is required, so that all Connectors are serializable.
Returns:
A tuple of connector's name and its serialized states.
"""
# Must implement by each connector.
return NotImplementedError
@staticmethod
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
"""De-serialize a JSON params back into a Connector.
from_config is required, so that all Connectors are serializable.
Args:
ctx: Context for constructing this connector.
params: Serialized states of the connector to be recovered.
Returns:
De-serialized connector.
"""
# Must implement by each connector.
return NotImplementedError
@DeveloperAPI
class AgentConnector(Connector):
"""Connector connecting user environments to RLlib policies.
An agent connector transforms a single piece of data in AgentConnectorDataType
format into a list of data in the same AgentConnectorDataTypes format.
The API is designed so multi-agent observations can be broken and emitted as
multiple single agent observations.
AgentConnectorDataTypes can be used to specify arbitrary type of env data,
Example:
.. code-block:: python
# A dict of multi-agent data from one env step() call.
ac = AgentConnectorDataType(
env_id="env_1",
agent_id=None,
data={
"agent_1": np.array(...),
"agent_2": np.array(...),
}
)
Example:
.. code-block:: python
# Single agent data ready to be preprocessed.
ac = AgentConnectorDataType(
env_id="env_1",
agent_id="agent_1",
data=np.array(...)
)
We can adapt a simple stateless function into an agent connector by using
register_lambda_agent_connector:
.. code-block:: python
TimesTwoAgentConnector = register_lambda_agent_connector(
"TimesTwoAgentConnector", lambda data: data * 2
)
More complicated agent connectors can be implemented by extending this
AgentConnector class:
Example:
.. code-block:: python
class FrameSkippingAgentConnector(AgentConnector):
def __init__(self, n):
self._n = n
self._frame_count = default_dict(str, default_dict(str, int))
def reset(self, env_id: str):
del self._frame_count[env_id]
def __call__(
self, ac_data: AgentConnectorDataType
) -> List[AgentConnectorDataType]:
assert ac_data.env_id and ac_data.agent_id, (
"Frame skipping works per agent")
count = self._frame_count[ac_data.env_id][ac_data.agent_id]
self._frame_count[ac_data.env_id][ac_data.agent_id] = count + 1
return [ac_data] if count % self._n == 0 else []
As shown, an agent connector may choose to emit an empty list to stop input
observations from being prosessed further.
"""
def reset(self, env_id: str):
"""Reset connector state for a specific environment.
For example, at the end of an episode.
Args:
env_id: required. ID of a user environment. Required.
"""
pass
def on_policy_output(self, output: ActionConnectorDataType):
"""Callback on agent connector of policy output.
This is useful for certain connectors, for example RNN state buffering,
where the agent connect needs to be aware of the output of a policy
forward pass.
Args:
ctx: Context for running this connector call.
output: Env and agent IDs, plus data output from policy forward pass.
"""
pass
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
"""Transform incoming data from environment before they reach policy.
Args:
ctx: Context for running this connector call.
data: Env and agent IDs, plus arbitrary data from an environment or
upstream agent connectors.
Returns:
A list of transformed data in AgentConnectorDataType format.
The return type is a list because an AgentConnector may choose to
derive multiple outputs for a single input data, for example
multi-agent obs -> multiple single agent obs.
Agent connectors may also return an empty list for certain input,
useful for connectors such as frame skipping.
"""
raise NotImplementedError
@DeveloperAPI
class ActionConnector(Connector):
"""Action connector connects policy outputs including actions,
to user environments.
An action connector transforms a single piece of policy output in
ActionConnectorDataType format, which is basically PolicyOutputType
plus env and agent IDs.
Any functions that operates directly on PolicyOutputType can be
easily adpated into an ActionConnector by using register_lambda_action_connector.
Example:
.. code-block:: python
ZeroActionConnector = register_lambda_action_connector(
"ZeroActionsConnector",
lambda actions, states, fetches: (
np.zeros_like(actions), states, fetches
)
)
More complicated action connectors can also be implemented by sub-classing
this ActionConnector class.
"""
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
"""Transform policy output before they are sent to a user environment.
Args:
ctx: Context for running this connector call.
ac_data: Env and agent IDs, plus policy output.
Returns:
The processed action connector data.
"""
raise NotImplementedError
@DeveloperAPI
class ConnectorPipeline:
"""Utility class for quick manipulation of a connector pipeline."""
def remove(self, name: str):
"""Remove a connector by <name>
Args:
name: name of the connector to be removed.
"""
idx = -1
for idx, c in enumerate(self.connectors):
if c.__class__.__name__ == name:
break
if idx < 0:
raise ValueError(f"Can not find connector {name}")
del self.connectors[idx]
def insert_before(self, name: str, connector: Connector):
"""Insert a new connector before connector <name>
Args:
name: name of the connector before which a new connector
will get inserted.
connector: a new connector to be inserted.
"""
idx = -1
for idx, c in enumerate(self.connectors):
if c.__class__.__name__ == name:
break
if idx < 0:
raise ValueError(f"Can not find connector {name}")
self.connectors.insert(idx, connector)
def insert_after(self, name: str, connector: Connector):
"""Insert a new connector after connector <name>
Args:
name: name of the connector after which a new connector
will get inserted.
connector: a new connector to be inserted.
"""
idx = -1
for idx, c in enumerate(self.connectors):
if c.__class__.__name__ == name:
break
if idx < 0:
raise ValueError(f"Can not find connector {name}")
self.connectors.insert(idx + 1, connector)
def prepend(self, connector: Connector):
"""Append a new connector at the beginning of a connector pipeline.
Args:
connector: a new connector to be appended.
"""
self.connectors.insert(0, connector)
def append(self, connector: Connector):
"""Append a new connector at the end of a connector pipeline.
Args:
connector: a new connector to be appended.
"""
self.connectors.append(connector)
@DeveloperAPI
def register_connector(name: str, cls: Connector):
"""Register a connector for use with RLlib.
Args:
name: Name to register.
cls: Callable that creates an env.
"""
if not issubclass(cls, Connector):
raise TypeError("Can only register Connector type.", cls)
_global_registry.register(RLLIB_CONNECTOR, name, cls)
@DeveloperAPI
def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Connector:
"""Get a connector by its name and serialized config.
Args:
name: name of the connector.
params: serialized parameters of the connector.
Returns:
Constructed connector.
"""
if not _global_registry.contains(RLLIB_CONNECTOR, name):
raise NameError("connector not found.", name)
cls = _global_registry.get(RLLIB_CONNECTOR, name)
return cls.from_config(ctx, params)

View file

@ -0,0 +1,127 @@
import gym
import numpy as np
import unittest
from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.lambdas import (
ConvertToNumpyConnector,
UnbatchActionsConnector,
)
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.connector import (
ConnectorContext,
get_connector,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ActionConnectorDataType
torch, _ = try_import_torch()
class TestActionConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
def test_convert_to_numpy_connector(self):
ctx = ConnectorContext()
c = ConvertToNumpyConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ConvertToNumpyConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
action = torch.Tensor([8, 9])
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
converted = c(ac_data)
self.assertTrue(isinstance(converted.output[0], np.ndarray))
self.assertTrue(isinstance(converted.output[1], np.ndarray))
def test_unbatch_action_connector(self):
ctx = ConnectorContext()
c = UnbatchActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "UnbatchActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, UnbatchActionsConnector))
ac_data = ActionConnectorDataType(
0,
1,
(
{
"a": np.array([1, 2, 3]),
"b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
},
[],
{},
),
)
unbatched = c(ac_data)
actions, _, _ = unbatched.output
self.assertEqual(len(actions), 3)
self.assertEqual(actions[0]["a"], 1)
self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())
self.assertEqual(actions[1]["a"], 2)
self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())
self.assertEqual(actions[2]["a"], 3)
self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
def test_normalize_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = NormalizeActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "NormalizeActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
normalized = c(ac_data)
self.assertEqual(normalized.output[0], 4.5)
def test_clip_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = ClipActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ClipActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
clipped = c(ac_data)
self.assertEqual(clipped.output[0], 6.0)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,186 @@
import gym
import numpy as np
import unittest
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
from ray.rllib.connectors.agent.env_to_agent import EnvToAgentDataConnector
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
from ray.rllib.connectors.connector import (
ConnectorContext,
get_connector,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.typing import (
AgentConnectorDataType,
)
class TestAgentConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
pipeline = AgentConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
def test_env_to_per_agent_data_connector(self):
vrs = {
"infos": ViewRequirement(
"infos",
used_for_training=True,
used_for_compute_actions=False,
)
}
ctx = ConnectorContext(view_requirements=vrs)
c = EnvToAgentDataConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, EnvToAgentDataConnector))
d = AgentConnectorDataType(
0,
None,
[
# obs
{1: [8, 8], 2: [9, 9]},
# rewards
{
1: 8.8,
2: 9.9,
},
# dones
{
1: False,
2: False,
},
# infos
{
1: {"random": "info"},
2: {},
},
# training_episode_info
{
1: {SampleBatch.DONES: True},
},
],
)
per_agent = c(d)
self.assertEqual(len(per_agent), 2)
batch1 = per_agent[0].data
self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8])
self.assertTrue(batch1[SampleBatch.DONES]) # from training_episode_info
self.assertTrue(SampleBatch.INFOS in batch1)
self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info")
batch2 = per_agent[1].data
self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9])
self.assertFalse(batch2[SampleBatch.DONES])
def test_obs_preprocessor_connector(self):
obs_space = gym.spaces.Dict(
{
"a": gym.spaces.Box(low=0, high=1, shape=(1,)),
"b": gym.spaces.Tuple(
[gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
),
}
)
ctx = ConnectorContext(config={}, observation_space=obs_space)
c = ObsPreprocessorConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
obs = obs_space.sample()
# Fake deterministic data.
obs["a"][0] = 0.5
obs["b"] = (1, np.array([0, 2]))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.OBS: obs,
},
)
preprocessed = c(d)
# obs is completely flattened.
self.assertTrue(
(preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
)
def test_clip_reward_connector(self):
ctx = ConnectorContext()
c = ClipRewardAgentConnector(ctx, limit=2.0)
name, params = c.to_config()
self.assertEqual(name, "ClipRewardAgentConnector")
self.assertAlmostEqual(params["limit"], 2.0)
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.REWARDS: 5.8,
},
)
clipped = restored(ac_data=d)
self.assertEqual(len(clipped), 1)
self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
def test_flatten_data_connector(self):
ctx = ConnectorContext()
c = FlattenDataAgentConnector(ctx)
name, params = c.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
d = AgentConnectorDataType(
0,
1,
{
SampleBatch.NEXT_OBS: {
"sensor1": [[1, 1], [2, 2]],
"sensor2": 8.8,
},
SampleBatch.REWARDS: 5.8,
SampleBatch.ACTIONS: [[1, 1], [2]],
SampleBatch.INFOS: {"random": "info"},
},
)
flattened = c(d)
self.assertEqual(len(flattened), 1)
batch = flattened[0].data
self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all())
self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
# Not flattened.
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,59 @@
import unittest
from ray.rllib.connectors.connector import Connector, ConnectorPipeline
class TestConnectorPipeline(unittest.TestCase):
class Tom(Connector):
def to_config():
return "tom"
class Bob(Connector):
def to_config():
return "bob"
class Mary(Connector):
def to_config():
return "mary"
class MockConnectorPipeline(ConnectorPipeline):
def __init__(self, ctx, connectors):
# Real connector pipelines should keep a list of
# Connectors.
# Use strings here for simple unit tests.
self.connectors = connectors
def test_sanity_check(self):
ctx = {}
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.insert_before("Bob", self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.insert_after("Tom", self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.prepend(self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[0].__class__.__name__, "Mary")
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.append(self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[2].__class__.__name__, "Mary")
m.remove("Bob")
self.assertEqual(len(m.connectors), 2)
self.assertEqual(m.connectors[0].__class__.__name__, "Tom")
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

9
rllib/connectors/util.py Normal file
View file

@ -0,0 +1,9 @@
from ray.rllib.connectors.connector import (
Connector,
get_connector,
)
from typing import Dict
def get_connectors_from_cfg(config: dict) -> Dict[str, Connector]:
return {k: get_connector(*v) for k, v in config.items()}

View file

@ -276,6 +276,8 @@ class JsonReader(InputReader):
# Clip actions (from any values into env's bounds), if necessary.
cfg = self.ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions") and self.ioctx.worker is not None:
if isinstance(batch, SampleBatch):
batch[SampleBatch.ACTIONS] = clip_action(

View file

@ -8,6 +8,10 @@ import zlib
from ray.rllib.utils.annotations import DeveloperAPI
# TODO(jungong) : We need to handle RLlib custom space types,
# FlexDict, Repeated, and Simplex.
def _serialize_ndarray(array: np.ndarray) -> str:
"""Pack numpy ndarray into Base64 encoded strings for serialization.

View file

@ -182,7 +182,10 @@ def unbatch(batches_struct):
"""Converts input from (nested) struct of batches to batch of structs.
Input: Struct of different batches (each batch has size=3):
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
{
"a": np.array([1, 2, 3]),
"b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
}
Output: Batch (list) of structs (each of these structs representing a
single action):
[

View file

@ -4,6 +4,7 @@ from typing import (
Callable,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Type,
@ -12,6 +13,8 @@ from typing import (
Union,
)
from ray.rllib.utils.annotations import DeveloperAPI
if TYPE_CHECKING:
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
@ -149,5 +152,35 @@ SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
# (possibly nested) dict|tuple of gym.space.Spaces.
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
# A list of batches of RNN states.
# Each item in this list has dimension [B, S] (S=state vector size)
StateBatches = List[List[Any]]
# Format of data output from policy forward pass.
PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict]
# Data type that is fed into and yielded from agent connectors.
AgentConnectorDataType = DeveloperAPI( # API stability declaration.
NamedTuple(
"AgentConnectorDataType", [("env_id", str), ("agent_id", str), ("data", Any)]
)
)
# Data type that is fed into and yielded from agent connectors.
ActionConnectorDataType = DeveloperAPI( # API stability declaration.
NamedTuple(
"ActionConnectorDataType",
[("env_id", str), ("agent_id", str), ("output", PolicyOutputType)],
)
)
# Final output data type of agent connectors.
AgentConnectorsOutput = DeveloperAPI( # API stability declaration.
NamedTuple(
"AgentConnectorsOut",
[("for_training", Dict[str, TensorStructType]), ("for_action", "SampleBatch")],
)
)
# Generic type var.
T = TypeVar("T")