[RLLib] Record framework and algorithm used by an RLlib run. (#26956)

Automatically record framework and algorithm used by RLlib jobs.
For better planning.
This commit is contained in:
Jun Gong 2022-07-25 16:16:36 -07:00 committed by GitHub
parent 22f0439c17
commit ca5e0dcaf4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 0 deletions

View file

@ -244,6 +244,9 @@ def _put_library_usage(library_usage: str):
class TagKey(Enum):
_TEST1 = auto()
_TEST2 = auto()
RLLIB_FRAMEWORK = auto()
RLLIB_ALGORITHM = auto()
RLLIB_NUM_WORKERS = auto()
def record_extra_usage_tag(key: TagKey, value: str):

View file

@ -29,10 +29,12 @@ import pkg_resources
from packaging import version
import ray
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray.actor import ActorHandle
from ray.exceptions import GetTimeoutError, RayActorError, RayError
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.registry import ALGORITHMS as ALL_ALGORITHMS
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.evaluation.episode import Episode
@ -334,6 +336,8 @@ class Algorithm(Trainable):
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])
self.validate_config(self.config)
self._record_usage(self.config)
self.callbacks = self.config["callbacks"]()
log_level = self.config.get("log_level")
if log_level in ["WARN", "ERROR"]:
@ -2562,6 +2566,20 @@ class Algorithm(Trainable):
def __repr__(self):
return type(self).__name__
def _record_usage(self, config):
"""Record the framework and algorithm used.
Args:
config: Algorithm config dict.
"""
record_extra_usage_tag(TagKey.RLLIB_FRAMEWORK, config["framework"])
record_extra_usage_tag(TagKey.RLLIB_NUM_WORKERS, str(config["num_workers"]))
alg = self.__class__.__name__
# We do not want to collect user defined algorithm names.
if alg not in ALL_ALGORITHMS:
alg = "USER_DEFINED"
record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)
@Deprecated(new="Trainer.compute_single_action()", error=False)
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)