mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
from pathlib import Path
|
|
import os
|
|
import unittest
|
|
|
|
import ray
|
|
import ray.rllib.agents.cql as cql
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
|
framework_iterator
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class TestCQL(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
ray.init()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
ray.shutdown()
|
|
|
|
def test_cql_compilation(self):
|
|
"""Test whether a CQLTrainer can be built with all frameworks."""
|
|
|
|
# Learns from a historic-data file.
|
|
# To generate this data, first run:
|
|
# $ ./train.py --run=SAC --env=Pendulum-v0 \
|
|
# --stop='{"timesteps_total": 50000}' \
|
|
# --config='{"output": "/tmp/out"}'
|
|
rllib_dir = Path(__file__).parent.parent.parent.parent
|
|
print("rllib dir={}".format(rllib_dir))
|
|
data_file = os.path.join(rllib_dir, "tests/data/pendulum/small.json")
|
|
print("data_file={} exists={}".format(data_file,
|
|
os.path.isfile(data_file)))
|
|
|
|
config = cql.CQL_DEFAULT_CONFIG.copy()
|
|
config["env"] = "Pendulum-v0"
|
|
config["input"] = [data_file]
|
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
config["twin_q"] = True
|
|
config["clip_actions"] = False
|
|
config["normalize_actions"] = True
|
|
config["learning_starts"] = 0
|
|
config["rollout_fragment_length"] = 1
|
|
config["train_batch_size"] = 10
|
|
|
|
# Switch on off-policy evaluation.
|
|
config["input_evaluation"] = ["is"]
|
|
|
|
num_iterations = 2
|
|
|
|
# Test for tf framework (torch not implemented yet).
|
|
for _ in framework_iterator(config, frameworks=("torch")):
|
|
trainer = cql.CQLTrainer(config=config)
|
|
for i in range(num_iterations):
|
|
print(trainer.train())
|
|
|
|
check_compute_single_action(trainer)
|
|
|
|
# Get policy, model, and replay-buffer.
|
|
pol = trainer.get_policy()
|
|
cql_model = pol.model
|
|
from ray.rllib.agents.cql.cql import replay_buffer
|
|
|
|
# Example on how to do evaluation on the trained Trainer
|
|
# using the data from our buffer.
|
|
# Get a sample (MultiAgentBatch -> SampleBatch).
|
|
batch = replay_buffer.replay().policy_batches["default_policy"]
|
|
obs = torch.from_numpy(batch["obs"])
|
|
# Pass the observations through our model to get the
|
|
# features, which then to pass through the Q-head.
|
|
model_out, _ = cql_model({"obs": obs})
|
|
# The estimated Q-values from the (historic) actions in the batch.
|
|
q_values_old = cql_model.get_q_values(
|
|
model_out, torch.from_numpy(batch["actions"]))
|
|
# The estimated Q-values for the new actions computed
|
|
# by our trainer policy.
|
|
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
|
|
q_values_new = cql_model.get_q_values(
|
|
model_out, torch.from_numpy(actions_new))
|
|
print(f"Q-val batch={q_values_old}")
|
|
print(f"Q-val policy={q_values_new}")
|
|
|
|
trainer.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import pytest
|
|
import sys
|
|
sys.exit(pytest.main(["-v", __file__]))
|