2022-06-04 07:35:24 +02:00
|
|
|
"""Example on how to use CQL to learn from an offline json file.
|
2021-05-04 10:06:19 -07:00
|
|
|
|
|
|
|
Important node: Make sure that your offline data file contains only
|
|
|
|
a single timestep per line to mimic the way SAC pulls samples from
|
|
|
|
the buffer.
|
|
|
|
|
|
|
|
Generate the offline json file by running an SAC algo until it reaches expert
|
2021-05-18 11:10:46 +02:00
|
|
|
level on your command line. For example:
|
2021-05-04 10:06:19 -07:00
|
|
|
$ cd ray
|
2021-05-18 11:10:46 +02:00
|
|
|
$ rllib train -f rllib/tuned_examples/sac/pendulum-sac.yaml --no-ray-ui
|
2021-05-04 10:06:19 -07:00
|
|
|
|
2021-05-18 11:10:46 +02:00
|
|
|
Also make sure that in the above SAC yaml file (pendulum-sac.yaml),
|
|
|
|
you specify an additional "output" key with any path on your local
|
|
|
|
file system. In that path, the offline json files will be written to.
|
2021-05-04 10:06:19 -07:00
|
|
|
|
2021-05-18 11:10:46 +02:00
|
|
|
Use the generated file(s) as "input" in the CQL config below
|
|
|
|
(`config["input"] = [list of your json files]`), then run this script.
|
2021-05-04 10:06:19 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms import cql as cql
|
2021-05-04 10:06:19 -07:00
|
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
|
|
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
# See rllib/tuned_examples/cql/pendulum-cql.yaml for comparison.
|
|
|
|
|
2022-05-22 18:58:47 +01:00
|
|
|
config = cql.DEFAULT_CONFIG.copy()
|
2021-05-04 10:06:19 -07:00
|
|
|
config["num_workers"] = 0 # Run locally.
|
|
|
|
config["horizon"] = 200
|
|
|
|
config["soft_horizon"] = True
|
|
|
|
config["no_done_at_end"] = True
|
|
|
|
config["n_step"] = 3
|
|
|
|
config["bc_iters"] = 0
|
|
|
|
config["clip_actions"] = False
|
|
|
|
config["normalize_actions"] = True
|
2022-05-17 13:43:49 +02:00
|
|
|
config["replay_buffer_config"]["learning_starts"] = 256
|
2021-05-04 10:06:19 -07:00
|
|
|
config["rollout_fragment_length"] = 1
|
2022-05-17 13:43:49 +02:00
|
|
|
# Test without prioritized replay
|
|
|
|
config["replay_buffer_config"]["type"] = "MultiAgentReplayBuffer"
|
|
|
|
config["replay_buffer_config"]["capacity"] = int(1e6)
|
2021-05-04 10:06:19 -07:00
|
|
|
config["tau"] = 0.005
|
|
|
|
config["target_entropy"] = "auto"
|
2022-05-22 18:58:47 +01:00
|
|
|
config["q_model_config"] = {
|
2021-05-04 10:06:19 -07:00
|
|
|
"fcnet_hiddens": [256, 256],
|
|
|
|
"fcnet_activation": "relu",
|
|
|
|
}
|
2022-05-22 18:58:47 +01:00
|
|
|
config["policy_model_config"] = {
|
2021-05-04 10:06:19 -07:00
|
|
|
"fcnet_hiddens": [256, 256],
|
|
|
|
"fcnet_activation": "relu",
|
|
|
|
}
|
|
|
|
config["optimization"] = {
|
|
|
|
"actor_learning_rate": 3e-4,
|
|
|
|
"critic_learning_rate": 3e-4,
|
|
|
|
"entropy_learning_rate": 3e-4,
|
|
|
|
}
|
|
|
|
config["train_batch_size"] = 256
|
|
|
|
config["target_network_update_freq"] = 1
|
2022-06-10 17:09:18 +02:00
|
|
|
config["min_train_timesteps_per_iteration"] = 1000
|
2021-05-04 10:06:19 -07:00
|
|
|
data_file = "/path/to/my/json_file.json"
|
|
|
|
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
|
|
|
config["input"] = [data_file]
|
|
|
|
config["log_level"] = "INFO"
|
[RLlib] Upgrade gym version to 0.21 and deprecate pendulum-v0. (#19535)
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
* Reformatting
* Fixing tests
* Move atari-py install conditional to req.txt
* migrate to new ale install method
* Fix QMix, SAC, and MADDPA too.
* Unpin gym and deprecate pendulum v0
Many tests in rllib depended on pendulum v0,
however in gym 0.21, pendulum v0 was deprecated
in favor of pendulum v1. This may change reward
thresholds, so will have to potentially rerun
all of the pendulum v1 benchmarks, or use another
environment in favor. The same applies to frozen
lake v0 and frozen lake v1
Lastly, all of the RLlib tests and have
been moved to python 3.7
* Add gym installation based on python version.
Pin python<= 3.6 to gym 0.19 due to install
issues with atari roms in gym 0.20
Move atari-py install conditional to req.txt
migrate to new ale install method
Make parametric_actions_cartpole return float32 actions/obs
Adding type conversions if obs/actions don't match space
Add utils to make elements match gym space dtypes
Co-authored-by: Jun Gong <jungong@anyscale.com>
Co-authored-by: sven1977 <svenmika1977@gmail.com>
2021-11-03 08:24:00 -07:00
|
|
|
config["env"] = "Pendulum-v1"
|
2021-05-04 10:06:19 -07:00
|
|
|
|
|
|
|
# Set up evaluation.
|
|
|
|
config["evaluation_num_workers"] = 1
|
|
|
|
config["evaluation_interval"] = 1
|
2021-12-04 13:26:33 +01:00
|
|
|
config["evaluation_duration"] = 10
|
2021-05-04 10:06:19 -07:00
|
|
|
# This should be False b/c iterations are very long and this would
|
|
|
|
# cause evaluation to lag one iter behind training.
|
|
|
|
config["evaluation_parallel_to_training"] = False
|
|
|
|
# Evaluate on actual environment.
|
|
|
|
config["evaluation_config"] = {"input": "sampler"}
|
|
|
|
|
|
|
|
# Check, whether we can learn from the given file in `num_iterations`
|
|
|
|
# iterations, up to a reward of `min_reward`.
|
2021-05-13 09:17:23 +02:00
|
|
|
num_iterations = 5
|
2021-05-04 10:06:19 -07:00
|
|
|
min_reward = -300
|
|
|
|
|
|
|
|
# Test for torch framework (tf not implemented yet).
|
2022-06-20 15:54:00 +02:00
|
|
|
algo = cql.CQL(config=config)
|
2021-05-04 10:06:19 -07:00
|
|
|
learnt = False
|
|
|
|
for i in range(num_iterations):
|
|
|
|
print(f"Iter {i}")
|
2022-06-20 15:54:00 +02:00
|
|
|
eval_results = algo.train().get("evaluation")
|
2021-05-04 10:06:19 -07:00
|
|
|
if eval_results:
|
|
|
|
print("... R={}".format(eval_results["episode_reward_mean"]))
|
|
|
|
# Learn until some reward is reached on an actual live env.
|
|
|
|
if eval_results["episode_reward_mean"] >= min_reward:
|
|
|
|
learnt = True
|
|
|
|
break
|
|
|
|
if not learnt:
|
|
|
|
raise ValueError(
|
2022-06-04 07:35:24 +02:00
|
|
|
"CQL did not reach {} reward from expert "
|
2021-05-04 10:06:19 -07:00
|
|
|
"offline data!".format(min_reward)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-05-04 10:06:19 -07:00
|
|
|
|
2021-05-13 09:17:23 +02:00
|
|
|
# Get policy, model, and replay-buffer.
|
2022-06-20 15:54:00 +02:00
|
|
|
pol = algo.get_policy()
|
2021-05-13 09:17:23 +02:00
|
|
|
cql_model = pol.model
|
2022-05-19 09:30:42 -07:00
|
|
|
from ray.rllib.algorithms.cql.cql import replay_buffer
|
2021-05-13 09:17:23 +02:00
|
|
|
|
2021-05-04 10:06:19 -07:00
|
|
|
# If you would like to query CQL's learnt Q-function for arbitrary
|
|
|
|
# (cont.) actions, do the following:
|
|
|
|
obs_batch = torch.from_numpy(np.random.random(size=(5, 3)))
|
|
|
|
action_batch = torch.from_numpy(np.random.random(size=(5, 1)))
|
2021-05-13 09:17:23 +02:00
|
|
|
q_values = cql_model.get_q_values(obs_batch, action_batch)
|
2021-05-04 10:06:19 -07:00
|
|
|
# If you are using the "twin_q", there'll be 2 Q-networks and
|
|
|
|
# we usually consider the min of the 2 outputs, like so:
|
2021-05-13 09:17:23 +02:00
|
|
|
twin_q_values = cql_model.get_twin_q_values(obs_batch, action_batch)
|
2021-05-04 10:06:19 -07:00
|
|
|
final_q_values = torch.min(q_values, twin_q_values)
|
|
|
|
print(final_q_values)
|
|
|
|
|
2022-06-20 15:54:00 +02:00
|
|
|
# Example on how to do evaluation on the trained Algorithm.
|
2021-05-13 09:17:23 +02:00
|
|
|
# using the data from our buffer.
|
2022-05-24 14:39:43 +02:00
|
|
|
# Get a sample (MultiAgentBatch).
|
|
|
|
multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"])
|
|
|
|
# All experiences have been buffered for `default_policy`
|
|
|
|
batch = multi_agent_batch.policy_batches["default_policy"]
|
2021-05-13 09:17:23 +02:00
|
|
|
obs = torch.from_numpy(batch["obs"])
|
|
|
|
# Pass the observations through our model to get the
|
|
|
|
# features, which then to pass through the Q-head.
|
|
|
|
model_out, _ = cql_model({"obs": obs})
|
|
|
|
# The estimated Q-values from the (historic) actions in the batch.
|
|
|
|
q_values_old = cql_model.get_q_values(model_out, torch.from_numpy(batch["actions"]))
|
2022-06-20 15:54:00 +02:00
|
|
|
# The estimated Q-values for the new actions computed by our policy.
|
2021-05-13 09:17:23 +02:00
|
|
|
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
|
|
|
|
q_values_new = cql_model.get_q_values(model_out, torch.from_numpy(actions_new))
|
|
|
|
print(f"Q-val batch={q_values_old}")
|
|
|
|
print(f"Q-val policy={q_values_new}")
|
|
|
|
|
2022-06-20 15:54:00 +02:00
|
|
|
algo.stop()
|