2020-01-02 19:08:03 -05:00
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import ray
|
|
|
|
import ray.rllib.agents.pg as pg
|
|
|
|
from ray.rllib.evaluation.postprocessing import Postprocessing
|
|
|
|
from ray.rllib.models.tf.tf_action_dist import Categorical
|
|
|
|
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
2020-06-13 17:51:50 +02:00
|
|
|
from ray.rllib.utils import check, check_compute_single_action, fc, \
|
|
|
|
framework_iterator
|
2020-01-02 19:08:03 -05:00
|
|
|
|
|
|
|
|
|
|
|
class TestPG(unittest.TestCase):
|
2020-02-19 16:07:37 -08:00
|
|
|
def setUp(self):
|
|
|
|
ray.init()
|
2020-01-02 19:08:03 -05:00
|
|
|
|
2020-02-19 16:07:37 -08:00
|
|
|
def tearDown(self):
|
|
|
|
ray.shutdown()
|
|
|
|
|
2020-01-02 19:08:03 -05:00
|
|
|
def test_pg_compilation(self):
|
|
|
|
"""Test whether a PGTrainer can be built with both frameworks."""
|
|
|
|
config = pg.DEFAULT_CONFIG.copy()
|
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
|
|
num_iterations = 2
|
|
|
|
|
2020-04-07 21:40:34 +02:00
|
|
|
for _ in framework_iterator(config):
|
|
|
|
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
|
|
|
|
for i in range(num_iterations):
|
|
|
|
trainer.train()
|
2020-06-13 17:51:50 +02:00
|
|
|
check_compute_single_action(
|
|
|
|
trainer, include_prev_action_reward=True)
|
2020-01-02 19:08:03 -05:00
|
|
|
|
|
|
|
def test_pg_loss_functions(self):
|
|
|
|
"""Tests the PG loss function math."""
|
|
|
|
config = pg.DEFAULT_CONFIG.copy()
|
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
|
|
config["gamma"] = 0.99
|
|
|
|
config["model"]["fcnet_hiddens"] = [10]
|
|
|
|
config["model"]["fcnet_activation"] = "linear"
|
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
# Fake CartPole episode of n time steps.
|
2020-01-02 19:08:03 -05:00
|
|
|
train_batch = {
|
2020-02-11 00:22:07 +01:00
|
|
|
SampleBatch.CUR_OBS: np.array([[0.1, 0.2, 0.3,
|
|
|
|
0.4], [0.5, 0.6, 0.7, 0.8],
|
|
|
|
[0.9, 1.0, 1.1, 1.2]]),
|
2020-01-02 19:08:03 -05:00
|
|
|
SampleBatch.ACTIONS: np.array([0, 1, 1]),
|
2020-04-07 21:40:34 +02:00
|
|
|
SampleBatch.PREV_ACTIONS: np.array([1, 0, 1]),
|
2020-01-02 19:08:03 -05:00
|
|
|
SampleBatch.REWARDS: np.array([1.0, 1.0, 1.0]),
|
2020-04-07 21:40:34 +02:00
|
|
|
SampleBatch.PREV_REWARDS: np.array([-1.0, -1.0, -1.0]),
|
2020-01-02 19:08:03 -05:00
|
|
|
SampleBatch.DONES: np.array([False, False, True])
|
|
|
|
}
|
|
|
|
|
2020-04-07 21:40:34 +02:00
|
|
|
for fw, sess in framework_iterator(config, session=True):
|
|
|
|
dist_cls = (Categorical if fw != "torch" else TorchCategorical)
|
|
|
|
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
|
|
|
|
policy = trainer.get_policy()
|
|
|
|
vars = policy.model.trainable_variables()
|
|
|
|
if fw == "tf":
|
|
|
|
vars = policy.get_session().run(vars)
|
|
|
|
|
|
|
|
# Post-process (calculate simple (non-GAE) advantages) and attach
|
|
|
|
# to train_batch dict.
|
|
|
|
# A = [0.99^2 * 1.0 + 0.99 * 1.0 + 1.0, 0.99 * 1.0 + 1.0, 1.0] =
|
|
|
|
# [2.9701, 1.99, 1.0]
|
|
|
|
train_batch = pg.post_process_advantages(policy, train_batch)
|
|
|
|
if fw == "torch":
|
|
|
|
train_batch = policy._lazy_tensor_dict(train_batch)
|
|
|
|
|
|
|
|
# Check Advantage values.
|
|
|
|
check(train_batch[Postprocessing.ADVANTAGES], [2.9701, 1.99, 1.0])
|
|
|
|
|
|
|
|
# Actual loss results.
|
|
|
|
if fw == "tf":
|
|
|
|
results = policy.get_session().run(
|
|
|
|
policy._loss,
|
|
|
|
feed_dict=policy._get_loss_inputs_dict(
|
|
|
|
train_batch, shuffle=False))
|
|
|
|
else:
|
2020-05-27 16:19:13 +02:00
|
|
|
results = (pg.pg_tf_loss if fw == "tfe" else pg.pg_torch_loss)(
|
|
|
|
policy,
|
|
|
|
policy.model,
|
|
|
|
dist_class=dist_cls,
|
|
|
|
train_batch=train_batch)
|
2020-04-07 21:40:34 +02:00
|
|
|
|
|
|
|
# Calculate expected results.
|
2020-04-15 13:25:16 +02:00
|
|
|
if fw != "torch":
|
|
|
|
expected_logits = fc(
|
|
|
|
fc(train_batch[SampleBatch.CUR_OBS],
|
|
|
|
vars[0],
|
|
|
|
vars[1],
|
|
|
|
framework=fw),
|
|
|
|
vars[2],
|
|
|
|
vars[3],
|
|
|
|
framework=fw)
|
|
|
|
else:
|
|
|
|
expected_logits = fc(
|
|
|
|
fc(train_batch[SampleBatch.CUR_OBS],
|
|
|
|
vars[2],
|
|
|
|
vars[3],
|
|
|
|
framework=fw),
|
|
|
|
vars[0],
|
|
|
|
vars[1],
|
|
|
|
framework=fw)
|
2020-04-07 21:40:34 +02:00
|
|
|
expected_logp = dist_cls(expected_logits, policy.model).logp(
|
|
|
|
train_batch[SampleBatch.ACTIONS])
|
|
|
|
if sess:
|
|
|
|
expected_logp = sess.run(expected_logp)
|
|
|
|
else:
|
|
|
|
expected_logp = expected_logp.numpy()
|
|
|
|
expected_loss = -np.mean(
|
|
|
|
expected_logp *
|
|
|
|
(train_batch[Postprocessing.ADVANTAGES] if fw != "torch" else
|
|
|
|
train_batch[Postprocessing.ADVANTAGES].numpy()))
|
|
|
|
check(results, expected_loss, decimals=4)
|
2020-02-19 21:18:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-02-19 16:07:37 -08:00
|
|
|
import pytest
|
|
|
|
import sys
|
|
|
|
sys.exit(pytest.main(["-v", __file__]))
|