Check if the provider is external before getting the config. (#1743)

This commit is contained in:
Christian Barra 2018-03-23 06:59:29 +01:00 committed by Richard Liaw
parent 8704c8618c
commit 13b3df9321
2 changed files with 12 additions and 1 deletions

View file

@ -10,6 +10,8 @@ except ImportError: # py2
def dockerize_if_needed(config): def dockerize_if_needed(config):
if "docker" not in config:
return config
docker_image = config["docker"].get("image") docker_image = config["docker"].get("image")
cname = config["docker"].get("container_name") cname = config["docker"].get("container_name")
if not docker_image: if not docker_image:

View file

@ -19,6 +19,13 @@ def load_aws_config():
ray_aws.__file__), "example-full.yaml") ray_aws.__file__), "example-full.yaml")
def import_external():
"""Mock a normal provider importer."""
def return_it_back(config):
return config
return return_it_back, None
NODE_PROVIDERS = { NODE_PROVIDERS = {
"aws": import_aws, "aws": import_aws,
"gce": None, # TODO: support more node providers "gce": None, # TODO: support more node providers
@ -26,7 +33,7 @@ NODE_PROVIDERS = {
"kubernetes": None, "kubernetes": None,
"docker": None, "docker": None,
"local_cluster": None, "local_cluster": None,
"external": None, # Import an external module "external": import_external # Import an external module
} }
DEFAULT_CONFIGS = { DEFAULT_CONFIGS = {
@ -71,6 +78,8 @@ def get_node_provider(provider_config, cluster_name):
def get_default_config(provider_config): def get_default_config(provider_config):
if provider_config["type"] == "external":
return {}
load_config = DEFAULT_CONFIGS.get(provider_config["type"]) load_config = DEFAULT_CONFIGS.get(provider_config["type"])
if load_config is None: if load_config is None:
raise NotImplementedError( raise NotImplementedError(