[serve] Implement serve.build (#23232)

The Serve REST API relies on YAML config files to specify and deploy deployments. This change introduces `serve.build()` and `serve build`, which translate Pipelines to YAML files.

Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
This commit is contained in:
Edward Oakes 2022-03-25 13:36:59 -05:00 committed by GitHub
parent be216a0e8c
commit cf7b4e65c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 343 additions and 114 deletions

View file

@ -1061,6 +1061,7 @@ def import_attr(full_path: str):
"""Given a full import path to a module attr, return the imported attr.
For example, the following are equivalent:
MyClass = import_attr("module.submodule:MyClass")
MyClass = import_attr("module.submodule.MyClass")
from module.submodule import MyClass
@ -1069,9 +1070,19 @@ def import_attr(full_path: str):
"""
if full_path is None:
raise TypeError("import path cannot be None")
if ":" in full_path:
if full_path.count(":") > 1:
raise ValueError(
f'Got invalid import path "{full_path}". An '
"import path may have at most one colon."
)
module_name, attr_name = full_path.split(":")
else:
last_period_idx = full_path.rfind(".")
attr_name = full_path[last_period_idx + 1 :]
module_name = full_path[:last_period_idx]
attr_name = full_path[last_period_idx + 1 :]
module = importlib.import_module(module_name)
return getattr(module, attr_name)

View file

@ -67,6 +67,7 @@ from ray.serve.utils import (
get_current_node_resource_key,
get_random_letters,
get_deployment_import_path,
in_interactive_shell,
logger,
DEFAULT,
)
@ -1792,9 +1793,7 @@ class Application:
Returns:
Dict: The Application's deployments formatted in a dictionary.
"""
return ServeApplicationSchema(
deployments=[deployment_to_schema(d) for d in self._deployments.values()]
).dict()
return serve_application_to_schema(self._deployments.values()).dict()
@classmethod
def from_dict(cls, d: Dict) -> "Application":
@ -1811,8 +1810,7 @@ class Application:
Application: a new application object containing the deployments.
"""
schema = ServeApplicationSchema.parse_obj(d)
return cls([schema_to_deployment(s) for s in schema.deployments])
return cls(schema_to_serve_application(ServeApplicationSchema.parse_obj(d)))
def to_yaml(self, f: Optional[TextIO] = None) -> Optional[str]:
"""Returns this application's deployments as a YAML string.
@ -1834,6 +1832,7 @@ class Application:
Optional[String]: The deployments' YAML string. The output is from
yaml.safe_dump(). Returned only if no file pointer is passed in.
"""
return yaml.safe_dump(
self.to_dict(), stream=f, default_flow_style=False, sort_keys=False
)
@ -1872,7 +1871,7 @@ def run(
*,
host: str = DEFAULT_HTTP_HOST,
port: int = DEFAULT_HTTP_PORT,
) -> RayServeHandle:
) -> Optional[RayServeHandle]:
"""Run a Serve application and return a ServeHandle to the ingress.
Either a DeploymentNode, DeploymentFunctionNode, or a pre-built application
@ -1954,13 +1953,13 @@ def run(
@PublicAPI(stability="alpha")
def build(target: DeploymentNode) -> Application:
def build(target: Union[DeploymentNode, DeploymentFunctionNode]) -> Application:
"""Builds a Serve application into a static application.
Takes in a DeploymentNode and converts it to a Serve application
consisting of one or more deployments. This is intended to be used for
production scenarios and deployed via the Serve REST API or CLI, so there
are some restrictions placed on the deployments:
Takes in a DeploymentNode or DeploymentFunctionNode and converts it to a
Serve application consisting of one or more deployments. This is intended
to be used for production scenarios and deployed via the Serve REST API or
CLI, so there are some restrictions placed on the deployments:
1) All of the deployments must be importable. That is, they cannot be
defined in __main__ or inline defined. The deployments will be
imported in production using the same import path they were here.
@ -1969,9 +1968,19 @@ def build(target: DeploymentNode) -> Application:
The returned Application object can be exported to a dictionary or YAML
config.
"""
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
from ray.serve.pipeline.api import build as pipeline_build
if in_interactive_shell():
raise RuntimeError(
"serve.build cannot be called from an interactive shell like "
"IPython or Jupyter because it requires all deployments to be "
"importable to run the app after building."
)
# TODO(edoakes): this should accept host and port, but we don't
# currently support them in the REST API.
raise NotImplementedError()
return Application(pipeline_build(target))
def deployment_to_schema(d: Deployment) -> DeploymentSchema:
@ -1983,6 +1992,7 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
init_args and init_kwargs must also be JSON-serializable or this call will
fail.
"""
from ray.serve.pipeline.json_serde import convert_to_json_safe_obj
if d.ray_actor_options is not None:
ray_actor_options_schema = RayActorOptionsSchema.parse_obj(d.ray_actor_options)
@ -1991,9 +2001,11 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
return DeploymentSchema(
name=d.name,
import_path=get_deployment_import_path(d),
init_args=d.init_args,
init_kwargs=d.init_kwargs,
import_path=get_deployment_import_path(
d, enforce_importable=True, replace_main=True
),
init_args=convert_to_json_safe_obj(d.init_args, err_key="init_args"),
init_kwargs=convert_to_json_safe_obj(d.init_kwargs, err_key="init_kwargs"),
num_replicas=d.num_replicas,
route_prefix=d.route_prefix,
max_concurrent_queries=d.max_concurrent_queries,
@ -2017,12 +2029,12 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
return deployment(
name=s.name,
init_args=convert_from_json_safe_obj(s.init_args),
init_kwargs=convert_from_json_safe_obj(s.init_kwargs),
init_args=convert_from_json_safe_obj(s.init_args, err_key="init_args"),
init_kwargs=convert_from_json_safe_obj(s.init_kwargs, err_key="init_kwargs"),
num_replicas=s.num_replicas,
route_prefix=s.route_prefix,
max_concurrent_queries=s.max_concurrent_queries,
user_config=convert_from_json_safe_obj(s.user_config),
user_config=s.user_config,
_autoscaling_config=s.autoscaling_config,
_graceful_shutdown_wait_loop_s=s.graceful_shutdown_wait_loop_s,
_graceful_shutdown_timeout_s=s.graceful_shutdown_timeout_s,
@ -2035,8 +2047,9 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
def serve_application_to_schema(
deployments: List[Deployment],
) -> ServeApplicationSchema:
schemas = [deployment_to_schema(d) for d in deployments]
return ServeApplicationSchema(deployments=schemas)
return ServeApplicationSchema(
deployments=[deployment_to_schema(d) for d in deployments]
)
def schema_to_serve_application(schema: ServeApplicationSchema) -> List[Deployment]:

View file

@ -31,13 +31,36 @@ from ray.serve.api import RayServeDAGHandle
def convert_to_json_safe_obj(obj: Any, *, err_key: str) -> Any:
# XXX: comment, err msg
"""Converts the provided object into a JSON-safe version of it.
The returned object can safely be `json.dumps`'d to a string.
Uses the Ray Serve encoder to serialize special objects such as
ServeHandles and DAGHandles.
Raises: TypeError if the object contains fields that cannot be
JSON-serialized.
"""
try:
return json.loads(json.dumps(obj, cls=DAGNodeEncoder))
except Exception as e:
raise TypeError(
"All provided fields must be JSON-serializable to build the "
f"Serve app. Failed while serializing {err_key}:\n{e}"
)
def convert_from_json_safe_obj(obj: Any) -> Any:
# XXX: comment, err msg
def convert_from_json_safe_obj(obj: Any, *, err_key: str) -> Any:
"""Converts a JSON-safe object to one that contains Serve special types.
The provided object should have been serialized using
convert_to_json_safe_obj. Any special-cased objects such as ServeHandles
will be recovered on this pass.
"""
try:
return json.loads(json.dumps(obj), object_hook=dagnode_from_json)
except Exception as e:
raise ValueError(f"Failed to convert {err_key} from JSON:\n{e}")
class DAGNodeEncoder(json.JSONEncoder):
@ -81,17 +104,7 @@ class DAGNodeEncoder(json.JSONEncoder):
elif isinstance(obj, DAGNode):
return obj.to_json(DAGNodeEncoder)
else:
# Let the base class default method raise the TypeError
try:
return json.JSONEncoder.default(self, obj)
except Exception as e:
raise TypeError(
"All args and kwargs used in Ray DAG building for serve "
"deployment need to be JSON serializable. Please JSON "
"serialize your args to make your ray application "
"deployment ready."
f"\n Original exception message: {e}"
)
def dagnode_from_json(input_json: Any) -> Union[DAGNode, RayServeHandle, Any]:

View file

@ -97,24 +97,6 @@ def test_non_json_serializable_args():
ray_dag = combine.bind(MyNonJSONClass(1), MyNonJSONClass(2))
# General context
with pytest.raises(
TypeError,
match=(
"All args and kwargs used in Ray DAG building for serve "
"deployment need to be JSON serializable."
),
):
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
# User actionable item
with pytest.raises(
TypeError,
match=(
"Please JSON serialize your args to make your ray application "
"deployment ready"
),
):
_ = json.dumps(ray_dag, cls=DAGNodeEncoder)
# Original error message
with pytest.raises(
TypeError,
match=r"Object of type .* is not JSON serializable",

View file

@ -5,6 +5,7 @@ import pathlib
import click
import time
import sys
from typing import Optional, Union
import yaml
import ray
@ -22,6 +23,8 @@ from ray.dashboard.modules.serve.sdk import ServeSubmissionClient
from ray.autoscaler._private.cli_logger import cli_logger
from ray.serve.api import (
Application,
DeploymentFunctionNode,
DeploymentNode,
get_deployment_statuses,
serve_application_status_to_schema,
)
@ -378,3 +381,51 @@ def delete(address: str, yes: bool):
cli_logger.newline()
cli_logger.success("\nSent delete request successfully!\n")
cli_logger.newline()
@cli.command(
short_help="Writes a Pipeline's config file.",
help=(
"Imports the DeploymentNode or DeploymentFunctionNode at IMPORT_PATH "
"and generates a structured config for it that can be used by "
"`serve deploy` or the REST API. "
),
)
@click.option(
"--app-dir",
"-d",
default=".",
type=str,
help=APP_DIR_HELP_STR,
)
@click.option(
"--output-path",
"-o",
default=None,
type=str,
help=(
"Local path where the output config will be written in YAML format. "
"If not provided, the config will be printed to STDOUT."
),
)
@click.argument("import_path")
def build(app_dir: str, output_path: Optional[str], import_path: str):
sys.path.insert(0, app_dir)
node: Union[DeploymentNode, DeploymentFunctionNode] = import_attr(import_path)
if not isinstance(node, (DeploymentNode, DeploymentFunctionNode)):
raise TypeError(
f"Expected '{import_path}' to be DeploymentNode or "
f"DeploymentFunctionNode, but got {type(node)}."
)
app = serve.build(node)
if output_path is not None:
if not output_path.endswith(".yaml"):
raise ValueError("FILE_PATH must end with '.yaml'.")
with open(output_path, "w") as f:
app.to_yaml(f)
else:
print(app.to_yaml(), end="")

View file

@ -5,6 +5,7 @@ import sys
import os
import yaml
import requests
import numpy as np
import ray
from ray import serve
@ -42,7 +43,7 @@ class TestApplicationConstruction:
Application([self.f, 5, "hello"])
class TestRun:
class TestServeRun:
@serve.deployment
def f():
return "f reached"
@ -272,6 +273,35 @@ class DecoratedClass:
return "got decorated class"
class TestServeBuild:
@serve.deployment
class A:
pass
def test_build_non_json_serializable_args(self, serve_instance):
with pytest.raises(
TypeError, match="must be JSON-serializable to build.*init_args"
):
serve.build(self.A.bind(np.zeros(100))).to_dict()
def test_build_non_json_serializable_kwargs(self, serve_instance):
with pytest.raises(
TypeError, match="must be JSON-serializable to build.*init_kwargs"
):
serve.build(self.A.bind(kwarg=np.zeros(100))).to_dict()
def test_build_non_importable(self, serve_instance):
def gen_deployment():
@serve.deployment
def f():
pass
return f
with pytest.raises(RuntimeError, match="must be importable"):
serve.build(gen_deployment().bind()).to_dict()
def compare_specified_options(deployments1: Dict, deployments2: Dict):
"""
Helper method that takes 2 deployment dictionaries in the REST API

View file

@ -6,12 +6,13 @@ import sys
import signal
import pytest
import requests
from tempfile import NamedTemporaryFile
import ray
from ray import serve
from ray.tests.conftest import tmp_working_dir # noqa: F401, E501
from ray._private.test_utils import wait_for_condition
from ray.serve.api import Application
from ray.serve.api import Application, RayServeDAGHandle
def ping_endpoint(endpoint: str, params: str = ""):
@ -372,6 +373,45 @@ def test_run_runtime_env(ray_start_stop):
p.wait()
@serve.deployment
def global_f(*args):
return "wonderful world"
@serve.deployment
class NoArgDriver:
def __init__(self, dag: RayServeDAGHandle):
self.dag = dag
async def __call__(self):
return await self.dag.remote()
TestBuildFNode = global_f.bind()
TestBuildDagNode = NoArgDriver.bind(TestBuildFNode)
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
@pytest.mark.parametrize("node", ["TestBuildFNode", "TestBuildDagNode"])
def test_build(ray_start_stop, node):
with NamedTemporaryFile(mode="w+", suffix=".yaml") as tmp:
# Build an app
subprocess.check_output(
[
"serve",
"build",
f"ray.serve.tests.test_cli.{node}",
"-o",
tmp.name,
]
)
subprocess.check_output(["serve", "deploy", tmp.name])
assert ping_endpoint("") == "wonderful world"
subprocess.check_output(["serve", "delete", "-y"])
assert ping_endpoint("") == "connection error"
@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
@pytest.mark.parametrize("use_command", [True, False])
def test_idempotence_after_controller_death(ray_start_stop, use_command: bool):

View file

@ -1,15 +1,15 @@
import pytest
import os
import sys
from typing import TypeVar
from typing import TypeVar, Union
import numpy as np
import requests
import ray
from ray import serve
from ray.serve.api import RayServeDAGHandle
from ray.experimental.dag.input_node import InputNode
from ray.serve.api import Application, DeploymentNode, RayServeDAGHandle
from ray.serve.pipeline.api import build as pipeline_build
from ray.serve.drivers import DAGDriver
import starlette.requests
@ -19,6 +19,15 @@ RayHandleLike = TypeVar("RayHandleLike")
NESTED_HANDLE_KEY = "nested_handle"
def maybe_build(
node: DeploymentNode, use_build: bool
) -> Union[Application, DeploymentNode]:
if use_build:
return Application.from_dict(serve.build(node).to_dict())
else:
return node
@serve.deployment
class ClassHello:
def __init__(self):
@ -113,11 +122,12 @@ class NoargDriver:
return await self.dag.remote()
def test_single_func_no_input(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_single_func_no_input(serve_instance, use_build):
dag = fn_hello.bind()
serve_dag = NoargDriver.bind(dag)
handle = serve.run(serve_dag)
handle = serve.run(maybe_build(serve_dag, use_build))
assert ray.get(handle.remote()) == "hello"
assert requests.get("http://127.0.0.1:8000/").text == "hello"
@ -126,7 +136,8 @@ async def json_resolver(request: starlette.requests.Request):
return await request.json()
def test_single_func_deployment_dag(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_single_func_deployment_dag(serve_instance, use_build):
with InputNode() as dag_input:
dag = combine.bind(dag_input[0], dag_input[1], kwargs_output=1)
serve_dag = DAGDriver.bind(dag, input_schema=json_resolver)
@ -135,7 +146,8 @@ def test_single_func_deployment_dag(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=[1, 2]).json() == 4
def test_chained_function(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_chained_function(serve_instance, use_build):
@serve.deployment
def func_1(input):
return input
@ -156,7 +168,8 @@ def test_chained_function(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=2).json() == 6
def test_simple_class_with_class_method(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_simple_class_with_class_method(serve_instance, use_build):
with InputNode() as dag_input:
model = Model.bind(2, ratio=0.3)
dag = model.forward.bind(dag_input)
@ -166,7 +179,8 @@ def test_simple_class_with_class_method(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 0.6
def test_func_class_with_class_method(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_func_class_with_class_method(serve_instance, use_build):
with InputNode() as dag_input:
m1 = Model.bind(1)
m2 = Model.bind(2)
@ -180,7 +194,8 @@ def test_func_class_with_class_method(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=[1, 2, 3]).json() == 8
def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_multi_instantiation_class_deployment_in_init_args(serve_instance, use_build):
with InputNode() as dag_input:
m1 = Model.bind(2)
m2 = Model.bind(3)
@ -193,7 +208,8 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 5
def test_shared_deployment_handle(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_shared_deployment_handle(serve_instance, use_build):
with InputNode() as dag_input:
m = Model.bind(2)
combine = Combine.bind(m, m2=m)
@ -205,7 +221,8 @@ def test_shared_deployment_handle(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 4
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance, use_build):
with InputNode() as dag_input:
m1 = Model.bind(2)
m2 = Model.bind(3)
@ -238,13 +255,15 @@ class Echo:
return self._s
def test_single_node_deploy_success(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_single_node_deploy_success(serve_instance, use_build):
m1 = Adder.bind(1)
handle = serve.run(m1)
handle = serve.run(maybe_build(m1, use_build))
assert ray.get(handle.remote(41)) == 42
def test_single_node_driver_sucess(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_single_node_driver_sucess(serve_instance, use_build):
m1 = Adder.bind(1)
m2 = Adder.bind(2)
with InputNode() as input_node:
@ -280,7 +299,8 @@ class TakeHandle:
return ray.get(self.handle.remote(inp))
def test_passing_handle(serve_instance):
@pytest.mark.parametrize("use_build", [False, True])
def test_passing_handle(serve_instance, use_build):
child = Adder.bind(1)
parent = TakeHandle.bind(child)
driver = DAGDriver.bind(parent, input_schema=json_resolver)
@ -289,40 +309,46 @@ def test_passing_handle(serve_instance):
assert requests.post("http://127.0.0.1:8000/", json=1).json() == 2
def test_passing_handle_in_obj(serve_instance):
@serve.deployment
class Parent:
@serve.deployment
class DictParent:
def __init__(self, d):
self._d = d
async def __call__(self, key):
return await self._d[key].remote()
@pytest.mark.parametrize("use_build", [False, True])
def test_passing_handle_in_obj(serve_instance, use_build):
child1 = Echo.bind("ed")
child2 = Echo.bind("simon")
parent = Parent.bind({"child1": child1, "child2": child2})
parent = maybe_build(
DictParent.bind({"child1": child1, "child2": child2}), use_build
)
handle = serve.run(parent)
assert ray.get(handle.remote("child1")) == "ed"
assert ray.get(handle.remote("child2")) == "simon"
def test_pass_handle_to_multiple(serve_instance):
@serve.deployment
class Child:
@serve.deployment
class Child:
def __call__(self, *args):
return os.getpid()
@serve.deployment
class Parent:
@serve.deployment
class Parent:
def __init__(self, child):
self._child = child
def __call__(self, *args):
return ray.get(self._child.remote())
@serve.deployment
class GrandParent:
@serve.deployment
class GrandParent:
def __init__(self, child, parent):
self._child = child
self._parent = parent
@ -332,31 +358,43 @@ def test_pass_handle_to_multiple(serve_instance):
assert ray.get(self._child.remote()) == ray.get(self._parent.remote())
return "ok"
@pytest.mark.parametrize("use_build", [False, True])
def test_pass_handle_to_multiple(serve_instance, use_build):
child = Child.bind()
parent = Parent.bind(child)
grandparent = GrandParent.bind(child, parent)
grandparent = maybe_build(GrandParent.bind(child, parent), use_build)
handle = serve.run(grandparent)
assert ray.get(handle.remote()) == "ok"
def test_non_json_serializable_args(serve_instance):
def test_run_non_json_serializable_args(serve_instance):
# Test that we can capture and bind non-json-serializable arguments.
arr1 = np.zeros(100)
arr2 = np.zeros(200)
arr3 = np.zeros(300)
@serve.deployment
class A:
def __init__(self, arr1):
def __init__(self, arr1, *, arr2):
self.arr1 = arr1
self.arr2 = arr2
self.arr3 = arr3
def __call__(self, *args):
return self.arr1, self.arr2
return self.arr1, self.arr2, self.arr3
handle = serve.run(A.bind(arr1))
ret1, ret2 = ray.get(handle.remote())
assert np.array_equal(ret1, arr1) and np.array_equal(ret2, arr2)
handle = serve.run(A.bind(arr1, arr2=arr2))
ret1, ret2, ret3 = ray.get(handle.remote())
assert all(
[
np.array_equal(ret1, arr1),
np.array_equal(ret2, arr2),
np.array_equal(ret3, arr3),
]
)
@serve.deployment
@ -406,8 +444,5 @@ def test_unsupported_remote():
_ = func.bind().remote()
# TODO: check that serve.build raises an exception.
if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))

View file

@ -39,7 +39,44 @@ class DecoratedActor:
return "reached decorated_actor"
def gen_func():
@serve.deployment
def f():
pass
return f
def gen_class():
@serve.deployment
class A:
pass
return A
class TestGetDeploymentImportPath:
def test_invalid_inline_defined(self):
@serve.deployment
def inline_f():
pass
with pytest.raises(RuntimeError, match="must be importable"):
get_deployment_import_path(inline_f, enforce_importable=True)
with pytest.raises(RuntimeError, match="must be importable"):
get_deployment_import_path(gen_func(), enforce_importable=True)
@serve.deployment
class InlineCls:
pass
with pytest.raises(RuntimeError, match="must be importable"):
get_deployment_import_path(InlineCls, enforce_importable=True)
with pytest.raises(RuntimeError, match="must be importable"):
get_deployment_import_path(gen_class(), enforce_importable=True)
def test_get_import_path_basic(self):
d = decorated_f.options()

View file

@ -255,7 +255,9 @@ def msgpack_serialize(obj):
return serialized
def get_deployment_import_path(deployment, replace_main=False):
def get_deployment_import_path(
deployment, replace_main=False, enforce_importable=False
):
"""
Gets the import path for deployment's func_or_class.
@ -275,8 +277,15 @@ def get_deployment_import_path(deployment, replace_main=False):
import_path = f"{body.__module__}.{body.__qualname__}"
if replace_main:
if enforce_importable and "<locals>" in body.__qualname__:
raise RuntimeError(
"Deployment definitions must be importable to build the Serve app, "
f"but deployment '{deployment.name}' is inline defined or returned "
"from another function. Please restructure your code so that "
f"'{import_path}' can be imported (i.e., put it in a module)."
)
if replace_main:
# Replaces __main__ with its file name. E.g. suppose the import path
# is __main__.classname and classname is defined in filename.py.
# Its import path becomes filename.classname.
@ -365,3 +374,11 @@ def require_packages(packages: List[str]):
return wrapped
return decorator
def in_interactive_shell():
# Taken from:
# https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
import __main__ as main
return not hasattr(main, "__file__")