2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.agents.marwil.marwil import (
|
|
|
|
MARWILTrainer,
|
|
|
|
DEFAULT_CONFIG as MARWIL_CONFIG,
|
|
|
|
)
|
2021-12-04 22:05:26 +01:00
|
|
|
from ray.rllib.utils.annotations import override
|
2020-09-09 17:33:21 +02:00
|
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: off
|
2020-09-09 17:33:21 +02:00
|
|
|
# __sphinx_doc_begin__
|
|
|
|
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
|
|
|
|
MARWIL_CONFIG, {
|
2020-12-27 09:46:03 -05:00
|
|
|
# No need to calculate advantages (or do anything else with the
|
|
|
|
# rewards).
|
2020-09-09 17:33:21 +02:00
|
|
|
"beta": 0.0,
|
2020-12-27 09:46:03 -05:00
|
|
|
# Advantages (calculated during postprocessing) not important for
|
|
|
|
# behavioral cloning.
|
|
|
|
"postprocess_inputs": False,
|
|
|
|
# No reward estimation.
|
|
|
|
"input_evaluation": [],
|
2020-09-09 17:33:21 +02:00
|
|
|
})
|
|
|
|
# __sphinx_doc_end__
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: on
|
2020-09-09 17:33:21 +02:00
|
|
|
|
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
class BCTrainer(MARWILTrainer):
|
|
|
|
"""Behavioral Cloning (derived from MARWIL).
|
|
|
|
|
|
|
|
Simply uses the MARWIL agent with beta force-set to 0.0.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@override(MARWILTrainer)
|
|
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
|
|
return BC_DEFAULT_CONFIG
|
2020-09-09 17:33:21 +02:00
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
@override(MARWILTrainer)
|
|
|
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
2022-01-10 11:19:40 +01:00
|
|
|
# Call super's validation method.
|
2021-12-04 22:05:26 +01:00
|
|
|
super().validate_config(config)
|
2020-09-09 17:33:21 +02:00
|
|
|
|
2021-12-04 22:05:26 +01:00
|
|
|
if config["beta"] != 0.0:
|
2022-01-29 18:41:57 -08:00
|
|
|
raise ValueError("For behavioral cloning, `beta` parameter must be 0.0!")
|