mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[tune] Next deprecation cycle (#24076)
Rolling out next deprecation cycle: - DeprecationWarnings that were `warnings.warn` or `logger.warn` before are now raised errors - Raised Deprecation warnings are now removed - Notably, this involves deprecating the TrialCheckpoint functionality and associated cloud tests - Added annotations to deprecation warning for when to fully remove
This commit is contained in:
parent
2c772a421f
commit
c0ec20dc3a
49 changed files with 170 additions and 1556 deletions
|
@ -45,9 +45,3 @@ ExperimentAnalysis (tune.ExperimentAnalysis)
|
|||
|
||||
.. autoclass:: ray.tune.ExperimentAnalysis
|
||||
:members:
|
||||
|
||||
TrialCheckpoint (tune.cloud.TrialCheckpoint)
|
||||
--------------------------------------------
|
||||
|
||||
.. autoclass:: ray.tune.cloud.TrialCheckpoint
|
||||
:members:
|
|
@ -75,6 +75,7 @@ def is_non_local_path_uri(uri: str) -> bool:
|
|||
return True
|
||||
# Keep manual check for prefixes for backwards compatibility with the
|
||||
# TrialCheckpoint class. Remove once fully deprecated.
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
if any(uri.startswith(p) for p in ALLOWED_REMOTE_PREFIXES):
|
||||
return True
|
||||
return False
|
||||
|
|
|
@ -56,14 +56,6 @@ py_test(
|
|||
tags = ["team:ml", "client", "py37", "exclusive", "tests_dir_C"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_cloud",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_cloud.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["team:ml", "exclusive", "tests_dir_C"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_cluster",
|
||||
size = "large",
|
||||
|
|
|
@ -15,8 +15,6 @@ from ray.tune.session import (
|
|||
get_trial_name,
|
||||
get_trial_id,
|
||||
get_trial_resources,
|
||||
make_checkpoint_dir,
|
||||
save_checkpoint,
|
||||
checkpoint_dir,
|
||||
is_session_enabled,
|
||||
)
|
||||
|
|
|
@ -843,12 +843,11 @@ class ExperimentAnalysis:
|
|||
return state
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
class Analysis(ExperimentAnalysis):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if log_once("durable_deprecated"):
|
||||
logger.warning(
|
||||
"DeprecationWarning: The `Analysis` class is being "
|
||||
"deprecated. Please use `ExperimentAnalysis` instead."
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
raise DeprecationWarning(
|
||||
"The `Analysis` class is being "
|
||||
"deprecated. Please use `ExperimentAnalysis` instead."
|
||||
)
|
||||
|
|
|
@ -1,22 +1,11 @@
|
|||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from ray import logger
|
||||
from ray.ml.checkpoint import (
|
||||
Checkpoint,
|
||||
_get_local_path,
|
||||
_get_external_path,
|
||||
)
|
||||
from ray.ml.utils.remote_storage import (
|
||||
download_from_uri,
|
||||
delete_at_uri,
|
||||
upload_to_uri,
|
||||
is_non_local_path_uri,
|
||||
)
|
||||
from ray.util import log_once
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
|
@ -79,246 +68,8 @@ class _TrialCheckpoint(os.PathLike):
|
|||
f">"
|
||||
)
|
||||
|
||||
def download(
|
||||
self,
|
||||
cloud_path: Optional[str] = None,
|
||||
local_path: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> str:
|
||||
"""Download checkpoint from cloud.
|
||||
|
||||
This will fetch the checkpoint directory from cloud storage
|
||||
and save it to ``local_path``.
|
||||
|
||||
If a ``local_path`` argument is provided and ``self.local_path``
|
||||
is unset, it will be set to ``local_path``.
|
||||
|
||||
Args:
|
||||
cloud_path: Cloud path to load checkpoint from.
|
||||
Defaults to ``self.cloud_path``.
|
||||
local_path: Local path to save checkpoint at.
|
||||
Defaults to ``self.local_path``.
|
||||
overwrite: If True, overwrites potential existing local
|
||||
checkpoint. If False, exits if ``self.local_dir`` already
|
||||
exists and has files in it.
|
||||
|
||||
"""
|
||||
cloud_path = cloud_path or self.cloud_path
|
||||
if not cloud_path:
|
||||
raise RuntimeError(
|
||||
"Could not download trial checkpoint: No cloud "
|
||||
"path is set. Fix this by either passing a "
|
||||
"`cloud_path` to your call to `download()` or by "
|
||||
"passing a `cloud_path` into the constructor. The latter "
|
||||
"should automatically be done if you pass the correct "
|
||||
"`tune.SyncConfig`."
|
||||
)
|
||||
|
||||
local_path = local_path or self.local_path
|
||||
|
||||
if not local_path:
|
||||
raise RuntimeError(
|
||||
"Could not download trial checkpoint: No local "
|
||||
"path is set. Fix this by either passing a "
|
||||
"`local_path` to your call to `download()` or by "
|
||||
"passing a `local_path` into the constructor."
|
||||
)
|
||||
|
||||
# Only update local path if unset
|
||||
if not self.local_path:
|
||||
self.local_path = local_path
|
||||
|
||||
if (
|
||||
not overwrite
|
||||
and os.path.exists(local_path)
|
||||
and len(os.listdir(local_path)) > 0
|
||||
):
|
||||
# Local path already exists and we should not overwrite,
|
||||
# so return.
|
||||
return local_path
|
||||
|
||||
# Else: Actually download
|
||||
|
||||
# Delete existing dir
|
||||
shutil.rmtree(local_path, ignore_errors=True)
|
||||
# Re-create
|
||||
os.makedirs(local_path, 0o755, exist_ok=True)
|
||||
|
||||
# Here we trigger the actual download
|
||||
download_from_uri(uri=cloud_path, local_path=local_path)
|
||||
|
||||
# Local dir exists and is not empty
|
||||
return local_path
|
||||
|
||||
def upload(
|
||||
self,
|
||||
cloud_path: Optional[str] = None,
|
||||
local_path: Optional[str] = None,
|
||||
clean_before: bool = False,
|
||||
):
|
||||
"""Upload checkpoint to cloud.
|
||||
|
||||
This will push the checkpoint directory from local storage
|
||||
to ``cloud_path``.
|
||||
|
||||
If a ``cloud_path`` argument is provided and ``self.cloud_path``
|
||||
is unset, it will be set to ``cloud_path``.
|
||||
|
||||
Args:
|
||||
cloud_path: Cloud path to load checkpoint from.
|
||||
Defaults to ``self.cloud_path``.
|
||||
local_path: Local path to save checkpoint at.
|
||||
Defaults to ``self.local_path``.
|
||||
clean_before: If True, deletes potentially existing
|
||||
cloud bucket before storing new data.
|
||||
|
||||
"""
|
||||
local_path = local_path or self.local_path
|
||||
if not local_path:
|
||||
raise RuntimeError(
|
||||
"Could not upload trial checkpoint: No local "
|
||||
"path is set. Fix this by either passing a "
|
||||
"`local_path` to your call to `upload()` or by "
|
||||
"passing a `local_path` into the constructor."
|
||||
)
|
||||
|
||||
cloud_path = cloud_path or self.cloud_path
|
||||
if not cloud_path:
|
||||
raise RuntimeError(
|
||||
"Could not download trial checkpoint: No cloud "
|
||||
"path is set. Fix this by either passing a "
|
||||
"`cloud_path` to your call to `download()` or by "
|
||||
"passing a `cloud_path` into the constructor. The latter "
|
||||
"should automatically be done if you pass the correct "
|
||||
"`tune.SyncConfig`."
|
||||
)
|
||||
|
||||
if not self.cloud_path:
|
||||
self.cloud_path = cloud_path
|
||||
|
||||
if clean_before:
|
||||
logger.info(f"Clearing bucket contents before upload: {cloud_path}")
|
||||
delete_at_uri(cloud_path)
|
||||
|
||||
# Actually upload
|
||||
upload_to_uri(local_path, cloud_path)
|
||||
|
||||
return cloud_path
|
||||
|
||||
def save(self, path: Optional[str] = None, force_download: bool = False):
|
||||
"""Save trial checkpoint to directory or cloud storage.
|
||||
|
||||
If the ``path`` is a local target and the checkpoint already exists
|
||||
on local storage, the local directory is copied. Else, the checkpoint
|
||||
is downloaded from cloud storage.
|
||||
|
||||
If the ``path`` is a cloud target and the checkpoint does not already
|
||||
exist on local storage, it is downloaded from cloud storage before.
|
||||
That way checkpoints can be transferred across cloud storage providers.
|
||||
|
||||
Args:
|
||||
path: Path to save checkpoint at. If empty,
|
||||
the default cloud storage path is saved to the default
|
||||
local directory.
|
||||
force_download: If ``True``, forces (re-)download of
|
||||
the checkpoint. Defaults to ``False``.
|
||||
"""
|
||||
temp_dirs = set()
|
||||
# Per default, save cloud checkpoint
|
||||
if not path:
|
||||
if self.cloud_path and self.local_path:
|
||||
path = self.local_path
|
||||
elif not self.cloud_path:
|
||||
raise RuntimeError(
|
||||
"Cannot save trial checkpoint: No cloud path "
|
||||
"found. If the checkpoint is already on the node, "
|
||||
"you can pass a `path` argument to save it at another "
|
||||
"location."
|
||||
)
|
||||
else:
|
||||
# No self.local_path
|
||||
raise RuntimeError(
|
||||
"Cannot save trial checkpoint: No target path "
|
||||
"specified and no default local directory available. "
|
||||
"Please pass a `path` argument to `save()`."
|
||||
)
|
||||
elif not self.local_path and not self.cloud_path:
|
||||
raise RuntimeError(
|
||||
f"Cannot save trial checkpoint to cloud target "
|
||||
f"`{path}`: No existing local or cloud path was "
|
||||
f"found. This indicates an error when loading "
|
||||
f"the checkpoints. Please report this issue."
|
||||
)
|
||||
|
||||
if is_non_local_path_uri(path):
|
||||
# Storing on cloud
|
||||
if not self.local_path:
|
||||
# No local copy, yet. Download to temp dir
|
||||
local_path = tempfile.mkdtemp(prefix="tune_checkpoint_")
|
||||
temp_dirs.add(local_path)
|
||||
else:
|
||||
local_path = self.local_path
|
||||
|
||||
if self.cloud_path:
|
||||
# Do not update local path as it might be a temp file
|
||||
local_path = self.download(
|
||||
local_path=local_path, overwrite=force_download
|
||||
)
|
||||
|
||||
# Remove pointer to a temporary directory
|
||||
if self.local_path in temp_dirs:
|
||||
self.local_path = None
|
||||
|
||||
# We should now have a checkpoint available locally
|
||||
if not os.path.exists(local_path) or len(os.listdir(local_path)) == 0:
|
||||
raise RuntimeError(
|
||||
f"No checkpoint found in directory `{local_path}` after "
|
||||
f"download - maybe the bucket is empty or downloading "
|
||||
f"failed?"
|
||||
)
|
||||
|
||||
# Only update cloud path if it wasn't set before
|
||||
cloud_path = self.upload(
|
||||
cloud_path=path, local_path=local_path, clean_before=True
|
||||
)
|
||||
|
||||
# Clean up temporary directories
|
||||
for temp_dir in temp_dirs:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
return cloud_path
|
||||
|
||||
local_path_exists = (
|
||||
self.local_path
|
||||
and os.path.exists(self.local_path)
|
||||
and len(os.listdir(self.local_path)) > 0
|
||||
)
|
||||
|
||||
# Else: path is a local target
|
||||
if self.local_path and local_path_exists and not force_download:
|
||||
# If we have a local copy, use it
|
||||
|
||||
if path == self.local_path:
|
||||
# Nothing to do
|
||||
return self.local_path
|
||||
|
||||
# Both local, just copy tree
|
||||
if os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
shutil.copytree(self.local_path, path)
|
||||
return path
|
||||
|
||||
# Else: Download
|
||||
try:
|
||||
return self.download(local_path=path, overwrite=force_download)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Cannot save trial checkpoint to local target as downloading "
|
||||
"from cloud failed. Did you pass the correct `SyncConfig`?"
|
||||
) from e
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
||||
def __init__(
|
||||
|
@ -394,13 +145,11 @@ class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
|||
local_path: Optional[str] = None,
|
||||
overwrite: bool = False,
|
||||
) -> str:
|
||||
if log_once("trial_checkpoint_download_deprecated"):
|
||||
warnings.warn(
|
||||
"`checkpoint.download()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_directory()` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return _TrialCheckpoint.download(self, cloud_path, local_path, overwrite)
|
||||
# Deprecated: Remove whole class in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"`checkpoint.download()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_directory()` instead."
|
||||
)
|
||||
|
||||
def upload(
|
||||
self,
|
||||
|
@ -408,20 +157,16 @@ class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
|||
local_path: Optional[str] = None,
|
||||
clean_before: bool = False,
|
||||
):
|
||||
if log_once("trial_checkpoint_upload_deprecated"):
|
||||
warnings.warn(
|
||||
"`checkpoint.upload()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_uri()` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return _TrialCheckpoint.upload(self, cloud_path, local_path, clean_before)
|
||||
# Deprecated: Remove whole class in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"`checkpoint.upload()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_uri()` instead."
|
||||
)
|
||||
|
||||
def save(self, path: Optional[str] = None, force_download: bool = False):
|
||||
if log_once("trial_checkpoint_save_deprecated"):
|
||||
warnings.warn(
|
||||
"`checkpoint.save()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_directory()` or"
|
||||
"`checkpoint.to_uri()` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return _TrialCheckpoint.save(self, path, force_download)
|
||||
# Deprecated: Remove whole class in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"`checkpoint.save()` is deprecated and will be removed in "
|
||||
"the future. Please use `checkpoint.to_directory()` or"
|
||||
"`checkpoint.to_uri()` instead."
|
||||
)
|
||||
|
|
|
@ -3,38 +3,35 @@ from typing import Callable, Type, Union
|
|||
import logging
|
||||
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.util import log_once
|
||||
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
class DurableTrainable(Trainable):
|
||||
_sync_function_tpl = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if log_once("durable_deprecated"):
|
||||
logger.warning(
|
||||
"DeprecationWarning: The `DurableTrainable` class is being "
|
||||
"deprecated. Instead, all Trainables are durable by default "
|
||||
"if you provide an `upload_dir`. You'll likely only need to "
|
||||
"remove the call to `tune.durable()` or directly inherit from "
|
||||
"`Trainable` instead of `DurableTrainable` for class "
|
||||
"trainables to make your code forward-compatible."
|
||||
)
|
||||
super(DurableTrainable, self).__init__(*args, **kwargs)
|
||||
raise DeprecationWarning(
|
||||
"DeprecationWarning: The `DurableTrainable` class is being "
|
||||
"deprecated. Instead, all Trainables are durable by default "
|
||||
"if you provide an `upload_dir`. You'll likely only need to "
|
||||
"remove the call to `tune.durable()` or directly inherit from "
|
||||
"`Trainable` instead of `DurableTrainable` for class "
|
||||
"trainables to make your code forward-compatible."
|
||||
)
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
def durable(trainable: Union[str, Type[Trainable], Callable]):
|
||||
if log_once("durable_deprecated"):
|
||||
logger.warning(
|
||||
"DeprecationWarning: `tune.durable()` is being deprecated."
|
||||
"Instead, all Trainables are durable by default if "
|
||||
"you provide an `upload_dir`. You'll likely only need to remove "
|
||||
"the call to `tune.durable()` to make your code "
|
||||
"forward-compatible."
|
||||
)
|
||||
return trainable
|
||||
raise DeprecationWarning(
|
||||
"DeprecationWarning: `tune.durable()` is being deprecated."
|
||||
"Instead, all Trainables are durable by default if "
|
||||
"you provide an `upload_dir`. You'll likely only need to remove "
|
||||
"the call to `tune.durable()` to make your code "
|
||||
"forward-compatible."
|
||||
)
|
||||
|
|
|
@ -63,7 +63,6 @@ if __name__ == "__main__":
|
|||
algo = SigOptSearch(
|
||||
space,
|
||||
name="SigOpt Example Experiment",
|
||||
max_concurrent=1,
|
||||
metric="mean_loss",
|
||||
mode="min",
|
||||
)
|
||||
|
|
|
@ -62,7 +62,6 @@ if __name__ == "__main__":
|
|||
space,
|
||||
name="SigOpt Example Multi Objective Experiment",
|
||||
observation_budget=4 if args.smoke_test else 100,
|
||||
max_concurrent=1,
|
||||
metric=["average", "std", "sharpe"],
|
||||
mode=["max", "min", "obs"],
|
||||
)
|
||||
|
|
|
@ -89,7 +89,6 @@ if __name__ == "__main__":
|
|||
connection=conn,
|
||||
experiment_id=experiment.id,
|
||||
name="SigOpt Example Existing Experiment",
|
||||
max_concurrent=1,
|
||||
metric=["average", "std"],
|
||||
mode=["obs", "min"],
|
||||
)
|
||||
|
|
|
@ -19,9 +19,3 @@ def set_keras_threads(threads):
|
|||
# is heavily parallelized across multiple cores.
|
||||
tf.config.threading.set_inter_op_parallelism_threads(threads)
|
||||
tf.config.threading.set_intra_op_parallelism_threads(threads)
|
||||
|
||||
|
||||
def TuneKerasCallback(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"TuneKerasCallback is now tune.integration.keras.TuneReporterCallback."
|
||||
)
|
||||
|
|
|
@ -3,10 +3,9 @@ import logging
|
|||
|
||||
import ray
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.logger import Logger, LoggerCallback
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.result import TRAINING_ITERATION, TIMESTEPS_TOTAL
|
||||
from ray.tune.trial import Trial
|
||||
from ray.util.annotations import Deprecated
|
||||
from ray.util.ml_utils.mlflow import MLflowLoggerUtil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -136,22 +135,6 @@ class MLflowLoggerCallback(LoggerCallback):
|
|||
self.mlflow_util.end_run(run_id=run_id, status=status)
|
||||
|
||||
|
||||
@Deprecated
|
||||
class MLflowLogger(Logger):
|
||||
"""MLflow logger using the deprecated Logger API.
|
||||
|
||||
Requires the experiment configuration to have a MLflow Experiment ID
|
||||
or manually set the proper environment variables.
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
raise DeprecationWarning(
|
||||
"The legacy MLflowLogger has been "
|
||||
"deprecated. Use the MLflowLoggerCallback "
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def mlflow_mixin(func: Callable):
|
||||
"""mlflow_mixin
|
||||
|
||||
|
|
|
@ -198,7 +198,6 @@ def DistributedTrainableCreator(
|
|||
num_workers_per_host: Optional[int] = None,
|
||||
backend: str = "gloo",
|
||||
timeout_s: int = NCCL_TIMEOUT_S,
|
||||
use_gpu=None,
|
||||
) -> Type[_TorchTrainable]:
|
||||
"""Creates a class that executes distributed training.
|
||||
|
||||
|
@ -239,10 +238,6 @@ def DistributedTrainableCreator(
|
|||
train_func, num_workers=2)
|
||||
analysis = tune.run(trainable_cls)
|
||||
"""
|
||||
if use_gpu:
|
||||
raise DeprecationWarning(
|
||||
"use_gpu is deprecated. Use 'num_gpus_per_worker' instead."
|
||||
)
|
||||
detect_checkpoint_function(func, abort=True)
|
||||
if num_workers_per_host:
|
||||
if num_workers % num_workers_per_host:
|
||||
|
|
|
@ -15,6 +15,7 @@ from ray.tune.utils import flatten_dict
|
|||
from ray.tune.trial import Trial
|
||||
|
||||
import yaml
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
@ -398,6 +399,7 @@ class WandbLoggerCallback(LoggerCallback):
|
|||
del self._trial_processes[trial]
|
||||
|
||||
|
||||
@Deprecated
|
||||
class WandbLogger(Logger):
|
||||
"""WandbLogger
|
||||
|
||||
|
@ -444,8 +446,7 @@ class WandbLogger(Logger):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from ray.tune.integration.wandb import WandbLogger
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
|
@ -459,14 +460,14 @@ class WandbLogger(Logger):
|
|||
"log_config": True
|
||||
}
|
||||
},
|
||||
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
|
||||
calllbacks=[WandbLoggerCallback])
|
||||
|
||||
Example for RLlib:
|
||||
|
||||
.. code-block :: python
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.wandb import WandbLogger
|
||||
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||
|
||||
tune.run(
|
||||
"PPO",
|
||||
|
@ -479,40 +480,18 @@ class WandbLogger(Logger):
|
|||
}
|
||||
}
|
||||
},
|
||||
loggers=[WandbLogger])
|
||||
callbacks=[WandbLoggerCallback])
|
||||
|
||||
|
||||
"""
|
||||
|
||||
_experiment_logger_cls = WandbLoggerCallback
|
||||
|
||||
def _init(self):
|
||||
config = self.config.copy()
|
||||
config.pop("callbacks", None) # Remove callbacks
|
||||
|
||||
try:
|
||||
if config.get("logger_config", {}).get("wandb"):
|
||||
logger_config = config.pop("logger_config")
|
||||
wandb_config = logger_config.get("wandb").copy()
|
||||
else:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Wandb logger specified but no configuration has been passed. "
|
||||
"Make sure to include a `wandb` key in your `config` dict "
|
||||
"containing at least a `project` specification."
|
||||
)
|
||||
|
||||
self._trial_experiment_logger = self._experiment_logger_cls(**wandb_config)
|
||||
self._trial_experiment_logger.setup()
|
||||
self._trial_experiment_logger.log_trial_start(self.trial)
|
||||
|
||||
def on_result(self, result: Dict):
|
||||
self._trial_experiment_logger.log_trial_result(0, self.trial, result)
|
||||
|
||||
def close(self):
|
||||
self._trial_experiment_logger.log_trial_end(self.trial, failed=False)
|
||||
del self._trial_experiment_logger
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"This `Logger` class is deprecated. "
|
||||
"Use the `WandbLoggerCallback` callback instead."
|
||||
)
|
||||
|
||||
|
||||
class WandbTrainableMixin:
|
||||
|
|
|
@ -747,16 +747,6 @@ class TBXLoggerCallback(LoggerCallback):
|
|||
)
|
||||
|
||||
|
||||
# Maintain backwards compatibility.
|
||||
from ray.tune.integration.mlflow import ( # noqa: E402
|
||||
MLflowLogger as _MLflowLogger,
|
||||
)
|
||||
|
||||
MLflowLogger = _MLflowLogger
|
||||
# The capital L is a typo, but needs to remain for backwards compatibility.
|
||||
MLFLowLogger = _MLflowLogger
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
result = result.copy()
|
||||
result.update(config=None) # drop config from pretty print
|
||||
|
|
|
@ -525,13 +525,11 @@ class Quantized(Sampler):
|
|||
return list(quantized)
|
||||
|
||||
|
||||
# TODO (krfricke): Remove tune.function
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
def function(func):
|
||||
logger.warning(
|
||||
"DeprecationWarning: wrapping {} with tune.function() is no "
|
||||
"longer needed".format(func)
|
||||
raise DeprecationWarning(
|
||||
"wrapping {} with tune.function() is no longer needed".format(func)
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
def sample_from(func: Callable[[Dict], Any]):
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import Dict, Any, List, Optional, Set, Tuple, Union, Callable
|
|||
import pickle
|
||||
import warnings
|
||||
|
||||
from ray.util import log_once
|
||||
from ray.util.annotations import PublicAPI, Deprecated
|
||||
from ray.tune import trial_runner
|
||||
from ray.tune.resources import Resources
|
||||
|
@ -583,6 +582,7 @@ _DistributeResourcesDefault = DistributeResources(add_bundles=False)
|
|||
_DistributeResourcesDistributedDefault = DistributeResources(add_bundles=True)
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
def evenly_distribute_cpus_gpus(
|
||||
trial_runner: "trial_runner.TrialRunner",
|
||||
|
@ -621,18 +621,16 @@ def evenly_distribute_cpus_gpus(
|
|||
the function.
|
||||
"""
|
||||
|
||||
if log_once("evenly_distribute_cpus_gpus_deprecated"):
|
||||
warnings.warn(
|
||||
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
||||
"and `evenly_distribute_cpus_gpus_distributed` are "
|
||||
"being deprecated. Use `DistributeResources()` and "
|
||||
"`DistributeResources(add_bundles=False)` instead "
|
||||
"for equivalent functionality."
|
||||
)
|
||||
|
||||
return _DistributeResourcesDefault(trial_runner, trial, result, scheduler)
|
||||
raise DeprecationWarning(
|
||||
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
||||
"and `evenly_distribute_cpus_gpus_distributed` are "
|
||||
"being deprecated. Use `DistributeResources()` and "
|
||||
"`DistributeResources(add_bundles=False)` instead "
|
||||
"for equivalent functionality."
|
||||
)
|
||||
|
||||
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
@Deprecated
|
||||
def evenly_distribute_cpus_gpus_distributed(
|
||||
trial_runner: "trial_runner.TrialRunner",
|
||||
|
@ -671,17 +669,12 @@ def evenly_distribute_cpus_gpus_distributed(
|
|||
the function.
|
||||
"""
|
||||
|
||||
if log_once("evenly_distribute_cpus_gpus_deprecated"):
|
||||
warnings.warn(
|
||||
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
||||
"and `evenly_distribute_cpus_gpus_distributed` are "
|
||||
"being deprecated. Use `DistributeResources()` and "
|
||||
"`DistributeResources(add_bundles=False)` instead "
|
||||
"for equivalent functionality."
|
||||
)
|
||||
|
||||
return _DistributeResourcesDistributedDefault(
|
||||
trial_runner, trial, result, scheduler
|
||||
raise DeprecationWarning(
|
||||
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
||||
"and `evenly_distribute_cpus_gpus_distributed` are "
|
||||
"being deprecated. Use `DistributeResources()` and "
|
||||
"`DistributeResources(add_bundles=False)` instead "
|
||||
"for equivalent functionality."
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -103,28 +103,6 @@ def report(_metric=None, **kwargs):
|
|||
return _session(_metric, **kwargs)
|
||||
|
||||
|
||||
def make_checkpoint_dir(step=None):
|
||||
"""Gets the next checkpoint dir.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
|
||||
.. deprecated:: 0.8.7
|
||||
Use tune.checkpoint_dir instead.
|
||||
"""
|
||||
raise DeprecationWarning("Deprecated method. Use `tune.checkpoint_dir` instead.")
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint):
|
||||
"""Register the given checkpoint.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
|
||||
.. deprecated:: 0.8.7
|
||||
Use tune.checkpoint_dir instead.
|
||||
"""
|
||||
raise DeprecationWarning("Deprecated method. Use `tune.checkpoint_dir` instead.")
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@contextmanager
|
||||
def checkpoint_dir(step: int):
|
||||
|
|
|
@ -157,38 +157,3 @@ __all__ = [
|
|||
"Repeater",
|
||||
"ConcurrencyLimiter",
|
||||
]
|
||||
|
||||
|
||||
def BayesOptSearch(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"""This class has been moved. Please import via
|
||||
`from ray.tune.suggest.bayesopt import BayesOptSearch`"""
|
||||
)
|
||||
|
||||
|
||||
def HyperOptSearch(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"""This class has been moved. Please import via
|
||||
`from ray.tune.suggest.hyperopt import HyperOptSearch`"""
|
||||
)
|
||||
|
||||
|
||||
def NevergradSearch(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"""This class has been moved. Please import via
|
||||
`from ray.tune.suggest.nevergrad import NevergradSearch`"""
|
||||
)
|
||||
|
||||
|
||||
def SkOptSearch(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"""This class has been moved. Please import via
|
||||
`from ray.tune.suggest.skopt import SkOptSearch`"""
|
||||
)
|
||||
|
||||
|
||||
def SigOptSearch(*args, **kwargs):
|
||||
raise DeprecationWarning(
|
||||
"""This class has been moved. Please import via
|
||||
`from ray.tune.suggest.sigopt import SigOptSearch`"""
|
||||
)
|
||||
|
|
|
@ -73,8 +73,6 @@ class AxSearch(Searcher):
|
|||
ax_client: Optional AxClient instance. If this is set, do
|
||||
not pass any values to these parameters: `space`, `metric`,
|
||||
`parameter_constraints`, `outcome_constraints`.
|
||||
use_early_stopped_trials: Deprecated.
|
||||
max_concurrent: Deprecated.
|
||||
**ax_kwargs: Passed to AxClient instance. Ignored if `AxClient` is not
|
||||
None.
|
||||
|
||||
|
@ -133,8 +131,6 @@ class AxSearch(Searcher):
|
|||
parameter_constraints: Optional[List] = None,
|
||||
outcome_constraints: Optional[List] = None,
|
||||
ax_client: Optional[AxClient] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
**ax_kwargs
|
||||
):
|
||||
assert (
|
||||
|
@ -149,8 +145,6 @@ class AxSearch(Searcher):
|
|||
super(AxSearch, self).__init__(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials,
|
||||
)
|
||||
|
||||
self._ax = ax_client
|
||||
|
@ -170,8 +164,6 @@ class AxSearch(Searcher):
|
|||
|
||||
self._points_to_evaluate = copy.deepcopy(points_to_evaluate)
|
||||
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
self._parameters = []
|
||||
self._live_trial_mapping = {}
|
||||
|
||||
|
|
|
@ -79,8 +79,6 @@ class BayesOptSearch(Searcher):
|
|||
analysis: Optionally, the previous analysis
|
||||
to integrate.
|
||||
verbose: Sets verbosity level for BayesOpt packages.
|
||||
max_concurrent: Deprecated.
|
||||
use_early_stopped_trials: Deprecated.
|
||||
|
||||
Tune automatically converts search spaces to BayesOptSearch's format:
|
||||
|
||||
|
@ -130,8 +128,6 @@ class BayesOptSearch(Searcher):
|
|||
patience: int = 5,
|
||||
skip_duplicate: bool = True,
|
||||
analysis: Optional[ExperimentAnalysis] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
):
|
||||
assert byo is not None, (
|
||||
"BayesOpt must be installed!. You can install BayesOpt with"
|
||||
|
@ -139,7 +135,6 @@ class BayesOptSearch(Searcher):
|
|||
)
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self.max_concurrent = max_concurrent
|
||||
self._config_counter = defaultdict(int)
|
||||
self._patience = patience
|
||||
# int: Precision at which to hash values.
|
||||
|
@ -150,8 +145,6 @@ class BayesOptSearch(Searcher):
|
|||
super(BayesOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials,
|
||||
)
|
||||
|
||||
if utility_kwargs is None:
|
||||
|
|
|
@ -61,8 +61,6 @@ class TuneBOHB(Searcher):
|
|||
Parameters will be sampled from this space which will be used
|
||||
to run trials.
|
||||
bohb_config: configuration for HpBandSter BOHB algorithm
|
||||
max_concurrent: Deprecated. Use
|
||||
``tune.suggest.ConcurrencyLimiter()``.
|
||||
metric: The training result objective value attribute. If None
|
||||
but a mode was passed, the anonymous metric `_metric` will be used
|
||||
per default.
|
||||
|
@ -126,7 +124,6 @@ class TuneBOHB(Searcher):
|
|||
self,
|
||||
space: Optional[Union[Dict, "ConfigSpace.ConfigurationSpace"]] = None,
|
||||
bohb_config: Optional[Dict] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
|
@ -139,7 +136,6 @@ class TuneBOHB(Searcher):
|
|||
`pip install hpbandster ConfigSpace`."""
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self._max_concurrent = max_concurrent
|
||||
self.trial_to_params = {}
|
||||
self._metric = metric
|
||||
|
||||
|
@ -159,7 +155,8 @@ class TuneBOHB(Searcher):
|
|||
self._points_to_evaluate = points_to_evaluate
|
||||
|
||||
super(TuneBOHB, self).__init__(
|
||||
metric=self._metric, mode=mode, max_concurrent=max_concurrent
|
||||
metric=self._metric,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
if self._space:
|
||||
|
|
|
@ -78,8 +78,6 @@ class HyperOptSearch(Searcher):
|
|||
results. Defaults to None.
|
||||
gamma: parameter governing the tree parzen
|
||||
estimators suggestion algorithm. Defaults to 0.25.
|
||||
max_concurrent: Deprecated.
|
||||
use_early_stopped_trials: Deprecated.
|
||||
|
||||
Tune automatically converts search spaces to HyperOpt's format:
|
||||
|
||||
|
@ -138,8 +136,6 @@ class HyperOptSearch(Searcher):
|
|||
n_initial_points: int = 20,
|
||||
random_state_seed: Optional[int] = None,
|
||||
gamma: float = 0.25,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
):
|
||||
assert (
|
||||
hpo is not None
|
||||
|
@ -149,10 +145,7 @@ class HyperOptSearch(Searcher):
|
|||
super(HyperOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials,
|
||||
)
|
||||
self.max_concurrent = max_concurrent
|
||||
# hyperopt internally minimizes, so "max" => -1
|
||||
if mode == "max":
|
||||
self.metric_op = -1.0
|
||||
|
|
|
@ -59,8 +59,6 @@ class NevergradSearch(Searcher):
|
|||
you want to run first to help the algorithm make better suggestions
|
||||
for future parameters. Needs to be a list of dicts containing the
|
||||
configurations.
|
||||
use_early_stopped_trials: Deprecated.
|
||||
max_concurrent: Deprecated.
|
||||
|
||||
Tune automatically converts search spaces to Nevergrad's format:
|
||||
|
||||
|
@ -120,7 +118,6 @@ class NevergradSearch(Searcher):
|
|||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert (
|
||||
|
@ -131,9 +128,7 @@ class NevergradSearch(Searcher):
|
|||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
|
||||
super(NevergradSearch, self).__init__(
|
||||
metric=metric, mode=mode, max_concurrent=max_concurrent, **kwargs
|
||||
)
|
||||
super(NevergradSearch, self).__init__(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
self._space = None
|
||||
self._opt_factory = None
|
||||
|
@ -180,7 +175,6 @@ class NevergradSearch(Searcher):
|
|||
)
|
||||
|
||||
self._live_trial_mapping = {}
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
if self._nevergrad_opt or self._space:
|
||||
self._setup_nevergrad()
|
||||
|
|
|
@ -289,9 +289,7 @@ class OptunaSearch(Searcher):
|
|||
evaluated_rewards: Optional[List] = None,
|
||||
):
|
||||
assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
|
||||
super(OptunaSearch, self).__init__(
|
||||
metric=metric, mode=mode, max_concurrent=None, use_early_stopped_trials=None
|
||||
)
|
||||
super(OptunaSearch, self).__init__(metric=metric, mode=mode)
|
||||
|
||||
if isinstance(space, dict) and space:
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||
|
|
|
@ -90,7 +90,7 @@ class SigOptSearch(Searcher):
|
|||
]
|
||||
algo = SigOptSearch(
|
||||
space, name="SigOpt Example Experiment",
|
||||
max_concurrent=1, metric="mean_loss", mode="min")
|
||||
metric="mean_loss", mode="min")
|
||||
|
||||
|
||||
Example:
|
||||
|
@ -117,7 +117,7 @@ class SigOptSearch(Searcher):
|
|||
]
|
||||
algo = SigOptSearch(
|
||||
space, name="SigOpt Multi Objective Example Experiment",
|
||||
max_concurrent=1, metric=["average", "std"], mode=["max", "min"])
|
||||
metric=["average", "std"], mode=["max", "min"])
|
||||
"""
|
||||
|
||||
OBJECTIVE_MAP = {
|
||||
|
|
|
@ -68,8 +68,6 @@ class SkOptSearch(Searcher):
|
|||
convert_to_python: SkOpt outputs numpy primitives (e.g.
|
||||
``np.int64``) instead of Python types. If this setting is set
|
||||
to ``True``, the values will be converted to Python primitives.
|
||||
max_concurrent: Deprecated.
|
||||
use_early_stopped_trials: Deprecated.
|
||||
|
||||
Tune automatically converts search spaces to SkOpt's format:
|
||||
|
||||
|
@ -127,8 +125,6 @@ class SkOptSearch(Searcher):
|
|||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
evaluated_rewards: Optional[List] = None,
|
||||
convert_to_python: bool = True,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
):
|
||||
assert sko is not None, (
|
||||
"skopt must be installed! "
|
||||
|
@ -138,12 +134,10 @@ class SkOptSearch(Searcher):
|
|||
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
super(SkOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials,
|
||||
)
|
||||
|
||||
self._initial_points = []
|
||||
|
|
|
@ -89,21 +89,7 @@ class Searcher:
|
|||
self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
max_concurrent: Optional[int] = None,
|
||||
use_early_stopped_trials: Optional[bool] = None,
|
||||
):
|
||||
if use_early_stopped_trials is False:
|
||||
raise DeprecationWarning(
|
||||
"Early stopped trials are now always used. If this is a "
|
||||
"problem, file an issue: https://github.com/ray-project/ray."
|
||||
)
|
||||
if max_concurrent is not None:
|
||||
raise DeprecationWarning(
|
||||
"`max_concurrent` is deprecated for this "
|
||||
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
|
||||
"instead. This will raise an error in future versions of Ray."
|
||||
)
|
||||
|
||||
self._metric = metric
|
||||
self._mode = mode
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import subprocess
|
|||
import tempfile
|
||||
import time
|
||||
import types
|
||||
import warnings
|
||||
|
||||
from typing import Optional, List, Callable, Union, Tuple
|
||||
|
||||
|
@ -18,7 +17,6 @@ import ray
|
|||
from ray.tune.error import TuneError
|
||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.debug import log_once
|
||||
from ray.ml.utils.remote_storage import (
|
||||
S3_PREFIX,
|
||||
GS_PREFIX,
|
||||
|
@ -197,13 +195,12 @@ class FunctionBasedClient(SyncClient):
|
|||
self._sync_down_legacy = _is_legacy_sync_fn(sync_up_func)
|
||||
|
||||
if self._sync_up_legacy or self._sync_down_legacy:
|
||||
if log_once("func_sync_up_legacy"):
|
||||
warnings.warn(
|
||||
"Your sync functions currently only accepts two params "
|
||||
"(a `source` and a `target`). In the future, we will "
|
||||
"pass an additional `exclude` parameter. Please adjust "
|
||||
"your sync function accordingly."
|
||||
)
|
||||
raise DeprecationWarning(
|
||||
"Your sync functions currently only accepts two params "
|
||||
"(a `source` and a `target`). In the future, we will "
|
||||
"pass an additional `exclude` parameter. Please adjust "
|
||||
"your sync function accordingly."
|
||||
)
|
||||
|
||||
self.delete_func = delete_func or noop
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ import logging
|
|||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
from inspect import isclass
|
||||
from shlex import quote
|
||||
|
@ -98,13 +97,13 @@ def validate_sync_config(sync_config: "SyncConfig"):
|
|||
sync_config.node_sync_period = -1
|
||||
sync_config.cloud_sync_period = -1
|
||||
|
||||
warnings.warn(
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"The `node_sync_period` and "
|
||||
"`cloud_sync_period` properties of `tune.SyncConfig` are "
|
||||
"deprecated. Pass the `sync_period` property instead. "
|
||||
"\nFor now, the lower of the two values (if provided) will "
|
||||
f"be used as the sync_period. This value is: {sync_period}",
|
||||
DeprecationWarning,
|
||||
f"be used as the sync_period. This value is: {sync_period}"
|
||||
)
|
||||
|
||||
if sync_config.sync_to_cloud or sync_config.sync_to_driver:
|
||||
|
@ -119,15 +118,15 @@ def validate_sync_config(sync_config: "SyncConfig"):
|
|||
sync_config.sync_to_cloud = None
|
||||
sync_config.sync_to_driver = None
|
||||
|
||||
warnings.warn(
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"The `sync_to_cloud` and `sync_to_driver` properties of "
|
||||
"`tune.SyncConfig` are deprecated. Pass the `syncer` property "
|
||||
"instead. Presence of an `upload_dir` decides if checkpoints "
|
||||
"are synced to cloud or not. Syncing to driver is "
|
||||
"automatically disabled if an `upload_dir` is given."
|
||||
f"\nFor now, as the upload dir is {help}, the respective "
|
||||
f"syncer is used. This value is: {syncer}",
|
||||
DeprecationWarning,
|
||||
f"syncer is used. This value is: {syncer}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -199,6 +198,7 @@ class SyncConfig:
|
|||
sync_period: int = 300
|
||||
|
||||
# Deprecated arguments
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
sync_to_cloud: Any = None
|
||||
sync_to_driver: Any = None
|
||||
node_sync_period: int = -1
|
||||
|
|
|
@ -28,7 +28,7 @@ from ray.tune import (
|
|||
from ray.tune.callback import Callback
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.logger import Logger, LegacyLoggerCallback
|
||||
from ray.tune.ray_trial_executor import noop_logger_creator
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.result import (
|
||||
|
@ -59,7 +59,7 @@ from ray.tune.suggest.suggestion import ConcurrencyLimiter
|
|||
from ray.tune.sync_client import CommandBasedClient
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.utils import flatten_dict, get_pinned_object, pin_in_object_store
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
|
||||
|
||||
|
@ -122,14 +122,14 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
|
||||
[trial1] = run(
|
||||
_function_trainable,
|
||||
loggers=[FunctionAPILogger],
|
||||
callbacks=[LegacyLoggerCallback([FunctionAPILogger])],
|
||||
raise_on_failed_trial=False,
|
||||
scheduler=MockScheduler(),
|
||||
).trials
|
||||
|
||||
[trial2] = run(
|
||||
class_trainable_name,
|
||||
loggers=[ClassAPILogger],
|
||||
callbacks=[LegacyLoggerCallback([ClassAPILogger])],
|
||||
raise_on_failed_trial=False,
|
||||
scheduler=MockScheduler(),
|
||||
).trials
|
||||
|
@ -180,33 +180,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
|
||||
return function_output, trials
|
||||
|
||||
def testPinObject(self):
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return get_pinned_object(X)
|
||||
|
||||
self.assertEqual(ray.get(f.remote()), "hello")
|
||||
|
||||
def testFetchPinned(self):
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
def train(config, reporter):
|
||||
get_pinned_object(X)
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100)
|
||||
|
||||
def testRegisterEnv(self):
|
||||
register_env("foo", lambda: None)
|
||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||
|
@ -756,7 +729,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
},
|
||||
verbose=1,
|
||||
local_dir=tmpdir,
|
||||
loggers=None,
|
||||
)
|
||||
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
|
||||
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
|
||||
|
|
|
@ -1,551 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray import tune
|
||||
from ray.ml.utils.remote_storage import upload_to_uri, delete_at_uri
|
||||
from ray.tune.cloud import TrialCheckpoint
|
||||
|
||||
|
||||
class TrialCheckpointApiTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.local_dir = tempfile.mkdtemp()
|
||||
with open(os.path.join(self.local_dir, "some_file"), "w") as f:
|
||||
f.write("checkpoint")
|
||||
|
||||
self.cloud_dir = "memory:///cloud_dir"
|
||||
|
||||
self._save_checkpoint_at(self.cloud_dir)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
shutil.rmtree(self.local_dir)
|
||||
delete_at_uri(self.cloud_dir)
|
||||
|
||||
def _save_checkpoint_at(self, target):
|
||||
delete_at_uri(target)
|
||||
upload_to_uri(local_path=self.local_dir, uri=target)
|
||||
|
||||
def testConstructTrialCheckpoint(self):
|
||||
# All these constructions should work
|
||||
TrialCheckpoint(None, None)
|
||||
TrialCheckpoint("/tmp", None)
|
||||
TrialCheckpoint(None, "memory:///invalid")
|
||||
TrialCheckpoint("/remote/node/dir", None)
|
||||
|
||||
def ensureCheckpointFile(self):
|
||||
with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
|
||||
f.write("checkpoint\n")
|
||||
|
||||
def testDownloadNoDefaults(self):
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaises(RuntimeError):
|
||||
checkpoint.download()
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.download(local_path=self.local_dir)
|
||||
|
||||
# Case: Cloud dir is passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
||||
checkpoint.download(cloud_path=self.cloud_dir)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
path = checkpoint.download(local_path=self.local_dir, cloud_path=self.cloud_dir)
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
def testDownloadDefaultLocal(self):
|
||||
other_local_dir = "/tmp/invalid"
|
||||
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.download()
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.download(local_path=other_local_dir)
|
||||
|
||||
# Case: Cloud dir is passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
path = checkpoint.download(cloud_path=self.cloud_dir)
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
path = checkpoint.download(
|
||||
local_path=other_local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
self.assertEqual(other_local_dir, path)
|
||||
|
||||
def testDownloadDefaultCloud(self):
|
||||
other_cloud_dir = "memory:///other"
|
||||
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
||||
checkpoint.download()
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
path = checkpoint.download(local_path=self.local_dir)
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
# Case: Cloud dir is passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
||||
checkpoint.download(cloud_path=other_cloud_dir)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
path = checkpoint.download(
|
||||
local_path=self.local_dir, cloud_path=other_cloud_dir
|
||||
)
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
def testDownloadDefaultBoth(self):
|
||||
other_local_dir = "/tmp/other"
|
||||
other_cloud_dir = "memory:///other"
|
||||
|
||||
self._save_checkpoint_at(other_cloud_dir)
|
||||
self._save_checkpoint_at(self.cloud_dir)
|
||||
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.download()
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.download(local_path=other_local_dir)
|
||||
|
||||
self.assertEqual(other_local_dir, path)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.download(
|
||||
local_path=other_local_dir, cloud_path=other_cloud_dir
|
||||
)
|
||||
|
||||
self.assertEqual(other_local_dir, path)
|
||||
|
||||
def testUploadNoDefaults(self):
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaises(RuntimeError):
|
||||
checkpoint.upload()
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.upload(local_path=self.local_dir)
|
||||
|
||||
# Case: Cloud dir is passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
||||
checkpoint.upload(cloud_path=self.cloud_dir)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint()
|
||||
path = checkpoint.upload(local_path=self.local_dir, cloud_path=self.cloud_dir)
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
def testUploadDefaultLocal(self):
|
||||
other_local_dir = "/tmp/invalid"
|
||||
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.upload()
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.upload(local_path=other_local_dir)
|
||||
|
||||
# Case: Cloud dir is passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
path = checkpoint.upload(cloud_path=self.cloud_dir)
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
path = checkpoint.upload(local_path=other_local_dir, cloud_path=self.cloud_dir)
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
def testUploadDefaultCloud(self):
|
||||
other_cloud_dir = "memory:///other"
|
||||
|
||||
delete_at_uri(other_cloud_dir)
|
||||
self._save_checkpoint_at(other_cloud_dir)
|
||||
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
||||
checkpoint.upload()
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
path = checkpoint.upload(local_path=self.local_dir)
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
# Case: Cloud dir is passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
||||
checkpoint.upload(cloud_path=other_cloud_dir)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
path = checkpoint.upload(local_path=self.local_dir, cloud_path=other_cloud_dir)
|
||||
|
||||
self.assertEqual(other_cloud_dir, path)
|
||||
|
||||
def testUploadDefaultBoth(self):
|
||||
other_local_dir = "/tmp/other"
|
||||
other_cloud_dir = "memory:///other"
|
||||
|
||||
delete_at_uri(other_cloud_dir)
|
||||
self._save_checkpoint_at(other_cloud_dir)
|
||||
shutil.copytree(self.local_dir, other_local_dir)
|
||||
|
||||
# Case: Nothing is passed
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.upload()
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
# Case: Local dir is passed
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.upload(local_path=other_local_dir)
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
# Case: Both are passed
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.upload(local_path=other_local_dir, cloud_path=other_cloud_dir)
|
||||
|
||||
self.assertEqual(other_cloud_dir, path)
|
||||
|
||||
def testSaveLocalTarget(self):
|
||||
state = {}
|
||||
|
||||
def copytree(source, dest):
|
||||
state["copy_source"] = source
|
||||
state["copy_dest"] = dest
|
||||
|
||||
other_local_dir = "/tmp/other"
|
||||
|
||||
# Case: No defaults
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.save()
|
||||
|
||||
# Case: Default local dir
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
||||
checkpoint.save()
|
||||
|
||||
# Case: Default cloud dir, no local dir passed
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "No target path"):
|
||||
checkpoint.save()
|
||||
|
||||
# Case: Default cloud dir, pass local dir
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
|
||||
path = checkpoint.save(self.local_dir, force_download=True)
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
# Case: Default local dir, pass local dir
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
self.ensureCheckpointFile()
|
||||
|
||||
with patch("shutil.copytree", copytree):
|
||||
path = checkpoint.save(other_local_dir)
|
||||
|
||||
self.assertEqual(other_local_dir, path)
|
||||
self.assertEqual(state["copy_source"], self.local_dir)
|
||||
self.assertEqual(state["copy_dest"], other_local_dir)
|
||||
|
||||
# Case: Both default, no pass
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.save()
|
||||
|
||||
self.assertEqual(self.local_dir, path)
|
||||
|
||||
# Case: Both default, pass other local dir
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
with patch("shutil.copytree", copytree):
|
||||
path = checkpoint.save(other_local_dir)
|
||||
|
||||
self.assertEqual(other_local_dir, path)
|
||||
self.assertEqual(state["copy_source"], self.local_dir)
|
||||
self.assertEqual(state["copy_dest"], other_local_dir)
|
||||
self.assertEqual(checkpoint.local_path, self.local_dir)
|
||||
|
||||
def testSaveCloudTarget(self):
|
||||
other_cloud_dir = "memory:///other"
|
||||
|
||||
delete_at_uri(other_cloud_dir)
|
||||
self._save_checkpoint_at(other_cloud_dir)
|
||||
|
||||
# Case: No defaults
|
||||
checkpoint = TrialCheckpoint()
|
||||
with self.assertRaisesRegex(RuntimeError, "No existing local"):
|
||||
checkpoint.save(self.cloud_dir)
|
||||
|
||||
# Case: Default local dir
|
||||
# Write a checkpoint here as we assume existing local dir
|
||||
with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
|
||||
f.write("Checkpoint\n")
|
||||
|
||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
||||
path = checkpoint.save(self.cloud_dir)
|
||||
|
||||
self.assertEqual(self.cloud_dir, path)
|
||||
|
||||
# Clean up checkpoint
|
||||
os.remove(os.path.join(self.local_dir, "checkpoint.txt"))
|
||||
|
||||
# Case: Default cloud dir, copy to other cloud
|
||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
||||
|
||||
path = checkpoint.save(other_cloud_dir)
|
||||
|
||||
self.assertEqual(other_cloud_dir, path)
|
||||
|
||||
# Case: Default both, copy to other cloud
|
||||
checkpoint = TrialCheckpoint(
|
||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
||||
)
|
||||
|
||||
path = checkpoint.save(other_cloud_dir)
|
||||
|
||||
self.assertEqual(other_cloud_dir, path)
|
||||
|
||||
|
||||
def train(config, checkpoint_dir=None):
|
||||
for i in range(10):
|
||||
with tune.checkpoint_dir(step=0) as cd:
|
||||
with open(os.path.join(cd, "checkpoint.json"), "wt") as f:
|
||||
json.dump({"score": i, "train_id": config["train_id"]}, f)
|
||||
tune.report(score=i)
|
||||
|
||||
|
||||
class TrialCheckpointEndToEndTest(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.local_experiment_dir = tempfile.mkdtemp()
|
||||
|
||||
self.fake_cloud_dir = tempfile.mkdtemp()
|
||||
self.cloud_target = "memory:///invalid/sub/path"
|
||||
|
||||
self.second_fake_cloud_dir = tempfile.mkdtemp()
|
||||
self.second_cloud_target = "memory:///other/cloud"
|
||||
|
||||
def tearDown(self) -> None:
|
||||
shutil.rmtree(self.local_experiment_dir)
|
||||
shutil.rmtree(self.fake_cloud_dir)
|
||||
shutil.rmtree(self.second_fake_cloud_dir)
|
||||
|
||||
def _delete_at_uri(self, uri: str):
|
||||
cloud_local_dir = uri.replace(self.cloud_target, self.fake_cloud_dir)
|
||||
cloud_local_dir = cloud_local_dir.replace(
|
||||
self.second_cloud_target, self.second_fake_cloud_dir
|
||||
)
|
||||
shutil.rmtree(cloud_local_dir)
|
||||
|
||||
def _fake_download_from_uri(self, uri: str, local_path: str):
|
||||
cloud_local_dir = uri.replace(self.cloud_target, self.fake_cloud_dir)
|
||||
cloud_local_dir = cloud_local_dir.replace(
|
||||
self.second_cloud_target, self.second_fake_cloud_dir
|
||||
)
|
||||
|
||||
shutil.rmtree(local_path, ignore_errors=True)
|
||||
shutil.copytree(cloud_local_dir, local_path)
|
||||
|
||||
def _fake_upload_to_uri(self, local_path: str, uri: str):
|
||||
cloud_local_dir = uri.replace(self.cloud_target, self.fake_cloud_dir)
|
||||
cloud_local_dir = cloud_local_dir.replace(
|
||||
self.second_cloud_target, self.second_fake_cloud_dir
|
||||
)
|
||||
shutil.rmtree(cloud_local_dir, ignore_errors=True)
|
||||
shutil.copytree(local_path, cloud_local_dir)
|
||||
|
||||
def testCheckpointDownload(self):
|
||||
analysis = tune.run(
|
||||
train,
|
||||
config={"train_id": tune.grid_search([0, 1, 2, 3, 4])},
|
||||
local_dir=self.local_experiment_dir,
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
# Inject the sync config (this is usually done by `tune.run()`)
|
||||
analysis._sync_config = tune.SyncConfig(upload_dir=self.cloud_target)
|
||||
|
||||
# Pretend we have all checkpoints on cloud storage (durable)
|
||||
shutil.rmtree(self.fake_cloud_dir, ignore_errors=True)
|
||||
shutil.copytree(self.local_experiment_dir, self.fake_cloud_dir)
|
||||
|
||||
# Pretend we don't have these on local storage
|
||||
shutil.rmtree(analysis.trials[1].logdir)
|
||||
shutil.rmtree(analysis.trials[2].logdir)
|
||||
shutil.rmtree(analysis.trials[3].logdir)
|
||||
shutil.rmtree(analysis.trials[4].logdir)
|
||||
|
||||
cp0 = analysis.get_best_checkpoint(analysis.trials[0], "score", "max")
|
||||
cp1 = analysis.get_best_checkpoint(analysis.trials[1], "score", "max")
|
||||
cp2 = analysis.get_best_checkpoint(analysis.trials[2], "score", "max")
|
||||
cp3 = analysis.get_best_checkpoint(analysis.trials[3], "score", "max")
|
||||
cp4 = analysis.get_best_checkpoint(analysis.trials[4], "score", "max")
|
||||
|
||||
def _load_cp(cd):
|
||||
with open(os.path.join(cd, "checkpoint.json"), "rt") as f:
|
||||
return json.load(f)
|
||||
|
||||
with patch("ray.tune.cloud.delete_at_uri", self._delete_at_uri), patch(
|
||||
"ray.tune.cloud.download_from_uri", self._fake_download_from_uri
|
||||
), patch(
|
||||
"ray.ml.checkpoint.download_from_uri", self._fake_download_from_uri
|
||||
), patch(
|
||||
"ray.tune.cloud.upload_to_uri", self._fake_upload_to_uri
|
||||
):
|
||||
#######
|
||||
# Case: Checkpoint exists on local dir. Copy to other local dir.
|
||||
other_local_dir = tempfile.mkdtemp()
|
||||
|
||||
cp0.save(other_local_dir)
|
||||
|
||||
self.assertTrue(os.path.exists(cp0.local_path))
|
||||
|
||||
cp_content = _load_cp(other_local_dir)
|
||||
self.assertEqual(cp_content["train_id"], 0)
|
||||
self.assertEqual(cp_content["score"], 9)
|
||||
|
||||
cp_content_2 = _load_cp(cp0.local_path)
|
||||
self.assertEqual(cp_content, cp_content_2)
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(other_local_dir)
|
||||
|
||||
#######
|
||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
||||
# store in experiment dir.
|
||||
|
||||
# Directory is empty / does not exist before
|
||||
self.assertFalse(os.path.exists(cp1.local_path))
|
||||
|
||||
# Save!
|
||||
cp1.save()
|
||||
|
||||
# Directory is not empty anymore
|
||||
self.assertTrue(os.listdir(cp1.local_path))
|
||||
cp_content = _load_cp(cp1.local_path)
|
||||
self.assertEqual(cp_content["train_id"], 1)
|
||||
self.assertEqual(cp_content["score"], 9)
|
||||
|
||||
#######
|
||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
||||
# store into other local dir.
|
||||
|
||||
# Directory is empty / does not exist before
|
||||
self.assertFalse(os.path.exists(cp2.local_path))
|
||||
|
||||
other_local_dir = tempfile.mkdtemp()
|
||||
# Save!
|
||||
cp2.save(other_local_dir)
|
||||
|
||||
# Directory still does not exist (as we save to other dir)
|
||||
self.assertFalse(os.path.exists(cp2.local_path))
|
||||
cp_content = _load_cp(other_local_dir)
|
||||
self.assertEqual(cp_content["train_id"], 2)
|
||||
self.assertEqual(cp_content["score"], 9)
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(other_local_dir)
|
||||
|
||||
#######
|
||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
||||
# and store onto other cloud.
|
||||
|
||||
# Local dir does not exist
|
||||
self.assertFalse(os.path.exists(cp3.local_path))
|
||||
# First cloud exists
|
||||
self.assertTrue(os.listdir(self.fake_cloud_dir))
|
||||
# Second cloud does not exist
|
||||
self.assertFalse(os.listdir(self.second_fake_cloud_dir))
|
||||
|
||||
# Trigger save
|
||||
cp3.save(self.second_cloud_target)
|
||||
|
||||
# Local dir now exists
|
||||
self.assertTrue(os.path.exists(cp3.local_path))
|
||||
# First cloud exists
|
||||
self.assertTrue(os.listdir(self.fake_cloud_dir))
|
||||
# Second cloud now exists!
|
||||
self.assertTrue(os.listdir(self.second_fake_cloud_dir))
|
||||
|
||||
cp_content = _load_cp(self.second_fake_cloud_dir)
|
||||
self.assertEqual(cp_content["train_id"], 3)
|
||||
self.assertEqual(cp_content["score"], 9)
|
||||
|
||||
#######
|
||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
||||
# store into local dir. Use new checkpoint abstractions for this.
|
||||
|
||||
temp_dir = cp4.to_directory(tempfile.mkdtemp())
|
||||
cp_content = _load_cp(temp_dir)
|
||||
self.assertEqual(cp_content["train_id"], 4)
|
||||
self.assertEqual(cp_content["score"], 9)
|
||||
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -309,7 +309,7 @@ class ExperimentAnalysisPropertySuite(unittest.TestCase):
|
|||
self.assertEqual(ea.best_trial, trials[2])
|
||||
self.assertEqual(ea.best_config, trials[2].config)
|
||||
self.assertEqual(ea.best_logdir, trials[2].logdir)
|
||||
self.assertEqual(ea.best_checkpoint.local_path, trials[2].checkpoint.value)
|
||||
self.assertEqual(ea.best_checkpoint._local_path, trials[2].checkpoint.value)
|
||||
self.assertTrue(all(ea.best_dataframe["trial_id"] == trials[2].trial_id))
|
||||
self.assertEqual(ea.results_df.loc[trials[2].trial_id, "res"], 309)
|
||||
self.assertEqual(ea.best_result["res"], 309)
|
||||
|
|
|
@ -100,117 +100,6 @@ class WandbIntegrationTest(unittest.TestCase):
|
|||
if WANDB_ENV_VAR in os.environ:
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
def testWandbLegacyLoggerConfig(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(
|
||||
trial_config, 0, "trial_0", "trainable", PlacementGroupFactory([{"CPU": 1}])
|
||||
)
|
||||
|
||||
if WANDB_ENV_VAR in os.environ:
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# Needs at least a project
|
||||
with self.assertRaises(ValueError):
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
|
||||
# No API key
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
with self.assertRaises(ValueError):
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
|
||||
# API Key in config
|
||||
trial_config["wandb"] = {"project": "test_project", "api_key": "1234"}
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
|
||||
|
||||
logger.close()
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key file
|
||||
with tempfile.NamedTemporaryFile("wt") as fp:
|
||||
fp.write("5678")
|
||||
fp.flush()
|
||||
|
||||
trial_config["wandb"] = {"project": "test_project", "api_key_file": fp.name}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
|
||||
|
||||
logger.close()
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key in env
|
||||
os.environ[WANDB_ENV_VAR] = "9012"
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
logger.close()
|
||||
|
||||
# From now on, the API key is in the env variable.
|
||||
|
||||
# Default configuration
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(logger.trial_process.kwargs["project"], "test_project")
|
||||
self.assertEqual(logger.trial_process.kwargs["id"], trial.trial_id)
|
||||
self.assertEqual(logger.trial_process.kwargs["name"], trial.trial_name)
|
||||
self.assertEqual(logger.trial_process.kwargs["group"], trial.trainable_name)
|
||||
self.assertIn("config", logger.trial_process._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
# log config.
|
||||
trial_config["wandb"] = {"project": "test_project", "log_config": True}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertNotIn("config", logger.trial_process._exclude)
|
||||
self.assertNotIn("metric", logger.trial_process._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
# Exclude metric.
|
||||
trial_config["wandb"] = {"project": "test_project", "excludes": ["metric"]}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertIn("config", logger.trial_process._exclude)
|
||||
self.assertIn("metric", logger.trial_process._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
def testWandbLegacyLoggerReporting(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(
|
||||
trial_config, 0, "trial_0", "trainable", PlacementGroupFactory([{"CPU": 1}])
|
||||
)
|
||||
|
||||
trial_config["wandb"] = {
|
||||
"project": "test_project",
|
||||
"api_key": "1234",
|
||||
"excludes": ["metric2"],
|
||||
}
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
|
||||
r1 = {
|
||||
"metric1": 0.8,
|
||||
"metric2": 1.4,
|
||||
"metric3": np.asarray(32.0),
|
||||
"metric4": np.float32(32.0),
|
||||
"const": "text",
|
||||
"config": trial_config,
|
||||
}
|
||||
|
||||
logger.on_result(r1)
|
||||
|
||||
logged = logger.trial_process.logs.get(timeout=10)
|
||||
self.assertIn("metric1", logged)
|
||||
self.assertNotIn("metric2", logged)
|
||||
self.assertIn("metric3", logged)
|
||||
self.assertIn("metric4", logged)
|
||||
self.assertNotIn("const", logged)
|
||||
self.assertNotIn("config", logged)
|
||||
|
||||
logger.close()
|
||||
|
||||
def testWandbLoggerConfig(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(
|
||||
|
|
|
@ -41,37 +41,26 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testSyncConfigDeprecation(self):
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||
sync_conf = tune.SyncConfig(node_sync_period=4, cloud_sync_period=8)
|
||||
self.assertEqual(sync_conf.sync_period, 4)
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||
tune.SyncConfig(node_sync_period=4, cloud_sync_period=8)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||
sync_conf = tune.SyncConfig(node_sync_period=4)
|
||||
self.assertEqual(sync_conf.sync_period, 4)
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||
tune.SyncConfig(node_sync_period=4)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||
sync_conf = tune.SyncConfig(cloud_sync_period=8)
|
||||
self.assertEqual(sync_conf.sync_period, 8)
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||
tune.SyncConfig(cloud_sync_period=8)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
sync_conf = tune.SyncConfig(
|
||||
sync_to_driver="a", sync_to_cloud="b", upload_dir=None
|
||||
)
|
||||
self.assertEqual(sync_conf.syncer, "a")
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
tune.SyncConfig(sync_to_driver="a", sync_to_cloud="b", upload_dir=None)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
sync_conf = tune.SyncConfig(
|
||||
sync_to_driver="a", sync_to_cloud="b", upload_dir="c"
|
||||
)
|
||||
self.assertEqual(sync_conf.syncer, "b")
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
tune.SyncConfig(sync_to_driver="a", sync_to_cloud="b", upload_dir="c")
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
sync_conf = tune.SyncConfig(sync_to_cloud="b", upload_dir=None)
|
||||
self.assertEqual(sync_conf.syncer, None)
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
tune.SyncConfig(sync_to_cloud="b", upload_dir=None)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
sync_conf = tune.SyncConfig(sync_to_driver="a", upload_dir="c")
|
||||
self.assertEqual(sync_conf.syncer, None)
|
||||
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||
tune.SyncConfig(sync_to_driver="a", upload_dir="c")
|
||||
|
||||
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
||||
def testCloudProperString(self):
|
||||
|
@ -159,7 +148,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
tmpdir2 = tempfile.mkdtemp()
|
||||
os.mkdir(os.path.join(tmpdir2, "foo"))
|
||||
|
||||
def sync_func(local, remote):
|
||||
def sync_func(local, remote, exclude=None):
|
||||
for filename in glob.glob(os.path.join(local, "*.json")):
|
||||
shutil.copy(filename, remote)
|
||||
|
||||
|
@ -187,7 +176,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
time.sleep(1)
|
||||
tune.report(score=i)
|
||||
|
||||
def counter(local, remote):
|
||||
def counter(local, remote, exclude=None):
|
||||
count_file = os.path.join(tmpdir, "count.txt")
|
||||
if not os.path.exists(count_file):
|
||||
count = 0
|
||||
|
@ -219,7 +208,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testClusterSyncFunction(self):
|
||||
def sync_func_driver(source, target):
|
||||
def sync_func_driver(source, target, exclude=None):
|
||||
assert ":" in source, "Source {} not a remote path.".format(source)
|
||||
assert ":" not in target, "Target is supposed to be local."
|
||||
with open(os.path.join(target, "test.log2"), "w") as f:
|
||||
|
@ -255,7 +244,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
def testNoSync(self):
|
||||
"""Sync should not run on a single node."""
|
||||
|
||||
def sync_func(source, target):
|
||||
def sync_func(source, target, exclude=None):
|
||||
pass
|
||||
|
||||
sync_config = tune.SyncConfig(syncer=sync_func)
|
||||
|
@ -409,7 +398,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
|||
sync_config = tune.SyncConfig(syncer=None)
|
||||
|
||||
# Create syncer callbacks
|
||||
callbacks = create_default_callbacks([], sync_config, loggers=None)
|
||||
callbacks = create_default_callbacks([], sync_config)
|
||||
syncer_callback = callbacks[-1]
|
||||
|
||||
# Sanity check that we got the syncer callback
|
||||
|
|
|
@ -1,59 +1,55 @@
|
|||
import warnings
|
||||
|
||||
from mock import patch
|
||||
import unittest
|
||||
|
||||
|
||||
class TestTrialExecutorInheritance(unittest.TestCase):
|
||||
@patch.object(warnings, "warn")
|
||||
def test_direct_inheritance_not_ok(self, mocked_warn):
|
||||
def test_direct_inheritance_not_ok(self):
|
||||
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
|
||||
class _MyTrialExecutor(TrialExecutor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def start_trial(self, trial):
|
||||
return True
|
||||
|
||||
def stop_trial(self, trial):
|
||||
pass
|
||||
|
||||
def restore(self, trial):
|
||||
pass
|
||||
|
||||
def save(self, trial):
|
||||
return None
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
return False
|
||||
|
||||
def debug_string(self):
|
||||
return "This is a debug string."
|
||||
|
||||
def export_trial_if_needed(self):
|
||||
return {}
|
||||
|
||||
def fetch_result(self):
|
||||
return []
|
||||
|
||||
def get_next_available_trial(self):
|
||||
return None
|
||||
|
||||
def get_running_trials(self):
|
||||
return []
|
||||
|
||||
msg = (
|
||||
"_MyTrialExecutor inherits from TrialExecutor, which is being "
|
||||
"deprecated. "
|
||||
"RFC: https://github.com/ray-project/ray/issues/17593. "
|
||||
"Please reach out on the Ray Github if you have any concerns."
|
||||
)
|
||||
mocked_warn.assert_called_once_with(msg, DeprecationWarning)
|
||||
|
||||
@patch.object(warnings, "warn")
|
||||
def test_indirect_inheritance_ok(self, mocked_warn):
|
||||
with self.assertRaisesRegex(DeprecationWarning, msg):
|
||||
|
||||
class _MyTrialExecutor(TrialExecutor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def start_trial(self, trial):
|
||||
return True
|
||||
|
||||
def stop_trial(self, trial):
|
||||
pass
|
||||
|
||||
def restore(self, trial):
|
||||
pass
|
||||
|
||||
def save(self, trial):
|
||||
return None
|
||||
|
||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||
return False
|
||||
|
||||
def debug_string(self):
|
||||
return "This is a debug string."
|
||||
|
||||
def export_trial_if_needed(self):
|
||||
return {}
|
||||
|
||||
def fetch_result(self):
|
||||
return []
|
||||
|
||||
def get_next_available_trial(self):
|
||||
return None
|
||||
|
||||
def get_running_trials(self):
|
||||
return []
|
||||
|
||||
def test_indirect_inheritance_ok(self):
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
|
||||
class _MyRayTrialExecutor(RayTrialExecutor):
|
||||
|
@ -61,5 +57,3 @@ class TestTrialExecutorInheritance(unittest.TestCase):
|
|||
|
||||
class _AnotherMyRayTrialExecutor(_MyRayTrialExecutor):
|
||||
pass
|
||||
|
||||
mocked_warn.assert_not_called()
|
||||
|
|
|
@ -753,7 +753,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
ray.init(num_cpus=3)
|
||||
|
||||
# This makes checkpointing take 2 seconds.
|
||||
def sync_up(source, target):
|
||||
def sync_up(source, target, exclude=None):
|
||||
time.sleep(2)
|
||||
return True
|
||||
|
||||
|
|
|
@ -50,7 +50,6 @@ from ray.tune.utils.util import (
|
|||
get_checkpoint_from_remote_node,
|
||||
delete_external_checkpoint,
|
||||
)
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -912,11 +911,6 @@ class Trainable:
|
|||
A dict that describes training progress.
|
||||
|
||||
"""
|
||||
if self._implements_method("_train") and log_once("_train"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._train is deprecated and is now removed. Override "
|
||||
"Trainable.step instead."
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self, tmp_checkpoint_dir: str):
|
||||
|
@ -957,11 +951,6 @@ class Trainable:
|
|||
>>> trainable.save_checkpoint("/tmp/bad_example") # doctest: +SKIP
|
||||
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
|
||||
"""
|
||||
if self._implements_method("_save") and log_once("_save"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._save is deprecated and is now removed. Override "
|
||||
"Trainable.save_checkpoint instead."
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def load_checkpoint(self, checkpoint: Union[Dict, str]):
|
||||
|
@ -1007,11 +996,6 @@ class Trainable:
|
|||
returned by `save_checkpoint`. The directory structure
|
||||
underneath the `checkpoint_dir` `save_checkpoint` is preserved.
|
||||
"""
|
||||
if self._implements_method("_restore") and log_once("_restore"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._restore is deprecated and is now removed. "
|
||||
"Override Trainable.load_checkpoint instead."
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def setup(self, config: Dict):
|
||||
|
@ -1024,11 +1008,6 @@ class Trainable:
|
|||
Copy of `self.config`.
|
||||
|
||||
"""
|
||||
if self._implements_method("_setup") and log_once("_setup"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._setup is deprecated and is now removed. Override "
|
||||
"Trainable.setup instead."
|
||||
)
|
||||
pass
|
||||
|
||||
def log_result(self, result: Dict):
|
||||
|
@ -1043,11 +1022,6 @@ class Trainable:
|
|||
Args:
|
||||
result: Training result returned by step().
|
||||
"""
|
||||
if self._implements_method("_log_result") and log_once("_log_result"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._log_result is deprecated and is now removed. "
|
||||
"Override Trainable.log_result instead."
|
||||
)
|
||||
self._result_logger.on_result(result)
|
||||
|
||||
def cleanup(self):
|
||||
|
@ -1061,11 +1035,6 @@ class Trainable:
|
|||
|
||||
.. versionadded:: 0.8.7
|
||||
"""
|
||||
if self._implements_method("_stop") and log_once("_stop"):
|
||||
raise DeprecationWarning(
|
||||
"Trainable._stop is deprecated and is now removed. Override "
|
||||
"Trainable.cleanup instead."
|
||||
)
|
||||
pass
|
||||
|
||||
def _export_model(self, export_formats: List[str], export_dir: str):
|
||||
|
|
|
@ -202,9 +202,8 @@ class Trial:
|
|||
Trials start in the PENDING state, and transition to RUNNING once started.
|
||||
On error it transitions to ERROR, otherwise TERMINATED on success.
|
||||
|
||||
There are resources allocated to each trial. It's preferred that resources
|
||||
are specified using PlacementGroupFactory, rather than through Resources,
|
||||
which is being deprecated.
|
||||
There are resources allocated to each trial. These should be specified
|
||||
using ``PlacementGroupFactory``.
|
||||
|
||||
Attributes:
|
||||
trainable_name: Name of the trainable object to be executed.
|
||||
|
@ -771,17 +770,7 @@ class Trial:
|
|||
if self.custom_dirname:
|
||||
generated_dirname = self.custom_dirname
|
||||
else:
|
||||
if "MAX_LEN_IDENTIFIER" in os.environ:
|
||||
logger.error(
|
||||
"The MAX_LEN_IDENTIFIER environment variable is "
|
||||
"deprecated and will be removed in the future. "
|
||||
"Use TUNE_MAX_LEN_IDENTIFIER instead."
|
||||
)
|
||||
MAX_LEN_IDENTIFIER = int(
|
||||
os.environ.get(
|
||||
"TUNE_MAX_LEN_IDENTIFIER", os.environ.get("MAX_LEN_IDENTIFIER", 130)
|
||||
)
|
||||
)
|
||||
MAX_LEN_IDENTIFIER = int(os.environ.get("TUNE_MAX_LEN_IDENTIFIER", "130"))
|
||||
generated_dirname = f"{str(self)}_{self.experiment_tag}"
|
||||
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
|
||||
generated_dirname += f"_{date_str()}"
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from abc import abstractmethod
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
import warnings
|
||||
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.tune import TuneError
|
||||
|
@ -25,13 +24,13 @@ class _WarnOnDirectInheritanceMeta(type):
|
|||
)
|
||||
and "TrialExecutor" in tuple(base.__name__ for base in bases)
|
||||
):
|
||||
deprecation_msg = (
|
||||
raise DeprecationWarning(
|
||||
f"{name} inherits from TrialExecutor, which is being "
|
||||
"deprecated. "
|
||||
"RFC: https://github.com/ray-project/ray/issues/17593. "
|
||||
"Please reach out on the Ray Github if you have any concerns."
|
||||
)
|
||||
warnings.warn(deprecation_msg, DeprecationWarning)
|
||||
|
||||
cls = super().__new__(mcls, name, bases, module, **kwargs)
|
||||
return cls
|
||||
|
||||
|
|
|
@ -314,13 +314,6 @@ class TrialRunner:
|
|||
|
||||
self._metric = metric
|
||||
|
||||
if "TRIALRUNNER_WALLTIME_LIMIT" in os.environ:
|
||||
raise ValueError(
|
||||
"The TRIALRUNNER_WALLTIME_LIMIT environment variable is "
|
||||
"deprecated. "
|
||||
"Use `tune.run(time_budget_s=limit)` instead."
|
||||
)
|
||||
|
||||
self._total_time = 0
|
||||
self._iteration = 0
|
||||
self._has_errored = False
|
||||
|
|
|
@ -162,7 +162,6 @@ def run(
|
|||
# == internal only ==
|
||||
_experiment_checkpoint_dir: Optional[str] = None,
|
||||
# Deprecated args
|
||||
queue_trials: Optional[bool] = None,
|
||||
loggers: Optional[Sequence[Type[Logger]]] = None,
|
||||
_remote: Optional[bool] = None,
|
||||
) -> ExperimentAnalysis:
|
||||
|
@ -349,27 +348,6 @@ def run(
|
|||
Raises:
|
||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||
"""
|
||||
|
||||
# To be removed in 1.9.
|
||||
if queue_trials is not None:
|
||||
raise DeprecationWarning(
|
||||
"`queue_trials` has been deprecated and is replaced by "
|
||||
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
|
||||
"Per default at least one Trial is queued at all times, "
|
||||
"so you likely don't need to change anything other than "
|
||||
"removing this argument from your call to `tune.run()`"
|
||||
)
|
||||
|
||||
# Starting deprecation in ray 1.10.
|
||||
if os.environ.get("TUNE_TRIAL_RESULT_WAIT_TIME_S") is not None:
|
||||
warnings.warn("`TUNE_TRIAL_RESULT_WAIT_TIME_S` is deprecated.")
|
||||
|
||||
if os.environ.get("TUNE_TRIAL_STARTUP_GRACE_PERIOD") is not None:
|
||||
warnings.warn("`TUNE_TRIAL_STARTUP_GRACE_PERIOD` is deprecated.")
|
||||
|
||||
if os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S") is not None:
|
||||
warnings.warn("`TUNE_PLACEMENT_GROUP_WAIT_S` is deprecated.")
|
||||
|
||||
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
||||
# remote_run_kwargs must be defined before any other
|
||||
# code is ran to ensure that at this point,
|
||||
|
@ -439,8 +417,8 @@ def run(
|
|||
all_start = time.time()
|
||||
|
||||
if loggers:
|
||||
# Raise DeprecationWarning in 1.9, remove in 1.10/1.11
|
||||
warnings.warn(
|
||||
# Deprecated: Remove in Ray > 1.13
|
||||
raise DeprecationWarning(
|
||||
"The `loggers` argument is deprecated. Please pass the respective "
|
||||
"`LoggerCallback` classes to the `callbacks` argument instead. "
|
||||
"See https://docs.ray.io/en/latest/tune/api_docs/logging.html"
|
||||
|
@ -642,9 +620,7 @@ def run(
|
|||
)
|
||||
|
||||
# Create syncer callbacks
|
||||
callbacks = create_default_callbacks(
|
||||
callbacks, sync_config, metric=metric, loggers=loggers
|
||||
)
|
||||
callbacks = create_default_callbacks(callbacks, sync_config, metric=metric)
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg,
|
||||
|
@ -803,7 +779,6 @@ def run_experiments(
|
|||
raise_on_failed_trial: bool = True,
|
||||
concurrent: bool = True,
|
||||
# Deprecated args.
|
||||
queue_trials: Optional[bool] = None,
|
||||
callbacks: Optional[Sequence[Callback]] = None,
|
||||
_remote: Optional[bool] = None,
|
||||
):
|
||||
|
@ -822,16 +797,6 @@ def run_experiments(
|
|||
List of Trial objects, holding data for each executed trial.
|
||||
|
||||
"""
|
||||
# To be removed in 1.9.
|
||||
if queue_trials is not None:
|
||||
raise DeprecationWarning(
|
||||
"`queue_trials` has been deprecated and is replaced by "
|
||||
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
|
||||
"Per default at least one Trial is queued at all times, "
|
||||
"so you likely don't need to change anything other than "
|
||||
"removing this argument from your call to `tune.run()`"
|
||||
)
|
||||
|
||||
if _remote is None:
|
||||
_remote = ray.util.client.ray.is_connected()
|
||||
|
||||
|
|
|
@ -2,9 +2,7 @@ from ray.tune.utils.util import (
|
|||
deep_update,
|
||||
date_str,
|
||||
flatten_dict,
|
||||
get_pinned_object,
|
||||
merge_dicts,
|
||||
pin_in_object_store,
|
||||
unflattened_lookup,
|
||||
UtilMonitor,
|
||||
validate_save_restore,
|
||||
|
@ -20,9 +18,7 @@ __all__ = [
|
|||
"deep_update",
|
||||
"date_str",
|
||||
"flatten_dict",
|
||||
"get_pinned_object",
|
||||
"merge_dicts",
|
||||
"pin_in_object_store",
|
||||
"unflattened_lookup",
|
||||
"UtilMonitor",
|
||||
"validate_save_restore",
|
||||
|
|
|
@ -9,11 +9,9 @@ from ray.tune.syncer import SyncConfig, detect_cluster_syncer
|
|||
from ray.tune.logger import (
|
||||
CSVLoggerCallback,
|
||||
CSVLogger,
|
||||
LoggerCallback,
|
||||
JsonLoggerCallback,
|
||||
JsonLogger,
|
||||
LegacyLoggerCallback,
|
||||
Logger,
|
||||
TBXLoggerCallback,
|
||||
TBXLogger,
|
||||
)
|
||||
|
@ -25,7 +23,6 @@ logger = logging.getLogger(__name__)
|
|||
def create_default_callbacks(
|
||||
callbacks: Optional[List[Callback]],
|
||||
sync_config: SyncConfig,
|
||||
loggers: Optional[List[Logger]],
|
||||
metric: Optional[str] = None,
|
||||
):
|
||||
"""Create default callbacks for `tune.run()`.
|
||||
|
@ -69,23 +66,6 @@ def create_default_callbacks(
|
|||
last_logger_index = None
|
||||
syncer_index = None
|
||||
|
||||
# Deprecate: 1.9
|
||||
# Create LegacyLoggerCallback for passed Logger classes
|
||||
if loggers:
|
||||
add_loggers = []
|
||||
for trial_logger in loggers:
|
||||
if isinstance(trial_logger, LoggerCallback):
|
||||
callbacks.append(trial_logger)
|
||||
elif isinstance(trial_logger, type) and issubclass(trial_logger, Logger):
|
||||
add_loggers.append(trial_logger)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid value passed to `loggers` argument of "
|
||||
f"`tune.run()`: {trial_logger}"
|
||||
)
|
||||
if add_loggers:
|
||||
callbacks.append(LegacyLoggerCallback(add_loggers))
|
||||
|
||||
# Check if we have a CSV, JSON and TensorboardX logger
|
||||
for i, callback in enumerate(callbacks):
|
||||
if isinstance(callback, LegacyLoggerCallback):
|
||||
|
@ -147,12 +127,7 @@ def create_default_callbacks(
|
|||
and last_logger_index is not None
|
||||
and syncer_index < last_logger_index
|
||||
):
|
||||
if (
|
||||
not has_csv_logger or not has_json_logger or not has_tbx_logger
|
||||
) and not loggers:
|
||||
# Only raise the warning if the loggers were passed by the user.
|
||||
# (I.e. don't warn if this was automatic behavior and they only
|
||||
# passed a customer SyncerCallback).
|
||||
if not has_csv_logger or not has_json_logger or not has_tbx_logger:
|
||||
raise ValueError(
|
||||
"The `SyncerCallback` you passed to `tune.run()` came before "
|
||||
"at least one `LoggerCallback`. Syncing should be done "
|
||||
|
|
|
@ -243,16 +243,6 @@ def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]]):
|
|||
spec = spec._asdict()
|
||||
|
||||
spec = spec.copy()
|
||||
extra_custom = spec.pop("extra_custom_resources", {}) or {}
|
||||
|
||||
if any(k.startswith("extra_") and spec[k] for k in spec) or any(
|
||||
extra_custom[k] for k in extra_custom
|
||||
):
|
||||
raise ValueError(
|
||||
"Passing `extra_*` resource requirements to `resources_per_trial` "
|
||||
"is deprecated. Please use a `PlacementGroupFactory` object "
|
||||
"to define your resource requirements instead."
|
||||
)
|
||||
|
||||
cpus = spec.pop("cpu", 0.0)
|
||||
gpus = spec.pop("gpu", 0.0)
|
||||
|
|
|
@ -121,20 +121,6 @@ class UtilMonitor(Thread):
|
|||
self.stopped = True
|
||||
|
||||
|
||||
def pin_in_object_store(obj):
|
||||
"""Deprecated, use ray.put(value) instead."""
|
||||
|
||||
obj_ref = ray.put(obj)
|
||||
_pinned_objects.append(obj_ref)
|
||||
return obj_ref
|
||||
|
||||
|
||||
def get_pinned_object(pinned_id):
|
||||
"""Deprecated."""
|
||||
|
||||
return ray.get(pinned_id)
|
||||
|
||||
|
||||
def retry_fn(
|
||||
fn: Callable[[], Any],
|
||||
exception_type: Type[Exception],
|
||||
|
@ -456,7 +442,6 @@ def wait_for_gpu(
|
|||
retry: Number of times to check GPU limit. Sleeps `delay_s`
|
||||
seconds between checks.
|
||||
delay_s: Seconds to wait before check.
|
||||
gpu_memory_limit: Deprecated.
|
||||
|
||||
Returns:
|
||||
bool: True if free.
|
||||
|
@ -476,8 +461,7 @@ def wait_for_gpu(
|
|||
tune.run(tune_func, resources_per_trial={"GPU": 1}, num_samples=10)
|
||||
"""
|
||||
GPUtil = _import_gputil()
|
||||
if gpu_memory_limit:
|
||||
raise ValueError("'gpu_memory_limit' is deprecated. Use 'target_util' instead.")
|
||||
|
||||
if GPUtil is None:
|
||||
raise RuntimeError("GPUtil must be installed if calling `wait_for_gpu`.")
|
||||
|
||||
|
@ -700,11 +684,3 @@ def validate_warmstart(
|
|||
+ " and points_to_evaluate {}".format(points_to_evaluate)
|
||||
+ " do not match."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
print(X)
|
||||
result = get_pinned_object(X)
|
||||
print(result)
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
import pandas as pd
|
||||
from pandas.api.types import is_string_dtype, is_numeric_dtype
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
from ray.tune.utils import flatten_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.warning("This module will be deprecated in a future version of Tune.")
|
||||
|
||||
|
||||
def _parse_results(res_path):
|
||||
res_dict = {}
|
||||
try:
|
||||
with open(res_path) as f:
|
||||
# Get last line in file
|
||||
for line in f:
|
||||
pass
|
||||
res_dict = flatten_dict(json.loads(line.strip()))
|
||||
except Exception:
|
||||
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
|
||||
return res_dict
|
||||
|
||||
|
||||
def _parse_configs(cfg_path):
|
||||
try:
|
||||
with open(cfg_path) as f:
|
||||
cfg_dict = flatten_dict(json.load(f))
|
||||
except Exception:
|
||||
logger.exception("Config parsing failed.")
|
||||
return cfg_dict
|
||||
|
||||
|
||||
def _resolve(directory, result_fname):
|
||||
try:
|
||||
resultp = osp.join(directory, result_fname)
|
||||
res_dict = _parse_results(resultp)
|
||||
cfgp = osp.join(directory, "params.json")
|
||||
cfg_dict = _parse_configs(cfgp)
|
||||
cfg_dict.update(res_dict)
|
||||
return cfg_dict
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_results_to_df(directory, result_name="result.json"):
|
||||
exp_directories = [
|
||||
dirpath
|
||||
for dirpath, dirs, files in os.walk(directory)
|
||||
for f in files
|
||||
if f == result_name
|
||||
]
|
||||
data = [_resolve(d, result_name) for d in exp_directories]
|
||||
data = [d for d in data if d]
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
def generate_plotly_dim_dict(df, field):
|
||||
dim_dict = {}
|
||||
dim_dict["label"] = field
|
||||
column = df[field]
|
||||
if is_numeric_dtype(column):
|
||||
dim_dict["values"] = column
|
||||
elif is_string_dtype(column):
|
||||
texts = column.unique()
|
||||
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0] for x in column]
|
||||
dim_dict["tickvals"] = list(range(len(texts)))
|
||||
dim_dict["ticktext"] = texts
|
||||
else:
|
||||
raise Exception("Unidentifiable Type")
|
||||
|
||||
return dim_dict
|
|
@ -15,7 +15,7 @@ import argparse
|
|||
import os
|
||||
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.logger import Logger, LegacyLoggerCallback
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -119,7 +119,11 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
results = tune.run(
|
||||
args.run, config=config, stop=stop, verbose=2, loggers=[MyPrintLogger]
|
||||
args.run,
|
||||
config=config,
|
||||
stop=stop,
|
||||
verbose=2,
|
||||
callbacks=[LegacyLoggerCallback(MyPrintLogger)],
|
||||
)
|
||||
|
||||
if args.as_test:
|
||||
|
|
|
@ -29,7 +29,6 @@ import os
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import function
|
||||
from ray.rllib.examples.env.windy_maze_env import WindyMazeEnv, HierarchicalWindyMazeEnv
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
|
@ -107,7 +106,7 @@ if __name__ == "__main__":
|
|||
{"gamma": 0.0},
|
||||
),
|
||||
},
|
||||
"policy_mapping_fn": function(policy_mapping_fn),
|
||||
"policy_mapping_fn": policy_mapping_fn,
|
||||
},
|
||||
"framework": args.framework,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
|
|
Loading…
Add table
Reference in a new issue