diff --git a/release/golden_notebook_tests/golden_notebook_tests.yaml b/release/golden_notebook_tests/golden_notebook_tests.yaml index ae17c4d48..1cdf7e960 100644 --- a/release/golden_notebook_tests/golden_notebook_tests.yaml +++ b/release/golden_notebook_tests/golden_notebook_tests.yaml @@ -26,6 +26,7 @@ compute_template: gpu_tpl.yaml run: + use_connect: True timeout: 1800 script: python workloads/torch_tune_serve_test.py diff --git a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py index da52404b5..22c266573 100644 --- a/release/golden_notebook_tests/workloads/torch_tune_serve_test.py +++ b/release/golden_notebook_tests/workloads/torch_tune_serve_test.py @@ -17,6 +17,13 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST +def _is_anyscale_connect(): + address = os.environ.get("RAY_ADDRESS") + is_anyscale_connect = address is not None and address.startswith( + "anyscale://") + return is_anyscale_connect + + def load_mnist_data(train: bool, download: bool): transform = transforms.Compose( [transforms.ToTensor(), @@ -85,8 +92,33 @@ def train_mnist(test_mode=False, num_workers=1, use_gpu=False): checkpoint_at_end=True) -def get_best_model(best_model_checkpoint_path): - model_state = torch.load(best_model_checkpoint_path) +def get_remote_model(remote_model_checkpoint_path): + if _is_anyscale_connect(): + # Download training results to local client. + local_dir = "~/ray_results" + # TODO(matt): remove the following line when Anyscale Connect + # supports tilde expansion. + local_dir = os.path.expanduser(local_dir) + remote_dir = "/home/ray/ray_results/" + ray.client().download_results( + local_dir=local_dir, remote_dir=remote_dir) + + # Compute local path. + rel_model_checkpoint_path = os.path.relpath( + remote_model_checkpoint_path, remote_dir) + local_model_checkpoint_path = os.path.join(local_dir, + rel_model_checkpoint_path) + + # Load model reference. + return get_model(local_model_checkpoint_path) + else: + get_best_model_remote = ray.remote(get_model) + return ray.get( + get_best_model_remote.remote(remote_model_checkpoint_path)) + + +def get_model(model_checkpoint_path): + model_state = torch.load(model_checkpoint_path) model = ResNet18(None) model.conv1 = nn.Conv2d( @@ -184,7 +216,12 @@ if __name__ == "__main__": start = time.time() - ray.client("anyscale://").connect() + client_builder = ray.client() + if (_is_anyscale_connect()): + job_name = os.environ.get("RAY_JOB_NAME", "torch_tune_serve_test") + client_builder.job_name(job_name) + client_builder.connect() + num_workers = 2 use_gpu = True @@ -193,7 +230,7 @@ if __name__ == "__main__": print("Retrieving best model.") best_checkpoint = analysis.best_checkpoint - model_id = get_best_model(best_checkpoint) + model_id = get_remote_model(best_checkpoint) print("Setting up Serve.") setup_serve(model_id)