mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Introduce basic connectors library. (#25311)
This commit is contained in:
parent
4e887fe776
commit
9b65d5535d
23 changed files with 1621 additions and 1 deletions
|
@ -21,6 +21,7 @@ RLLIB_MODEL = "rllib_model"
|
||||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||||
RLLIB_ACTION_DIST = "rllib_action_dist"
|
RLLIB_ACTION_DIST = "rllib_action_dist"
|
||||||
RLLIB_INPUT = "rllib_input"
|
RLLIB_INPUT = "rllib_input"
|
||||||
|
RLLIB_CONNECTOR = "rllib_connector"
|
||||||
TEST = "__test__"
|
TEST = "__test__"
|
||||||
KNOWN_CATEGORIES = [
|
KNOWN_CATEGORIES = [
|
||||||
TRAINABLE_CLASS,
|
TRAINABLE_CLASS,
|
||||||
|
@ -29,6 +30,7 @@ KNOWN_CATEGORIES = [
|
||||||
RLLIB_PREPROCESSOR,
|
RLLIB_PREPROCESSOR,
|
||||||
RLLIB_ACTION_DIST,
|
RLLIB_ACTION_DIST,
|
||||||
RLLIB_INPUT,
|
RLLIB_INPUT,
|
||||||
|
RLLIB_CONNECTOR,
|
||||||
TEST,
|
TEST,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
28
rllib/BUILD
28
rllib/BUILD
|
@ -1292,6 +1292,34 @@ py_test(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# Connector tests
|
||||||
|
# rllib/connector/
|
||||||
|
#
|
||||||
|
# Tag: connector
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_connector",
|
||||||
|
tags = ["team:ml", "connector"],
|
||||||
|
size = "small",
|
||||||
|
srcs = ["connectors/tests/test_connector.py"]
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_action",
|
||||||
|
tags = ["team:ml", "connector"],
|
||||||
|
size = "small",
|
||||||
|
srcs = ["connectors/tests/test_action.py"]
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_agent",
|
||||||
|
tags = ["team:ml", "connector"],
|
||||||
|
size = "small",
|
||||||
|
srcs = ["connectors/tests/test_agent.py"]
|
||||||
|
)
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Env tests
|
# Env tests
|
||||||
# rllib/env/
|
# rllib/env/
|
||||||
|
|
0
rllib/connectors/__init__.py
Normal file
0
rllib/connectors/__init__.py
Normal file
43
rllib/connectors/action/clip.py
Normal file
43
rllib/connectors/action/clip.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
ActionConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.spaces.space_utils import (
|
||||||
|
clip_action,
|
||||||
|
get_base_struct_from_space,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.typing import ActionConnectorDataType
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ClipActionsConnector(ActionConnector):
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
super().__init__(ctx)
|
||||||
|
|
||||||
|
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||||||
|
|
||||||
|
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||||
|
assert isinstance(
|
||||||
|
ac_data.output, tuple
|
||||||
|
), "Action connector requires PolicyOutputType data."
|
||||||
|
|
||||||
|
actions, states, fetches = ac_data.output
|
||||||
|
return ActionConnectorDataType(
|
||||||
|
ac_data.env_id,
|
||||||
|
ac_data.agent_id,
|
||||||
|
(clip_action(actions, self._action_space_struct), states, fetches),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return ClipActionsConnector.__name__, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return ClipActionsConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(ClipActionsConnector.__name__, ClipActionsConnector)
|
79
rllib/connectors/action/lambdas.py
Normal file
79
rllib/connectors/action/lambdas.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
from typing import Any, Callable, Dict, List, Type
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
ActionConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.numpy import convert_to_numpy
|
||||||
|
from ray.rllib.utils.spaces.space_utils import unbatch
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
ActionConnectorDataType,
|
||||||
|
PolicyOutputType,
|
||||||
|
StateBatches,
|
||||||
|
TensorStructType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def register_lambda_action_connector(
|
||||||
|
name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
|
||||||
|
) -> Type[ActionConnector]:
|
||||||
|
"""A util to register any function transforming PolicyOutputType as an ActionConnector.
|
||||||
|
|
||||||
|
The only requirement is that fn should take actions, states, and fetches as input,
|
||||||
|
and return transformed actions, states, and fetches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the resulting actor connector.
|
||||||
|
fn: The function that transforms PolicyOutputType.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new ActionConnector class that transforms PolicyOutputType using fn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class LambdaActionConnector(ActionConnector):
|
||||||
|
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||||
|
assert isinstance(
|
||||||
|
ac_data.output, tuple
|
||||||
|
), "Action connector requires PolicyOutputType data."
|
||||||
|
|
||||||
|
actions, states, fetches = ac_data.output
|
||||||
|
return ActionConnectorDataType(
|
||||||
|
ac_data.env_id,
|
||||||
|
ac_data.agent_id,
|
||||||
|
fn(actions, states, fetches),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return LambdaActionConnector(ctx)
|
||||||
|
|
||||||
|
LambdaActionConnector.__name__ = name
|
||||||
|
LambdaActionConnector.__qualname__ = name
|
||||||
|
|
||||||
|
register_connector(name, LambdaActionConnector)
|
||||||
|
|
||||||
|
return LambdaActionConnector
|
||||||
|
|
||||||
|
|
||||||
|
# Convert actions and states into numpy arrays if necessary.
|
||||||
|
ConvertToNumpyConnector = register_lambda_action_connector(
|
||||||
|
"ConvertToNumpyConnector",
|
||||||
|
lambda actions, states, fetches: (
|
||||||
|
convert_to_numpy(actions),
|
||||||
|
convert_to_numpy(states),
|
||||||
|
fetches,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Split action-component batches into single action rows.
|
||||||
|
UnbatchActionsConnector = register_lambda_action_connector(
|
||||||
|
"UnbatchActionsConnector",
|
||||||
|
lambda actions, states, fetches: (unbatch(actions), states, fetches),
|
||||||
|
)
|
43
rllib/connectors/action/normalize.py
Normal file
43
rllib/connectors/action/normalize.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
ActionConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.spaces.space_utils import (
|
||||||
|
get_base_struct_from_space,
|
||||||
|
unsquash_action,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.typing import ActionConnectorDataType
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class NormalizeActionsConnector(ActionConnector):
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
super().__init__(ctx)
|
||||||
|
|
||||||
|
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||||||
|
|
||||||
|
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||||
|
assert isinstance(
|
||||||
|
ac_data.output, tuple
|
||||||
|
), "Action connector requires PolicyOutputType data."
|
||||||
|
|
||||||
|
actions, states, fetches = ac_data.output
|
||||||
|
return ActionConnectorDataType(
|
||||||
|
ac_data.env_id,
|
||||||
|
ac_data.agent_id,
|
||||||
|
(unsquash_action(actions, self._action_space_struct), states, fetches),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return NormalizeActionsConnector.__name__, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return NormalizeActionsConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)
|
57
rllib/connectors/action/pipeline.py
Normal file
57
rllib/connectors/action/pipeline.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
import gym
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ActionConnector,
|
||||||
|
Connector,
|
||||||
|
ConnectorContext,
|
||||||
|
ConnectorPipeline,
|
||||||
|
get_connector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
ActionConnectorDataType,
|
||||||
|
TrainerConfigDict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ActionConnectorPipeline(ActionConnector, ConnectorPipeline):
|
||||||
|
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
||||||
|
super().__init__(ctx)
|
||||||
|
self.connectors = connectors
|
||||||
|
|
||||||
|
def is_training(self, is_training: bool):
|
||||||
|
self.is_training = is_training
|
||||||
|
for c in self.connectors:
|
||||||
|
c.is_training(is_training)
|
||||||
|
|
||||||
|
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||||
|
for c in self.connectors:
|
||||||
|
ac_data = c(ac_data)
|
||||||
|
return ac_data
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return ActionConnectorPipeline.__name__, [
|
||||||
|
c.to_config() for c in self.connectors
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
assert (
|
||||||
|
type(params) == list
|
||||||
|
), "ActionConnectorPipeline takes a list of connector params."
|
||||||
|
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
|
||||||
|
return ActionConnectorPipeline(ctx, connectors)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def get_action_connectors_from_trainer_config(
|
||||||
|
config: TrainerConfigDict, action_space: gym.Space
|
||||||
|
) -> ActionConnectorPipeline:
|
||||||
|
connectors = []
|
||||||
|
return ActionConnectorPipeline(connectors)
|
52
rllib/connectors/agent/clip_reward.py
Normal file
52
rllib/connectors/agent/clip_reward.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import numpy as np
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.typing import AgentConnectorDataType
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ClipRewardAgentConnector(AgentConnector):
|
||||||
|
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
|
||||||
|
super().__init__(ctx)
|
||||||
|
assert (
|
||||||
|
not sign or not limit
|
||||||
|
), "should not enable both sign and limit reward clipping."
|
||||||
|
self.sign = sign
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||||
|
d = ac_data.data
|
||||||
|
assert (
|
||||||
|
type(d) == dict
|
||||||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||||
|
|
||||||
|
assert SampleBatch.REWARDS in d, "input data does not have reward column."
|
||||||
|
if self.sign:
|
||||||
|
d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS])
|
||||||
|
elif self.limit:
|
||||||
|
d[SampleBatch.REWARDS] = np.clip(
|
||||||
|
d[SampleBatch.REWARDS],
|
||||||
|
a_min=-self.limit,
|
||||||
|
a_max=self.limit,
|
||||||
|
)
|
||||||
|
return [ac_data]
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return ClipRewardAgentConnector.__name__, {
|
||||||
|
"sign": self.sign,
|
||||||
|
"limit": self.limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return ClipRewardAgentConnector(ctx, **params)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)
|
72
rllib/connectors/agent/env_to_agent.py
Normal file
72
rllib/connectors/agent/env_to_agent.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import AgentConnectorDataType
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class EnvToAgentDataConnector(AgentConnector):
|
||||||
|
"""Converts per environment multi-agent obs into per agent SampleBatches."""
|
||||||
|
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
super().__init__(ctx)
|
||||||
|
self._view_requirements = ctx.view_requirements
|
||||||
|
|
||||||
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||||
|
if ac_data.agent_id:
|
||||||
|
# data is already for a single agent.
|
||||||
|
return [ac_data]
|
||||||
|
|
||||||
|
assert isinstance(ac_data.data, (tuple, list)) and len(ac_data.data) == 5, (
|
||||||
|
"EnvToPerAgentDataConnector expects a tuple of "
|
||||||
|
+ "(obs, rewards, dones, infos, episode_infos)."
|
||||||
|
)
|
||||||
|
# episode_infos contains additional training related data bits
|
||||||
|
# for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX,
|
||||||
|
# SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon),
|
||||||
|
# and is usually empty in inference mode.
|
||||||
|
obs, rewards, dones, infos, training_episode_infos = ac_data.data
|
||||||
|
for var, name in zip(
|
||||||
|
(obs, rewards, dones, infos, training_episode_infos),
|
||||||
|
("obs", "rewards", "dones", "infos", "training_episode_infos"),
|
||||||
|
):
|
||||||
|
assert isinstance(var, dict), (
|
||||||
|
f"EnvToPerAgentDataConnector expects {name} "
|
||||||
|
+ "to be a MultiAgentDict."
|
||||||
|
)
|
||||||
|
|
||||||
|
env_id = ac_data.env_id
|
||||||
|
per_agent_data = []
|
||||||
|
for agent_id, obs in obs.items():
|
||||||
|
input_dict = {
|
||||||
|
SampleBatch.ENV_ID: env_id,
|
||||||
|
SampleBatch.REWARDS: rewards[agent_id],
|
||||||
|
# SampleBatch.DONES may be overridden by data from
|
||||||
|
# training_episode_infos next.
|
||||||
|
SampleBatch.DONES: dones[agent_id],
|
||||||
|
SampleBatch.NEXT_OBS: obs,
|
||||||
|
}
|
||||||
|
if SampleBatch.INFOS in self._view_requirements:
|
||||||
|
input_dict[SampleBatch.INFOS] = infos[agent_id]
|
||||||
|
if agent_id in training_episode_infos:
|
||||||
|
input_dict.update(training_episode_infos[agent_id])
|
||||||
|
|
||||||
|
per_agent_data.append(AgentConnectorDataType(env_id, agent_id, input_dict))
|
||||||
|
|
||||||
|
return per_agent_data
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return EnvToAgentDataConnector.__name__, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return EnvToAgentDataConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)
|
81
rllib/connectors/agent/lambdas.py
Normal file
81
rllib/connectors/agent/lambdas.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import numpy as np
|
||||||
|
import tree # dm_tree
|
||||||
|
from typing import Any, Callable, Dict, List, Type
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
AgentConnectorDataType,
|
||||||
|
TensorStructType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def register_lambda_agent_connector(
|
||||||
|
name: str, fn: Callable[[Any], Any]
|
||||||
|
) -> Type[AgentConnector]:
|
||||||
|
"""A util to register any simple transforming function as an AgentConnector
|
||||||
|
|
||||||
|
The only requirement is that fn should take a single data object and return
|
||||||
|
a single data object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the resulting actor connector.
|
||||||
|
fn: The function that transforms env / agent data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new AgentConnector class that transforms data using fn.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class LambdaAgentConnector(AgentConnector):
|
||||||
|
def __call__(
|
||||||
|
self, ac_data: AgentConnectorDataType
|
||||||
|
) -> List[AgentConnectorDataType]:
|
||||||
|
d = ac_data.data
|
||||||
|
return [AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))]
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return LambdaAgentConnector(ctx)
|
||||||
|
|
||||||
|
LambdaAgentConnector.__name__ = name
|
||||||
|
LambdaAgentConnector.__qualname__ = name
|
||||||
|
|
||||||
|
register_connector(name, LambdaAgentConnector)
|
||||||
|
|
||||||
|
return LambdaAgentConnector
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def flatten_data(data: Dict[str, TensorStructType]):
|
||||||
|
assert (
|
||||||
|
type(data) == dict
|
||||||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||||
|
|
||||||
|
flattened = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
|
||||||
|
# Do not flatten infos, actions, and state_out_ columns.
|
||||||
|
flattened[k] = v
|
||||||
|
continue
|
||||||
|
if v is None:
|
||||||
|
# Keep the same column shape.
|
||||||
|
flattened[k] = None
|
||||||
|
continue
|
||||||
|
flattened[k] = np.array(tree.flatten(v))
|
||||||
|
|
||||||
|
return flattened
|
||||||
|
|
||||||
|
|
||||||
|
# Flatten observation data.
|
||||||
|
FlattenDataAgentConnector = register_lambda_agent_connector(
|
||||||
|
"FlattenDataAgentConnector", flatten_data
|
||||||
|
)
|
60
rllib/connectors/agent/obs_preproc.py
Normal file
60
rllib/connectors/agent/obs_preproc.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.models.preprocessors import get_preprocessor
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import AgentConnectorDataType
|
||||||
|
|
||||||
|
|
||||||
|
# Bridging between current obs preprocessors and connector.
|
||||||
|
# We should not introduce any new preprocessors.
|
||||||
|
# TODO(jungong) : migrate and implement preprocessor library in Connector framework.
|
||||||
|
@DeveloperAPI
|
||||||
|
class ObsPreprocessorConnector(AgentConnector):
|
||||||
|
"""A connector that wraps around existing RLlib observation preprocessors.
|
||||||
|
|
||||||
|
This includes:
|
||||||
|
- OneHotPreprocessor for Discrete and Multi-Discrete spaces.
|
||||||
|
- GenericPixelPreprocessor and AtariRamPreprocessor for Atari spaces.
|
||||||
|
- TupleFlatteningPreprocessor and DictFlatteningPreprocessor for flattening
|
||||||
|
arbitrary nested input observations.
|
||||||
|
- RepeatedValuesPreprocessor for padding observations from RLlib Repeated
|
||||||
|
observation space.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
super().__init__(ctx)
|
||||||
|
|
||||||
|
self._preprocessor = get_preprocessor(ctx.observation_space)(
|
||||||
|
ctx.observation_space, ctx.config.get("model", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||||
|
d = ac_data.data
|
||||||
|
assert (
|
||||||
|
type(d) == dict
|
||||||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||||
|
|
||||||
|
if SampleBatch.OBS in d:
|
||||||
|
d[SampleBatch.OBS] = self._preprocessor.transform(d[SampleBatch.OBS])
|
||||||
|
if SampleBatch.NEXT_OBS in d:
|
||||||
|
d[SampleBatch.NEXT_OBS] = self._preprocessor.transform(
|
||||||
|
d[SampleBatch.NEXT_OBS]
|
||||||
|
)
|
||||||
|
|
||||||
|
return [ac_data]
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return ObsPreprocessorConnector.__name__, {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return ObsPreprocessorConnector(ctx, **params)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)
|
79
rllib/connectors/agent/pipeline.py
Normal file
79
rllib/connectors/agent/pipeline.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import gym
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
Connector,
|
||||||
|
ConnectorContext,
|
||||||
|
ConnectorPipeline,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
get_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
||||||
|
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
ActionConnectorDataType,
|
||||||
|
AgentConnectorDataType,
|
||||||
|
TrainerConfigDict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class AgentConnectorPipeline(AgentConnector, ConnectorPipeline):
|
||||||
|
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
||||||
|
super().__init__(ctx)
|
||||||
|
self.connectors = connectors
|
||||||
|
|
||||||
|
def is_training(self, is_training: bool):
|
||||||
|
self.is_training = is_training
|
||||||
|
for c in self.connectors:
|
||||||
|
c.is_training(is_training)
|
||||||
|
|
||||||
|
def reset(self, env_id: str):
|
||||||
|
for c in self.connectors:
|
||||||
|
c.reset(env_id)
|
||||||
|
|
||||||
|
def on_policy_output(self, output: ActionConnectorDataType):
|
||||||
|
for c in self.connectors:
|
||||||
|
c.on_policy_output(output)
|
||||||
|
|
||||||
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||||
|
ret = [ac_data]
|
||||||
|
for c in self.connectors:
|
||||||
|
# Run the list of input data through the next agent connect,
|
||||||
|
# and collect the list of output data.
|
||||||
|
new_ret = []
|
||||||
|
for d in ret:
|
||||||
|
new_ret += c(d)
|
||||||
|
ret = new_ret
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
assert (
|
||||||
|
type(params) == list
|
||||||
|
), "AgentConnectorPipeline takes a list of connector params."
|
||||||
|
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
|
||||||
|
return AgentConnectorPipeline(ctx, connectors)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(jungong) : finish this.
|
||||||
|
@DeveloperAPI
|
||||||
|
def get_agent_connectors_from_config(
|
||||||
|
config: TrainerConfigDict, obs_space: gym.Space
|
||||||
|
) -> AgentConnectorPipeline:
|
||||||
|
connectors = [FlattenDataAgentConnector()]
|
||||||
|
|
||||||
|
if config["clip_rewards"] is True:
|
||||||
|
connectors.append(ClipRewardAgentConnector(sign=True))
|
||||||
|
elif type(config["clip_rewards"]) == float:
|
||||||
|
connectors.append(ClipRewardAgentConnector(limit=abs(config["clip_rewards"])))
|
||||||
|
|
||||||
|
return AgentConnectorPipeline(connectors)
|
99
rllib/connectors/agent/state_buffer.py
Normal file
99
rllib/connectors/agent/state_buffer.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy as np
|
||||||
|
import tree # dm_tree
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.numpy import convert_to_numpy
|
||||||
|
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
AgentConnectorDataType,
|
||||||
|
PolicyOutputType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class _AgentState(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.t = 0
|
||||||
|
self.action = None
|
||||||
|
self.states = None
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class StateBufferConnector(AgentConnector):
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
super().__init__(ctx)
|
||||||
|
|
||||||
|
self._initial_states = ctx.initial_states
|
||||||
|
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||||||
|
self._states = defaultdict(lambda: defaultdict(_AgentState))
|
||||||
|
|
||||||
|
def reset(self, env_id: str):
|
||||||
|
del self._states[env_id]
|
||||||
|
|
||||||
|
def on_policy_output(self, env_id: str, agent_id: str, output: PolicyOutputType):
|
||||||
|
# Buffer latest output states for next input __call__.
|
||||||
|
action, states, _ = output
|
||||||
|
agent_state = self._states[env_id][agent_id]
|
||||||
|
agent_state.action = convert_to_numpy(action)
|
||||||
|
agent_state.states = convert_to_numpy(states)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, ctx: ConnectorContext, ac_data: AgentConnectorDataType
|
||||||
|
) -> List[AgentConnectorDataType]:
|
||||||
|
d = ac_data.data
|
||||||
|
assert (
|
||||||
|
type(d) == dict
|
||||||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||||
|
|
||||||
|
env_id = ac_data.env_id
|
||||||
|
agent_id = ac_data.agent_id
|
||||||
|
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
|
||||||
|
|
||||||
|
agent_state = self._states[env_id][agent_id]
|
||||||
|
|
||||||
|
d.update(
|
||||||
|
{
|
||||||
|
SampleBatch.T: agent_state.t,
|
||||||
|
SampleBatch.ENV_ID: env_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_state.states is not None:
|
||||||
|
states = agent_state.states
|
||||||
|
else:
|
||||||
|
states = self._initial_states
|
||||||
|
for i, v in enumerate(states):
|
||||||
|
d["state_out_{}".format(i)] = v
|
||||||
|
|
||||||
|
if agent_state.action is not None:
|
||||||
|
d[SampleBatch.ACTIONS] = agent_state.action # Last action
|
||||||
|
else:
|
||||||
|
# Default zero action.
|
||||||
|
d[SampleBatch.ACTIONS] = tree.map_structure(
|
||||||
|
lambda s: np.zeros_like(s.sample(), s.dtype)
|
||||||
|
if hasattr(s, "dtype")
|
||||||
|
else np.zeros_like(s.sample()),
|
||||||
|
self._action_space_struct,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_state.t += 1
|
||||||
|
|
||||||
|
return [ac_data]
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return StateBufferConnector.__name__, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return StateBufferConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(StateBufferConnector.__name__, StateBufferConnector)
|
136
rllib/connectors/agent/view_requirement.py
Normal file
136
rllib/connectors/agent/view_requirement.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
import numpy as np
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
AgentConnector,
|
||||||
|
register_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
AgentConnectorDataType,
|
||||||
|
AgentConnectorsOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ViewRequirementAgentConnector(AgentConnector):
|
||||||
|
"""This connector does 2 things:
|
||||||
|
1. It filters data columns based on view_requirements for training and inference.
|
||||||
|
2. It buffers the right amount of history for computing the sample batch for
|
||||||
|
action computation.
|
||||||
|
The output of this connector is AgentConnectorsOut, which basically is
|
||||||
|
a tuple of 2 things:
|
||||||
|
{
|
||||||
|
"for_training": {"obs": ...}
|
||||||
|
"for_action": SampleBatch
|
||||||
|
}
|
||||||
|
The "for_training" dict, which contains data for the latest time slice,
|
||||||
|
can be used to construct a complete episode by Sampler for training purpose.
|
||||||
|
The "for_action" SampleBatch can be used to directly call the policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
super().__init__(ctx)
|
||||||
|
|
||||||
|
self._view_requirements = ctx.view_requirements
|
||||||
|
self._agent_data = defaultdict(lambda: defaultdict(SampleBatch))
|
||||||
|
|
||||||
|
def reset(self, env_id: str):
|
||||||
|
if env_id in self._agent_data:
|
||||||
|
del self._agent_data[env_id]
|
||||||
|
|
||||||
|
def _get_sample_batch_for_action(
|
||||||
|
self, view_requirements, agent_batch
|
||||||
|
) -> SampleBatch:
|
||||||
|
# TODO(jungong) : actually support buildling input sample batch with all the
|
||||||
|
# view shift requirements, etc.
|
||||||
|
# For now, we use some simple logics for demo purpose.
|
||||||
|
input_batch = SampleBatch()
|
||||||
|
for k, v in view_requirements.items():
|
||||||
|
if not v.used_for_compute_actions:
|
||||||
|
continue
|
||||||
|
data_col = v.data_col or k
|
||||||
|
if data_col not in agent_batch:
|
||||||
|
continue
|
||||||
|
input_batch[k] = agent_batch[data_col][-1:]
|
||||||
|
input_batch.count = 1
|
||||||
|
return input_batch
|
||||||
|
|
||||||
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||||
|
d = ac_data.data
|
||||||
|
assert (
|
||||||
|
type(d) == dict
|
||||||
|
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||||
|
|
||||||
|
env_id = ac_data.env_id
|
||||||
|
agent_id = ac_data.agent_id
|
||||||
|
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
|
||||||
|
|
||||||
|
vr = self._view_requirements
|
||||||
|
assert vr, "ViewRequirements required by ViewRequirementConnector"
|
||||||
|
|
||||||
|
training_dict = {}
|
||||||
|
# We construct a proper per-timeslice dict in training mode,
|
||||||
|
# for Sampler to construct a complete episode for back propagation.
|
||||||
|
if self.is_training:
|
||||||
|
# Filter columns that are not needed for traing.
|
||||||
|
for col, req in vr.items():
|
||||||
|
# Not used for training.
|
||||||
|
if not req.used_for_training:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create the batch of data from the different buffers.
|
||||||
|
data_col = req.data_col or col
|
||||||
|
if data_col not in d:
|
||||||
|
continue
|
||||||
|
|
||||||
|
training_dict[data_col] = d[data_col]
|
||||||
|
|
||||||
|
# Agent batch is our buffer of necessary history for computing
|
||||||
|
# a SampleBatch for policy forward pass.
|
||||||
|
# This is used by both training and inference.
|
||||||
|
agent_batch = self._agent_data[env_id][agent_id]
|
||||||
|
for col, req in vr.items():
|
||||||
|
# Not used for action computation.
|
||||||
|
if not req.used_for_compute_actions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create the batch of data from the different buffers.
|
||||||
|
data_col = req.data_col or col
|
||||||
|
if data_col not in d:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add batch dim to this data_col.
|
||||||
|
d_col = np.expand_dims(d[data_col], axis=0)
|
||||||
|
|
||||||
|
if col in agent_batch:
|
||||||
|
# Stack along batch dim.
|
||||||
|
agent_batch[data_col] = np.vstack((agent_batch[data_col], d_col))
|
||||||
|
else:
|
||||||
|
agent_batch[data_col] = d_col
|
||||||
|
# Only keep the useful part of the history.
|
||||||
|
h = req.shift_from if req.shift_from else -1
|
||||||
|
assert h <= 0, "Can use future data to compute action"
|
||||||
|
agent_batch[data_col] = agent_batch[data_col][h:]
|
||||||
|
|
||||||
|
sample_batch = self._get_sample_batch_for_action(vr, agent_batch)
|
||||||
|
|
||||||
|
return_data = AgentConnectorDataType(
|
||||||
|
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
|
||||||
|
)
|
||||||
|
return return_data
|
||||||
|
|
||||||
|
def to_config(self):
|
||||||
|
return ViewRequirementAgentConnector.__name__, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||||
|
return ViewRequirementAgentConnector(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
register_connector(
|
||||||
|
ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector
|
||||||
|
)
|
366
rllib/connectors/connector.py
Normal file
366
rllib/connectors/connector.py
Normal file
|
@ -0,0 +1,366 @@
|
||||||
|
"""This file defines base types and common structures for RLlib connectors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import gym
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from ray.tune.registry import RLLIB_CONNECTOR, _global_registry
|
||||||
|
from ray.rllib.policy.policy import Policy
|
||||||
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
ActionConnectorDataType,
|
||||||
|
AgentConnectorDataType,
|
||||||
|
TensorType,
|
||||||
|
TrainerConfigDict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ConnectorContext:
|
||||||
|
"""Data bits that may be needed for running connectors.
|
||||||
|
|
||||||
|
Note(jungong) : we need to be really careful with the data fields here.
|
||||||
|
E.g., everything needs to be serializable, in case we need to fetch them
|
||||||
|
in a remote setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(jungong) : figure out how to fetch these in a remote setting.
|
||||||
|
# Probably from a policy server when initializing a policy client.
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TrainerConfigDict = None,
|
||||||
|
model_initial_states: List[TensorType] = None,
|
||||||
|
observation_space: gym.Space = None,
|
||||||
|
action_space: gym.Space = None,
|
||||||
|
view_requirements: Dict[str, ViewRequirement] = None,
|
||||||
|
):
|
||||||
|
"""Construct a ConnectorContext instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_initial_states: States that are used for constructing
|
||||||
|
the initial input dict for RNN models. [] if a model is not recurrent.
|
||||||
|
action_space_struct: a policy's action space, in python
|
||||||
|
data format. E.g., python dict instead of DictSpace, python tuple
|
||||||
|
instead of TupleSpace.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.initial_states = model_initial_states or []
|
||||||
|
self.observation_space = observation_space
|
||||||
|
self.action_space = action_space
|
||||||
|
self.view_requirements = view_requirements
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_policy(policy: Policy) -> "ConnectorContext":
|
||||||
|
"""Build ConnectorContext from a given policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: Policy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ConnectorContext instance.
|
||||||
|
"""
|
||||||
|
return ConnectorContext(
|
||||||
|
policy.config,
|
||||||
|
policy.get_initial_state(),
|
||||||
|
policy.observation_space,
|
||||||
|
policy.action_space,
|
||||||
|
policy.view_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class Connector(abc.ABC):
|
||||||
|
"""Connector base class.
|
||||||
|
|
||||||
|
A connector is a step of transformation, of either envrionment data before they
|
||||||
|
get to a policy, or policy output before it is sent back to the environment.
|
||||||
|
|
||||||
|
Connectors may be training-aware, for example, behave slightly differently
|
||||||
|
during training and inference.
|
||||||
|
|
||||||
|
All connectors are required to be serializable and implement to_config().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: ConnectorContext):
|
||||||
|
# This gets flipped to False for inference.
|
||||||
|
self.is_training = True
|
||||||
|
|
||||||
|
def is_training(self, is_training: bool):
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def to_config(self) -> Tuple[str, List[Any]]:
|
||||||
|
"""Serialize a connector into a JSON serializable Tuple.
|
||||||
|
|
||||||
|
to_config is required, so that all Connectors are serializable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of connector's name and its serialized states.
|
||||||
|
"""
|
||||||
|
# Must implement by each connector.
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
|
||||||
|
"""De-serialize a JSON params back into a Connector.
|
||||||
|
|
||||||
|
from_config is required, so that all Connectors are serializable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context for constructing this connector.
|
||||||
|
params: Serialized states of the connector to be recovered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
De-serialized connector.
|
||||||
|
"""
|
||||||
|
# Must implement by each connector.
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class AgentConnector(Connector):
|
||||||
|
"""Connector connecting user environments to RLlib policies.
|
||||||
|
|
||||||
|
An agent connector transforms a single piece of data in AgentConnectorDataType
|
||||||
|
format into a list of data in the same AgentConnectorDataTypes format.
|
||||||
|
The API is designed so multi-agent observations can be broken and emitted as
|
||||||
|
multiple single agent observations.
|
||||||
|
|
||||||
|
AgentConnectorDataTypes can be used to specify arbitrary type of env data,
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
# A dict of multi-agent data from one env step() call.
|
||||||
|
ac = AgentConnectorDataType(
|
||||||
|
env_id="env_1",
|
||||||
|
agent_id=None,
|
||||||
|
data={
|
||||||
|
"agent_1": np.array(...),
|
||||||
|
"agent_2": np.array(...),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
# Single agent data ready to be preprocessed.
|
||||||
|
ac = AgentConnectorDataType(
|
||||||
|
env_id="env_1",
|
||||||
|
agent_id="agent_1",
|
||||||
|
data=np.array(...)
|
||||||
|
)
|
||||||
|
|
||||||
|
We can adapt a simple stateless function into an agent connector by using
|
||||||
|
register_lambda_agent_connector:
|
||||||
|
.. code-block:: python
|
||||||
|
TimesTwoAgentConnector = register_lambda_agent_connector(
|
||||||
|
"TimesTwoAgentConnector", lambda data: data * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
More complicated agent connectors can be implemented by extending this
|
||||||
|
AgentConnector class:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
class FrameSkippingAgentConnector(AgentConnector):
|
||||||
|
def __init__(self, n):
|
||||||
|
self._n = n
|
||||||
|
self._frame_count = default_dict(str, default_dict(str, int))
|
||||||
|
|
||||||
|
def reset(self, env_id: str):
|
||||||
|
del self._frame_count[env_id]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, ac_data: AgentConnectorDataType
|
||||||
|
) -> List[AgentConnectorDataType]:
|
||||||
|
assert ac_data.env_id and ac_data.agent_id, (
|
||||||
|
"Frame skipping works per agent")
|
||||||
|
|
||||||
|
count = self._frame_count[ac_data.env_id][ac_data.agent_id]
|
||||||
|
self._frame_count[ac_data.env_id][ac_data.agent_id] = count + 1
|
||||||
|
|
||||||
|
return [ac_data] if count % self._n == 0 else []
|
||||||
|
|
||||||
|
As shown, an agent connector may choose to emit an empty list to stop input
|
||||||
|
observations from being prosessed further.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reset(self, env_id: str):
|
||||||
|
"""Reset connector state for a specific environment.
|
||||||
|
|
||||||
|
For example, at the end of an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id: required. ID of a user environment. Required.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_policy_output(self, output: ActionConnectorDataType):
|
||||||
|
"""Callback on agent connector of policy output.
|
||||||
|
|
||||||
|
This is useful for certain connectors, for example RNN state buffering,
|
||||||
|
where the agent connect needs to be aware of the output of a policy
|
||||||
|
forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context for running this connector call.
|
||||||
|
output: Env and agent IDs, plus data output from policy forward pass.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||||
|
"""Transform incoming data from environment before they reach policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context for running this connector call.
|
||||||
|
data: Env and agent IDs, plus arbitrary data from an environment or
|
||||||
|
upstream agent connectors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of transformed data in AgentConnectorDataType format.
|
||||||
|
The return type is a list because an AgentConnector may choose to
|
||||||
|
derive multiple outputs for a single input data, for example
|
||||||
|
multi-agent obs -> multiple single agent obs.
|
||||||
|
Agent connectors may also return an empty list for certain input,
|
||||||
|
useful for connectors such as frame skipping.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ActionConnector(Connector):
|
||||||
|
"""Action connector connects policy outputs including actions,
|
||||||
|
to user environments.
|
||||||
|
|
||||||
|
An action connector transforms a single piece of policy output in
|
||||||
|
ActionConnectorDataType format, which is basically PolicyOutputType
|
||||||
|
plus env and agent IDs.
|
||||||
|
|
||||||
|
Any functions that operates directly on PolicyOutputType can be
|
||||||
|
easily adpated into an ActionConnector by using register_lambda_action_connector.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
ZeroActionConnector = register_lambda_action_connector(
|
||||||
|
"ZeroActionsConnector",
|
||||||
|
lambda actions, states, fetches: (
|
||||||
|
np.zeros_like(actions), states, fetches
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
More complicated action connectors can also be implemented by sub-classing
|
||||||
|
this ActionConnector class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||||
|
"""Transform policy output before they are sent to a user environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Context for running this connector call.
|
||||||
|
ac_data: Env and agent IDs, plus policy output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed action connector data.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
class ConnectorPipeline:
|
||||||
|
"""Utility class for quick manipulation of a connector pipeline."""
|
||||||
|
|
||||||
|
def remove(self, name: str):
|
||||||
|
"""Remove a connector by <name>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the connector to be removed.
|
||||||
|
"""
|
||||||
|
idx = -1
|
||||||
|
for idx, c in enumerate(self.connectors):
|
||||||
|
if c.__class__.__name__ == name:
|
||||||
|
break
|
||||||
|
if idx < 0:
|
||||||
|
raise ValueError(f"Can not find connector {name}")
|
||||||
|
del self.connectors[idx]
|
||||||
|
|
||||||
|
def insert_before(self, name: str, connector: Connector):
|
||||||
|
"""Insert a new connector before connector <name>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the connector before which a new connector
|
||||||
|
will get inserted.
|
||||||
|
connector: a new connector to be inserted.
|
||||||
|
"""
|
||||||
|
idx = -1
|
||||||
|
for idx, c in enumerate(self.connectors):
|
||||||
|
if c.__class__.__name__ == name:
|
||||||
|
break
|
||||||
|
if idx < 0:
|
||||||
|
raise ValueError(f"Can not find connector {name}")
|
||||||
|
self.connectors.insert(idx, connector)
|
||||||
|
|
||||||
|
def insert_after(self, name: str, connector: Connector):
|
||||||
|
"""Insert a new connector after connector <name>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the connector after which a new connector
|
||||||
|
will get inserted.
|
||||||
|
connector: a new connector to be inserted.
|
||||||
|
"""
|
||||||
|
idx = -1
|
||||||
|
for idx, c in enumerate(self.connectors):
|
||||||
|
if c.__class__.__name__ == name:
|
||||||
|
break
|
||||||
|
if idx < 0:
|
||||||
|
raise ValueError(f"Can not find connector {name}")
|
||||||
|
self.connectors.insert(idx + 1, connector)
|
||||||
|
|
||||||
|
def prepend(self, connector: Connector):
|
||||||
|
"""Append a new connector at the beginning of a connector pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: a new connector to be appended.
|
||||||
|
"""
|
||||||
|
self.connectors.insert(0, connector)
|
||||||
|
|
||||||
|
def append(self, connector: Connector):
|
||||||
|
"""Append a new connector at the end of a connector pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: a new connector to be appended.
|
||||||
|
"""
|
||||||
|
self.connectors.append(connector)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def register_connector(name: str, cls: Connector):
|
||||||
|
"""Register a connector for use with RLlib.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name to register.
|
||||||
|
cls: Callable that creates an env.
|
||||||
|
"""
|
||||||
|
if not issubclass(cls, Connector):
|
||||||
|
raise TypeError("Can only register Connector type.", cls)
|
||||||
|
_global_registry.register(RLLIB_CONNECTOR, name, cls)
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Connector:
|
||||||
|
"""Get a connector by its name and serialized config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the connector.
|
||||||
|
params: serialized parameters of the connector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constructed connector.
|
||||||
|
"""
|
||||||
|
if not _global_registry.contains(RLLIB_CONNECTOR, name):
|
||||||
|
raise NameError("connector not found.", name)
|
||||||
|
cls = _global_registry.get(RLLIB_CONNECTOR, name)
|
||||||
|
return cls.from_config(ctx, params)
|
127
rllib/connectors/tests/test_action.py
Normal file
127
rllib/connectors/tests/test_action.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from ray.rllib.connectors.action.clip import ClipActionsConnector
|
||||||
|
from ray.rllib.connectors.action.lambdas import (
|
||||||
|
ConvertToNumpyConnector,
|
||||||
|
UnbatchActionsConnector,
|
||||||
|
)
|
||||||
|
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
|
||||||
|
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
get_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
|
from ray.rllib.utils.typing import ActionConnectorDataType
|
||||||
|
|
||||||
|
torch, _ = try_import_torch()
|
||||||
|
|
||||||
|
|
||||||
|
class TestActionConnector(unittest.TestCase):
|
||||||
|
def test_connector_pipeline(self):
|
||||||
|
ctx = ConnectorContext()
|
||||||
|
connectors = [ConvertToNumpyConnector(ctx)]
|
||||||
|
pipeline = ActionConnectorPipeline(ctx, connectors)
|
||||||
|
name, params = pipeline.to_config()
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
|
||||||
|
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
|
||||||
|
|
||||||
|
def test_convert_to_numpy_connector(self):
|
||||||
|
ctx = ConnectorContext()
|
||||||
|
c = ConvertToNumpyConnector(ctx)
|
||||||
|
|
||||||
|
name, params = c.to_config()
|
||||||
|
|
||||||
|
self.assertEqual(name, "ConvertToNumpyConnector")
|
||||||
|
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
|
||||||
|
|
||||||
|
action = torch.Tensor([8, 9])
|
||||||
|
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
|
||||||
|
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
|
||||||
|
|
||||||
|
converted = c(ac_data)
|
||||||
|
self.assertTrue(isinstance(converted.output[0], np.ndarray))
|
||||||
|
self.assertTrue(isinstance(converted.output[1], np.ndarray))
|
||||||
|
|
||||||
|
def test_unbatch_action_connector(self):
|
||||||
|
ctx = ConnectorContext()
|
||||||
|
c = UnbatchActionsConnector(ctx)
|
||||||
|
|
||||||
|
name, params = c.to_config()
|
||||||
|
|
||||||
|
self.assertEqual(name, "UnbatchActionsConnector")
|
||||||
|
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, UnbatchActionsConnector))
|
||||||
|
|
||||||
|
ac_data = ActionConnectorDataType(
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"a": np.array([1, 2, 3]),
|
||||||
|
"b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
unbatched = c(ac_data)
|
||||||
|
actions, _, _ = unbatched.output
|
||||||
|
|
||||||
|
self.assertEqual(len(actions), 3)
|
||||||
|
self.assertEqual(actions[0]["a"], 1)
|
||||||
|
self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())
|
||||||
|
|
||||||
|
self.assertEqual(actions[1]["a"], 2)
|
||||||
|
self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())
|
||||||
|
|
||||||
|
self.assertEqual(actions[2]["a"], 3)
|
||||||
|
self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
|
||||||
|
|
||||||
|
def test_normalize_action_connector(self):
|
||||||
|
ctx = ConnectorContext(
|
||||||
|
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
||||||
|
)
|
||||||
|
c = NormalizeActionsConnector(ctx)
|
||||||
|
|
||||||
|
name, params = c.to_config()
|
||||||
|
self.assertEqual(name, "NormalizeActionsConnector")
|
||||||
|
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
|
||||||
|
|
||||||
|
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
|
||||||
|
|
||||||
|
normalized = c(ac_data)
|
||||||
|
self.assertEqual(normalized.output[0], 4.5)
|
||||||
|
|
||||||
|
def test_clip_action_connector(self):
|
||||||
|
ctx = ConnectorContext(
|
||||||
|
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
||||||
|
)
|
||||||
|
c = ClipActionsConnector(ctx)
|
||||||
|
|
||||||
|
name, params = c.to_config()
|
||||||
|
self.assertEqual(name, "ClipActionsConnector")
|
||||||
|
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, ClipActionsConnector))
|
||||||
|
|
||||||
|
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
|
||||||
|
|
||||||
|
clipped = c(ac_data)
|
||||||
|
self.assertEqual(clipped.output[0], 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
186
rllib/connectors/tests/test_agent.py
Normal file
186
rllib/connectors/tests/test_agent.py
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
||||||
|
from ray.rllib.connectors.agent.env_to_agent import EnvToAgentDataConnector
|
||||||
|
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
|
||||||
|
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
|
||||||
|
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
ConnectorContext,
|
||||||
|
get_connector,
|
||||||
|
)
|
||||||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||||
|
from ray.rllib.utils.typing import (
|
||||||
|
AgentConnectorDataType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentConnector(unittest.TestCase):
|
||||||
|
def test_connector_pipeline(self):
|
||||||
|
ctx = ConnectorContext()
|
||||||
|
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
|
||||||
|
pipeline = AgentConnectorPipeline(ctx, connectors)
|
||||||
|
name, params = pipeline.to_config()
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
||||||
|
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
|
||||||
|
|
||||||
|
def test_env_to_per_agent_data_connector(self):
|
||||||
|
vrs = {
|
||||||
|
"infos": ViewRequirement(
|
||||||
|
"infos",
|
||||||
|
used_for_training=True,
|
||||||
|
used_for_compute_actions=False,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ctx = ConnectorContext(view_requirements=vrs)
|
||||||
|
|
||||||
|
c = EnvToAgentDataConnector(ctx)
|
||||||
|
|
||||||
|
name, params = c.to_config()
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, EnvToAgentDataConnector))
|
||||||
|
|
||||||
|
d = AgentConnectorDataType(
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
# obs
|
||||||
|
{1: [8, 8], 2: [9, 9]},
|
||||||
|
# rewards
|
||||||
|
{
|
||||||
|
1: 8.8,
|
||||||
|
2: 9.9,
|
||||||
|
},
|
||||||
|
# dones
|
||||||
|
{
|
||||||
|
1: False,
|
||||||
|
2: False,
|
||||||
|
},
|
||||||
|
# infos
|
||||||
|
{
|
||||||
|
1: {"random": "info"},
|
||||||
|
2: {},
|
||||||
|
},
|
||||||
|
# training_episode_info
|
||||||
|
{
|
||||||
|
1: {SampleBatch.DONES: True},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
per_agent = c(d)
|
||||||
|
|
||||||
|
self.assertEqual(len(per_agent), 2)
|
||||||
|
|
||||||
|
batch1 = per_agent[0].data
|
||||||
|
self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8])
|
||||||
|
self.assertTrue(batch1[SampleBatch.DONES]) # from training_episode_info
|
||||||
|
self.assertTrue(SampleBatch.INFOS in batch1)
|
||||||
|
self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info")
|
||||||
|
|
||||||
|
batch2 = per_agent[1].data
|
||||||
|
self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9])
|
||||||
|
self.assertFalse(batch2[SampleBatch.DONES])
|
||||||
|
|
||||||
|
def test_obs_preprocessor_connector(self):
|
||||||
|
obs_space = gym.spaces.Dict(
|
||||||
|
{
|
||||||
|
"a": gym.spaces.Box(low=0, high=1, shape=(1,)),
|
||||||
|
"b": gym.spaces.Tuple(
|
||||||
|
[gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ctx = ConnectorContext(config={}, observation_space=obs_space)
|
||||||
|
|
||||||
|
c = ObsPreprocessorConnector(ctx)
|
||||||
|
name, params = c.to_config()
|
||||||
|
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
|
||||||
|
|
||||||
|
obs = obs_space.sample()
|
||||||
|
# Fake deterministic data.
|
||||||
|
obs["a"][0] = 0.5
|
||||||
|
obs["b"] = (1, np.array([0, 2]))
|
||||||
|
|
||||||
|
d = AgentConnectorDataType(
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
SampleBatch.OBS: obs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
preprocessed = c(d)
|
||||||
|
|
||||||
|
# obs is completely flattened.
|
||||||
|
self.assertTrue(
|
||||||
|
(preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_clip_reward_connector(self):
|
||||||
|
ctx = ConnectorContext()
|
||||||
|
|
||||||
|
c = ClipRewardAgentConnector(ctx, limit=2.0)
|
||||||
|
name, params = c.to_config()
|
||||||
|
|
||||||
|
self.assertEqual(name, "ClipRewardAgentConnector")
|
||||||
|
self.assertAlmostEqual(params["limit"], 2.0)
|
||||||
|
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))
|
||||||
|
|
||||||
|
d = AgentConnectorDataType(
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
SampleBatch.REWARDS: 5.8,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
clipped = restored(ac_data=d)
|
||||||
|
|
||||||
|
self.assertEqual(len(clipped), 1)
|
||||||
|
self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
|
||||||
|
|
||||||
|
def test_flatten_data_connector(self):
|
||||||
|
ctx = ConnectorContext()
|
||||||
|
|
||||||
|
c = FlattenDataAgentConnector(ctx)
|
||||||
|
|
||||||
|
name, params = c.to_config()
|
||||||
|
restored = get_connector(ctx, name, params)
|
||||||
|
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
|
||||||
|
|
||||||
|
d = AgentConnectorDataType(
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
SampleBatch.NEXT_OBS: {
|
||||||
|
"sensor1": [[1, 1], [2, 2]],
|
||||||
|
"sensor2": 8.8,
|
||||||
|
},
|
||||||
|
SampleBatch.REWARDS: 5.8,
|
||||||
|
SampleBatch.ACTIONS: [[1, 1], [2]],
|
||||||
|
SampleBatch.INFOS: {"random": "info"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
flattened = c(d)
|
||||||
|
self.assertEqual(len(flattened), 1)
|
||||||
|
|
||||||
|
batch = flattened[0].data
|
||||||
|
self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all())
|
||||||
|
self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
|
||||||
|
# Not flattened.
|
||||||
|
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
|
||||||
|
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
59
rllib/connectors/tests/test_connector.py
Normal file
59
rllib/connectors/tests/test_connector.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from ray.rllib.connectors.connector import Connector, ConnectorPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectorPipeline(unittest.TestCase):
|
||||||
|
class Tom(Connector):
|
||||||
|
def to_config():
|
||||||
|
return "tom"
|
||||||
|
|
||||||
|
class Bob(Connector):
|
||||||
|
def to_config():
|
||||||
|
return "bob"
|
||||||
|
|
||||||
|
class Mary(Connector):
|
||||||
|
def to_config():
|
||||||
|
return "mary"
|
||||||
|
|
||||||
|
class MockConnectorPipeline(ConnectorPipeline):
|
||||||
|
def __init__(self, ctx, connectors):
|
||||||
|
# Real connector pipelines should keep a list of
|
||||||
|
# Connectors.
|
||||||
|
# Use strings here for simple unit tests.
|
||||||
|
self.connectors = connectors
|
||||||
|
|
||||||
|
def test_sanity_check(self):
|
||||||
|
ctx = {}
|
||||||
|
|
||||||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||||
|
m.insert_before("Bob", self.Mary(ctx))
|
||||||
|
self.assertEqual(len(m.connectors), 3)
|
||||||
|
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||||||
|
|
||||||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||||
|
m.insert_after("Tom", self.Mary(ctx))
|
||||||
|
self.assertEqual(len(m.connectors), 3)
|
||||||
|
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||||||
|
|
||||||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||||
|
m.prepend(self.Mary(ctx))
|
||||||
|
self.assertEqual(len(m.connectors), 3)
|
||||||
|
self.assertEqual(m.connectors[0].__class__.__name__, "Mary")
|
||||||
|
|
||||||
|
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||||
|
m.append(self.Mary(ctx))
|
||||||
|
self.assertEqual(len(m.connectors), 3)
|
||||||
|
self.assertEqual(m.connectors[2].__class__.__name__, "Mary")
|
||||||
|
|
||||||
|
m.remove("Bob")
|
||||||
|
self.assertEqual(len(m.connectors), 2)
|
||||||
|
self.assertEqual(m.connectors[0].__class__.__name__, "Tom")
|
||||||
|
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
9
rllib/connectors/util.py
Normal file
9
rllib/connectors/util.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
from ray.rllib.connectors.connector import (
|
||||||
|
Connector,
|
||||||
|
get_connector,
|
||||||
|
)
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_connectors_from_cfg(config: dict) -> Dict[str, Connector]:
|
||||||
|
return {k: get_connector(*v) for k, v in config.items()}
|
|
@ -276,6 +276,8 @@ class JsonReader(InputReader):
|
||||||
|
|
||||||
# Clip actions (from any values into env's bounds), if necessary.
|
# Clip actions (from any values into env's bounds), if necessary.
|
||||||
cfg = self.ioctx.config
|
cfg = self.ioctx.config
|
||||||
|
# TODO(jungong) : we should not clip_action in input reader.
|
||||||
|
# Use connector to handle this.
|
||||||
if cfg.get("clip_actions") and self.ioctx.worker is not None:
|
if cfg.get("clip_actions") and self.ioctx.worker is not None:
|
||||||
if isinstance(batch, SampleBatch):
|
if isinstance(batch, SampleBatch):
|
||||||
batch[SampleBatch.ACTIONS] = clip_action(
|
batch[SampleBatch.ACTIONS] = clip_action(
|
||||||
|
|
|
@ -8,6 +8,10 @@ import zlib
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(jungong) : We need to handle RLlib custom space types,
|
||||||
|
# FlexDict, Repeated, and Simplex.
|
||||||
|
|
||||||
|
|
||||||
def _serialize_ndarray(array: np.ndarray) -> str:
|
def _serialize_ndarray(array: np.ndarray) -> str:
|
||||||
"""Pack numpy ndarray into Base64 encoded strings for serialization.
|
"""Pack numpy ndarray into Base64 encoded strings for serialization.
|
||||||
|
|
||||||
|
|
|
@ -182,7 +182,10 @@ def unbatch(batches_struct):
|
||||||
"""Converts input from (nested) struct of batches to batch of structs.
|
"""Converts input from (nested) struct of batches to batch of structs.
|
||||||
|
|
||||||
Input: Struct of different batches (each batch has size=3):
|
Input: Struct of different batches (each batch has size=3):
|
||||||
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
|
{
|
||||||
|
"a": np.array([1, 2, 3]),
|
||||||
|
"b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
|
||||||
|
}
|
||||||
Output: Batch (list) of structs (each of these structs representing a
|
Output: Batch (list) of structs (each of these structs representing a
|
||||||
single action):
|
single action):
|
||||||
[
|
[
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
@ -12,6 +13,8 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.rllib.env.env_context import EnvContext
|
from ray.rllib.env.env_context import EnvContext
|
||||||
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
||||||
|
@ -149,5 +152,35 @@ SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
|
||||||
# (possibly nested) dict|tuple of gym.space.Spaces.
|
# (possibly nested) dict|tuple of gym.space.Spaces.
|
||||||
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
|
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
|
||||||
|
|
||||||
|
# A list of batches of RNN states.
|
||||||
|
# Each item in this list has dimension [B, S] (S=state vector size)
|
||||||
|
StateBatches = List[List[Any]]
|
||||||
|
|
||||||
|
# Format of data output from policy forward pass.
|
||||||
|
PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict]
|
||||||
|
|
||||||
|
# Data type that is fed into and yielded from agent connectors.
|
||||||
|
AgentConnectorDataType = DeveloperAPI( # API stability declaration.
|
||||||
|
NamedTuple(
|
||||||
|
"AgentConnectorDataType", [("env_id", str), ("agent_id", str), ("data", Any)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data type that is fed into and yielded from agent connectors.
|
||||||
|
ActionConnectorDataType = DeveloperAPI( # API stability declaration.
|
||||||
|
NamedTuple(
|
||||||
|
"ActionConnectorDataType",
|
||||||
|
[("env_id", str), ("agent_id", str), ("output", PolicyOutputType)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final output data type of agent connectors.
|
||||||
|
AgentConnectorsOutput = DeveloperAPI( # API stability declaration.
|
||||||
|
NamedTuple(
|
||||||
|
"AgentConnectorsOut",
|
||||||
|
[("for_training", Dict[str, TensorStructType]), ("for_action", "SampleBatch")],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Generic type var.
|
# Generic type var.
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
Loading…
Add table
Reference in a new issue