2020-09-09 17:33:21 +02:00
|
|
|
"""Behavioral Cloning (derived from MARWIL).
|
|
|
|
|
|
|
|
Simply uses the MARWIL agent with beta force-set to 0.0.
|
|
|
|
"""
|
|
|
|
from ray.rllib.agents.marwil.marwil import MARWILTrainer, \
|
|
|
|
DEFAULT_CONFIG as MARWIL_CONFIG
|
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
|
|
|
|
MARWIL_CONFIG, {
|
2020-12-27 09:46:03 -05:00
|
|
|
# No need to calculate advantages (or do anything else with the
|
|
|
|
# rewards).
|
2020-09-09 17:33:21 +02:00
|
|
|
"beta": 0.0,
|
2020-12-27 09:46:03 -05:00
|
|
|
# Advantages (calculated during postprocessing) not important for
|
|
|
|
# behavioral cloning.
|
|
|
|
"postprocess_inputs": False,
|
|
|
|
# No reward estimation.
|
|
|
|
"input_evaluation": [],
|
2020-09-09 17:33:21 +02:00
|
|
|
})
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
# yapf: enable
|
|
|
|
|
|
|
|
|
|
|
|
def validate_config(config: TrainerConfigDict):
|
|
|
|
if config["beta"] != 0.0:
|
|
|
|
raise ValueError(
|
|
|
|
"For behavioral cloning, `beta` parameter must be 0.0!")
|
|
|
|
|
|
|
|
|
|
|
|
BCTrainer = MARWILTrainer.with_updates(
|
|
|
|
name="BC",
|
|
|
|
default_config=BC_DEFAULT_CONFIG,
|
|
|
|
validate_config=validate_config,
|
|
|
|
)
|