2021-11-01 21:46:02 +01:00
|
|
|
from ray.rllib.utils.deprecation import Deprecated
|
2021-08-03 18:30:02 -04:00
|
|
|
|
|
|
|
|
2018-12-08 16:28:58 -08:00
|
|
|
def override(cls):
|
2021-12-15 22:32:52 +01:00
|
|
|
"""Decorator for documenting method overrides.
|
2018-12-08 16:28:58 -08:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2020-10-07 22:11:07 -04:00
|
|
|
cls (type): The superclass that provides the overridden method. If this
|
2018-12-08 16:28:58 -08:00
|
|
|
cls does not actually have the method, an error is raised.
|
2021-12-15 22:32:52 +01:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.policy import Policy
|
|
|
|
>>> class TorchPolicy(Policy): # doctest: +SKIP
|
|
|
|
... ...
|
2021-12-15 22:32:52 +01:00
|
|
|
... # Indicates that `TorchPolicy.loss()` overrides the parent
|
|
|
|
... # Policy class' own `loss method. Leads to an error if Policy
|
|
|
|
... # does not have a `loss` method.
|
2022-03-25 01:04:02 +01:00
|
|
|
... @override(Policy) # doctest: +SKIP
|
|
|
|
... def loss(self, model, action_dist, train_batch): # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
2018-12-08 16:28:58 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def check_override(method):
|
|
|
|
if method.__name__ not in dir(cls):
|
2022-01-29 18:41:57 -08:00
|
|
|
raise NameError("{} does not override any method of {}".format(method, cls))
|
2018-12-08 16:28:58 -08:00
|
|
|
return method
|
|
|
|
|
|
|
|
return check_override
|
2019-01-23 21:27:26 -08:00
|
|
|
|
|
|
|
|
|
|
|
def PublicAPI(obj):
|
2021-12-15 22:32:52 +01:00
|
|
|
"""Decorator for documenting public APIs.
|
2019-01-23 21:27:26 -08:00
|
|
|
|
|
|
|
Public APIs are classes and methods exposed to end users of RLlib. You
|
|
|
|
can expect these APIs to remain stable across RLlib releases.
|
|
|
|
|
|
|
|
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
2019-04-07 00:36:18 -07:00
|
|
|
assumed part of the RLlib public API as well (e.g., all trainer classes
|
|
|
|
are in public API because Trainer is ``@PublicAPI``).
|
2019-01-23 21:27:26 -08:00
|
|
|
|
2019-04-07 00:36:18 -07:00
|
|
|
In addition, you can assume all trainer configurations are part of their
|
2019-01-23 21:27:26 -08:00
|
|
|
public API as well.
|
2021-12-15 22:32:52 +01:00
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # Indicates that the `Trainer` class is exposed to end users
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> # of RLlib and will remain stable across RLlib releases.
|
|
|
|
>>> from ray import tune
|
|
|
|
>>> @PublicAPI # doctest: +SKIP
|
|
|
|
>>> class Trainer(tune.Trainable): # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
2019-01-23 21:27:26 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
def DeveloperAPI(obj):
|
2021-12-15 22:32:52 +01:00
|
|
|
"""Decorator for documenting developer APIs.
|
2019-01-23 21:27:26 -08:00
|
|
|
|
|
|
|
Developer APIs are classes and methods explicitly exposed to developers
|
|
|
|
for the purposes of building custom algorithms or advanced training
|
|
|
|
strategies on top of RLlib internals. You can generally expect these APIs
|
|
|
|
to be stable sans minor changes (but less stable than public APIs).
|
|
|
|
|
|
|
|
Subclasses that inherit from a ``@DeveloperAPI`` base class can be
|
2020-05-21 10:16:18 -07:00
|
|
|
assumed part of the RLlib developer API as well.
|
2021-12-15 22:32:52 +01:00
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> # Indicates that the `TorchPolicy` class is exposed to end users
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> # of RLlib and will remain (relatively) stable across RLlib
|
|
|
|
>>> # releases.
|
|
|
|
>>> from ray.rllib.policy import Policy
|
|
|
|
>>> @DeveloperAPI # doctest: +SKIP
|
|
|
|
... class TorchPolicy(Policy): # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
2019-01-23 21:27:26 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
return obj
|
2021-08-03 18:30:02 -04:00
|
|
|
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
def ExperimentalAPI(obj):
|
2021-12-15 22:32:52 +01:00
|
|
|
"""Decorator for documenting experimental APIs.
|
2021-11-01 21:46:02 +01:00
|
|
|
|
|
|
|
Experimental APIs are classes and methods that are in development and may
|
|
|
|
change at any time in their development process. You should not expect
|
|
|
|
these APIs to be stable until their tag is changed to `DeveloperAPI` or
|
|
|
|
`PublicAPI`.
|
2021-08-03 18:30:02 -04:00
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
Subclasses that inherit from a ``@ExperimentalAPI`` base class can be
|
|
|
|
assumed experimental as well.
|
2021-12-15 22:32:52 +01:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.policy import Policy
|
|
|
|
>>> class TorchPolicy(Policy): # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
2021-12-15 22:32:52 +01:00
|
|
|
... # Indicates that the `TorchPolicy.loss` method is a new and
|
|
|
|
... # experimental API and may change frequently in future
|
|
|
|
... # releases.
|
2022-03-25 01:04:02 +01:00
|
|
|
... @ExperimentalAPI # doctest: +SKIP
|
|
|
|
... def loss(self, model, action_dist, train_batch): # doctest: +SKIP
|
|
|
|
... ... # doctest: +SKIP
|
2021-08-03 18:30:02 -04:00
|
|
|
"""
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
return obj
|
|
|
|
|
|
|
|
|
2021-11-16 14:49:41 +01:00
|
|
|
def OverrideToImplementCustomLogic(obj):
|
|
|
|
"""Users should override this in their sub-classes to implement custom logic.
|
|
|
|
|
|
|
|
Used in Trainer and Policy to tag methods that need overriding, e.g.
|
|
|
|
`Policy.loss()`.
|
2021-12-15 22:32:52 +01:00
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray.rllib.policy.torch_policy import TorchPolicy
|
|
|
|
>>> @overrides(TorchPolicy) # doctest: +SKIP
|
|
|
|
... @OverrideToImplementCustomLogic # doctest: +SKIP
|
|
|
|
... def loss(self, ...): # doctest: +SKIP
|
2021-12-15 22:32:52 +01:00
|
|
|
... # implement custom loss function here ...
|
|
|
|
... # ... w/o calling the corresponding `super().loss()` method.
|
2022-03-25 01:04:02 +01:00
|
|
|
... ... # doctest: +SKIP
|
2021-11-16 14:49:41 +01:00
|
|
|
"""
|
2022-05-17 08:16:08 -07:00
|
|
|
obj.__is_overriden__ = False
|
2021-11-16 14:49:41 +01:00
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
|
|
|
|
"""Users should override this in their sub-classes to implement custom logic.
|
|
|
|
|
|
|
|
Thereby, it is recommended (but not required) to call the super-class'
|
|
|
|
corresponding method.
|
|
|
|
|
|
|
|
Used in Trainer and Policy to tag methods that need overriding, but the
|
|
|
|
super class' method should still be called, e.g.
|
|
|
|
`Trainer.setup()`.
|
|
|
|
|
|
|
|
Examples:
|
2022-03-25 01:04:02 +01:00
|
|
|
>>> from ray import tune
|
|
|
|
>>> @overrides(tune.Trainable) # doctest: +SKIP
|
|
|
|
... @OverrideToImplementCustomLogic_CallToSuperRecommended # doctest: +SKIP
|
|
|
|
... def setup(self, config): # doctest: +SKIP
|
2021-12-15 22:32:52 +01:00
|
|
|
... # implement custom setup logic here ...
|
2022-03-25 01:04:02 +01:00
|
|
|
... super().setup(config) # doctest: +SKIP
|
2021-12-15 22:32:52 +01:00
|
|
|
... # ... or here (after having called super()'s setup method.
|
2021-11-16 14:49:41 +01:00
|
|
|
"""
|
2022-05-17 08:16:08 -07:00
|
|
|
obj.__is_overriden__ = False
|
2021-11-16 14:49:41 +01:00
|
|
|
return obj
|
|
|
|
|
|
|
|
|
2022-05-17 08:16:08 -07:00
|
|
|
def is_overridden(obj):
|
|
|
|
"""Check whether a function has been overridden.
|
|
|
|
Note, this only works for API calls decorated with OverrideToImplementCustomLogic
|
|
|
|
or OverrideToImplementCustomLogic_CallToSuperRecommended.
|
|
|
|
"""
|
|
|
|
return getattr(obj, "__is_overriden__", True)
|
|
|
|
|
|
|
|
|
2021-11-01 21:46:02 +01:00
|
|
|
# Backward compatibility.
|
|
|
|
Deprecated = Deprecated
|