ray/rllib/agents/cql/cql.py

149 lines
5 KiB
Python
Raw Normal View History

"""CQL (derived from SAC).
"""
from typing import Optional, Type, List
from ray.actor import ActorHandle
from ray.rllib.agents.sac.sac import SACTrainer, \
DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import merge_dicts
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.policy.policy import LEARNER_STATS_KEY
from ray.rllib.agents.dqn.dqn import calculate_rr_weights
# yapf: disable
# __sphinx_doc_begin__
CQL_DEFAULT_CONFIG = merge_dicts(
SAC_CONFIG, {
# You should override this to point to an offline dataset.
"input": "sampler",
# Offline RL does not need IS estimators
"input_evaluation": [],
# Number of iterations with Behavior Cloning Pretraining
"bc_iters": 20000,
# CQL Loss Temperature
"temperature": 1.0,
# Num Actions to sample for CQL Loss
"num_actions": 10,
# Whether to use the Langrangian for Alpha Prime (in CQL Loss)
"lagrangian": False,
# Lagrangian Threshold
"lagrangian_thresh": 5.0,
# Min Q Weight multiplier
"min_q_weight": 5.0,
# Replay Buffer should be size of offline dataset
"buffer_size": 1000000,
})
# __sphinx_doc_end__
# yapf: enable
def validate_config(config: TrainerConfigDict):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for CQL!")
if config["framework"] == "tf":
raise ValueError("Tensorflow CQL not implemented yet!")
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return CQLTorchPolicy
replay_buffer = None
class NoOpReplayBuffer:
def __init__(self,
*,
local_buffer: LocalReplayBuffer = None,
actors: List[ActorHandle] = None):
return
def __call__(self, batch):
return batch
def execution_plan(workers, config):
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
**prio_args)
global replay_buffer
replay_buffer = local_replay_buffer
rollouts = ParallelRollouts(workers, mode="bulk_sync")
store_op = rollouts.for_each(
NoOpReplayBuffer(local_buffer=local_replay_buffer))
def update_prio(item):
samples, info_dict = item
if config.get("prioritized_replay"):
prio_dict = {}
for policy_id, info in info_dict.items():
td_error = info.get("td_error",
info[LEARNER_STATS_KEY].get("td_error"))
prio_dict[policy_id] = (samples.policy_batches[policy_id]
.get("batch_indexes"), td_error)
local_replay_buffer.update_priorities(prio_dict)
return info_dict
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(TrainOneStep(workers)) \
.for_each(update_prio) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config))
return StandardMetricsReporting(train_op, workers, config)
def after_init(trainer):
# Add the entire dataset to Replay Buffer (global variable)
global replay_buffer
local_worker = trainer.workers.local_worker
if "d4rl" in trainer.config["input"]:
dataset = local_worker().input_reader.dataset
replay_buffer.add_batch(dataset)
CQLTrainer = SACTrainer.with_updates(
name="CQL",
default_config=CQL_DEFAULT_CONFIG,
validate_config=validate_config,
default_policy=CQLTorchPolicy,
get_policy_class=get_policy_class,
after_init=after_init,
execution_plan=execution_plan,
)