ray/rllib/agents/cql/tests/test_cql.py

137 lines
4.9 KiB
Python
Raw Normal View History

import numpy as np
from pathlib import Path
import os
import unittest
import ray
import ray.rllib.agents.cql as cql
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_compute_single_action, \
check_train_results, framework_iterator
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class TestCQL(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_cql_compilation(self):
"""Test whether a CQLTrainer can be built with all frameworks."""
# Learns from a historic-data file.
# To generate this data, first run:
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
# $ ./train.py --run=SAC --env=Pendulum-v1 \
# --stop='{"timesteps_total": 50000}' \
# --config='{"output": "/tmp/out"}'
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json")
print("data_file={} exists={}".format(data_file,
os.path.isfile(data_file)))
config = cql.CQL_DEFAULT_CONFIG.copy()
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535) * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 * Reformatting * Fixing tests * Move atari-py install conditional to req.txt * migrate to new ale install method * Fix QMix, SAC, and MADDPA too. * Unpin gym and deprecate pendulum v0 Many tests in rllib depended on pendulum v0, however in gym 0.21, pendulum v0 was deprecated in favor of pendulum v1. This may change reward thresholds, so will have to potentially rerun all of the pendulum v1 benchmarks, or use another environment in favor. The same applies to frozen lake v0 and frozen lake v1 Lastly, all of the RLlib tests and have been moved to python 3.7 * Add gym installation based on python version. Pin python<= 3.6 to gym 0.19 due to install issues with atari roms in gym 0.20 Move atari-py install conditional to req.txt migrate to new ale install method Make parametric_actions_cartpole return float32 actions/obs Adding type conversions if obs/actions don't match space Add utils to make elements match gym space dtypes Co-authored-by: Jun Gong <jungong@anyscale.com> Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
config["env"] = "Pendulum-v1"
config["input"] = [data_file]
# In the files, we use here for testing, actions have already
# been normalized.
# This is usually the case when the file was generated by another
# RLlib algorithm (e.g. PPO or SAC).
config["actions_in_input_normalized"] = False
config["clip_actions"] = True
config["train_batch_size"] = 2000
config["num_workers"] = 0 # Run locally.
config["twin_q"] = True
config["learning_starts"] = 0
config["bc_iters"] = 2 # 2 BC iters, 2 CQL iters.
config["rollout_fragment_length"] = 1
# Switch on off-policy evaluation.
config["input_evaluation"] = ["is"]
config["evaluation_interval"] = 2
config["evaluation_duration"] = 10
config["evaluation_config"]["input"] = "sampler"
config["evaluation_parallel_to_training"] = False
config["evaluation_num_workers"] = 2
num_iterations = 4
# Test for tf/torch frameworks.
for fw in framework_iterator(config, with_eager_tracing=True):
trainer = cql.CQLTrainer(config=config)
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
print(results)
eval_results = results.get("evaluation")
if eval_results:
print(f"iter={trainer.iteration} "
f"R={eval_results['episode_reward_mean']}")
check_compute_single_action(trainer)
# Get policy and model.
pol = trainer.get_policy()
cql_model = pol.model
if fw == "tf":
pol.get_session().__enter__()
# Example on how to do evaluation on the trained Trainer
# using the data from CQL's global replay buffer.
# Get a sample (MultiAgentBatch -> SampleBatch).
batch = trainer.local_replay_buffer.replay().policy_batches[
"default_policy"]
if fw == "torch":
obs = torch.from_numpy(batch["obs"])
else:
obs = batch["obs"]
batch["actions"] = batch["actions"].astype(np.float32)
# Pass the observations through our model to get the
# features, which then to pass through the Q-head.
model_out, _ = cql_model({"obs": obs})
# The estimated Q-values from the (historic) actions in the batch.
if fw == "torch":
q_values_old = cql_model.get_q_values(
model_out, torch.from_numpy(batch["actions"]))
else:
q_values_old = cql_model.get_q_values(
tf.convert_to_tensor(model_out), batch["actions"])
# The estimated Q-values for the new actions computed
# by our trainer policy.
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
if fw == "torch":
q_values_new = cql_model.get_q_values(
model_out, torch.from_numpy(actions_new))
else:
q_values_new = cql_model.get_q_values(model_out, actions_new)
if fw == "tf":
q_values_old, q_values_new = pol.get_session().run(
[q_values_old, q_values_new])
print(f"Q-val batch={q_values_old}")
print(f"Q-val policy={q_values_new}")
if fw == "tf":
pol.get_session().__exit__(None, None, None)
trainer.stop()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))