diff --git a/LICENSE b/LICENSE index 2484261d5..c523e66a8 100644 --- a/LICENSE +++ b/LICENSE @@ -401,3 +401,26 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +-------------------------------------------------------------------------------- +Code in python/ray/_private/thirdparty/dacite is adapted from https://github.com/konradhalas/dacite/blob/master/dacite + +Copyright (c) 2018 Konrad Hałas + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/python/ray/_private/runtime_env/plugin_schema_manager.py b/python/ray/_private/runtime_env/plugin_schema_manager.py index 1ff52f0e2..fffee0551 100644 --- a/python/ray/_private/runtime_env/plugin_schema_manager.py +++ b/python/ray/_private/runtime_env/plugin_schema_manager.py @@ -14,7 +14,9 @@ logger = logging.getLogger(__name__) class RuntimeEnvPluginSchemaManager: """This manager is used to load plugin json schemas.""" - default_schema_path = os.path.join(os.path.dirname(__file__), "schemas") + default_schema_path = os.path.join( + os.path.dirname(__file__), "../../runtime_env/schemas" + ) schemas = {} loaded = False diff --git a/python/ray/_private/thirdparty/dacite/LICENSE b/python/ray/_private/thirdparty/dacite/LICENSE new file mode 100644 index 000000000..4be5be762 --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Konrad Hałas + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/python/ray/_private/thirdparty/dacite/__init__.py b/python/ray/_private/thirdparty/dacite/__init__.py new file mode 100644 index 000000000..318cd9b60 --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/__init__.py @@ -0,0 +1,3 @@ +from .config import Config +from .core import from_dict +from .exceptions import * diff --git a/python/ray/_private/thirdparty/dacite/config.py b/python/ray/_private/thirdparty/dacite/config.py new file mode 100644 index 000000000..1766a68e2 --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/config.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass, field +from typing import Dict, Any, Callable, Optional, Type, List + + +@dataclass +class Config: + type_hooks: Dict[Type, Callable[[Any], Any]] = field(default_factory=dict) + cast: List[Type] = field(default_factory=list) + forward_references: Optional[Dict[str, Any]] = None + check_types: bool = True + strict: bool = False + strict_unions_match: bool = False diff --git a/python/ray/_private/thirdparty/dacite/core.py b/python/ray/_private/thirdparty/dacite/core.py new file mode 100644 index 000000000..eccfe5e96 --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/core.py @@ -0,0 +1,140 @@ +import copy +from dataclasses import is_dataclass +from itertools import zip_longest +from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any + +from .config import Config +from .data import Data +from .dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError, get_fields +from .exceptions import ( + ForwardReferenceError, + WrongTypeError, + DaciteError, + UnionMatchError, + MissingValueError, + DaciteFieldError, + UnexpectedDataError, + StrictUnionMatchError, +) +from .types import ( + is_instance, + is_generic_collection, + is_union, + extract_generic, + is_optional, + transform_value, + extract_origin_collection, + is_init_var, + extract_init_var, +) + +T = TypeVar("T") + + +def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) -> T: + """Create a data class instance from a dictionary. + + :param data_class: a data class type + :param data: a dictionary of a input data + :param config: a configuration of the creation process + :return: an instance of a data class + """ + init_values: Data = {} + post_init_values: Data = {} + config = config or Config() + try: + data_class_hints = get_type_hints(data_class, globalns=config.forward_references) + except NameError as error: + raise ForwardReferenceError(str(error)) + data_class_fields = get_fields(data_class) + if config.strict: + extra_fields = set(data.keys()) - {f.name for f in data_class_fields} + if extra_fields: + raise UnexpectedDataError(keys=extra_fields) + for field in data_class_fields: + field = copy.copy(field) + field.type = data_class_hints[field.name] + try: + try: + field_data = data[field.name] + transformed_value = transform_value( + type_hooks=config.type_hooks, cast=config.cast, target_type=field.type, value=field_data + ) + value = _build_value(type_=field.type, data=transformed_value, config=config) + except DaciteFieldError as error: + error.update_path(field.name) + raise + if config.check_types and not is_instance(value, field.type): + raise WrongTypeError(field_path=field.name, field_type=field.type, value=value) + except KeyError: + try: + value = get_default_value_for_field(field) + except DefaultValueNotFoundError: + if not field.init: + continue + raise MissingValueError(field.name) + if field.init: + init_values[field.name] = value + else: + post_init_values[field.name] = value + + return create_instance(data_class=data_class, init_values=init_values, post_init_values=post_init_values) + + +def _build_value(type_: Type, data: Any, config: Config) -> Any: + if is_init_var(type_): + type_ = extract_init_var(type_) + if is_union(type_): + return _build_value_for_union(union=type_, data=data, config=config) + elif is_generic_collection(type_) and is_instance(data, extract_origin_collection(type_)): + return _build_value_for_collection(collection=type_, data=data, config=config) + elif is_dataclass(type_) and is_instance(data, Data): + return from_dict(data_class=type_, data=data, config=config) + return data + + +def _build_value_for_union(union: Type, data: Any, config: Config) -> Any: + types = extract_generic(union) + if is_optional(union) and len(types) == 2: + return _build_value(type_=types[0], data=data, config=config) + union_matches = {} + for inner_type in types: + try: + # noinspection PyBroadException + try: + data = transform_value( + type_hooks=config.type_hooks, cast=config.cast, target_type=inner_type, value=data + ) + except Exception: # pylint: disable=broad-except + continue + value = _build_value(type_=inner_type, data=data, config=config) + if is_instance(value, inner_type): + if config.strict_unions_match: + union_matches[inner_type] = value + else: + return value + except DaciteError: + pass + if config.strict_unions_match: + if len(union_matches) > 1: + raise StrictUnionMatchError(union_matches) + return union_matches.popitem()[1] + if not config.check_types: + return data + raise UnionMatchError(field_type=union, value=data) + + +def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any: + data_type = data.__class__ + if is_instance(data, Mapping): + item_type = extract_generic(collection, defaults=(Any, Any))[1] + return data_type((key, _build_value(type_=item_type, data=value, config=config)) for key, value in data.items()) + elif is_instance(data, tuple): + types = extract_generic(collection) + if len(types) == 2 and types[1] == Ellipsis: + return data_type(_build_value(type_=types[0], data=item, config=config) for item in data) + return data_type( + _build_value(type_=type_, data=item, config=config) for item, type_ in zip_longest(data, types) + ) + item_type = extract_generic(collection, defaults=(Any,))[0] + return data_type(_build_value(type_=item_type, data=item, config=config) for item in data) diff --git a/python/ray/_private/thirdparty/dacite/data.py b/python/ray/_private/thirdparty/dacite/data.py new file mode 100644 index 000000000..560cdf8be --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/data.py @@ -0,0 +1,3 @@ +from typing import Dict, Any + +Data = Dict[str, Any] diff --git a/python/ray/_private/thirdparty/dacite/dataclasses.py b/python/ray/_private/thirdparty/dacite/dataclasses.py new file mode 100644 index 000000000..b5db7ac01 --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/dataclasses.py @@ -0,0 +1,33 @@ +from dataclasses import Field, MISSING, _FIELDS, _FIELD, _FIELD_INITVAR # type: ignore +from typing import Type, Any, TypeVar, List + +from .data import Data +from .types import is_optional + +T = TypeVar("T", bound=Any) + + +class DefaultValueNotFoundError(Exception): + pass + + +def get_default_value_for_field(field: Field) -> Any: + if field.default != MISSING: + return field.default + elif field.default_factory != MISSING: # type: ignore + return field.default_factory() # type: ignore + elif is_optional(field.type): + return None + raise DefaultValueNotFoundError() + + +def create_instance(data_class: Type[T], init_values: Data, post_init_values: Data) -> T: + instance = data_class(**init_values) + for key, value in post_init_values.items(): + setattr(instance, key, value) + return instance + + +def get_fields(data_class: Type[T]) -> List[Field]: + fields = getattr(data_class, _FIELDS) + return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR] diff --git a/python/ray/_private/thirdparty/dacite/exceptions.py b/python/ray/_private/thirdparty/dacite/exceptions.py new file mode 100644 index 000000000..871e73f1c --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/exceptions.py @@ -0,0 +1,79 @@ +from typing import Any, Type, Optional, Set, Dict + + +def _name(type_: Type) -> str: + return type_.__name__ if hasattr(type_, "__name__") else str(type_) + + +class DaciteError(Exception): + pass + + +class DaciteFieldError(DaciteError): + def __init__(self, field_path: Optional[str] = None): + super().__init__() + self.field_path = field_path + + def update_path(self, parent_field_path: str) -> None: + if self.field_path: + self.field_path = f"{parent_field_path}.{self.field_path}" + else: + self.field_path = parent_field_path + + +class WrongTypeError(DaciteFieldError): + def __init__(self, field_type: Type, value: Any, field_path: Optional[str] = None) -> None: + super().__init__(field_path=field_path) + self.field_type = field_type + self.value = value + + def __str__(self) -> str: + return ( + f'wrong value type for field "{self.field_path}" - should be "{_name(self.field_type)}" ' + f'instead of value "{self.value}" of type "{_name(type(self.value))}"' + ) + + +class MissingValueError(DaciteFieldError): + def __init__(self, field_path: Optional[str] = None): + super().__init__(field_path=field_path) + + def __str__(self) -> str: + return f'missing value for field "{self.field_path}"' + + +class UnionMatchError(WrongTypeError): + def __str__(self) -> str: + return ( + f'can not match type "{_name(type(self.value))}" to any type ' + f'of "{self.field_path}" union: {_name(self.field_type)}' + ) + + +class StrictUnionMatchError(DaciteFieldError): + def __init__(self, union_matches: Dict[Type, Any], field_path: Optional[str] = None) -> None: + super().__init__(field_path=field_path) + self.union_matches = union_matches + + def __str__(self) -> str: + conflicting_types = ", ".join(_name(type_) for type_ in self.union_matches) + return f'can not choose between possible Union matches for field "{self.field_path}": {conflicting_types}' + + +class ForwardReferenceError(DaciteError): + def __init__(self, message: str) -> None: + super().__init__() + self.message = message + + def __str__(self) -> str: + return f"can not resolve forward reference: {self.message}" + + +class UnexpectedDataError(DaciteError): + def __init__(self, keys: Set[str]) -> None: + super().__init__() + self.keys = keys + + def __str__(self) -> str: + formatted_keys = ", ".join(f'"{key}"' for key in self.keys) + return f"can not match {formatted_keys} to any data class field" diff --git a/python/ray/_private/thirdparty/dacite/py.typed b/python/ray/_private/thirdparty/dacite/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/_private/thirdparty/dacite/types.py b/python/ray/_private/thirdparty/dacite/types.py new file mode 100644 index 000000000..1d4dfea4f --- /dev/null +++ b/python/ray/_private/thirdparty/dacite/types.py @@ -0,0 +1,172 @@ +from dataclasses import InitVar +from typing import Type, Any, Optional, Union, Collection, TypeVar, Dict, Callable, Mapping, List, Tuple + +T = TypeVar("T", bound=Any) + + +def transform_value( + type_hooks: Dict[Type, Callable[[Any], Any]], cast: List[Type], target_type: Type, value: Any +) -> Any: + if target_type in type_hooks: + value = type_hooks[target_type](value) + else: + for cast_type in cast: + if is_subclass(target_type, cast_type): + if is_generic_collection(target_type): + value = extract_origin_collection(target_type)(value) + else: + value = target_type(value) + break + if is_optional(target_type): + if value is None: + return None + target_type = extract_optional(target_type) + return transform_value(type_hooks, cast, target_type, value) + if is_generic_collection(target_type) and isinstance(value, extract_origin_collection(target_type)): + collection_cls = value.__class__ + if issubclass(collection_cls, dict): + key_cls, item_cls = extract_generic(target_type, defaults=(Any, Any)) + return collection_cls( + { + transform_value(type_hooks, cast, key_cls, key): transform_value(type_hooks, cast, item_cls, item) + for key, item in value.items() + } + ) + item_cls = extract_generic(target_type, defaults=(Any,))[0] + return collection_cls(transform_value(type_hooks, cast, item_cls, item) for item in value) + return value + + +def extract_origin_collection(collection: Type) -> Type: + try: + return collection.__extra__ + except AttributeError: + return collection.__origin__ + + +def is_optional(type_: Type) -> bool: + return is_union(type_) and type(None) in extract_generic(type_) + + +def extract_optional(optional: Type[Optional[T]]) -> T: + for type_ in extract_generic(optional): + if type_ is not type(None): + return type_ + raise ValueError("can not find not-none value") + + +def is_generic(type_: Type) -> bool: + return hasattr(type_, "__origin__") + + +def is_union(type_: Type) -> bool: + return is_generic(type_) and type_.__origin__ == Union + + +def is_literal(type_: Type) -> bool: + try: + from typing import Literal # type: ignore + + return is_generic(type_) and type_.__origin__ == Literal + except ImportError: + return False + + +def is_new_type(type_: Type) -> bool: + return hasattr(type_, "__supertype__") + + +def extract_new_type(type_: Type) -> Type: + return type_.__supertype__ + + +def is_init_var(type_: Type) -> bool: + return isinstance(type_, InitVar) or type_ is InitVar + + +def extract_init_var(type_: Type) -> Union[Type, Any]: + try: + return type_.type + except AttributeError: + return Any + + +def is_instance(value: Any, type_: Type) -> bool: + if type_ == Any: + return True + elif is_union(type_): + return any(is_instance(value, t) for t in extract_generic(type_)) + elif is_generic_collection(type_): + origin = extract_origin_collection(type_) + if not isinstance(value, origin): + return False + if not extract_generic(type_): + return True + if isinstance(value, tuple): + tuple_types = extract_generic(type_) + if len(tuple_types) == 1 and tuple_types[0] == (): + return len(value) == 0 + elif len(tuple_types) == 2 and tuple_types[1] is ...: + return all(is_instance(item, tuple_types[0]) for item in value) + else: + if len(tuple_types) != len(value): + return False + return all(is_instance(item, item_type) for item, item_type in zip(value, tuple_types)) + if isinstance(value, Mapping): + key_type, val_type = extract_generic(type_, defaults=(Any, Any)) + for key, val in value.items(): + if not is_instance(key, key_type) or not is_instance(val, val_type): + return False + return True + return all(is_instance(item, extract_generic(type_, defaults=(Any,))[0]) for item in value) + elif is_new_type(type_): + return is_instance(value, extract_new_type(type_)) + elif is_literal(type_): + return value in extract_generic(type_) + elif is_init_var(type_): + return is_instance(value, extract_init_var(type_)) + elif is_type_generic(type_): + return is_subclass(value, extract_generic(type_)[0]) + else: + try: + # As described in PEP 484 - section: "The numeric tower" + if isinstance(value, (int, float)) and type_ in [float, complex]: + return True + return isinstance(value, type_) + except TypeError: + return False + + +def is_generic_collection(type_: Type) -> bool: + if not is_generic(type_): + return False + origin = extract_origin_collection(type_) + try: + return bool(origin and issubclass(origin, Collection)) + except (TypeError, AttributeError): + return False + + +def extract_generic(type_: Type, defaults: Tuple = ()) -> tuple: + try: + if hasattr(type_, "_special") and type_._special: + return defaults + return type_.__args__ or defaults # type: ignore + except AttributeError: + return defaults + + +def is_subclass(sub_type: Type, base_type: Type) -> bool: + if is_generic_collection(sub_type): + sub_type = extract_origin_collection(sub_type) + try: + return issubclass(sub_type, base_type) + except TypeError: + return False + + +def is_type_generic(type_: Type) -> bool: + try: + return type_.__origin__ in (type, Type) + except AttributeError: + return False diff --git a/python/ray/runtime_env/__init__.py b/python/ray/runtime_env/__init__.py new file mode 100644 index 000000000..f3cd30f70 --- /dev/null +++ b/python/ray/runtime_env/__init__.py @@ -0,0 +1,6 @@ +from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig # noqa: E402,F401 + +__all__ = [ + "RuntimeEnvConfig", + "RuntimeEnv", +] diff --git a/python/ray/runtime_env.py b/python/ray/runtime_env/runtime_env.py similarity index 97% rename from python/ray/runtime_env.py rename to python/ray/runtime_env/runtime_env.py index b7a899750..fa41725a7 100644 --- a/python/ray/runtime_env.py +++ b/python/ray/runtime_env/runtime_env.py @@ -2,6 +2,7 @@ import json import logging import os from copy import deepcopy +from dataclasses import asdict, is_dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Union from google.protobuf import json_format @@ -12,6 +13,7 @@ from ray._private.runtime_env.conda import get_uri as get_conda_uri from ray._private.runtime_env.pip import get_uri as get_pip_uri from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN +from ray._private.thirdparty.dacite import from_dict from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv from ray.core.generated.runtime_env_common_pb2 import ( RuntimeEnvConfig as ProtoRuntimeEnvConfig, @@ -431,10 +433,14 @@ class RuntimeEnv(dict): return plugin_uris def __setitem__(self, key: str, value: Any) -> None: - res_value = value - RuntimeEnvPluginSchemaManager.validate(key, res_value) + if is_dataclass(value): + jsonable_type = asdict(value) + else: + jsonable_type = value + RuntimeEnvPluginSchemaManager.validate(key, jsonable_type) + res_value = jsonable_type if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN: - res_value = OPTION_TO_VALIDATION_FN[key](value) + res_value = OPTION_TO_VALIDATION_FN[key](jsonable_type) if res_value is None: return return super().__setitem__(key, res_value) @@ -442,6 +448,14 @@ class RuntimeEnv(dict): def set(self, name: str, value: Any) -> None: self.__setitem__(name, value) + def get(self, name, default=None, data_class=None): + if name not in self: + return default + if not data_class: + return self.__getitem__(name) + else: + return from_dict(data_class=data_class, data=self.__getitem__(name)) + @classmethod def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821 proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv()) diff --git a/python/ray/_private/runtime_env/schemas/pip_schema.json b/python/ray/runtime_env/schemas/pip_schema.json similarity index 100% rename from python/ray/_private/runtime_env/schemas/pip_schema.json rename to python/ray/runtime_env/schemas/pip_schema.json diff --git a/python/ray/_private/runtime_env/schemas/working_dir_schema.json b/python/ray/runtime_env/schemas/working_dir_schema.json similarity index 100% rename from python/ray/_private/runtime_env/schemas/working_dir_schema.json rename to python/ray/runtime_env/schemas/working_dir_schema.json diff --git a/python/ray/runtime_env/types/pip.py b/python/ray/runtime_env/types/pip.py new file mode 100644 index 000000000..dd44acd5a --- /dev/null +++ b/python/ray/runtime_env/types/pip.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class Pip: + packages: List[str] + pip_check: bool = False diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 54a77782f..9a8bb02c3 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -122,6 +122,7 @@ py_test_module_list( "test_runtime_env_env_vars.py", "test_runtime_env_packaging.py", "test_runtime_env_plugin.py", + "test_runtime_env_strong_type.py", "test_runtime_env_fork_process.py", "test_serialization.py", "test_shuffle.py", diff --git a/python/ray/tests/test_runtime_env_strong_type.py b/python/ray/tests/test_runtime_env_strong_type.py new file mode 100644 index 000000000..7ccfff521 --- /dev/null +++ b/python/ray/tests/test_runtime_env_strong_type.py @@ -0,0 +1,70 @@ +import os +import sys +import pytest +import ray + +from typing import List +from ray.runtime_env import RuntimeEnv +from ray.runtime_env.types.pip import Pip +from dataclasses import dataclass + + +@dataclass +class ValueType: + nfield1: List[str] + nfield2: bool + + +@dataclass +class TestPlugin: + field1: List[ValueType] + field2: str + + +def test_convert_from_and_to_dataclass(): + runtime_env = RuntimeEnv() + test_plugin = TestPlugin( + field1=[ + ValueType(nfield1=["a", "b", "c"], nfield2=False), + ValueType(nfield1=["d", "e"], nfield2=True), + ], + field2="abc", + ) + runtime_env.set("test_plugin", test_plugin) + serialized_runtime_env = runtime_env.serialize() + assert "test_plugin" in serialized_runtime_env + runtime_env_2 = RuntimeEnv.deserialize(serialized_runtime_env) + test_plugin_2 = runtime_env_2.get("test_plugin", data_class=TestPlugin) + assert len(test_plugin_2.field1) == 2 + assert test_plugin_2.field1[0].nfield1 == ["a", "b", "c"] + assert test_plugin_2.field1[0].nfield2 is False + assert test_plugin_2.field1[1].nfield1 == ["d", "e"] + assert test_plugin_2.field1[1].nfield2 is True + assert test_plugin_2.field2 == "abc" + + +def test_pip(start_cluster): + cluster, address = start_cluster + ray.init(address) + + runtime_env = RuntimeEnv() + pip = Pip(packages=["pip-install-test==0.5"]) + runtime_env.set("pip", pip) + + @ray.remote + class Actor: + def foo(self): + import pip_install_test # noqa + + return "hello" + + a = Actor.options(runtime_env=runtime_env).remote() + assert ray.get(a.foo.remote()) == "hello" + + +if __name__ == "__main__": + + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__]))