ray/rllib/agents/trainer_template.py
Sven Mika 428516056a
[RLlib] SAC Torch (incl. Atari learning) (#7984)
* Policy-classes cleanup and torch/tf unification.
- Make Policy abstract.
- Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch).
- Move some methods and vars to base Policy
  (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more.

* Fix `clip_action` import from Policy (should probably be moved into utils altogether).

* - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy).
- Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces).

* Add `config` to c'tor call to TFPolicy.

* Add missing `config` to c'tor call to TFPolicy in marvil_policy.py.

* Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract).

* Fix LINT errors in Policy classes.

* Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py.

* policy.py LINT errors.

* Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases).

* policy.py
- Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented).
- Fix docstring of `num_state_tensors`.

* Make QMIX torch Policy a child of TorchPolicy (instead of Policy).

* QMixPolicy add empty implementations of abstract Policy methods.

* Store Policy's config in self.config in base Policy c'tor.

* - Make only compute_actions in base Policy's an abstractmethod and provide pass
implementation to all other methods if not defined.
- Fix state_batches=None (most Policies don't have internal states).

* Cartpole tf learning.

* Cartpole tf AND torch learning (in ~ same ts).

* Cartpole tf AND torch learning (in ~ same ts). 2

* Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3

* Cartpole tf AND torch learning (in ~ same ts). 4

* Cartpole tf AND torch learning (in ~ same ts). 5

* Cartpole tf AND torch learning (in ~ same ts). 6

* Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning.

* WIP.

* WIP.

* SAC torch learning Pendulum.

* WIP.

* SAC torch and tf learning Pendulum and Cartpole after cleanup.

* WIP.

* LINT.

* LINT.

* SAC: Move policy.target_model to policy.device as well.

* Fixes and cleanup.

* Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default).

* Fixes and LINT.

* Fixes and LINT.

* Fix and LINT.

* WIP.

* Test fixes and LINT.

* Fixes and LINT.

Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00

214 lines
9 KiB
Python

