mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Missing type annotations policy templates. (#9846)
This commit is contained in:
parent
38408574c4
commit
9b90f7db67
5 changed files with 287 additions and 135 deletions
|
@ -52,7 +52,7 @@ atari-ppo-torch:
|
|||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
use_pytorch: true
|
||||
framework: torch
|
||||
lambda: 0.95
|
||||
kl_coeff: 0.5
|
||||
clip_rewards: True
|
||||
|
|
|
@ -206,7 +206,8 @@ class TestModules(unittest.TestCase):
|
|||
|
||||
# Get initial state and add a batch dimension.
|
||||
init_state = [np.expand_dims(s, 0) for s in init_state]
|
||||
seq_lens_init = torch.full(size=(B, ), fill_value=L)
|
||||
seq_lens_init = torch.full(
|
||||
size=(B, ), fill_value=L, dtype=torch.int32)
|
||||
|
||||
# Torch implementation expects a formatted input_dict instead
|
||||
# of a numpy array as input.
|
||||
|
|
|
@ -74,7 +74,8 @@ class DynamicTFPolicy(TFPolicy):
|
|||
existing_inputs: Optional[Dict[
|
||||
str, "tf1.placeholder"]] = None,
|
||||
existing_model: Optional[ModelV2] = None,
|
||||
get_batch_divisibility_req: Optional[int] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[
|
||||
[Policy], int]] = None,
|
||||
obs_include_prev_action_reward: bool = True):
|
||||
"""Initialize a dynamic TF policy.
|
||||
|
||||
|
|
|
@ -1,34 +1,70 @@
|
|||
import gym
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy import eager_tf_policy
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import ModelGradients, TensorType, TrainerConfigDict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_tf_policy(name,
|
||||
def build_tf_policy(name: str,
|
||||
*,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
postprocess_fn=None,
|
||||
stats_fn=None,
|
||||
optimizer_fn=None,
|
||||
gradients_fn=None,
|
||||
apply_gradients_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
extra_learn_fetches_fn=None,
|
||||
validate_spaces=None,
|
||||
before_init=None,
|
||||
before_loss_init=None,
|
||||
after_init=None,
|
||||
make_model=None,
|
||||
action_sampler_fn=None,
|
||||
action_distribution_fn=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None,
|
||||
obs_include_prev_action_reward=True):
|
||||
loss_fn: Callable[
|
||||
[Policy, ModelV2, type, SampleBatch], TensorType],
|
||||
get_default_config: Optional[
|
||||
Callable[[None], TrainerConfigDict]] = None,
|
||||
postprocess_fn: Optional[Callable[
|
||||
[Policy, SampleBatch, List[SampleBatch],
|
||||
"MultiAgentEpisode"], None]] = None,
|
||||
stats_fn: Optional[Callable[
|
||||
[Policy, SampleBatch], Dict[str, TensorType]]] = None,
|
||||
optimizer_fn: Optional[Callable[
|
||||
[Policy, TrainerConfigDict],
|
||||
"tf.keras.optimizers.Optimizer"]] = None,
|
||||
gradients_fn: Optional[Callable[
|
||||
[Policy, "tf.keras.optimizers.Optimizer",
|
||||
TensorType], ModelGradients]] = None,
|
||||
apply_gradients_fn: Optional[Callable[
|
||||
[Policy, "tf.keras.optimizers.Optimizer",
|
||||
ModelGradients], "tf.Operation"]] = None,
|
||||
grad_stats_fn: Optional[Callable[
|
||||
[Policy, SampleBatch, ModelGradients],
|
||||
Dict[str, TensorType]]] = None,
|
||||
extra_action_fetches_fn: Optional[Callable[
|
||||
[Policy], Dict[str, TensorType]]] = None,
|
||||
extra_learn_fetches_fn: Optional[Callable[
|
||||
[Policy], Dict[str, TensorType]]] = None,
|
||||
validate_spaces: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict],
|
||||
None]] = None,
|
||||
before_init: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict],
|
||||
None]] = None,
|
||||
before_loss_init: Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict], None]] = None,
|
||||
after_init: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict],
|
||||
None]] = None,
|
||||
make_model: Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict], ModelV2]] = None,
|
||||
action_sampler_fn: Optional[Callable[
|
||||
[TensorType, List[TensorType]], Tuple[
|
||||
TensorType, TensorType]]] = None,
|
||||
action_distribution_fn: Optional[Callable[
|
||||
[Policy, ModelV2, TensorType, TensorType, TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[
|
||||
[Policy], int]] = None,
|
||||
obs_include_prev_action_reward: bool = True):
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Functions will be run in this order to initialize the policy:
|
||||
|
@ -48,55 +84,91 @@ def build_tf_policy(name,
|
|||
otherwise they will fail in eager mode execution. Variable should only
|
||||
be created in make_model (if defined).
|
||||
|
||||
Arguments:
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
loss_fn (func): function that returns a loss tensor as arguments
|
||||
(policy, model, dist_class, train_batch)
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
given (policy, optimizer, loss). If not specified, this
|
||||
defaults to optimizer.compute_gradients(loss)
|
||||
apply_gradients_fn (func): optional function that returns an apply
|
||||
gradients op given (policy, optimizer, grads_and_vars)
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy, batch input, and gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
extra_learn_fetches_fn (func): optional function that returns a dict of
|
||||
extra values to fetch and return when learning on a batch
|
||||
validate_spaces (Optional[callable]): Optional callable that takes the
|
||||
Policy, observation_space, action_space, and config to check for
|
||||
correctness.
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as the policy constructor
|
||||
after_init (func): optional function to run at the end of policy init
|
||||
that takes the same arguments as the policy constructor
|
||||
make_model (func): optional function that returns a ModelV2 object
|
||||
given (policy, obs_space, action_space, config).
|
||||
All policy variables should be created in this function. If not
|
||||
specified, a default model will be created.
|
||||
action_sampler_fn (Optional[callable]): A callable returning a sampled
|
||||
action and its log-likelihood given some (obs and state) inputs.
|
||||
action_distribution_fn (Optional[callable]): A callable returning
|
||||
distribution inputs (parameters), a dist-class to generate an
|
||||
action distribution object from, and internal-state outputs (or an
|
||||
empty list if not applicable).
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the DynamicTFPolicy class
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
obs_include_prev_action_reward (bool): whether to include the
|
||||
previous action and reward in the model input
|
||||
Args:
|
||||
name (str): Name of the policy (e.g., "PPOTFPolicy").
|
||||
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
|
||||
Callable for calculating a loss tensor.
|
||||
get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
|
||||
Optional callable that returns the default config to merge with any
|
||||
overrides. If None, uses only(!) the user-provided
|
||||
PartialTrainerConfigDict as dict for this Policy.
|
||||
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
||||
List[SampleBatch], MultiAgentEpisode], None]]): Optional callable
|
||||
for post-processing experience batches (called after the
|
||||
super's `postprocess_trajectory` method).
|
||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
TF tensors to fetch given the policy and batch input tensors. If
|
||||
None, will not compute any stats.
|
||||
optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
|
||||
"tf.keras.optimizers.Optimizer"]]): Optional callable that returns
|
||||
a tf.Optimizer given the policy and config. If None, will call
|
||||
the base class' `optimizer()` method instead (which returns a
|
||||
tf1.train.AdamOptimizer).
|
||||
gradients_fn (Optional[Callable[[Policy,
|
||||
"tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
|
||||
Optional callable that returns a list of gradients. If None,
|
||||
this defaults to optimizer.compute_gradients([loss]).
|
||||
apply_gradients_fn (Optional[Callable[[Policy,
|
||||
"tf.keras.optimizers.Optimizer", ModelGradients],
|
||||
"tf.Operation"]]): Optional callable that returns an apply
|
||||
gradients op given policy, tf-optimizer, and grads_and_vars. If
|
||||
None, will call the base class' `build_apply_op()` method instead.
|
||||
grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
TF fetches given the policy, batch input, and gradient tensors. If
|
||||
None, will not collect any gradient stats.
|
||||
extra_action_fetches_fn (Optional[Callable[[Policy],
|
||||
Dict[str, TensorType]]]): Optional callable that returns
|
||||
a dict of TF fetches given the policy object. If None, will not
|
||||
perform any extra fetches.
|
||||
extra_learn_fetches_fn (Optional[Callable[[Policy],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
extra values to fetch and return when learning on a batch. If None,
|
||||
will call the base class' `extra_compute_grad_fetches()` method
|
||||
instead.
|
||||
validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable that takes the
|
||||
Policy, observation_space, action_space, and config to check
|
||||
the spaces for correctness. If None, no spaces checking will be
|
||||
done.
|
||||
before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable to run at the
|
||||
beginning of policy init that takes the same arguments as the
|
||||
policy constructor. If None, this step will be skipped.
|
||||
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
|
||||
run prior to loss init. If None, this step will be skipped.
|
||||
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable to run at the end of
|
||||
policy init. If None, this step will be skipped.
|
||||
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
|
||||
that returns a ModelV2 object.
|
||||
All policy variables should be created in this function. If None,
|
||||
a default ModelV2 object will be created.
|
||||
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
||||
Tuple[TensorType, TensorType]]]): A callable returning a sampled
|
||||
action and its log-likelihood given observation and state inputs.
|
||||
If None, will either use `action_distribution_fn` or
|
||||
compute actions by calling self.model, then sampling from the
|
||||
so parameterized action distribution.
|
||||
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
||||
TensorType, TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]]): Optional callable
|
||||
returning distribution inputs (parameters), a dist-class to
|
||||
generate an action distribution object from, and internal-state
|
||||
outputs (or an empty list if not applicable). If None, will either
|
||||
use `action_sampler_fn` or compute actions by calling self.model,
|
||||
then sampling from the so parameterized action distribution.
|
||||
mixins (Optional[List[type]]): Optional list of any class mixins for
|
||||
the returned policy class. These mixins will be applied in order
|
||||
and will have higher precedence than the DynamicTFPolicy class.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
Optional callable that returns the divisibility requirement for
|
||||
sample batches. If None, will assume a value of 1.
|
||||
obs_include_prev_action_reward (bool): Whether to include the
|
||||
previous action and reward in the model input.
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
|
|
|
@ -1,92 +1,170 @@
|
|||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
import gym
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
|
||||
from ray.rllib.utils.types import TensorType, TrainerConfigDict
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_torch_policy(name,
|
||||
def build_torch_policy(name: str,
|
||||
*,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
stats_fn=None,
|
||||
postprocess_fn=None,
|
||||
extra_action_out_fn=None,
|
||||
extra_grad_process_fn=None,
|
||||
extra_learn_fetches_fn=None,
|
||||
optimizer_fn=None,
|
||||
validate_spaces=None,
|
||||
before_init=None,
|
||||
after_init=None,
|
||||
action_sampler_fn=None,
|
||||
action_distribution_fn=None,
|
||||
make_model=None,
|
||||
make_model_and_action_dist=None,
|
||||
apply_gradients_fn=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
loss_fn: Callable[
|
||||
[Policy, ModelV2, type, SampleBatch], TensorType],
|
||||
get_default_config: Optional[Callable[
|
||||
[], TrainerConfigDict]] = None,
|
||||
stats_fn: Optional[Callable[
|
||||
[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]] = None,
|
||||
postprocess_fn: Optional[Callable[
|
||||
[Policy, SampleBatch, List[SampleBatch],
|
||||
"MultiAgentEpisode"], None]] = None,
|
||||
extra_action_out_fn: Optional[Callable[
|
||||
[Policy, Dict[str, TensorType], List[TensorType],
|
||||
ModelV2, TorchDistributionWrapper],
|
||||
Dict[str, TensorType]]] = None,
|
||||
extra_grad_process_fn: Optional[Callable[
|
||||
[Policy, "torch.optim.Optimizer", TensorType],
|
||||
Dict[str, TensorType]]] = None,
|
||||
# TODO: (sven) Replace "fetches" with "process".
|
||||
extra_learn_fetches_fn: Optional[Callable[
|
||||
[Policy], Dict[str, TensorType]]] = None,
|
||||
optimizer_fn: Optional[Callable[
|
||||
[Policy, TrainerConfigDict],
|
||||
"torch.optim.Optimizer"]] = None,
|
||||
validate_spaces: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict],
|
||||
None]] = None,
|
||||
before_init: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict],
|
||||
None]] = None,
|
||||
after_init: Optional[Callable[
|
||||
[Policy, gym.Space, gym.Space, TrainerConfigDict],
|
||||
None]] = None,
|
||||
action_sampler_fn: Optional[Callable[
|
||||
[TensorType, List[TensorType]], Tuple[
|
||||
TensorType, TensorType]]] = None,
|
||||
action_distribution_fn: Optional[Callable[
|
||||
[Policy, ModelV2, TensorType, TensorType,
|
||||
TensorType],
|
||||
Tuple[TensorType, type, List[TensorType]]]] = None,
|
||||
make_model: Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict], ModelV2]] = None,
|
||||
make_model_and_action_dist: Optional[Callable[
|
||||
[Policy, gym.spaces.Space, gym.spaces.Space,
|
||||
TrainerConfigDict],
|
||||
Tuple[ModelV2, TorchDistributionWrapper]]] = None,
|
||||
apply_gradients_fn: Optional[Callable[
|
||||
[Policy, "torch.optim.Optimizer"], None]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
get_batch_divisibility_req: Optional[Callable[
|
||||
[Policy], int]] = None
|
||||
):
|
||||
"""Helper function for creating a torch policy class at runtime.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
name (str): name of the policy (e.g., "PPOTorchPolicy")
|
||||
loss_fn (callable): Callable that returns a loss tensor as arguments
|
||||
given (policy, model, dist_class, train_batch).
|
||||
get_default_config (Optional[callable]): Optional callable that returns
|
||||
the default config to merge with any overrides.
|
||||
stats_fn (Optional[callable]): Optional callable that returns a dict of
|
||||
values given the policy and batch input tensors.
|
||||
postprocess_fn (Optional[callable]): Optional experience postprocessing
|
||||
function that takes the same args as
|
||||
Policy.postprocess_trajectory().
|
||||
extra_action_out_fn (Optional[callable]): Optional callable that
|
||||
returns a dict of extra values to include in experiences.
|
||||
extra_grad_process_fn (Optional[callable]): Optional callable that is
|
||||
called after gradients are computed and returns processing info.
|
||||
extra_learn_fetches_fn (func): optional function that returns a dict of
|
||||
extra values to fetch from the policy after loss evaluation.
|
||||
optimizer_fn (Optional[callable]): Optional callable that returns a
|
||||
torch optimizer given the policy and config.
|
||||
validate_spaces (Optional[callable]): Optional callable that takes the
|
||||
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
|
||||
Callable that returns a loss tensor.
|
||||
get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
|
||||
Optional callable that returns the default config to merge with any
|
||||
overrides. If None, uses only(!) the user-provided
|
||||
PartialTrainerConfigDict as dict for this Policy.
|
||||
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
|
||||
List[SampleBatch], MultiAgentEpisode], None]]): Optional callable
|
||||
for post-processing experience batches (called after the
|
||||
super's `postprocess_trajectory` method).
|
||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
values given the policy and batch input tensors. If None,
|
||||
will use `TorchPolicy.extra_grad_info()` instead.
|
||||
extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType,
|
||||
List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
|
||||
TensorType]]]): Optional callable that returns a dict of extra
|
||||
values to include in experiences. If None, no extra computations
|
||||
will be performed.
|
||||
extra_grad_process_fn (Optional[Callable[[Policy,
|
||||
"torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
|
||||
Optional callable that is called after gradients are computed and
|
||||
returns a processing info dict. If None, will call the
|
||||
`TorchPolicy.extra_grad_process()` method instead.
|
||||
# TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
|
||||
extra_learn_fetches_fn (Optional[Callable[[Policy],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
extra tensors from the policy after loss evaluation. If None,
|
||||
will call the `TorchPolicy.extra_compute_grad_fetches()` method
|
||||
instead.
|
||||
optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
|
||||
"torch.optim.Optimizer"]]): Optional callable that returns a
|
||||
torch optimizer given the policy and config. If None, will call
|
||||
the `TorchPolicy.optimizer()` method instead (which returns a
|
||||
torch Adam optimizer).
|
||||
validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable that takes the
|
||||
Policy, observation_space, action_space, and config to check for
|
||||
correctness.
|
||||
before_init (Optional[callable]): Optional callable to run at the
|
||||
correctness. If None, no spaces checking will be done.
|
||||
before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable to run at the
|
||||
beginning of `Policy.__init__` that takes the same arguments as
|
||||
the Policy constructor.
|
||||
after_init (Optional[callable]): Optional callable to run at the end of
|
||||
the Policy constructor. If None, this step will be skipped.
|
||||
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
|
||||
TrainerConfigDict], None]]): Optional callable to run at the end of
|
||||
policy init that takes the same arguments as the policy
|
||||
constructor.
|
||||
action_sampler_fn (Optional[callable]): Optional callable returning a
|
||||
constructor. If None, this step will be skipped.
|
||||
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
|
||||
Tuple[TensorType, TensorType]]]): Optional callable returning a
|
||||
sampled action and its log-likelihood given some (obs and state)
|
||||
inputs.
|
||||
action_distribution_fn (Optional[callable]): A callable that takes
|
||||
inputs. If None, will either use `action_distribution_fn` or
|
||||
compute actions by calling self.model, then sampling from the
|
||||
so parameterized action distribution.
|
||||
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
|
||||
TensorType, TensorType], Tuple[TensorType, type,
|
||||
List[TensorType]]]]): A callable that takes
|
||||
the Policy, Model, the observation batch, an explore-flag, a
|
||||
timestep, and an is_training flag and returns a tuple of
|
||||
a) distribution inputs (parameters), b) a dist-class to generate
|
||||
an action distribution object from, and c) internal-state outputs
|
||||
(empty list if not applicable).
|
||||
make_model (Optional[callable]): Optional func that
|
||||
takes the same arguments as Policy.__init__ and returns a model
|
||||
instance. The distribution class will be determined automatically.
|
||||
Note: Only one of `make_model` or `make_model_and_action_dist`
|
||||
should be provided.
|
||||
make_model_and_action_dist (Optional[callable]): Optional func that
|
||||
(empty list if not applicable). If None, will either use
|
||||
`action_sampler_fn` or compute actions by calling self.model,
|
||||
then sampling from the parameterized action distribution.
|
||||
make_model (Optional[Callable[[Policy, gym.spaces.Space,
|
||||
gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
|
||||
that takes the same arguments as Policy.__init__ and returns a
|
||||
model instance. The distribution class will be determined
|
||||
automatically. Note: Only one of `make_model` or
|
||||
`make_model_and_action_dist` should be provided. If both are None,
|
||||
a default Model will be created.
|
||||
make_model_and_action_dist (Optional[Callable[[Policy,
|
||||
gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
|
||||
Tuple[ModelV2, TorchDistributionWrapper]]]): Optional callable that
|
||||
takes the same arguments as Policy.__init__ and returns a tuple
|
||||
of model instance and torch action distribution class.
|
||||
Note: Only one of `make_model` or `make_model_and_action_dist`
|
||||
should be provided.
|
||||
apply_gradients_fn (Optional[callable]): Optional callable that
|
||||
should be provided. If both are None, a default Model will be
|
||||
created.
|
||||
apply_gradients_fn (Optional[Callable[[Policy,
|
||||
"torch.optim.Optimizer"], None]]): Optional callable that
|
||||
takes a grads list and applies these to the Model's parameters.
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the TorchPolicy class.
|
||||
get_batch_divisibility_req (Optional[callable]): Optional callable that
|
||||
returns the divisibility requirement for sample batches.
|
||||
If None, will call the `TorchPolicy.apply_gradients()` method
|
||||
instead.
|
||||
mixins (Optional[List[type]]): Optional list of any class mixins for
|
||||
the returned policy class. These mixins will be applied in order
|
||||
and will have higher precedence than the TorchPolicy class.
|
||||
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
|
||||
Optional callable that returns the divisibility requirement for
|
||||
sample batches. If None, will assume a value of 1.
|
||||
|
||||
Returns:
|
||||
type: TorchPolicy child class constructed from the specified args.
|
||||
|
|
Loading…
Add table
Reference in a new issue