[runtime env] plugin refactor[3/n]: support strong type by @dataclass (#26296)

This commit is contained in:
Guyang Song 2022-07-13 00:40:42 +08:00 committed by GitHub
parent b3878e26d7
commit 781c2a7834
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 591 additions and 4 deletions

23
LICENSE
View file

@ -401,3 +401,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--------------------------------------------------------------------------------
Code in python/ray/_private/thirdparty/dacite is adapted from https://github.com/konradhalas/dacite/blob/master/dacite
Copyright (c) 2018 Konrad Hałas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -14,7 +14,9 @@ logger = logging.getLogger(__name__)
class RuntimeEnvPluginSchemaManager: class RuntimeEnvPluginSchemaManager:
"""This manager is used to load plugin json schemas.""" """This manager is used to load plugin json schemas."""
default_schema_path = os.path.join(os.path.dirname(__file__), "schemas") default_schema_path = os.path.join(
os.path.dirname(__file__), "../../runtime_env/schemas"
)
schemas = {} schemas = {}
loaded = False loaded = False

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018 Konrad Hałas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,3 @@
from .config import Config
from .core import from_dict
from .exceptions import *

View file

@ -0,0 +1,12 @@
from dataclasses import dataclass, field
from typing import Dict, Any, Callable, Optional, Type, List
@dataclass
class Config:
type_hooks: Dict[Type, Callable[[Any], Any]] = field(default_factory=dict)
cast: List[Type] = field(default_factory=list)
forward_references: Optional[Dict[str, Any]] = None
check_types: bool = True
strict: bool = False
strict_unions_match: bool = False

View file

@ -0,0 +1,140 @@
import copy
from dataclasses import is_dataclass
from itertools import zip_longest
from typing import TypeVar, Type, Optional, get_type_hints, Mapping, Any
from .config import Config
from .data import Data
from .dataclasses import get_default_value_for_field, create_instance, DefaultValueNotFoundError, get_fields
from .exceptions import (
ForwardReferenceError,
WrongTypeError,
DaciteError,
UnionMatchError,
MissingValueError,
DaciteFieldError,
UnexpectedDataError,
StrictUnionMatchError,
)
from .types import (
is_instance,
is_generic_collection,
is_union,
extract_generic,
is_optional,
transform_value,
extract_origin_collection,
is_init_var,
extract_init_var,
)
T = TypeVar("T")
def from_dict(data_class: Type[T], data: Data, config: Optional[Config] = None) -> T:
"""Create a data class instance from a dictionary.
:param data_class: a data class type
:param data: a dictionary of a input data
:param config: a configuration of the creation process
:return: an instance of a data class
"""
init_values: Data = {}
post_init_values: Data = {}
config = config or Config()
try:
data_class_hints = get_type_hints(data_class, globalns=config.forward_references)
except NameError as error:
raise ForwardReferenceError(str(error))
data_class_fields = get_fields(data_class)
if config.strict:
extra_fields = set(data.keys()) - {f.name for f in data_class_fields}
if extra_fields:
raise UnexpectedDataError(keys=extra_fields)
for field in data_class_fields:
field = copy.copy(field)
field.type = data_class_hints[field.name]
try:
try:
field_data = data[field.name]
transformed_value = transform_value(
type_hooks=config.type_hooks, cast=config.cast, target_type=field.type, value=field_data
)
value = _build_value(type_=field.type, data=transformed_value, config=config)
except DaciteFieldError as error:
error.update_path(field.name)
raise
if config.check_types and not is_instance(value, field.type):
raise WrongTypeError(field_path=field.name, field_type=field.type, value=value)
except KeyError:
try:
value = get_default_value_for_field(field)
except DefaultValueNotFoundError:
if not field.init:
continue
raise MissingValueError(field.name)
if field.init:
init_values[field.name] = value
else:
post_init_values[field.name] = value
return create_instance(data_class=data_class, init_values=init_values, post_init_values=post_init_values)
def _build_value(type_: Type, data: Any, config: Config) -> Any:
if is_init_var(type_):
type_ = extract_init_var(type_)
if is_union(type_):
return _build_value_for_union(union=type_, data=data, config=config)
elif is_generic_collection(type_) and is_instance(data, extract_origin_collection(type_)):
return _build_value_for_collection(collection=type_, data=data, config=config)
elif is_dataclass(type_) and is_instance(data, Data):
return from_dict(data_class=type_, data=data, config=config)
return data
def _build_value_for_union(union: Type, data: Any, config: Config) -> Any:
types = extract_generic(union)
if is_optional(union) and len(types) == 2:
return _build_value(type_=types[0], data=data, config=config)
union_matches = {}
for inner_type in types:
try:
# noinspection PyBroadException
try:
data = transform_value(
type_hooks=config.type_hooks, cast=config.cast, target_type=inner_type, value=data
)
except Exception: # pylint: disable=broad-except
continue
value = _build_value(type_=inner_type, data=data, config=config)
if is_instance(value, inner_type):
if config.strict_unions_match:
union_matches[inner_type] = value
else:
return value
except DaciteError:
pass
if config.strict_unions_match:
if len(union_matches) > 1:
raise StrictUnionMatchError(union_matches)
return union_matches.popitem()[1]
if not config.check_types:
return data
raise UnionMatchError(field_type=union, value=data)
def _build_value_for_collection(collection: Type, data: Any, config: Config) -> Any:
data_type = data.__class__
if is_instance(data, Mapping):
item_type = extract_generic(collection, defaults=(Any, Any))[1]
return data_type((key, _build_value(type_=item_type, data=value, config=config)) for key, value in data.items())
elif is_instance(data, tuple):
types = extract_generic(collection)
if len(types) == 2 and types[1] == Ellipsis:
return data_type(_build_value(type_=types[0], data=item, config=config) for item in data)
return data_type(
_build_value(type_=type_, data=item, config=config) for item, type_ in zip_longest(data, types)
)
item_type = extract_generic(collection, defaults=(Any,))[0]
return data_type(_build_value(type_=item_type, data=item, config=config) for item in data)

View file

@ -0,0 +1,3 @@
from typing import Dict, Any
Data = Dict[str, Any]

View file

@ -0,0 +1,33 @@
from dataclasses import Field, MISSING, _FIELDS, _FIELD, _FIELD_INITVAR # type: ignore
from typing import Type, Any, TypeVar, List
from .data import Data
from .types import is_optional
T = TypeVar("T", bound=Any)
class DefaultValueNotFoundError(Exception):
pass
def get_default_value_for_field(field: Field) -> Any:
if field.default != MISSING:
return field.default
elif field.default_factory != MISSING: # type: ignore
return field.default_factory() # type: ignore
elif is_optional(field.type):
return None
raise DefaultValueNotFoundError()
def create_instance(data_class: Type[T], init_values: Data, post_init_values: Data) -> T:
instance = data_class(**init_values)
for key, value in post_init_values.items():
setattr(instance, key, value)
return instance
def get_fields(data_class: Type[T]) -> List[Field]:
fields = getattr(data_class, _FIELDS)
return [f for f in fields.values() if f._field_type is _FIELD or f._field_type is _FIELD_INITVAR]

View file

@ -0,0 +1,79 @@
from typing import Any, Type, Optional, Set, Dict
def _name(type_: Type) -> str:
return type_.__name__ if hasattr(type_, "__name__") else str(type_)
class DaciteError(Exception):
pass
class DaciteFieldError(DaciteError):
def __init__(self, field_path: Optional[str] = None):
super().__init__()
self.field_path = field_path
def update_path(self, parent_field_path: str) -> None:
if self.field_path:
self.field_path = f"{parent_field_path}.{self.field_path}"
else:
self.field_path = parent_field_path
class WrongTypeError(DaciteFieldError):
def __init__(self, field_type: Type, value: Any, field_path: Optional[str] = None) -> None:
super().__init__(field_path=field_path)
self.field_type = field_type
self.value = value
def __str__(self) -> str:
return (
f'wrong value type for field "{self.field_path}" - should be "{_name(self.field_type)}" '
f'instead of value "{self.value}" of type "{_name(type(self.value))}"'
)
class MissingValueError(DaciteFieldError):
def __init__(self, field_path: Optional[str] = None):
super().__init__(field_path=field_path)
def __str__(self) -> str:
return f'missing value for field "{self.field_path}"'
class UnionMatchError(WrongTypeError):
def __str__(self) -> str:
return (
f'can not match type "{_name(type(self.value))}" to any type '
f'of "{self.field_path}" union: {_name(self.field_type)}'
)
class StrictUnionMatchError(DaciteFieldError):
def __init__(self, union_matches: Dict[Type, Any], field_path: Optional[str] = None) -> None:
super().__init__(field_path=field_path)
self.union_matches = union_matches
def __str__(self) -> str:
conflicting_types = ", ".join(_name(type_) for type_ in self.union_matches)
return f'can not choose between possible Union matches for field "{self.field_path}": {conflicting_types}'
class ForwardReferenceError(DaciteError):
def __init__(self, message: str) -> None:
super().__init__()
self.message = message
def __str__(self) -> str:
return f"can not resolve forward reference: {self.message}"
class UnexpectedDataError(DaciteError):
def __init__(self, keys: Set[str]) -> None:
super().__init__()
self.keys = keys
def __str__(self) -> str:
formatted_keys = ", ".join(f'"{key}"' for key in self.keys)
return f"can not match {formatted_keys} to any data class field"

View file

View file

@ -0,0 +1,172 @@
from dataclasses import InitVar
from typing import Type, Any, Optional, Union, Collection, TypeVar, Dict, Callable, Mapping, List, Tuple
T = TypeVar("T", bound=Any)
def transform_value(
type_hooks: Dict[Type, Callable[[Any], Any]], cast: List[Type], target_type: Type, value: Any
) -> Any:
if target_type in type_hooks:
value = type_hooks[target_type](value)
else:
for cast_type in cast:
if is_subclass(target_type, cast_type):
if is_generic_collection(target_type):
value = extract_origin_collection(target_type)(value)
else:
value = target_type(value)
break
if is_optional(target_type):
if value is None:
return None
target_type = extract_optional(target_type)
return transform_value(type_hooks, cast, target_type, value)
if is_generic_collection(target_type) and isinstance(value, extract_origin_collection(target_type)):
collection_cls = value.__class__
if issubclass(collection_cls, dict):
key_cls, item_cls = extract_generic(target_type, defaults=(Any, Any))
return collection_cls(
{
transform_value(type_hooks, cast, key_cls, key): transform_value(type_hooks, cast, item_cls, item)
for key, item in value.items()
}
)
item_cls = extract_generic(target_type, defaults=(Any,))[0]
return collection_cls(transform_value(type_hooks, cast, item_cls, item) for item in value)
return value
def extract_origin_collection(collection: Type) -> Type:
try:
return collection.__extra__
except AttributeError:
return collection.__origin__
def is_optional(type_: Type) -> bool:
return is_union(type_) and type(None) in extract_generic(type_)
def extract_optional(optional: Type[Optional[T]]) -> T:
for type_ in extract_generic(optional):
if type_ is not type(None):
return type_
raise ValueError("can not find not-none value")
def is_generic(type_: Type) -> bool:
return hasattr(type_, "__origin__")
def is_union(type_: Type) -> bool:
return is_generic(type_) and type_.__origin__ == Union
def is_literal(type_: Type) -> bool:
try:
from typing import Literal # type: ignore
return is_generic(type_) and type_.__origin__ == Literal
except ImportError:
return False
def is_new_type(type_: Type) -> bool:
return hasattr(type_, "__supertype__")
def extract_new_type(type_: Type) -> Type:
return type_.__supertype__
def is_init_var(type_: Type) -> bool:
return isinstance(type_, InitVar) or type_ is InitVar
def extract_init_var(type_: Type) -> Union[Type, Any]:
try:
return type_.type
except AttributeError:
return Any
def is_instance(value: Any, type_: Type) -> bool:
if type_ == Any:
return True
elif is_union(type_):
return any(is_instance(value, t) for t in extract_generic(type_))
elif is_generic_collection(type_):
origin = extract_origin_collection(type_)
if not isinstance(value, origin):
return False
if not extract_generic(type_):
return True
if isinstance(value, tuple):
tuple_types = extract_generic(type_)
if len(tuple_types) == 1 and tuple_types[0] == ():
return len(value) == 0
elif len(tuple_types) == 2 and tuple_types[1] is ...:
return all(is_instance(item, tuple_types[0]) for item in value)
else:
if len(tuple_types) != len(value):
return False
return all(is_instance(item, item_type) for item, item_type in zip(value, tuple_types))
if isinstance(value, Mapping):
key_type, val_type = extract_generic(type_, defaults=(Any, Any))
for key, val in value.items():
if not is_instance(key, key_type) or not is_instance(val, val_type):
return False
return True
return all(is_instance(item, extract_generic(type_, defaults=(Any,))[0]) for item in value)
elif is_new_type(type_):
return is_instance(value, extract_new_type(type_))
elif is_literal(type_):
return value in extract_generic(type_)
elif is_init_var(type_):
return is_instance(value, extract_init_var(type_))
elif is_type_generic(type_):
return is_subclass(value, extract_generic(type_)[0])
else:
try:
# As described in PEP 484 - section: "The numeric tower"
if isinstance(value, (int, float)) and type_ in [float, complex]:
return True
return isinstance(value, type_)
except TypeError:
return False
def is_generic_collection(type_: Type) -> bool:
if not is_generic(type_):
return False
origin = extract_origin_collection(type_)
try:
return bool(origin and issubclass(origin, Collection))
except (TypeError, AttributeError):
return False
def extract_generic(type_: Type, defaults: Tuple = ()) -> tuple:
try:
if hasattr(type_, "_special") and type_._special:
return defaults
return type_.__args__ or defaults # type: ignore
except AttributeError:
return defaults
def is_subclass(sub_type: Type, base_type: Type) -> bool:
if is_generic_collection(sub_type):
sub_type = extract_origin_collection(sub_type)
try:
return issubclass(sub_type, base_type)
except TypeError:
return False
def is_type_generic(type_: Type) -> bool:
try:
return type_.__origin__ in (type, Type)
except AttributeError:
return False

View file

@ -0,0 +1,6 @@
from ray.runtime_env.runtime_env import RuntimeEnv, RuntimeEnvConfig # noqa: E402,F401
__all__ = [
"RuntimeEnvConfig",
"RuntimeEnv",
]

View file

@ -2,6 +2,7 @@ import json
import logging import logging
import os import os
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Dict, List, Optional, Set, Tuple, Union
from google.protobuf import json_format from google.protobuf import json_format
@ -12,6 +13,7 @@ from ray._private.runtime_env.conda import get_uri as get_conda_uri
from ray._private.runtime_env.pip import get_uri as get_pip_uri from ray._private.runtime_env.pip import get_uri as get_pip_uri
from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager from ray._private.runtime_env.plugin_schema_manager import RuntimeEnvPluginSchemaManager
from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN from ray._private.runtime_env.validation import OPTION_TO_VALIDATION_FN
from ray._private.thirdparty.dacite import from_dict
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
from ray.core.generated.runtime_env_common_pb2 import ( from ray.core.generated.runtime_env_common_pb2 import (
RuntimeEnvConfig as ProtoRuntimeEnvConfig, RuntimeEnvConfig as ProtoRuntimeEnvConfig,
@ -431,10 +433,14 @@ class RuntimeEnv(dict):
return plugin_uris return plugin_uris
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
res_value = value if is_dataclass(value):
RuntimeEnvPluginSchemaManager.validate(key, res_value) jsonable_type = asdict(value)
else:
jsonable_type = value
RuntimeEnvPluginSchemaManager.validate(key, jsonable_type)
res_value = jsonable_type
if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN: if key in RuntimeEnv.known_fields and key in OPTION_TO_VALIDATION_FN:
res_value = OPTION_TO_VALIDATION_FN[key](value) res_value = OPTION_TO_VALIDATION_FN[key](jsonable_type)
if res_value is None: if res_value is None:
return return
return super().__setitem__(key, res_value) return super().__setitem__(key, res_value)
@ -442,6 +448,14 @@ class RuntimeEnv(dict):
def set(self, name: str, value: Any) -> None: def set(self, name: str, value: Any) -> None:
self.__setitem__(name, value) self.__setitem__(name, value)
def get(self, name, default=None, data_class=None):
if name not in self:
return default
if not data_class:
return self.__getitem__(name)
else:
return from_dict(data_class=data_class, data=self.__getitem__(name))
@classmethod @classmethod
def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821 def deserialize(cls, serialized_runtime_env: str) -> "RuntimeEnv": # noqa: F821
proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv()) proto_runtime_env = json_format.Parse(serialized_runtime_env, ProtoRuntimeEnv())

View file

@ -0,0 +1,8 @@
from dataclasses import dataclass
from typing import List
@dataclass
class Pip:
packages: List[str]
pip_check: bool = False

View file

@ -122,6 +122,7 @@ py_test_module_list(
"test_runtime_env_env_vars.py", "test_runtime_env_env_vars.py",
"test_runtime_env_packaging.py", "test_runtime_env_packaging.py",
"test_runtime_env_plugin.py", "test_runtime_env_plugin.py",
"test_runtime_env_strong_type.py",
"test_runtime_env_fork_process.py", "test_runtime_env_fork_process.py",
"test_serialization.py", "test_serialization.py",
"test_shuffle.py", "test_shuffle.py",

View file

@ -0,0 +1,70 @@
import os
import sys
import pytest
import ray
from typing import List
from ray.runtime_env import RuntimeEnv
from ray.runtime_env.types.pip import Pip
from dataclasses import dataclass
@dataclass
class ValueType:
nfield1: List[str]
nfield2: bool
@dataclass
class TestPlugin:
field1: List[ValueType]
field2: str
def test_convert_from_and_to_dataclass():
runtime_env = RuntimeEnv()
test_plugin = TestPlugin(
field1=[
ValueType(nfield1=["a", "b", "c"], nfield2=False),
ValueType(nfield1=["d", "e"], nfield2=True),
],
field2="abc",
)
runtime_env.set("test_plugin", test_plugin)
serialized_runtime_env = runtime_env.serialize()
assert "test_plugin" in serialized_runtime_env
runtime_env_2 = RuntimeEnv.deserialize(serialized_runtime_env)
test_plugin_2 = runtime_env_2.get("test_plugin", data_class=TestPlugin)
assert len(test_plugin_2.field1) == 2
assert test_plugin_2.field1[0].nfield1 == ["a", "b", "c"]
assert test_plugin_2.field1[0].nfield2 is False
assert test_plugin_2.field1[1].nfield1 == ["d", "e"]
assert test_plugin_2.field1[1].nfield2 is True
assert test_plugin_2.field2 == "abc"
def test_pip(start_cluster):
cluster, address = start_cluster
ray.init(address)
runtime_env = RuntimeEnv()
pip = Pip(packages=["pip-install-test==0.5"])
runtime_env.set("pip", pip)
@ray.remote
class Actor:
def foo(self):
import pip_install_test # noqa
return "hello"
a = Actor.options(runtime_env=runtime_env).remote()
assert ray.get(a.foo.remote()) == "hello"
if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
sys.exit(pytest.main(["-sv", __file__]))