import logging
import os
import time
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
logger = logging.getLogger(__name__)
@DeveloperAPI
def build_trainer(name,
default_policy,
default_config=None,
validate_config=None,
get_initial_state=None,
get_policy_class=None,
before_init=None,
make_workers=None,
make_policy_optimizer=None,
after_init=None,
before_train_step=None,
after_optimizer_step=None,
after_train_result=None,
collect_metrics_fn=None,
before_evaluate_fn=None,
mixins=None,
execution_plan=None):
"""Helper function for defining a custom trainer.
Functions will be run in this order to initialize the trainer:
1. Config setup: validate_config, get_initial_state, get_policy
2. Worker setup: before_init, make_workers, make_policy_optimizer
3. Post setup: after_init
Arguments:
name (str): name of the trainer (e.g., "PPO")
default_policy (cls): the default Policy class to use
default_config (dict): The default config dict of the algorithm,
otherwise uses the Trainer default config.
validate_config (func): optional callback that checks a given config
for correctness. It may mutate the config as needed.
get_initial_state (func): optional function that returns the initial
state dict given the trainer instance as an argument. The state
dict must be serializable so that it can be checkpointed, and will
be available as the `trainer.state` variable.
get_policy_class (func): optional callback that takes a config and
returns the policy class to override the default with
before_init (func): optional function to run at the start of trainer
init that takes the trainer instance as argument
make_workers (func): override the method that creates rollout workers.
This takes in (trainer, env_creator, policy, config) as args.
make_policy_optimizer (func): optional function that returns a
PolicyOptimizer instance given (WorkerSet, config)
after_init (func): optional function to run at the end of trainer init
that takes the trainer instance as argument
before_train_step (func): optional callback to run before each train()
call. It takes the trainer instance as an argument.
after_optimizer_step (func): optional callback to run after each
step() call to the policy optimizer. It takes the trainer instance
and the policy gradient fetches as arguments.
after_train_result (func): optional callback to run at the end of each
train() call. It takes the trainer instance and result dict as
arguments, and may mutate the result dict as needed.
collect_metrics_fn (func): override the method used to collect metrics.
It takes the trainer instance as argumnt.
before_evaluate_fn (func): callback to run before evaluation. This
takes the trainer instance as argument.
mixins (list): list of any class mixins for the returned trainer class.
These mixins will be applied in order and will have higher
precedence than the Trainer class
execution_plan (func): Experimental distributed execution
API. This overrides `make_policy_optimizer`.
Returns:
a Trainer instance that uses the specified args.
"""
original_kwargs = locals().copy()
base = add_mixins(Trainer, mixins)
class trainer_cls(base):
_name = name
_default_config = default_config or COMMON_CONFIG
_policy = default_policy
def __init__(self, config=None, env=None, logger_creator=None):
Trainer.__init__(self, config, env, logger_creator)
def _init(self, config, env_creator):
if validate_config:
validate_config(config)
if get_initial_state:
self.state = get_initial_state(self)
else:
self.state = {}
if get_policy_class is None:
self._policy = default_policy
else:
self._policy = get_policy_class(config)
if before_init:
before_init(self)
use_exec_api = (execution_plan
and (self.config["use_exec_api"]
or "RLLIB_EXEC_API" in os.environ))
# Creating all workers (excluding evaluation workers).
if make_workers and not use_exec_api:
self.workers = make_workers(self, env_creator, self._policy,
config)
else:
self.workers = self._make_workers(env_creator, self._policy,
config,
self.config["num_workers"])
self.train_exec_impl = None
self.optimizer = None
self.execution_plan = execution_plan
if use_exec_api:
logger.warning(
"The experimental distributed execution API is enabled "
"for this algorithm. Disable this by setting "
"'use_exec_api': False.")
self.train_exec_impl = execution_plan(self.workers, config)
elif make_policy_optimizer:
self.optimizer = make_policy_optimizer(self.workers, config)
else:
optimizer_config = dict(
config["optimizer"],
**{"train_batch_size": config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(self.workers,
**optimizer_config)
if after_init:
after_init(self)
@override(Trainer)
def _train(self):
if self.train_exec_impl:
return self._train_exec_impl()
if before_train_step:
before_train_step(self)
prev_steps = self.optimizer.num_steps_sampled
start = time.time()
optimizer_steps_this_iter = 0
while True:
fetches = self.optimizer.step()
optimizer_steps_this_iter += 1
if after_optimizer_step:
after_optimizer_step(self, fetches)
if (time.time() - start >= self.config["min_iter_time_s"]
and self.optimizer.num_steps_sampled - prev_steps >=
self.config["timesteps_per_iteration"]):
break
if collect_metrics_fn:
res = collect_metrics_fn(self)
else:
res = self.collect_metrics()
res.update(
optimizer_steps_this_iter=optimizer_steps_this_iter,
timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps,
info=res.get("info", {}))
if after_train_result:
after_train_result(self, res)
return res
def _train_exec_impl(self):
if before_train_step:
logger.warning("Ignoring before_train_step callback")
res = next(self.train_exec_impl)
if after_train_result:
logger.warning("Ignoring after_train_result callback")
return res
@override(Trainer)
def _before_evaluate(self):
if before_evaluate_fn:
before_evaluate_fn(self)
def __getstate__(self):
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
if self.train_exec_impl:
state["train_exec_impl"] = (
self.train_exec_impl.shared_metrics.get().save())
return state
def __setstate__(self, state):
Trainer.__setstate__(self, state)
self.state = state["trainer_state"].copy()
if self.train_exec_impl:
self.train_exec_impl.shared_metrics.get().restore(
state["train_exec_impl"])
def with_updates(**overrides):
"""Build a copy of this trainer with the specified overrides.
Arguments:
overrides (dict): use this to override any of the arguments
originally passed to build_trainer() for this policy.
"""
return build_trainer(**dict(original_kwargs, **overrides))
trainer_cls.with_updates = staticmethod(with_updates)
trainer_cls.__name__ = name
trainer_cls.__qualname__ = name
return trainer_cls