[release] update torch_tune_serve_test to use anyscale connect (#16754)

* [release] update torch_tune_serve_test to use anyscale connect

* use download_results to download model checkpoint

* clean up code to support both OSS and Anyscale
This commit is contained in:
matthewdeng 2021-07-06 19:02:50 -07:00 committed by GitHub
parent 7318a212fb
commit 23088bd7ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 4 deletions

View file

@ -26,6 +26,7 @@
compute_template: gpu_tpl.yaml
run:
use_connect: True
timeout: 1800
script: python workloads/torch_tune_serve_test.py

View file

@ -17,6 +17,13 @@ from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
def _is_anyscale_connect():
address = os.environ.get("RAY_ADDRESS")
is_anyscale_connect = address is not None and address.startswith(
"anyscale://")
return is_anyscale_connect
def load_mnist_data(train: bool, download: bool):
transform = transforms.Compose(
[transforms.ToTensor(),
@ -85,8 +92,33 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False):
checkpoint_at_end=True)
def get_best_model(best_model_checkpoint_path):
model_state = torch.load(best_model_checkpoint_path)
def get_remote_model(remote_model_checkpoint_path):
if _is_anyscale_connect():
# Download training results to local client.
local_dir = "~/ray_results"
# TODO(matt): remove the following line when Anyscale Connect
# supports tilde expansion.
local_dir = os.path.expanduser(local_dir)
remote_dir = "/home/ray/ray_results/"
ray.client().download_results(
local_dir=local_dir, remote_dir=remote_dir)
# Compute local path.
rel_model_checkpoint_path = os.path.relpath(
remote_model_checkpoint_path, remote_dir)
local_model_checkpoint_path = os.path.join(local_dir,
rel_model_checkpoint_path)
# Load model reference.
return get_model(local_model_checkpoint_path)
else:
get_best_model_remote = ray.remote(get_model)
return ray.get(
get_best_model_remote.remote(remote_model_checkpoint_path))
def get_model(model_checkpoint_path):
model_state = torch.load(model_checkpoint_path)
model = ResNet18(None)
model.conv1 = nn.Conv2d(
@ -184,7 +216,12 @@ if __name__ == "__main__":
start = time.time()
ray.client("anyscale://").connect()
client_builder = ray.client()
if (_is_anyscale_connect()):
job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test")
client_builder.job_name(job_name)
client_builder.connect()
num_workers = 2
use_gpu = True
@ -193,7 +230,7 @@ if __name__ == "__main__":
print("Retrieving best model.")
best_checkpoint = analysis.best_checkpoint
model_id = get_best_model(best_checkpoint)
model_id = get_remote_model(best_checkpoint)
print("Setting up Serve.")
setup_serve(model_id)