mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Prototype: Model Trajectory View API, part 0 (#9171)
This commit is contained in:
parent
882f60012f
commit
0d37103f84
6 changed files with 311 additions and 89 deletions
2
rllib/env/policy_server_input.py
vendored
2
rllib/env/policy_server_input.py
vendored
|
@ -34,7 +34,7 @@ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
|
|||
... "num_workers": 0, # Run just 1 server, in the trainer.
|
||||
... }
|
||||
>>> while True:
|
||||
pg.train()
|
||||
>>> pg.train()
|
||||
|
||||
>>> client = PolicyClient("localhost:9900", inference_mode="local")
|
||||
>>> eps_id = client.start_episode()
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
from collections import OrderedDict
|
||||
import gym
|
||||
from typing import Dict
|
||||
|
||||
from ray.rllib.models.preprocessors import get_preprocessor, \
|
||||
RepeatedValuesPreprocessor
|
||||
from ray.rllib.models.repeated_values import RepeatedValues
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.trajectory_view import ViewRequirement
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
TensorType
|
||||
from ray.rllib.utils.spaces.repeated import Repeated
|
||||
from ray.rllib.utils.types import ModelConfigDict
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -15,7 +19,7 @@ torch, _ = try_import_torch()
|
|||
|
||||
@PublicAPI
|
||||
class ModelV2:
|
||||
"""Defines a Keras-style abstract network model for use with RLlib.
|
||||
"""Defines an abstract neural network model for use with RLlib.
|
||||
|
||||
Custom models should extend either TFModelV2 or TorchModelV2 instead of
|
||||
this class directly.
|
||||
|
@ -23,33 +27,41 @@ class ModelV2:
|
|||
Data flow:
|
||||
obs -> forward() -> model_out
|
||||
value_function() -> V(s)
|
||||
|
||||
Attributes:
|
||||
obs_space (Space): observation space of the target gym env. This
|
||||
may have an `original_space` attribute that specifies how to
|
||||
unflatten the tensor into a ragged tensor.
|
||||
action_space (Space): action space of the target gym env
|
||||
num_outputs (int): number of output units of the model
|
||||
model_config (dict): config for the model, documented in ModelCatalog
|
||||
name (str): name (scope) for the model
|
||||
framework (str): either "tf" or "torch"
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
||||
name, framework):
|
||||
"""Initialize the model.
|
||||
def __init__(self,
|
||||
obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
num_outputs: int,
|
||||
model_config: ModelConfigDict,
|
||||
name: str,
|
||||
framework: str):
|
||||
"""Initializes a ModelV2 object.
|
||||
|
||||
This method should create any variables used by the model.
|
||||
|
||||
Args:
|
||||
obs_space (gym.spaces.Space): Observation space of the target gym
|
||||
env. This may have an `original_space` attribute that
|
||||
specifies how to unflatten the tensor into a ragged tensor.
|
||||
action_space (gym.spaces.Space): Action space of the target gym
|
||||
env.
|
||||
num_outputs (int): Number of output units of the model.
|
||||
model_config (ModelConfigDict): Config for the model, documented
|
||||
in ModelCatalog.
|
||||
name (str): Name (scope) for the model.
|
||||
framework (str): Either "tf" or "torch".
|
||||
"""
|
||||
|
||||
self.obs_space = obs_space
|
||||
self.action_space = action_space
|
||||
self.num_outputs = num_outputs
|
||||
self.model_config = model_config
|
||||
self.name = name or "default_model"
|
||||
self.framework = framework
|
||||
self.obs_space: gym.spaces.Space = obs_space
|
||||
self.action_space: gym.spaces.Space = action_space
|
||||
self.num_outputs: int = num_outputs
|
||||
self.model_config: ModelConfigDict = model_config
|
||||
self.name: str = name or "default_model"
|
||||
self.framework: str = framework
|
||||
self._last_output = None
|
||||
|
||||
@PublicAPI
|
||||
def get_initial_state(self):
|
||||
"""Get the initial recurrent state values for the model.
|
||||
|
||||
|
@ -66,6 +78,7 @@ class ModelV2:
|
|||
"""
|
||||
return []
|
||||
|
||||
@PublicAPI
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
"""Call the model with the given input tensors and state.
|
||||
|
||||
|
@ -100,6 +113,7 @@ class ModelV2:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def value_function(self):
|
||||
"""Returns the value function output for the most recent forward pass.
|
||||
|
||||
|
@ -112,6 +126,7 @@ class ModelV2:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
"""Override to customize the loss function used to optimize this model.
|
||||
|
||||
|
@ -133,6 +148,7 @@ class ModelV2:
|
|||
"""
|
||||
return policy_loss
|
||||
|
||||
@PublicAPI
|
||||
def metrics(self):
|
||||
"""Override to return custom metrics from your model.
|
||||
|
||||
|
@ -201,6 +217,7 @@ class ModelV2:
|
|||
self._last_output = outputs
|
||||
return outputs, state
|
||||
|
||||
@PublicAPI
|
||||
def from_batch(self, train_batch, is_training=True):
|
||||
"""Convenience function that calls this model with a tensor batch.
|
||||
|
||||
|
@ -223,6 +240,34 @@ class ModelV2:
|
|||
i += 1
|
||||
return self.__call__(input_dict, states, train_batch.get("seq_lens"))
|
||||
|
||||
def get_view_requirements(
|
||||
self,
|
||||
is_training: bool = False) -> Dict[str, ViewRequirement]:
|
||||
"""Returns a list of ViewRequirements for this Model (or None).
|
||||
|
||||
Note: This is an experimental API method.
|
||||
|
||||
A ViewRequirement object tells the caller of this Model, which
|
||||
data at which timesteps are needed by this Model. This could be a
|
||||
sequence of past observations, internal-states, previous rewards, or
|
||||
other episode data/previous model outputs.
|
||||
|
||||
Args:
|
||||
is_training (bool): Whether the returned requirements are for
|
||||
training or inference (default).
|
||||
|
||||
Returns:
|
||||
Dict[str, ViewRequirement]: The view requirements as a dict mapping
|
||||
column names e.g. "obs" to config dicts containing supported
|
||||
fields.
|
||||
TODO: (sven) Currently only `timesteps==0` can be setup.
|
||||
"""
|
||||
# Default implementation for simple RL model:
|
||||
# Single requirement: Pass current obs as input.
|
||||
return {
|
||||
SampleBatch.CUR_OBS: ViewRequirement(timesteps=0),
|
||||
}
|
||||
|
||||
def import_from_h5(self, h5_file):
|
||||
"""Imports weights from an h5 file.
|
||||
|
||||
|
@ -237,14 +282,17 @@ class ModelV2:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def last_output(self):
|
||||
"""Returns the last output returned from calling the model."""
|
||||
return self._last_output
|
||||
|
||||
@PublicAPI
|
||||
def context(self):
|
||||
"""Returns a contextmanager for the current forward pass."""
|
||||
return NullContextManager()
|
||||
|
||||
@PublicAPI
|
||||
def variables(self, as_dict=False):
|
||||
"""Returns the list (or a dict) of variables for this model.
|
||||
|
||||
|
@ -258,6 +306,7 @@ class ModelV2:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def trainable_variables(self, as_dict=False):
|
||||
"""Returns the list of trainable variables for this model.
|
||||
|
||||
|
@ -299,17 +348,20 @@ def flatten(obs, framework):
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
||||
def restore_original_dimensions(obs: TensorType,
|
||||
obs_space: gym.spaces.Space,
|
||||
tensorlib=tf):
|
||||
"""Unpacks Dict and Tuple space observations into their original form.
|
||||
|
||||
This is needed since we flatten Dict and Tuple observations in transit.
|
||||
Before sending them to the model though, we should unflatten them into
|
||||
Dicts or Tuples of tensors.
|
||||
This is needed since we flatten Dict and Tuple observations in transit
|
||||
within a SampleBatch. Before sending them to the model though, we should
|
||||
unflatten them into Dicts or Tuples of tensors.
|
||||
|
||||
Arguments:
|
||||
obs: The flattened observation tensor.
|
||||
obs_space: The flattened obs space. If this has the `original_space`
|
||||
attribute, we will unflatten the tensor to that shape.
|
||||
Args:
|
||||
obs (TensorType): The flattened observation tensor.
|
||||
obs_space (gym.spaces.Space): The flattened obs space. If this has the
|
||||
`original_space` attribute, we will unflatten the tensor to that
|
||||
shape.
|
||||
tensorlib: The library used to unflatten (reshape) the array/tensor.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -9,6 +10,7 @@ from ray.rllib.utils.framework import try_import_torch
|
|||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
unbatch
|
||||
from ray.rllib.utils.types import AgentID
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
tree = try_import_tree()
|
||||
|
@ -78,12 +80,12 @@ class Policy(metaclass=ABCMeta):
|
|||
"""Computes actions for the current policy.
|
||||
|
||||
Args:
|
||||
obs_batch (Union[List,np.ndarray]): Batch of observations.
|
||||
obs_batch (Union[List, np.ndarray]): Batch of observations.
|
||||
state_batches (Optional[list]): List of RNN state input batches,
|
||||
if any.
|
||||
prev_action_batch (Optional[List,np.ndarray]): Batch of previous
|
||||
prev_action_batch (Optional[List, np.ndarray]): Batch of previous
|
||||
action values.
|
||||
prev_reward_batch (Optional[List,np.ndarray]): Batch of previous
|
||||
prev_reward_batch (Optional[List, np.ndarray]): Batch of previous
|
||||
rewards.
|
||||
info_batch (info): Batch of info objects.
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
|
@ -189,6 +191,40 @@ class Policy(metaclass=ABCMeta):
|
|||
return single_action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
def compute_actions_from_trajectories(
|
||||
self,
|
||||
trajectories: List["Trajectory"],
|
||||
other_trajectories: Dict[AgentID, "Trajectory"],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""Computes actions for the current policy based on .
|
||||
|
||||
Note: This is an experimental API method.
|
||||
|
||||
Only used so far by the Sampler iff `_fast_sampling=True` (also only
|
||||
supported for torch).
|
||||
|
||||
Args:
|
||||
trajectories (List[Trajectory]): A List of Trajectory data used
|
||||
to create a view for the Model forward call.
|
||||
other_trajectories (Dict[AgentID, Trajectory]): Optional dict
|
||||
mapping AgentIDs to Trajectory objects.
|
||||
explore (bool): Whether to pick an exploitation or exploration
|
||||
action (default: None -> use self.config["explore"]).
|
||||
timestep (Optional[int]): The current (sampling) time step.
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
[BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (list): list of RNN state output batches, if any, with
|
||||
shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (dict): dictionary of extra feature batches, if any, with
|
||||
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_log_likelihoods(self,
|
||||
actions,
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import functools
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.trajectory_view import get_trajectory_view
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
@ -13,6 +15,7 @@ from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
|||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
from ray.rllib.utils.types import AgentID
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
@ -126,64 +129,112 @@ class TorchPolicy(Policy):
|
|||
state_batches = [
|
||||
convert_to_torch_tensor(s) for s in (state_batches or [])
|
||||
]
|
||||
actions, state_out, extra_fetches, logp = \
|
||||
self._compute_action_helper(
|
||||
input_dict, state_batches, seq_lens, explore, timestep)
|
||||
|
||||
if self.action_sampler_fn:
|
||||
action_dist = dist_inputs = None
|
||||
state_out = []
|
||||
actions, logp = self.action_sampler_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
explore=explore,
|
||||
timestep=timestep)
|
||||
else:
|
||||
# Call the exploration before_compute_actions hook.
|
||||
self.exploration.before_compute_actions(
|
||||
explore=explore, timestep=timestep)
|
||||
if self.action_distribution_fn:
|
||||
dist_inputs, dist_class, state_out = \
|
||||
self.action_distribution_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
is_training=False)
|
||||
else:
|
||||
dist_class = self.dist_class
|
||||
dist_inputs, state_out = self.model(
|
||||
input_dict, state_batches, seq_lens)
|
||||
if not (isinstance(dist_class, functools.partial)
|
||||
or issubclass(dist_class, TorchDistributionWrapper)):
|
||||
raise ValueError(
|
||||
"`dist_class` ({}) not a TorchDistributionWrapper "
|
||||
"subclass! Make sure your `action_distribution_fn` or "
|
||||
"`make_model_and_action_dist` return a correct "
|
||||
"distribution class.".format(dist_class.__name__))
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
# Get the exploration action from the forward results.
|
||||
actions, logp = \
|
||||
self.exploration.get_exploration_action(
|
||||
action_distribution=action_dist,
|
||||
timestep=timestep,
|
||||
explore=explore)
|
||||
|
||||
input_dict[SampleBatch.ACTIONS] = actions
|
||||
|
||||
# Add default and custom fetches.
|
||||
extra_fetches = self.extra_action_out(input_dict, state_batches,
|
||||
self.model, action_dist)
|
||||
# Action-logp and action-prob.
|
||||
if logp is not None:
|
||||
logp = convert_to_non_torch_type(logp)
|
||||
extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
|
||||
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
||||
# Action-dist inputs.
|
||||
if dist_inputs is not None:
|
||||
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
||||
return convert_to_non_torch_type((actions, state_out,
|
||||
extra_fetches))
|
||||
|
||||
return convert_to_non_torch_type(
|
||||
(actions, state_out, extra_fetches))
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_trajectories(
|
||||
self,
|
||||
trajectories: List["Trajectory"],
|
||||
other_trajectories: Dict[AgentID, "Trajectory"],
|
||||
explore: bool = None,
|
||||
timestep: Optional[int] = None,
|
||||
**kwargs):
|
||||
|
||||
explore = explore if explore is not None else self.config["explore"]
|
||||
timestep = timestep if timestep is not None else self.global_timestep
|
||||
|
||||
with torch.no_grad():
|
||||
# Create a view and pass that to Model as `input_dict`.
|
||||
input_dict = self._lazy_tensor_dict(get_trajectory_view(
|
||||
self.model, trajectories, is_training=False))
|
||||
# TODO: (sven) support RNNs w/ fast sampling.
|
||||
state_batches = []
|
||||
seq_lens = None
|
||||
|
||||
actions, state_out, extra_fetches, logp = \
|
||||
self._compute_action_helper(
|
||||
input_dict, state_batches, seq_lens, explore, timestep)
|
||||
|
||||
# Leave outputs as is (torch.Tensors): Action-logp and action-prob.
|
||||
if logp is not None:
|
||||
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp)
|
||||
extra_fetches[SampleBatch.ACTION_LOGP] = logp
|
||||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
|
||||
explore, timestep):
|
||||
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
||||
|
||||
Returns:
|
||||
Tuple:
|
||||
- actions, state_out, extra_fetches, logp.
|
||||
"""
|
||||
if self.action_sampler_fn:
|
||||
action_dist = dist_inputs = None
|
||||
state_out = []
|
||||
actions, logp = self.action_sampler_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
explore=explore,
|
||||
timestep=timestep)
|
||||
else:
|
||||
# Call the exploration before_compute_actions hook.
|
||||
self.exploration.before_compute_actions(
|
||||
explore=explore, timestep=timestep)
|
||||
if self.action_distribution_fn:
|
||||
dist_inputs, dist_class, state_out = \
|
||||
self.action_distribution_fn(
|
||||
self,
|
||||
self.model,
|
||||
input_dict[SampleBatch.CUR_OBS],
|
||||
explore=explore,
|
||||
timestep=timestep,
|
||||
is_training=False)
|
||||
else:
|
||||
dist_class = self.dist_class
|
||||
dist_inputs, state_out = self.model(
|
||||
input_dict, state_batches, seq_lens)
|
||||
|
||||
if not (isinstance(dist_class, functools.partial)
|
||||
or issubclass(dist_class, TorchDistributionWrapper)):
|
||||
raise ValueError(
|
||||
"`dist_class` ({}) not a TorchDistributionWrapper "
|
||||
"subclass! Make sure your `action_distribution_fn` or "
|
||||
"`make_model_and_action_dist` return a correct "
|
||||
"distribution class.".format(dist_class.__name__))
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
|
||||
# Get the exploration action from the forward results.
|
||||
actions, logp = \
|
||||
self.exploration.get_exploration_action(
|
||||
action_distribution=action_dist,
|
||||
timestep=timestep,
|
||||
explore=explore)
|
||||
|
||||
input_dict[SampleBatch.ACTIONS] = actions
|
||||
|
||||
# Add default and custom fetches.
|
||||
extra_fetches = self.extra_action_out(input_dict, state_batches,
|
||||
self.model, action_dist)
|
||||
|
||||
# Action-dist inputs.
|
||||
if dist_inputs is not None:
|
||||
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
|
||||
|
||||
return actions, state_out, extra_fetches, logp
|
||||
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(self,
|
||||
|
|
84
rllib/policy/trajectory_view.py
Normal file
84
rllib/policy/trajectory_view.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
from ray.rllib.utils.types import TensorType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ViewRequirement:
|
||||
"""Single view requirement (for one column in a ModelV2 input_dict).
|
||||
|
||||
Note: This is an experimental class.
|
||||
|
||||
ModelV2 returns a Dict[str, ViewRequirement] upon calling
|
||||
`ModelV2.get_view_requirements()`, where the str key represents the column
|
||||
name (C) under which the view is available in the `input_dict` and
|
||||
ViewRequirement specifies the actual underlying column names (in the
|
||||
original data buffer), timesteps, and other options to build the view
|
||||
for N.
|
||||
|
||||
Examples:
|
||||
>>> # The default ViewRequirement for a Model is:
|
||||
>>> req = [ModelV2].get_view_requirements(is_training=False)
|
||||
>>> print(req)
|
||||
{"obs": ViewRequirement(timesteps=0)}
|
||||
"""
|
||||
# The data column name from the SampleBatch (str key).
|
||||
# If None, use the dict key under which this ViewRequirement resides.
|
||||
data_col: str = None
|
||||
|
||||
# List of relative (or absolute timesteps) to be present in the
|
||||
# input_dict.
|
||||
timesteps: int = 0
|
||||
|
||||
# Switch on absolute timestep mode. Default: False.
|
||||
# TODO: (sven)
|
||||
# "absolute_timesteps",
|
||||
|
||||
# The fill mode in case t<0 or t>H: One of "zeros", "tile".
|
||||
fill_mode: str = "zeros"
|
||||
|
||||
# The repeat-mode (one of "all" or "only_first"). E.g. for training,
|
||||
# we only want the first internal state timestep (the NN will
|
||||
# calculate all others again anyways).
|
||||
repeat_mode: str = "all"
|
||||
|
||||
# Provide all data as time major (default: False).
|
||||
# TODO: (sven)
|
||||
# "time_major",
|
||||
|
||||
|
||||
def get_trajectory_view(
|
||||
model,
|
||||
trajectories,
|
||||
is_training: bool = False) -> Dict[str, TensorType]:
|
||||
"""Returns an input_dict for a Model's forward pass given some data.
|
||||
|
||||
Args:
|
||||
model (ModelV2): The ModelV2 object for which to generate the view
|
||||
(input_dict) from `data`.
|
||||
trajectories (List[Trajectory]): The data from which to generate
|
||||
an input_dict.
|
||||
is_training (bool): Whether the view should be generated for training
|
||||
purposes or inference (default).
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The input_dict to be passed into the ModelV2
|
||||
for inference/training.
|
||||
"""
|
||||
# Get ModelV2's view requirements.
|
||||
view_reqs = model.get_view_requirements(is_training=is_training)
|
||||
# Construct the view dict.
|
||||
view = {}
|
||||
for view_col, view_req in view_reqs.items():
|
||||
# Create the batch of data from the different buffers in `data`.
|
||||
# TODO: (sven): Here, we actually do create a copy of the data (from a
|
||||
# list). The only way to avoid this entirely would be to keep a
|
||||
# single(!) np buffer per column across all currently ongoing
|
||||
# agents + episodes (which seems very hard to realize).
|
||||
view[view_col] = np.array([
|
||||
t.buffers[view_req.data_col][t.cursor + view_req.timesteps]
|
||||
for t in trajectories
|
||||
])
|
||||
return view
|
|
@ -20,10 +20,9 @@ def check_support_multiagent(alg, config):
|
|||
config=config, env="multi_agent_mountaincar")
|
||||
else:
|
||||
a = get_agent_class(alg)(config=config, env="multi_agent_cartpole")
|
||||
try:
|
||||
print(a.train())
|
||||
finally:
|
||||
a.stop()
|
||||
|
||||
print(a.train())
|
||||
a.stop()
|
||||
|
||||
|
||||
class TestSupportedMultiAgent(unittest.TestCase):
|
||||
|
|
Loading…
Add table
Reference in a new issue