
337 lines
17 KiB
Raw Normal View History

import gym
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_non_torch_type
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
torch, _ = try_import_torch()
def build_torch_policy(
name: str,
loss_fn: Callable[[
Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
], Union[TensorType, List[TensorType]]],
get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
str, TensorType]]] = None,
postprocess_fn: Optional[Callable[[
Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode"
], None]] = None,
extra_action_out_fn: Optional[Callable[[
Policy, Dict[str, TensorType], List[TensorType], ModelV2,
], Dict[str, TensorType]]] = None,
extra_grad_process_fn: Optional[Callable[[
Policy, "torch.optim.Optimizer", TensorType
], Dict[str, TensorType]]] = None,
# TODO: (sven) Replace "fetches" with "process".
extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[
str, TensorType]]] = None,
optimizer_fn: Optional[Callable[[Policy, TrainerConfigDict],
"torch.optim.Optimizer"]] = None,
validate_spaces: Optional[Callable[
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
before_init: Optional[Callable[
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
after_init: Optional[Callable[
[Policy, gym.Space, gym.Space, TrainerConfigDict], None]] = None,
action_sampler_fn: Optional[Callable[[TensorType, List[
TensorType]], Tuple[TensorType, TensorType]]] = None,
action_distribution_fn: Optional[Callable[[
Policy, ModelV2, TensorType, TensorType, TensorType
], Tuple[TensorType, type, List[TensorType]]]] = None,
make_model: Optional[Callable[[
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
], ModelV2]] = None,
make_model_and_action_dist: Optional[Callable[[
Policy, gym.spaces.Space, gym.spaces.Space, TrainerConfigDict
], Tuple[ModelV2, TorchDistributionWrapper]]] = None,
apply_gradients_fn: Optional[Callable[
[Policy, "torch.optim.Optimizer"], None]] = None,
mixins: Optional[List[type]] = None,
training_view_requirements_fn: Optional[Callable[[], Dict[
str, ViewRequirement]]] = None,
get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None):
"""Helper function for creating a torch policy class at runtime.
name (str): name of the policy (e.g., "PPOTorchPolicy")
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
Callable that returns a loss tensor.
get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
Optional callable that returns the default config to merge with any
overrides. If None, uses only(!) the user-provided
PartialTrainerConfigDict as dict for this Policy.
postprocess_fn (Optional[Callable[[Policy, SampleBatch,
List[SampleBatch], MultiAgentEpisode], None]]): Optional callable
for post-processing experience batches (called after the
super's `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch],
Dict[str, TensorType]]]): Optional callable that returns a dict of
values given the policy and training batch. If None,
will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
used for logging (e.g. in TensorBoard).
extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
TensorType]]]): Optional callable that returns a dict of extra
values to include in experiences. If None, no extra computations
will be performed.
extra_grad_process_fn (Optional[Callable[[Policy,
"torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
Optional callable that is called after gradients are computed and
returns a processing info dict. If None, will call the
`TorchPolicy.extra_grad_process()` method instead.
# TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
extra_learn_fetches_fn (Optional[Callable[[Policy],
Dict[str, TensorType]]]): Optional callable that returns a dict of
extra tensors from the policy after loss evaluation. If None,
will call the `TorchPolicy.extra_compute_grad_fetches()` method
optimizer_fn (Optional[Callable[[Policy, TrainerConfigDict],
"torch.optim.Optimizer"]]): Optional callable that returns a
torch optimizer given the policy and config. If None, will call
the `TorchPolicy.optimizer()` method instead (which returns a
torch Adam optimizer).
validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable that takes the
Policy, observation_space, action_space, and config to check for
correctness. If None, no spaces checking will be done.
before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable to run at the
beginning of `Policy.__init__` that takes the same arguments as
the Policy constructor. If None, this step will be skipped.
after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
TrainerConfigDict], None]]): Optional callable to run at the end of
policy init that takes the same arguments as the policy
constructor. If None, this step will be skipped.
action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
Tuple[TensorType, TensorType]]]): Optional callable returning a
sampled action and its log-likelihood given some (obs and state)
inputs. If None, will either use `action_distribution_fn` or
compute actions by calling self.model, then sampling from the
so parameterized action distribution.
action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
TensorType, TensorType], Tuple[TensorType, type,
List[TensorType]]]]): A callable that takes
the Policy, Model, the observation batch, an explore-flag, a
timestep, and an is_training flag and returns a tuple of
a) distribution inputs (parameters), b) a dist-class to generate
an action distribution object from, and c) internal-state outputs
(empty list if not applicable). If None, will either use
`action_sampler_fn` or compute actions by calling self.model,
then sampling from the parameterized action distribution.
make_model (Optional[Callable[[Policy, gym.spaces.Space,
gym.spaces.Space, TrainerConfigDict], ModelV2]]): Optional callable
that takes the same arguments as Policy.__init__ and returns a
model instance. The distribution class will be determined
automatically. Note: Only one of `make_model` or
`make_model_and_action_dist` should be provided. If both are None,
a default Model will be created.
make_model_and_action_dist (Optional[Callable[[Policy,
gym.spaces.Space, gym.spaces.Space, TrainerConfigDict],
Tuple[ModelV2, TorchDistributionWrapper]]]): Optional callable that
takes the same arguments as Policy.__init__ and returns a tuple
of model instance and torch action distribution class.
Note: Only one of `make_model` or `make_model_and_action_dist`
should be provided. If both are None, a default Model will be
apply_gradients_fn (Optional[Callable[[Policy,
"torch.optim.Optimizer"], None]]): Optional callable that
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
takes a grads list and applies these to the Model's parameters.
If None, will call the `TorchPolicy.apply_gradients()` method
mixins (Optional[List[type]]): Optional list of any class mixins for
the returned policy class. These mixins will be applied in order
and will have higher precedence than the TorchPolicy class.
training_view_requirements_fn (Callable[[],
Dict[str, ViewRequirement]]): An optional callable to retrieve
additional train view requirements for this policy.
get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
Optional callable that returns the divisibility requirement for
sample batches. If None, will assume a value of 1.
type: TorchPolicy child class constructed from the specified args.
original_kwargs = locals().copy()
base = add_mixins(TorchPolicy, mixins)
class policy_cls(base):
def __init__(self, obs_space, action_space, config):
if get_default_config:
config = dict(get_default_config(), **config)
self.config = config
if validate_spaces:
validate_spaces(self, obs_space, action_space, self.config)
if before_init:
before_init(self, obs_space, action_space, self.config)
# Model is customized (use default action dist class).
if make_model:
assert make_model_and_action_dist is None, \
"Either `make_model` or `make_model_and_action_dist`" \
" must be None!"
self.model = make_model(self, obs_space, action_space, config)
dist_class, _ = ModelCatalog.get_action_dist(
action_space, self.config["model"], framework="torch")
# Model and action dist class are customized.
elif make_model_and_action_dist:
self.model, dist_class = make_model_and_action_dist(
self, obs_space, action_space, config)
# Use default model and default action dist.
dist_class, logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"], framework="torch")
self.model = ModelCatalog.get_model_v2(
# Make sure, we passed in a correct Model factory.
assert isinstance(self.model, TorchModelV2), \
"ERROR: Generated Model must be a TorchModelV2 object!"
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
if callable(training_view_requirements_fn):
if after_init:
after_init(self, obs_space, action_space, config)
def postprocess_trajectory(self,
# Do all post-processing always with no_grad().
# Not using this here will introduce a memory leak (issue #6962).
with torch.no_grad():
# Call super's postprocess_trajectory first.
sample_batch = super().postprocess_trajectory(
convert_to_non_torch_type(other_agent_batches), episode)
if postprocess_fn:
return postprocess_fn(self, sample_batch,
other_agent_batches, episode)
return sample_batch
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
def extra_grad_process(self, optimizer, loss):
"""Called after optimizer.zero_grad() and loss.backward() calls.
Allows for gradient processing before optimizer.step() is called.
E.g. for gradient clipping.
if extra_grad_process_fn:
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
return extra_grad_process_fn(self, optimizer, loss)
return TorchPolicy.extra_grad_process(self, optimizer, loss)
def extra_compute_grad_fetches(self):
if extra_learn_fetches_fn:
fetches = convert_to_non_torch_type(
# Auto-add empty learner stats dict if needed.
return dict({LEARNER_STATS_KEY: {}}, **fetches)
return TorchPolicy.extra_compute_grad_fetches(self)
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
def apply_gradients(self, gradients):
if apply_gradients_fn:
apply_gradients_fn(self, gradients)
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
TorchPolicy.apply_gradients(self, gradients)
def extra_action_out(self, input_dict, state_batches, model,
with torch.no_grad():
if extra_action_out_fn:
stats_dict = extra_action_out_fn(
2020-03-29 00:16:30 +01:00
self, input_dict, state_batches, model, action_dist)
stats_dict = TorchPolicy.extra_action_out(
2020-03-29 00:16:30 +01:00
self, input_dict, state_batches, model, action_dist)
return convert_to_non_torch_type(stats_dict)
def optimizer(self):
if optimizer_fn:
optimizers = optimizer_fn(self, self.config)
optimizers = TorchPolicy.optimizer(self)
optimizers = force_list(optimizers)
if hasattr(self, "exploration"):
optimizers = self.exploration.get_exploration_optimizer(
return optimizers
def extra_grad_info(self, train_batch):
with torch.no_grad():
if stats_fn:
stats_dict = stats_fn(self, train_batch)
stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
return convert_to_non_torch_type(stats_dict)
def with_updates(**overrides):
"""Allows creating a TorchPolicy cls based on settings of another one.
Keyword Args:
**overrides: The settings (passed into `build_torch_policy`) that
should be different from the class that this method is called
type: A new TorchPolicy sub-class.
>> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
.. name="MySpecialDQNPolicyClass",
.. loss_function=[some_new_loss_function],
.. )
return build_torch_policy(**dict(original_kwargs, **overrides))
policy_cls.with_updates = staticmethod(with_updates)
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls