mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00
52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
![]() |
"""CQL (derived from SAC).
|
||
|
"""
|
||
|
from typing import Optional, Type
|
||
|
|
||
|
from ray.rllib.agents.sac.sac import SACTrainer, \
|
||
|
DEFAULT_CONFIG as SAC_CONFIG
|
||
|
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||
|
from ray.rllib.policy.policy import Policy
|
||
|
from ray.rllib.utils import merge_dicts
|
||
|
|
||
|
# yapf: disable
|
||
|
# __sphinx_doc_begin__
|
||
|
CQL_DEFAULT_CONFIG = merge_dicts(
|
||
|
SAC_CONFIG, {
|
||
|
# You should override this to point to an offline dataset.
|
||
|
"input": "sampler",
|
||
|
# Number of iterations with Behavior Cloning Pretraining
|
||
|
"bc_iters": 20000,
|
||
|
# CQL Loss Temperature
|
||
|
"temperature": 1.0,
|
||
|
# Num Actions to sample for CQL Loss
|
||
|
"num_actions": 10,
|
||
|
# Whether to use the Langrangian for Alpha Prime (in CQL Loss)
|
||
|
"lagrangian": False,
|
||
|
# Lagrangian Threshold
|
||
|
"lagrangian_thresh": 5.0,
|
||
|
# Min Q Weight multiplier
|
||
|
"min_q_weight": 5.0,
|
||
|
})
|
||
|
# __sphinx_doc_end__
|
||
|
# yapf: enable
|
||
|
|
||
|
|
||
|
def validate_config(config: TrainerConfigDict):
|
||
|
if config["framework"] == "tf":
|
||
|
raise ValueError("Tensorflow CQL not implemented yet!")
|
||
|
|
||
|
|
||
|
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||
|
if config["framework"] == "torch":
|
||
|
return CQLTorchPolicy
|
||
|
|
||
|
|
||
|
CQLTrainer = SACTrainer.with_updates(
|
||
|
name="CQL",
|
||
|
default_config=CQL_DEFAULT_CONFIG,
|
||
|
validate_config=validate_config,
|
||
|
default_policy=CQLTorchPolicy,
|
||
|
get_policy_class=get_policy_class,
|
||
|
)
|