mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[release] update torch_tune_serve_test to use anyscale connect (#16754)
* [release] update torch_tune_serve_test to use anyscale connect * use download_results to download model checkpoint * clean up code to support both OSS and Anyscale
This commit is contained in:
parent
7318a212fb
commit
23088bd7ea
2 changed files with 42 additions and 4 deletions
|
@ -26,6 +26,7 @@
|
|||
compute_template: gpu_tpl.yaml
|
||||
|
||||
run:
|
||||
use_connect: True
|
||||
timeout: 1800
|
||||
script: python workloads/torch_tune_serve_test.py
|
||||
|
||||
|
|
|
@ -17,6 +17,13 @@ from torch.utils.data import DataLoader, Subset
|
|||
from torchvision.datasets import MNIST
|
||||
|
||||
|
||||
def _is_anyscale_connect():
|
||||
address = os.environ.get("RAY_ADDRESS")
|
||||
is_anyscale_connect = address is not None and address.startswith(
|
||||
"anyscale://")
|
||||
return is_anyscale_connect
|
||||
|
||||
|
||||
def load_mnist_data(train: bool, download: bool):
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
|
@ -85,8 +92,33 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False):
|
|||
checkpoint_at_end=True)
|
||||
|
||||
|
||||
def get_best_model(best_model_checkpoint_path):
|
||||
model_state = torch.load(best_model_checkpoint_path)
|
||||
def get_remote_model(remote_model_checkpoint_path):
|
||||
if _is_anyscale_connect():
|
||||
# Download training results to local client.
|
||||
local_dir = "~/ray_results"
|
||||
# TODO(matt): remove the following line when Anyscale Connect
|
||||
# supports tilde expansion.
|
||||
local_dir = os.path.expanduser(local_dir)
|
||||
remote_dir = "/home/ray/ray_results/"
|
||||
ray.client().download_results(
|
||||
local_dir=local_dir, remote_dir=remote_dir)
|
||||
|
||||
# Compute local path.
|
||||
rel_model_checkpoint_path = os.path.relpath(
|
||||
remote_model_checkpoint_path, remote_dir)
|
||||
local_model_checkpoint_path = os.path.join(local_dir,
|
||||
rel_model_checkpoint_path)
|
||||
|
||||
# Load model reference.
|
||||
return get_model(local_model_checkpoint_path)
|
||||
else:
|
||||
get_best_model_remote = ray.remote(get_model)
|
||||
return ray.get(
|
||||
get_best_model_remote.remote(remote_model_checkpoint_path))
|
||||
|
||||
|
||||
def get_model(model_checkpoint_path):
|
||||
model_state = torch.load(model_checkpoint_path)
|
||||
|
||||
model = ResNet18(None)
|
||||
model.conv1 = nn.Conv2d(
|
||||
|
@ -184,7 +216,12 @@ if __name__ == "__main__":
|
|||
|
||||
start = time.time()
|
||||
|
||||
ray.client("anyscale://").connect()
|
||||
client_builder = ray.client()
|
||||
if (_is_anyscale_connect()):
|
||||
job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test")
|
||||
client_builder.job_name(job_name)
|
||||
client_builder.connect()
|
||||
|
||||
num_workers = 2
|
||||
use_gpu = True
|
||||
|
||||
|
@ -193,7 +230,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("Retrieving best model.")
|
||||
best_checkpoint = analysis.best_checkpoint
|
||||
model_id = get_best_model(best_checkpoint)
|
||||
model_id = get_remote_model(best_checkpoint)
|
||||
|
||||
print("Setting up Serve.")
|
||||
setup_serve(model_id)
|
||||
|
|
Loading…
Add table
Reference in a new issue