mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Introduce basic connectors library. (#25311)
This commit is contained in:
parent
4e887fe776
commit
9b65d5535d
23 changed files with 1621 additions and 1 deletions
|
@ -21,6 +21,7 @@ RLLIB_MODEL = "rllib_model"
|
|||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||
RLLIB_ACTION_DIST = "rllib_action_dist"
|
||||
RLLIB_INPUT = "rllib_input"
|
||||
RLLIB_CONNECTOR = "rllib_connector"
|
||||
TEST = "__test__"
|
||||
KNOWN_CATEGORIES = [
|
||||
TRAINABLE_CLASS,
|
||||
|
@ -29,6 +30,7 @@ KNOWN_CATEGORIES = [
|
|||
RLLIB_PREPROCESSOR,
|
||||
RLLIB_ACTION_DIST,
|
||||
RLLIB_INPUT,
|
||||
RLLIB_CONNECTOR,
|
||||
TEST,
|
||||
]
|
||||
|
||||
|
|
28
rllib/BUILD
28
rllib/BUILD
|
@ -1292,6 +1292,34 @@ py_test(
|
|||
]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Connector tests
|
||||
# rllib/connector/
|
||||
#
|
||||
# Tag: connector
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_connector",
|
||||
tags = ["team:ml", "connector"],
|
||||
size = "small",
|
||||
srcs = ["connectors/tests/test_connector.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_action",
|
||||
tags = ["team:ml", "connector"],
|
||||
size = "small",
|
||||
srcs = ["connectors/tests/test_action.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_agent",
|
||||
tags = ["team:ml", "connector"],
|
||||
size = "small",
|
||||
srcs = ["connectors/tests/test_agent.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Env tests
|
||||
# rllib/env/
|
||||
|
|
0
rllib/connectors/__init__.py
Normal file
0
rllib/connectors/__init__.py
Normal file
43
rllib/connectors/action/clip.py
Normal file
43
rllib/connectors/action/clip.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
ActionConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.spaces.space_utils import (
|
||||
clip_action,
|
||||
get_base_struct_from_space,
|
||||
)
|
||||
from ray.rllib.utils.typing import ActionConnectorDataType
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ClipActionsConnector(ActionConnector):
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
super().__init__(ctx)
|
||||
|
||||
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||||
|
||||
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||
assert isinstance(
|
||||
ac_data.output, tuple
|
||||
), "Action connector requires PolicyOutputType data."
|
||||
|
||||
actions, states, fetches = ac_data.output
|
||||
return ActionConnectorDataType(
|
||||
ac_data.env_id,
|
||||
ac_data.agent_id,
|
||||
(clip_action(actions, self._action_space_struct), states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
return ClipActionsConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return ClipActionsConnector(ctx)
|
||||
|
||||
|
||||
register_connector(ClipActionsConnector.__name__, ClipActionsConnector)
|
79
rllib/connectors/action/lambdas.py
Normal file
79
rllib/connectors/action/lambdas.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
from typing import Any, Callable, Dict, List, Type
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
ActionConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.spaces.space_utils import unbatch
|
||||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
PolicyOutputType,
|
||||
StateBatches,
|
||||
TensorStructType,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def register_lambda_action_connector(
|
||||
name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
|
||||
) -> Type[ActionConnector]:
|
||||
"""A util to register any function transforming PolicyOutputType as an ActionConnector.
|
||||
|
||||
The only requirement is that fn should take actions, states, and fetches as input,
|
||||
and return transformed actions, states, and fetches.
|
||||
|
||||
Args:
|
||||
name: Name of the resulting actor connector.
|
||||
fn: The function that transforms PolicyOutputType.
|
||||
|
||||
Returns:
|
||||
A new ActionConnector class that transforms PolicyOutputType using fn.
|
||||
"""
|
||||
|
||||
class LambdaActionConnector(ActionConnector):
|
||||
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||
assert isinstance(
|
||||
ac_data.output, tuple
|
||||
), "Action connector requires PolicyOutputType data."
|
||||
|
||||
actions, states, fetches = ac_data.output
|
||||
return ActionConnectorDataType(
|
||||
ac_data.env_id,
|
||||
ac_data.agent_id,
|
||||
fn(actions, states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
return name, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return LambdaActionConnector(ctx)
|
||||
|
||||
LambdaActionConnector.__name__ = name
|
||||
LambdaActionConnector.__qualname__ = name
|
||||
|
||||
register_connector(name, LambdaActionConnector)
|
||||
|
||||
return LambdaActionConnector
|
||||
|
||||
|
||||
# Convert actions and states into numpy arrays if necessary.
|
||||
ConvertToNumpyConnector = register_lambda_action_connector(
|
||||
"ConvertToNumpyConnector",
|
||||
lambda actions, states, fetches: (
|
||||
convert_to_numpy(actions),
|
||||
convert_to_numpy(states),
|
||||
fetches,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Split action-component batches into single action rows.
|
||||
UnbatchActionsConnector = register_lambda_action_connector(
|
||||
"UnbatchActionsConnector",
|
||||
lambda actions, states, fetches: (unbatch(actions), states, fetches),
|
||||
)
|
43
rllib/connectors/action/normalize.py
Normal file
43
rllib/connectors/action/normalize.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
ActionConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.spaces.space_utils import (
|
||||
get_base_struct_from_space,
|
||||
unsquash_action,
|
||||
)
|
||||
from ray.rllib.utils.typing import ActionConnectorDataType
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class NormalizeActionsConnector(ActionConnector):
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
super().__init__(ctx)
|
||||
|
||||
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||||
|
||||
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||
assert isinstance(
|
||||
ac_data.output, tuple
|
||||
), "Action connector requires PolicyOutputType data."
|
||||
|
||||
actions, states, fetches = ac_data.output
|
||||
return ActionConnectorDataType(
|
||||
ac_data.env_id,
|
||||
ac_data.agent_id,
|
||||
(unsquash_action(actions, self._action_space_struct), states, fetches),
|
||||
)
|
||||
|
||||
def to_config(self):
|
||||
return NormalizeActionsConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return NormalizeActionsConnector(ctx)
|
||||
|
||||
|
||||
register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)
|
57
rllib/connectors/action/pipeline.py
Normal file
57
rllib/connectors/action/pipeline.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import gym
|
||||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ActionConnector,
|
||||
Connector,
|
||||
ConnectorContext,
|
||||
ConnectorPipeline,
|
||||
get_connector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ActionConnectorPipeline(ActionConnector, ConnectorPipeline):
|
||||
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
||||
super().__init__(ctx)
|
||||
self.connectors = connectors
|
||||
|
||||
def is_training(self, is_training: bool):
|
||||
self.is_training = is_training
|
||||
for c in self.connectors:
|
||||
c.is_training(is_training)
|
||||
|
||||
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||
for c in self.connectors:
|
||||
ac_data = c(ac_data)
|
||||
return ac_data
|
||||
|
||||
def to_config(self):
|
||||
return ActionConnectorPipeline.__name__, [
|
||||
c.to_config() for c in self.connectors
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
assert (
|
||||
type(params) == list
|
||||
), "ActionConnectorPipeline takes a list of connector params."
|
||||
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
|
||||
return ActionConnectorPipeline(ctx, connectors)
|
||||
|
||||
|
||||
register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_action_connectors_from_trainer_config(
|
||||
config: TrainerConfigDict, action_space: gym.Space
|
||||
) -> ActionConnectorPipeline:
|
||||
connectors = []
|
||||
return ActionConnectorPipeline(connectors)
|
52
rllib/connectors/agent/clip_reward.py
Normal file
52
rllib/connectors/agent/clip_reward.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import numpy as np
|
||||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.typing import AgentConnectorDataType
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ClipRewardAgentConnector(AgentConnector):
|
||||
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
|
||||
super().__init__(ctx)
|
||||
assert (
|
||||
not sign or not limit
|
||||
), "should not enable both sign and limit reward clipping."
|
||||
self.sign = sign
|
||||
self.limit = limit
|
||||
|
||||
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||
d = ac_data.data
|
||||
assert (
|
||||
type(d) == dict
|
||||
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||
|
||||
assert SampleBatch.REWARDS in d, "input data does not have reward column."
|
||||
if self.sign:
|
||||
d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS])
|
||||
elif self.limit:
|
||||
d[SampleBatch.REWARDS] = np.clip(
|
||||
d[SampleBatch.REWARDS],
|
||||
a_min=-self.limit,
|
||||
a_max=self.limit,
|
||||
)
|
||||
return [ac_data]
|
||||
|
||||
def to_config(self):
|
||||
return ClipRewardAgentConnector.__name__, {
|
||||
"sign": self.sign,
|
||||
"limit": self.limit,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return ClipRewardAgentConnector(ctx, **params)
|
||||
|
||||
|
||||
register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)
|
72
rllib/connectors/agent/env_to_agent.py
Normal file
72
rllib/connectors/agent/env_to_agent.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import AgentConnectorDataType
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class EnvToAgentDataConnector(AgentConnector):
|
||||
"""Converts per environment multi-agent obs into per agent SampleBatches."""
|
||||
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
super().__init__(ctx)
|
||||
self._view_requirements = ctx.view_requirements
|
||||
|
||||
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||
if ac_data.agent_id:
|
||||
# data is already for a single agent.
|
||||
return [ac_data]
|
||||
|
||||
assert isinstance(ac_data.data, (tuple, list)) and len(ac_data.data) == 5, (
|
||||
"EnvToPerAgentDataConnector expects a tuple of "
|
||||
+ "(obs, rewards, dones, infos, episode_infos)."
|
||||
)
|
||||
# episode_infos contains additional training related data bits
|
||||
# for each agent, such as SampleBatch.T, SampleBatch.AGENT_INDEX,
|
||||
# SampleBatch.ACTIONS, SampleBatch.DONES (if hitting horizon),
|
||||
# and is usually empty in inference mode.
|
||||
obs, rewards, dones, infos, training_episode_infos = ac_data.data
|
||||
for var, name in zip(
|
||||
(obs, rewards, dones, infos, training_episode_infos),
|
||||
("obs", "rewards", "dones", "infos", "training_episode_infos"),
|
||||
):
|
||||
assert isinstance(var, dict), (
|
||||
f"EnvToPerAgentDataConnector expects {name} "
|
||||
+ "to be a MultiAgentDict."
|
||||
)
|
||||
|
||||
env_id = ac_data.env_id
|
||||
per_agent_data = []
|
||||
for agent_id, obs in obs.items():
|
||||
input_dict = {
|
||||
SampleBatch.ENV_ID: env_id,
|
||||
SampleBatch.REWARDS: rewards[agent_id],
|
||||
# SampleBatch.DONES may be overridden by data from
|
||||
# training_episode_infos next.
|
||||
SampleBatch.DONES: dones[agent_id],
|
||||
SampleBatch.NEXT_OBS: obs,
|
||||
}
|
||||
if SampleBatch.INFOS in self._view_requirements:
|
||||
input_dict[SampleBatch.INFOS] = infos[agent_id]
|
||||
if agent_id in training_episode_infos:
|
||||
input_dict.update(training_episode_infos[agent_id])
|
||||
|
||||
per_agent_data.append(AgentConnectorDataType(env_id, agent_id, input_dict))
|
||||
|
||||
return per_agent_data
|
||||
|
||||
def to_config(self):
|
||||
return EnvToAgentDataConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return EnvToAgentDataConnector(ctx)
|
||||
|
||||
|
||||
register_connector(EnvToAgentDataConnector.__name__, EnvToAgentDataConnector)
|
81
rllib/connectors/agent/lambdas.py
Normal file
81
rllib/connectors/agent/lambdas.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import numpy as np
|
||||
import tree # dm_tree
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import (
|
||||
AgentConnectorDataType,
|
||||
TensorStructType,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def register_lambda_agent_connector(
|
||||
name: str, fn: Callable[[Any], Any]
|
||||
) -> Type[AgentConnector]:
|
||||
"""A util to register any simple transforming function as an AgentConnector
|
||||
|
||||
The only requirement is that fn should take a single data object and return
|
||||
a single data object.
|
||||
|
||||
Args:
|
||||
name: Name of the resulting actor connector.
|
||||
fn: The function that transforms env / agent data.
|
||||
|
||||
Returns:
|
||||
A new AgentConnector class that transforms data using fn.
|
||||
"""
|
||||
|
||||
class LambdaAgentConnector(AgentConnector):
|
||||
def __call__(
|
||||
self, ac_data: AgentConnectorDataType
|
||||
) -> List[AgentConnectorDataType]:
|
||||
d = ac_data.data
|
||||
return [AgentConnectorDataType(ac_data.env_id, ac_data.agent_id, fn(d))]
|
||||
|
||||
def to_config(self):
|
||||
return name, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return LambdaAgentConnector(ctx)
|
||||
|
||||
LambdaAgentConnector.__name__ = name
|
||||
LambdaAgentConnector.__qualname__ = name
|
||||
|
||||
register_connector(name, LambdaAgentConnector)
|
||||
|
||||
return LambdaAgentConnector
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def flatten_data(data: Dict[str, TensorStructType]):
|
||||
assert (
|
||||
type(data) == dict
|
||||
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||
|
||||
flattened = {}
|
||||
for k, v in data.items():
|
||||
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
|
||||
# Do not flatten infos, actions, and state_out_ columns.
|
||||
flattened[k] = v
|
||||
continue
|
||||
if v is None:
|
||||
# Keep the same column shape.
|
||||
flattened[k] = None
|
||||
continue
|
||||
flattened[k] = np.array(tree.flatten(v))
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
# Flatten observation data.
|
||||
FlattenDataAgentConnector = register_lambda_agent_connector(
|
||||
"FlattenDataAgentConnector", flatten_data
|
||||
)
|
60
rllib/connectors/agent/obs_preproc.py
Normal file
60
rllib/connectors/agent/obs_preproc.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import AgentConnectorDataType
|
||||
|
||||
|
||||
# Bridging between current obs preprocessors and connector.
|
||||
# We should not introduce any new preprocessors.
|
||||
# TODO(jungong) : migrate and implement preprocessor library in Connector framework.
|
||||
@DeveloperAPI
|
||||
class ObsPreprocessorConnector(AgentConnector):
|
||||
"""A connector that wraps around existing RLlib observation preprocessors.
|
||||
|
||||
This includes:
|
||||
- OneHotPreprocessor for Discrete and Multi-Discrete spaces.
|
||||
- GenericPixelPreprocessor and AtariRamPreprocessor for Atari spaces.
|
||||
- TupleFlatteningPreprocessor and DictFlatteningPreprocessor for flattening
|
||||
arbitrary nested input observations.
|
||||
- RepeatedValuesPreprocessor for padding observations from RLlib Repeated
|
||||
observation space.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
super().__init__(ctx)
|
||||
|
||||
self._preprocessor = get_preprocessor(ctx.observation_space)(
|
||||
ctx.observation_space, ctx.config.get("model", {})
|
||||
)
|
||||
|
||||
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||
d = ac_data.data
|
||||
assert (
|
||||
type(d) == dict
|
||||
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||
|
||||
if SampleBatch.OBS in d:
|
||||
d[SampleBatch.OBS] = self._preprocessor.transform(d[SampleBatch.OBS])
|
||||
if SampleBatch.NEXT_OBS in d:
|
||||
d[SampleBatch.NEXT_OBS] = self._preprocessor.transform(
|
||||
d[SampleBatch.NEXT_OBS]
|
||||
)
|
||||
|
||||
return [ac_data]
|
||||
|
||||
def to_config(self):
|
||||
return ObsPreprocessorConnector.__name__, {}
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return ObsPreprocessorConnector(ctx, **params)
|
||||
|
||||
|
||||
register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)
|
79
rllib/connectors/agent/pipeline.py
Normal file
79
rllib/connectors/agent/pipeline.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import gym
|
||||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
Connector,
|
||||
ConnectorContext,
|
||||
ConnectorPipeline,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
get_connector,
|
||||
)
|
||||
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
||||
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
AgentConnectorDataType,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class AgentConnectorPipeline(AgentConnector, ConnectorPipeline):
|
||||
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
||||
super().__init__(ctx)
|
||||
self.connectors = connectors
|
||||
|
||||
def is_training(self, is_training: bool):
|
||||
self.is_training = is_training
|
||||
for c in self.connectors:
|
||||
c.is_training(is_training)
|
||||
|
||||
def reset(self, env_id: str):
|
||||
for c in self.connectors:
|
||||
c.reset(env_id)
|
||||
|
||||
def on_policy_output(self, output: ActionConnectorDataType):
|
||||
for c in self.connectors:
|
||||
c.on_policy_output(output)
|
||||
|
||||
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||
ret = [ac_data]
|
||||
for c in self.connectors:
|
||||
# Run the list of input data through the next agent connect,
|
||||
# and collect the list of output data.
|
||||
new_ret = []
|
||||
for d in ret:
|
||||
new_ret += c(d)
|
||||
ret = new_ret
|
||||
return ret
|
||||
|
||||
def to_config(self):
|
||||
return AgentConnectorPipeline.__name__, [c.to_config() for c in self.connectors]
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
assert (
|
||||
type(params) == list
|
||||
), "AgentConnectorPipeline takes a list of connector params."
|
||||
connectors = [get_connector(ctx, name, subparams) for name, subparams in params]
|
||||
return AgentConnectorPipeline(ctx, connectors)
|
||||
|
||||
|
||||
register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
|
||||
|
||||
|
||||
# TODO(jungong) : finish this.
|
||||
@DeveloperAPI
|
||||
def get_agent_connectors_from_config(
|
||||
config: TrainerConfigDict, obs_space: gym.Space
|
||||
) -> AgentConnectorPipeline:
|
||||
connectors = [FlattenDataAgentConnector()]
|
||||
|
||||
if config["clip_rewards"] is True:
|
||||
connectors.append(ClipRewardAgentConnector(sign=True))
|
||||
elif type(config["clip_rewards"]) == float:
|
||||
connectors.append(ClipRewardAgentConnector(limit=abs(config["clip_rewards"])))
|
||||
|
||||
return AgentConnectorPipeline(connectors)
|
99
rllib/connectors/agent/state_buffer.py
Normal file
99
rllib/connectors/agent/state_buffer.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import tree # dm_tree
|
||||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
||||
from ray.rllib.utils.typing import (
|
||||
AgentConnectorDataType,
|
||||
PolicyOutputType,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class _AgentState(object):
|
||||
def __init__(self):
|
||||
self.t = 0
|
||||
self.action = None
|
||||
self.states = None
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class StateBufferConnector(AgentConnector):
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
super().__init__(ctx)
|
||||
|
||||
self._initial_states = ctx.initial_states
|
||||
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
||||
self._states = defaultdict(lambda: defaultdict(_AgentState))
|
||||
|
||||
def reset(self, env_id: str):
|
||||
del self._states[env_id]
|
||||
|
||||
def on_policy_output(self, env_id: str, agent_id: str, output: PolicyOutputType):
|
||||
# Buffer latest output states for next input __call__.
|
||||
action, states, _ = output
|
||||
agent_state = self._states[env_id][agent_id]
|
||||
agent_state.action = convert_to_numpy(action)
|
||||
agent_state.states = convert_to_numpy(states)
|
||||
|
||||
def __call__(
|
||||
self, ctx: ConnectorContext, ac_data: AgentConnectorDataType
|
||||
) -> List[AgentConnectorDataType]:
|
||||
d = ac_data.data
|
||||
assert (
|
||||
type(d) == dict
|
||||
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||
|
||||
env_id = ac_data.env_id
|
||||
agent_id = ac_data.agent_id
|
||||
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
|
||||
|
||||
agent_state = self._states[env_id][agent_id]
|
||||
|
||||
d.update(
|
||||
{
|
||||
SampleBatch.T: agent_state.t,
|
||||
SampleBatch.ENV_ID: env_id,
|
||||
}
|
||||
)
|
||||
|
||||
if agent_state.states is not None:
|
||||
states = agent_state.states
|
||||
else:
|
||||
states = self._initial_states
|
||||
for i, v in enumerate(states):
|
||||
d["state_out_{}".format(i)] = v
|
||||
|
||||
if agent_state.action is not None:
|
||||
d[SampleBatch.ACTIONS] = agent_state.action # Last action
|
||||
else:
|
||||
# Default zero action.
|
||||
d[SampleBatch.ACTIONS] = tree.map_structure(
|
||||
lambda s: np.zeros_like(s.sample(), s.dtype)
|
||||
if hasattr(s, "dtype")
|
||||
else np.zeros_like(s.sample()),
|
||||
self._action_space_struct,
|
||||
)
|
||||
|
||||
agent_state.t += 1
|
||||
|
||||
return [ac_data]
|
||||
|
||||
def to_config(self):
|
||||
return StateBufferConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return StateBufferConnector(ctx)
|
||||
|
||||
|
||||
register_connector(StateBufferConnector.__name__, StateBufferConnector)
|
136
rllib/connectors/agent/view_requirement.py
Normal file
136
rllib/connectors/agent/view_requirement.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from typing import Any, List
|
||||
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
AgentConnector,
|
||||
register_connector,
|
||||
)
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.typing import (
|
||||
AgentConnectorDataType,
|
||||
AgentConnectorsOutput,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ViewRequirementAgentConnector(AgentConnector):
|
||||
"""This connector does 2 things:
|
||||
1. It filters data columns based on view_requirements for training and inference.
|
||||
2. It buffers the right amount of history for computing the sample batch for
|
||||
action computation.
|
||||
The output of this connector is AgentConnectorsOut, which basically is
|
||||
a tuple of 2 things:
|
||||
{
|
||||
"for_training": {"obs": ...}
|
||||
"for_action": SampleBatch
|
||||
}
|
||||
The "for_training" dict, which contains data for the latest time slice,
|
||||
can be used to construct a complete episode by Sampler for training purpose.
|
||||
The "for_action" SampleBatch can be used to directly call the policy.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
super().__init__(ctx)
|
||||
|
||||
self._view_requirements = ctx.view_requirements
|
||||
self._agent_data = defaultdict(lambda: defaultdict(SampleBatch))
|
||||
|
||||
def reset(self, env_id: str):
|
||||
if env_id in self._agent_data:
|
||||
del self._agent_data[env_id]
|
||||
|
||||
def _get_sample_batch_for_action(
|
||||
self, view_requirements, agent_batch
|
||||
) -> SampleBatch:
|
||||
# TODO(jungong) : actually support buildling input sample batch with all the
|
||||
# view shift requirements, etc.
|
||||
# For now, we use some simple logics for demo purpose.
|
||||
input_batch = SampleBatch()
|
||||
for k, v in view_requirements.items():
|
||||
if not v.used_for_compute_actions:
|
||||
continue
|
||||
data_col = v.data_col or k
|
||||
if data_col not in agent_batch:
|
||||
continue
|
||||
input_batch[k] = agent_batch[data_col][-1:]
|
||||
input_batch.count = 1
|
||||
return input_batch
|
||||
|
||||
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||
d = ac_data.data
|
||||
assert (
|
||||
type(d) == dict
|
||||
), "Single agent data must be of type Dict[str, TensorStructType]"
|
||||
|
||||
env_id = ac_data.env_id
|
||||
agent_id = ac_data.agent_id
|
||||
assert env_id and agent_id, "StateBufferConnector requires env_id and agent_id"
|
||||
|
||||
vr = self._view_requirements
|
||||
assert vr, "ViewRequirements required by ViewRequirementConnector"
|
||||
|
||||
training_dict = {}
|
||||
# We construct a proper per-timeslice dict in training mode,
|
||||
# for Sampler to construct a complete episode for back propagation.
|
||||
if self.is_training:
|
||||
# Filter columns that are not needed for traing.
|
||||
for col, req in vr.items():
|
||||
# Not used for training.
|
||||
if not req.used_for_training:
|
||||
continue
|
||||
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = req.data_col or col
|
||||
if data_col not in d:
|
||||
continue
|
||||
|
||||
training_dict[data_col] = d[data_col]
|
||||
|
||||
# Agent batch is our buffer of necessary history for computing
|
||||
# a SampleBatch for policy forward pass.
|
||||
# This is used by both training and inference.
|
||||
agent_batch = self._agent_data[env_id][agent_id]
|
||||
for col, req in vr.items():
|
||||
# Not used for action computation.
|
||||
if not req.used_for_compute_actions:
|
||||
continue
|
||||
|
||||
# Create the batch of data from the different buffers.
|
||||
data_col = req.data_col or col
|
||||
if data_col not in d:
|
||||
continue
|
||||
|
||||
# Add batch dim to this data_col.
|
||||
d_col = np.expand_dims(d[data_col], axis=0)
|
||||
|
||||
if col in agent_batch:
|
||||
# Stack along batch dim.
|
||||
agent_batch[data_col] = np.vstack((agent_batch[data_col], d_col))
|
||||
else:
|
||||
agent_batch[data_col] = d_col
|
||||
# Only keep the useful part of the history.
|
||||
h = req.shift_from if req.shift_from else -1
|
||||
assert h <= 0, "Can use future data to compute action"
|
||||
agent_batch[data_col] = agent_batch[data_col][h:]
|
||||
|
||||
sample_batch = self._get_sample_batch_for_action(vr, agent_batch)
|
||||
|
||||
return_data = AgentConnectorDataType(
|
||||
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
|
||||
)
|
||||
return return_data
|
||||
|
||||
def to_config(self):
|
||||
return ViewRequirementAgentConnector.__name__, None
|
||||
|
||||
@staticmethod
|
||||
def from_config(ctx: ConnectorContext, params: List[Any]):
|
||||
return ViewRequirementAgentConnector(ctx)
|
||||
|
||||
|
||||
register_connector(
|
||||
ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector
|
||||
)
|
366
rllib/connectors/connector.py
Normal file
366
rllib/connectors/connector.py
Normal file
|
@ -0,0 +1,366 @@
|
|||
"""This file defines base types and common structures for RLlib connectors.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import gym
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from ray.tune.registry import RLLIB_CONNECTOR, _global_registry
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
AgentConnectorDataType,
|
||||
TensorType,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ConnectorContext:
|
||||
"""Data bits that may be needed for running connectors.
|
||||
|
||||
Note(jungong) : we need to be really careful with the data fields here.
|
||||
E.g., everything needs to be serializable, in case we need to fetch them
|
||||
in a remote setting.
|
||||
"""
|
||||
|
||||
# TODO(jungong) : figure out how to fetch these in a remote setting.
|
||||
# Probably from a policy server when initializing a policy client.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainerConfigDict = None,
|
||||
model_initial_states: List[TensorType] = None,
|
||||
observation_space: gym.Space = None,
|
||||
action_space: gym.Space = None,
|
||||
view_requirements: Dict[str, ViewRequirement] = None,
|
||||
):
|
||||
"""Construct a ConnectorContext instance.
|
||||
|
||||
Args:
|
||||
model_initial_states: States that are used for constructing
|
||||
the initial input dict for RNN models. [] if a model is not recurrent.
|
||||
action_space_struct: a policy's action space, in python
|
||||
data format. E.g., python dict instead of DictSpace, python tuple
|
||||
instead of TupleSpace.
|
||||
"""
|
||||
self.config = config
|
||||
self.initial_states = model_initial_states or []
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.view_requirements = view_requirements
|
||||
|
||||
@staticmethod
|
||||
def from_policy(policy: Policy) -> "ConnectorContext":
|
||||
"""Build ConnectorContext from a given policy.
|
||||
|
||||
Args:
|
||||
policy: Policy
|
||||
|
||||
Returns:
|
||||
A ConnectorContext instance.
|
||||
"""
|
||||
return ConnectorContext(
|
||||
policy.config,
|
||||
policy.get_initial_state(),
|
||||
policy.observation_space,
|
||||
policy.action_space,
|
||||
policy.view_requirements,
|
||||
)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class Connector(abc.ABC):
|
||||
"""Connector base class.
|
||||
|
||||
A connector is a step of transformation, of either envrionment data before they
|
||||
get to a policy, or policy output before it is sent back to the environment.
|
||||
|
||||
Connectors may be training-aware, for example, behave slightly differently
|
||||
during training and inference.
|
||||
|
||||
All connectors are required to be serializable and implement to_config().
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: ConnectorContext):
|
||||
# This gets flipped to False for inference.
|
||||
self.is_training = True
|
||||
|
||||
def is_training(self, is_training: bool):
|
||||
self.is_training = is_training
|
||||
|
||||
def to_config(self) -> Tuple[str, List[Any]]:
|
||||
"""Serialize a connector into a JSON serializable Tuple.
|
||||
|
||||
to_config is required, so that all Connectors are serializable.
|
||||
|
||||
Returns:
|
||||
A tuple of connector's name and its serialized states.
|
||||
"""
|
||||
# Must implement by each connector.
|
||||
return NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def from_config(self, ctx: ConnectorContext, params: List[Any]) -> "Connector":
|
||||
"""De-serialize a JSON params back into a Connector.
|
||||
|
||||
from_config is required, so that all Connectors are serializable.
|
||||
|
||||
Args:
|
||||
ctx: Context for constructing this connector.
|
||||
params: Serialized states of the connector to be recovered.
|
||||
|
||||
Returns:
|
||||
De-serialized connector.
|
||||
"""
|
||||
# Must implement by each connector.
|
||||
return NotImplementedError
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class AgentConnector(Connector):
|
||||
"""Connector connecting user environments to RLlib policies.
|
||||
|
||||
An agent connector transforms a single piece of data in AgentConnectorDataType
|
||||
format into a list of data in the same AgentConnectorDataTypes format.
|
||||
The API is designed so multi-agent observations can be broken and emitted as
|
||||
multiple single agent observations.
|
||||
|
||||
AgentConnectorDataTypes can be used to specify arbitrary type of env data,
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
# A dict of multi-agent data from one env step() call.
|
||||
ac = AgentConnectorDataType(
|
||||
env_id="env_1",
|
||||
agent_id=None,
|
||||
data={
|
||||
"agent_1": np.array(...),
|
||||
"agent_2": np.array(...),
|
||||
}
|
||||
)
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
# Single agent data ready to be preprocessed.
|
||||
ac = AgentConnectorDataType(
|
||||
env_id="env_1",
|
||||
agent_id="agent_1",
|
||||
data=np.array(...)
|
||||
)
|
||||
|
||||
We can adapt a simple stateless function into an agent connector by using
|
||||
register_lambda_agent_connector:
|
||||
.. code-block:: python
|
||||
TimesTwoAgentConnector = register_lambda_agent_connector(
|
||||
"TimesTwoAgentConnector", lambda data: data * 2
|
||||
)
|
||||
|
||||
More complicated agent connectors can be implemented by extending this
|
||||
AgentConnector class:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
class FrameSkippingAgentConnector(AgentConnector):
|
||||
def __init__(self, n):
|
||||
self._n = n
|
||||
self._frame_count = default_dict(str, default_dict(str, int))
|
||||
|
||||
def reset(self, env_id: str):
|
||||
del self._frame_count[env_id]
|
||||
|
||||
def __call__(
|
||||
self, ac_data: AgentConnectorDataType
|
||||
) -> List[AgentConnectorDataType]:
|
||||
assert ac_data.env_id and ac_data.agent_id, (
|
||||
"Frame skipping works per agent")
|
||||
|
||||
count = self._frame_count[ac_data.env_id][ac_data.agent_id]
|
||||
self._frame_count[ac_data.env_id][ac_data.agent_id] = count + 1
|
||||
|
||||
return [ac_data] if count % self._n == 0 else []
|
||||
|
||||
As shown, an agent connector may choose to emit an empty list to stop input
|
||||
observations from being prosessed further.
|
||||
"""
|
||||
|
||||
def reset(self, env_id: str):
|
||||
"""Reset connector state for a specific environment.
|
||||
|
||||
For example, at the end of an episode.
|
||||
|
||||
Args:
|
||||
env_id: required. ID of a user environment. Required.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_policy_output(self, output: ActionConnectorDataType):
|
||||
"""Callback on agent connector of policy output.
|
||||
|
||||
This is useful for certain connectors, for example RNN state buffering,
|
||||
where the agent connect needs to be aware of the output of a policy
|
||||
forward pass.
|
||||
|
||||
Args:
|
||||
ctx: Context for running this connector call.
|
||||
output: Env and agent IDs, plus data output from policy forward pass.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, ac_data: AgentConnectorDataType) -> List[AgentConnectorDataType]:
|
||||
"""Transform incoming data from environment before they reach policy.
|
||||
|
||||
Args:
|
||||
ctx: Context for running this connector call.
|
||||
data: Env and agent IDs, plus arbitrary data from an environment or
|
||||
upstream agent connectors.
|
||||
|
||||
Returns:
|
||||
A list of transformed data in AgentConnectorDataType format.
|
||||
The return type is a list because an AgentConnector may choose to
|
||||
derive multiple outputs for a single input data, for example
|
||||
multi-agent obs -> multiple single agent obs.
|
||||
Agent connectors may also return an empty list for certain input,
|
||||
useful for connectors such as frame skipping.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ActionConnector(Connector):
|
||||
"""Action connector connects policy outputs including actions,
|
||||
to user environments.
|
||||
|
||||
An action connector transforms a single piece of policy output in
|
||||
ActionConnectorDataType format, which is basically PolicyOutputType
|
||||
plus env and agent IDs.
|
||||
|
||||
Any functions that operates directly on PolicyOutputType can be
|
||||
easily adpated into an ActionConnector by using register_lambda_action_connector.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
ZeroActionConnector = register_lambda_action_connector(
|
||||
"ZeroActionsConnector",
|
||||
lambda actions, states, fetches: (
|
||||
np.zeros_like(actions), states, fetches
|
||||
)
|
||||
)
|
||||
|
||||
More complicated action connectors can also be implemented by sub-classing
|
||||
this ActionConnector class.
|
||||
"""
|
||||
|
||||
def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
|
||||
"""Transform policy output before they are sent to a user environment.
|
||||
|
||||
Args:
|
||||
ctx: Context for running this connector call.
|
||||
ac_data: Env and agent IDs, plus policy output.
|
||||
|
||||
Returns:
|
||||
The processed action connector data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ConnectorPipeline:
|
||||
"""Utility class for quick manipulation of a connector pipeline."""
|
||||
|
||||
def remove(self, name: str):
|
||||
"""Remove a connector by <name>
|
||||
|
||||
Args:
|
||||
name: name of the connector to be removed.
|
||||
"""
|
||||
idx = -1
|
||||
for idx, c in enumerate(self.connectors):
|
||||
if c.__class__.__name__ == name:
|
||||
break
|
||||
if idx < 0:
|
||||
raise ValueError(f"Can not find connector {name}")
|
||||
del self.connectors[idx]
|
||||
|
||||
def insert_before(self, name: str, connector: Connector):
|
||||
"""Insert a new connector before connector <name>
|
||||
|
||||
Args:
|
||||
name: name of the connector before which a new connector
|
||||
will get inserted.
|
||||
connector: a new connector to be inserted.
|
||||
"""
|
||||
idx = -1
|
||||
for idx, c in enumerate(self.connectors):
|
||||
if c.__class__.__name__ == name:
|
||||
break
|
||||
if idx < 0:
|
||||
raise ValueError(f"Can not find connector {name}")
|
||||
self.connectors.insert(idx, connector)
|
||||
|
||||
def insert_after(self, name: str, connector: Connector):
|
||||
"""Insert a new connector after connector <name>
|
||||
|
||||
Args:
|
||||
name: name of the connector after which a new connector
|
||||
will get inserted.
|
||||
connector: a new connector to be inserted.
|
||||
"""
|
||||
idx = -1
|
||||
for idx, c in enumerate(self.connectors):
|
||||
if c.__class__.__name__ == name:
|
||||
break
|
||||
if idx < 0:
|
||||
raise ValueError(f"Can not find connector {name}")
|
||||
self.connectors.insert(idx + 1, connector)
|
||||
|
||||
def prepend(self, connector: Connector):
|
||||
"""Append a new connector at the beginning of a connector pipeline.
|
||||
|
||||
Args:
|
||||
connector: a new connector to be appended.
|
||||
"""
|
||||
self.connectors.insert(0, connector)
|
||||
|
||||
def append(self, connector: Connector):
|
||||
"""Append a new connector at the end of a connector pipeline.
|
||||
|
||||
Args:
|
||||
connector: a new connector to be appended.
|
||||
"""
|
||||
self.connectors.append(connector)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def register_connector(name: str, cls: Connector):
|
||||
"""Register a connector for use with RLlib.
|
||||
|
||||
Args:
|
||||
name: Name to register.
|
||||
cls: Callable that creates an env.
|
||||
"""
|
||||
if not issubclass(cls, Connector):
|
||||
raise TypeError("Can only register Connector type.", cls)
|
||||
_global_registry.register(RLLIB_CONNECTOR, name, cls)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_connector(ctx: ConnectorContext, name: str, params: Tuple[Any]) -> Connector:
|
||||
"""Get a connector by its name and serialized config.
|
||||
|
||||
Args:
|
||||
name: name of the connector.
|
||||
params: serialized parameters of the connector.
|
||||
|
||||
Returns:
|
||||
Constructed connector.
|
||||
"""
|
||||
if not _global_registry.contains(RLLIB_CONNECTOR, name):
|
||||
raise NameError("connector not found.", name)
|
||||
cls = _global_registry.get(RLLIB_CONNECTOR, name)
|
||||
return cls.from_config(ctx, params)
|
127
rllib/connectors/tests/test_action.py
Normal file
127
rllib/connectors/tests/test_action.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.connectors.action.clip import ClipActionsConnector
|
||||
from ray.rllib.connectors.action.lambdas import (
|
||||
ConvertToNumpyConnector,
|
||||
UnbatchActionsConnector,
|
||||
)
|
||||
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
|
||||
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
get_connector,
|
||||
)
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ActionConnectorDataType
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
class TestActionConnector(unittest.TestCase):
|
||||
def test_connector_pipeline(self):
|
||||
ctx = ConnectorContext()
|
||||
connectors = [ConvertToNumpyConnector(ctx)]
|
||||
pipeline = ActionConnectorPipeline(ctx, connectors)
|
||||
name, params = pipeline.to_config()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
|
||||
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
|
||||
|
||||
def test_convert_to_numpy_connector(self):
|
||||
ctx = ConnectorContext()
|
||||
c = ConvertToNumpyConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
|
||||
self.assertEqual(name, "ConvertToNumpyConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
|
||||
|
||||
action = torch.Tensor([8, 9])
|
||||
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
|
||||
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
|
||||
|
||||
converted = c(ac_data)
|
||||
self.assertTrue(isinstance(converted.output[0], np.ndarray))
|
||||
self.assertTrue(isinstance(converted.output[1], np.ndarray))
|
||||
|
||||
def test_unbatch_action_connector(self):
|
||||
ctx = ConnectorContext()
|
||||
c = UnbatchActionsConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
|
||||
self.assertEqual(name, "UnbatchActionsConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, UnbatchActionsConnector))
|
||||
|
||||
ac_data = ActionConnectorDataType(
|
||||
0,
|
||||
1,
|
||||
(
|
||||
{
|
||||
"a": np.array([1, 2, 3]),
|
||||
"b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
|
||||
},
|
||||
[],
|
||||
{},
|
||||
),
|
||||
)
|
||||
|
||||
unbatched = c(ac_data)
|
||||
actions, _, _ = unbatched.output
|
||||
|
||||
self.assertEqual(len(actions), 3)
|
||||
self.assertEqual(actions[0]["a"], 1)
|
||||
self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())
|
||||
|
||||
self.assertEqual(actions[1]["a"], 2)
|
||||
self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())
|
||||
|
||||
self.assertEqual(actions[2]["a"], 3)
|
||||
self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
|
||||
|
||||
def test_normalize_action_connector(self):
|
||||
ctx = ConnectorContext(
|
||||
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
||||
)
|
||||
c = NormalizeActionsConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
self.assertEqual(name, "NormalizeActionsConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
|
||||
|
||||
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
|
||||
|
||||
normalized = c(ac_data)
|
||||
self.assertEqual(normalized.output[0], 4.5)
|
||||
|
||||
def test_clip_action_connector(self):
|
||||
ctx = ConnectorContext(
|
||||
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
|
||||
)
|
||||
c = ClipActionsConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
self.assertEqual(name, "ClipActionsConnector")
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ClipActionsConnector))
|
||||
|
||||
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
|
||||
|
||||
clipped = c(ac_data)
|
||||
self.assertEqual(clipped.output[0], 6.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
186
rllib/connectors/tests/test_agent.py
Normal file
186
rllib/connectors/tests/test_agent.py
Normal file
|
@ -0,0 +1,186 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
|
||||
from ray.rllib.connectors.agent.env_to_agent import EnvToAgentDataConnector
|
||||
from ray.rllib.connectors.agent.lambdas import FlattenDataAgentConnector
|
||||
from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
|
||||
from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
|
||||
from ray.rllib.connectors.connector import (
|
||||
ConnectorContext,
|
||||
get_connector,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.typing import (
|
||||
AgentConnectorDataType,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentConnector(unittest.TestCase):
|
||||
def test_connector_pipeline(self):
|
||||
ctx = ConnectorContext()
|
||||
connectors = [ClipRewardAgentConnector(ctx, False, 1.0)]
|
||||
pipeline = AgentConnectorPipeline(ctx, connectors)
|
||||
name, params = pipeline.to_config()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, AgentConnectorPipeline))
|
||||
self.assertTrue(isinstance(restored.connectors[0], ClipRewardAgentConnector))
|
||||
|
||||
def test_env_to_per_agent_data_connector(self):
|
||||
vrs = {
|
||||
"infos": ViewRequirement(
|
||||
"infos",
|
||||
used_for_training=True,
|
||||
used_for_compute_actions=False,
|
||||
)
|
||||
}
|
||||
ctx = ConnectorContext(view_requirements=vrs)
|
||||
|
||||
c = EnvToAgentDataConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, EnvToAgentDataConnector))
|
||||
|
||||
d = AgentConnectorDataType(
|
||||
0,
|
||||
None,
|
||||
[
|
||||
# obs
|
||||
{1: [8, 8], 2: [9, 9]},
|
||||
# rewards
|
||||
{
|
||||
1: 8.8,
|
||||
2: 9.9,
|
||||
},
|
||||
# dones
|
||||
{
|
||||
1: False,
|
||||
2: False,
|
||||
},
|
||||
# infos
|
||||
{
|
||||
1: {"random": "info"},
|
||||
2: {},
|
||||
},
|
||||
# training_episode_info
|
||||
{
|
||||
1: {SampleBatch.DONES: True},
|
||||
},
|
||||
],
|
||||
)
|
||||
per_agent = c(d)
|
||||
|
||||
self.assertEqual(len(per_agent), 2)
|
||||
|
||||
batch1 = per_agent[0].data
|
||||
self.assertEqual(batch1[SampleBatch.NEXT_OBS], [8, 8])
|
||||
self.assertTrue(batch1[SampleBatch.DONES]) # from training_episode_info
|
||||
self.assertTrue(SampleBatch.INFOS in batch1)
|
||||
self.assertEqual(batch1[SampleBatch.INFOS]["random"], "info")
|
||||
|
||||
batch2 = per_agent[1].data
|
||||
self.assertEqual(batch2[SampleBatch.NEXT_OBS], [9, 9])
|
||||
self.assertFalse(batch2[SampleBatch.DONES])
|
||||
|
||||
def test_obs_preprocessor_connector(self):
|
||||
obs_space = gym.spaces.Dict(
|
||||
{
|
||||
"a": gym.spaces.Box(low=0, high=1, shape=(1,)),
|
||||
"b": gym.spaces.Tuple(
|
||||
[gym.spaces.Discrete(2), gym.spaces.MultiDiscrete(nvec=[2, 3])]
|
||||
),
|
||||
}
|
||||
)
|
||||
ctx = ConnectorContext(config={}, observation_space=obs_space)
|
||||
|
||||
c = ObsPreprocessorConnector(ctx)
|
||||
name, params = c.to_config()
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ObsPreprocessorConnector))
|
||||
|
||||
obs = obs_space.sample()
|
||||
# Fake deterministic data.
|
||||
obs["a"][0] = 0.5
|
||||
obs["b"] = (1, np.array([0, 2]))
|
||||
|
||||
d = AgentConnectorDataType(
|
||||
0,
|
||||
1,
|
||||
{
|
||||
SampleBatch.OBS: obs,
|
||||
},
|
||||
)
|
||||
preprocessed = c(d)
|
||||
|
||||
# obs is completely flattened.
|
||||
self.assertTrue(
|
||||
(preprocessed[0].data[SampleBatch.OBS] == [0.5, 0, 1, 1, 0, 0, 0, 1]).all()
|
||||
)
|
||||
|
||||
def test_clip_reward_connector(self):
|
||||
ctx = ConnectorContext()
|
||||
|
||||
c = ClipRewardAgentConnector(ctx, limit=2.0)
|
||||
name, params = c.to_config()
|
||||
|
||||
self.assertEqual(name, "ClipRewardAgentConnector")
|
||||
self.assertAlmostEqual(params["limit"], 2.0)
|
||||
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, ClipRewardAgentConnector))
|
||||
|
||||
d = AgentConnectorDataType(
|
||||
0,
|
||||
1,
|
||||
{
|
||||
SampleBatch.REWARDS: 5.8,
|
||||
},
|
||||
)
|
||||
clipped = restored(ac_data=d)
|
||||
|
||||
self.assertEqual(len(clipped), 1)
|
||||
self.assertEqual(clipped[0].data[SampleBatch.REWARDS], 2.0)
|
||||
|
||||
def test_flatten_data_connector(self):
|
||||
ctx = ConnectorContext()
|
||||
|
||||
c = FlattenDataAgentConnector(ctx)
|
||||
|
||||
name, params = c.to_config()
|
||||
restored = get_connector(ctx, name, params)
|
||||
self.assertTrue(isinstance(restored, FlattenDataAgentConnector))
|
||||
|
||||
d = AgentConnectorDataType(
|
||||
0,
|
||||
1,
|
||||
{
|
||||
SampleBatch.NEXT_OBS: {
|
||||
"sensor1": [[1, 1], [2, 2]],
|
||||
"sensor2": 8.8,
|
||||
},
|
||||
SampleBatch.REWARDS: 5.8,
|
||||
SampleBatch.ACTIONS: [[1, 1], [2]],
|
||||
SampleBatch.INFOS: {"random": "info"},
|
||||
},
|
||||
)
|
||||
|
||||
flattened = c(d)
|
||||
self.assertEqual(len(flattened), 1)
|
||||
|
||||
batch = flattened[0].data
|
||||
self.assertTrue((batch[SampleBatch.NEXT_OBS] == [1, 1, 2, 2, 8.8]).all())
|
||||
self.assertEqual(batch[SampleBatch.REWARDS][0], 5.8)
|
||||
# Not flattened.
|
||||
self.assertEqual(len(batch[SampleBatch.ACTIONS]), 2)
|
||||
self.assertEqual(batch[SampleBatch.INFOS]["random"], "info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
59
rllib/connectors/tests/test_connector.py
Normal file
59
rllib/connectors/tests/test_connector.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import unittest
|
||||
|
||||
from ray.rllib.connectors.connector import Connector, ConnectorPipeline
|
||||
|
||||
|
||||
class TestConnectorPipeline(unittest.TestCase):
|
||||
class Tom(Connector):
|
||||
def to_config():
|
||||
return "tom"
|
||||
|
||||
class Bob(Connector):
|
||||
def to_config():
|
||||
return "bob"
|
||||
|
||||
class Mary(Connector):
|
||||
def to_config():
|
||||
return "mary"
|
||||
|
||||
class MockConnectorPipeline(ConnectorPipeline):
|
||||
def __init__(self, ctx, connectors):
|
||||
# Real connector pipelines should keep a list of
|
||||
# Connectors.
|
||||
# Use strings here for simple unit tests.
|
||||
self.connectors = connectors
|
||||
|
||||
def test_sanity_check(self):
|
||||
ctx = {}
|
||||
|
||||
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||
m.insert_before("Bob", self.Mary(ctx))
|
||||
self.assertEqual(len(m.connectors), 3)
|
||||
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||||
|
||||
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||
m.insert_after("Tom", self.Mary(ctx))
|
||||
self.assertEqual(len(m.connectors), 3)
|
||||
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||||
|
||||
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||
m.prepend(self.Mary(ctx))
|
||||
self.assertEqual(len(m.connectors), 3)
|
||||
self.assertEqual(m.connectors[0].__class__.__name__, "Mary")
|
||||
|
||||
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
|
||||
m.append(self.Mary(ctx))
|
||||
self.assertEqual(len(m.connectors), 3)
|
||||
self.assertEqual(m.connectors[2].__class__.__name__, "Mary")
|
||||
|
||||
m.remove("Bob")
|
||||
self.assertEqual(len(m.connectors), 2)
|
||||
self.assertEqual(m.connectors[0].__class__.__name__, "Tom")
|
||||
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
9
rllib/connectors/util.py
Normal file
9
rllib/connectors/util.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from ray.rllib.connectors.connector import (
|
||||
Connector,
|
||||
get_connector,
|
||||
)
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def get_connectors_from_cfg(config: dict) -> Dict[str, Connector]:
|
||||
return {k: get_connector(*v) for k, v in config.items()}
|
|
@ -276,6 +276,8 @@ class JsonReader(InputReader):
|
|||
|
||||
# Clip actions (from any values into env's bounds), if necessary.
|
||||
cfg = self.ioctx.config
|
||||
# TODO(jungong) : we should not clip_action in input reader.
|
||||
# Use connector to handle this.
|
||||
if cfg.get("clip_actions") and self.ioctx.worker is not None:
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch[SampleBatch.ACTIONS] = clip_action(
|
||||
|
|
|
@ -8,6 +8,10 @@ import zlib
|
|||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
# TODO(jungong) : We need to handle RLlib custom space types,
|
||||
# FlexDict, Repeated, and Simplex.
|
||||
|
||||
|
||||
def _serialize_ndarray(array: np.ndarray) -> str:
|
||||
"""Pack numpy ndarray into Base64 encoded strings for serialization.
|
||||
|
||||
|
|
|
@ -182,7 +182,10 @@ def unbatch(batches_struct):
|
|||
"""Converts input from (nested) struct of batches to batch of structs.
|
||||
|
||||
Input: Struct of different batches (each batch has size=3):
|
||||
{"a": [1, 2, 3], "b": ([4, 5, 6], [7.0, 8.0, 9.0])}
|
||||
{
|
||||
"a": np.array([1, 2, 3]),
|
||||
"b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
|
||||
}
|
||||
Output: Batch (list) of structs (each of these structs representing a
|
||||
single action):
|
||||
[
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import (
|
|||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
|
@ -12,6 +13,8 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
|
||||
|
@ -149,5 +152,35 @@ SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
|
|||
# (possibly nested) dict|tuple of gym.space.Spaces.
|
||||
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
|
||||
|
||||
# A list of batches of RNN states.
|
||||
# Each item in this list has dimension [B, S] (S=state vector size)
|
||||
StateBatches = List[List[Any]]
|
||||
|
||||
# Format of data output from policy forward pass.
|
||||
PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict]
|
||||
|
||||
# Data type that is fed into and yielded from agent connectors.
|
||||
AgentConnectorDataType = DeveloperAPI( # API stability declaration.
|
||||
NamedTuple(
|
||||
"AgentConnectorDataType", [("env_id", str), ("agent_id", str), ("data", Any)]
|
||||
)
|
||||
)
|
||||
|
||||
# Data type that is fed into and yielded from agent connectors.
|
||||
ActionConnectorDataType = DeveloperAPI( # API stability declaration.
|
||||
NamedTuple(
|
||||
"ActionConnectorDataType",
|
||||
[("env_id", str), ("agent_id", str), ("output", PolicyOutputType)],
|
||||
)
|
||||
)
|
||||
|
||||
# Final output data type of agent connectors.
|
||||
AgentConnectorsOutput = DeveloperAPI( # API stability declaration.
|
||||
NamedTuple(
|
||||
"AgentConnectorsOut",
|
||||
[("for_training", Dict[str, TensorStructType]), ("for_action", "SampleBatch")],
|
||||
)
|
||||
)
|
||||
|
||||
# Generic type var.
|
||||
T = TypeVar("T")
|
||||
|
|
Loading…
Add table
Reference in a new issue