mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[release] update torch_tune_serve_test to use anyscale connect (#16754)
* [release] update torch_tune_serve_test to use anyscale connect * use download_results to download model checkpoint * clean up code to support both OSS and Anyscale
This commit is contained in:
parent
7318a212fb
commit
23088bd7ea
2 changed files with 42 additions and 4 deletions
|
@ -26,6 +26,7 @@
|
||||||
compute_template: gpu_tpl.yaml
|
compute_template: gpu_tpl.yaml
|
||||||
|
|
||||||
run:
|
run:
|
||||||
|
use_connect: True
|
||||||
timeout: 1800
|
timeout: 1800
|
||||||
script: python workloads/torch_tune_serve_test.py
|
script: python workloads/torch_tune_serve_test.py
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,13 @@ from torch.utils.data import DataLoader, Subset
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
|
||||||
|
def _is_anyscale_connect():
|
||||||
|
address = os.environ.get("RAY_ADDRESS")
|
||||||
|
is_anyscale_connect = address is not None and address.startswith(
|
||||||
|
"anyscale://")
|
||||||
|
return is_anyscale_connect
|
||||||
|
|
||||||
|
|
||||||
def load_mnist_data(train: bool, download: bool):
|
def load_mnist_data(train: bool, download: bool):
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
[transforms.ToTensor(),
|
[transforms.ToTensor(),
|
||||||
|
@ -85,8 +92,33 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False):
|
||||||
checkpoint_at_end=True)
|
checkpoint_at_end=True)
|
||||||
|
|
||||||
|
|
||||||
def get_best_model(best_model_checkpoint_path):
|
def get_remote_model(remote_model_checkpoint_path):
|
||||||
model_state = torch.load(best_model_checkpoint_path)
|
if _is_anyscale_connect():
|
||||||
|
# Download training results to local client.
|
||||||
|
local_dir = "~/ray_results"
|
||||||
|
# TODO(matt): remove the following line when Anyscale Connect
|
||||||
|
# supports tilde expansion.
|
||||||
|
local_dir = os.path.expanduser(local_dir)
|
||||||
|
remote_dir = "/home/ray/ray_results/"
|
||||||
|
ray.client().download_results(
|
||||||
|
local_dir=local_dir, remote_dir=remote_dir)
|
||||||
|
|
||||||
|
# Compute local path.
|
||||||
|
rel_model_checkpoint_path = os.path.relpath(
|
||||||
|
remote_model_checkpoint_path, remote_dir)
|
||||||
|
local_model_checkpoint_path = os.path.join(local_dir,
|
||||||
|
rel_model_checkpoint_path)
|
||||||
|
|
||||||
|
# Load model reference.
|
||||||
|
return get_model(local_model_checkpoint_path)
|
||||||
|
else:
|
||||||
|
get_best_model_remote = ray.remote(get_model)
|
||||||
|
return ray.get(
|
||||||
|
get_best_model_remote.remote(remote_model_checkpoint_path))
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_checkpoint_path):
|
||||||
|
model_state = torch.load(model_checkpoint_path)
|
||||||
|
|
||||||
model = ResNet18(None)
|
model = ResNet18(None)
|
||||||
model.conv1 = nn.Conv2d(
|
model.conv1 = nn.Conv2d(
|
||||||
|
@ -184,7 +216,12 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
ray.client("anyscale://").connect()
|
client_builder = ray.client()
|
||||||
|
if (_is_anyscale_connect()):
|
||||||
|
job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test")
|
||||||
|
client_builder.job_name(job_name)
|
||||||
|
client_builder.connect()
|
||||||
|
|
||||||
num_workers = 2
|
num_workers = 2
|
||||||
use_gpu = True
|
use_gpu = True
|
||||||
|
|
||||||
|
@ -193,7 +230,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print("Retrieving best model.")
|
print("Retrieving best model.")
|
||||||
best_checkpoint = analysis.best_checkpoint
|
best_checkpoint = analysis.best_checkpoint
|
||||||
model_id = get_best_model(best_checkpoint)
|
model_id = get_remote_model(best_checkpoint)
|
||||||
|
|
||||||
print("Setting up Serve.")
|
print("Setting up Serve.")
|
||||||
setup_serve(model_id)
|
setup_serve(model_id)
|
||||||
|
|
Loading…
Add table
Reference in a new issue