mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[AIR] Remove ML code from ray.util
(#27005)
Removes all ML related code from `ray.util` Removes: - `ray.util.xgboost` - `ray.util.lightgbm` - `ray.util.horovod` - `ray.util.ray_lightning` Moves `ray.util.ml_utils` to other locations Closes #23900 Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com> Signed-off-by: Kai Fricke <kai@anyscale.com> Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
parent
6c2bfccad9
commit
862d10c162
113 changed files with 751 additions and 784 deletions
|
@ -311,15 +311,21 @@
|
|||
- TRAIN_TESTING=1 TUNE_TESTING=1 ./ci/env/install-dependencies.sh
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=tune,-gpu_only,-ray_air python/ray/train/...
|
||||
|
||||
- label: ":octopus: Tune/Modin/Dask tests and examples. Python 3.7"
|
||||
- label: ":octopus: Tune tests and examples. Python 3.7"
|
||||
conditions: ["RAY_CI_TUNE_AFFECTED"]
|
||||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
- TUNE_TESTING=1 PYTHON=3.7 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=py37,-client python/ray/tune/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=-client python/ray/util/xgboost/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only python/ray/util/horovod/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only python/ray/util/ray_lightning/...
|
||||
|
||||
- label: ":octopus: ML library integrations tests and examples. Python 3.7"
|
||||
conditions: ["RAY_CI_TUNE_AFFECTED"]
|
||||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
- TUNE_TESTING=1 PYTHON=3.7 INSTALL_HOROVOD=1 ./ci/env/install-dependencies.sh
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only python/ray/tests/xgboost/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only python/ray/tests/horovod/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) python/ray/tests/ray_lightning/...
|
||||
|
||||
# TODO(amogkam): Re-enable Ludwig tests after Ludwig supports Ray 2.0
|
||||
#- label: ":octopus: Ludwig tests and examples. Python 3.7"
|
||||
|
@ -337,9 +343,8 @@
|
|||
- rm -rf ./python/ray/thirdparty_files; rm -rf ./python/ray/pickle5_files; ./ci/ci.sh build
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=client --test_env=RAY_CLIENT_MODE=1 python/ray/util/dask/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=client python/ray/tune/...
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=client python/ray/util/xgboost/...
|
||||
|
||||
- label: ":potable_water: Modin/Dask tests and examples. Python 3.7"
|
||||
- label: ":potable_water: Dataset library integrations tests and examples. Python 3.7"
|
||||
conditions: ["RAY_CI_PYTHON_AFFECTED"]
|
||||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
|
|
|
@ -982,16 +982,7 @@
|
|||
" print(\"Best trial final validation accuracy: {}\".format(\n",
|
||||
" best_result.metrics[\"accuracy\"]))\n",
|
||||
"\n",
|
||||
" if ray.util.client.ray.is_connected():\n",
|
||||
" # If using Ray Client, we want to make sure checkpoint access\n",
|
||||
" # happens on the server. So we wrap `test_best_model` in a Ray task.\n",
|
||||
" # We have to make sure it gets executed on the same node that\n",
|
||||
" # ``tuner.fit()`` is called on.\n",
|
||||
" from ray.util.ml_utils.node import force_on_current_node\n",
|
||||
" remote_fn = force_on_current_node(ray.remote(test_best_model))\n",
|
||||
" ray.get(remote_fn.remote(best_result))\n",
|
||||
" else:\n",
|
||||
" test_best_model(best_result)\n",
|
||||
" test_best_model(best_result)\n",
|
||||
"\n",
|
||||
"main(num_samples=2, max_num_epochs=2, gpus_per_trial=0)"
|
||||
]
|
||||
|
@ -1061,7 +1052,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.7"
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orphan": true
|
||||
},
|
||||
|
|
|
@ -1163,17 +1163,7 @@
|
|||
" results = tune_xgboost()\n",
|
||||
"\n",
|
||||
" # Load the best model checkpoint.\n",
|
||||
" if args.server_address:\n",
|
||||
" # If connecting to a remote server with Ray Client, checkpoint loading\n",
|
||||
" # should be wrapped in a task so it will execute on the server.\n",
|
||||
" # We have to make sure it gets executed on the same node that\n",
|
||||
" # ``tuner.fit`` is called on.\n",
|
||||
" from ray.util.ml_utils.node import force_on_current_node\n",
|
||||
"\n",
|
||||
" remote_fn = force_on_current_node(ray.remote(get_best_model_checkpoint))\n",
|
||||
" best_bst = ray.get(remote_fn.remote(results))\n",
|
||||
" else:\n",
|
||||
" best_bst = get_best_model_checkpoint(results)\n",
|
||||
" best_bst = get_best_model_checkpoint(results)\n",
|
||||
"\n",
|
||||
" # You could now do further predictions with\n",
|
||||
" # best_bst.predict(...)"
|
||||
|
@ -1385,7 +1375,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.7"
|
||||
"version": "3.8.6"
|
||||
},
|
||||
"orphan": true
|
||||
},
|
||||
|
|
|
@ -179,6 +179,14 @@ py_test(
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_checkpoint_manager",
|
||||
size = "small",
|
||||
srcs = ["tests/test_checkpoint_manager.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_data_batch_conversion",
|
||||
size = "small",
|
||||
|
@ -211,6 +219,14 @@ py_test(
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_mlflow",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_mlflow.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_remote_storage",
|
||||
size = "small",
|
||||
|
|
|
@ -6,23 +6,19 @@ import logging
|
|||
import numbers
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import ray
|
||||
from ray.air import Checkpoint, CheckpointConfig
|
||||
from ray.air.config import MAX
|
||||
from ray.air._internal.util import is_nan
|
||||
from ray.util import log_once
|
||||
from ray.util.annotations import Deprecated, DeveloperAPI
|
||||
from ray.util.ml_utils.util import is_nan
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class CheckpointStorage(enum.Enum):
|
||||
MEMORY = enum.auto()
|
||||
PERSISTENT = enum.auto()
|
||||
|
@ -193,23 +189,6 @@ class _HeapCheckpointWrapper:
|
|||
return f"_HeapCheckpoint({repr(self.tracked_checkpoint)})"
|
||||
|
||||
|
||||
# Alias for backwards compatibility
|
||||
|
||||
deprecation_message = (
|
||||
"`CheckpointStrategy` is deprecated and will be removed in "
|
||||
"the future. Please use `ray.air.config.CheckpointStrategy` "
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@Deprecated(message=deprecation_message)
|
||||
@dataclass
|
||||
class CheckpointStrategy(CheckpointConfig):
|
||||
def __post_init__(self):
|
||||
warnings.warn(deprecation_message)
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
class _CheckpointManager:
|
||||
"""Common checkpoint management and bookkeeping class for Ray Train and Tune.
|
||||
|
|
@ -5,12 +5,10 @@ from pathlib import Path
|
|||
|
||||
from filelock import FileLock
|
||||
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
RAY_LOCKFILE_DIR = "_ray_lockfiles"
|
||||
|
||||
|
||||
@Deprecated
|
||||
class TempFileLock:
|
||||
"""FileLock wrapper that uses temporary file locks."""
|
||||
|
|
@ -3,10 +3,7 @@ import numbers
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
@Deprecated
|
||||
class SafeFallbackEncoder(json.JSONEncoder):
|
||||
def __init__(self, nan_str="null", **kwargs):
|
||||
super(SafeFallbackEncoder, self).__init__(**kwargs)
|
|
@ -3,7 +3,7 @@ import os
|
|||
import urllib.parse
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ray.util.ml_utils.filelock import TempFileLock
|
||||
from ray.air._internal.filelock import TempFileLock
|
||||
|
||||
try:
|
||||
import fsspec
|
||||
|
|
|
@ -3,10 +3,7 @@ from contextlib import closing
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
@Deprecated
|
||||
def find_free_port():
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
|
@ -14,11 +11,9 @@ def find_free_port():
|
|||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@Deprecated
|
||||
def is_nan(value):
|
||||
return np.isnan(value)
|
||||
|
||||
|
||||
@Deprecated
|
||||
def is_nan_or_inf(value):
|
||||
return is_nan(value) or np.isinf(value)
|
|
@ -2,10 +2,10 @@ import logging
|
|||
from typing import Dict, Optional
|
||||
|
||||
import ray
|
||||
from ray.air._internal.mlflow import _MLflowLoggerUtil
|
||||
from ray.tune.logger import LoggerCallback
|
||||
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union, TYPE_CHECKING
|
|||
import ray
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.air._internal.checkpointing import load_preprocessor_from_dir
|
||||
from ray.air._internal.filelock import TempFileLock
|
||||
from ray.air._internal.remote_storage import (
|
||||
download_from_uri,
|
||||
fs_hint,
|
||||
|
@ -21,7 +22,6 @@ from ray.air._internal.remote_storage import (
|
|||
)
|
||||
from ray.air.constants import PREPROCESSOR_KEY
|
||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.util.ml_utils.filelock import TempFileLock
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from ray.util.ml_utils.checkpoint_manager import (
|
||||
from ray.air._internal.checkpoint_manager import (
|
||||
_CheckpointManager,
|
||||
CheckpointStorage,
|
||||
CheckpointConfig,
|
|
@ -3,7 +3,7 @@ import shutil
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
from ray.air._internal.mlflow import _MLflowLoggerUtil
|
||||
|
||||
|
||||
class MLflowTest(unittest.TestCase):
|
|
@ -13,11 +13,11 @@ from typing import Any, Dict, Optional
|
|||
import yaml
|
||||
|
||||
import ray
|
||||
from ray._private.dict import deep_update
|
||||
from ray.autoscaler._private.fake_multi_node.node_provider import (
|
||||
FAKE_DOCKER_DEFAULT_CLIENT_PORT,
|
||||
FAKE_DOCKER_DEFAULT_GCS_PORT,
|
||||
)
|
||||
from ray.util.ml_utils.dict import deep_update
|
||||
from ray.util.queue import Empty, Queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
8
python/ray/tests/horovod/BUILD
Normal file
8
python/ray/tests/horovod/BUILD
Normal file
|
@ -0,0 +1,8 @@
|
|||
py_test(
|
||||
name = "test_horovod",
|
||||
size = "medium",
|
||||
srcs = ["test_horovod.py"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
||||
|
||||
|
242
python/ray/tests/horovod/horovod_example.py
Normal file
242
python/ray/tests/horovod/horovod_example.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
# This file is duplicated in release/ml_user_tests/horovod
|
||||
import argparse
|
||||
import os
|
||||
from filelock import FileLock
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
import torch.utils.data.distributed
|
||||
|
||||
import horovod.torch as hvd
|
||||
from horovod.ray import RayExecutor
|
||||
|
||||
|
||||
def metric_average(val, name):
|
||||
tensor = torch.tensor(val)
|
||||
avg_tensor = hvd.allreduce(tensor, name=name)
|
||||
return avg_tensor.item()
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x)
|
||||
|
||||
|
||||
def train_fn(
|
||||
data_dir=None,
|
||||
seed=42,
|
||||
use_cuda=False,
|
||||
batch_size=64,
|
||||
use_adasum=False,
|
||||
lr=0.01,
|
||||
momentum=0.5,
|
||||
num_epochs=10,
|
||||
log_interval=10,
|
||||
):
|
||||
# Horovod: initialize library.
|
||||
hvd.init()
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if use_cuda:
|
||||
# Horovod: pin GPU to local rank.
|
||||
torch.cuda.set_device(hvd.local_rank())
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Horovod: limit # of CPU threads to be used per worker.
|
||||
torch.set_num_threads(1)
|
||||
|
||||
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
|
||||
data_dir = data_dir or "./data"
|
||||
with FileLock(os.path.expanduser("~/.horovod_lock")):
|
||||
train_dataset = datasets.MNIST(
|
||||
data_dir,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
),
|
||||
)
|
||||
# Horovod: use DistributedSampler to partition the training data.
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
train_dataset, num_replicas=hvd.size(), rank=hvd.rank()
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs
|
||||
)
|
||||
|
||||
model = Net()
|
||||
|
||||
# By default, Adasum doesn't need scaling up learning rate.
|
||||
lr_scaler = hvd.size() if not use_adasum else 1
|
||||
|
||||
if use_cuda:
|
||||
# Move model to GPU.
|
||||
model.cuda()
|
||||
# If using GPU Adasum allreduce, scale learning rate by local_size.
|
||||
if use_adasum and hvd.nccl_built():
|
||||
lr_scaler = hvd.local_size()
|
||||
|
||||
# Horovod: scale learning rate by lr_scaler.
|
||||
optimizer = optim.SGD(model.parameters(), lr=lr * lr_scaler, momentum=momentum)
|
||||
|
||||
# Horovod: wrap optimizer with DistributedOptimizer.
|
||||
optimizer = hvd.DistributedOptimizer(
|
||||
optimizer,
|
||||
named_parameters=model.named_parameters(),
|
||||
op=hvd.Adasum if use_adasum else hvd.Average,
|
||||
)
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
model.train()
|
||||
# Horovod: set epoch to sampler for shuffling.
|
||||
train_sampler.set_epoch(epoch)
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if use_cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % log_interval == 0:
|
||||
# Horovod: use train_sampler to determine the number of
|
||||
# examples in this worker's partition.
|
||||
print(
|
||||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
|
||||
epoch,
|
||||
batch_idx * len(data),
|
||||
len(train_sampler),
|
||||
100.0 * batch_idx / len(train_loader),
|
||||
loss.item(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
num_workers, use_gpu, timeout_s=30, placement_group_timeout_s=100, kwargs=None
|
||||
):
|
||||
kwargs = kwargs or {}
|
||||
if use_gpu:
|
||||
kwargs["use_cuda"] = True
|
||||
settings = RayExecutor.create_settings(
|
||||
timeout_s=timeout_s, placement_group_timeout_s=placement_group_timeout_s
|
||||
)
|
||||
executor = RayExecutor(settings, use_gpu=use_gpu, num_workers=num_workers)
|
||||
executor.start()
|
||||
executor.run(train_fn, kwargs=kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch MNIST Example",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
metavar="N",
|
||||
help="input batch size for training (default: 64)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
metavar="N",
|
||||
help="number of epochs to train (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.01,
|
||||
metavar="LR",
|
||||
help="learning rate (default: 0.01)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--momentum",
|
||||
type=float,
|
||||
default=0.5,
|
||||
metavar="M",
|
||||
help="SGD momentum (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-cuda", action="store_true", default=False, help="enables CUDA training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, metavar="S", help="random seed (default: 42)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=10,
|
||||
metavar="N",
|
||||
help="how many batches to wait before logging training status",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-adasum",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use adasum algorithm to do reduction",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of Ray workers to use for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
help="location of the training dataset in the local filesystem ("
|
||||
"will be downloaded if needed)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Address of Ray cluster.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import ray
|
||||
|
||||
if args.address:
|
||||
ray.init(args.address)
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
kwargs = {
|
||||
"data_dir": args.data_dir,
|
||||
"seed": args.seed,
|
||||
"use_cuda": args.use_cuda if args.use_cuda else False,
|
||||
"batch_size": args.batch_size,
|
||||
"use_adasum": args.use_adasum if args.use_adasum else False,
|
||||
"lr": args.lr,
|
||||
"momentum": args.momentum,
|
||||
"num_epochs": args.num_epochs,
|
||||
"log_interval": args.log_interval,
|
||||
}
|
||||
|
||||
main(
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=args.use_cuda if args.use_cuda else False,
|
||||
kwargs=kwargs,
|
||||
)
|
|
@ -80,7 +80,7 @@ def test_train(ray_start_4_cpus):
|
|||
|
||||
@pytest.mark.skipif(not gloo_built(), reason="Gloo is required for Ray integration")
|
||||
def test_horovod_example(ray_start_4_cpus):
|
||||
from ray.util.horovod.horovod_example import main
|
||||
from ray.tests.horovod.horovod_example import main
|
||||
|
||||
kwargs = {
|
||||
"data_dir": "./data",
|
15
python/ray/tests/lightgbm/BUILD
Normal file
15
python/ray/tests/lightgbm/BUILD
Normal file
|
@ -0,0 +1,15 @@
|
|||
py_test(
|
||||
name = "simple_example",
|
||||
size = "small",
|
||||
srcs = ["simple_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "simple_tune",
|
||||
size="small",
|
||||
srcs = ["simple_tune.py"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from sklearn import datasets
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from ray.util.lightgbm import RayDMatrix, RayParams, train
|
||||
from lightgbm_ray import RayDMatrix, RayParams, train
|
||||
|
||||
|
||||
# __lightgbm_begin__
|
|
@ -1,7 +1,7 @@
|
|||
from sklearn import datasets
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from ray.util.lightgbm import RayDMatrix, RayParams, train
|
||||
from lightgbm_ray import RayDMatrix, RayParams, train
|
||||
|
||||
# __train_begin__
|
||||
num_cpus_per_actor = 2
|
13
python/ray/tests/ray_lightning/BUILD
Normal file
13
python/ray/tests/ray_lightning/BUILD
Normal file
|
@ -0,0 +1,13 @@
|
|||
py_test(
|
||||
name = "simple_example",
|
||||
size = "small",
|
||||
srcs = ["simple_example.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "simple_tune",
|
||||
size="small",
|
||||
srcs = ["simple_tune.py"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
83
python/ray/tests/ray_lightning/simple_example.py
Normal file
83
python/ray/tests/ray_lightning/simple_example.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# This file is duplicated in release/ml_user_tests/ray-lightning
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.datasets import MNIST
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision import transforms
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from ray_lightning import RayPlugin
|
||||
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
|
||||
)
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# in lightning, forward defines the prediction/inference actions
|
||||
embedding = self.encoder(x)
|
||||
return embedding
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# training_step defines the train loop. It is independent of forward
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
z = self.encoder(x)
|
||||
x_hat = self.decoder(z)
|
||||
loss = F.mse_loss(x_hat, x)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
||||
|
||||
def main(num_workers: int = 2, use_gpu: bool = False, max_steps: int = 10):
|
||||
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
|
||||
train, val = random_split(dataset, [55000, 5000])
|
||||
|
||||
autoencoder = LitAutoEncoder()
|
||||
trainer = pl.Trainer(
|
||||
plugins=[RayPlugin(num_workers=num_workers, use_gpu=use_gpu)],
|
||||
max_steps=max_steps,
|
||||
)
|
||||
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Ray Lightning Example",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of workers to use for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Maximum number of steps to run for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to enable GPU training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(num_workers=args.num_workers, max_steps=args.max_steps, use_gpu=args.use_gpu)
|
|
@ -7,8 +7,8 @@ from torch.utils.data import DataLoader, random_split
|
|||
from torchvision import transforms
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from ray.util.ray_lightning import RayPlugin
|
||||
from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_resources
|
||||
from ray_lightning import RayPlugin
|
||||
from ray_lightning.tune import TuneReportCallback, get_tune_resources
|
||||
|
||||
num_cpus_per_actor = 1
|
||||
num_workers = 1
|
|
@ -1,12 +1,11 @@
|
|||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/ray_lightning directory.
|
||||
# Tests from the python/ray/tests/xgboost directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "simple_example",
|
||||
size = "small",
|
||||
srcs = ["simple_example.py"],
|
||||
deps = [":lightning_lib"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
)
|
||||
|
||||
|
@ -14,13 +13,7 @@ py_test(
|
|||
name = "simple_tune",
|
||||
size="small",
|
||||
srcs = ["simple_tune.py"],
|
||||
deps = [":lightning_lib"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
||||
|
||||
# This is a dummy test dependency that causes the above tests to be
|
||||
# re-run if any of these files changes.
|
||||
py_library(
|
||||
name = "lightning_lib",
|
||||
srcs = glob(["**/*.py"]),
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from sklearn import datasets
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from ray.util.xgboost import RayDMatrix, RayParams, train
|
||||
from xgboost_ray import RayDMatrix, RayParams, train
|
||||
|
||||
|
||||
# __xgboost_begin__
|
|
@ -2,7 +2,7 @@ from sklearn import datasets
|
|||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
||||
from ray.util.xgboost import RayDMatrix, RayParams, train
|
||||
from xgboost_ray import RayDMatrix, RayParams, train
|
||||
|
||||
# __train_begin__
|
||||
num_cpus_per_actor = 1
|
|
@ -1,6 +1,7 @@
|
|||
from ray._private.usage import usage_lib
|
||||
from ray.train.backend import BackendConfig
|
||||
from ray.train.callbacks import TrainingCallback
|
||||
from ray.train.checkpoint import CheckpointStrategy
|
||||
from ray.train.constants import TRAIN_DATASET_KEY
|
||||
from ray.train.train_loop_utils import (
|
||||
get_dataset_shard,
|
||||
|
@ -12,10 +13,7 @@ from ray.train.train_loop_utils import (
|
|||
world_size,
|
||||
)
|
||||
from ray.train.trainer import Trainer, TrainingIterator
|
||||
from ray.air.config import CheckpointConfig
|
||||
|
||||
# Deprecated. Alias of CheckpointConfig for backwards compat
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
|
||||
|
||||
usage_lib.record_library_usage("train")
|
||||
|
||||
|
|
|
@ -3,6 +3,11 @@ from pathlib import Path
|
|||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
from ray.air import Checkpoint, CheckpointConfig
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage
|
||||
from ray.air._internal.checkpoint_manager import (
|
||||
_CheckpointManager as CommonCheckpointManager,
|
||||
)
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
|
||||
from ray.train._internal.session import TrainingResult
|
||||
from ray.train._internal.utils import construct_path
|
||||
from ray.train.constants import (
|
||||
|
@ -11,11 +16,6 @@ from ray.train.constants import (
|
|||
TUNE_CHECKPOINT_ID,
|
||||
TUNE_INSTALLED,
|
||||
)
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage
|
||||
from ray.util.ml_utils.checkpoint_manager import (
|
||||
_CheckpointManager as CommonCheckpointManager,
|
||||
)
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
|
|
|
@ -18,10 +18,11 @@ from typing import (
|
|||
)
|
||||
|
||||
import ray
|
||||
from ray.air._internal.util import find_free_port
|
||||
from ray.actor import ActorHandle
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.types import ObjectRef
|
||||
from ray.util.ml_utils.util import find_free_port
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from ray.air.result import Result
|
|||
from ray.train.constants import TRAIN_DATASET_KEY
|
||||
from ray.util import PublicAPI
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
from ray._private.dict import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset
|
||||
|
|
|
@ -7,6 +7,8 @@ from typing import Dict, List, Optional, Set, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ray.air._internal.mlflow import _MLflowLoggerUtil
|
||||
from ray.air._internal.json import SafeFallbackEncoder
|
||||
from ray.train._internal.results_preprocessors import (
|
||||
ExcludedKeysResultsPreprocessor,
|
||||
IndexedResultsPreprocessor,
|
||||
|
@ -26,9 +28,7 @@ from ray.train.constants import (
|
|||
)
|
||||
from ray.util.annotations import Deprecated
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.ml_utils.dict import flatten_dict
|
||||
from ray.util.ml_utils.json import SafeFallbackEncoder
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
from ray._private.dict import flatten_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
20
python/ray/train/checkpoint.py
Normal file
20
python/ray/train/checkpoint.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
from ray.air.config import CheckpointConfig
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
# Deprecated. Alias of CheckpointConfig for backwards compat
|
||||
deprecation_message = (
|
||||
"`ray.train.checkpoint.CheckpointStrategy` is deprecated and will be removed in "
|
||||
"the future. Please use `ray.air.config.CheckpointConfig` "
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@Deprecated(message=deprecation_message)
|
||||
@dataclass
|
||||
class CheckpointStrategy(CheckpointConfig):
|
||||
def __post_init__(self):
|
||||
warnings.warn(deprecation_message, DeprecationWarning, stacklevel=2)
|
||||
super().__post_init__()
|
|
@ -10,6 +10,7 @@ from ray.air import session
|
|||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig
|
||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
|
||||
from ray.train import BackendConfig, TrainingIterator
|
||||
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
|
||||
from ray.train._internal.checkpoint import TuneCheckpointManager
|
||||
|
@ -18,7 +19,6 @@ from ray.train._internal.utils import construct_train_func
|
|||
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
|
||||
from ray.train.trainer import BaseTrainer, GenDataset
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
|
|
@ -5,6 +5,7 @@ from ray.air import session
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.models import resnet18
|
||||
from filelock import FileLock
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
@ -17,7 +18,6 @@ from ray.train.torch import TorchTrainer
|
|||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from ray.tune.tune_config import TuneConfig
|
||||
from ray.tune.tuner import Tuner
|
||||
from ray.util.ml_utils.resnet import ResNet18
|
||||
|
||||
|
||||
def train_epoch(dataloader, model, loss_fn, optimizer):
|
||||
|
@ -60,7 +60,7 @@ def validate_epoch(dataloader, model, loss_fn):
|
|||
|
||||
def train_func(config):
|
||||
epochs = config.pop("epochs", 3)
|
||||
model = ResNet18(config)
|
||||
model = resnet18()
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
# Create optimizer.
|
||||
|
|
|
@ -11,7 +11,7 @@ from ray.train.trainer import BaseTrainer, GenDataset
|
|||
from ray.tune import Trainable
|
||||
from ray.tune.trainable.util import TrainableUtil
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.ml_utils.dict import flatten_dict
|
||||
from ray._private.dict import flatten_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgboost_ray
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.air import session
|
|||
from ray.air._internal.checkpointing import (
|
||||
save_preprocessor_to_dir,
|
||||
)
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
|
||||
from ray.train.constants import (
|
||||
|
@ -38,7 +39,6 @@ from ray.train.trainer import GenDataset
|
|||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes
|
||||
from ray.util import PublicAPI, get_node_ip_address
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
|
|
@ -17,8 +17,8 @@ from ray.tune.registry import get_trainable_cls
|
|||
from ray.tune.resources import Resources
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
from ray.train.rl.rl_checkpoint import RL_TRAINER_CLASS_FILE, RL_CONFIG_FILE
|
||||
from ray._private.dict import merge_dicts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
|
|
|
@ -10,7 +10,8 @@ import torch
|
|||
import ray
|
||||
import ray.train as train
|
||||
from ray._private.test_utils import wait_for_condition
|
||||
from ray.train import Trainer, CheckpointConfig
|
||||
from ray.air import CheckpointConfig
|
||||
from ray.train import Trainer
|
||||
from ray.train.backend import BackendConfig, Backend
|
||||
from ray.train.constants import TRAIN_ENABLE_WORKER_SPREAD_ENV
|
||||
from ray.train.torch import TorchConfig
|
||||
|
|
|
@ -5,9 +5,9 @@ import warnings
|
|||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.tune.stopper import Stopper
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint
|
||||
|
||||
|
||||
class _CallbackMeta(ABCMeta):
|
||||
|
|
|
@ -226,16 +226,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
|||
print("Best trial final validation accuracy: {}".format(
|
||||
best_result.metrics["accuracy"]))
|
||||
|
||||
if ray.util.client.ray.is_connected():
|
||||
# If using Ray Client, we want to make sure checkpoint access
|
||||
# happens on the server. So we wrap `test_best_model` in a Ray task.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tuner.fit()`` is called on.
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
remote_fn = force_on_current_node(ray.remote(test_best_model))
|
||||
ray.get(remote_fn.remote(best_result.config, best_result.checkpoint))
|
||||
else:
|
||||
test_best_model(best_result.config, best_result.checkpoint)
|
||||
test_best_model(best_result.config, best_result.checkpoint)
|
||||
|
||||
|
||||
# __main_end__
|
||||
|
@ -251,24 +242,12 @@ if __name__ == "__main__":
|
|||
"--ray-address",
|
||||
help="Address of Ray cluster for seamless distributed execution.",
|
||||
required=False)
|
||||
parser.add_argument(
|
||||
"--server-address",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The address of server to connect to if using "
|
||||
"Ray Client.")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=2)
|
||||
main(num_samples=1, max_num_epochs=1, gpus_per_trial=0)
|
||||
else:
|
||||
if args.server_address:
|
||||
# Connect to a remote server through Ray Client.
|
||||
ray.init(f"ray://{args.server_address}")
|
||||
elif args.ray_address:
|
||||
# Run directly on the Ray cluster.
|
||||
ray.init(args.ray_address)
|
||||
ray.init(args.ray_address)
|
||||
# Change this to activate training on GPUs
|
||||
main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)
|
||||
|
|
|
@ -84,20 +84,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-address",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The address of server to connect to if using Ray Client.",
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.server_address:
|
||||
import ray
|
||||
|
||||
ray.init(f"ray://{args.server_address}")
|
||||
|
||||
# __pbt_begin__
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
|
@ -153,14 +141,4 @@ if __name__ == "__main__":
|
|||
results = tuner.fit()
|
||||
# __tune_end__
|
||||
|
||||
if args.server_address:
|
||||
# If using Ray Client, we want to make sure checkpoint access
|
||||
# happens on the server. So we wrap `test_best_model` in a Ray task.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tuner.fit()`` is called on.
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
|
||||
remote_fn = force_on_current_node(ray.remote(test_best_model))
|
||||
ray.get(remote_fn.remote(results))
|
||||
else:
|
||||
test_best_model(results)
|
||||
test_best_model(results)
|
||||
|
|
|
@ -264,13 +264,6 @@ if __name__ == "__main__":
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--server-address",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The address of server to connect to if using Ray Client.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class-trainable",
|
||||
action="store_true",
|
||||
|
@ -285,10 +278,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.server_address:
|
||||
ray.init(f"ray://{args.server_address}")
|
||||
else:
|
||||
ray.init(num_cpus=8)
|
||||
ray.init(num_cpus=8)
|
||||
|
||||
if args.test:
|
||||
best_result = tune_xgboost(use_class_trainable=True)
|
||||
|
@ -296,18 +286,7 @@ if __name__ == "__main__":
|
|||
|
||||
best_result = tune_xgboost(use_class_trainable=args.class_trainable)
|
||||
|
||||
# Load the best model checkpoint.
|
||||
if args.server_address:
|
||||
# If connecting to a remote server with Ray Client, checkpoint loading
|
||||
# should be wrapped in a task so it will execute on the server.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``Tuner.fit()`` is called on.
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
|
||||
remote_fn = force_on_current_node(ray.remote(get_best_model_checkpoint))
|
||||
best_bst = ray.get(remote_fn.remote(best_result))
|
||||
else:
|
||||
best_bst = get_best_model_checkpoint(best_result)
|
||||
best_bst = get_best_model_checkpoint(best_result)
|
||||
|
||||
# You could now do further predictions with
|
||||
# best_bst.predict(...)
|
||||
|
|
|
@ -3,11 +3,12 @@ import sklearn.datasets
|
|||
import sklearn.metrics
|
||||
import os
|
||||
import numpy as np
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from sklearn.model_selection import train_test_split
|
||||
import xgboost as xgb
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from ray.tune.integration.xgboost import (
|
||||
TuneReportCheckpointCallback,
|
||||
TuneReportCallback,
|
||||
|
@ -108,39 +109,17 @@ if __name__ == "__main__":
|
|||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--server-address",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="The address of server to connect to if using Ray Client.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-cv", action="store_true", help="Use `xgb.cv` instead of `xgb.train`."
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.server_address:
|
||||
import ray
|
||||
|
||||
ray.init(f"ray://{args.server_address}")
|
||||
|
||||
best_result = tune_xgboost(args.use_cv)
|
||||
|
||||
# Load the best model checkpoint.
|
||||
# Checkpointing is not supported when using `xgb.cv`
|
||||
if not args.use_cv:
|
||||
if args.server_address:
|
||||
# If connecting to a remote server with Ray Client, checkpoint loading
|
||||
# should be wrapped in a task so it will execute on the server.
|
||||
# We have to make sure it gets executed on the same node that
|
||||
# ``tuner.fit()`` is called on.
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
|
||||
remote_fn = force_on_current_node(ray.remote(get_best_model_checkpoint))
|
||||
best_bst = ray.get(remote_fn.remote(best_result))
|
||||
else:
|
||||
best_bst = get_best_model_checkpoint(best_result)
|
||||
best_bst = get_best_model_checkpoint(best_result)
|
||||
|
||||
# You could now do further predictions with
|
||||
# best_bst.predict(...)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, Optional
|
|||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.air.config import CheckpointConfig, MIN, MAX
|
||||
from ray.util.ml_utils.checkpoint_manager import (
|
||||
from ray.air._internal.checkpoint_manager import (
|
||||
_CheckpointManager as CommonCheckpointManager,
|
||||
_TrackedCheckpoint,
|
||||
CheckpointStorage,
|
||||
|
|
|
@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Union
|
|||
|
||||
import ray
|
||||
from ray.air import Checkpoint
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
from ray.exceptions import GetTimeoutError, RayTaskError
|
||||
from ray.tune.error import (
|
||||
TuneError,
|
||||
|
@ -33,7 +34,6 @@ from ray.tune.utils.resource_updater import _ResourceUpdater
|
|||
from ray.tune.trainable.util import TrainableUtil
|
||||
from ray.util import log_once
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
from ray.util.placement_group import PlacementGroup, remove_placement_group
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -11,6 +11,7 @@ import traceback
|
|||
import warnings
|
||||
|
||||
import ray
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.tune.error import _TuneStopTrialError
|
||||
from ray.tune.impl.out_of_band_serialize_dataset import out_of_band_serialize_dataset
|
||||
|
@ -46,7 +47,6 @@ from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncode
|
|||
from ray.tune.web_server import TuneServer
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing import Dict, Optional, Sequence, Union, Callable, List
|
|||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.exceptions import RayActorError, RayTaskError
|
||||
from ray.tune import TuneError
|
||||
|
@ -43,7 +44,6 @@ from ray.tune.utils import date_str, flatten_dict
|
|||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.debug import log_once
|
||||
from ray._private.utils import binary_to_hex, hex_to_binary
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -4,9 +4,9 @@ import logging
|
|||
from typing import Callable, Dict, Optional
|
||||
|
||||
import ray
|
||||
from ray.air._internal.mlflow import _MLflowLoggerUtil
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.util.annotations import Deprecated
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -5,9 +5,9 @@ import logging
|
|||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Iterable
|
||||
|
||||
import yaml
|
||||
from ray.air._internal.json import SafeFallbackEncoder
|
||||
from ray.tune.callback import Callback
|
||||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.util.ml_utils.json import SafeFallbackEncoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.tune.experiment.trial import Trial # noqa: F401
|
||||
|
|
|
@ -7,6 +7,7 @@ import random
|
|||
import shutil
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage
|
||||
from ray.tune.execution import trial_runner
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_METRIC, TRAINING_ITERATION
|
||||
|
@ -18,7 +19,6 @@ from ray.tune.search.variant_generator import format_vars
|
|||
from ray.tune.experiment import Trial
|
||||
from ray.util import PublicAPI
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import time
|
|||
from dataclasses import dataclass
|
||||
|
||||
import ray
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
from ray.air._internal.remote_storage import (
|
||||
fs_hint,
|
||||
upload_to_uri,
|
||||
|
@ -29,7 +30,6 @@ from ray.tune.callback import Callback
|
|||
from ray.tune.result import NODE_IP
|
||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
||||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.tune.experiment import Trial
|
||||
|
|
|
@ -6,13 +6,13 @@ import tempfile
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.execution.checkpoint_manager import _CheckpointManager
|
||||
from ray.util.ml_utils.checkpoint_manager import (
|
||||
from ray.air._internal.checkpoint_manager import (
|
||||
_TrackedCheckpoint,
|
||||
logger,
|
||||
CheckpointStorage,
|
||||
)
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.execution.checkpoint_manager import _CheckpointManager
|
||||
|
||||
|
||||
class CheckpointManagerTest(unittest.TestCase):
|
||||
|
|
|
@ -14,7 +14,7 @@ from ray.tune.integration.mlflow import (
|
|||
from ray.air.callbacks.mlflow import (
|
||||
MLflowLoggerCallback,
|
||||
)
|
||||
from ray.util.ml_utils.mlflow import _MLflowLoggerUtil
|
||||
from ray.air._internal.mlflow import _MLflowLoggerUtil
|
||||
|
||||
|
||||
class MockTrial(
|
||||
|
|
|
@ -7,6 +7,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.callback import Callback
|
||||
|
@ -27,8 +28,6 @@ from ray.tune.execution.placement_groups import (
|
|||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage
|
||||
|
||||
|
||||
class TrialExecutorInsufficientResourcesTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -7,13 +7,13 @@ import pytest
|
|||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
from ray import tune
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
from ray.tune.result_grid import ResultGrid
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.tune.tests.tune_test_util import create_tune_experiment_checkpoint
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from freezegun import freeze_time
|
||||
|
||||
import ray.util
|
||||
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import NODE_IP
|
||||
from ray.tune.syncer import (
|
||||
|
@ -18,7 +19,6 @@ from ray.tune.syncer import (
|
|||
)
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -6,6 +6,7 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import TuneError
|
||||
|
@ -18,7 +19,6 @@ from ray.tune.resources import Resources
|
|||
from ray.tune.search import BasicVariantGenerator
|
||||
from ray.tune.tests.tune_test_util import TrialResultObserver
|
||||
from ray.tune.trainable.util import TrainableUtil
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
|
||||
|
||||
def create_mock_components():
|
||||
|
|
|
@ -9,6 +9,7 @@ from collections import OrderedDict
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.logger import DEFAULT_LOGGERS, LoggerCallback, LegacyLoggerCallback
|
||||
from ray.tune.execution.ray_trial_executor import (
|
||||
|
@ -25,7 +26,6 @@ from ray.tune.execution.trial_runner import TrialRunner
|
|||
from ray.tune import Callback
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
|
||||
|
||||
class TestCallback(Callback):
|
||||
|
|
|
@ -13,6 +13,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.execution.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
|
@ -33,7 +34,6 @@ from ray.tune.experiment import Trial
|
|||
from ray.tune.resources import Resources
|
||||
|
||||
from ray.rllib import _register_all
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
|
||||
_register_all()
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.experiment import Trial
|
||||
from ray.tune.execution.trial_runner import TrialRunner
|
||||
|
@ -20,7 +21,6 @@ from ray._private.test_utils import object_memory_usage
|
|||
|
||||
# Import psutil after ray so the packaged version is used.
|
||||
import psutil
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
|
||||
MB = 1024 ** 2
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
|
||||
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
from ray.tune import PlacementGroupFactory
|
||||
from ray.tune.schedulers.trial_scheduler import TrialScheduler
|
||||
from ray.tune.experiment import Trial
|
||||
|
@ -8,7 +9,6 @@ from ray.tune.schedulers.resource_changing_scheduler import (
|
|||
DistributeResources,
|
||||
DistributeResourcesToTopJob,
|
||||
)
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
|
||||
|
||||
|
||||
class MockResourceUpdater:
|
||||
|
|
|
@ -51,9 +51,9 @@ from ray.tune.experiment import Trial
|
|||
from ray.tune.execution.trial_runner import TrialRunner
|
||||
from ray.tune.utils.callback import create_default_callbacks
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
|
||||
from ray.tune.utils.node import force_on_current_node
|
||||
from ray.tune.execution.placement_groups import PlacementGroupFactory
|
||||
from ray.util.annotations import PublicAPI
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
from ray.util.queue import Empty, Queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -9,8 +9,8 @@ from ray.tune.result_grid import ResultGrid
|
|||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.impl.tuner_internal import TunerInternal
|
||||
from ray.tune.tune_config import TuneConfig
|
||||
from ray.tune.utils.node import force_on_current_node
|
||||
from ray.util import PublicAPI
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.train.trainer import BaseTrainer
|
||||
|
|
|
@ -6,7 +6,7 @@ import tarfile
|
|||
from typing import Optional, Tuple, Dict, Generator, Union
|
||||
|
||||
import ray
|
||||
from ray.util.ml_utils.filelock import TempFileLock
|
||||
from ray.air._internal.filelock import TempFileLock
|
||||
|
||||
|
||||
_DEFAULT_CHUNK_SIZE_BYTES = 500 * 1024 * 1024 # 500 MiB
|
||||
|
|
|
@ -1,15 +1,11 @@
|
|||
import ray
|
||||
from ray.util.annotations import Deprecated
|
||||
|
||||
|
||||
@Deprecated
|
||||
def get_current_node_resource_key() -> str:
|
||||
def _get_current_node_resource_key() -> str:
|
||||
"""Get the Ray resource key for current node.
|
||||
It can be used for actor placement.
|
||||
|
||||
If using Ray Client, this will return the resource key for the node that
|
||||
is running the client server.
|
||||
|
||||
Returns:
|
||||
(str) A string of the format node:<CURRENT-NODE-IP-ADDRESS>
|
||||
"""
|
||||
|
@ -24,22 +20,18 @@ def get_current_node_resource_key() -> str:
|
|||
raise ValueError("Cannot found the node dictionary for current node.")
|
||||
|
||||
|
||||
@Deprecated
|
||||
def force_on_current_node(task_or_actor=None):
|
||||
"""Given a task or actor, place it on the current node.
|
||||
|
||||
If using Ray Client, the current node is the client server node.
|
||||
|
||||
Args:
|
||||
task_or_actor: A Ray remote function or class to place on the
|
||||
current node. If None, returns the options dict to pass to
|
||||
another actor.
|
||||
|
||||
Returns:
|
||||
The provided task or actor, but with options modified to force
|
||||
placement on the current node.
|
||||
"""
|
||||
node_resource_key = get_current_node_resource_key()
|
||||
node_resource_key = _get_current_node_resource_key()
|
||||
options = {"resources": {node_resource_key: 0.01}}
|
||||
|
||||
if task_or_actor is None:
|
|
@ -17,7 +17,12 @@ import psutil
|
|||
import ray
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air._internal.remote_storage import delete_at_uri
|
||||
from ray.util.ml_utils.dict import ( # noqa: F401
|
||||
from ray.air._internal.json import SafeFallbackEncoder # noqa
|
||||
from ray.air._internal.util import ( # noqa: F401
|
||||
is_nan,
|
||||
is_nan_or_inf,
|
||||
)
|
||||
from ray._private.dict import ( # noqa: F401
|
||||
merge_dicts,
|
||||
deep_update,
|
||||
flatten_dict,
|
||||
|
@ -25,11 +30,6 @@ from ray.util.ml_utils.dict import ( # noqa: F401
|
|||
unflatten_list_dict,
|
||||
unflattened_lookup,
|
||||
)
|
||||
from ray.util.ml_utils.json import SafeFallbackEncoder # noqa
|
||||
from ray.util.ml_utils.util import ( # noqa: F401
|
||||
is_nan,
|
||||
is_nan_or_inf,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/horovod directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_horovod",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_horovod.py"],
|
||||
deps = [":horovod_lib"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
||||
|
||||
# This is a dummy test dependency that causes the above tests to be
|
||||
# re-run if any of these files changes.
|
||||
py_library(
|
||||
name = "horovod_lib",
|
||||
srcs = glob(["**/*.py"]),
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
raise DeprecationWarning(
|
||||
"ray.util.horovod has been removed as of Ray 2.0. Instead, use the `horovod` "
|
||||
"library directly or the `HorovodTrainer` in Ray AIR ("
|
||||
"https://docs.ray.io/en/master/ray-air/getting-started.html)"
|
||||
)
|
|
@ -1,36 +0,0 @@
|
|||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/lightgbm directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "simple_example",
|
||||
size = "small",
|
||||
srcs = ["simple_example.py"],
|
||||
deps = [":lgbm_lib"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "simple_tune",
|
||||
size="small",
|
||||
srcs = ["simple_tune.py"],
|
||||
deps = [":lgbm_lib"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_client",
|
||||
size = "small",
|
||||
srcs = ["tests/test_client.py"],
|
||||
deps = [":lgbm_lib"],
|
||||
tags = ["team:ml", "exclusive", "client"]
|
||||
)
|
||||
|
||||
# This is a dummy test dependency that causes the above tests to be
|
||||
# re-run if any of these files changes.
|
||||
py_library(
|
||||
name = "lgbm_lib",
|
||||
srcs = glob(["**/*.py"]),
|
||||
)
|
||||
|
||||
|
|
@ -1,37 +1,5 @@
|
|||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
train = None
|
||||
predict = None
|
||||
RayParams = None
|
||||
RayDMatrix = None
|
||||
RayFileType = None
|
||||
RayLGBMClassifier = None
|
||||
RayLGBMRegressor = None
|
||||
|
||||
try:
|
||||
from lightgbm_ray import (
|
||||
train,
|
||||
predict,
|
||||
RayParams,
|
||||
RayDMatrix,
|
||||
RayFileType,
|
||||
RayLGBMClassifier,
|
||||
RayLGBMRegressor,
|
||||
)
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"lightgbm_ray is not installed. Please run "
|
||||
"`pip install git+https://github.com/ray-project/lightgbm_ray`."
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"predict",
|
||||
"RayParams",
|
||||
"RayDMatrix",
|
||||
"RayFileType",
|
||||
"RayLGBMClassifier",
|
||||
"RayLGBMRegressor",
|
||||
]
|
||||
raise DeprecationWarning(
|
||||
"ray.util.lightgbm has been removed as of Ray 2.0. Instead, use the `lightgbm-ray` "
|
||||
"library directly or the `LightGBMTrainer` in Ray AIR ("
|
||||
"https://docs.ray.io/en/master/ray-air/getting-started.html)"
|
||||
)
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
import pytest
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_client_server():
|
||||
with ray_start_client_server() as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_simple_example(start_client_server):
|
||||
assert ray.util.client.ray.is_connected()
|
||||
from ray.util.lightgbm.simple_example import main
|
||||
|
||||
main()
|
||||
|
||||
|
||||
def test_simple_tune(start_client_server):
|
||||
assert ray.util.client.ray.is_connected()
|
||||
from ray.util.lightgbm.simple_tune import main
|
||||
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,26 +0,0 @@
|
|||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/ml_util/tests directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "test_checkpoint_manager",
|
||||
size = "small",
|
||||
srcs = ["tests/test_checkpoint_manager.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_util_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_mlflow",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_mlflow.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_util_lib"]
|
||||
)
|
||||
|
||||
# This is a dummy test dependency that causes the above tests to be
|
||||
# re-run if any of these files changes.
|
||||
py_library(
|
||||
name = "ml_util_lib",
|
||||
srcs = glob(["**/*.py"], exclude=["tests/*.py"]),
|
||||
)
|
|
@ -1,132 +0,0 @@
|
|||
"""ResNet in PyTorch.
|
||||
Copied from https://github.com/kuangliu/pytorch-cifar/
|
||||
blob/ab908327d44bf9b1d22cd333a4466e85083d3f21/models/resnet.py
|
||||
"""
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, self.expansion * planes, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
||||
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def ResNet18(_):
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
|
||||
def ResNet34(_):
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def ResNet50(_):
|
||||
return ResNet(Bottleneck, [3, 4, 6, 3])
|
||||
|
||||
|
||||
def ResNet101(_):
|
||||
return ResNet(Bottleneck, [3, 4, 23, 3])
|
||||
|
||||
|
||||
def ResNet152(_):
|
||||
return ResNet(Bottleneck, [3, 8, 36, 3])
|
|
@ -1,16 +1,4 @@
|
|||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RayPlugin = None
|
||||
HorovodRayPlugin = None
|
||||
RayShardedPlugin = None
|
||||
|
||||
try:
|
||||
from ray_lightning import RayPlugin, HorovodRayPlugin, RayShardedPlugin
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"ray_lightning is not installed. Please run `pip install ray-lightning`."
|
||||
)
|
||||
|
||||
__all__ = ["RayPlugin", "HorovodRayPlugin", "RayShardedPlugin"]
|
||||
raise DeprecationWarning(
|
||||
"ray.util.ray_lightning has been removed as of Ray 2.0. Instead, use the "
|
||||
"`ray_lightning` library directly."
|
||||
)
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TuneReportCallback = None
|
||||
TuneReportCheckpointCallback = None
|
||||
get_tune_resources = None
|
||||
|
||||
try:
|
||||
from ray_lightning.tune import (
|
||||
TuneReportCallback,
|
||||
TuneReportCheckpointCallback,
|
||||
get_tune_resources,
|
||||
)
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"ray_lightning is not installed. Please run "
|
||||
"`pip install git+https://github.com/ray-project/"
|
||||
"ray_lightning#ray_lightning`."
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TuneReportCallback",
|
||||
"TuneReportCheckpointCallback",
|
||||
"get_tune_resources",
|
||||
]
|
|
@ -1,36 +0,0 @@
|
|||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/xgboost directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "simple_example",
|
||||
size = "small",
|
||||
srcs = ["simple_example.py"],
|
||||
deps = [":xgb_lib"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "simple_tune",
|
||||
size="small",
|
||||
srcs = ["simple_tune.py"],
|
||||
deps = [":xgb_lib"],
|
||||
tags = ["team:ml", "exclusive"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_client",
|
||||
size = "small",
|
||||
srcs = ["tests/test_client.py"],
|
||||
deps = [":xgb_lib"],
|
||||
tags = ["team:ml", "exclusive", "client"]
|
||||
)
|
||||
|
||||
# This is a dummy test dependency that causes the above tests to be
|
||||
# re-run if any of these files changes.
|
||||
py_library(
|
||||
name = "xgb_lib",
|
||||
srcs = glob(["**/*.py"]),
|
||||
)
|
||||
|
||||
|
|
@ -1,44 +1,5 @@
|
|||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
train = None
|
||||
predict = None
|
||||
RayParams = None
|
||||
RayDMatrix = None
|
||||
RayFileType = None
|
||||
RayXGBClassifier = None
|
||||
RayXGBRegressor = None
|
||||
RayXGBRFClassifier = None
|
||||
RayXGBRFRegressor = None
|
||||
|
||||
try:
|
||||
from xgboost_ray import (
|
||||
train,
|
||||
predict,
|
||||
RayParams,
|
||||
RayDMatrix,
|
||||
RayFileType,
|
||||
RayXGBClassifier,
|
||||
RayXGBRegressor,
|
||||
RayXGBRFClassifier,
|
||||
RayXGBRFRegressor,
|
||||
)
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"xgboost_ray is not installed. Please run "
|
||||
"`pip install 'git+https://github.com/ray-project/"
|
||||
"xgboost_ray#egg=xgboost_ray'`."
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"predict",
|
||||
"RayParams",
|
||||
"RayDMatrix",
|
||||
"RayFileType",
|
||||
"RayXGBClassifier",
|
||||
"RayXGBRegressor",
|
||||
"RayXGBRFClassifier",
|
||||
"RayXGBRFRegressor",
|
||||
]
|
||||
raise DeprecationWarning(
|
||||
"ray.util.xgboost has been removed as of Ray 2.0. Instead, use the `xgboost-ray` "
|
||||
"library directly or the `XGBoostTrainer` in Ray AIR ("
|
||||
"https://docs.ray.io/en/master/ray-air/getting-started.html)"
|
||||
)
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
import pytest
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray.util.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_client_server():
|
||||
with ray_start_client_server() as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_simple_example(start_client_server):
|
||||
assert ray.util.client.ray.is_connected()
|
||||
from ray.util.xgboost.simple_example import main
|
||||
|
||||
main()
|
||||
|
||||
|
||||
def test_simple_tune(start_client_server):
|
||||
assert ray.util.client.ray.is_connected()
|
||||
from ray.util.xgboost.simple_tune import main
|
||||
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -10,13 +10,13 @@ from ray.tune.tuner import Tuner
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.models import resnet18
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.tune.schedulers import create_scheduler
|
||||
|
||||
from ray.util.ml_utils.resnet import ResNet18
|
||||
|
||||
from ray.tune.utils.release_test_util import ProgressCallback
|
||||
|
||||
|
@ -31,7 +31,7 @@ def train_loop_per_worker(config):
|
|||
|
||||
hvd.init()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
net = ResNet18(None).to(device)
|
||||
net = resnet18().to(device)
|
||||
optimizer = torch.optim.SGD(
|
||||
net.parameters(),
|
||||
lr=config["lr"],
|
||||
|
|
|
@ -13,10 +13,10 @@ import torchvision.transforms as transforms
|
|||
from filelock import FileLock
|
||||
from ray import serve, tune, train
|
||||
from ray.train import Trainer
|
||||
from ray.util.ml_utils.node import force_on_current_node
|
||||
from ray.util.ml_utils.resnet import ResNet18
|
||||
from ray.tune.utils.node import force_on_current_node
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.models import resnet18
|
||||
|
||||
|
||||
def load_mnist_data(train: bool, download: bool):
|
||||
|
@ -55,7 +55,7 @@ def validate_epoch(dataloader, model, loss_fn):
|
|||
|
||||
def training_loop(config):
|
||||
# Create model.
|
||||
model = ResNet18(config)
|
||||
model = resnet18()
|
||||
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
|
@ -129,7 +129,7 @@ def get_model(model_checkpoint_path):
|
|||
checkpoint_dict = Trainer.load_checkpoint_from_path(model_checkpoint_path)
|
||||
model_state = checkpoint_dict["model_state_dict"]
|
||||
|
||||
model = ResNet18(None)
|
||||
model = resnet18()
|
||||
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
|
||||
model.load_state_dict(model_state)
|
||||
|
||||
|
|
|
@ -15,12 +15,14 @@ Notes: This test seems to be somewhat flaky. This might be due to
|
|||
race conditions in handling dead actors. This is likely a problem of
|
||||
the lightgbm_ray implementation and not of this test.
|
||||
"""
|
||||
import os
|
||||
|
||||
import ray
|
||||
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
|
||||
from ray.util.lightgbm.release_test_util import (
|
||||
from release_test_util import (
|
||||
train_ray,
|
||||
FailureState,
|
||||
FailureInjection,
|
||||
|
@ -28,7 +30,7 @@ from ray.util.lightgbm.release_test_util import (
|
|||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
|
||||
|
||||
failure_state = FailureState.remote()
|
||||
|
||||
|
|
|
@ -13,10 +13,10 @@ import time
|
|||
import ray
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
from ray.util.lightgbm.release_test_util import train_ray
|
||||
from release_test_util import train_ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
|
||||
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
|
|
|
@ -13,15 +13,18 @@ import time
|
|||
import ray
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
from ray.util.lightgbm.release_test_util import train_ray
|
||||
from release_test_util import train_ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "train_small")
|
||||
|
||||
runtime_env = {"working_dir": os.path.dirname(__file__)}
|
||||
|
||||
if addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name)
|
||||
ray.init(address=addr, job_name=job_name, runtime_env=runtime_env)
|
||||
else:
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env=runtime_env)
|
||||
|
||||
output = os.environ["TEST_OUTPUT_JSON"]
|
||||
ray_params = RayParams(
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
"""Small cluster training
|
||||
|
||||
This training run will start 4 workers on 4 nodes (including head node).
|
||||
|
||||
Test owner: Yard1 (primary), krfricke
|
||||
|
||||
Acceptance criteria: Should run through and report final results.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
from ray.util.lightgbm.release_test_util import train_ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "train_small")
|
||||
if addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name)
|
||||
else:
|
||||
ray.init(address="auto")
|
||||
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
max_actor_restarts=2,
|
||||
num_actors=4,
|
||||
cpus_per_actor=4,
|
||||
gpus_per_actor=0,
|
||||
)
|
||||
|
||||
@ray.remote
|
||||
def train():
|
||||
train_ray(
|
||||
path="/data/classification.parquet",
|
||||
num_workers=None,
|
||||
num_boost_rounds=100,
|
||||
num_files=25,
|
||||
regression=False,
|
||||
use_gpu=False,
|
||||
ray_params=ray_params,
|
||||
lightgbm_params=None,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
ray.get(train.remote())
|
||||
taken = time.time() - start
|
||||
|
||||
result = {
|
||||
"time_taken": taken,
|
||||
}
|
||||
test_output_json = os.environ.get("TEST_OUTPUT_JSON", "/tmp/train_small.json")
|
||||
with open(test_output_json, "wt") as f:
|
||||
json.dump(result, f)
|
||||
|
||||
print("PASSED.")
|
|
@ -19,7 +19,7 @@ from ray import tune
|
|||
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
from ray.util.lightgbm.release_test_util import train_ray
|
||||
from release_test_util import train_ray
|
||||
|
||||
|
||||
def train_wrapper(config, ray_params):
|
||||
|
@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||
"max_depth": tune.randint(1, 9),
|
||||
}
|
||||
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
|
||||
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray import tune
|
|||
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
from ray.util.lightgbm.release_test_util import train_ray
|
||||
from release_test_util import train_ray
|
||||
|
||||
|
||||
def train_wrapper(config, ray_params):
|
||||
|
@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||
"max_depth": tune.randint(1, 9),
|
||||
}
|
||||
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
|
||||
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
|
|
|
@ -19,7 +19,7 @@ from ray import tune
|
|||
|
||||
from lightgbm_ray import RayParams
|
||||
|
||||
from ray.util.lightgbm.release_test_util import train_ray
|
||||
from release_test_util import train_ray
|
||||
|
||||
|
||||
def train_wrapper(config, ray_params):
|
||||
|
@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||
"max_depth": tune.randint(1, 9),
|
||||
}
|
||||
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env={"working_dir": os.path.dirname(__file__)})
|
||||
|
||||
ray_params = RayParams(
|
||||
elastic_training=False,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# This file is duplicated in ray/tests/horovod
|
||||
import argparse
|
||||
import os
|
||||
from filelock import FileLock
|
|
@ -3,17 +3,20 @@ import os
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.util.horovod.horovod_example import main
|
||||
from horovod_example import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = time.time()
|
||||
|
||||
addr = os.environ.get("RAY_ADDRESS")
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "horovod_user_test")
|
||||
if addr is not None and addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name)
|
||||
|
||||
runtime_env = {"working_dir": os.path.dirname(__file__)}
|
||||
|
||||
if addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name, runtime_env=runtime_env)
|
||||
else:
|
||||
ray.init(address="auto")
|
||||
ray.init(address="auto", runtime_env=runtime_env)
|
||||
|
||||
main(
|
||||
num_workers=6,
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import time
|
||||
|
||||
import ray
|
||||
from ray.util.ray_lightning.simple_example import main
|
||||
from simple_example import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = time.time()
|
||||
|
@ -16,9 +16,12 @@ if __name__ == "__main__":
|
|||
# See https://github.com/pytorch/pytorch/issues/68893 for more details.
|
||||
# Passing in runtime_env to ray.init() will also set it for all the
|
||||
# workers.
|
||||
runtime_env = {"env_vars": {"NCCL_SOCKET_IFNAME": "ens3"}}
|
||||
runtime_env = {
|
||||
"env_vars": {"NCCL_SOCKET_IFNAME": "ens3"},
|
||||
"working_dir": os.path.dirname(__file__),
|
||||
}
|
||||
|
||||
if addr is not None and addr.startswith("anyscale://"):
|
||||
if addr.startswith("anyscale://"):
|
||||
ray.init(address=addr, job_name=job_name, runtime_env=runtime_env)
|
||||
else:
|
||||
ray.init(address="auto", runtime_env=runtime_env)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# This file is duplicated in ray/tests/ray_lightning
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
|
@ -8,7 +9,7 @@ from torch.utils.data import DataLoader, random_split
|
|||
from torchvision import transforms
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from ray.util.ray_lightning import RayPlugin
|
||||
from ray_lightning import RayPlugin
|
||||
|
||||
|
||||
class LitAutoEncoder(pl.LightningModule):
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue