
Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

380 lines
17 KiB
Raw Normal View History

import gym
from typing import Callable, Dict, List, Optional, Tuple, Type, Union, TYPE_CHECKING
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy import eager_tf_policy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
from ray.rllib.evaluation import Episode
tf1, tf, tfv = try_import_tf()
def build_tf_policy(
name: str,
loss_fn: Callable[
[Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
Union[TensorType, List[TensorType]],
get_default_config: Optional[Callable[[None], TrainerConfigDict]] = None,
postprocess_fn: Optional[
Optional[Dict[AgentID, SampleBatch]],
] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
optimizer_fn: Optional[
Callable[[Policy, TrainerConfigDict], "tf.keras.optimizers.Optimizer"]
] = None,
compute_gradients_fn: Optional[
Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]
] = None,
apply_gradients_fn: Optional[
[Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"
] = None,
grad_stats_fn: Optional[
Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
] = None,
2021-02-25 12:18:11 +01:00
extra_action_out_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
validate_spaces: Optional[
Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]
] = None,
before_init: Optional[
Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]
] = None,
before_loss_init: Optional[
Callable[[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], None]
] = None,
after_init: Optional[
Callable[[Policy, gym.Space, gym.Space, TrainerConfigDict], None]
] = None,
make_model: Optional[
[Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict], ModelV2
] = None,
action_sampler_fn: Optional[
Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
] = None,
action_distribution_fn: Optional[
[Policy, ModelV2, TensorType, TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]],
] = None,
mixins: Optional[List[type]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
2021-02-25 12:18:11 +01:00
# Deprecated args.
extra_action_fetches_fn=None, # Use `extra_action_out_fn`.
gradients_fn=None, # Use `compute_gradients_fn`.
) -> Type[DynamicTFPolicy]:
"""Helper function for creating a dynamic tf policy at runtime.
Functions will be run in this order to initialize the policy:
1. Placeholder setup: postprocess_fn
2. Loss init: loss_fn, stats_fn
3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
This means that you can e.g., depend on any policy attributes created in
the running of `loss_fn` in later functions such as `stats_fn`.
In eager mode, the following functions will be run repeatedly on each
eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
and grad_stats_fn.
2019-07-03 15:59:47 -07:00
This means that these functions should not define any variables internally,
otherwise they will fail in eager mode execution. Variable should only
be created in make_model (if defined).
name: Name of the policy (e.g., "PPOTFPolicy").
loss_fn (Callable[[
Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
Union[TensorType, List[TensorType]]]): Callable for calculating a
loss tensor.
get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
Optional callable that returns the default config to merge with any
overrides. If None, uses only(!) the user-provided
PartialTrainerConfigDict as dict for this Policy.
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
Optional callable for post-processing experience batches (called
after the parent class' `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch],
Dict[str, TensorType]]]): Optional callable that returns a dict of
TF tensors to fetch given the policy and batch input tensors. If
None, will not compute any stats.
optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
"tf.keras.optimizers.Optimizer"]]): Optional callable that returns
a tf.Optimizer given the policy and config. If None, will call
the base class' `optimizer()` method instead (which returns a
compute_gradients_fn (Optional[Callable[[Policy,
"tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
Optional callable that returns a list of gradients. If None,
this defaults to optimizer.compute_gradients([loss]).
apply_gradients_fn (Optional[Callable[[Policy,
"tf.keras.optimizers.Optimizer", ModelGradients],
"tf.Operation"]]): Optional callable that returns an apply
gradients op given policy, tf-optimizer, and grads_and_vars. If
None, will call the base class' `build_apply_op()` method instead.
grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
Dict[str, TensorType]]]): Optional callable that returns a dict of
TF fetches given the policy, batch input, and gradient tensors. If
None, will not collect any gradient stats.
2021-02-25 12:18:11 +01:00
extra_action_out_fn (Optional[Callable[[Policy],
Dict[str, TensorType]]]): Optional callable that returns
a dict of TF fetches given the policy object. If None, will not
perform any extra fetches.
extra_learn_fetches_fn (Optional[Callable[[Policy],
Dict[str, TensorType]]]): Optional callable that returns a dict of
extra values to fetch and return when learning on a batch. If None,
will call the base class' `extra_compute_grad_fetches()` method
validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable that takes the
Policy, observation_space, action_space, and config to check
the spaces for correctness. If None, no spaces checking will be
before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable to run at the
beginning of policy init that takes the same arguments as the
policy constructor. If None, this step will be skipped.
before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
gym.spaces.Space, TrainerConfigDict], None]]): Optional callable to
run prior to loss init. If None, this step will be skipped.
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable to run at the end of
policy init. If None, this step will be skipped.
make_model (Optional[Callable[[Policy, gym.spaces.Space,
gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
that returns a ModelV2 object.
All policy variables should be created in this function. If None,
a default ModelV2 object will be created.
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
Tuple[TensorType, TensorType]]]): A callable returning a sampled
action and its log-likelihood given observation and state inputs.
If None, will either use `action_distribution_fn` or
compute actions by calling self.model, then sampling from the
so parameterized action distribution.
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
TensorType, TensorType],
Tuple[TensorType, type, List[TensorType]]]]): Optional callable
returning distribution inputs (parameters), a dist-class to
generate an action distribution object from, and internal-state
outputs (or an empty list if not applicable). If None, will either
use `action_sampler_fn` or compute actions by calling self.model,
then sampling from the so parameterized action distribution.
mixins (Optional[List[type]]): Optional list of any class mixins for
the returned policy class. These mixins will be applied in order
and will have higher precedence than the DynamicTFPolicy class.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
Optional callable that returns the divisibility requirement for
sample batches. If None, will assume a value of 1.
Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
specified args.
original_kwargs = locals().copy()
base = add_mixins(DynamicTFPolicy, mixins)
if obs_include_prev_action_reward != DEPRECATED_VALUE:
deprecation_warning(old="obs_include_prev_action_reward", error=False)
2021-02-25 12:18:11 +01:00
if extra_action_fetches_fn is not None:
old="extra_action_fetches_fn", new="extra_action_out_fn", error=False
extra_action_out_fn = extra_action_fetches_fn
if gradients_fn is not None:
deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=False)
compute_gradients_fn = gradients_fn
class policy_cls(base):
def __init__(
2019-07-03 15:59:47 -07:00
if get_default_config:
config = dict(get_default_config(), **config)
if validate_spaces:
validate_spaces(self, obs_space, action_space, config)
if before_init:
before_init(self, obs_space, action_space, config)
def before_loss_init_wrapper(policy, obs_space, action_space, config):
if before_loss_init:
before_loss_init(policy, obs_space, action_space, config)
if extra_action_out_fn is None or policy._is_tower:
extra_action_fetches = {}
extra_action_fetches = extra_action_out_fn(policy)
if hasattr(policy, "_extra_action_fetches"):
policy._extra_action_fetches = extra_action_fetches
2019-07-03 15:59:47 -07:00
if after_init:
after_init(self, obs_space, action_space, config)
# Got to reset global_timestep again after this fake run-through.
self.global_timestep = 0
def postprocess_trajectory(
self, sample_batch, other_agent_batches=None, episode=None
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
2020-03-29 00:16:30 +01:00
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
return sample_batch
def optimizer(self):
if optimizer_fn:
optimizers = optimizer_fn(self, self.config)
optimizers = base.optimizer(self)
optimizers = force_list(optimizers)
if getattr(self, "exploration", None):
optimizers = self.exploration.get_exploration_optimizer(optimizers)
# No optimizers produced -> Return None.
if not optimizers:
return None
# New API: Allow more than one optimizer to be returned.
# -> Return list.
elif self.config["_tf_policy_handles_more_than_one_loss"]:
return optimizers
# Old API: Return a single LocalOptimizer.
return optimizers[0]
def gradients(self, optimizer, loss):
optimizers = force_list(optimizer)
losses = force_list(loss)
if compute_gradients_fn:
# New API: Allow more than one optimizer -> Return a list of
# lists of gradients.
if self.config["_tf_policy_handles_more_than_one_loss"]:
return compute_gradients_fn(self, optimizers, losses)
# Old API: Return a single List of gradients.
return compute_gradients_fn(self, optimizers[0], losses[0])
return base.gradients(self, optimizers, losses)
def build_apply_op(self, optimizer, grads_and_vars):
if apply_gradients_fn:
return apply_gradients_fn(self, optimizer, grads_and_vars)
return base.build_apply_op(self, optimizer, grads_and_vars)
def extra_compute_action_fetches(self):
return dict(
base.extra_compute_action_fetches(self), **self._extra_action_fetches
def extra_compute_grad_fetches(self):
if extra_learn_fetches_fn:
# TODO: (sven) in torch, extra_learn_fetches do not exist.
# Hence, things like td_error are returned by the stats_fn
# and end up under the LEARNER_STATS_KEY. We should
# change tf to do this as well. However, this will confilct
# the handling of LEARNER_STATS_KEY inside the multi-GPU
# train op.
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
# Auto-add empty learner stats dict if needed.
return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self))
return base.extra_compute_grad_fetches(self)
def with_updates(**overrides):
"""Allows creating a TFPolicy cls based on settings of another one.
Keyword Args:
**overrides: The settings (passed into `build_tf_policy`) that
should be different from the class that this method is called
type: A new TFPolicy sub-class.
>> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
.. name="MySpecialDQNPolicyClass",
.. loss_function=[some_new_loss_function],
.. )
return build_tf_policy(**dict(original_kwargs, **overrides))
def as_eager():
return eager_tf_policy._build_eager_tf_policy(**original_kwargs)
policy_cls.with_updates = staticmethod(with_updates)
policy_cls.as_eager = staticmethod(as_eager)
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls