mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLLib] Record framework and algorithm used by an RLlib run. (#26956)
Automatically record framework and algorithm used by RLlib jobs. For better planning.
This commit is contained in:
parent
22f0439c17
commit
ca5e0dcaf4
2 changed files with 21 additions and 0 deletions
|
@ -244,6 +244,9 @@ def _put_library_usage(library_usage: str):
|
|||
class TagKey(Enum):
|
||||
_TEST1 = auto()
|
||||
_TEST2 = auto()
|
||||
RLLIB_FRAMEWORK = auto()
|
||||
RLLIB_ALGORITHM = auto()
|
||||
RLLIB_NUM_WORKERS = auto()
|
||||
|
||||
|
||||
def record_extra_usage_tag(key: TagKey, value: str):
|
||||
|
|
|
@ -29,10 +29,12 @@ import pkg_resources
|
|||
from packaging import version
|
||||
|
||||
import ray
|
||||
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
|
||||
from ray.actor import ActorHandle
|
||||
from ray.exceptions import GetTimeoutError, RayActorError, RayError
|
||||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
||||
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
||||
from ray.rllib.algorithms.registry import ALGORITHMS as ALL_ALGORITHMS
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.utils import _gym_env_creator
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
|
@ -334,6 +336,8 @@ class Algorithm(Trainable):
|
|||
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
|
||||
|
||||
self.validate_config(self.config)
|
||||
self._record_usage(self.config)
|
||||
|
||||
self.callbacks = self.config["callbacks"]()
|
||||
log_level = self.config.get("log_level")
|
||||
if log_level in ["WARN", "ERROR"]:
|
||||
|
@ -2562,6 +2566,20 @@ class Algorithm(Trainable):
|
|||
def __repr__(self):
|
||||
return type(self).__name__
|
||||
|
||||
def _record_usage(self, config):
|
||||
"""Record the framework and algorithm used.
|
||||
|
||||
Args:
|
||||
config: Algorithm config dict.
|
||||
"""
|
||||
record_extra_usage_tag(TagKey.RLLIB_FRAMEWORK, config["framework"])
|
||||
record_extra_usage_tag(TagKey.RLLIB_NUM_WORKERS, str(config["num_workers"]))
|
||||
alg = self.__class__.__name__
|
||||
# We do not want to collect user defined algorithm names.
|
||||
if alg not in ALL_ALGORITHMS:
|
||||
alg = "USER_DEFINED"
|
||||
record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)
|
||||
|
||||
@Deprecated(new="Trainer.compute_single_action()", error=False)
|
||||
def compute_action(self, *args, **kwargs):
|
||||
return self.compute_single_action(*args, **kwargs)
|
||||
|
|
Loading…
Add table
Reference in a new issue