ray/rllib/agents/ppo/tests/test_ppo.py

416 lines
17 KiB
Python
Raw Normal View History

import copy
import numpy as np
import unittest
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
ppo_surrogate_loss_tf
from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
ppo_surrogate_loss_torch
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
LEARNER_STATS_KEY
from ray.rllib.utils.numpy import fc
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
check_train_results, framework_iterator
# Fake CartPole episode of n time steps.
FAKE_BATCH = SampleBatch({
SampleBatch.OBS: np.array(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]],
dtype=np.float32),
SampleBatch.ACTIONS: np.array([0, 1, 1]),
SampleBatch.PREV_ACTIONS: np.array([0, 1, 1]),
SampleBatch.REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
SampleBatch.PREV_REWARDS: np.array([1.0, -1.0, .5], dtype=np.float32),
SampleBatch.DONES: np.array([False, False, True]),
SampleBatch.VF_PREDS: np.array([0.5, 0.6, 0.7], dtype=np.float32),
SampleBatch.ACTION_DIST_INPUTS: np.array(
[[-2., 0.5], [-3., -0.3], [-0.1, 2.5]], dtype=np.float32),
SampleBatch.ACTION_LOGP: np.array([-0.5, -0.1, -0.2], dtype=np.float32),
SampleBatch.EPS_ID: np.array([0, 0, 0]),
SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
})
class MyCallbacks(DefaultCallbacks):
@staticmethod
def _check_lr_torch(policy, policy_id):
for j, opt in enumerate(policy._optimizers):
for p in opt.param_groups:
assert p["lr"] == policy.cur_lr, "LR scheduling error!"
@staticmethod
def _check_lr_tf(policy, policy_id):
lr = policy.cur_lr
sess = policy.get_session()
if sess:
lr = sess.run(lr)
optim_lr = sess.run(policy._optimizer._lr)
else:
lr = lr.numpy()
optim_lr = policy._optimizer.lr.numpy()
assert lr == optim_lr, "LR scheduling error!"
def on_train_result(self, *, trainer, result: dict, **kwargs):
stats = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
LEARNER_STATS_KEY]
# Learning rate should go to 0 after 1 iter.
check(stats["cur_lr"], 5e-5 if trainer.iteration == 1 else 0.0)
# Entropy coeff goes to 0.05, then 0.0 (per iter).
check(stats["entropy_coeff"], 0.1 if trainer.iteration == 1 else 0.05)
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
"framework"] == "torch" else self._check_lr_tf)
class TestPPO(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_ppo_compilation_and_schedule_mixins(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
# For checking lr-schedule correctness.
config["callbacks"] = MyCallbacks
config["num_workers"] = 1
config["num_sgd_iter"] = 2
# Settings in case we use an LSTM.
config["model"]["lstm_cell_size"] = 10
config["model"]["max_seq_len"] = 20
# Use default-native keras models whenever possible.
# config["model"]["_use_default_native_models"] = True
# Setup lr- and entropy schedules for testing.
config["lr_schedule"] = [[0, config["lr"]], [128, 0.0]]
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
config["entropy_coeff"] = 100.0
config["entropy_coeff_schedule"] = [[0, 0.1], [256, 0.0]]
config["train_batch_size"] = 128
# Test with compression.
config["compress_observations"] = True
num_iterations = 2
for fw in framework_iterator(config, with_eager_tracing=True):
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env))
for lstm in [True, False]:
print("LSTM={}".format(lstm))
config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action"] = lstm
config["model"]["lstm_use_prev_reward"] = lstm
trainer = ppo.PPOTrainer(config=config, env=env)
policy = trainer.get_policy()
entropy_coeff = trainer.get_policy().entropy_coeff
lr = policy.cur_lr
if fw == "tf":
entropy_coeff, lr = policy.get_session().run(
[entropy_coeff, lr])
check(entropy_coeff, 0.1)
check(lr, config["lr"])
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
check_compute_single_action(
trainer,
include_prev_action_reward=True,
include_state=lstm)
trainer.stop()
def test_ppo_exploration_setup(self):
"""Tests, whether PPO runs with different exploration setups."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 0 # Run locally.
config["env_config"] = {"is_slippery": False, "map_name": "4x4"}
obs = np.array(0)
# Test against all frameworks.
for fw in framework_iterator(config):
# Default Agent should be setup with StochasticSampling.
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v1")
# explore=False, always expect the same (deterministic) action.
a_ = trainer.compute_single_action(
obs,
explore=False,
prev_action=np.array(2),
prev_reward=np.array(1.0))
# Test whether this is really the argmax action over the logits.
if fw != "tf":
last_out = trainer.get_policy().model.last_output()
if fw == "torch":
check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
else:
check(a_, np.argmax(last_out.numpy(), 1)[0])
for _ in range(50):
a = trainer.compute_single_action(
obs,
explore=False,
prev_action=np.array(2),
prev_reward=np.array(1.0))
check(a, a_)
# With explore=True (default), expect stochastic actions.
actions = []
for _ in range(300):
actions.append(
trainer.compute_single_action(
obs,
prev_action=np.array(2),
prev_reward=np.array(1.0)))
check(np.mean(actions), 1.5, atol=0.2)
trainer.stop()
def test_ppo_free_log_std(self):
"""Tests the free log std option works."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 0 # Run locally.
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config["model"]["free_log_std"] = True
config["model"]["vf_share_layers"] = True
for fw, sess in framework_iterator(config, session=True):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Check the free log std var is created.
if fw == "torch":
matching = [
v for (n, v) in policy.model.named_parameters()
if "log_std" in n
]
else:
matching = [
v for v in policy.model.trainable_variables()
if "log_std" in str(v)
]
assert len(matching) == 1, matching
log_std_var = matching[0]
def get_value():
if fw == "tf":
return policy.get_session().run(log_std_var)[0]
elif fw == "torch":
return log_std_var.detach().cpu().numpy()[0]
else:
return log_std_var.numpy()[0]
# Check the variable is initially zero.
init_std = get_value()
assert init_std == 0.0, init_std
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
if fw == "torch":
batch = policy._lazy_tensor_dict(batch)
policy.learn_on_batch(batch)
# Check the variable is updated.
post_std = get_value()
assert post_std != 0.0, post_std
trainer.stop()
def test_ppo_loss_function(self):
"""Tests the PPO loss function math."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
config["num_workers"] = 0 # Run locally.
config["gamma"] = 0.99
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = "linear"
config["model"]["vf_share_layers"] = True
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
for fw, sess in framework_iterator(config, session=True):
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
policy = trainer.get_policy()
# Check no free log std var by default.
if fw == "torch":
matching = [
v for (n, v) in policy.model.named_parameters()
if "log_std" in n
]
else:
matching = [
v for v in policy.model.trainable_variables()
if "log_std" in str(v)
]
assert len(matching) == 0, matching
# Post-process (calculate simple (non-GAE) advantages) and attach
# to train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5]
train_batch = compute_gae_for_sample_batch(policy,
FAKE_BATCH.copy())
if fw == "torch":
train_batch = policy._lazy_tensor_dict(train_batch)
# Check Advantage values.
check(train_batch[Postprocessing.VALUE_TARGETS],
[0.50005, -0.505, 0.5])
# Calculate actual PPO loss.
2020-07-11 22:06:35 +02:00
if fw in ["tf2", "tfe"]:
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
train_batch)
elif fw == "torch":
ppo_surrogate_loss_torch(policy, policy.model,
TorchCategorical, train_batch)
vars = policy.model.variables() if fw != "torch" else \
list(policy.model.parameters())
if fw == "tf":
vars = policy.get_session().run(vars)
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
expected_shared_out = fc(
train_batch[SampleBatch.CUR_OBS],
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
vars[0 if fw != "torch" else 2],
vars[1 if fw != "torch" else 3],
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
framework=fw)
expected_logits = fc(
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
expected_shared_out,
vars[2 if fw != "torch" else 0],
vars[3 if fw != "torch" else 1],
framework=fw)
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
expected_value_outs = fc(
expected_shared_out, vars[4], vars[5], framework=fw)
kl, entropy, pg_loss, vf_loss, overall_loss = \
self._ppo_loss_helper(
policy, policy.model,
Categorical if fw != "torch" else TorchCategorical,
train_batch,
expected_logits, expected_value_outs,
sess=sess
)
if sess:
policy_sess = policy.get_session()
k, e, pl, v, tl = policy_sess.run(
[
policy._mean_kl_loss,
policy._mean_entropy,
policy._mean_policy_loss,
policy._mean_vf_loss,
policy._total_loss,
],
feed_dict=policy._get_loss_inputs_dict(
train_batch, shuffle=False))
check(k, kl)
check(e, entropy)
check(pl, np.mean(-pg_loss))
check(v, np.mean(vf_loss), decimals=4)
check(tl, overall_loss, decimals=4)
elif fw == "torch":
check(policy.model.tower_stats["mean_kl_loss"], kl)
check(policy.model.tower_stats["mean_entropy"], entropy)
check(policy.model.tower_stats["mean_policy_loss"],
np.mean(-pg_loss))
check(
policy.model.tower_stats["mean_vf_loss"],
np.mean(vf_loss),
decimals=4)
check(
policy.model.tower_stats["total_loss"],
overall_loss,
decimals=4)
else:
check(policy._mean_kl_loss, kl)
check(policy._mean_entropy, entropy)
check(policy._mean_policy_loss, np.mean(-pg_loss))
check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
check(policy._total_loss, overall_loss, decimals=4)
trainer.stop()
def _ppo_loss_helper(self,
policy,
model,
dist_class,
train_batch,
logits,
vf_outs,
sess=None):
"""
Calculates the expected PPO loss (components) given Policy,
Model, distribution, some batch, logits & vf outputs, using numpy.
"""
# Calculate expected PPO loss results.
dist = dist_class(logits, policy.model)
dist_prev = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
policy.model)
expected_logp = dist.logp(train_batch[SampleBatch.ACTIONS])
if isinstance(model, TorchModelV2):
train_batch.set_get_interceptor(None)
expected_rho = np.exp(expected_logp.detach().cpu().numpy() -
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
kl = np.mean(dist_prev.kl(dist).detach().cpu().numpy())
# Entropy-loss component.
entropy = np.mean(dist.entropy().detach().cpu().numpy())
else:
if sess:
expected_logp = sess.run(expected_logp)
expected_rho = np.exp(expected_logp -
train_batch[SampleBatch.ACTION_LOGP])
# KL(prev vs current action dist)-loss component.
kl = dist_prev.kl(dist)
if sess:
kl = sess.run(kl)
kl = np.mean(kl)
# Entropy-loss component.
entropy = dist.entropy()
if sess:
entropy = sess.run(entropy)
entropy = np.mean(entropy)
# Policy loss component.
pg_loss = np.minimum(
train_batch[Postprocessing.ADVANTAGES] * expected_rho,
train_batch[Postprocessing.ADVANTAGES] * np.clip(
expected_rho, 1 - policy.config["clip_param"],
1 + policy.config["clip_param"]))
# Value function loss component.
vf_loss1 = np.power(
vf_outs - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_clipped = train_batch[SampleBatch.VF_PREDS] + np.clip(
vf_outs - train_batch[SampleBatch.VF_PREDS],
-policy.config["vf_clip_param"], policy.config["vf_clip_param"])
vf_loss2 = np.power(
vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
vf_loss = np.maximum(vf_loss1, vf_loss2)
# Overall loss.
if sess:
policy_sess = policy.get_session()
kl_coeff, entropy_coeff = policy_sess.run(
[policy.kl_coeff, policy.entropy_coeff])
else:
kl_coeff, entropy_coeff = policy.kl_coeff, policy.entropy_coeff
overall_loss = np.mean(-pg_loss + kl_coeff * kl +
policy.config["vf_loss_coeff"] * vf_loss -
entropy_coeff * entropy)
return kl, entropy, pg_loss, vf_loss, overall_loss
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))