mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[tune] Next deprecation cycle (#24076)
Rolling out next deprecation cycle: - DeprecationWarnings that were `warnings.warn` or `logger.warn` before are now raised errors - Raised Deprecation warnings are now removed - Notably, this involves deprecating the TrialCheckpoint functionality and associated cloud tests - Added annotations to deprecation warning for when to fully remove
This commit is contained in:
parent
2c772a421f
commit
c0ec20dc3a
49 changed files with 170 additions and 1556 deletions
|
@ -45,9 +45,3 @@ ExperimentAnalysis (tune.ExperimentAnalysis)
|
||||||
|
|
||||||
.. autoclass:: ray.tune.ExperimentAnalysis
|
.. autoclass:: ray.tune.ExperimentAnalysis
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
TrialCheckpoint (tune.cloud.TrialCheckpoint)
|
|
||||||
--------------------------------------------
|
|
||||||
|
|
||||||
.. autoclass:: ray.tune.cloud.TrialCheckpoint
|
|
||||||
:members:
|
|
|
@ -75,6 +75,7 @@ def is_non_local_path_uri(uri: str) -> bool:
|
||||||
return True
|
return True
|
||||||
# Keep manual check for prefixes for backwards compatibility with the
|
# Keep manual check for prefixes for backwards compatibility with the
|
||||||
# TrialCheckpoint class. Remove once fully deprecated.
|
# TrialCheckpoint class. Remove once fully deprecated.
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
if any(uri.startswith(p) for p in ALLOWED_REMOTE_PREFIXES):
|
if any(uri.startswith(p) for p in ALLOWED_REMOTE_PREFIXES):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -56,14 +56,6 @@ py_test(
|
||||||
tags = ["team:ml", "client", "py37", "exclusive", "tests_dir_C"]
|
tags = ["team:ml", "client", "py37", "exclusive", "tests_dir_C"]
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "test_cloud",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["tests/test_cloud.py"],
|
|
||||||
deps = [":tune_lib"],
|
|
||||||
tags = ["team:ml", "exclusive", "tests_dir_C"],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_cluster",
|
name = "test_cluster",
|
||||||
size = "large",
|
size = "large",
|
||||||
|
|
|
@ -15,8 +15,6 @@ from ray.tune.session import (
|
||||||
get_trial_name,
|
get_trial_name,
|
||||||
get_trial_id,
|
get_trial_id,
|
||||||
get_trial_resources,
|
get_trial_resources,
|
||||||
make_checkpoint_dir,
|
|
||||||
save_checkpoint,
|
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
is_session_enabled,
|
is_session_enabled,
|
||||||
)
|
)
|
||||||
|
|
|
@ -843,12 +843,11 @@ class ExperimentAnalysis:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
@Deprecated
|
@Deprecated
|
||||||
class Analysis(ExperimentAnalysis):
|
class Analysis(ExperimentAnalysis):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
if log_once("durable_deprecated"):
|
raise DeprecationWarning(
|
||||||
logger.warning(
|
"The `Analysis` class is being "
|
||||||
"DeprecationWarning: The `Analysis` class is being "
|
"deprecated. Please use `ExperimentAnalysis` instead."
|
||||||
"deprecated. Please use `ExperimentAnalysis` instead."
|
)
|
||||||
)
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
|
@ -1,22 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
import warnings
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ray import logger
|
|
||||||
from ray.ml.checkpoint import (
|
from ray.ml.checkpoint import (
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
_get_local_path,
|
_get_local_path,
|
||||||
_get_external_path,
|
_get_external_path,
|
||||||
)
|
)
|
||||||
from ray.ml.utils.remote_storage import (
|
|
||||||
download_from_uri,
|
|
||||||
delete_at_uri,
|
|
||||||
upload_to_uri,
|
|
||||||
is_non_local_path_uri,
|
|
||||||
)
|
|
||||||
from ray.util import log_once
|
|
||||||
from ray.util.annotations import Deprecated
|
from ray.util.annotations import Deprecated
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,246 +68,8 @@ class _TrialCheckpoint(os.PathLike):
|
||||||
f">"
|
f">"
|
||||||
)
|
)
|
||||||
|
|
||||||
def download(
|
|
||||||
self,
|
|
||||||
cloud_path: Optional[str] = None,
|
|
||||||
local_path: Optional[str] = None,
|
|
||||||
overwrite: bool = False,
|
|
||||||
) -> str:
|
|
||||||
"""Download checkpoint from cloud.
|
|
||||||
|
|
||||||
This will fetch the checkpoint directory from cloud storage
|
|
||||||
and save it to ``local_path``.
|
|
||||||
|
|
||||||
If a ``local_path`` argument is provided and ``self.local_path``
|
|
||||||
is unset, it will be set to ``local_path``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cloud_path: Cloud path to load checkpoint from.
|
|
||||||
Defaults to ``self.cloud_path``.
|
|
||||||
local_path: Local path to save checkpoint at.
|
|
||||||
Defaults to ``self.local_path``.
|
|
||||||
overwrite: If True, overwrites potential existing local
|
|
||||||
checkpoint. If False, exits if ``self.local_dir`` already
|
|
||||||
exists and has files in it.
|
|
||||||
|
|
||||||
"""
|
|
||||||
cloud_path = cloud_path or self.cloud_path
|
|
||||||
if not cloud_path:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not download trial checkpoint: No cloud "
|
|
||||||
"path is set. Fix this by either passing a "
|
|
||||||
"`cloud_path` to your call to `download()` or by "
|
|
||||||
"passing a `cloud_path` into the constructor. The latter "
|
|
||||||
"should automatically be done if you pass the correct "
|
|
||||||
"`tune.SyncConfig`."
|
|
||||||
)
|
|
||||||
|
|
||||||
local_path = local_path or self.local_path
|
|
||||||
|
|
||||||
if not local_path:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not download trial checkpoint: No local "
|
|
||||||
"path is set. Fix this by either passing a "
|
|
||||||
"`local_path` to your call to `download()` or by "
|
|
||||||
"passing a `local_path` into the constructor."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only update local path if unset
|
|
||||||
if not self.local_path:
|
|
||||||
self.local_path = local_path
|
|
||||||
|
|
||||||
if (
|
|
||||||
not overwrite
|
|
||||||
and os.path.exists(local_path)
|
|
||||||
and len(os.listdir(local_path)) > 0
|
|
||||||
):
|
|
||||||
# Local path already exists and we should not overwrite,
|
|
||||||
# so return.
|
|
||||||
return local_path
|
|
||||||
|
|
||||||
# Else: Actually download
|
|
||||||
|
|
||||||
# Delete existing dir
|
|
||||||
shutil.rmtree(local_path, ignore_errors=True)
|
|
||||||
# Re-create
|
|
||||||
os.makedirs(local_path, 0o755, exist_ok=True)
|
|
||||||
|
|
||||||
# Here we trigger the actual download
|
|
||||||
download_from_uri(uri=cloud_path, local_path=local_path)
|
|
||||||
|
|
||||||
# Local dir exists and is not empty
|
|
||||||
return local_path
|
|
||||||
|
|
||||||
def upload(
|
|
||||||
self,
|
|
||||||
cloud_path: Optional[str] = None,
|
|
||||||
local_path: Optional[str] = None,
|
|
||||||
clean_before: bool = False,
|
|
||||||
):
|
|
||||||
"""Upload checkpoint to cloud.
|
|
||||||
|
|
||||||
This will push the checkpoint directory from local storage
|
|
||||||
to ``cloud_path``.
|
|
||||||
|
|
||||||
If a ``cloud_path`` argument is provided and ``self.cloud_path``
|
|
||||||
is unset, it will be set to ``cloud_path``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cloud_path: Cloud path to load checkpoint from.
|
|
||||||
Defaults to ``self.cloud_path``.
|
|
||||||
local_path: Local path to save checkpoint at.
|
|
||||||
Defaults to ``self.local_path``.
|
|
||||||
clean_before: If True, deletes potentially existing
|
|
||||||
cloud bucket before storing new data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
local_path = local_path or self.local_path
|
|
||||||
if not local_path:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not upload trial checkpoint: No local "
|
|
||||||
"path is set. Fix this by either passing a "
|
|
||||||
"`local_path` to your call to `upload()` or by "
|
|
||||||
"passing a `local_path` into the constructor."
|
|
||||||
)
|
|
||||||
|
|
||||||
cloud_path = cloud_path or self.cloud_path
|
|
||||||
if not cloud_path:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not download trial checkpoint: No cloud "
|
|
||||||
"path is set. Fix this by either passing a "
|
|
||||||
"`cloud_path` to your call to `download()` or by "
|
|
||||||
"passing a `cloud_path` into the constructor. The latter "
|
|
||||||
"should automatically be done if you pass the correct "
|
|
||||||
"`tune.SyncConfig`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.cloud_path:
|
|
||||||
self.cloud_path = cloud_path
|
|
||||||
|
|
||||||
if clean_before:
|
|
||||||
logger.info(f"Clearing bucket contents before upload: {cloud_path}")
|
|
||||||
delete_at_uri(cloud_path)
|
|
||||||
|
|
||||||
# Actually upload
|
|
||||||
upload_to_uri(local_path, cloud_path)
|
|
||||||
|
|
||||||
return cloud_path
|
|
||||||
|
|
||||||
def save(self, path: Optional[str] = None, force_download: bool = False):
|
|
||||||
"""Save trial checkpoint to directory or cloud storage.
|
|
||||||
|
|
||||||
If the ``path`` is a local target and the checkpoint already exists
|
|
||||||
on local storage, the local directory is copied. Else, the checkpoint
|
|
||||||
is downloaded from cloud storage.
|
|
||||||
|
|
||||||
If the ``path`` is a cloud target and the checkpoint does not already
|
|
||||||
exist on local storage, it is downloaded from cloud storage before.
|
|
||||||
That way checkpoints can be transferred across cloud storage providers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to save checkpoint at. If empty,
|
|
||||||
the default cloud storage path is saved to the default
|
|
||||||
local directory.
|
|
||||||
force_download: If ``True``, forces (re-)download of
|
|
||||||
the checkpoint. Defaults to ``False``.
|
|
||||||
"""
|
|
||||||
temp_dirs = set()
|
|
||||||
# Per default, save cloud checkpoint
|
|
||||||
if not path:
|
|
||||||
if self.cloud_path and self.local_path:
|
|
||||||
path = self.local_path
|
|
||||||
elif not self.cloud_path:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot save trial checkpoint: No cloud path "
|
|
||||||
"found. If the checkpoint is already on the node, "
|
|
||||||
"you can pass a `path` argument to save it at another "
|
|
||||||
"location."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No self.local_path
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot save trial checkpoint: No target path "
|
|
||||||
"specified and no default local directory available. "
|
|
||||||
"Please pass a `path` argument to `save()`."
|
|
||||||
)
|
|
||||||
elif not self.local_path and not self.cloud_path:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot save trial checkpoint to cloud target "
|
|
||||||
f"`{path}`: No existing local or cloud path was "
|
|
||||||
f"found. This indicates an error when loading "
|
|
||||||
f"the checkpoints. Please report this issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_non_local_path_uri(path):
|
|
||||||
# Storing on cloud
|
|
||||||
if not self.local_path:
|
|
||||||
# No local copy, yet. Download to temp dir
|
|
||||||
local_path = tempfile.mkdtemp(prefix="tune_checkpoint_")
|
|
||||||
temp_dirs.add(local_path)
|
|
||||||
else:
|
|
||||||
local_path = self.local_path
|
|
||||||
|
|
||||||
if self.cloud_path:
|
|
||||||
# Do not update local path as it might be a temp file
|
|
||||||
local_path = self.download(
|
|
||||||
local_path=local_path, overwrite=force_download
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove pointer to a temporary directory
|
|
||||||
if self.local_path in temp_dirs:
|
|
||||||
self.local_path = None
|
|
||||||
|
|
||||||
# We should now have a checkpoint available locally
|
|
||||||
if not os.path.exists(local_path) or len(os.listdir(local_path)) == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"No checkpoint found in directory `{local_path}` after "
|
|
||||||
f"download - maybe the bucket is empty or downloading "
|
|
||||||
f"failed?"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only update cloud path if it wasn't set before
|
|
||||||
cloud_path = self.upload(
|
|
||||||
cloud_path=path, local_path=local_path, clean_before=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up temporary directories
|
|
||||||
for temp_dir in temp_dirs:
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
return cloud_path
|
|
||||||
|
|
||||||
local_path_exists = (
|
|
||||||
self.local_path
|
|
||||||
and os.path.exists(self.local_path)
|
|
||||||
and len(os.listdir(self.local_path)) > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Else: path is a local target
|
|
||||||
if self.local_path and local_path_exists and not force_download:
|
|
||||||
# If we have a local copy, use it
|
|
||||||
|
|
||||||
if path == self.local_path:
|
|
||||||
# Nothing to do
|
|
||||||
return self.local_path
|
|
||||||
|
|
||||||
# Both local, just copy tree
|
|
||||||
if os.path.exists(path):
|
|
||||||
shutil.rmtree(path)
|
|
||||||
|
|
||||||
shutil.copytree(self.local_path, path)
|
|
||||||
return path
|
|
||||||
|
|
||||||
# Else: Download
|
|
||||||
try:
|
|
||||||
return self.download(local_path=path, overwrite=force_download)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot save trial checkpoint to local target as downloading "
|
|
||||||
"from cloud failed. Did you pass the correct `SyncConfig`?"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
@Deprecated
|
@Deprecated
|
||||||
class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -394,13 +145,11 @@ class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
||||||
local_path: Optional[str] = None,
|
local_path: Optional[str] = None,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
if log_once("trial_checkpoint_download_deprecated"):
|
# Deprecated: Remove whole class in Ray > 1.13
|
||||||
warnings.warn(
|
raise DeprecationWarning(
|
||||||
"`checkpoint.download()` is deprecated and will be removed in "
|
"`checkpoint.download()` is deprecated and will be removed in "
|
||||||
"the future. Please use `checkpoint.to_directory()` instead.",
|
"the future. Please use `checkpoint.to_directory()` instead."
|
||||||
DeprecationWarning,
|
)
|
||||||
)
|
|
||||||
return _TrialCheckpoint.download(self, cloud_path, local_path, overwrite)
|
|
||||||
|
|
||||||
def upload(
|
def upload(
|
||||||
self,
|
self,
|
||||||
|
@ -408,20 +157,16 @@ class TrialCheckpoint(Checkpoint, _TrialCheckpoint):
|
||||||
local_path: Optional[str] = None,
|
local_path: Optional[str] = None,
|
||||||
clean_before: bool = False,
|
clean_before: bool = False,
|
||||||
):
|
):
|
||||||
if log_once("trial_checkpoint_upload_deprecated"):
|
# Deprecated: Remove whole class in Ray > 1.13
|
||||||
warnings.warn(
|
raise DeprecationWarning(
|
||||||
"`checkpoint.upload()` is deprecated and will be removed in "
|
"`checkpoint.upload()` is deprecated and will be removed in "
|
||||||
"the future. Please use `checkpoint.to_uri()` instead.",
|
"the future. Please use `checkpoint.to_uri()` instead."
|
||||||
DeprecationWarning,
|
)
|
||||||
)
|
|
||||||
return _TrialCheckpoint.upload(self, cloud_path, local_path, clean_before)
|
|
||||||
|
|
||||||
def save(self, path: Optional[str] = None, force_download: bool = False):
|
def save(self, path: Optional[str] = None, force_download: bool = False):
|
||||||
if log_once("trial_checkpoint_save_deprecated"):
|
# Deprecated: Remove whole class in Ray > 1.13
|
||||||
warnings.warn(
|
raise DeprecationWarning(
|
||||||
"`checkpoint.save()` is deprecated and will be removed in "
|
"`checkpoint.save()` is deprecated and will be removed in "
|
||||||
"the future. Please use `checkpoint.to_directory()` or"
|
"the future. Please use `checkpoint.to_directory()` or"
|
||||||
"`checkpoint.to_uri()` instead.",
|
"`checkpoint.to_uri()` instead."
|
||||||
DeprecationWarning,
|
)
|
||||||
)
|
|
||||||
return _TrialCheckpoint.save(self, path, force_download)
|
|
||||||
|
|
|
@ -3,38 +3,35 @@ from typing import Callable, Type, Union
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from ray.tune.trainable import Trainable
|
from ray.tune.trainable import Trainable
|
||||||
from ray.util import log_once
|
|
||||||
|
|
||||||
from ray.util.annotations import Deprecated
|
from ray.util.annotations import Deprecated
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
@Deprecated
|
@Deprecated
|
||||||
class DurableTrainable(Trainable):
|
class DurableTrainable(Trainable):
|
||||||
_sync_function_tpl = None
|
_sync_function_tpl = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
if log_once("durable_deprecated"):
|
raise DeprecationWarning(
|
||||||
logger.warning(
|
"DeprecationWarning: The `DurableTrainable` class is being "
|
||||||
"DeprecationWarning: The `DurableTrainable` class is being "
|
"deprecated. Instead, all Trainables are durable by default "
|
||||||
"deprecated. Instead, all Trainables are durable by default "
|
"if you provide an `upload_dir`. You'll likely only need to "
|
||||||
"if you provide an `upload_dir`. You'll likely only need to "
|
"remove the call to `tune.durable()` or directly inherit from "
|
||||||
"remove the call to `tune.durable()` or directly inherit from "
|
"`Trainable` instead of `DurableTrainable` for class "
|
||||||
"`Trainable` instead of `DurableTrainable` for class "
|
"trainables to make your code forward-compatible."
|
||||||
"trainables to make your code forward-compatible."
|
)
|
||||||
)
|
|
||||||
super(DurableTrainable, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
@Deprecated
|
@Deprecated
|
||||||
def durable(trainable: Union[str, Type[Trainable], Callable]):
|
def durable(trainable: Union[str, Type[Trainable], Callable]):
|
||||||
if log_once("durable_deprecated"):
|
raise DeprecationWarning(
|
||||||
logger.warning(
|
"DeprecationWarning: `tune.durable()` is being deprecated."
|
||||||
"DeprecationWarning: `tune.durable()` is being deprecated."
|
"Instead, all Trainables are durable by default if "
|
||||||
"Instead, all Trainables are durable by default if "
|
"you provide an `upload_dir`. You'll likely only need to remove "
|
||||||
"you provide an `upload_dir`. You'll likely only need to remove "
|
"the call to `tune.durable()` to make your code "
|
||||||
"the call to `tune.durable()` to make your code "
|
"forward-compatible."
|
||||||
"forward-compatible."
|
)
|
||||||
)
|
|
||||||
return trainable
|
|
||||||
|
|
|
@ -63,7 +63,6 @@ if __name__ == "__main__":
|
||||||
algo = SigOptSearch(
|
algo = SigOptSearch(
|
||||||
space,
|
space,
|
||||||
name="SigOpt Example Experiment",
|
name="SigOpt Example Experiment",
|
||||||
max_concurrent=1,
|
|
||||||
metric="mean_loss",
|
metric="mean_loss",
|
||||||
mode="min",
|
mode="min",
|
||||||
)
|
)
|
||||||
|
|
|
@ -62,7 +62,6 @@ if __name__ == "__main__":
|
||||||
space,
|
space,
|
||||||
name="SigOpt Example Multi Objective Experiment",
|
name="SigOpt Example Multi Objective Experiment",
|
||||||
observation_budget=4 if args.smoke_test else 100,
|
observation_budget=4 if args.smoke_test else 100,
|
||||||
max_concurrent=1,
|
|
||||||
metric=["average", "std", "sharpe"],
|
metric=["average", "std", "sharpe"],
|
||||||
mode=["max", "min", "obs"],
|
mode=["max", "min", "obs"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -89,7 +89,6 @@ if __name__ == "__main__":
|
||||||
connection=conn,
|
connection=conn,
|
||||||
experiment_id=experiment.id,
|
experiment_id=experiment.id,
|
||||||
name="SigOpt Example Existing Experiment",
|
name="SigOpt Example Existing Experiment",
|
||||||
max_concurrent=1,
|
|
||||||
metric=["average", "std"],
|
metric=["average", "std"],
|
||||||
mode=["obs", "min"],
|
mode=["obs", "min"],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,9 +19,3 @@ def set_keras_threads(threads):
|
||||||
# is heavily parallelized across multiple cores.
|
# is heavily parallelized across multiple cores.
|
||||||
tf.config.threading.set_inter_op_parallelism_threads(threads)
|
tf.config.threading.set_inter_op_parallelism_threads(threads)
|
||||||
tf.config.threading.set_intra_op_parallelism_threads(threads)
|
tf.config.threading.set_intra_op_parallelism_threads(threads)
|
||||||
|
|
||||||
|
|
||||||
def TuneKerasCallback(*args, **kwargs):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"TuneKerasCallback is now tune.integration.keras.TuneReporterCallback."
|
|
||||||
)
|
|
||||||
|
|
|
@ -3,10 +3,9 @@ import logging
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.tune.trainable import Trainable
|
from ray.tune.trainable import Trainable
|
||||||
from ray.tune.logger import Logger, LoggerCallback
|
from ray.tune.logger import LoggerCallback
|
||||||
from ray.tune.result import TRAINING_ITERATION, TIMESTEPS_TOTAL
|
from ray.tune.result import TRAINING_ITERATION, TIMESTEPS_TOTAL
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.util.annotations import Deprecated
|
|
||||||
from ray.util.ml_utils.mlflow import MLflowLoggerUtil
|
from ray.util.ml_utils.mlflow import MLflowLoggerUtil
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -136,22 +135,6 @@ class MLflowLoggerCallback(LoggerCallback):
|
||||||
self.mlflow_util.end_run(run_id=run_id, status=status)
|
self.mlflow_util.end_run(run_id=run_id, status=status)
|
||||||
|
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
class MLflowLogger(Logger):
|
|
||||||
"""MLflow logger using the deprecated Logger API.
|
|
||||||
|
|
||||||
Requires the experiment configuration to have a MLflow Experiment ID
|
|
||||||
or manually set the proper environment variables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init(self):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"The legacy MLflowLogger has been "
|
|
||||||
"deprecated. Use the MLflowLoggerCallback "
|
|
||||||
"instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def mlflow_mixin(func: Callable):
|
def mlflow_mixin(func: Callable):
|
||||||
"""mlflow_mixin
|
"""mlflow_mixin
|
||||||
|
|
||||||
|
|
|
@ -198,7 +198,6 @@ def DistributedTrainableCreator(
|
||||||
num_workers_per_host: Optional[int] = None,
|
num_workers_per_host: Optional[int] = None,
|
||||||
backend: str = "gloo",
|
backend: str = "gloo",
|
||||||
timeout_s: int = NCCL_TIMEOUT_S,
|
timeout_s: int = NCCL_TIMEOUT_S,
|
||||||
use_gpu=None,
|
|
||||||
) -> Type[_TorchTrainable]:
|
) -> Type[_TorchTrainable]:
|
||||||
"""Creates a class that executes distributed training.
|
"""Creates a class that executes distributed training.
|
||||||
|
|
||||||
|
@ -239,10 +238,6 @@ def DistributedTrainableCreator(
|
||||||
train_func, num_workers=2)
|
train_func, num_workers=2)
|
||||||
analysis = tune.run(trainable_cls)
|
analysis = tune.run(trainable_cls)
|
||||||
"""
|
"""
|
||||||
if use_gpu:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"use_gpu is deprecated. Use 'num_gpus_per_worker' instead."
|
|
||||||
)
|
|
||||||
detect_checkpoint_function(func, abort=True)
|
detect_checkpoint_function(func, abort=True)
|
||||||
if num_workers_per_host:
|
if num_workers_per_host:
|
||||||
if num_workers % num_workers_per_host:
|
if num_workers % num_workers_per_host:
|
||||||
|
|
|
@ -15,6 +15,7 @@ from ray.tune.utils import flatten_dict
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from ray.util.annotations import Deprecated
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
|
@ -398,6 +399,7 @@ class WandbLoggerCallback(LoggerCallback):
|
||||||
del self._trial_processes[trial]
|
del self._trial_processes[trial]
|
||||||
|
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
class WandbLogger(Logger):
|
class WandbLogger(Logger):
|
||||||
"""WandbLogger
|
"""WandbLogger
|
||||||
|
|
||||||
|
@ -444,8 +446,7 @@ class WandbLogger(Logger):
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from ray.tune.logger import DEFAULT_LOGGERS
|
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||||
from ray.tune.integration.wandb import WandbLogger
|
|
||||||
tune.run(
|
tune.run(
|
||||||
train_fn,
|
train_fn,
|
||||||
config={
|
config={
|
||||||
|
@ -459,14 +460,14 @@ class WandbLogger(Logger):
|
||||||
"log_config": True
|
"log_config": True
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
|
calllbacks=[WandbLoggerCallback])
|
||||||
|
|
||||||
Example for RLlib:
|
Example for RLlib:
|
||||||
|
|
||||||
.. code-block :: python
|
.. code-block :: python
|
||||||
|
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune.integration.wandb import WandbLogger
|
from ray.tune.integration.wandb import WandbLoggerCallback
|
||||||
|
|
||||||
tune.run(
|
tune.run(
|
||||||
"PPO",
|
"PPO",
|
||||||
|
@ -479,40 +480,18 @@ class WandbLogger(Logger):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
loggers=[WandbLogger])
|
callbacks=[WandbLoggerCallback])
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_experiment_logger_cls = WandbLoggerCallback
|
_experiment_logger_cls = WandbLoggerCallback
|
||||||
|
|
||||||
def _init(self):
|
def __init__(self, *args, **kwargs):
|
||||||
config = self.config.copy()
|
raise DeprecationWarning(
|
||||||
config.pop("callbacks", None) # Remove callbacks
|
"This `Logger` class is deprecated. "
|
||||||
|
"Use the `WandbLoggerCallback` callback instead."
|
||||||
try:
|
)
|
||||||
if config.get("logger_config", {}).get("wandb"):
|
|
||||||
logger_config = config.pop("logger_config")
|
|
||||||
wandb_config = logger_config.get("wandb").copy()
|
|
||||||
else:
|
|
||||||
wandb_config = config.pop("wandb").copy()
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(
|
|
||||||
"Wandb logger specified but no configuration has been passed. "
|
|
||||||
"Make sure to include a `wandb` key in your `config` dict "
|
|
||||||
"containing at least a `project` specification."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._trial_experiment_logger = self._experiment_logger_cls(**wandb_config)
|
|
||||||
self._trial_experiment_logger.setup()
|
|
||||||
self._trial_experiment_logger.log_trial_start(self.trial)
|
|
||||||
|
|
||||||
def on_result(self, result: Dict):
|
|
||||||
self._trial_experiment_logger.log_trial_result(0, self.trial, result)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self._trial_experiment_logger.log_trial_end(self.trial, failed=False)
|
|
||||||
del self._trial_experiment_logger
|
|
||||||
|
|
||||||
|
|
||||||
class WandbTrainableMixin:
|
class WandbTrainableMixin:
|
||||||
|
|
|
@ -747,16 +747,6 @@ class TBXLoggerCallback(LoggerCallback):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Maintain backwards compatibility.
|
|
||||||
from ray.tune.integration.mlflow import ( # noqa: E402
|
|
||||||
MLflowLogger as _MLflowLogger,
|
|
||||||
)
|
|
||||||
|
|
||||||
MLflowLogger = _MLflowLogger
|
|
||||||
# The capital L is a typo, but needs to remain for backwards compatibility.
|
|
||||||
MLFLowLogger = _MLflowLogger
|
|
||||||
|
|
||||||
|
|
||||||
def pretty_print(result):
|
def pretty_print(result):
|
||||||
result = result.copy()
|
result = result.copy()
|
||||||
result.update(config=None) # drop config from pretty print
|
result.update(config=None) # drop config from pretty print
|
||||||
|
|
|
@ -525,13 +525,11 @@ class Quantized(Sampler):
|
||||||
return list(quantized)
|
return list(quantized)
|
||||||
|
|
||||||
|
|
||||||
# TODO (krfricke): Remove tune.function
|
# Deprecated: Remove in Ray > 1.13
|
||||||
def function(func):
|
def function(func):
|
||||||
logger.warning(
|
raise DeprecationWarning(
|
||||||
"DeprecationWarning: wrapping {} with tune.function() is no "
|
"wrapping {} with tune.function() is no longer needed".format(func)
|
||||||
"longer needed".format(func)
|
|
||||||
)
|
)
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def sample_from(func: Callable[[Dict], Any]):
|
def sample_from(func: Callable[[Dict], Any]):
|
||||||
|
|
|
@ -6,7 +6,6 @@ from typing import Dict, Any, List, Optional, Set, Tuple, Union, Callable
|
||||||
import pickle
|
import pickle
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ray.util import log_once
|
|
||||||
from ray.util.annotations import PublicAPI, Deprecated
|
from ray.util.annotations import PublicAPI, Deprecated
|
||||||
from ray.tune import trial_runner
|
from ray.tune import trial_runner
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
|
@ -583,6 +582,7 @@ _DistributeResourcesDefault = DistributeResources(add_bundles=False)
|
||||||
_DistributeResourcesDistributedDefault = DistributeResources(add_bundles=True)
|
_DistributeResourcesDistributedDefault = DistributeResources(add_bundles=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
@Deprecated
|
@Deprecated
|
||||||
def evenly_distribute_cpus_gpus(
|
def evenly_distribute_cpus_gpus(
|
||||||
trial_runner: "trial_runner.TrialRunner",
|
trial_runner: "trial_runner.TrialRunner",
|
||||||
|
@ -621,18 +621,16 @@ def evenly_distribute_cpus_gpus(
|
||||||
the function.
|
the function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if log_once("evenly_distribute_cpus_gpus_deprecated"):
|
raise DeprecationWarning(
|
||||||
warnings.warn(
|
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
||||||
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
"and `evenly_distribute_cpus_gpus_distributed` are "
|
||||||
"and `evenly_distribute_cpus_gpus_distributed` are "
|
"being deprecated. Use `DistributeResources()` and "
|
||||||
"being deprecated. Use `DistributeResources()` and "
|
"`DistributeResources(add_bundles=False)` instead "
|
||||||
"`DistributeResources(add_bundles=False)` instead "
|
"for equivalent functionality."
|
||||||
"for equivalent functionality."
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return _DistributeResourcesDefault(trial_runner, trial, result, scheduler)
|
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
@Deprecated
|
@Deprecated
|
||||||
def evenly_distribute_cpus_gpus_distributed(
|
def evenly_distribute_cpus_gpus_distributed(
|
||||||
trial_runner: "trial_runner.TrialRunner",
|
trial_runner: "trial_runner.TrialRunner",
|
||||||
|
@ -671,17 +669,12 @@ def evenly_distribute_cpus_gpus_distributed(
|
||||||
the function.
|
the function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if log_once("evenly_distribute_cpus_gpus_deprecated"):
|
raise DeprecationWarning(
|
||||||
warnings.warn(
|
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
||||||
"DeprecationWarning: `evenly_distribute_cpus_gpus` "
|
"and `evenly_distribute_cpus_gpus_distributed` are "
|
||||||
"and `evenly_distribute_cpus_gpus_distributed` are "
|
"being deprecated. Use `DistributeResources()` and "
|
||||||
"being deprecated. Use `DistributeResources()` and "
|
"`DistributeResources(add_bundles=False)` instead "
|
||||||
"`DistributeResources(add_bundles=False)` instead "
|
"for equivalent functionality."
|
||||||
"for equivalent functionality."
|
|
||||||
)
|
|
||||||
|
|
||||||
return _DistributeResourcesDistributedDefault(
|
|
||||||
trial_runner, trial, result, scheduler
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -103,28 +103,6 @@ def report(_metric=None, **kwargs):
|
||||||
return _session(_metric, **kwargs)
|
return _session(_metric, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_checkpoint_dir(step=None):
|
|
||||||
"""Gets the next checkpoint dir.
|
|
||||||
|
|
||||||
.. versionadded:: 0.8.6
|
|
||||||
|
|
||||||
.. deprecated:: 0.8.7
|
|
||||||
Use tune.checkpoint_dir instead.
|
|
||||||
"""
|
|
||||||
raise DeprecationWarning("Deprecated method. Use `tune.checkpoint_dir` instead.")
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(checkpoint):
|
|
||||||
"""Register the given checkpoint.
|
|
||||||
|
|
||||||
.. versionadded:: 0.8.6
|
|
||||||
|
|
||||||
.. deprecated:: 0.8.7
|
|
||||||
Use tune.checkpoint_dir instead.
|
|
||||||
"""
|
|
||||||
raise DeprecationWarning("Deprecated method. Use `tune.checkpoint_dir` instead.")
|
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def checkpoint_dir(step: int):
|
def checkpoint_dir(step: int):
|
||||||
|
|
|
@ -157,38 +157,3 @@ __all__ = [
|
||||||
"Repeater",
|
"Repeater",
|
||||||
"ConcurrencyLimiter",
|
"ConcurrencyLimiter",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def BayesOptSearch(*args, **kwargs):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"""This class has been moved. Please import via
|
|
||||||
`from ray.tune.suggest.bayesopt import BayesOptSearch`"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def HyperOptSearch(*args, **kwargs):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"""This class has been moved. Please import via
|
|
||||||
`from ray.tune.suggest.hyperopt import HyperOptSearch`"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def NevergradSearch(*args, **kwargs):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"""This class has been moved. Please import via
|
|
||||||
`from ray.tune.suggest.nevergrad import NevergradSearch`"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def SkOptSearch(*args, **kwargs):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"""This class has been moved. Please import via
|
|
||||||
`from ray.tune.suggest.skopt import SkOptSearch`"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def SigOptSearch(*args, **kwargs):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"""This class has been moved. Please import via
|
|
||||||
`from ray.tune.suggest.sigopt import SigOptSearch`"""
|
|
||||||
)
|
|
||||||
|
|
|
@ -73,8 +73,6 @@ class AxSearch(Searcher):
|
||||||
ax_client: Optional AxClient instance. If this is set, do
|
ax_client: Optional AxClient instance. If this is set, do
|
||||||
not pass any values to these parameters: `space`, `metric`,
|
not pass any values to these parameters: `space`, `metric`,
|
||||||
`parameter_constraints`, `outcome_constraints`.
|
`parameter_constraints`, `outcome_constraints`.
|
||||||
use_early_stopped_trials: Deprecated.
|
|
||||||
max_concurrent: Deprecated.
|
|
||||||
**ax_kwargs: Passed to AxClient instance. Ignored if `AxClient` is not
|
**ax_kwargs: Passed to AxClient instance. Ignored if `AxClient` is not
|
||||||
None.
|
None.
|
||||||
|
|
||||||
|
@ -133,8 +131,6 @@ class AxSearch(Searcher):
|
||||||
parameter_constraints: Optional[List] = None,
|
parameter_constraints: Optional[List] = None,
|
||||||
outcome_constraints: Optional[List] = None,
|
outcome_constraints: Optional[List] = None,
|
||||||
ax_client: Optional[AxClient] = None,
|
ax_client: Optional[AxClient] = None,
|
||||||
use_early_stopped_trials: Optional[bool] = None,
|
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
**ax_kwargs
|
**ax_kwargs
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
|
@ -149,8 +145,6 @@ class AxSearch(Searcher):
|
||||||
super(AxSearch, self).__init__(
|
super(AxSearch, self).__init__(
|
||||||
metric=metric,
|
metric=metric,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
max_concurrent=max_concurrent,
|
|
||||||
use_early_stopped_trials=use_early_stopped_trials,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._ax = ax_client
|
self._ax = ax_client
|
||||||
|
@ -170,8 +164,6 @@ class AxSearch(Searcher):
|
||||||
|
|
||||||
self._points_to_evaluate = copy.deepcopy(points_to_evaluate)
|
self._points_to_evaluate = copy.deepcopy(points_to_evaluate)
|
||||||
|
|
||||||
self.max_concurrent = max_concurrent
|
|
||||||
|
|
||||||
self._parameters = []
|
self._parameters = []
|
||||||
self._live_trial_mapping = {}
|
self._live_trial_mapping = {}
|
||||||
|
|
||||||
|
|
|
@ -79,8 +79,6 @@ class BayesOptSearch(Searcher):
|
||||||
analysis: Optionally, the previous analysis
|
analysis: Optionally, the previous analysis
|
||||||
to integrate.
|
to integrate.
|
||||||
verbose: Sets verbosity level for BayesOpt packages.
|
verbose: Sets verbosity level for BayesOpt packages.
|
||||||
max_concurrent: Deprecated.
|
|
||||||
use_early_stopped_trials: Deprecated.
|
|
||||||
|
|
||||||
Tune automatically converts search spaces to BayesOptSearch's format:
|
Tune automatically converts search spaces to BayesOptSearch's format:
|
||||||
|
|
||||||
|
@ -130,8 +128,6 @@ class BayesOptSearch(Searcher):
|
||||||
patience: int = 5,
|
patience: int = 5,
|
||||||
skip_duplicate: bool = True,
|
skip_duplicate: bool = True,
|
||||||
analysis: Optional[ExperimentAnalysis] = None,
|
analysis: Optional[ExperimentAnalysis] = None,
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
use_early_stopped_trials: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
assert byo is not None, (
|
assert byo is not None, (
|
||||||
"BayesOpt must be installed!. You can install BayesOpt with"
|
"BayesOpt must be installed!. You can install BayesOpt with"
|
||||||
|
@ -139,7 +135,6 @@ class BayesOptSearch(Searcher):
|
||||||
)
|
)
|
||||||
if mode:
|
if mode:
|
||||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||||
self.max_concurrent = max_concurrent
|
|
||||||
self._config_counter = defaultdict(int)
|
self._config_counter = defaultdict(int)
|
||||||
self._patience = patience
|
self._patience = patience
|
||||||
# int: Precision at which to hash values.
|
# int: Precision at which to hash values.
|
||||||
|
@ -150,8 +145,6 @@ class BayesOptSearch(Searcher):
|
||||||
super(BayesOptSearch, self).__init__(
|
super(BayesOptSearch, self).__init__(
|
||||||
metric=metric,
|
metric=metric,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
max_concurrent=max_concurrent,
|
|
||||||
use_early_stopped_trials=use_early_stopped_trials,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if utility_kwargs is None:
|
if utility_kwargs is None:
|
||||||
|
|
|
@ -61,8 +61,6 @@ class TuneBOHB(Searcher):
|
||||||
Parameters will be sampled from this space which will be used
|
Parameters will be sampled from this space which will be used
|
||||||
to run trials.
|
to run trials.
|
||||||
bohb_config: configuration for HpBandSter BOHB algorithm
|
bohb_config: configuration for HpBandSter BOHB algorithm
|
||||||
max_concurrent: Deprecated. Use
|
|
||||||
``tune.suggest.ConcurrencyLimiter()``.
|
|
||||||
metric: The training result objective value attribute. If None
|
metric: The training result objective value attribute. If None
|
||||||
but a mode was passed, the anonymous metric `_metric` will be used
|
but a mode was passed, the anonymous metric `_metric` will be used
|
||||||
per default.
|
per default.
|
||||||
|
@ -126,7 +124,6 @@ class TuneBOHB(Searcher):
|
||||||
self,
|
self,
|
||||||
space: Optional[Union[Dict, "ConfigSpace.ConfigurationSpace"]] = None,
|
space: Optional[Union[Dict, "ConfigSpace.ConfigurationSpace"]] = None,
|
||||||
bohb_config: Optional[Dict] = None,
|
bohb_config: Optional[Dict] = None,
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[str] = None,
|
||||||
points_to_evaluate: Optional[List[Dict]] = None,
|
points_to_evaluate: Optional[List[Dict]] = None,
|
||||||
|
@ -139,7 +136,6 @@ class TuneBOHB(Searcher):
|
||||||
`pip install hpbandster ConfigSpace`."""
|
`pip install hpbandster ConfigSpace`."""
|
||||||
if mode:
|
if mode:
|
||||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||||
self._max_concurrent = max_concurrent
|
|
||||||
self.trial_to_params = {}
|
self.trial_to_params = {}
|
||||||
self._metric = metric
|
self._metric = metric
|
||||||
|
|
||||||
|
@ -159,7 +155,8 @@ class TuneBOHB(Searcher):
|
||||||
self._points_to_evaluate = points_to_evaluate
|
self._points_to_evaluate = points_to_evaluate
|
||||||
|
|
||||||
super(TuneBOHB, self).__init__(
|
super(TuneBOHB, self).__init__(
|
||||||
metric=self._metric, mode=mode, max_concurrent=max_concurrent
|
metric=self._metric,
|
||||||
|
mode=mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._space:
|
if self._space:
|
||||||
|
|
|
@ -78,8 +78,6 @@ class HyperOptSearch(Searcher):
|
||||||
results. Defaults to None.
|
results. Defaults to None.
|
||||||
gamma: parameter governing the tree parzen
|
gamma: parameter governing the tree parzen
|
||||||
estimators suggestion algorithm. Defaults to 0.25.
|
estimators suggestion algorithm. Defaults to 0.25.
|
||||||
max_concurrent: Deprecated.
|
|
||||||
use_early_stopped_trials: Deprecated.
|
|
||||||
|
|
||||||
Tune automatically converts search spaces to HyperOpt's format:
|
Tune automatically converts search spaces to HyperOpt's format:
|
||||||
|
|
||||||
|
@ -138,8 +136,6 @@ class HyperOptSearch(Searcher):
|
||||||
n_initial_points: int = 20,
|
n_initial_points: int = 20,
|
||||||
random_state_seed: Optional[int] = None,
|
random_state_seed: Optional[int] = None,
|
||||||
gamma: float = 0.25,
|
gamma: float = 0.25,
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
use_early_stopped_trials: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
hpo is not None
|
hpo is not None
|
||||||
|
@ -149,10 +145,7 @@ class HyperOptSearch(Searcher):
|
||||||
super(HyperOptSearch, self).__init__(
|
super(HyperOptSearch, self).__init__(
|
||||||
metric=metric,
|
metric=metric,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
max_concurrent=max_concurrent,
|
|
||||||
use_early_stopped_trials=use_early_stopped_trials,
|
|
||||||
)
|
)
|
||||||
self.max_concurrent = max_concurrent
|
|
||||||
# hyperopt internally minimizes, so "max" => -1
|
# hyperopt internally minimizes, so "max" => -1
|
||||||
if mode == "max":
|
if mode == "max":
|
||||||
self.metric_op = -1.0
|
self.metric_op = -1.0
|
||||||
|
|
|
@ -59,8 +59,6 @@ class NevergradSearch(Searcher):
|
||||||
you want to run first to help the algorithm make better suggestions
|
you want to run first to help the algorithm make better suggestions
|
||||||
for future parameters. Needs to be a list of dicts containing the
|
for future parameters. Needs to be a list of dicts containing the
|
||||||
configurations.
|
configurations.
|
||||||
use_early_stopped_trials: Deprecated.
|
|
||||||
max_concurrent: Deprecated.
|
|
||||||
|
|
||||||
Tune automatically converts search spaces to Nevergrad's format:
|
Tune automatically converts search spaces to Nevergrad's format:
|
||||||
|
|
||||||
|
@ -120,7 +118,6 @@ class NevergradSearch(Searcher):
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[str] = None,
|
||||||
points_to_evaluate: Optional[List[Dict]] = None,
|
points_to_evaluate: Optional[List[Dict]] = None,
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
|
@ -131,9 +128,7 @@ class NevergradSearch(Searcher):
|
||||||
if mode:
|
if mode:
|
||||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||||
|
|
||||||
super(NevergradSearch, self).__init__(
|
super(NevergradSearch, self).__init__(metric=metric, mode=mode, **kwargs)
|
||||||
metric=metric, mode=mode, max_concurrent=max_concurrent, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self._space = None
|
self._space = None
|
||||||
self._opt_factory = None
|
self._opt_factory = None
|
||||||
|
@ -180,7 +175,6 @@ class NevergradSearch(Searcher):
|
||||||
)
|
)
|
||||||
|
|
||||||
self._live_trial_mapping = {}
|
self._live_trial_mapping = {}
|
||||||
self.max_concurrent = max_concurrent
|
|
||||||
|
|
||||||
if self._nevergrad_opt or self._space:
|
if self._nevergrad_opt or self._space:
|
||||||
self._setup_nevergrad()
|
self._setup_nevergrad()
|
||||||
|
|
|
@ -289,9 +289,7 @@ class OptunaSearch(Searcher):
|
||||||
evaluated_rewards: Optional[List] = None,
|
evaluated_rewards: Optional[List] = None,
|
||||||
):
|
):
|
||||||
assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
|
assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
|
||||||
super(OptunaSearch, self).__init__(
|
super(OptunaSearch, self).__init__(metric=metric, mode=mode)
|
||||||
metric=metric, mode=mode, max_concurrent=None, use_early_stopped_trials=None
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(space, dict) and space:
|
if isinstance(space, dict) and space:
|
||||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
||||||
|
|
|
@ -90,7 +90,7 @@ class SigOptSearch(Searcher):
|
||||||
]
|
]
|
||||||
algo = SigOptSearch(
|
algo = SigOptSearch(
|
||||||
space, name="SigOpt Example Experiment",
|
space, name="SigOpt Example Experiment",
|
||||||
max_concurrent=1, metric="mean_loss", mode="min")
|
metric="mean_loss", mode="min")
|
||||||
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -117,7 +117,7 @@ class SigOptSearch(Searcher):
|
||||||
]
|
]
|
||||||
algo = SigOptSearch(
|
algo = SigOptSearch(
|
||||||
space, name="SigOpt Multi Objective Example Experiment",
|
space, name="SigOpt Multi Objective Example Experiment",
|
||||||
max_concurrent=1, metric=["average", "std"], mode=["max", "min"])
|
metric=["average", "std"], mode=["max", "min"])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OBJECTIVE_MAP = {
|
OBJECTIVE_MAP = {
|
||||||
|
|
|
@ -68,8 +68,6 @@ class SkOptSearch(Searcher):
|
||||||
convert_to_python: SkOpt outputs numpy primitives (e.g.
|
convert_to_python: SkOpt outputs numpy primitives (e.g.
|
||||||
``np.int64``) instead of Python types. If this setting is set
|
``np.int64``) instead of Python types. If this setting is set
|
||||||
to ``True``, the values will be converted to Python primitives.
|
to ``True``, the values will be converted to Python primitives.
|
||||||
max_concurrent: Deprecated.
|
|
||||||
use_early_stopped_trials: Deprecated.
|
|
||||||
|
|
||||||
Tune automatically converts search spaces to SkOpt's format:
|
Tune automatically converts search spaces to SkOpt's format:
|
||||||
|
|
||||||
|
@ -127,8 +125,6 @@ class SkOptSearch(Searcher):
|
||||||
points_to_evaluate: Optional[List[Dict]] = None,
|
points_to_evaluate: Optional[List[Dict]] = None,
|
||||||
evaluated_rewards: Optional[List] = None,
|
evaluated_rewards: Optional[List] = None,
|
||||||
convert_to_python: bool = True,
|
convert_to_python: bool = True,
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
use_early_stopped_trials: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
assert sko is not None, (
|
assert sko is not None, (
|
||||||
"skopt must be installed! "
|
"skopt must be installed! "
|
||||||
|
@ -138,12 +134,10 @@ class SkOptSearch(Searcher):
|
||||||
|
|
||||||
if mode:
|
if mode:
|
||||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||||
self.max_concurrent = max_concurrent
|
|
||||||
super(SkOptSearch, self).__init__(
|
super(SkOptSearch, self).__init__(
|
||||||
metric=metric,
|
metric=metric,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
max_concurrent=max_concurrent,
|
|
||||||
use_early_stopped_trials=use_early_stopped_trials,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._initial_points = []
|
self._initial_points = []
|
||||||
|
|
|
@ -89,21 +89,7 @@ class Searcher:
|
||||||
self,
|
self,
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
mode: Optional[str] = None,
|
mode: Optional[str] = None,
|
||||||
max_concurrent: Optional[int] = None,
|
|
||||||
use_early_stopped_trials: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
if use_early_stopped_trials is False:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Early stopped trials are now always used. If this is a "
|
|
||||||
"problem, file an issue: https://github.com/ray-project/ray."
|
|
||||||
)
|
|
||||||
if max_concurrent is not None:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"`max_concurrent` is deprecated for this "
|
|
||||||
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
|
|
||||||
"instead. This will raise an error in future versions of Ray."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._metric = metric
|
self._metric = metric
|
||||||
self._mode = mode
|
self._mode = mode
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
import warnings
|
|
||||||
|
|
||||||
from typing import Optional, List, Callable, Union, Tuple
|
from typing import Optional, List, Callable, Union, Tuple
|
||||||
|
|
||||||
|
@ -18,7 +17,6 @@ import ray
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node
|
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node
|
||||||
from ray.util.annotations import PublicAPI
|
from ray.util.annotations import PublicAPI
|
||||||
from ray.util.debug import log_once
|
|
||||||
from ray.ml.utils.remote_storage import (
|
from ray.ml.utils.remote_storage import (
|
||||||
S3_PREFIX,
|
S3_PREFIX,
|
||||||
GS_PREFIX,
|
GS_PREFIX,
|
||||||
|
@ -197,13 +195,12 @@ class FunctionBasedClient(SyncClient):
|
||||||
self._sync_down_legacy = _is_legacy_sync_fn(sync_up_func)
|
self._sync_down_legacy = _is_legacy_sync_fn(sync_up_func)
|
||||||
|
|
||||||
if self._sync_up_legacy or self._sync_down_legacy:
|
if self._sync_up_legacy or self._sync_down_legacy:
|
||||||
if log_once("func_sync_up_legacy"):
|
raise DeprecationWarning(
|
||||||
warnings.warn(
|
"Your sync functions currently only accepts two params "
|
||||||
"Your sync functions currently only accepts two params "
|
"(a `source` and a `target`). In the future, we will "
|
||||||
"(a `source` and a `target`). In the future, we will "
|
"pass an additional `exclude` parameter. Please adjust "
|
||||||
"pass an additional `exclude` parameter. Please adjust "
|
"your sync function accordingly."
|
||||||
"your sync function accordingly."
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self.delete_func = delete_func or noop
|
self.delete_func = delete_func or noop
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import warnings
|
|
||||||
|
|
||||||
from inspect import isclass
|
from inspect import isclass
|
||||||
from shlex import quote
|
from shlex import quote
|
||||||
|
@ -98,13 +97,13 @@ def validate_sync_config(sync_config: "SyncConfig"):
|
||||||
sync_config.node_sync_period = -1
|
sync_config.node_sync_period = -1
|
||||||
sync_config.cloud_sync_period = -1
|
sync_config.cloud_sync_period = -1
|
||||||
|
|
||||||
warnings.warn(
|
# Deprecated: Remove in Ray > 1.13
|
||||||
|
raise DeprecationWarning(
|
||||||
"The `node_sync_period` and "
|
"The `node_sync_period` and "
|
||||||
"`cloud_sync_period` properties of `tune.SyncConfig` are "
|
"`cloud_sync_period` properties of `tune.SyncConfig` are "
|
||||||
"deprecated. Pass the `sync_period` property instead. "
|
"deprecated. Pass the `sync_period` property instead. "
|
||||||
"\nFor now, the lower of the two values (if provided) will "
|
"\nFor now, the lower of the two values (if provided) will "
|
||||||
f"be used as the sync_period. This value is: {sync_period}",
|
f"be used as the sync_period. This value is: {sync_period}"
|
||||||
DeprecationWarning,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if sync_config.sync_to_cloud or sync_config.sync_to_driver:
|
if sync_config.sync_to_cloud or sync_config.sync_to_driver:
|
||||||
|
@ -119,15 +118,15 @@ def validate_sync_config(sync_config: "SyncConfig"):
|
||||||
sync_config.sync_to_cloud = None
|
sync_config.sync_to_cloud = None
|
||||||
sync_config.sync_to_driver = None
|
sync_config.sync_to_driver = None
|
||||||
|
|
||||||
warnings.warn(
|
# Deprecated: Remove in Ray > 1.13
|
||||||
|
raise DeprecationWarning(
|
||||||
"The `sync_to_cloud` and `sync_to_driver` properties of "
|
"The `sync_to_cloud` and `sync_to_driver` properties of "
|
||||||
"`tune.SyncConfig` are deprecated. Pass the `syncer` property "
|
"`tune.SyncConfig` are deprecated. Pass the `syncer` property "
|
||||||
"instead. Presence of an `upload_dir` decides if checkpoints "
|
"instead. Presence of an `upload_dir` decides if checkpoints "
|
||||||
"are synced to cloud or not. Syncing to driver is "
|
"are synced to cloud or not. Syncing to driver is "
|
||||||
"automatically disabled if an `upload_dir` is given."
|
"automatically disabled if an `upload_dir` is given."
|
||||||
f"\nFor now, as the upload dir is {help}, the respective "
|
f"\nFor now, as the upload dir is {help}, the respective "
|
||||||
f"syncer is used. This value is: {syncer}",
|
f"syncer is used. This value is: {syncer}"
|
||||||
DeprecationWarning,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,6 +198,7 @@ class SyncConfig:
|
||||||
sync_period: int = 300
|
sync_period: int = 300
|
||||||
|
|
||||||
# Deprecated arguments
|
# Deprecated arguments
|
||||||
|
# Deprecated: Remove in Ray > 1.13
|
||||||
sync_to_cloud: Any = None
|
sync_to_cloud: Any = None
|
||||||
sync_to_driver: Any = None
|
sync_to_driver: Any = None
|
||||||
node_sync_period: int = -1
|
node_sync_period: int = -1
|
||||||
|
|
|
@ -28,7 +28,7 @@ from ray.tune import (
|
||||||
from ray.tune.callback import Callback
|
from ray.tune.callback import Callback
|
||||||
from ray.tune.experiment import Experiment
|
from ray.tune.experiment import Experiment
|
||||||
from ray.tune.function_runner import wrap_function
|
from ray.tune.function_runner import wrap_function
|
||||||
from ray.tune.logger import Logger
|
from ray.tune.logger import Logger, LegacyLoggerCallback
|
||||||
from ray.tune.ray_trial_executor import noop_logger_creator
|
from ray.tune.ray_trial_executor import noop_logger_creator
|
||||||
from ray.tune.resources import Resources
|
from ray.tune.resources import Resources
|
||||||
from ray.tune.result import (
|
from ray.tune.result import (
|
||||||
|
@ -59,7 +59,7 @@ from ray.tune.suggest.suggestion import ConcurrencyLimiter
|
||||||
from ray.tune.sync_client import CommandBasedClient
|
from ray.tune.sync_client import CommandBasedClient
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.tune.trial_runner import TrialRunner
|
from ray.tune.trial_runner import TrialRunner
|
||||||
from ray.tune.utils import flatten_dict, get_pinned_object, pin_in_object_store
|
from ray.tune.utils import flatten_dict
|
||||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,14 +122,14 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
|
|
||||||
[trial1] = run(
|
[trial1] = run(
|
||||||
_function_trainable,
|
_function_trainable,
|
||||||
loggers=[FunctionAPILogger],
|
callbacks=[LegacyLoggerCallback([FunctionAPILogger])],
|
||||||
raise_on_failed_trial=False,
|
raise_on_failed_trial=False,
|
||||||
scheduler=MockScheduler(),
|
scheduler=MockScheduler(),
|
||||||
).trials
|
).trials
|
||||||
|
|
||||||
[trial2] = run(
|
[trial2] = run(
|
||||||
class_trainable_name,
|
class_trainable_name,
|
||||||
loggers=[ClassAPILogger],
|
callbacks=[LegacyLoggerCallback([ClassAPILogger])],
|
||||||
raise_on_failed_trial=False,
|
raise_on_failed_trial=False,
|
||||||
scheduler=MockScheduler(),
|
scheduler=MockScheduler(),
|
||||||
).trials
|
).trials
|
||||||
|
@ -180,33 +180,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
|
|
||||||
return function_output, trials
|
return function_output, trials
|
||||||
|
|
||||||
def testPinObject(self):
|
|
||||||
X = pin_in_object_store("hello")
|
|
||||||
|
|
||||||
@ray.remote
|
|
||||||
def f():
|
|
||||||
return get_pinned_object(X)
|
|
||||||
|
|
||||||
self.assertEqual(ray.get(f.remote()), "hello")
|
|
||||||
|
|
||||||
def testFetchPinned(self):
|
|
||||||
X = pin_in_object_store("hello")
|
|
||||||
|
|
||||||
def train(config, reporter):
|
|
||||||
get_pinned_object(X)
|
|
||||||
reporter(timesteps_total=100, done=True)
|
|
||||||
|
|
||||||
register_trainable("f1", train)
|
|
||||||
[trial] = run_experiments(
|
|
||||||
{
|
|
||||||
"foo": {
|
|
||||||
"run": "f1",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
|
||||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100)
|
|
||||||
|
|
||||||
def testRegisterEnv(self):
|
def testRegisterEnv(self):
|
||||||
register_env("foo", lambda: None)
|
register_env("foo", lambda: None)
|
||||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||||
|
@ -756,7 +729,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||||
},
|
},
|
||||||
verbose=1,
|
verbose=1,
|
||||||
local_dir=tmpdir,
|
local_dir=tmpdir,
|
||||||
loggers=None,
|
|
||||||
)
|
)
|
||||||
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
|
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
|
||||||
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
|
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
|
||||||
|
|
|
@ -1,551 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from ray import tune
|
|
||||||
from ray.ml.utils.remote_storage import upload_to_uri, delete_at_uri
|
|
||||||
from ray.tune.cloud import TrialCheckpoint
|
|
||||||
|
|
||||||
|
|
||||||
class TrialCheckpointApiTest(unittest.TestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.local_dir = tempfile.mkdtemp()
|
|
||||||
with open(os.path.join(self.local_dir, "some_file"), "w") as f:
|
|
||||||
f.write("checkpoint")
|
|
||||||
|
|
||||||
self.cloud_dir = "memory:///cloud_dir"
|
|
||||||
|
|
||||||
self._save_checkpoint_at(self.cloud_dir)
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
shutil.rmtree(self.local_dir)
|
|
||||||
delete_at_uri(self.cloud_dir)
|
|
||||||
|
|
||||||
def _save_checkpoint_at(self, target):
|
|
||||||
delete_at_uri(target)
|
|
||||||
upload_to_uri(local_path=self.local_dir, uri=target)
|
|
||||||
|
|
||||||
def testConstructTrialCheckpoint(self):
|
|
||||||
# All these constructions should work
|
|
||||||
TrialCheckpoint(None, None)
|
|
||||||
TrialCheckpoint("/tmp", None)
|
|
||||||
TrialCheckpoint(None, "memory:///invalid")
|
|
||||||
TrialCheckpoint("/remote/node/dir", None)
|
|
||||||
|
|
||||||
def ensureCheckpointFile(self):
|
|
||||||
with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
|
|
||||||
f.write("checkpoint\n")
|
|
||||||
|
|
||||||
def testDownloadNoDefaults(self):
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
checkpoint.download()
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.download(local_path=self.local_dir)
|
|
||||||
|
|
||||||
# Case: Cloud dir is passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
|
||||||
checkpoint.download(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
path = checkpoint.download(local_path=self.local_dir, cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
def testDownloadDefaultLocal(self):
|
|
||||||
other_local_dir = "/tmp/invalid"
|
|
||||||
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.download()
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.download(local_path=other_local_dir)
|
|
||||||
|
|
||||||
# Case: Cloud dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
path = checkpoint.download(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
path = checkpoint.download(
|
|
||||||
local_path=other_local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(other_local_dir, path)
|
|
||||||
|
|
||||||
def testDownloadDefaultCloud(self):
|
|
||||||
other_cloud_dir = "memory:///other"
|
|
||||||
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
|
||||||
checkpoint.download()
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
path = checkpoint.download(local_path=self.local_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
# Case: Cloud dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
|
||||||
checkpoint.download(cloud_path=other_cloud_dir)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
path = checkpoint.download(
|
|
||||||
local_path=self.local_dir, cloud_path=other_cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
def testDownloadDefaultBoth(self):
|
|
||||||
other_local_dir = "/tmp/other"
|
|
||||||
other_cloud_dir = "memory:///other"
|
|
||||||
|
|
||||||
self._save_checkpoint_at(other_cloud_dir)
|
|
||||||
self._save_checkpoint_at(self.cloud_dir)
|
|
||||||
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.download()
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.download(local_path=other_local_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_local_dir, path)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.download(
|
|
||||||
local_path=other_local_dir, cloud_path=other_cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(other_local_dir, path)
|
|
||||||
|
|
||||||
def testUploadNoDefaults(self):
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
checkpoint.upload()
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.upload(local_path=self.local_dir)
|
|
||||||
|
|
||||||
# Case: Cloud dir is passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
|
||||||
checkpoint.upload(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
path = checkpoint.upload(local_path=self.local_dir, cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
def testUploadDefaultLocal(self):
|
|
||||||
other_local_dir = "/tmp/invalid"
|
|
||||||
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.upload()
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.upload(local_path=other_local_dir)
|
|
||||||
|
|
||||||
# Case: Cloud dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
path = checkpoint.upload(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
path = checkpoint.upload(local_path=other_local_dir, cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
def testUploadDefaultCloud(self):
|
|
||||||
other_cloud_dir = "memory:///other"
|
|
||||||
|
|
||||||
delete_at_uri(other_cloud_dir)
|
|
||||||
self._save_checkpoint_at(other_cloud_dir)
|
|
||||||
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
|
||||||
checkpoint.upload()
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
path = checkpoint.upload(local_path=self.local_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
# Case: Cloud dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No local path"):
|
|
||||||
checkpoint.upload(cloud_path=other_cloud_dir)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
path = checkpoint.upload(local_path=self.local_dir, cloud_path=other_cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_cloud_dir, path)
|
|
||||||
|
|
||||||
def testUploadDefaultBoth(self):
|
|
||||||
other_local_dir = "/tmp/other"
|
|
||||||
other_cloud_dir = "memory:///other"
|
|
||||||
|
|
||||||
delete_at_uri(other_cloud_dir)
|
|
||||||
self._save_checkpoint_at(other_cloud_dir)
|
|
||||||
shutil.copytree(self.local_dir, other_local_dir)
|
|
||||||
|
|
||||||
# Case: Nothing is passed
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.upload()
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
# Case: Local dir is passed
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.upload(local_path=other_local_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
# Case: Both are passed
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.upload(local_path=other_local_dir, cloud_path=other_cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_cloud_dir, path)
|
|
||||||
|
|
||||||
def testSaveLocalTarget(self):
|
|
||||||
state = {}
|
|
||||||
|
|
||||||
def copytree(source, dest):
|
|
||||||
state["copy_source"] = source
|
|
||||||
state["copy_dest"] = dest
|
|
||||||
|
|
||||||
other_local_dir = "/tmp/other"
|
|
||||||
|
|
||||||
# Case: No defaults
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.save()
|
|
||||||
|
|
||||||
# Case: Default local dir
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No cloud path"):
|
|
||||||
checkpoint.save()
|
|
||||||
|
|
||||||
# Case: Default cloud dir, no local dir passed
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No target path"):
|
|
||||||
checkpoint.save()
|
|
||||||
|
|
||||||
# Case: Default cloud dir, pass local dir
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
path = checkpoint.save(self.local_dir, force_download=True)
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
# Case: Default local dir, pass local dir
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
self.ensureCheckpointFile()
|
|
||||||
|
|
||||||
with patch("shutil.copytree", copytree):
|
|
||||||
path = checkpoint.save(other_local_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_local_dir, path)
|
|
||||||
self.assertEqual(state["copy_source"], self.local_dir)
|
|
||||||
self.assertEqual(state["copy_dest"], other_local_dir)
|
|
||||||
|
|
||||||
# Case: Both default, no pass
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.save()
|
|
||||||
|
|
||||||
self.assertEqual(self.local_dir, path)
|
|
||||||
|
|
||||||
# Case: Both default, pass other local dir
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("shutil.copytree", copytree):
|
|
||||||
path = checkpoint.save(other_local_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_local_dir, path)
|
|
||||||
self.assertEqual(state["copy_source"], self.local_dir)
|
|
||||||
self.assertEqual(state["copy_dest"], other_local_dir)
|
|
||||||
self.assertEqual(checkpoint.local_path, self.local_dir)
|
|
||||||
|
|
||||||
def testSaveCloudTarget(self):
|
|
||||||
other_cloud_dir = "memory:///other"
|
|
||||||
|
|
||||||
delete_at_uri(other_cloud_dir)
|
|
||||||
self._save_checkpoint_at(other_cloud_dir)
|
|
||||||
|
|
||||||
# Case: No defaults
|
|
||||||
checkpoint = TrialCheckpoint()
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "No existing local"):
|
|
||||||
checkpoint.save(self.cloud_dir)
|
|
||||||
|
|
||||||
# Case: Default local dir
|
|
||||||
# Write a checkpoint here as we assume existing local dir
|
|
||||||
with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
|
|
||||||
f.write("Checkpoint\n")
|
|
||||||
|
|
||||||
checkpoint = TrialCheckpoint(local_path=self.local_dir)
|
|
||||||
path = checkpoint.save(self.cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(self.cloud_dir, path)
|
|
||||||
|
|
||||||
# Clean up checkpoint
|
|
||||||
os.remove(os.path.join(self.local_dir, "checkpoint.txt"))
|
|
||||||
|
|
||||||
# Case: Default cloud dir, copy to other cloud
|
|
||||||
checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
|
|
||||||
|
|
||||||
path = checkpoint.save(other_cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_cloud_dir, path)
|
|
||||||
|
|
||||||
# Case: Default both, copy to other cloud
|
|
||||||
checkpoint = TrialCheckpoint(
|
|
||||||
local_path=self.local_dir, cloud_path=self.cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
path = checkpoint.save(other_cloud_dir)
|
|
||||||
|
|
||||||
self.assertEqual(other_cloud_dir, path)
|
|
||||||
|
|
||||||
|
|
||||||
def train(config, checkpoint_dir=None):
|
|
||||||
for i in range(10):
|
|
||||||
with tune.checkpoint_dir(step=0) as cd:
|
|
||||||
with open(os.path.join(cd, "checkpoint.json"), "wt") as f:
|
|
||||||
json.dump({"score": i, "train_id": config["train_id"]}, f)
|
|
||||||
tune.report(score=i)
|
|
||||||
|
|
||||||
|
|
||||||
class TrialCheckpointEndToEndTest(unittest.TestCase):
|
|
||||||
def setUp(self) -> None:
|
|
||||||
self.local_experiment_dir = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
self.fake_cloud_dir = tempfile.mkdtemp()
|
|
||||||
self.cloud_target = "memory:///invalid/sub/path"
|
|
||||||
|
|
||||||
self.second_fake_cloud_dir = tempfile.mkdtemp()
|
|
||||||
self.second_cloud_target = "memory:///other/cloud"
|
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
|
||||||
shutil.rmtree(self.local_experiment_dir)
|
|
||||||
shutil.rmtree(self.fake_cloud_dir)
|
|
||||||
shutil.rmtree(self.second_fake_cloud_dir)
|
|
||||||
|
|
||||||
def _delete_at_uri(self, uri: str):
|
|
||||||
cloud_local_dir = uri.replace(self.cloud_target, self.fake_cloud_dir)
|
|
||||||
cloud_local_dir = cloud_local_dir.replace(
|
|
||||||
self.second_cloud_target, self.second_fake_cloud_dir
|
|
||||||
)
|
|
||||||
shutil.rmtree(cloud_local_dir)
|
|
||||||
|
|
||||||
def _fake_download_from_uri(self, uri: str, local_path: str):
|
|
||||||
cloud_local_dir = uri.replace(self.cloud_target, self.fake_cloud_dir)
|
|
||||||
cloud_local_dir = cloud_local_dir.replace(
|
|
||||||
self.second_cloud_target, self.second_fake_cloud_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
shutil.rmtree(local_path, ignore_errors=True)
|
|
||||||
shutil.copytree(cloud_local_dir, local_path)
|
|
||||||
|
|
||||||
def _fake_upload_to_uri(self, local_path: str, uri: str):
|
|
||||||
cloud_local_dir = uri.replace(self.cloud_target, self.fake_cloud_dir)
|
|
||||||
cloud_local_dir = cloud_local_dir.replace(
|
|
||||||
self.second_cloud_target, self.second_fake_cloud_dir
|
|
||||||
)
|
|
||||||
shutil.rmtree(cloud_local_dir, ignore_errors=True)
|
|
||||||
shutil.copytree(local_path, cloud_local_dir)
|
|
||||||
|
|
||||||
def testCheckpointDownload(self):
|
|
||||||
analysis = tune.run(
|
|
||||||
train,
|
|
||||||
config={"train_id": tune.grid_search([0, 1, 2, 3, 4])},
|
|
||||||
local_dir=self.local_experiment_dir,
|
|
||||||
verbose=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Inject the sync config (this is usually done by `tune.run()`)
|
|
||||||
analysis._sync_config = tune.SyncConfig(upload_dir=self.cloud_target)
|
|
||||||
|
|
||||||
# Pretend we have all checkpoints on cloud storage (durable)
|
|
||||||
shutil.rmtree(self.fake_cloud_dir, ignore_errors=True)
|
|
||||||
shutil.copytree(self.local_experiment_dir, self.fake_cloud_dir)
|
|
||||||
|
|
||||||
# Pretend we don't have these on local storage
|
|
||||||
shutil.rmtree(analysis.trials[1].logdir)
|
|
||||||
shutil.rmtree(analysis.trials[2].logdir)
|
|
||||||
shutil.rmtree(analysis.trials[3].logdir)
|
|
||||||
shutil.rmtree(analysis.trials[4].logdir)
|
|
||||||
|
|
||||||
cp0 = analysis.get_best_checkpoint(analysis.trials[0], "score", "max")
|
|
||||||
cp1 = analysis.get_best_checkpoint(analysis.trials[1], "score", "max")
|
|
||||||
cp2 = analysis.get_best_checkpoint(analysis.trials[2], "score", "max")
|
|
||||||
cp3 = analysis.get_best_checkpoint(analysis.trials[3], "score", "max")
|
|
||||||
cp4 = analysis.get_best_checkpoint(analysis.trials[4], "score", "max")
|
|
||||||
|
|
||||||
def _load_cp(cd):
|
|
||||||
with open(os.path.join(cd, "checkpoint.json"), "rt") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
with patch("ray.tune.cloud.delete_at_uri", self._delete_at_uri), patch(
|
|
||||||
"ray.tune.cloud.download_from_uri", self._fake_download_from_uri
|
|
||||||
), patch(
|
|
||||||
"ray.ml.checkpoint.download_from_uri", self._fake_download_from_uri
|
|
||||||
), patch(
|
|
||||||
"ray.tune.cloud.upload_to_uri", self._fake_upload_to_uri
|
|
||||||
):
|
|
||||||
#######
|
|
||||||
# Case: Checkpoint exists on local dir. Copy to other local dir.
|
|
||||||
other_local_dir = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
cp0.save(other_local_dir)
|
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(cp0.local_path))
|
|
||||||
|
|
||||||
cp_content = _load_cp(other_local_dir)
|
|
||||||
self.assertEqual(cp_content["train_id"], 0)
|
|
||||||
self.assertEqual(cp_content["score"], 9)
|
|
||||||
|
|
||||||
cp_content_2 = _load_cp(cp0.local_path)
|
|
||||||
self.assertEqual(cp_content, cp_content_2)
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
shutil.rmtree(other_local_dir)
|
|
||||||
|
|
||||||
#######
|
|
||||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
|
||||||
# store in experiment dir.
|
|
||||||
|
|
||||||
# Directory is empty / does not exist before
|
|
||||||
self.assertFalse(os.path.exists(cp1.local_path))
|
|
||||||
|
|
||||||
# Save!
|
|
||||||
cp1.save()
|
|
||||||
|
|
||||||
# Directory is not empty anymore
|
|
||||||
self.assertTrue(os.listdir(cp1.local_path))
|
|
||||||
cp_content = _load_cp(cp1.local_path)
|
|
||||||
self.assertEqual(cp_content["train_id"], 1)
|
|
||||||
self.assertEqual(cp_content["score"], 9)
|
|
||||||
|
|
||||||
#######
|
|
||||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
|
||||||
# store into other local dir.
|
|
||||||
|
|
||||||
# Directory is empty / does not exist before
|
|
||||||
self.assertFalse(os.path.exists(cp2.local_path))
|
|
||||||
|
|
||||||
other_local_dir = tempfile.mkdtemp()
|
|
||||||
# Save!
|
|
||||||
cp2.save(other_local_dir)
|
|
||||||
|
|
||||||
# Directory still does not exist (as we save to other dir)
|
|
||||||
self.assertFalse(os.path.exists(cp2.local_path))
|
|
||||||
cp_content = _load_cp(other_local_dir)
|
|
||||||
self.assertEqual(cp_content["train_id"], 2)
|
|
||||||
self.assertEqual(cp_content["score"], 9)
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
shutil.rmtree(other_local_dir)
|
|
||||||
|
|
||||||
#######
|
|
||||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
|
||||||
# and store onto other cloud.
|
|
||||||
|
|
||||||
# Local dir does not exist
|
|
||||||
self.assertFalse(os.path.exists(cp3.local_path))
|
|
||||||
# First cloud exists
|
|
||||||
self.assertTrue(os.listdir(self.fake_cloud_dir))
|
|
||||||
# Second cloud does not exist
|
|
||||||
self.assertFalse(os.listdir(self.second_fake_cloud_dir))
|
|
||||||
|
|
||||||
# Trigger save
|
|
||||||
cp3.save(self.second_cloud_target)
|
|
||||||
|
|
||||||
# Local dir now exists
|
|
||||||
self.assertTrue(os.path.exists(cp3.local_path))
|
|
||||||
# First cloud exists
|
|
||||||
self.assertTrue(os.listdir(self.fake_cloud_dir))
|
|
||||||
# Second cloud now exists!
|
|
||||||
self.assertTrue(os.listdir(self.second_fake_cloud_dir))
|
|
||||||
|
|
||||||
cp_content = _load_cp(self.second_fake_cloud_dir)
|
|
||||||
self.assertEqual(cp_content["train_id"], 3)
|
|
||||||
self.assertEqual(cp_content["score"], 9)
|
|
||||||
|
|
||||||
#######
|
|
||||||
# Case: Checkpoint does not exist on local dir, download from cloud
|
|
||||||
# store into local dir. Use new checkpoint abstractions for this.
|
|
||||||
|
|
||||||
temp_dir = cp4.to_directory(tempfile.mkdtemp())
|
|
||||||
cp_content = _load_cp(temp_dir)
|
|
||||||
self.assertEqual(cp_content["train_id"], 4)
|
|
||||||
self.assertEqual(cp_content["score"], 9)
|
|
||||||
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
sys.exit(pytest.main(["-v", __file__]))
|
|
|
@ -309,7 +309,7 @@ class ExperimentAnalysisPropertySuite(unittest.TestCase):
|
||||||
self.assertEqual(ea.best_trial, trials[2])
|
self.assertEqual(ea.best_trial, trials[2])
|
||||||
self.assertEqual(ea.best_config, trials[2].config)
|
self.assertEqual(ea.best_config, trials[2].config)
|
||||||
self.assertEqual(ea.best_logdir, trials[2].logdir)
|
self.assertEqual(ea.best_logdir, trials[2].logdir)
|
||||||
self.assertEqual(ea.best_checkpoint.local_path, trials[2].checkpoint.value)
|
self.assertEqual(ea.best_checkpoint._local_path, trials[2].checkpoint.value)
|
||||||
self.assertTrue(all(ea.best_dataframe["trial_id"] == trials[2].trial_id))
|
self.assertTrue(all(ea.best_dataframe["trial_id"] == trials[2].trial_id))
|
||||||
self.assertEqual(ea.results_df.loc[trials[2].trial_id, "res"], 309)
|
self.assertEqual(ea.results_df.loc[trials[2].trial_id, "res"], 309)
|
||||||
self.assertEqual(ea.best_result["res"], 309)
|
self.assertEqual(ea.best_result["res"], 309)
|
||||||
|
|
|
@ -100,117 +100,6 @@ class WandbIntegrationTest(unittest.TestCase):
|
||||||
if WANDB_ENV_VAR in os.environ:
|
if WANDB_ENV_VAR in os.environ:
|
||||||
del os.environ[WANDB_ENV_VAR]
|
del os.environ[WANDB_ENV_VAR]
|
||||||
|
|
||||||
def testWandbLegacyLoggerConfig(self):
|
|
||||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
|
||||||
trial = Trial(
|
|
||||||
trial_config, 0, "trial_0", "trainable", PlacementGroupFactory([{"CPU": 1}])
|
|
||||||
)
|
|
||||||
|
|
||||||
if WANDB_ENV_VAR in os.environ:
|
|
||||||
del os.environ[WANDB_ENV_VAR]
|
|
||||||
|
|
||||||
# Needs at least a project
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
|
|
||||||
# No API key
|
|
||||||
trial_config["wandb"] = {"project": "test_project"}
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
|
|
||||||
# API Key in config
|
|
||||||
trial_config["wandb"] = {"project": "test_project", "api_key": "1234"}
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
del os.environ[WANDB_ENV_VAR]
|
|
||||||
|
|
||||||
# API Key file
|
|
||||||
with tempfile.NamedTemporaryFile("wt") as fp:
|
|
||||||
fp.write("5678")
|
|
||||||
fp.flush()
|
|
||||||
|
|
||||||
trial_config["wandb"] = {"project": "test_project", "api_key_file": fp.name}
|
|
||||||
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
del os.environ[WANDB_ENV_VAR]
|
|
||||||
|
|
||||||
# API Key in env
|
|
||||||
os.environ[WANDB_ENV_VAR] = "9012"
|
|
||||||
trial_config["wandb"] = {"project": "test_project"}
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
logger.close()
|
|
||||||
|
|
||||||
# From now on, the API key is in the env variable.
|
|
||||||
|
|
||||||
# Default configuration
|
|
||||||
trial_config["wandb"] = {"project": "test_project"}
|
|
||||||
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
self.assertEqual(logger.trial_process.kwargs["project"], "test_project")
|
|
||||||
self.assertEqual(logger.trial_process.kwargs["id"], trial.trial_id)
|
|
||||||
self.assertEqual(logger.trial_process.kwargs["name"], trial.trial_name)
|
|
||||||
self.assertEqual(logger.trial_process.kwargs["group"], trial.trainable_name)
|
|
||||||
self.assertIn("config", logger.trial_process._exclude)
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
|
|
||||||
# log config.
|
|
||||||
trial_config["wandb"] = {"project": "test_project", "log_config": True}
|
|
||||||
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
self.assertNotIn("config", logger.trial_process._exclude)
|
|
||||||
self.assertNotIn("metric", logger.trial_process._exclude)
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
|
|
||||||
# Exclude metric.
|
|
||||||
trial_config["wandb"] = {"project": "test_project", "excludes": ["metric"]}
|
|
||||||
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
self.assertIn("config", logger.trial_process._exclude)
|
|
||||||
self.assertIn("metric", logger.trial_process._exclude)
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
|
|
||||||
def testWandbLegacyLoggerReporting(self):
|
|
||||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
|
||||||
trial = Trial(
|
|
||||||
trial_config, 0, "trial_0", "trainable", PlacementGroupFactory([{"CPU": 1}])
|
|
||||||
)
|
|
||||||
|
|
||||||
trial_config["wandb"] = {
|
|
||||||
"project": "test_project",
|
|
||||||
"api_key": "1234",
|
|
||||||
"excludes": ["metric2"],
|
|
||||||
}
|
|
||||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
|
||||||
|
|
||||||
r1 = {
|
|
||||||
"metric1": 0.8,
|
|
||||||
"metric2": 1.4,
|
|
||||||
"metric3": np.asarray(32.0),
|
|
||||||
"metric4": np.float32(32.0),
|
|
||||||
"const": "text",
|
|
||||||
"config": trial_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.on_result(r1)
|
|
||||||
|
|
||||||
logged = logger.trial_process.logs.get(timeout=10)
|
|
||||||
self.assertIn("metric1", logged)
|
|
||||||
self.assertNotIn("metric2", logged)
|
|
||||||
self.assertIn("metric3", logged)
|
|
||||||
self.assertIn("metric4", logged)
|
|
||||||
self.assertNotIn("const", logged)
|
|
||||||
self.assertNotIn("config", logged)
|
|
||||||
|
|
||||||
logger.close()
|
|
||||||
|
|
||||||
def testWandbLoggerConfig(self):
|
def testWandbLoggerConfig(self):
|
||||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||||
trial = Trial(
|
trial = Trial(
|
||||||
|
|
|
@ -41,37 +41,26 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
_register_all() # re-register the evicted objects
|
_register_all() # re-register the evicted objects
|
||||||
|
|
||||||
def testSyncConfigDeprecation(self):
|
def testSyncConfigDeprecation(self):
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="sync_period"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||||
sync_conf = tune.SyncConfig(node_sync_period=4, cloud_sync_period=8)
|
tune.SyncConfig(node_sync_period=4, cloud_sync_period=8)
|
||||||
self.assertEqual(sync_conf.sync_period, 4)
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="sync_period"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||||
sync_conf = tune.SyncConfig(node_sync_period=4)
|
tune.SyncConfig(node_sync_period=4)
|
||||||
self.assertEqual(sync_conf.sync_period, 4)
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="sync_period"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="sync_period"):
|
||||||
sync_conf = tune.SyncConfig(cloud_sync_period=8)
|
tune.SyncConfig(cloud_sync_period=8)
|
||||||
self.assertEqual(sync_conf.sync_period, 8)
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||||
sync_conf = tune.SyncConfig(
|
tune.SyncConfig(sync_to_driver="a", sync_to_cloud="b", upload_dir=None)
|
||||||
sync_to_driver="a", sync_to_cloud="b", upload_dir=None
|
|
||||||
)
|
|
||||||
self.assertEqual(sync_conf.syncer, "a")
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||||
sync_conf = tune.SyncConfig(
|
tune.SyncConfig(sync_to_driver="a", sync_to_cloud="b", upload_dir="c")
|
||||||
sync_to_driver="a", sync_to_cloud="b", upload_dir="c"
|
|
||||||
)
|
|
||||||
self.assertEqual(sync_conf.syncer, "b")
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||||
sync_conf = tune.SyncConfig(sync_to_cloud="b", upload_dir=None)
|
tune.SyncConfig(sync_to_cloud="b", upload_dir=None)
|
||||||
self.assertEqual(sync_conf.syncer, None)
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(DeprecationWarning, expected_regex="syncer"):
|
with self.assertRaisesRegex(DeprecationWarning, expected_regex="syncer"):
|
||||||
sync_conf = tune.SyncConfig(sync_to_driver="a", upload_dir="c")
|
tune.SyncConfig(sync_to_driver="a", upload_dir="c")
|
||||||
self.assertEqual(sync_conf.syncer, None)
|
|
||||||
|
|
||||||
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
||||||
def testCloudProperString(self):
|
def testCloudProperString(self):
|
||||||
|
@ -159,7 +148,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
tmpdir2 = tempfile.mkdtemp()
|
tmpdir2 = tempfile.mkdtemp()
|
||||||
os.mkdir(os.path.join(tmpdir2, "foo"))
|
os.mkdir(os.path.join(tmpdir2, "foo"))
|
||||||
|
|
||||||
def sync_func(local, remote):
|
def sync_func(local, remote, exclude=None):
|
||||||
for filename in glob.glob(os.path.join(local, "*.json")):
|
for filename in glob.glob(os.path.join(local, "*.json")):
|
||||||
shutil.copy(filename, remote)
|
shutil.copy(filename, remote)
|
||||||
|
|
||||||
|
@ -187,7 +176,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
tune.report(score=i)
|
tune.report(score=i)
|
||||||
|
|
||||||
def counter(local, remote):
|
def counter(local, remote, exclude=None):
|
||||||
count_file = os.path.join(tmpdir, "count.txt")
|
count_file = os.path.join(tmpdir, "count.txt")
|
||||||
if not os.path.exists(count_file):
|
if not os.path.exists(count_file):
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -219,7 +208,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
def testClusterSyncFunction(self):
|
def testClusterSyncFunction(self):
|
||||||
def sync_func_driver(source, target):
|
def sync_func_driver(source, target, exclude=None):
|
||||||
assert ":" in source, "Source {} not a remote path.".format(source)
|
assert ":" in source, "Source {} not a remote path.".format(source)
|
||||||
assert ":" not in target, "Target is supposed to be local."
|
assert ":" not in target, "Target is supposed to be local."
|
||||||
with open(os.path.join(target, "test.log2"), "w") as f:
|
with open(os.path.join(target, "test.log2"), "w") as f:
|
||||||
|
@ -255,7 +244,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
def testNoSync(self):
|
def testNoSync(self):
|
||||||
"""Sync should not run on a single node."""
|
"""Sync should not run on a single node."""
|
||||||
|
|
||||||
def sync_func(source, target):
|
def sync_func(source, target, exclude=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sync_config = tune.SyncConfig(syncer=sync_func)
|
sync_config = tune.SyncConfig(syncer=sync_func)
|
||||||
|
@ -409,7 +398,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||||
sync_config = tune.SyncConfig(syncer=None)
|
sync_config = tune.SyncConfig(syncer=None)
|
||||||
|
|
||||||
# Create syncer callbacks
|
# Create syncer callbacks
|
||||||
callbacks = create_default_callbacks([], sync_config, loggers=None)
|
callbacks = create_default_callbacks([], sync_config)
|
||||||
syncer_callback = callbacks[-1]
|
syncer_callback = callbacks[-1]
|
||||||
|
|
||||||
# Sanity check that we got the syncer callback
|
# Sanity check that we got the syncer callback
|
||||||
|
|
|
@ -1,59 +1,55 @@
|
||||||
import warnings
|
|
||||||
|
|
||||||
from mock import patch
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
class TestTrialExecutorInheritance(unittest.TestCase):
|
class TestTrialExecutorInheritance(unittest.TestCase):
|
||||||
@patch.object(warnings, "warn")
|
def test_direct_inheritance_not_ok(self):
|
||||||
def test_direct_inheritance_not_ok(self, mocked_warn):
|
|
||||||
|
|
||||||
from ray.tune.trial_executor import TrialExecutor
|
from ray.tune.trial_executor import TrialExecutor
|
||||||
|
|
||||||
class _MyTrialExecutor(TrialExecutor):
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def start_trial(self, trial):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def stop_trial(self, trial):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def restore(self, trial):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save(self, trial):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def reset_trial(self, trial, new_config, new_experiment_tag):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def debug_string(self):
|
|
||||||
return "This is a debug string."
|
|
||||||
|
|
||||||
def export_trial_if_needed(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def fetch_result(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_next_available_trial(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_running_trials(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
msg = (
|
msg = (
|
||||||
"_MyTrialExecutor inherits from TrialExecutor, which is being "
|
"_MyTrialExecutor inherits from TrialExecutor, which is being "
|
||||||
"deprecated. "
|
"deprecated. "
|
||||||
"RFC: https://github.com/ray-project/ray/issues/17593. "
|
"RFC: https://github.com/ray-project/ray/issues/17593. "
|
||||||
"Please reach out on the Ray Github if you have any concerns."
|
"Please reach out on the Ray Github if you have any concerns."
|
||||||
)
|
)
|
||||||
mocked_warn.assert_called_once_with(msg, DeprecationWarning)
|
|
||||||
|
|
||||||
@patch.object(warnings, "warn")
|
with self.assertRaisesRegex(DeprecationWarning, msg):
|
||||||
def test_indirect_inheritance_ok(self, mocked_warn):
|
|
||||||
|
class _MyTrialExecutor(TrialExecutor):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start_trial(self, trial):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop_trial(self, trial):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def restore(self, trial):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(self, trial):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset_trial(self, trial, new_config, new_experiment_tag):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def debug_string(self):
|
||||||
|
return "This is a debug string."
|
||||||
|
|
||||||
|
def export_trial_if_needed(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def fetch_result(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_next_available_trial(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_running_trials(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def test_indirect_inheritance_ok(self):
|
||||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||||
|
|
||||||
class _MyRayTrialExecutor(RayTrialExecutor):
|
class _MyRayTrialExecutor(RayTrialExecutor):
|
||||||
|
@ -61,5 +57,3 @@ class TestTrialExecutorInheritance(unittest.TestCase):
|
||||||
|
|
||||||
class _AnotherMyRayTrialExecutor(_MyRayTrialExecutor):
|
class _AnotherMyRayTrialExecutor(_MyRayTrialExecutor):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
mocked_warn.assert_not_called()
|
|
||||||
|
|
|
@ -753,7 +753,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||||
ray.init(num_cpus=3)
|
ray.init(num_cpus=3)
|
||||||
|
|
||||||
# This makes checkpointing take 2 seconds.
|
# This makes checkpointing take 2 seconds.
|
||||||
def sync_up(source, target):
|
def sync_up(source, target, exclude=None):
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,6 @@ from ray.tune.utils.util import (
|
||||||
get_checkpoint_from_remote_node,
|
get_checkpoint_from_remote_node,
|
||||||
delete_external_checkpoint,
|
delete_external_checkpoint,
|
||||||
)
|
)
|
||||||
from ray.util.debug import log_once
|
|
||||||
from ray.util.annotations import PublicAPI
|
from ray.util.annotations import PublicAPI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -912,11 +911,6 @@ class Trainable:
|
||||||
A dict that describes training progress.
|
A dict that describes training progress.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._implements_method("_train") and log_once("_train"):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Trainable._train is deprecated and is now removed. Override "
|
|
||||||
"Trainable.step instead."
|
|
||||||
)
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save_checkpoint(self, tmp_checkpoint_dir: str):
|
def save_checkpoint(self, tmp_checkpoint_dir: str):
|
||||||
|
@ -957,11 +951,6 @@ class Trainable:
|
||||||
>>> trainable.save_checkpoint("/tmp/bad_example") # doctest: +SKIP
|
>>> trainable.save_checkpoint("/tmp/bad_example") # doctest: +SKIP
|
||||||
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
|
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
|
||||||
"""
|
"""
|
||||||
if self._implements_method("_save") and log_once("_save"):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Trainable._save is deprecated and is now removed. Override "
|
|
||||||
"Trainable.save_checkpoint instead."
|
|
||||||
)
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load_checkpoint(self, checkpoint: Union[Dict, str]):
|
def load_checkpoint(self, checkpoint: Union[Dict, str]):
|
||||||
|
@ -1007,11 +996,6 @@ class Trainable:
|
||||||
returned by `save_checkpoint`. The directory structure
|
returned by `save_checkpoint`. The directory structure
|
||||||
underneath the `checkpoint_dir` `save_checkpoint` is preserved.
|
underneath the `checkpoint_dir` `save_checkpoint` is preserved.
|
||||||
"""
|
"""
|
||||||
if self._implements_method("_restore") and log_once("_restore"):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Trainable._restore is deprecated and is now removed. "
|
|
||||||
"Override Trainable.load_checkpoint instead."
|
|
||||||
)
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def setup(self, config: Dict):
|
def setup(self, config: Dict):
|
||||||
|
@ -1024,11 +1008,6 @@ class Trainable:
|
||||||
Copy of `self.config`.
|
Copy of `self.config`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._implements_method("_setup") and log_once("_setup"):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Trainable._setup is deprecated and is now removed. Override "
|
|
||||||
"Trainable.setup instead."
|
|
||||||
)
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log_result(self, result: Dict):
|
def log_result(self, result: Dict):
|
||||||
|
@ -1043,11 +1022,6 @@ class Trainable:
|
||||||
Args:
|
Args:
|
||||||
result: Training result returned by step().
|
result: Training result returned by step().
|
||||||
"""
|
"""
|
||||||
if self._implements_method("_log_result") and log_once("_log_result"):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Trainable._log_result is deprecated and is now removed. "
|
|
||||||
"Override Trainable.log_result instead."
|
|
||||||
)
|
|
||||||
self._result_logger.on_result(result)
|
self._result_logger.on_result(result)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
@ -1061,11 +1035,6 @@ class Trainable:
|
||||||
|
|
||||||
.. versionadded:: 0.8.7
|
.. versionadded:: 0.8.7
|
||||||
"""
|
"""
|
||||||
if self._implements_method("_stop") and log_once("_stop"):
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"Trainable._stop is deprecated and is now removed. Override "
|
|
||||||
"Trainable.cleanup instead."
|
|
||||||
)
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _export_model(self, export_formats: List[str], export_dir: str):
|
def _export_model(self, export_formats: List[str], export_dir: str):
|
||||||
|
|
|
@ -202,9 +202,8 @@ class Trial:
|
||||||
Trials start in the PENDING state, and transition to RUNNING once started.
|
Trials start in the PENDING state, and transition to RUNNING once started.
|
||||||
On error it transitions to ERROR, otherwise TERMINATED on success.
|
On error it transitions to ERROR, otherwise TERMINATED on success.
|
||||||
|
|
||||||
There are resources allocated to each trial. It's preferred that resources
|
There are resources allocated to each trial. These should be specified
|
||||||
are specified using PlacementGroupFactory, rather than through Resources,
|
using ``PlacementGroupFactory``.
|
||||||
which is being deprecated.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
trainable_name: Name of the trainable object to be executed.
|
trainable_name: Name of the trainable object to be executed.
|
||||||
|
@ -771,17 +770,7 @@ class Trial:
|
||||||
if self.custom_dirname:
|
if self.custom_dirname:
|
||||||
generated_dirname = self.custom_dirname
|
generated_dirname = self.custom_dirname
|
||||||
else:
|
else:
|
||||||
if "MAX_LEN_IDENTIFIER" in os.environ:
|
MAX_LEN_IDENTIFIER = int(os.environ.get("TUNE_MAX_LEN_IDENTIFIER", "130"))
|
||||||
logger.error(
|
|
||||||
"The MAX_LEN_IDENTIFIER environment variable is "
|
|
||||||
"deprecated and will be removed in the future. "
|
|
||||||
"Use TUNE_MAX_LEN_IDENTIFIER instead."
|
|
||||||
)
|
|
||||||
MAX_LEN_IDENTIFIER = int(
|
|
||||||
os.environ.get(
|
|
||||||
"TUNE_MAX_LEN_IDENTIFIER", os.environ.get("MAX_LEN_IDENTIFIER", 130)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
generated_dirname = f"{str(self)}_{self.experiment_tag}"
|
generated_dirname = f"{str(self)}_{self.experiment_tag}"
|
||||||
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
|
generated_dirname = generated_dirname[:MAX_LEN_IDENTIFIER]
|
||||||
generated_dirname += f"_{date_str()}"
|
generated_dirname += f"_{date_str()}"
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
import warnings
|
|
||||||
|
|
||||||
from ray.exceptions import RayTaskError
|
from ray.exceptions import RayTaskError
|
||||||
from ray.tune import TuneError
|
from ray.tune import TuneError
|
||||||
|
@ -25,13 +24,13 @@ class _WarnOnDirectInheritanceMeta(type):
|
||||||
)
|
)
|
||||||
and "TrialExecutor" in tuple(base.__name__ for base in bases)
|
and "TrialExecutor" in tuple(base.__name__ for base in bases)
|
||||||
):
|
):
|
||||||
deprecation_msg = (
|
raise DeprecationWarning(
|
||||||
f"{name} inherits from TrialExecutor, which is being "
|
f"{name} inherits from TrialExecutor, which is being "
|
||||||
"deprecated. "
|
"deprecated. "
|
||||||
"RFC: https://github.com/ray-project/ray/issues/17593. "
|
"RFC: https://github.com/ray-project/ray/issues/17593. "
|
||||||
"Please reach out on the Ray Github if you have any concerns."
|
"Please reach out on the Ray Github if you have any concerns."
|
||||||
)
|
)
|
||||||
warnings.warn(deprecation_msg, DeprecationWarning)
|
|
||||||
cls = super().__new__(mcls, name, bases, module, **kwargs)
|
cls = super().__new__(mcls, name, bases, module, **kwargs)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
|
@ -314,13 +314,6 @@ class TrialRunner:
|
||||||
|
|
||||||
self._metric = metric
|
self._metric = metric
|
||||||
|
|
||||||
if "TRIALRUNNER_WALLTIME_LIMIT" in os.environ:
|
|
||||||
raise ValueError(
|
|
||||||
"The TRIALRUNNER_WALLTIME_LIMIT environment variable is "
|
|
||||||
"deprecated. "
|
|
||||||
"Use `tune.run(time_budget_s=limit)` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._total_time = 0
|
self._total_time = 0
|
||||||
self._iteration = 0
|
self._iteration = 0
|
||||||
self._has_errored = False
|
self._has_errored = False
|
||||||
|
|
|
@ -162,7 +162,6 @@ def run(
|
||||||
# == internal only ==
|
# == internal only ==
|
||||||
_experiment_checkpoint_dir: Optional[str] = None,
|
_experiment_checkpoint_dir: Optional[str] = None,
|
||||||
# Deprecated args
|
# Deprecated args
|
||||||
queue_trials: Optional[bool] = None,
|
|
||||||
loggers: Optional[Sequence[Type[Logger]]] = None,
|
loggers: Optional[Sequence[Type[Logger]]] = None,
|
||||||
_remote: Optional[bool] = None,
|
_remote: Optional[bool] = None,
|
||||||
) -> ExperimentAnalysis:
|
) -> ExperimentAnalysis:
|
||||||
|
@ -349,27 +348,6 @@ def run(
|
||||||
Raises:
|
Raises:
|
||||||
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
TuneError: Any trials failed and `raise_on_failed_trial` is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# To be removed in 1.9.
|
|
||||||
if queue_trials is not None:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"`queue_trials` has been deprecated and is replaced by "
|
|
||||||
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
|
|
||||||
"Per default at least one Trial is queued at all times, "
|
|
||||||
"so you likely don't need to change anything other than "
|
|
||||||
"removing this argument from your call to `tune.run()`"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Starting deprecation in ray 1.10.
|
|
||||||
if os.environ.get("TUNE_TRIAL_RESULT_WAIT_TIME_S") is not None:
|
|
||||||
warnings.warn("`TUNE_TRIAL_RESULT_WAIT_TIME_S` is deprecated.")
|
|
||||||
|
|
||||||
if os.environ.get("TUNE_TRIAL_STARTUP_GRACE_PERIOD") is not None:
|
|
||||||
warnings.warn("`TUNE_TRIAL_STARTUP_GRACE_PERIOD` is deprecated.")
|
|
||||||
|
|
||||||
if os.environ.get("TUNE_PLACEMENT_GROUP_WAIT_S") is not None:
|
|
||||||
warnings.warn("`TUNE_PLACEMENT_GROUP_WAIT_S` is deprecated.")
|
|
||||||
|
|
||||||
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
# NO CODE IS TO BE ADDED ABOVE THIS COMMENT
|
||||||
# remote_run_kwargs must be defined before any other
|
# remote_run_kwargs must be defined before any other
|
||||||
# code is ran to ensure that at this point,
|
# code is ran to ensure that at this point,
|
||||||
|
@ -439,8 +417,8 @@ def run(
|
||||||
all_start = time.time()
|
all_start = time.time()
|
||||||
|
|
||||||
if loggers:
|
if loggers:
|
||||||
# Raise DeprecationWarning in 1.9, remove in 1.10/1.11
|
# Deprecated: Remove in Ray > 1.13
|
||||||
warnings.warn(
|
raise DeprecationWarning(
|
||||||
"The `loggers` argument is deprecated. Please pass the respective "
|
"The `loggers` argument is deprecated. Please pass the respective "
|
||||||
"`LoggerCallback` classes to the `callbacks` argument instead. "
|
"`LoggerCallback` classes to the `callbacks` argument instead. "
|
||||||
"See https://docs.ray.io/en/latest/tune/api_docs/logging.html"
|
"See https://docs.ray.io/en/latest/tune/api_docs/logging.html"
|
||||||
|
@ -642,9 +620,7 @@ def run(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create syncer callbacks
|
# Create syncer callbacks
|
||||||
callbacks = create_default_callbacks(
|
callbacks = create_default_callbacks(callbacks, sync_config, metric=metric)
|
||||||
callbacks, sync_config, metric=metric, loggers=loggers
|
|
||||||
)
|
|
||||||
|
|
||||||
runner = TrialRunner(
|
runner = TrialRunner(
|
||||||
search_alg=search_alg,
|
search_alg=search_alg,
|
||||||
|
@ -803,7 +779,6 @@ def run_experiments(
|
||||||
raise_on_failed_trial: bool = True,
|
raise_on_failed_trial: bool = True,
|
||||||
concurrent: bool = True,
|
concurrent: bool = True,
|
||||||
# Deprecated args.
|
# Deprecated args.
|
||||||
queue_trials: Optional[bool] = None,
|
|
||||||
callbacks: Optional[Sequence[Callback]] = None,
|
callbacks: Optional[Sequence[Callback]] = None,
|
||||||
_remote: Optional[bool] = None,
|
_remote: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
|
@ -822,16 +797,6 @@ def run_experiments(
|
||||||
List of Trial objects, holding data for each executed trial.
|
List of Trial objects, holding data for each executed trial.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# To be removed in 1.9.
|
|
||||||
if queue_trials is not None:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"`queue_trials` has been deprecated and is replaced by "
|
|
||||||
"the `TUNE_MAX_PENDING_TRIALS_PG` environment variable. "
|
|
||||||
"Per default at least one Trial is queued at all times, "
|
|
||||||
"so you likely don't need to change anything other than "
|
|
||||||
"removing this argument from your call to `tune.run()`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if _remote is None:
|
if _remote is None:
|
||||||
_remote = ray.util.client.ray.is_connected()
|
_remote = ray.util.client.ray.is_connected()
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,7 @@ from ray.tune.utils.util import (
|
||||||
deep_update,
|
deep_update,
|
||||||
date_str,
|
date_str,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_pinned_object,
|
|
||||||
merge_dicts,
|
merge_dicts,
|
||||||
pin_in_object_store,
|
|
||||||
unflattened_lookup,
|
unflattened_lookup,
|
||||||
UtilMonitor,
|
UtilMonitor,
|
||||||
validate_save_restore,
|
validate_save_restore,
|
||||||
|
@ -20,9 +18,7 @@ __all__ = [
|
||||||
"deep_update",
|
"deep_update",
|
||||||
"date_str",
|
"date_str",
|
||||||
"flatten_dict",
|
"flatten_dict",
|
||||||
"get_pinned_object",
|
|
||||||
"merge_dicts",
|
"merge_dicts",
|
||||||
"pin_in_object_store",
|
|
||||||
"unflattened_lookup",
|
"unflattened_lookup",
|
||||||
"UtilMonitor",
|
"UtilMonitor",
|
||||||
"validate_save_restore",
|
"validate_save_restore",
|
||||||
|
|
|
@ -9,11 +9,9 @@ from ray.tune.syncer import SyncConfig, detect_cluster_syncer
|
||||||
from ray.tune.logger import (
|
from ray.tune.logger import (
|
||||||
CSVLoggerCallback,
|
CSVLoggerCallback,
|
||||||
CSVLogger,
|
CSVLogger,
|
||||||
LoggerCallback,
|
|
||||||
JsonLoggerCallback,
|
JsonLoggerCallback,
|
||||||
JsonLogger,
|
JsonLogger,
|
||||||
LegacyLoggerCallback,
|
LegacyLoggerCallback,
|
||||||
Logger,
|
|
||||||
TBXLoggerCallback,
|
TBXLoggerCallback,
|
||||||
TBXLogger,
|
TBXLogger,
|
||||||
)
|
)
|
||||||
|
@ -25,7 +23,6 @@ logger = logging.getLogger(__name__)
|
||||||
def create_default_callbacks(
|
def create_default_callbacks(
|
||||||
callbacks: Optional[List[Callback]],
|
callbacks: Optional[List[Callback]],
|
||||||
sync_config: SyncConfig,
|
sync_config: SyncConfig,
|
||||||
loggers: Optional[List[Logger]],
|
|
||||||
metric: Optional[str] = None,
|
metric: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Create default callbacks for `tune.run()`.
|
"""Create default callbacks for `tune.run()`.
|
||||||
|
@ -69,23 +66,6 @@ def create_default_callbacks(
|
||||||
last_logger_index = None
|
last_logger_index = None
|
||||||
syncer_index = None
|
syncer_index = None
|
||||||
|
|
||||||
# Deprecate: 1.9
|
|
||||||
# Create LegacyLoggerCallback for passed Logger classes
|
|
||||||
if loggers:
|
|
||||||
add_loggers = []
|
|
||||||
for trial_logger in loggers:
|
|
||||||
if isinstance(trial_logger, LoggerCallback):
|
|
||||||
callbacks.append(trial_logger)
|
|
||||||
elif isinstance(trial_logger, type) and issubclass(trial_logger, Logger):
|
|
||||||
add_loggers.append(trial_logger)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid value passed to `loggers` argument of "
|
|
||||||
f"`tune.run()`: {trial_logger}"
|
|
||||||
)
|
|
||||||
if add_loggers:
|
|
||||||
callbacks.append(LegacyLoggerCallback(add_loggers))
|
|
||||||
|
|
||||||
# Check if we have a CSV, JSON and TensorboardX logger
|
# Check if we have a CSV, JSON and TensorboardX logger
|
||||||
for i, callback in enumerate(callbacks):
|
for i, callback in enumerate(callbacks):
|
||||||
if isinstance(callback, LegacyLoggerCallback):
|
if isinstance(callback, LegacyLoggerCallback):
|
||||||
|
@ -147,12 +127,7 @@ def create_default_callbacks(
|
||||||
and last_logger_index is not None
|
and last_logger_index is not None
|
||||||
and syncer_index < last_logger_index
|
and syncer_index < last_logger_index
|
||||||
):
|
):
|
||||||
if (
|
if not has_csv_logger or not has_json_logger or not has_tbx_logger:
|
||||||
not has_csv_logger or not has_json_logger or not has_tbx_logger
|
|
||||||
) and not loggers:
|
|
||||||
# Only raise the warning if the loggers were passed by the user.
|
|
||||||
# (I.e. don't warn if this was automatic behavior and they only
|
|
||||||
# passed a customer SyncerCallback).
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The `SyncerCallback` you passed to `tune.run()` came before "
|
"The `SyncerCallback` you passed to `tune.run()` came before "
|
||||||
"at least one `LoggerCallback`. Syncing should be done "
|
"at least one `LoggerCallback`. Syncing should be done "
|
||||||
|
|
|
@ -243,16 +243,6 @@ def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]]):
|
||||||
spec = spec._asdict()
|
spec = spec._asdict()
|
||||||
|
|
||||||
spec = spec.copy()
|
spec = spec.copy()
|
||||||
extra_custom = spec.pop("extra_custom_resources", {}) or {}
|
|
||||||
|
|
||||||
if any(k.startswith("extra_") and spec[k] for k in spec) or any(
|
|
||||||
extra_custom[k] for k in extra_custom
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Passing `extra_*` resource requirements to `resources_per_trial` "
|
|
||||||
"is deprecated. Please use a `PlacementGroupFactory` object "
|
|
||||||
"to define your resource requirements instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
cpus = spec.pop("cpu", 0.0)
|
cpus = spec.pop("cpu", 0.0)
|
||||||
gpus = spec.pop("gpu", 0.0)
|
gpus = spec.pop("gpu", 0.0)
|
||||||
|
|
|
@ -121,20 +121,6 @@ class UtilMonitor(Thread):
|
||||||
self.stopped = True
|
self.stopped = True
|
||||||
|
|
||||||
|
|
||||||
def pin_in_object_store(obj):
|
|
||||||
"""Deprecated, use ray.put(value) instead."""
|
|
||||||
|
|
||||||
obj_ref = ray.put(obj)
|
|
||||||
_pinned_objects.append(obj_ref)
|
|
||||||
return obj_ref
|
|
||||||
|
|
||||||
|
|
||||||
def get_pinned_object(pinned_id):
|
|
||||||
"""Deprecated."""
|
|
||||||
|
|
||||||
return ray.get(pinned_id)
|
|
||||||
|
|
||||||
|
|
||||||
def retry_fn(
|
def retry_fn(
|
||||||
fn: Callable[[], Any],
|
fn: Callable[[], Any],
|
||||||
exception_type: Type[Exception],
|
exception_type: Type[Exception],
|
||||||
|
@ -456,7 +442,6 @@ def wait_for_gpu(
|
||||||
retry: Number of times to check GPU limit. Sleeps `delay_s`
|
retry: Number of times to check GPU limit. Sleeps `delay_s`
|
||||||
seconds between checks.
|
seconds between checks.
|
||||||
delay_s: Seconds to wait before check.
|
delay_s: Seconds to wait before check.
|
||||||
gpu_memory_limit: Deprecated.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if free.
|
bool: True if free.
|
||||||
|
@ -476,8 +461,7 @@ def wait_for_gpu(
|
||||||
tune.run(tune_func, resources_per_trial={"GPU": 1}, num_samples=10)
|
tune.run(tune_func, resources_per_trial={"GPU": 1}, num_samples=10)
|
||||||
"""
|
"""
|
||||||
GPUtil = _import_gputil()
|
GPUtil = _import_gputil()
|
||||||
if gpu_memory_limit:
|
|
||||||
raise ValueError("'gpu_memory_limit' is deprecated. Use 'target_util' instead.")
|
|
||||||
if GPUtil is None:
|
if GPUtil is None:
|
||||||
raise RuntimeError("GPUtil must be installed if calling `wait_for_gpu`.")
|
raise RuntimeError("GPUtil must be installed if calling `wait_for_gpu`.")
|
||||||
|
|
||||||
|
@ -700,11 +684,3 @@ def validate_warmstart(
|
||||||
+ " and points_to_evaluate {}".format(points_to_evaluate)
|
+ " and points_to_evaluate {}".format(points_to_evaluate)
|
||||||
+ " do not match."
|
+ " do not match."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
ray.init()
|
|
||||||
X = pin_in_object_store("hello")
|
|
||||||
print(X)
|
|
||||||
result = get_pinned_object(X)
|
|
||||||
print(result)
|
|
||||||
|
|
|
@ -1,76 +0,0 @@
|
||||||
import pandas as pd
|
|
||||||
from pandas.api.types import is_string_dtype, is_numeric_dtype
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
|
|
||||||
from ray.tune.utils import flatten_dict
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
logger.warning("This module will be deprecated in a future version of Tune.")
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_results(res_path):
|
|
||||||
res_dict = {}
|
|
||||||
try:
|
|
||||||
with open(res_path) as f:
|
|
||||||
# Get last line in file
|
|
||||||
for line in f:
|
|
||||||
pass
|
|
||||||
res_dict = flatten_dict(json.loads(line.strip()))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
|
|
||||||
return res_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_configs(cfg_path):
|
|
||||||
try:
|
|
||||||
with open(cfg_path) as f:
|
|
||||||
cfg_dict = flatten_dict(json.load(f))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Config parsing failed.")
|
|
||||||
return cfg_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve(directory, result_fname):
|
|
||||||
try:
|
|
||||||
resultp = osp.join(directory, result_fname)
|
|
||||||
res_dict = _parse_results(resultp)
|
|
||||||
cfgp = osp.join(directory, "params.json")
|
|
||||||
cfg_dict = _parse_configs(cfgp)
|
|
||||||
cfg_dict.update(res_dict)
|
|
||||||
return cfg_dict
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def load_results_to_df(directory, result_name="result.json"):
|
|
||||||
exp_directories = [
|
|
||||||
dirpath
|
|
||||||
for dirpath, dirs, files in os.walk(directory)
|
|
||||||
for f in files
|
|
||||||
if f == result_name
|
|
||||||
]
|
|
||||||
data = [_resolve(d, result_name) for d in exp_directories]
|
|
||||||
data = [d for d in data if d]
|
|
||||||
return pd.DataFrame(data)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_plotly_dim_dict(df, field):
|
|
||||||
dim_dict = {}
|
|
||||||
dim_dict["label"] = field
|
|
||||||
column = df[field]
|
|
||||||
if is_numeric_dtype(column):
|
|
||||||
dim_dict["values"] = column
|
|
||||||
elif is_string_dtype(column):
|
|
||||||
texts = column.unique()
|
|
||||||
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0] for x in column]
|
|
||||||
dim_dict["tickvals"] = list(range(len(texts)))
|
|
||||||
dim_dict["ticktext"] = texts
|
|
||||||
else:
|
|
||||||
raise Exception("Unidentifiable Type")
|
|
||||||
|
|
||||||
return dim_dict
|
|
|
@ -15,7 +15,7 @@ import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||||
from ray.tune.logger import Logger
|
from ray.tune.logger import Logger, LegacyLoggerCallback
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -119,7 +119,11 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(
|
results = tune.run(
|
||||||
args.run, config=config, stop=stop, verbose=2, loggers=[MyPrintLogger]
|
args.run,
|
||||||
|
config=config,
|
||||||
|
stop=stop,
|
||||||
|
verbose=2,
|
||||||
|
callbacks=[LegacyLoggerCallback(MyPrintLogger)],
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
|
|
|
@ -29,7 +29,6 @@ import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune import function
|
|
||||||
from ray.rllib.examples.env.windy_maze_env import WindyMazeEnv, HierarchicalWindyMazeEnv
|
from ray.rllib.examples.env.windy_maze_env import WindyMazeEnv, HierarchicalWindyMazeEnv
|
||||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||||
|
|
||||||
|
@ -107,7 +106,7 @@ if __name__ == "__main__":
|
||||||
{"gamma": 0.0},
|
{"gamma": 0.0},
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"policy_mapping_fn": function(policy_mapping_fn),
|
"policy_mapping_fn": policy_mapping_fn,
|
||||||
},
|
},
|
||||||
"framework": args.framework,
|
"framework": args.framework,
|
||||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||||
|
|
Loading…
Add table
Reference in a new issue