[RLlib] Add @OverrideToImplementCustomLogic decorators to some Trainer class methods. (#24684)

This commit is contained in:
Sven Mika 2022-05-24 11:30:50 +02:00 committed by GitHub
parent 5b9b4fa018
commit 4e99a57bab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -58,6 +58,8 @@ from ray.rllib.utils.annotations import (
DeveloperAPI,
ExperimentalAPI,
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
PublicAPI,
)
from ray.rllib.utils.debug import update_global_seed_if_necessary
@ -288,10 +290,12 @@ class Trainer(Trainable):
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
)
@OverrideToImplementCustomLogic
@classmethod
def get_default_config(cls) -> TrainerConfigDict:
return TrainerConfig().to_dict()
@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Trainable)
def setup(self, config: PartialTrainerConfigDict):
@ -487,6 +491,7 @@ class Trainer(Trainable):
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator) -> None:
raise NotImplementedError
@OverrideToImplementCustomLogic
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
"""Returns a default Policy class to use, given a config.
@ -845,6 +850,7 @@ class Trainer(Trainable):
# Also return the results here for convenience.
return self.evaluation_metrics
@OverrideToImplementCustomLogic
@DeveloperAPI
def training_iteration(self) -> ResultDict:
"""Default single iteration logic of an algorithm.
@ -1467,9 +1473,12 @@ class Trainer(Trainable):
@override(Trainable)
def cleanup(self) -> None:
# Stop all workers.
if hasattr(self, "workers"):
if hasattr(self, "workers") and self.workers is not None:
self.workers.stop()
if hasattr(self, "evaluation_workers") and self.evaluation_workers is not None:
self.evaluation_workers.stop()
@OverrideToImplementCustomLogic
@classmethod
@override(Trainable)
def default_resource_request(
@ -1777,6 +1786,7 @@ class Trainer(Trainable):
check_if_correct_nn_framework_installed()
resolve_tf_settings()
@OverrideToImplementCustomLogic_CallToSuperRecommended
@DeveloperAPI
def validate_config(self, config: TrainerConfigDict) -> None:
"""Validates a given config dict for this Trainer.