[RLlib] Prototype: Model Trajectory View API, part 0 (#9171)

This commit is contained in:
Sven Mika 2020-06-30 05:33:19 +02:00 committed by GitHub
parent 882f60012f
commit 0d37103f84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 311 additions and 89 deletions

View file

@ -34,7 +34,7 @@ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
... "num_workers": 0, # Run just 1 server, in the trainer.
... }
>>> while True:
pg.train()
>>> pg.train()
>>> client = PolicyClient("localhost:9900", inference_mode="local")
>>> eps_id = client.start_episode()

View file

@ -1,13 +1,17 @@
from collections import OrderedDict
import gym
from typing import Dict
from ray.rllib.models.preprocessors import get_preprocessor, \
RepeatedValuesPreprocessor
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.trajectory_view import ViewRequirement
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.utils.types import ModelConfigDict
tf = try_import_tf()
torch, _ = try_import_torch()
@ -15,7 +19,7 @@ torch, _ = try_import_torch()
@PublicAPI
class ModelV2:
"""Defines a Keras-style abstract network model for use with RLlib.
"""Defines an abstract neural network model for use with RLlib.
Custom models should extend either TFModelV2 or TorchModelV2 instead of
this class directly.
@ -23,33 +27,41 @@ class ModelV2:
Data flow:
obs -> forward() -> model_out
value_function() -> V(s)
Attributes:
obs_space (Space): observation space of the target gym env. This
may have an `original_space` attribute that specifies how to
unflatten the tensor into a ragged tensor.
action_space (Space): action space of the target gym env
num_outputs (int): number of output units of the model
model_config (dict): config for the model, documented in ModelCatalog
name (str): name (scope) for the model
framework (str): either "tf" or "torch"
"""
def __init__(self, obs_space, action_space, num_outputs, model_config,
name, framework):
"""Initialize the model.
def __init__(self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
framework: str):
"""Initializes a ModelV2 object.
This method should create any variables used by the model.
Args:
obs_space (gym.spaces.Space): Observation space of the target gym
env. This may have an `original_space` attribute that
specifies how to unflatten the tensor into a ragged tensor.
action_space (gym.spaces.Space): Action space of the target gym
env.
num_outputs (int): Number of output units of the model.
model_config (ModelConfigDict): Config for the model, documented
in ModelCatalog.
name (str): Name (scope) for the model.
framework (str): Either "tf" or "torch".
"""
self.obs_space = obs_space
self.action_space = action_space
self.num_outputs = num_outputs
self.model_config = model_config
self.name = name or "default_model"
self.framework = framework
self.obs_space: gym.spaces.Space = obs_space
self.action_space: gym.spaces.Space = action_space
self.num_outputs: int = num_outputs
self.model_config: ModelConfigDict = model_config
self.name: str = name or "default_model"
self.framework: str = framework
self._last_output = None
@PublicAPI
def get_initial_state(self):
"""Get the initial recurrent state values for the model.
@ -66,6 +78,7 @@ class ModelV2:
"""
return []
@PublicAPI
def forward(self, input_dict, state, seq_lens):
"""Call the model with the given input tensors and state.
@ -100,6 +113,7 @@ class ModelV2:
"""
raise NotImplementedError
@PublicAPI
def value_function(self):
"""Returns the value function output for the most recent forward pass.
@ -112,6 +126,7 @@ class ModelV2:
"""
raise NotImplementedError
@PublicAPI
def custom_loss(self, policy_loss, loss_inputs):
"""Override to customize the loss function used to optimize this model.
@ -133,6 +148,7 @@ class ModelV2:
"""
return policy_loss
@PublicAPI
def metrics(self):
"""Override to return custom metrics from your model.
@ -201,6 +217,7 @@ class ModelV2:
self._last_output = outputs
return outputs, state
@PublicAPI
def from_batch(self, train_batch, is_training=True):
"""Convenience function that calls this model with a tensor batch.
@ -223,6 +240,34 @@ class ModelV2:
i += 1
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
def get_view_requirements(
self,
is_training: bool = False) -> Dict[str, ViewRequirement]:
"""Returns a list of ViewRequirements for this Model (or None).
Note: This is an experimental API method.
A ViewRequirement object tells the caller of this Model, which
data at which timesteps are needed by this Model. This could be a
sequence of past observations, internal-states, previous rewards, or
other episode data/previous model outputs.
Args:
is_training (bool): Whether the returned requirements are for
training or inference (default).
Returns:
Dict[str, ViewRequirement]: The view requirements as a dict mapping
column names e.g. "obs" to config dicts containing supported
fields.
TODO: (sven) Currently only `timesteps==0` can be setup.
"""
# Default implementation for simple RL model:
# Single requirement: Pass current obs as input.
return {
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
}
def import_from_h5(self, h5_file):
"""Imports weights from an h5 file.
@ -237,14 +282,17 @@ class ModelV2:
"""
raise NotImplementedError
@PublicAPI
def last_output(self):
"""Returns the last output returned from calling the model."""
return self._last_output
@PublicAPI
def context(self):
"""Returns a contextmanager for the current forward pass."""
return NullContextManager()
@PublicAPI
def variables(self, as_dict=False):
"""Returns the list (or a dict) of variables for this model.
@ -258,6 +306,7 @@ class ModelV2:
"""
raise NotImplementedError
@PublicAPI
def trainable_variables(self, as_dict=False):
"""Returns the list of trainable variables for this model.
@ -299,17 +348,20 @@ def flatten(obs, framework):
@DeveloperAPI
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
def restore_original_dimensions(obs: TensorType,
obs_space: gym.spaces.Space,
tensorlib=tf):
"""Unpacks Dict and Tuple space observations into their original form.
This is needed since we flatten Dict and Tuple observations in transit.
Before sending them to the model though, we should unflatten them into
Dicts or Tuples of tensors.
This is needed since we flatten Dict and Tuple observations in transit
within a SampleBatch. Before sending them to the model though, we should
unflatten them into Dicts or Tuples of tensors.
Arguments:
obs: The flattened observation tensor.
obs_space: The flattened obs space. If this has the `original_space`
attribute, we will unflatten the tensor to that shape.
Args:
obs (TensorType): The flattened observation tensor.
obs_space (gym.spaces.Space): The flattened obs space. If this has the
`original_space` attribute, we will unflatten the tensor to that
shape.
tensorlib: The library used to unflatten (reshape) the array/tensor.
Returns:

