ray/rllib/agents/a3c/a3c_torch_policy.py
Sven Mika 428516056a
[RLlib] SAC Torch (incl. Atari learning) (#7984)
* Policy-classes cleanup and torch/tf unification.
- Make Policy abstract.
- Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch).
- Move some methods and vars to base Policy
  (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more.

* Fix `clip_action` import from Policy (should probably be moved into utils altogether).

* - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy).
- Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces).

* Add `config` to c'tor call to TFPolicy.

* Add missing `config` to c'tor call to TFPolicy in marvil_policy.py.

* Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract).

* Fix LINT errors in Policy classes.

* Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py.

* policy.py LINT errors.

* Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases).

* policy.py
- Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented).
- Fix docstring of `num_state_tensors`.

* Make QMIX torch Policy a child of TorchPolicy (instead of Policy).

* QMixPolicy add empty implementations of abstract Policy methods.

* Store Policy's config in self.config in base Policy c'tor.

* - Make only compute_actions in base Policy's an abstractmethod and provide pass
implementation to all other methods if not defined.
- Fix state_batches=None (most Policies don't have internal states).

* Cartpole tf learning.

* Cartpole tf AND torch learning (in ~ same ts).

* Cartpole tf AND torch learning (in ~ same ts). 2

* Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3

* Cartpole tf AND torch learning (in ~ same ts). 4

* Cartpole tf AND torch learning (in ~ same ts). 5

* Cartpole tf AND torch learning (in ~ same ts). 6

* Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning.

* WIP.

* WIP.

* SAC torch learning Pendulum.

* WIP.

* SAC torch and tf learning Pendulum and Cartpole after cleanup.

* WIP.

* LINT.

* LINT.

* SAC: Move policy.target_model to policy.device as well.

* Fixes and cleanup.

* Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default).

* Fixes and LINT.

* Fixes and LINT.

* Fix and LINT.

* WIP.

* Test fixes and LINT.

* Fixes and LINT.

Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00

86 lines
2.8 KiB
Python

import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
def actor_critic_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
values = model.value_function()
dist = dist_class(logits, model)
log_probs = dist.logp(train_batch[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
policy.pi_err = -train_batch[Postprocessing.ADVANTAGES].dot(
log_probs.reshape(-1))
policy.value_err = nn.functional.mse_loss(
values.reshape(-1), train_batch[Postprocessing.VALUE_TARGETS])
overall_err = sum([
policy.pi_err,
policy.config["vf_loss_coeff"] * policy.value_err,
-policy.config["entropy_coeff"] * policy.entropy,
])
return overall_err
def loss_and_entropy_stats(policy, train_batch):
return {
"policy_entropy": policy.entropy.item(),
"policy_loss": policy.pi_err.item(),
"vf_loss": policy.value_err.item(),
}
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
policy.config["use_gae"], policy.config["use_critic"])
def model_value_predictions(policy, input_dict, state_batches, model,
action_dist):
return {SampleBatch.VF_PREDS: model.value_function()}
def apply_grad_clipping(policy, optimizer, loss):
info = {}
if policy.config["grad_clip"]:
for param_group in optimizer.param_groups:
info["grad_gnorm"] = nn.utils.clip_grad_norm_(
param_group["params"], policy.config["grad_clip"])
return info
def torch_optimizer(policy, config):
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
class ValueNetworkMixin:
def _value(self, obs):
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1])
return self.model.value_function()[0]
A3CTorchPolicy = build_torch_policy(
name="A3CTorchPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=loss_and_entropy_stats,
postprocess_fn=add_advantages,
extra_action_out_fn=model_value_predictions,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
mixins=[ValueNetworkMixin])