diff --git a/scripts/shell.py b/scripts/shell.py old mode 100644 new mode 100755 index d51eca6cf..4d6614353 --- a/scripts/shell.py +++ b/scripts/shell.py @@ -1,5 +1,7 @@ +#!/usr/bin/env python + import os -import argparse +import sys import numpy as np import ray @@ -8,19 +10,20 @@ import ray.array.remote as ra import ray.array.distributed as da import example_functions -DEFAULT_NUM_WORKERS = 10 -DEFAULT_WORKER_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "default_worker.py") +def main(argv): + DEFAULT_NUM_WORKERS = 1 + DEFAULT_WORKER_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "default_worker.py") -parser = argparse.ArgumentParser(description="Parse shell options") -parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address") -parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") -parser.add_argument("--worker-address", default="127.0.0.1:30001", type=str, help="the worker's address") -parser.add_argument("--attach", action="store_true", help="If true, attach the shell to an already running cluster. If false, start a new cluster.") -parser.add_argument("--worker-path", type=str, help="Path to the worker script") -parser.add_argument("--num-workers", type=int, help="Number of workers to start") + import argparse # No need for this to be global + parser = argparse.ArgumentParser(description="Parse shell options") + parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address") + parser.add_argument("--objstore-address", default="127.0.0.1:20001", type=str, help="the objstore's address") + parser.add_argument("--worker-address", default="127.0.0.1:30001", type=str, help="the worker's address") + parser.add_argument("--attach", action="store_true", help="If true, attach the shell to an already running cluster. If false, start a new cluster.") + parser.add_argument("--worker-path", type=str, help="Path to the worker script") + parser.add_argument("--num-workers", type=int, help="Number of workers to start") -if __name__ == "__main__": - args = parser.parse_args() + args, unknown_args = parser.parse_known_args(argv) if args.attach: assert args.worker_path is None, "when attaching, no new worker can be started" assert args.num_workers is None, "when attaching, no new worker can be started" @@ -29,6 +32,8 @@ if __name__ == "__main__": ray.services.start_ray_local(num_workers=args.num_workers if not args.num_workers is None else DEFAULT_NUM_WORKERS, worker_path=args.worker_path if not args.worker_path is None else DEFAULT_WORKER_PATH, driver_mode=ray.SHELL_MODE) + return unknown_args +if __name__ == "__main__": import IPython - IPython.embed() + IPython.terminal.ipapp.launch_new_instance(argv=main(sys.argv[1:]), user_ns=globals())