View file

@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
import gym
import numpy as np
from typing import Dict, List, Optional
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI
@ -9,6 +10,7 @@ from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
unbatch
from ray.rllib.utils.types import AgentID
torch, _ = try_import_torch()
tree = try_import_tree()
@ -78,12 +80,12 @@ class Policy(metaclass=ABCMeta):
"""Computes actions for the current policy.
Args:
obs_batch (Union[List,np.ndarray]): Batch of observations.
obs_batch (Union[List, np.ndarray]): Batch of observations.
state_batches (Optional[list]): List of RNN state input batches,
if any.
prev_action_batch (Optional[List,np.ndarray]): Batch of previous
prev_action_batch (Optional[List, np.ndarray]): Batch of previous
action values.
prev_reward_batch (Optional[List,np.ndarray]): Batch of previous
prev_reward_batch (Optional[List, np.ndarray]): Batch of previous
rewards.
info_batch (info): Batch of info objects.
episodes (list): MultiAgentEpisode for each obs in obs_batch.
@ -189,6 +191,40 @@ class Policy(metaclass=ABCMeta):
return single_action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
def compute_actions_from_trajectories(
self,
trajectories: List["Trajectory"],
other_trajectories: Dict[AgentID, "Trajectory"],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs):
"""Computes actions for the current policy based on .
Note: This is an experimental API method.
Only used so far by the Sampler iff `_fast_sampling=True` (also only
supported for torch).
Args:
trajectories (List[Trajectory]): A List of Trajectory data used
to create a view for the Model forward call.
other_trajectories (Dict[AgentID, Trajectory]): Optional dict
mapping AgentIDs to Trajectory objects.
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep (Optional[int]): The current (sampling) time step.
kwargs: forward compatibility placeholder
Returns:
actions (np.ndarray): batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (list): list of RNN state output batches, if any, with
shape like [STATE_SIZE, BATCH_SIZE].
info (dict): dictionary of extra feature batches, if any, with
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
@DeveloperAPI
def compute_log_likelihoods(self,
actions,

View file

@ -1,11 +1,13 @@
import functools
import numpy as np
import time
from typing import Dict, List, Optional
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.trajectory_view import get_trajectory_view
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
@ -13,6 +15,7 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
from ray.rllib.utils.tracking_dict import UsageTrackingDict
from ray.rllib.utils.types import AgentID
torch, _ = try_import_torch()
@ -126,64 +129,112 @@ class TorchPolicy(Policy):
state_batches = [
convert_to_torch_tensor(s) for s in (state_batches or [])
]
actions, state_out, extra_fetches, logp = \
self._compute_action_helper(
input_dict, state_batches, seq_lens, explore, timestep)
if self.action_sampler_fn:
action_dist = dist_inputs = None
state_out = []
actions, logp = self.action_sampler_fn(
self,
self.model,
input_dict[SampleBatch.CUR_OBS],
explore=explore,
timestep=timestep)
else:
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(
explore=explore, timestep=timestep)
if self.action_distribution_fn:
dist_inputs, dist_class, state_out = \
self.action_distribution_fn(
self,
self.model,
input_dict[SampleBatch.CUR_OBS],
explore=explore,
timestep=timestep,
is_training=False)
else:
dist_class = self.dist_class
dist_inputs, state_out = self.model(
input_dict, state_batches, seq_lens)
if not (isinstance(dist_class, functools.partial)
or issubclass(dist_class, TorchDistributionWrapper)):
raise ValueError(
"`dist_class` ({}) not a TorchDistributionWrapper "
"subclass! Make sure your `action_distribution_fn` or "
"`make_model_and_action_dist` return a correct "
"distribution class.".format(dist_class.__name__))
action_dist = dist_class(dist_inputs, self.model)
# Get the exploration action from the forward results.
actions, logp = \
self.exploration.get_exploration_action(
action_distribution=action_dist,
timestep=timestep,
explore=explore)
input_dict[SampleBatch.ACTIONS] = actions
# Add default and custom fetches.
extra_fetches = self.extra_action_out(input_dict, state_batches,
self.model, action_dist)
# Action-logp and action-prob.
if logp is not None:
logp = convert_to_non_torch_type(logp)
extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
extra_fetches[SampleBatch.ACTION_LOGP] = logp
# Action-dist inputs.
if dist_inputs is not None:
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
return convert_to_non_torch_type((actions, state_out,
extra_fetches))
return convert_to_non_torch_type(
(actions, state_out, extra_fetches))
@override(Policy)
def compute_actions_from_trajectories(
self,
trajectories: List["Trajectory"],
other_trajectories: Dict[AgentID, "Trajectory"],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs):
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
with torch.no_grad():
# Create a view and pass that to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(get_trajectory_view(
self.model, trajectories, is_training=False))
# TODO: (sven) support RNNs w/ fast sampling.
state_batches = []
seq_lens = None
actions, state_out, extra_fetches, logp = \
self._compute_action_helper(
input_dict, state_batches, seq_lens, explore, timestep)
# Leave outputs as is (torch.Tensors): Action-logp and action-prob.
if logp is not None:
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp)
extra_fetches[SampleBatch.ACTION_LOGP] = logp
return actions, state_out, extra_fetches
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
explore, timestep):
"""Shared forward pass logic (w/ and w/o trajectory view API).
Returns:
Tuple:
- actions, state_out, extra_fetches, logp.
"""
if self.action_sampler_fn:
action_dist = dist_inputs = None
state_out = []
actions, logp = self.action_sampler_fn(
self,
self.model,
input_dict[SampleBatch.CUR_OBS],
explore=explore,
timestep=timestep)
else:
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(
explore=explore, timestep=timestep)
if self.action_distribution_fn:
dist_inputs, dist_class, state_out = \
self.action_distribution_fn(
self,
self.model,
input_dict[SampleBatch.CUR_OBS],
explore=explore,
timestep=timestep,
is_training=False)
else:
dist_class = self.dist_class
dist_inputs, state_out = self.model(
input_dict, state_batches, seq_lens)
if not (isinstance(dist_class, functools.partial)
or issubclass(dist_class, TorchDistributionWrapper)):
raise ValueError(
"`dist_class` ({}) not a TorchDistributionWrapper "
"subclass! Make sure your `action_distribution_fn` or "
"`make_model_and_action_dist` return a correct "
"distribution class.".format(dist_class.__name__))
action_dist = dist_class(dist_inputs, self.model)
# Get the exploration action from the forward results.
actions, logp = \
self.exploration.get_exploration_action(
action_distribution=action_dist,
timestep=timestep,
explore=explore)
input_dict[SampleBatch.ACTIONS] = actions
# Add default and custom fetches.
extra_fetches = self.extra_action_out(input_dict, state_batches,
self.model, action_dist)
# Action-dist inputs.
if dist_inputs is not None:
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
return actions, state_out, extra_fetches, logp
@override(Policy)
def compute_log_likelihoods(self,

View file

@ -0,0 +1,84 @@
from dataclasses import dataclass
import numpy as np
from typing import Dict
from ray.rllib.utils.types import TensorType
@dataclass
class ViewRequirement:
"""Single view requirement (for one column in a ModelV2 input_dict).
Note: This is an experimental class.
ModelV2 returns a Dict[str, ViewRequirement] upon calling
`ModelV2.get_view_requirements()`, where the str key represents the column
name (C) under which the view is available in the `input_dict` and
ViewRequirement specifies the actual underlying column names (in the
original data buffer), timesteps, and other options to build the view
for N.
Examples:
>>> # The default ViewRequirement for a Model is:
>>> req = [ModelV2].get_view_requirements(is_training=False)
>>> print(req)
{"obs": ViewRequirement(timesteps=0)}
"""
# The data column name from the SampleBatch (str key).
# If None, use the dict key under which this ViewRequirement resides.
data_col: str = None
# List of relative (or absolute timesteps) to be present in the
# input_dict.
timesteps: int = 0
# Switch on absolute timestep mode. Default: False.
# TODO: (sven)
# "absolute_timesteps",
# The fill mode in case t<0 or t>H: One of "zeros", "tile".
fill_mode: str = "zeros"
# The repeat-mode (one of "all" or "only_first"). E.g. for training,
# we only want the first internal state timestep (the NN will
# calculate all others again anyways).
repeat_mode: str = "all"
# Provide all data as time major (default: False).
# TODO: (sven)
# "time_major",
def get_trajectory_view(
model,
trajectories,
is_training: bool = False) -> Dict[str, TensorType]:
"""Returns an input_dict for a Model's forward pass given some data.
Args:
model (ModelV2): The ModelV2 object for which to generate the view
(input_dict) from `data`.
trajectories (List[Trajectory]): The data from which to generate
an input_dict.
is_training (bool): Whether the view should be generated for training
purposes or inference (default).
Returns:
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
for inference/training.
"""
# Get ModelV2's view requirements.
view_reqs = model.get_view_requirements(is_training=is_training)
# Construct the view dict.
view = {}
for view_col, view_req in view_reqs.items():
# Create the batch of data from the different buffers in `data`.
# TODO: (sven): Here, we actually do create a copy of the data (from a
# list). The only way to avoid this entirely would be to keep a
# single(!) np buffer per column across all currently ongoing
# agents + episodes (which seems very hard to realize).
view[view_col] = np.array([
t.buffers[view_req.data_col][t.cursor + view_req.timesteps]
for t in trajectories
])
return view

View file

@ -20,10 +20,9 @@ def check_support_multiagent(alg, config):
config=config, env="multi_agent_mountaincar")
else:
a = get_agent_class(alg)(config=config, env="multi_agent_cartpole")
try:
print(a.train())
finally:
a.stop()
print(a.train())
a.stop()
class TestSupportedMultiAgent(unittest.TestCase):