mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[runtime env] plugin refactor[3/n]: support strong type by @dataclass (#26296)
This commit is contained in:
parent
b3878e26d7
commit
781c2a7834
18 changed files with 591 additions and 4 deletions
23
LICENSE
23
LICENSE
|
@ -401,3 +401,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
Code in python/ray/_private/thirdparty/dacite is adapted from https://github.com/konradhalas/dacite/blob/master/dacite
|
||||||
|
|
||||||
|
Copyright (c) 2018 Konrad Hałas
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
|
@ -14,7 +14,9 @@ logger = logging.getLogger(__name__)
|
||||||
class RuntimeEnvPluginSchemaManager:
|
class RuntimeEnvPluginSchemaManager:
|
||||||
"""This manager is used to load plugin json schemas."""
|
"""This manager is used to load plugin json schemas."""
|
||||||
|
|
||||||
default_schema_path = os.path.join(os.path.dirname(__file__), "schemas")
|
default_schema_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "../../runtime_env/schemas"
|
||||||
|
)
|
||||||
schemas = {}
|
schemas = {}
|
||||||
loaded = False
|
loaded = False
|
||||||
|
|
||||||
|
|
21
python/ray/_private/thirdparty/dacite/LICENSE
vendored
Normal file
21
python/ray/_private/thirdparty/dacite/LICENSE
vendored
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2018 Konrad Hałas
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
3
python/ray/_private/thirdparty/dacite/__init__.py
vendored
Normal file
3
python/ray/_private/thirdparty/dacite/__init__.py
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .config import Config
|
||||||
|
from .core import from_dict
|
||||||
|
from .exceptions import *
|
12
python/ray/_private/thirdparty/dacite/config.py
vendored
Normal file
12
python/ray/_private/thirdparty/dacite/config.py
vendored
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Any, Callable, Optional, Type, List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
type_hooks: Dict[Type, Callable[[Any], Any]] = field(default_factory=dict)
|
||||||
|
cast: List[Type] = field(default_factory=list)
|
||||||
|
forward_references: Optional[Dict[str, Any]] = None
|
||||||
|
check_types: bool = True
|
||||||
|
strict: bool = False
|
||||||
|
strict_unions_match: bool = False
|
140
python/ray/_private/thirdparty/dacite/core.py
vendored
Normal file
140
python/ray/_private/thirdparty/dacite/core.py
vendored
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
import copy
|
||||||
|
from dataclasses import is_dataclass
|
||||||
|
from itertools import zip_longest
|
||||||
|
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any
|
||||||
|
|
||||||
|
from .config import Config
|
||||||
|
from .data import Data
|
||||||
|
from .dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError, get_fields
|
||||||
|
from .exceptions import (
|
||||||
|
ForwardReferenceError,
|
||||||
|
WrongTypeError,
|
||||||
|
DaciteError,
|
||||||
|
UnionMatchError,
|
||||||
|
MissingValueError,
|
||||||
|
DaciteFieldError,
|
||||||
|
UnexpectedDataError,
|
||||||
|
StrictUnionMatchError,
|
||||||
|
)
|
||||||
|
from .types import (
|
||||||
|
is_instance,
|
||||||
|
is_generic_collection,
|
||||||
|
is_union,
|
||||||
|
extract_generic,
|
||||||
|
is_optional,
|
||||||
|
transform_value,
|
||||||
|
extract_origin_collection,
|
||||||
|
is_init_var,
|
||||||
|
extract_init_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) -> T:
|
||||||
|
"""Create a data class instance from a dictionary.
|
||||||
|
|
||||||
|
:param data_class: a data class type
|
||||||
|
:param data: a dictionary of a input data
|
||||||
|
:param config: a configuration of the creation process
|
||||||
|
:return: an instance of a data class
|
||||||
|
"""
|
||||||
|
init_values: Data = {}
|
||||||
|
post_init_values: Data = {}
|
||||||
|
config = config or Config()
|
||||||
|
try:
|
||||||
|
data_class_hints = get_type_hints(data_class, globalns=config.forward_references)
|
||||||
|
except NameError as error:
|
||||||
|
raise ForwardReferenceError(str(error))
|
||||||
|
data_class_fields = get_fields(data_class)
|
||||||
|
if config.strict:
|
||||||
|
extra_fields = set(data.keys()) - {f.name for f in data_class_fields}
|
||||||
|
if extra_fields:
|
||||||
|
raise UnexpectedDataError(keys=extra_fields)
|
||||||
|
for field in data_class_fields:
|
||||||
|
field = copy.copy(field)
|
||||||
|
field.type = data_class_hints[field.name]
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
field_data = data[field.name]
|
||||||
|
transformed_value = transform_value(
|
||||||
|
type_hooks=config.type_hooks, cast=config.cast, target_type=field.type, value=field_data
|
||||||
|
)
|
||||||
|
value = _build_value(type_=field.type, data=transformed_value, config=config)
|
||||||
|
except DaciteFieldError as error:
|
||||||
|
error.update_path(field.name)
|
||||||
|
raise
|
||||||
|
if config.check_types and not is_instance(value, field.type):
|
||||||
|
raise WrongTypeError(field_path=field.name, field_type=field.type, value=value)
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
value = get_default_value_for_field(field)
|
||||||
|
except DefaultValueNotFoundError:
|
||||||
|
if not field.init:
|
||||||
|
continue
|
||||||
|
raise MissingValueError(field.name)
|
||||||
|
if field.init:
|
||||||
|
init_values[field.name] = value
|
||||||
|
else:
|
||||||
|
post_init_values[field.name] = value
|
||||||
|
|
||||||
|
return create_instance(data_class=data_class, init_values=init_values, post_init_values=post_init_values)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_value(type_: Type, data: Any, config: Config) -> Any:
|
||||||
|
if is_init_var(type_):
|
||||||
|
type_ = extract_init_var(type_)
|
||||||
|
if is_union(type_):
|
||||||
|
return _build_value_for_union(union=type_, data=data, config=config)
|
||||||
|
elif is_generic_collection(type_) and is_instance(data, extract_origin_collection(type_)):
|
||||||
|
return _build_value_for_collection(collection=type_, data=data, config=config)
|
||||||
|
elif is_dataclass(type_) and is_instance(data, Data):
|
||||||
|
return from_dict(data_class=type_, data=data, config=config)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _build_value_for_union(union: Type, data: Any, config: Config) -> Any:
|
||||||
|
types = extract_generic(union)
|
||||||
|
if is_optional(union) and len(types) == 2:
|
||||||
|
return _build_value(type_=types[0], data=data, config=config)
|
||||||
|
union_matches = {}
|
||||||
|
for inner_type in types:
|
||||||
|
try:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
data = transform_value(
|
||||||
|
type_hooks=config.type_hooks, cast=config.cast, target_type=inner_type, value=data
|
||||||
|
)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
continue
|
||||||
|
value = _build_value(type_=inner_type, data=data, config=config)
|
||||||
|
if is_instance(value, inner_type):
|
||||||
|
if config.strict_unions_match:
|
||||||
|
union_matches[inner_type] = value
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
except DaciteError:
|
||||||
|
pass
|
||||||
|
if config.strict_unions_match:
|
||||||
|
if len(union_matches) > 1:
|
||||||
|
raise StrictUnionMatchError(union_matches)
|
||||||
|
return union_matches.popitem()[1]
|
||||||
|
if not config.check_types:
|
||||||
|
return data
|
||||||
|
raise UnionMatchError(field_type=union, value=data)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any:
|
||||||
|
data_type = data.__class__
|
||||||
|
if is_instance(data, Mapping):
|
||||||
|
item_type = extract_generic(collection, defaults=(Any, Any))[1]
|
||||||
|
return data_type((key, _build_value(type_=item_type, data=value, config=config)) for key, value in data.items())
|
||||||
|
elif is_instance(data, tuple):
|
||||||
|
types = extract_generic(collection)
|
||||||
|
if len(types) == 2 and types[1] == Ellipsis:
|
||||||
|
return data_type(_build_value(type_=types[0], data=item, config=config) for item in data)
|
||||||
|
return data_type(
|
||||||
|
_build_value(type_=type_, data=item, config=config) for item, type_ in zip_longest(data, types)
|
||||||
|
)
|
||||||
|
item_type = extract_generic(collection, defaults=(Any,))[0]
|
||||||
|
return data_type(_build_value(type_=item_type, data=item, config=config) for item in data)
|
3
python/ray/_private/thirdparty/dacite/data.py
vendored
Normal file
3
python/ray/_private/thirdparty/dacite/data.py
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
Data = Dict[str, Any]
|
33
python/ray/_private/thirdparty/dacite/dataclasses.py
vendored
Normal file
33
python/ray/_private/thirdparty/dacite/dataclasses.py
vendored
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
from dataclasses import Field, MISSING, _FIELDS, _FIELD, _FIELD_INITVAR # type: ignore
|
||||||
|
from typing import Type, Any, TypeVar, List
|
||||||
|
|
||||||
|
from .data import Data
|
||||||
|
from .types import is_optional
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Any)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultValueNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_value_for_field(field: Field) -> Any:
|
||||||
|
if field.default != MISSING:
|
||||||
|
return field.default
|
||||||
|
elif field.default_factory != MISSING: # type: ignore
|
||||||
|
return field.default_factory() # type: ignore
|
||||||
|
elif is_optional(field.type):
|
||||||
|
return None
|
||||||
|
raise DefaultValueNotFoundError()
|
||||||
|
|
||||||
|
|
||||||
|
def create_instance(data_class: Type[T], init_values: Data, post_init_values: Data) -> T:
|
||||||
|
instance = data_class(**init_values)
|
||||||
|
for key, value in post_init_values.items():
|
||||||
|
setattr(instance, key, value)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def get_fields(data_class: Type[T]) -> List[Field]:
|
||||||
|
fields = getattr(data_class, _FIELDS)
|
||||||
|
return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR]
|
79
python/ray/_private/thirdparty/dacite/exceptions.py
vendored
Normal file
79
python/ray/_private/thirdparty/dacite/exceptions.py
vendored
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
from typing import Any, Type, Optional, Set, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def _name(type_: Type) -> str:
|
||||||
|
return type_.__name__ if hasattr(type_, "__name__") else str(type_)
|
||||||
|
|
||||||
|
|
||||||
|
class DaciteError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DaciteFieldError(DaciteError):
|
||||||
|
def __init__(self, field_path: Optional[str] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.field_path = field_path
|
||||||
|
|
||||||
|
def update_path(self, parent_field_path: str) -> None:
|
||||||
|
if self.field_path:
|
||||||
|
self.field_path = f"{parent_field_path}.{self.field_path}"
|
||||||
|
else:
|
||||||
|
self.field_path = parent_field_path
|
||||||
|
|
||||||
|
|
||||||
|
class WrongTypeError(DaciteFieldError):
|
||||||
|
def __init__(self, field_type: Type, value: Any, field_path: Optional[str] = None) -> None:
|
||||||
|
super().__init__(field_path=field_path)
|
||||||
|
self.field_type = field_type
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f'wrong value type for field "{self.field_path}" - should be "{_name(self.field_type)}" '
|
||||||
|
f'instead of value "{self.value}" of type "{_name(type(self.value))}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MissingValueError(DaciteFieldError):
|
||||||
|
def __init__(self, field_path: Optional[str] = None):
|
||||||
|
super().__init__(field_path=field_path)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'missing value for field "{self.field_path}"'
|
||||||
|
|
||||||
|
|
||||||
|
class UnionMatchError(WrongTypeError):
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f'can not match type "{_name(type(self.value))}" to any type '
|
||||||
|
f'of "{self.field_path}" union: {_name(self.field_type)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StrictUnionMatchError(DaciteFieldError):
|
||||||
|
def __init__(self, union_matches: Dict[Type, Any], field_path: Optional[str] = None) -> None:
|
||||||
|
super().__init__(field_path=field_path)
|
||||||
|
self.union_matches = union_matches
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
conflicting_types = ", ".join(_name(type_) for type_ in self.union_matches)
|
||||||
|
return f'can not choose between possible Union matches for field "{self.field_path}": {conflicting_types}'
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardReferenceError(DaciteError):
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"can not resolve forward reference: {self.message}"
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedDataError(DaciteError):
|
||||||
|
def __init__(self, keys: Set[str]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
formatted_keys = ", ".join(f'"{key}"' for key in self.keys)
|
||||||
|
return f"can not match {formatted_keys} to any data class field"
|
0
python/ray/_private/thirdparty/dacite/py.typed
vendored
Normal file
0
python/ray/_private/thirdparty/dacite/py.typed
vendored
Normal file
172
python/ray/_private/thirdparty/dacite/types.py
vendored
Normal file
172
python/ray/_private/thirdparty/dacite/types.py
vendored
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
from dataclasses import InitVar
|
||||||
|
from typing import Type, Any, Optional, Union, Collection, TypeVar, Dict, Callable, Mapping, List, Tuple
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Any)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_value(
|
||||||
|
type_hooks: Dict[Type, Callable[[Any], Any]], cast: List[Type], target_type: Type, value: Any
|
||||||
|
) -> Any:
|
||||||
|
if target_type in type_hooks:
|
||||||
|
value = type_hooks[target_type](value)
|
||||||
|
else:
|
||||||
|
for cast_type in cast:
|
||||||
|
if is_subclass(target_type, cast_type):
|
||||||
|
if is_generic_collection(target_type):
|
||||||
|
value = extract_origin_collection(target_type)(value)
|
||||||
|
else:
|
||||||
|
value = target_type(value)
|
||||||
|
break
|
||||||
|
if is_optional(target_type):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
target_type = extract_optional(target_type)
|
||||||
|
return transform_value(type_hooks, cast, target_type, value)
|
||||||
|
if is_generic_collection(target_type) and isinstance(value, extract_origin_collection(target_type)):
|
||||||
|
collection_cls = value.__class__
|
||||||
|
if issubclass(collection_cls, dict):
|
||||||
|
key_cls, item_cls = extract_generic(target_type, defaults=(Any, Any))
|
||||||
|
return collection_cls(
|
||||||
|
{
|
||||||
|
transform_value(type_hooks, cast, key_cls, key): transform_value(type_hooks, cast, item_cls, item)
|
||||||
|
for key, item in value.items()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
item_cls = extract_generic(target_type, defaults=(Any,))[0]
|
||||||
|
return collection_cls(transform_value(type_hooks, cast, item_cls, item) for item in value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def extract_origin_collection(collection: Type) -> Type:
|
||||||
|
try:
|
||||||
|
return collection.__extra__
|
||||||
|
except AttributeError:
|
||||||
|
return collection.__origin__
|
||||||
|
|
||||||
|
|
||||||
|
def is_optional(type_: Type) -> bool:
|
||||||
|
return is_union(type_) and type(None) in extract_generic(type_)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_optional(optional: Type[Optional[T]]) -> T:
|
||||||
|
for type_ in extract_generic(optional):
|
||||||
|
if type_ is not type(None):
|
||||||
|
return type_
|
||||||
|
raise ValueError("can not find not-none value")
|
||||||
|
|
||||||
|
|
||||||
|
def is_generic(type_: Type) -> bool:
|
||||||
|
return hasattr(type_, "__origin__")
|
||||||
|
|
||||||
|
|
||||||
|
def is_union(type_: Type) -> bool:
|
||||||
|
return is_generic(type_) and type_.__origin__ == Union
|
||||||
|
|
||||||
|
|
||||||
|
def is_literal(type_: Type) -> bool:
|
||||||
|
try:
|
||||||
|
from typing import Literal # type: ignore
|
||||||
|
|
||||||
|
return is_generic(type_) and type_.__origin__ == Literal
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_new_type(type_: Type) -> bool:
|
||||||
|
return hasattr(type_, "__supertype__")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_new_type(type_: Type) -> Type:
|
||||||
|
return type_.__supertype__
|
||||||
|
|
||||||
|
|
||||||
|
def is_init_var(type_: Type) -> bool:
|
||||||
|
return isinstance(type_, InitVar) or type_ is InitVar
|
||||||
|
|
||||||
|
|
||||||
|
def extract_init_var(type_: Type) -> Union[Type, Any]:
|
||||||
|
try:
|
||||||
|
return type_.type
|
||||||
|
except AttributeError:
|
||||||
|
return Any
|
||||||
|
|
||||||
|
|
||||||
|
def is_instance(value: Any, type_: Type) -> bool:
|
||||||
|
if type_ == Any:
|
||||||
|
return True
|
||||||
|
elif is_union(type_):
|
||||||
|
return any(is_instance(value, t) for t in extract_generic(type_))
|
||||||
|
elif is_generic_collection(type_):
|
||||||
|
origin = extract_origin_collection(type_)
|
||||||
|
if not isinstance(value, origin):
|
||||||
|
return False
|
||||||
|
if not extract_generic(type_):
|
||||||
|
return True
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
tuple_types = extract_generic(type_)
|
||||||
|
if len(tuple_types) == 1 and tuple_types[0] == ():
|
||||||
|
return len(value) == 0
|
||||||
|
elif len(tuple_types) == 2 and tuple_types[1] is ...:
|
||||||
|
return all(is_instance(item, tuple_types[0]) for item in value)
|
||||||
|
else:
|
||||||
|
if len(tuple_types) != len(value):
|
||||||
|
return False
|
||||||
|
return all(is_instance(item, item_type) for item, item_type in zip(value, tuple_types))
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
key_type, val_type = extract_generic(type_, defaults=(Any, Any))
|
||||||
|
for key, val in value.items():
|
||||||
|
if not is_instance(key, key_type) or not is_instance(val, val_type):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return all(is_instance(item, extract_generic(type_, defaults=(Any,))[0]) for item in value)
|
||||||
|
elif is_new_type(type_):
|
||||||
|
return is_instance(value, extract_new_type(type_))
|
||||||
|
elif is_literal(type_):
|
||||||
|
return value in extract_generic(type_)
|
||||||
|
elif is_init_var(type_):
|
||||||
|
return is_instance(value, extract_init_var(type_))
|
||||||
|
elif is_type_generic(type_):
|
||||||
|
return is_subclass(value, extract_generic(type_)[0])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# As described in PEP 484 - section: "The numeric tower"
|
||||||
|
if isinstance(value, (int, float)) and type_ in [float, complex]:
|
||||||
|
return True
|
||||||
|
return isinstance(value, type_)
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_generic_collection(type_: Type) -> bool:
|
||||||
|
if not is_generic(type_):
|
||||||
|
return False
|
||||||
|
origin = extract_origin_collection(type_)
|
||||||
|
try:
|
||||||
|
return bool(origin and issubclass(origin, Collection))
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_generic(type_: Type, defaults: Tuple = ()) -> tuple:
|
||||||
|
try:
|
||||||
|
if hasattr(type_, "_special") and type_._special:
|
||||||
|
return defaults
|
||||||
|
return type_.__args__ or defaults # type: ignore
|
||||||
|
except AttributeError:
|
||||||
|
return defaults
|
||||||
|
|
||||||
|
|
||||||
|
def is_subclass(sub_type: Type, base_type: Type) -> bool:
|
||||||
|
if is_generic_collection(sub_type):
|
||||||
|
sub_type = extract_origin_collection(sub_type)
|
||||||
|
try:
|
||||||
|
return issubclass(sub_type, base_type)
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_type_generic(type_: Type) -> bool:
|
||||||
|
try:
|
||||||
|
return type_.__origin__ in (type, Type)
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
6
python/ray/runtime_env/__init__.py
Normal file
6
python/ray/runtime_env/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig # noqa: E402,F401
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RuntimeEnvConfig",
|
||||||
|
"RuntimeEnv",
|
||||||
|
]
|
|
@ -2,6 +2,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from google.protobuf import json_format
|
from google.protobuf import json_format
|
||||||
|
@ -12,6 +13,7 @@ from ray._private.runtime_env.conda import get_uri as get_conda_uri
|
||||||
from ray._private.runtime_env.pip import get_uri as get_pip_uri
|
from ray._private.runtime_env.pip import get_uri as get_pip_uri
|
||||||
from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager
|
from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager
|
||||||
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
|
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
|
||||||
|
from ray._private.thirdparty.dacite import from_dict
|
||||||
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
|
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
|
||||||
from ray.core.generated.runtime_env_common_pb2 import (
|
from ray.core.generated.runtime_env_common_pb2 import (
|
||||||
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
|
RuntimeEnvConfig as ProtoRuntimeEnvConfig,
|
||||||
|
@ -431,10 +433,14 @@ class RuntimeEnv(dict):
|
||||||
return plugin_uris
|
return plugin_uris
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Any) -> None:
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
res_value = value
|
if is_dataclass(value):
|
||||||
RuntimeEnvPluginSchemaManager.validate(key, res_value)
|
jsonable_type = asdict(value)
|
||||||
|
else:
|
||||||
|
jsonable_type = value
|
||||||
|
RuntimeEnvPluginSchemaManager.validate(key, jsonable_type)
|
||||||
|
res_value = jsonable_type
|
||||||
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
|
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
|
||||||
res_value = OPTION_TO_VALIDATION_FN[key](value)
|
res_value = OPTION_TO_VALIDATION_FN[key](jsonable_type)
|
||||||
if res_value is None:
|
if res_value is None:
|
||||||
return
|
return
|
||||||
return super().__setitem__(key, res_value)
|
return super().__setitem__(key, res_value)
|
||||||
|
@ -442,6 +448,14 @@ class RuntimeEnv(dict):
|
||||||
def set(self, name: str, value: Any) -> None:
|
def set(self, name: str, value: Any) -> None:
|
||||||
self.__setitem__(name, value)
|
self.__setitem__(name, value)
|
||||||
|
|
||||||
|
def get(self, name, default=None, data_class=None):
|
||||||
|
if name not in self:
|
||||||
|
return default
|
||||||
|
if not data_class:
|
||||||
|
return self.__getitem__(name)
|
||||||
|
else:
|
||||||
|
return from_dict(data_class=data_class, data=self.__getitem__(name))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
|
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
|
||||||
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
|
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())
|
8
python/ray/runtime_env/types/pip.py
Normal file
8
python/ray/runtime_env/types/pip.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Pip:
|
||||||
|
packages: List[str]
|
||||||
|
pip_check: bool = False
|
|
@ -122,6 +122,7 @@ py_test_module_list(
|
||||||
"test_runtime_env_env_vars.py",
|
"test_runtime_env_env_vars.py",
|
||||||
"test_runtime_env_packaging.py",
|
"test_runtime_env_packaging.py",
|
||||||
"test_runtime_env_plugin.py",
|
"test_runtime_env_plugin.py",
|
||||||
|
"test_runtime_env_strong_type.py",
|
||||||
"test_runtime_env_fork_process.py",
|
"test_runtime_env_fork_process.py",
|
||||||
"test_serialization.py",
|
"test_serialization.py",
|
||||||
"test_shuffle.py",
|
"test_shuffle.py",
|
||||||
|
|
70
python/ray/tests/test_runtime_env_strong_type.py
Normal file
70
python/ray/tests/test_runtime_env_strong_type.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from ray.runtime_env import RuntimeEnv
|
||||||
|
from ray.runtime_env.types.pip import Pip
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValueType:
|
||||||
|
nfield1: List[str]
|
||||||
|
nfield2: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestPlugin:
|
||||||
|
field1: List[ValueType]
|
||||||
|
field2: str
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_from_and_to_dataclass():
|
||||||
|
runtime_env = RuntimeEnv()
|
||||||
|
test_plugin = TestPlugin(
|
||||||
|
field1=[
|
||||||
|
ValueType(nfield1=["a", "b", "c"], nfield2=False),
|
||||||
|
ValueType(nfield1=["d", "e"], nfield2=True),
|
||||||
|
],
|
||||||
|
field2="abc",
|
||||||
|
)
|
||||||
|
runtime_env.set("test_plugin", test_plugin)
|
||||||
|
serialized_runtime_env = runtime_env.serialize()
|
||||||
|
assert "test_plugin" in serialized_runtime_env
|
||||||
|
runtime_env_2 = RuntimeEnv.deserialize(serialized_runtime_env)
|
||||||
|
test_plugin_2 = runtime_env_2.get("test_plugin", data_class=TestPlugin)
|
||||||
|
assert len(test_plugin_2.field1) == 2
|
||||||
|
assert test_plugin_2.field1[0].nfield1 == ["a", "b", "c"]
|
||||||
|
assert test_plugin_2.field1[0].nfield2 is False
|
||||||
|
assert test_plugin_2.field1[1].nfield1 == ["d", "e"]
|
||||||
|
assert test_plugin_2.field1[1].nfield2 is True
|
||||||
|
assert test_plugin_2.field2 == "abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pip(start_cluster):
|
||||||
|
cluster, address = start_cluster
|
||||||
|
ray.init(address)
|
||||||
|
|
||||||
|
runtime_env = RuntimeEnv()
|
||||||
|
pip = Pip(packages=["pip-install-test==0.5"])
|
||||||
|
runtime_env.set("pip", pip)
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class Actor:
|
||||||
|
def foo(self):
|
||||||
|
import pip_install_test # noqa
|
||||||
|
|
||||||
|
return "hello"
|
||||||
|
|
||||||
|
a = Actor.options(runtime_env=runtime_env).remote()
|
||||||
|
assert ray.get(a.foo.remote()) == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if os.environ.get("PARALLEL_CI"):
|
||||||
|
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
|
||||||
|
else:
|
||||||
|
sys.exit(pytest.main(["-sv", __file__]))
|
Loading…
Add table
Reference in a new issue