mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[runtime env] plugin refactor[3/n]: support strong type by @dataclass (#26296)
This commit is contained in:
parent
b3878e26d7
commit
781c2a7834
18 changed files with 591 additions and 4 deletions
23
LICENSE
23
LICENSE
|
@ -401,3 +401,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
Code in python/ray/_private/thirdparty/dacite is adapted from https://github.com/konradhalas/dacite/blob/master/dacite
|
||||
|
||||
Copyright (c) 2018 Konrad Hałas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
|
@ -14,7 +14,9 @@ logger = logging.getLogger(__name__)
|
|||
class RuntimeEnvPluginSchemaManager:
|
||||
"""This manager is used to load plugin json schemas."""
|
||||
|
||||
default_schema_path = os.path.join(os.path.dirname(__file__), "schemas")
|
||||
default_schema_path = os.path.join(
|
||||
os.path.dirname(__file__), "../../runtime_env/schemas"
|
||||
)
|
||||
schemas = {}
|
||||
loaded = False
|
||||
|
||||
|
|
21
python/ray/_private/thirdparty/dacite/LICENSE
vendored
Normal file
21
python/ray/_private/thirdparty/dacite/LICENSE
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2018 Konrad Hałas
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
3
python/ray/_private/thirdparty/dacite/__init__.py
vendored
Normal file
3
python/ray/_private/thirdparty/dacite/__init__.py
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .config import Config
|
||||
from .core import from_dict
|
||||
from .exceptions import *
|
12
python/ray/_private/thirdparty/dacite/config.py
vendored
Normal file
12
python/ray/_private/thirdparty/dacite/config.py
vendored
Normal file
|
@ -0,0 +1,12 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Callable, Optional, Type, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
type_hooks: Dict[Type, Callable[[Any], Any]] = field(default_factory=dict)
|
||||
cast: List[Type] = field(default_factory=list)
|
||||
forward_references: Optional[Dict[str, Any]] = None
|
||||
check_types: bool = True
|
||||
strict: bool = False
|
||||
strict_unions_match: bool = False
|
140
python/ray/_private/thirdparty/dacite/core.py
vendored
Normal file
140
python/ray/_private/thirdparty/dacite/core.py
vendored
Normal file
|
@ -0,0 +1,140 @@
|
|||
import copy
|
||||
from dataclasses import is_dataclass
|
||||
from itertools import zip_longest
|
||||
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any
|
||||
|
||||
from .config import Config
|
||||
from .data import Data
|
||||
from .dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError, get_fields
|
||||
from .exceptions import (
|
||||
ForwardReferenceError,
|
||||
WrongTypeError,
|
||||
DaciteError,
|
||||
UnionMatchError,
|
||||
MissingValueError,
|
||||
DaciteFieldError,
|
||||
UnexpectedDataError,
|
||||
StrictUnionMatchError,
|
||||
)
|
||||
from .types import (
|
||||
is_instance,
|
||||
is_generic_collection,
|
||||
is_union,
|
||||
extract_generic,
|
||||
is_optional,
|
||||
transform_value,
|
||||
extract_origin_collection,
|
||||
is_init_var,
|
||||
extract_init_var,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) -> T:
|
||||
"""Create a data class instance from a dictionary.
|
||||
|
||||
:param data_class: a data class type
|
||||
:param data: a dictionary of a input data
|
||||
:param config: a configuration of the creation process
|
||||
:return: an instance of a data class
|
||||
"""
|
||||
init_values: Data = {}
|
||||
post_init_values: Data = {}
|
||||
config = config or Config()
|
||||
try:
|
||||
data_class_hints = get_type_hints(data_class, globalns=config.forward_references)
|
||||
except NameError as error:
|
||||
raise ForwardReferenceError(str(error))
|
||||
data_class_fields = get_fields(data_class)
|
||||
if config.strict:
|
||||
extra_fields = set(data.keys()) - {f.name for f in data_class_fields}
|
||||
if extra_fields:
|
||||
raise UnexpectedDataError(keys=extra_fields)
|
||||
for field in data_class_fields:
|
||||
field = copy.copy(field)
|
||||
field.type = data_class_hints[field.name]
|
||||
try:
|
||||
try:
|
||||
field_data = data[field.name]
|
||||
transformed_value = transform_value(
|
||||
type_hooks=config.type_hooks, cast=config.cast, target_type=field.type, value=field_data
|
||||
)
|
||||
value = _build_value(type_=field.type, data=transformed_value, config=config)
|
||||
except DaciteFieldError as error:
|
||||
error.update_path(field.name)
|
||||
raise
|
||||
if config.check_types and not is_instance(value, field.type):
|
||||
raise WrongTypeError(field_path=field.name, field_type=field.type, value=value)
|
||||
except KeyError:
|
||||
try:
|
||||
value = get_default_value_for_field(field)
|
||||
except DefaultValueNotFoundError:
|
||||
if not field.init:
|
||||
continue
|
||||
raise MissingValueError(field.name)
|
||||
if field.init:
|
||||
init_values[field.name] = value
|
||||
else:
|
||||
post_init_values[field.name] = value
|
||||
|
||||
return create_instance(data_class=data_class, init_values=init_values, post_init_values=post_init_values)
|
||||
|
||||
|
||||
def _build_value(type_: Type, data: Any, config: Config) -> Any:
|
||||
if is_init_var(type_):
|
||||
type_ = extract_init_var(type_)
|
||||
if is_union(type_):
|
||||
return _build_value_for_union(union=type_, data=data, config=config)
|
||||
elif is_generic_collection(type_) and is_instance(data, extract_origin_collection(type_)):
|
||||
return _build_value_for_collection(collection=type_, data=data, config=config)
|
||||
elif is_dataclass(type_) and is_instance(data, Data):
|
||||
return from_dict(data_class=type_, data=data, config=config)
|
||||
return data
|
||||
|
||||
|
||||
def _build_value_for_union(union: Type, data: Any, config: Config) -> Any:
|
||||
types = extract_generic(union)
|
||||
if is_optional(union) and len(types) == 2:
|
||||
return _build_value(type_=types[0], data=data, config=config)
|
||||
union_matches = {}
|
||||
for inner_type in types:
|
||||
try:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
data = transform_value(
|
||||
type_hooks=config.type_hooks, cast=config.cast, target_type=inner_type, value=data
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
continue
|
||||
value = _build_value(type_=inner_type, data=data, config=config)
|
||||
if is_instance(value, inner_type):
|
||||
if config.strict_unions_match:
|
||||
union_matches[inner_type] = value
|
||||
else:
|
||||
return value
|
||||
except DaciteError:
|
||||
pass
|
||||
if config.strict_unions_match:
|
||||
if len(union_matches) > 1:
|
||||
raise StrictUnionMatchError(union_matches)
|
||||
return union_matches.popitem()[1]
|
||||
if not config.check_types:
|
||||
return data
|
||||
raise UnionMatchError(field_type=union, value=data)
|
||||
|
||||
|
||||
def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any:
|
||||
data_type = data.__class__
|
||||
if is_instance(data, Mapping):
|
||||
item_type = extract_generic(collection, defaults=(Any, Any))[1]
|
||||
return data_type((key, _build_value(type_=item_type, data=value, config=config)) for key, value in data.items())
|
||||
elif is_instance(data, tuple):
|
||||
types = extract_generic(collection)
|
||||
if len(types) == 2 and types[1] == Ellipsis:
|
||||
return data_type(_build_value(type_=types[0], data=item, config=config) for item in data)
|
||||
return data_type(
|
||||
_build_value(type_=type_, data=item, config=config) for item, type_ in zip_longest(data, types)
|
||||
)
|
||||
item_type = extract_generic(collection, defaults=(Any,))[0]
|
||||
return data_type(_build_value(type_=item_type, data=item, config=config) for item in data)
|
3
python/ray/_private/thirdparty/dacite/data.py
vendored
Normal file
3
python/ray/_private/thirdparty/dacite/data.py
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
from typing import Dict, Any
|
||||
|
||||
Data = Dict[str, Any]
|
33
python/ray/_private/thirdparty/dacite/dataclasses.py
vendored
Normal file
33
python/ray/_private/thirdparty/dacite/dataclasses.py
vendored
Normal file
|
@ -0,0 +1,33 @@
|
|||
from dataclasses import Field, MISSING, _FIELDS, _FIELD, _FIELD_INITVAR # type: ignore
|
||||
from typing import Type, Any, TypeVar, List
|
||||
|
||||
from .data import Data
|
||||
from .types import is_optional
|
||||
|
||||
T = TypeVar("T", bound=Any)
|
||||
|
||||
|
||||
class DefaultValueNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_default_value_for_field(field: Field) -> Any:
|
||||
if field.default != MISSING:
|
||||
return field.default
|
||||
elif field.default_factory != MISSING: # type: ignore
|
||||
return field.default_factory() # type: ignore
|
||||
elif is_optional(field.type):
|
||||
return None
|
||||
raise DefaultValueNotFoundError()
|
||||
|
||||
|
||||
def create_instance(data_class: Type[T], init_values: Data, post_init_values: Data) -> T:
|
||||
instance = data_class(**init_values)
|
||||
for key, value in post_init_values.items():
|
||||
setattr(instance, key, value)
|
||||
return instance
|
||||
|
||||
|
||||
def get_fields(data_class: Type[T]) -> List[Field]:
|
||||
fields = getattr(data_class, _FIELDS)
|
||||
return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR]
|
79
python/ray/_private/thirdparty/dacite/exceptions.py
vendored
Normal file
79
python/ray/_private/thirdparty/dacite/exceptions.py
vendored
Normal file
|
@ -0,0 +1,79 @@
|
|||
from typing import Any, Type, Optional, Set, Dict
|
||||
|
||||
|
||||
def _name(type_: Type) -> str:
|
||||
return type_.__name__ if hasattr(type_, "__name__") else str(type_)
|
||||
|
||||
|
||||
class DaciteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DaciteFieldError(DaciteError):
|
||||
def __init__(self, field_path: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.field_path = field_path
|
||||
|
||||
def update_path(self, parent_field_path: str) -> None:
|
||||
if self.field_path:
|
||||
self.field_path = f"{parent_field_path}.{self.field_path}"
|
||||
else:
|
||||
self.field_path = parent_field_path
|
||||
|
||||
|
||||
class WrongTypeError(DaciteFieldError):
|
||||
def __init__(self, field_type: Type, value: Any, field_path: Optional[str] = None) -> None:
|
||||
super().__init__(field_path=field_path)
|
||||
self.field_type = field_type
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'wrong value type for field "{self.field_path}" - should be "{_name(self.field_type)}" '
|
||||
f'instead of value "{self.value}" of type "{_name(type(self.value))}"'
|
||||
)
|
||||
|
||||
|
||||
class MissingValueError(DaciteFieldError):
|
||||
def __init__(self, field_path: Optional[str] = None):
|
||||
super().__init__(field_path=field_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'missing value for field "{self.field_path}"'
|
||||
|
||||
|
||||
class UnionMatchError(WrongTypeError):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'can not match type "{_name(type(self.value))}" to any type '
|
||||
f'of "{self.field_path}" union: {_name(self.field_type)}'
|
||||
)
|
||||
|
||||
|
||||
class StrictUnionMatchError(DaciteFieldError):
|
||||
def __init__(self, union_matches: Dict[Type, Any], field_path: Optional[str] = None) -> None:
|
||||
super().__init__(field_path=field_path)
|
||||
self.union_matches = union_matches
|
||||
|
||||
def __str__(self) -> str:
|
||||
conflicting_types = ", ".join(_name(type_) for type_ in self.union_matches)
|
||||
return f'can not choose between possible Union matches for field "{self.field_path}": {conflicting_types}'
|
||||
|
||||
|
||||
class ForwardReferenceError(DaciteError):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__()
|
||||
self.message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"can not resolve forward reference: {self.message}"
|
||||
|
||||
|
||||
class UnexpectedDataError(DaciteError):
|
||||
def __init__(self, keys: Set[str]) -> None:
|
||||
super().__init__()
|
||||
self.keys = keys
|
||||
|
||||
def __str__(self) -> str:
|
||||
formatted_keys = ", ".join(f'"{key}"' for key in self.keys)
|
||||
return f"can not match {formatted_keys} to any data class field"
|
0
python/ray/_private/thirdparty/dacite/py.typed
vendored
Normal file
0
python/ray/_private/thirdparty/dacite/py.typed
vendored
Normal file
172
python/ray/_private/thirdparty/dacite/types.py
vendored
Normal file
172
python/ray/_private/thirdparty/dacite/types.py
vendored
Normal file
|
@ -0,0 +1,172 @@
|
|||
from dataclasses import InitVar
|
||||
from typing import Type, Any, Optional, Union, Collection, TypeVar, Dict, Callable, Mapping, List, Tuple
|
||||
|
||||
T = TypeVar("T", bound=Any)
|
||||
|
||||
|
||||
def transform_value(
|
||||
type_hooks: Dict[Type, Callable[[Any], Any]], cast: List[Type], target_type: Type, value: Any
|
||||
) -> Any:
|
||||
if target_type in type_hooks:
|
||||
value = type_hooks[target_type](value)
|
||||
else:
|
||||
for cast_type in cast:
|
||||
if is_subclass(target_type, cast_type):
|
||||
if is_generic_collection(target_type):
|
||||
value = extract_origin_collection(target_type)(value)
|
||||
else:
|
||||
value = target_type(value)
|
||||
break
|
||||
if is_optional(target_type):
|
||||
if value is None:
|
||||
return None
|
||||
target_type = extract_optional(target_type)
|
||||
return transform_value(type_hooks, cast, target_type, value)
|
||||
if is_generic_collection(target_type) and isinstance(value, extract_origin_collection(target_type)):
|
||||
collection_cls = value.__class__
|
||||
if issubclass(collection_cls, dict):
|
||||
key_cls, item_cls = extract_generic(target_type, defaults=(Any, Any))
|
||||
return collection_cls(
|
||||
{
|
||||
transform_value(type_hooks, cast, key_cls, key): transform_value(type_hooks, cast, item_cls, item)
|
||||
for key, item in value.items()
|
||||
}
|
||||
)
|
||||
item_cls = extract_generic(target_type, defaults=(Any,))[0]
|
||||
return collection_cls(transform_value(type_hooks, cast, item_cls, item) for item in value)
|
||||
return value
|
||||
|
||||
|
||||
def extract_origin_collection(collection: Type) -> Type:
|
||||
try:
|
||||
return collection.__extra__
|
||||
except AttributeError:
|
||||
return collection.__origin__
|
||||
|
||||
|
||||
def is_optional(type_: Type) -> bool:
|
||||
return is_union(type_) and type(None) in extract_generic(type_)
|
||||
|
||||
|
||||
def extract_optional(optional: Type[Optional[T]]) -> T:
|
||||
for type_ in extract_generic(optional):
|
||||
if type_ is not type(None):
|
||||
return type_
|
||||
raise ValueError("can not find not-none value")
|
||||
|
||||
|
||||
def is_generic(type_: Type) -> bool:
|
||||
return hasattr(type_, "__origin__")
|
||||
|
||||
|
||||
def is_union(type_: Type) -> bool:
|
||||
return is_generic(type_) and type_.__origin__ == Union
|
||||
|
||||
|
||||
def is_literal(type_: Type) -> bool:
|
||||
try:
|
||||
from typing import Literal # type: ignore
|
||||
|
||||
return is_generic(type_) and type_.__origin__ == Literal
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def is_new_type(type_: Type) -> bool:
|
||||
return hasattr(type_, "__supertype__")
|
||||
|
||||
|
||||
def extract_new_type(type_: Type) -> Type:
|
||||
return type_.__supertype__
|
||||
|
||||
|
||||
def is_init_var(type_: Type) -> bool:
|
||||
return isinstance(type_, InitVar) or type_ is InitVar
|
||||
|
||||
|
||||
def extract_init_var(type_: Type) -> Union[Type, Any]:
|
||||
try:
|
||||
return type_.type
|
||||
except AttributeError:
|
||||
return Any
|
||||
|
||||
|
||||
def is_instance(value: Any, type_: Type) -> bool:
|
||||
if type_ == Any:
|
||||
return True
|
||||
elif is_union(type_):
|
||||
return any(is_instance(value, t) for t in extract_generic(type_))
|
||||
elif is_generic_collection(type_):
|
||||
origin = extract_origin_collection(type_)
|
||||
if not isinstance(value, origin):
|
||||
return False
|
||||
if not extract_generic(type_):
|
||||
return True
|
||||
if isinstance(value, tuple):
|
||||
tuple_types = extract_generic(type_)
|
||||
if len(tuple_types) == 1 and tuple_types[0] == ():
|
||||
return len(value) == 0
|
||||
elif len(tuple_types) == 2 and tuple_types[1] is ...:
|
||||
return all(is_instance(item, tuple_types[0]) for item in value)
|
||||
else:
|
||||
if len(tuple_types) != len(value):
|
||||
return False
|
||||
return all(is_instance(item, item_type) for item, item_type in zip(value, tuple_types))
|
||||
if isinstance(value, Mapping):
|
||||
key_type, val_type = extract_generic(type_, defaults=(Any, Any))
|
||||
for key, val in value.items():
|
||||
if not is_instance(key, key_type) or not is_instance(val, val_type):
|
||||
return False
|
||||
return True
|
||||
return all(is_instance(item, extract_generic(type_, defaults=(Any,))[0]) for item in value)
|
||||
elif is_new_type(type_):
|
||||
return is_instance(value, extract_new_type(type_))
|
||||
elif is_literal(type_):
|
||||
return value in extract_generic(type_)
|
||||
elif is_init_var(type_):
|
||||
return is_instance(value, extract_init_var(type_))
|
||||
elif is_type_generic(type_):
|
||||
return is_subclass(value, extract_generic(type_)[0])
|
||||
else:
|
||||
try:
|
||||
# As described in PEP 484 - section: "The numeric tower"
|
||||
if isinstance(value, (int, float)) and type_ in [float, complex]:
|
||||
return True
|
||||
return isinstance(value, type_)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_generic_collection(type_: Type) -> bool:
|
||||
if not is_generic(type_):
|
||||
return False
|
||||
origin = extract_origin_collection(type_)
|
||||
try:
|
||||
return bool(origin and issubclass(origin, Collection))
|
||||
except (TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
def extract_generic(type_: Type, defaults: Tuple = ()) -> tuple:
|
||||
try:
|
||||
if hasattr(type_, "_special") and type_._special:
|
||||
return defaults
|
||||
return type_.__args__ or defaults # type: ignore
|
||||
except AttributeError:
|
||||
return defaults
|
||||
|
||||
|
||||
def is_subclass(sub_type: Type, base_type: Type) -> bool:
|
||||
if is_generic_collection(sub_type):
|
||||
sub_type = extract_origin_collection(sub_type)
|
||||
try:
|
||||
return issubclass(sub_type, base_type)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_type_generic(type_: Type) -> bool:
|
||||
try:
|
||||
return type_.__origin__ in (type, Type)
|
||||
except AttributeError:
|
||||
return False
|
6
python/ray/runtime_env/__init__.py
Normal file
6
python/ray/runtime_env/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig # noqa: E402,F401
|
||||
|
||||
__all__ = [
|
||||
"RuntimeEnvConfig",
|
||||
"RuntimeEnv",
|
||||
]
|
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from google.protobuf import json_format
|
||||
|
@ -12,6 +13,7 @@ from ray._private.runtime_env.conda import get_uri as get_conda_uri
|
|||
from ray._private.runtime_env.pip import get_uri as get_pip_uri
|
||||
from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager
|
||||
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
|
||||
from ray._private.thirdparty.dacite import from_dict
|
||||
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
|
||||
from ray.core.generated.runtime_env_common_pb2 import (
|
||||
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
|
||||
|
@ -431,10 +433,14 @@ class RuntimeEnv(dict):
|
|||
return plugin_uris
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
res_value = value
|
||||
RuntimeEnvPluginSchemaManager.validate(key, res_value)
|
||||
if is_dataclass(value):
|
||||
jsonable_type = asdict(value)
|
||||
else:
|
||||
jsonable_type = value
|
||||
RuntimeEnvPluginSchemaManager.validate(key, jsonable_type)
|
||||
res_value = jsonable_type
|
||||
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
|
||||
res_value = OPTION_TO_VALIDATION_FN[key](value)
|
||||
res_value = OPTION_TO_VALIDATION_FN[key](jsonable_type)
|
||||
if res_value is None:
|
||||
return
|
||||
return super().__setitem__(key, res_value)
|
||||
|
@ -442,6 +448,14 @@ class RuntimeEnv(dict):
|
|||
def set(self, name: str, value: Any) -> None:
|
||||
self.__setitem__(name, value)
|
||||
|
||||
def get(self, name, default=None, data_class=None):
|
||||
if name not in self:
|
||||
return default
|
||||
if not data_class:
|
||||
return self.__getitem__(name)
|
||||
else:
|
||||
return from_dict(data_class=data_class, data=self.__getitem__(name))
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
|
||||
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
|
8
python/ray/runtime_env/types/pip.py
Normal file
8
python/ray/runtime_env/types/pip.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pip:
|
||||
packages: List[str]
|
||||
pip_check: bool = False
|
|
@ -122,6 +122,7 @@ py_test_module_list(
|
|||
"test_runtime_env_env_vars.py",
|
||||
"test_runtime_env_packaging.py",
|
||||
"test_runtime_env_plugin.py",
|
||||
"test_runtime_env_strong_type.py",
|
||||
"test_runtime_env_fork_process.py",
|
||||
"test_serialization.py",
|
||||
"test_shuffle.py",
|
||||
|
|
70
python/ray/tests/test_runtime_env_strong_type.py
Normal file
70
python/ray/tests/test_runtime_env_strong_type.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from typing import List
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.runtime_env.types.pip import Pip
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValueType:
|
||||
nfield1: List[str]
|
||||
nfield2: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestPlugin:
|
||||
field1: List[ValueType]
|
||||
field2: str
|
||||
|
||||
|
||||
def test_convert_from_and_to_dataclass():
|
||||
runtime_env = RuntimeEnv()
|
||||
test_plugin = TestPlugin(
|
||||
field1=[
|
||||
ValueType(nfield1=["a", "b", "c"], nfield2=False),
|
||||
ValueType(nfield1=["d", "e"], nfield2=True),
|
||||
],
|
||||
field2="abc",
|
||||
)
|
||||
runtime_env.set("test_plugin", test_plugin)
|
||||
serialized_runtime_env = runtime_env.serialize()
|
||||
assert "test_plugin" in serialized_runtime_env
|
||||
runtime_env_2 = RuntimeEnv.deserialize(serialized_runtime_env)
|
||||
test_plugin_2 = runtime_env_2.get("test_plugin", data_class=TestPlugin)
|
||||
assert len(test_plugin_2.field1) == 2
|
||||
assert test_plugin_2.field1[0].nfield1 == ["a", "b", "c"]
|
||||
assert test_plugin_2.field1[0].nfield2 is False
|
||||
assert test_plugin_2.field1[1].nfield1 == ["d", "e"]
|
||||
assert test_plugin_2.field1[1].nfield2 is True
|
||||
assert test_plugin_2.field2 == "abc"
|
||||
|
||||
|
||||
def test_pip(start_cluster):
|
||||
cluster, address = start_cluster
|
||||
ray.init(address)
|
||||
|
||||
runtime_env = RuntimeEnv()
|
||||
pip = Pip(packages=["pip-install-test==0.5"])
|
||||
runtime_env.set("pip", pip)
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def foo(self):
|
||||
import pip_install_test # noqa
|
||||
|
||||
return "hello"
|
||||
|
||||
a = Actor.options(runtime_env=runtime_env).remote()
|
||||
assert ray.get(a.foo.remote()) == "hello"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if os.environ.get("PARALLEL_CI"):
|
||||
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
|
||||
else:
|
||||
sys.exit(pytest.main(["-sv", __file__]))
|
Loading…
Add table
Reference in a new issue