mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add @OverrideToImplementCustomLogic
decorators to some Trainer
class methods. (#24684)
This commit is contained in:
parent
5b9b4fa018
commit
4e99a57bab
1 changed files with 11 additions and 1 deletions
|
@ -58,6 +58,8 @@ from ray.rllib.utils.annotations import (
|
|||
DeveloperAPI,
|
||||
ExperimentalAPI,
|
||||
override,
|
||||
OverrideToImplementCustomLogic,
|
||||
OverrideToImplementCustomLogic_CallToSuperRecommended,
|
||||
PublicAPI,
|
||||
)
|
||||
from ray.rllib.utils.debug import update_global_seed_if_necessary
|
||||
|
@ -288,10 +290,12 @@ class Trainer(Trainable):
|
|||
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
|
||||
)
|
||||
|
||||
@OverrideToImplementCustomLogic
|
||||
@classmethod
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return TrainerConfig().to_dict()
|
||||
|
||||
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
||||
@override(Trainable)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
|
||||
|
@ -487,6 +491,7 @@ class Trainer(Trainable):
|
|||
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@OverrideToImplementCustomLogic
|
||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||
"""Returns a default Policy class to use, given a config.
|
||||
|
||||
|
@ -845,6 +850,7 @@ class Trainer(Trainable):
|
|||
# Also return the results here for convenience.
|
||||
return self.evaluation_metrics
|
||||
|
||||
@OverrideToImplementCustomLogic
|
||||
@DeveloperAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
"""Default single iteration logic of an algorithm.
|
||||
|
@ -1467,9 +1473,12 @@ class Trainer(Trainable):
|
|||
@override(Trainable)
|
||||
def cleanup(self) -> None:
|
||||
# Stop all workers.
|
||||
if hasattr(self, "workers"):
|
||||
if hasattr(self, "workers") and self.workers is not None:
|
||||
self.workers.stop()
|
||||
if hasattr(self, "evaluation_workers") and self.evaluation_workers is not None:
|
||||
self.evaluation_workers.stop()
|
||||
|
||||
@OverrideToImplementCustomLogic
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(
|
||||
|
@ -1777,6 +1786,7 @@ class Trainer(Trainable):
|
|||
check_if_correct_nn_framework_installed()
|
||||
resolve_tf_settings()
|
||||
|
||||
@OverrideToImplementCustomLogic_CallToSuperRecommended
|
||||
@DeveloperAPI
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
"""Validates a given config dict for this Trainer.
|
||||
|
|
Loading…
Add table
Reference in a new issue