mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
Merge branch 'master' of github.com:ray-project/ray into chunkedclienttask
This commit is contained in:
commit
5847582593
54 changed files with 1305 additions and 472 deletions
|
@ -254,7 +254,7 @@
|
|||
- ./ci/env/install-minimal.sh
|
||||
- ./ci/env/env_info.sh
|
||||
- python ./ci/env/check_minimal_install.py
|
||||
- bazel test --test_output=streamed --config=ci $(./ci/run/bazel_export_options)
|
||||
- bazel test --test_output=streamed --config=ci --test_env=RAY_MINIMAL=1 $(./ci/run/bazel_export_options)
|
||||
python/ray/tests/test_basic
|
||||
- bazel test --test_output=streamed --config=ci $(./ci/run/bazel_export_options)
|
||||
python/ray/tests/test_basic_2
|
||||
|
|
34
BUILD.bazel
34
BUILD.bazel
|
@ -539,6 +539,7 @@ cc_library(
|
|||
":gcs_service_rpc",
|
||||
":gcs_table_storage_lib",
|
||||
":node_manager_rpc",
|
||||
":observable_store_client",
|
||||
":pubsub_lib",
|
||||
":raylet_client_lib",
|
||||
":scheduler",
|
||||
|
@ -1853,6 +1854,7 @@ cc_library(
|
|||
deps = [
|
||||
":gcs",
|
||||
":gcs_in_memory_store_client",
|
||||
":observable_store_client",
|
||||
":pubsub_lib",
|
||||
":ray_common",
|
||||
":redis_store_client",
|
||||
|
@ -2281,6 +2283,24 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "observable_store_client",
|
||||
srcs = [
|
||||
"src/ray/gcs/store_client/observable_store_client.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"src/ray/gcs/callback.h",
|
||||
"src/ray/gcs/store_client/observable_store_client.h",
|
||||
"src/ray/gcs/store_client/store_client.h",
|
||||
],
|
||||
copts = COPTS,
|
||||
strip_include_prefix = "src",
|
||||
deps = [
|
||||
":ray_common",
|
||||
":ray_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "store_client_test_lib",
|
||||
hdrs = [
|
||||
|
@ -2327,6 +2347,20 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "observable_store_client_test",
|
||||
size = "small",
|
||||
srcs = ["src/ray/gcs/store_client/test/in_memory_store_client_test.cc"],
|
||||
copts = COPTS,
|
||||
tags = ["team:core"],
|
||||
deps = [
|
||||
":gcs_in_memory_store_client",
|
||||
":observable_store_client",
|
||||
":store_client_test_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gcs",
|
||||
srcs = glob(
|
||||
|
|
|
@ -76,7 +76,7 @@ if [[ "$platform" == "linux" ]]; then
|
|||
"$PYTHON_EXE" -u -c "import ray; print(ray.__commit__)" | grep "$TRAVIS_COMMIT" || (echo "ray.__commit__ not set properly!" && exit 1)
|
||||
|
||||
# Install the dependencies to run the tests.
|
||||
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests
|
||||
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests proxy.py
|
||||
|
||||
# Run a simple test script to make sure that the wheel works.
|
||||
for SCRIPT in "${TEST_SCRIPTS[@]}"; do
|
||||
|
@ -117,11 +117,11 @@ elif [[ "$platform" == "macosx" ]]; then
|
|||
"$PIP_CMD" install -q "$PYTHON_WHEEL"
|
||||
|
||||
# Install the dependencies to run the tests.
|
||||
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests
|
||||
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests proxy.py
|
||||
|
||||
# Run a simple test script to make sure that the wheel works.
|
||||
for SCRIPT in "${TEST_SCRIPTS[@]}"; do
|
||||
retry "$PYTHON_EXE" "$SCRIPT"
|
||||
PATH="$(dirname "$PYTHON_EXE"):$PATH" retry "$PYTHON_EXE" "$SCRIPT"
|
||||
done
|
||||
done
|
||||
elif [ "${platform}" = windows ]; then
|
||||
|
|
|
@ -85,7 +85,7 @@ class DashboardAgent(object):
|
|||
logger.info("Parent pid is %s", self.ppid)
|
||||
|
||||
# Setup raylet channel
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
|
||||
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
aiogrpc.init_grpc_aio()
|
||||
GRPC_CHANNEL_OPTIONS = (
|
||||
("grpc.enable_http_proxy", 0),
|
||||
*ray_constants.GLOBAL_GRPC_OPTIONS,
|
||||
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ except ImportError:
|
|||
from grpc.experimental import aio as aiogrpc
|
||||
|
||||
from ray._private.gcs_pubsub import GcsAioActorSubscriber
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.dashboard.utils as dashboard_utils
|
||||
import ray.dashboard.optional_utils as dashboard_optional_utils
|
||||
from ray.dashboard.optional_utils import rest_response
|
||||
|
@ -88,7 +89,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
address = "{}:{}".format(
|
||||
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
||||
)
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
address, options, asynchronous=True
|
||||
)
|
||||
|
@ -207,7 +208,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
|
|||
except KeyError:
|
||||
return rest_response(success=False, message="Bad Request")
|
||||
try:
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
f"{ip_address}:{port}", options=options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -43,7 +43,7 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
|
|||
)
|
||||
if dashboard_rpc_address:
|
||||
logger.info("Report events to %s", dashboard_rpc_address)
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = utils.init_grpc_channel(
|
||||
dashboard_rpc_address, options=options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -126,7 +126,7 @@ class LogHeadV1(dashboard_utils.DashboardHeadModule):
|
|||
node_id, ports = change.new
|
||||
ip = DataSource.node_id_to_ip[node_id]
|
||||
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = init_grpc_channel(
|
||||
f"{ip}:{ports[1]}", options=options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -78,7 +78,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
|
|||
address = "{}:{}".format(
|
||||
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
|
||||
)
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
address, options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -11,6 +11,7 @@ import ray.experimental.internal_kv as internal_kv
|
|||
import ray._private.services
|
||||
import ray._private.utils
|
||||
from ray.ray_constants import (
|
||||
GLOBAL_GRPC_OPTIONS,
|
||||
DEBUG_AUTOSCALING_STATUS,
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_ERROR,
|
||||
|
@ -48,7 +49,7 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
|||
if change.new:
|
||||
node_id, ports = change.new
|
||||
ip = DataSource.node_id_to_ip[node_id]
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = GLOBAL_GRPC_OPTIONS
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
f"{ip}:{ports[1]}", options=options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -17,29 +17,6 @@ def disable_aiohttp_cache():
|
|||
os.environ.pop("RAY_DASHBOARD_NO_CACHE", None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_http_proxy():
|
||||
http_proxy = os.environ.get("http_proxy", None)
|
||||
https_proxy = os.environ.get("https_proxy", None)
|
||||
|
||||
# set http proxy
|
||||
os.environ["http_proxy"] = "www.example.com:990"
|
||||
os.environ["https_proxy"] = "www.example.com:990"
|
||||
|
||||
yield
|
||||
|
||||
# reset http proxy
|
||||
if http_proxy:
|
||||
os.environ["http_proxy"] = http_proxy
|
||||
else:
|
||||
del os.environ["http_proxy"]
|
||||
|
||||
if https_proxy:
|
||||
os.environ["https_proxy"] = https_proxy
|
||||
else:
|
||||
del os.environ["https_proxy"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_event_line_limit():
|
||||
os.environ["EVENT_READ_LINE_LENGTH_LIMIT"] = "1024"
|
||||
|
|
|
@ -596,37 +596,55 @@ def test_immutable_types():
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RAY_MINIMAL") == "1",
|
||||
reason="This test is not supposed to work for minimal installation.",
|
||||
os.environ.get("RAY_MINIMAL") == "1" or os.environ.get("RAY_DEFAULT") == "1",
|
||||
reason="This test is not supposed to work for minimal or default installation.",
|
||||
)
|
||||
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
|
||||
address_info = ray.init(num_cpus=1, include_dashboard=True)
|
||||
assert wait_until_server_available(address_info["webui_url"]) is True
|
||||
def test_http_proxy(enable_test_module, start_http_proxy, shutdown_only):
|
||||
# C++ config `grpc_enable_http_proxy` only initializes once, so we have to
|
||||
# run driver as a separate process to make sure the correct config value
|
||||
# is initialized.
|
||||
script = """
|
||||
import ray
|
||||
import time
|
||||
import requests
|
||||
from ray._private.test_utils import (
|
||||
format_web_url,
|
||||
wait_until_server_available,
|
||||
)
|
||||
import logging
|
||||
|
||||
webui_url = address_info["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
timeout_seconds = 10
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
address_info = ray.init(num_cpus=1, include_dashboard=True)
|
||||
assert wait_until_server_available(address_info["webui_url"]) is True
|
||||
|
||||
webui_url = address_info["webui_url"]
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
timeout_seconds = 10
|
||||
start_time = time.time()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
try:
|
||||
response = requests.get(
|
||||
webui_url + "/test/dump", proxies={"http": None, "https": None}
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response = requests.get(
|
||||
webui_url + "/test/dump", proxies={"http": None, "https": None}
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response.json()
|
||||
assert response.ok
|
||||
except Exception as ex:
|
||||
logger.info("failed response: %s", response.text)
|
||||
raise ex
|
||||
break
|
||||
except (AssertionError, requests.exceptions.ConnectionError) as e:
|
||||
logger.info("Retry because of %s", e)
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
raise Exception("Timed out while testing.")
|
||||
response.json()
|
||||
assert response.ok
|
||||
except Exception as ex:
|
||||
logger.info("failed response: %s", response.text)
|
||||
raise ex
|
||||
break
|
||||
except (AssertionError, requests.exceptions.ConnectionError) as e:
|
||||
logger.info("Retry because of %s", e)
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
raise Exception("Timed out while testing.")
|
||||
"""
|
||||
env = start_http_proxy
|
||||
run_string_as_driver(script, dict(os.environ, **env))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
|
@ -243,7 +243,7 @@
|
|||
"more machines.\n",
|
||||
"\n",
|
||||
"To ingest 500GiB of data, we'll set up a Ray Cluster.\n",
|
||||
"The provided :download:`big_data_ingestion.yaml <../big_data_ingestion.yaml>`\n",
|
||||
"The provided [big_data_ingestion.yaml](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/data/big_data_ingestion.yaml)\n",
|
||||
"cluster config can be used to set up an AWS cluster with 70 CPU nodes and\n",
|
||||
"16 GPU nodes. Using following command to bring up the Ray cluster.\n",
|
||||
"\n",
|
||||
|
@ -362,7 +362,7 @@
|
|||
"# -> throughput: 8.56GiB/s\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Note: The pipeline can also be submitted using :ref:`Ray Job Submission <jobs-overview>`,\n",
|
||||
"Note: The pipeline can also be submitted using [Ray Job Submission](https://docs.ray.io/en/latest/cluster/job-submission.html) ,\n",
|
||||
"which is in beta starting with Ray 1.12. Try it out!"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
"\n",
|
||||
"## Cluster Setup\n",
|
||||
"\n",
|
||||
"First, we'll set up our Ray Cluster. The provided `dask_xgboost.yaml`\n",
|
||||
"First, we'll set up our Ray Cluster. The provided [dask_xgboost.yaml](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-core/examples/dask_xgboost/dask_xgboost.yaml)\n",
|
||||
"cluster config can be used to set up an AWS cluster with 64 CPUs.\n",
|
||||
"\n",
|
||||
"The following steps assume you are in a directory with both\n",
|
||||
|
@ -477,4 +477,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
"\n",
|
||||
"## Cluster Setup\n",
|
||||
"\n",
|
||||
"First, we'll set up our Ray Cluster. The provided ``modin_xgboost.yaml``\n",
|
||||
"First, we'll set up our Ray Cluster. The provided [modin_xgboost.yaml](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-core/examples/modin_xgboost/modin_xgboost.yaml)\n",
|
||||
"cluster config can be used to set up an AWS cluster with 64 CPUs.\n",
|
||||
"\n",
|
||||
"The following steps assume you are in a directory with both\n",
|
||||
|
|
|
@ -7,6 +7,7 @@ import time
|
|||
import grpc
|
||||
|
||||
import ray
|
||||
from ray import ray_constants
|
||||
from ray.core.generated.common_pb2 import ErrorType
|
||||
from ray.core.generated import gcs_service_pb2_grpc
|
||||
from ray.core.generated import gcs_service_pb2
|
||||
|
@ -74,7 +75,7 @@ _GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
|
|||
# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
|
||||
# grpc.use_local_subchannel_pool=0: Subchannels are shared.
|
||||
_GRPC_OPTIONS = [
|
||||
("grpc.enable_http_proxy", 0),
|
||||
*ray_constants.GLOBAL_GRPC_OPTIONS,
|
||||
("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
|
||||
("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
|
||||
("grpc.keepalive_time_ms", _GRPC_KEEPALIVE_TIME_MS),
|
||||
|
|
|
@ -145,7 +145,7 @@ class Monitor:
|
|||
retry_on_failure: bool = True,
|
||||
):
|
||||
gcs_address = address
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
gcs_channel = ray._private.utils.init_grpc_channel(gcs_address, options)
|
||||
# TODO: Use gcs client for this
|
||||
self.gcs_node_resources_stub = (
|
||||
|
|
|
@ -7,6 +7,7 @@ import grpc
|
|||
import ray
|
||||
|
||||
from typing import Dict, List
|
||||
from ray import ray_constants
|
||||
|
||||
from ray.core.generated.gcs_service_pb2 import (
|
||||
GetAllActorInfoRequest,
|
||||
|
@ -106,7 +107,7 @@ class StateDataSourceClient:
|
|||
|
||||
def register_raylet_client(self, node_id: str, address: str, port: int):
|
||||
full_addr = f"{address}:{port}"
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
full_addr, options, asynchronous=True
|
||||
)
|
||||
|
@ -116,7 +117,7 @@ class StateDataSourceClient:
|
|||
self._raylet_stubs.pop(node_id)
|
||||
|
||||
def register_agent_client(self, node_id, address: str, port: int):
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = ray._private.utils.init_grpc_channel(
|
||||
f"{address}:{port}", options=options, asynchronous=True
|
||||
)
|
||||
|
|
|
@ -332,6 +332,14 @@ CALL_STACK_LINE_DELIMITER = " | "
|
|||
# NOTE: This is equal to the C++ limit of (RAY_CONFIG::max_grpc_message_size)
|
||||
GRPC_CPP_MAX_MESSAGE_SIZE = 100 * 1024 * 1024
|
||||
|
||||
# GRPC options
|
||||
GRPC_ENABLE_HTTP_PROXY = (
|
||||
1
|
||||
if os.environ.get("RAY_grpc_enable_http_proxy", "0").lower() in ("1", "true")
|
||||
else 0
|
||||
)
|
||||
GLOBAL_GRPC_OPTIONS = (("grpc.enable_http_proxy", GRPC_ENABLE_HTTP_PROXY),)
|
||||
|
||||
# Internal kv namespaces
|
||||
KV_NAMESPACE_DASHBOARD = b"dashboard"
|
||||
KV_NAMESPACE_SESSION = b"session"
|
||||
|
|
|
@ -58,6 +58,8 @@ from ray.serve.context import (
|
|||
get_internal_replica_context,
|
||||
ReplicaContext,
|
||||
)
|
||||
from ray.serve.pipeline.api import build as pipeline_build
|
||||
from ray.serve.pipeline.api import get_and_validate_ingress_deployment
|
||||
from ray._private.usage import usage_lib
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
@ -592,9 +594,6 @@ def run(
|
|||
RayServeHandle: A regular ray serve handle that can be called by user
|
||||
to execute the serve DAG.
|
||||
"""
|
||||
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
|
||||
from ray.serve.pipeline.api import build as pipeline_build
|
||||
from ray.serve.pipeline.api import get_and_validate_ingress_deployment
|
||||
|
||||
client = start(detached=True, http_options={"host": host, "port": port})
|
||||
|
||||
|
@ -668,8 +667,6 @@ def build(target: Union[ClassNode, FunctionNode]) -> Application:
|
|||
The returned Application object can be exported to a dictionary or YAML
|
||||
config.
|
||||
"""
|
||||
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
|
||||
from ray.serve.pipeline.api import build as pipeline_build
|
||||
|
||||
if in_interactive_shell():
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -46,6 +46,12 @@ class DeploymentStatusInfo:
|
|||
)
|
||||
|
||||
|
||||
HEALTH_CHECK_CONCURRENCY_GROUP = "health_check"
|
||||
REPLICA_DEFAULT_ACTOR_OPTIONS = {
|
||||
"concurrency_groups": {HEALTH_CHECK_CONCURRENCY_GROUP: 1}
|
||||
}
|
||||
|
||||
|
||||
class DeploymentInfo:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -95,14 +101,14 @@ class DeploymentInfo:
|
|||
or self.serialized_deployment_def is not None
|
||||
)
|
||||
if self.replica_config.import_path is not None:
|
||||
self._cached_actor_def = ray.remote(
|
||||
self._cached_actor_def = ray.remote(**REPLICA_DEFAULT_ACTOR_OPTIONS)(
|
||||
create_replica_wrapper(
|
||||
self.actor_name,
|
||||
import_path=self.replica_config.import_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._cached_actor_def = ray.remote(
|
||||
self._cached_actor_def = ray.remote(**REPLICA_DEFAULT_ACTOR_OPTIONS)(
|
||||
create_replica_wrapper(
|
||||
self.actor_name,
|
||||
serialized_deployment_def=self.serialized_deployment_def,
|
||||
|
|
|
@ -18,7 +18,7 @@ from ray.util import metrics
|
|||
from ray._private.async_compat import sync_to_async
|
||||
|
||||
from ray.serve.autoscaling_metrics import start_metrics_pusher
|
||||
from ray.serve.common import ReplicaTag
|
||||
from ray.serve.common import HEALTH_CHECK_CONCURRENCY_GROUP, ReplicaTag
|
||||
from ray.serve.config import DeploymentConfig
|
||||
from ray.serve.constants import (
|
||||
HEALTH_CHECK_METHOD,
|
||||
|
@ -219,6 +219,7 @@ def create_replica_wrapper(
|
|||
if self.replica is not None:
|
||||
return await self.replica.prepare_for_shutdown()
|
||||
|
||||
@ray.method(concurrency_group=HEALTH_CHECK_CONCURRENCY_GROUP)
|
||||
async def check_health(self):
|
||||
await self.replica.check_health()
|
||||
|
||||
|
|
|
@ -235,6 +235,33 @@ def test_uvicorn_duplicate_headers(serve_instance):
|
|||
assert resp.headers["content-length"] == "9"
|
||||
|
||||
|
||||
def test_healthcheck_timeout(serve_instance):
|
||||
# https://github.com/ray-project/ray/issues/24554
|
||||
|
||||
signal = SignalActor.remote()
|
||||
|
||||
@serve.deployment(
|
||||
_health_check_timeout_s=2,
|
||||
_health_check_period_s=1,
|
||||
_graceful_shutdown_timeout_s=0,
|
||||
)
|
||||
class A:
|
||||
def check_health(self):
|
||||
return True
|
||||
|
||||
def __call__(self):
|
||||
ray.get(signal.wait.remote())
|
||||
|
||||
A.deploy()
|
||||
handle = A.get_handle()
|
||||
ref = handle.remote()
|
||||
# without the proper fix, the ref will fail with actor died error.
|
||||
with pytest.raises(GetTimeoutError):
|
||||
ray.get(ref, timeout=10)
|
||||
signal.send.remote()
|
||||
ray.get(ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
|
@ -276,9 +276,13 @@ def call_ray_start(request):
|
|||
"--max-worker-port=0 --port 0",
|
||||
)
|
||||
command_args = parameter.split(" ")
|
||||
out = ray._private.utils.decode(
|
||||
subprocess.check_output(command_args, stderr=subprocess.STDOUT)
|
||||
)
|
||||
try:
|
||||
out = ray._private.utils.decode(
|
||||
subprocess.check_output(command_args, stderr=subprocess.STDOUT)
|
||||
)
|
||||
except Exception as e:
|
||||
print(type(e), e)
|
||||
raise
|
||||
# Get the redis address from the output.
|
||||
redis_substring_prefix = "--address='"
|
||||
address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix)
|
||||
|
@ -867,3 +871,27 @@ def create_ray_logs_for_failed_test(rep):
|
|||
test_name = rep.nodeid.replace(os.sep, "::")
|
||||
output_file = os.path.join(archive_dir, f"{test_name}_{time.time():.4f}")
|
||||
shutil.make_archive(output_file, "zip", logs_dir)
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def start_http_proxy(request):
|
||||
env = {}
|
||||
|
||||
proxy = None
|
||||
try:
|
||||
if request.param:
|
||||
# the `proxy` command is from the proxy.py package.
|
||||
proxy = subprocess.Popen(
|
||||
["proxy", "--port", "8899", "--log-level", "ERROR"]
|
||||
)
|
||||
env["RAY_grpc_enable_http_proxy"] = "1"
|
||||
proxy_url = "http://localhost:8899"
|
||||
else:
|
||||
proxy_url = "http://example.com"
|
||||
env["http_proxy"] = proxy_url
|
||||
env["https_proxy"] = proxy_url
|
||||
yield env
|
||||
finally:
|
||||
if proxy:
|
||||
proxy.terminate()
|
||||
proxy.wait()
|
||||
|
|
|
@ -22,17 +22,29 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/6662
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("RAY_MINIMAL") == "1",
|
||||
reason="This test is not supposed to work for minimal installation.",
|
||||
)
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
|
||||
def test_ignore_http_proxy(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
os.environ["http_proxy"] = "http://example.com"
|
||||
os.environ["https_proxy"] = "http://example.com"
|
||||
def test_http_proxy(start_http_proxy, shutdown_only):
|
||||
# C++ config `grpc_enable_http_proxy` only initializes once, so we have to
|
||||
# run driver as a separate process to make sure the correct config value
|
||||
# is initialized.
|
||||
script = """
|
||||
import ray
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return 1
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
assert ray.get(f.remote()) == 1
|
||||
@ray.remote
|
||||
def f():
|
||||
return 1
|
||||
|
||||
assert ray.get(f.remote()) == 1
|
||||
"""
|
||||
|
||||
env = start_http_proxy
|
||||
run_string_as_driver(script, dict(os.environ, **env))
|
||||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/16025
|
||||
|
@ -310,6 +322,66 @@ def test_options():
|
|||
"zzz": 42,
|
||||
}
|
||||
|
||||
# test options for other Ray libraries.
|
||||
namespace = "namespace"
|
||||
|
||||
class mock_options:
|
||||
def __init__(self, **options):
|
||||
self.options = {"_metadata": {namespace: options}}
|
||||
|
||||
def keys(self):
|
||||
return ("_metadata",)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.options[key]
|
||||
|
||||
def __call__(self, f):
|
||||
f._default_options.update(self.options)
|
||||
return f
|
||||
|
||||
@mock_options(a=1, b=2)
|
||||
@ray.remote(num_gpus=2)
|
||||
def foo():
|
||||
pass
|
||||
|
||||
assert foo._default_options == {
|
||||
"_metadata": {"namespace": {"a": 1, "b": 2}},
|
||||
"num_gpus": 2,
|
||||
}
|
||||
|
||||
f2 = foo.options(num_cpus=1, num_gpus=1, **mock_options(a=11, c=3))
|
||||
|
||||
# TODO(suquark): The current implementation of `.options()` is so bad that we
|
||||
# cannot even access its options from outside. Here we hack the closures to
|
||||
# achieve our goal. Need futher efforts to clean up the tech debt.
|
||||
assert f2.remote.__closure__[1].cell_contents == {
|
||||
"_metadata": {"namespace": {"a": 11, "b": 2, "c": 3}},
|
||||
"num_cpus": 1,
|
||||
"num_gpus": 1,
|
||||
}
|
||||
|
||||
class mock_options2(mock_options):
|
||||
def __init__(self, **options):
|
||||
self.options = {"_metadata": {namespace + "2": options}}
|
||||
|
||||
f3 = foo.options(num_cpus=1, num_gpus=1, **mock_options2(a=11, c=3))
|
||||
|
||||
assert f3.remote.__closure__[1].cell_contents == {
|
||||
"_metadata": {"namespace": {"a": 1, "b": 2}, "namespace2": {"a": 11, "c": 3}},
|
||||
"num_cpus": 1,
|
||||
"num_gpus": 1,
|
||||
}
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Ensure only a single "**option" per ".options()".
|
||||
# Otherwise it would be confusing.
|
||||
foo.options(
|
||||
num_cpus=1,
|
||||
num_gpus=1,
|
||||
**mock_options(a=11, c=3),
|
||||
**mock_options2(a=11, c=3),
|
||||
)
|
||||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/17842
|
||||
def test_disable_cuda_devices():
|
||||
|
|
|
@ -400,7 +400,7 @@ def test_gcs_drain(ray_start_cluster_head, error_pubsub):
|
|||
"""
|
||||
# Prepare requests.
|
||||
gcs_server_addr = cluster.gcs_address
|
||||
options = (("grpc.enable_http_proxy", 0),)
|
||||
options = ray_constants.GLOBAL_GRPC_OPTIONS
|
||||
channel = grpc.insecure_channel(gcs_server_addr, options)
|
||||
stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(channel)
|
||||
r = gcs_service_pb2.DrainNodeRequest()
|
||||
|
|
|
@ -404,7 +404,7 @@ async def test_state_data_source_client(ray_start_cluster):
|
|||
worker = cluster.add_node(num_cpus=2)
|
||||
|
||||
GRPC_CHANNEL_OPTIONS = (
|
||||
("grpc.enable_http_proxy", 0),
|
||||
*ray_constants.GLOBAL_GRPC_OPTIONS,
|
||||
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
|
||||
)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
import numbers
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
@ -129,6 +130,8 @@ class TuneReporterBase(ProgressReporter):
|
|||
max_error_rows: Maximum number of rows to print in the
|
||||
error table. The error table lists the error file, if any,
|
||||
corresponding to each trial. Defaults to 20.
|
||||
max_column_length: Maximum column length (in characters). Column
|
||||
headers and values longer than this will be abbreviated.
|
||||
max_report_frequency: Maximum report frequency in seconds.
|
||||
Defaults to 5s.
|
||||
infer_limit: Maximum number of metrics to automatically infer
|
||||
|
@ -169,11 +172,13 @@ class TuneReporterBase(ProgressReporter):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
total_samples: Optional[int] = None,
|
||||
max_progress_rows: int = 20,
|
||||
max_error_rows: int = 20,
|
||||
max_column_length: int = 20,
|
||||
max_report_frequency: int = 5,
|
||||
infer_limit: int = 3,
|
||||
print_intermediate_tables: Optional[bool] = None,
|
||||
|
@ -188,6 +193,7 @@ class TuneReporterBase(ProgressReporter):
|
|||
self._parameter_columns = parameter_columns or []
|
||||
self._max_progress_rows = max_progress_rows
|
||||
self._max_error_rows = max_error_rows
|
||||
self._max_column_length = max_column_length
|
||||
self._infer_limit = infer_limit
|
||||
|
||||
if print_intermediate_tables is None:
|
||||
|
@ -360,6 +366,7 @@ class TuneReporterBase(ProgressReporter):
|
|||
force_table=self._print_intermediate_tables,
|
||||
fmt=fmt,
|
||||
max_rows=max_progress,
|
||||
max_column_length=self._max_column_length,
|
||||
done=done,
|
||||
metric=self._metric,
|
||||
mode=self._mode,
|
||||
|
@ -461,6 +468,8 @@ class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
|
|||
max_error_rows: Maximum number of rows to print in the
|
||||
error table. The error table lists the error file, if any,
|
||||
corresponding to each trial. Defaults to 20.
|
||||
max_column_length: Maximum column length (in characters). Column
|
||||
headers and values longer than this will be abbreviated.
|
||||
max_report_frequency: Maximum report frequency in seconds.
|
||||
Defaults to 5s.
|
||||
infer_limit: Maximum number of metrics to automatically infer
|
||||
|
@ -480,12 +489,14 @@ class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
overwrite: bool = True,
|
||||
metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
total_samples: Optional[int] = None,
|
||||
max_progress_rows: int = 20,
|
||||
max_error_rows: int = 20,
|
||||
max_column_length: int = 20,
|
||||
max_report_frequency: int = 5,
|
||||
infer_limit: int = 3,
|
||||
print_intermediate_tables: Optional[bool] = None,
|
||||
|
@ -494,17 +505,18 @@ class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
|
|||
sort_by_metric: bool = False,
|
||||
):
|
||||
super(JupyterNotebookReporter, self).__init__(
|
||||
metric_columns,
|
||||
parameter_columns,
|
||||
total_samples,
|
||||
max_progress_rows,
|
||||
max_error_rows,
|
||||
max_report_frequency,
|
||||
infer_limit,
|
||||
print_intermediate_tables,
|
||||
metric,
|
||||
mode,
|
||||
sort_by_metric,
|
||||
metric_columns=metric_columns,
|
||||
parameter_columns=parameter_columns,
|
||||
total_samples=total_samples,
|
||||
max_progress_rows=max_progress_rows,
|
||||
max_error_rows=max_error_rows,
|
||||
max_column_length=max_column_length,
|
||||
max_report_frequency=max_report_frequency,
|
||||
infer_limit=infer_limit,
|
||||
print_intermediate_tables=print_intermediate_tables,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
sort_by_metric=sort_by_metric,
|
||||
)
|
||||
|
||||
if not IS_NOTEBOOK:
|
||||
|
@ -564,6 +576,8 @@ class CLIReporter(TuneReporterBase):
|
|||
max_error_rows: Maximum number of rows to print in the
|
||||
error table. The error table lists the error file, if any,
|
||||
corresponding to each trial. Defaults to 20.
|
||||
max_column_length: Maximum column length (in characters). Column
|
||||
headers and values longer than this will be abbreviated.
|
||||
max_report_frequency: Maximum report frequency in seconds.
|
||||
Defaults to 5s.
|
||||
infer_limit: Maximum number of metrics to automatically infer
|
||||
|
@ -583,11 +597,13 @@ class CLIReporter(TuneReporterBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
total_samples: Optional[int] = None,
|
||||
max_progress_rows: int = 20,
|
||||
max_error_rows: int = 20,
|
||||
max_column_length: int = 20,
|
||||
max_report_frequency: int = 5,
|
||||
infer_limit: int = 3,
|
||||
print_intermediate_tables: Optional[bool] = None,
|
||||
|
@ -597,17 +613,18 @@ class CLIReporter(TuneReporterBase):
|
|||
):
|
||||
|
||||
super(CLIReporter, self).__init__(
|
||||
metric_columns,
|
||||
parameter_columns,
|
||||
total_samples,
|
||||
max_progress_rows,
|
||||
max_error_rows,
|
||||
max_report_frequency,
|
||||
infer_limit,
|
||||
print_intermediate_tables,
|
||||
metric,
|
||||
mode,
|
||||
sort_by_metric,
|
||||
metric_columns=metric_columns,
|
||||
parameter_columns=parameter_columns,
|
||||
total_samples=total_samples,
|
||||
max_progress_rows=max_progress_rows,
|
||||
max_error_rows=max_error_rows,
|
||||
max_column_length=max_column_length,
|
||||
max_report_frequency=max_report_frequency,
|
||||
infer_limit=infer_limit,
|
||||
print_intermediate_tables=print_intermediate_tables,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
sort_by_metric=sort_by_metric,
|
||||
)
|
||||
|
||||
def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
|
||||
|
@ -683,6 +700,7 @@ def trial_progress_str(
|
|||
force_table: bool = False,
|
||||
fmt: str = "psql",
|
||||
max_rows: Optional[int] = None,
|
||||
max_column_length: int = 20,
|
||||
done: bool = False,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
|
@ -711,6 +729,7 @@ def trial_progress_str(
|
|||
fmt: Output format (see tablefmt in tabulate API).
|
||||
max_rows: Maximum number of rows in the trial table. Defaults to
|
||||
unlimited.
|
||||
max_column_length: Maximum column length (in characters).
|
||||
done: True indicates that the tuning run finished.
|
||||
metric: Metric used to sort trials.
|
||||
mode: One of [min, max]. Determines whether objective is
|
||||
|
@ -747,19 +766,48 @@ def trial_progress_str(
|
|||
|
||||
if force_table or (has_verbosity(Verbosity.V2_TRIAL_NORM) and done):
|
||||
messages += trial_progress_table(
|
||||
trials,
|
||||
metric_columns,
|
||||
parameter_columns,
|
||||
fmt,
|
||||
max_rows,
|
||||
metric,
|
||||
mode,
|
||||
sort_by_metric,
|
||||
trials=trials,
|
||||
metric_columns=metric_columns,
|
||||
parameter_columns=parameter_columns,
|
||||
fmt=fmt,
|
||||
max_rows=max_rows,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
sort_by_metric=sort_by_metric,
|
||||
max_column_length=max_column_length,
|
||||
)
|
||||
|
||||
return delim.join(messages)
|
||||
|
||||
|
||||
def _max_len(value: Any, max_len: int = 20, add_addr: bool = False) -> Any:
|
||||
"""Abbreviate a string representation of an object to `max_len` characters.
|
||||
|
||||
For numbers, booleans and None, the original value will be returned for
|
||||
correct rendering in the table formatting tool.
|
||||
|
||||
Args:
|
||||
value: Object to be represented as a string.
|
||||
max_len: Maximum return string length.
|
||||
add_addr: If True, will add part of the object address to the end of the
|
||||
string, e.g. to identify different instances of the same class. If
|
||||
False, three dots (``...``) will be used instead.
|
||||
"""
|
||||
if value is None or isinstance(value, (int, float, numbers.Number, bool)):
|
||||
return value
|
||||
|
||||
string = str(value)
|
||||
if len(string) <= max_len:
|
||||
return string
|
||||
|
||||
if add_addr and not isinstance(value, (int, float, bool)):
|
||||
result = f"{string[: (max_len - 5)]}_{hex(id(value))[-4:]}"
|
||||
return result
|
||||
|
||||
result = f"{string[: (max_len - 3)]}..."
|
||||
return result
|
||||
|
||||
|
||||
def trial_progress_table(
|
||||
trials: List[Trial],
|
||||
metric_columns: Union[List[str], Dict[str, str]],
|
||||
|
@ -769,6 +817,7 @@ def trial_progress_table(
|
|||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
sort_by_metric: bool = False,
|
||||
max_column_length: int = 20,
|
||||
):
|
||||
messages = []
|
||||
num_trials = len(trials)
|
||||
|
@ -840,17 +889,29 @@ def trial_progress_table(
|
|||
|
||||
# Build trial rows.
|
||||
trial_table = [
|
||||
_get_trial_info(trial, parameter_keys, metric_keys) for trial in trials
|
||||
_get_trial_info(
|
||||
trial, parameter_keys, metric_keys, max_column_length=max_column_length
|
||||
)
|
||||
for trial in trials
|
||||
]
|
||||
# Format column headings
|
||||
if isinstance(metric_columns, Mapping):
|
||||
formatted_metric_columns = [metric_columns[k] for k in metric_keys]
|
||||
formatted_metric_columns = [
|
||||
_max_len(metric_columns[k], max_len=max_column_length, add_addr=False)
|
||||
for k in metric_keys
|
||||
]
|
||||
else:
|
||||
formatted_metric_columns = metric_keys
|
||||
if isinstance(parameter_columns, Mapping):
|
||||
formatted_parameter_columns = [parameter_columns[k] for k in parameter_keys]
|
||||
formatted_parameter_columns = [
|
||||
_max_len(parameter_columns[k], max_len=max_column_length, add_addr=False)
|
||||
for k in parameter_keys
|
||||
]
|
||||
else:
|
||||
formatted_parameter_columns = parameter_keys
|
||||
formatted_parameter_columns = [
|
||||
_max_len(k, max_len=max_column_length, add_addr=False)
|
||||
for k in parameter_keys
|
||||
]
|
||||
columns = (
|
||||
["Trial name", "status", "loc"]
|
||||
+ formatted_parameter_columns
|
||||
|
@ -972,7 +1033,9 @@ def _get_trial_location(trial: Trial, result: dict) -> Location:
|
|||
return location
|
||||
|
||||
|
||||
def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]):
|
||||
def _get_trial_info(
|
||||
trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20
|
||||
):
|
||||
"""Returns the following information about a trial:
|
||||
|
||||
name | status | loc | params... | metrics...
|
||||
|
@ -981,16 +1044,27 @@ def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]):
|
|||
trial: Trial to get information for.
|
||||
parameters: Names of trial parameters to include.
|
||||
metrics: Names of metrics to include.
|
||||
max_column_length: Maximum column length (in characters).
|
||||
"""
|
||||
result = trial.last_result
|
||||
config = trial.config
|
||||
location = _get_trial_location(trial, result)
|
||||
trial_info = [str(trial), trial.status, str(location)]
|
||||
trial_info += [
|
||||
unflattened_lookup(param, config, default=None) for param in parameters
|
||||
_max_len(
|
||||
unflattened_lookup(param, config, default=None),
|
||||
max_len=max_column_length,
|
||||
add_addr=True,
|
||||
)
|
||||
for param in parameters
|
||||
]
|
||||
trial_info += [
|
||||
unflattened_lookup(metric, result, default=None) for metric in metrics
|
||||
_max_len(
|
||||
unflattened_lookup(metric, result, default=None),
|
||||
max_len=max_column_length,
|
||||
add_addr=True,
|
||||
)
|
||||
for metric in metrics
|
||||
]
|
||||
return trial_info
|
||||
|
||||
|
|
|
@ -694,6 +694,25 @@ class ProgressReporterTest(unittest.TestCase):
|
|||
|
||||
tune.run(lambda config: 2, num_samples=1, progress_reporter=CustomReporter())
|
||||
|
||||
def testMaxLen(self):
|
||||
trials = []
|
||||
for i in range(5):
|
||||
t = Mock()
|
||||
t.status = "TERMINATED"
|
||||
t.trial_id = "%05d" % i
|
||||
t.local_dir = "/foo"
|
||||
t.location = "here"
|
||||
t.config = {"verylong" * 20: i}
|
||||
t.evaluated_params = {"verylong" * 20: i}
|
||||
t.last_result = {"some_metric": "evenlonger" * 100}
|
||||
t.__str__ = lambda self: self.trial_id
|
||||
trials.append(t)
|
||||
|
||||
progress_str = trial_progress_str(
|
||||
trials, metric_columns=["some_metric"], force_table=True
|
||||
)
|
||||
assert any(len(row) <= 90 for row in progress_str.split("\n"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List, Any, Union, Dict, Callable, Tuple, Optional
|
||||
from collections import deque, defaultdict
|
||||
|
||||
import ray
|
||||
from ray.workflow import workflow_context
|
||||
|
@ -12,7 +13,6 @@ from ray.workflow.common import (
|
|||
StepType,
|
||||
)
|
||||
from ray.workflow import workflow_storage
|
||||
from ray.workflow.step_function import WorkflowStepFunction
|
||||
|
||||
|
||||
class WorkflowStepNotRecoverableError(Exception):
|
||||
|
@ -32,136 +32,111 @@ class WorkflowNotResumableError(Exception):
|
|||
super().__init__(self.message)
|
||||
|
||||
|
||||
@WorkflowStepFunction
|
||||
def _recover_workflow_step(
|
||||
input_workflows: List[Any],
|
||||
input_workflow_refs: List[WorkflowRef],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""A workflow step that recovers the output of an unfinished step.
|
||||
|
||||
Args:
|
||||
args: The positional arguments for the step function.
|
||||
kwargs: The keyword args for the step function.
|
||||
input_workflows: The workflows in the argument of the (original) step.
|
||||
They are resolved into physical objects (i.e. the output of the
|
||||
workflows) here. They come from other recover workflows we
|
||||
construct recursively.
|
||||
|
||||
Returns:
|
||||
The output of the recovered step.
|
||||
"""
|
||||
reader = workflow_storage.get_workflow_storage()
|
||||
step_id = workflow_context.get_current_step_id()
|
||||
func: Callable = reader.load_step_func_body(step_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _reconstruct_wait_step(
|
||||
reader: workflow_storage.WorkflowStorage,
|
||||
step_id: StepID,
|
||||
result: workflow_storage.StepInspectResult,
|
||||
input_map: Dict[StepID, Any],
|
||||
):
|
||||
input_workflows = []
|
||||
step_options = result.step_options
|
||||
wait_options = step_options.ray_options.get("wait_options", {})
|
||||
for i, _step_id in enumerate(result.workflows):
|
||||
# Check whether the step has been loaded or not to avoid
|
||||
# duplication
|
||||
if _step_id in input_map:
|
||||
r = input_map[_step_id]
|
||||
else:
|
||||
r = _construct_resume_workflow_from_step(reader, _step_id, input_map)
|
||||
input_map[_step_id] = r
|
||||
if isinstance(r, Workflow):
|
||||
input_workflows.append(r)
|
||||
else:
|
||||
assert isinstance(r, StepID)
|
||||
# TODO (Alex): We should consider caching these outputs too.
|
||||
output = reader.load_step_output(r)
|
||||
# Simulate a workflow with a workflow reference so it could be
|
||||
# used directly by 'workflow.wait'.
|
||||
static_ref = WorkflowStaticRef(step_id=r, ref=ray.put(output))
|
||||
wf = Workflow.from_ref(static_ref)
|
||||
input_workflows.append(wf)
|
||||
|
||||
from ray import workflow
|
||||
|
||||
wait_step = workflow.wait(input_workflows, **wait_options)
|
||||
# override step id
|
||||
wait_step._step_id = step_id
|
||||
return wait_step
|
||||
|
||||
|
||||
def _construct_resume_workflow_from_step(
|
||||
reader: workflow_storage.WorkflowStorage,
|
||||
step_id: StepID,
|
||||
input_map: Dict[StepID, Any],
|
||||
) -> Union[Workflow, StepID]:
|
||||
workflow_id: str, step_id: StepID
|
||||
) -> Union[Workflow, Any]:
|
||||
"""Try to construct a workflow (step) that recovers the workflow step.
|
||||
If the workflow step already has an output checkpointing file, we return
|
||||
the workflow step id instead.
|
||||
|
||||
Args:
|
||||
reader: The storage reader for inspecting the step.
|
||||
workflow_id: The ID of the workflow.
|
||||
step_id: The ID of the step we want to recover.
|
||||
input_map: This is a context storing the input which has been loaded.
|
||||
This context is important for dedupe
|
||||
|
||||
Returns:
|
||||
A workflow that recovers the step, or a ID of a step
|
||||
that contains the output checkpoint file.
|
||||
A workflow that recovers the step, or the output of the step
|
||||
if it has been checkpointed.
|
||||
"""
|
||||
result: workflow_storage.StepInspectResult = reader.inspect_step(step_id)
|
||||
if result.output_object_valid:
|
||||
# we already have the output
|
||||
return step_id
|
||||
if isinstance(result.output_step_id, str):
|
||||
return _construct_resume_workflow_from_step(
|
||||
reader, result.output_step_id, input_map
|
||||
)
|
||||
# output does not exists or not valid. try to reconstruct it.
|
||||
if not result.is_recoverable():
|
||||
raise WorkflowStepNotRecoverableError(step_id)
|
||||
reader = workflow_storage.WorkflowStorage(workflow_id)
|
||||
|
||||
step_options = result.step_options
|
||||
# Process the wait step as a special case.
|
||||
if step_options.step_type == StepType.WAIT:
|
||||
return _reconstruct_wait_step(reader, step_id, result, input_map)
|
||||
# Step 1: construct dependency of the DAG (BFS)
|
||||
inpsect_results = {}
|
||||
dependency_map = defaultdict(list)
|
||||
num_in_edges = {}
|
||||
|
||||
dag_visit_queue = deque([step_id])
|
||||
while dag_visit_queue:
|
||||
s: StepID = dag_visit_queue.popleft()
|
||||
if s in inpsect_results:
|
||||
continue
|
||||
r = reader.inspect_step(s)
|
||||
inpsect_results[s] = r
|
||||
if not r.is_recoverable():
|
||||
raise WorkflowStepNotRecoverableError(s)
|
||||
if r.output_object_valid:
|
||||
deps = []
|
||||
elif isinstance(r.output_step_id, str):
|
||||
deps = [r.output_step_id]
|
||||
else:
|
||||
deps = r.workflows
|
||||
for w in deps:
|
||||
dependency_map[w].append(s)
|
||||
num_in_edges[s] = len(deps)
|
||||
dag_visit_queue.extend(deps)
|
||||
|
||||
# Step 2: topological sort to determine the execution order (Kahn's algorithm)
|
||||
execution_queue: List[StepID] = []
|
||||
|
||||
start_nodes = deque(k for k, v in num_in_edges.items() if v == 0)
|
||||
while start_nodes:
|
||||
n = start_nodes.popleft()
|
||||
execution_queue.append(n)
|
||||
for m in dependency_map[n]:
|
||||
num_in_edges[m] -= 1
|
||||
assert num_in_edges[m] >= 0, (m, n)
|
||||
if num_in_edges[m] == 0:
|
||||
start_nodes.append(m)
|
||||
|
||||
# Step 3: recover the workflow by the order of the execution queue
|
||||
with serialization.objectref_cache():
|
||||
input_workflows = []
|
||||
for i, _step_id in enumerate(result.workflows):
|
||||
# Check whether the step has been loaded or not to avoid
|
||||
# duplication
|
||||
if _step_id in input_map:
|
||||
r = input_map[_step_id]
|
||||
else:
|
||||
r = _construct_resume_workflow_from_step(reader, _step_id, input_map)
|
||||
input_map[_step_id] = r
|
||||
if isinstance(r, Workflow):
|
||||
input_workflows.append(r)
|
||||
else:
|
||||
assert isinstance(r, StepID)
|
||||
# TODO (Alex): We should consider caching these outputs too.
|
||||
input_workflows.append(reader.load_step_output(r))
|
||||
workflow_refs = list(map(WorkflowRef, result.workflow_refs))
|
||||
# "input_map" is a context storing the input which has been loaded.
|
||||
# This context is important for deduplicate step inputs.
|
||||
input_map: Dict[StepID, Any] = {}
|
||||
|
||||
args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs)
|
||||
# Note: we must uppack args and kwargs, so the refs in the args/kwargs can get
|
||||
# resolved consistently like in Ray.
|
||||
recovery_workflow: Workflow = _recover_workflow_step.step(
|
||||
input_workflows,
|
||||
workflow_refs,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
recovery_workflow._step_id = step_id
|
||||
# override step_options
|
||||
recovery_workflow.data.step_options = step_options
|
||||
return recovery_workflow
|
||||
for _step_id in execution_queue:
|
||||
result = inpsect_results[_step_id]
|
||||
if result.output_object_valid:
|
||||
input_map[_step_id] = reader.load_step_output(_step_id)
|
||||
continue
|
||||
if isinstance(result.output_step_id, str):
|
||||
input_map[_step_id] = input_map[result.output_step_id]
|
||||
continue
|
||||
|
||||
# Process the wait step as a special case.
|
||||
if result.step_options.step_type == StepType.WAIT:
|
||||
wait_input_workflows = []
|
||||
for w in result.workflows:
|
||||
output = input_map[w]
|
||||
if isinstance(output, Workflow):
|
||||
wait_input_workflows.append(output)
|
||||
else:
|
||||
# Simulate a workflow with a workflow reference so it could be
|
||||
# used directly by 'workflow.wait'.
|
||||
static_ref = WorkflowStaticRef(step_id=w, ref=ray.put(output))
|
||||
wait_input_workflows.append(Workflow.from_ref(static_ref))
|
||||
recovery_workflow = ray.workflow.wait(
|
||||
wait_input_workflows,
|
||||
**result.step_options.ray_options.get("wait_options", {}),
|
||||
)
|
||||
else:
|
||||
args, kwargs = reader.load_step_args(
|
||||
_step_id,
|
||||
workflows=[input_map[w] for w in result.workflows],
|
||||
workflow_refs=list(map(WorkflowRef, result.workflow_refs)),
|
||||
)
|
||||
func: Callable = reader.load_step_func_body(_step_id)
|
||||
# TODO(suquark): Use an alternative function when "workflow.step"
|
||||
# is fully deprecated.
|
||||
recovery_workflow = ray.workflow.step(func).step(*args, **kwargs)
|
||||
|
||||
# override step_options
|
||||
recovery_workflow._step_id = _step_id
|
||||
recovery_workflow.data.step_options = result.step_options
|
||||
|
||||
input_map[_step_id] = recovery_workflow
|
||||
|
||||
# Step 4: return the output of the requested step
|
||||
return input_map[step_id]
|
||||
|
||||
|
||||
@ray.remote(num_returns=2)
|
||||
|
@ -183,21 +158,19 @@ def _resume_workflow_step_executor(
|
|||
except Exception:
|
||||
pass
|
||||
try:
|
||||
wf_store = workflow_storage.WorkflowStorage(workflow_id)
|
||||
r = _construct_resume_workflow_from_step(wf_store, step_id, {})
|
||||
r = _construct_resume_workflow_from_step(workflow_id, step_id)
|
||||
except Exception as e:
|
||||
raise WorkflowNotResumableError(workflow_id) from e
|
||||
|
||||
if isinstance(r, Workflow):
|
||||
with workflow_context.workflow_step_context(
|
||||
workflow_id, last_step_of_workflow=True
|
||||
):
|
||||
from ray.workflow.step_executor import execute_workflow
|
||||
if not isinstance(r, Workflow):
|
||||
return r, None
|
||||
with workflow_context.workflow_step_context(
|
||||
workflow_id, last_step_of_workflow=True
|
||||
):
|
||||
from ray.workflow.step_executor import execute_workflow
|
||||
|
||||
result = execute_workflow(job_id, r)
|
||||
return result.persisted_output, result.volatile_output
|
||||
assert isinstance(r, StepID)
|
||||
return wf_store.load_step_output(r), None
|
||||
result = execute_workflow(job_id, r)
|
||||
return result.persisted_output, result.volatile_output
|
||||
|
||||
|
||||
def resume_workflow_step(
|
||||
|
|
|
@ -94,6 +94,7 @@ tqdm
|
|||
async-exit-stack
|
||||
async-generator
|
||||
cryptography>=3.0.0
|
||||
proxy.py
|
||||
# For doc tests
|
||||
myst-parser==0.15.2
|
||||
myst-nb==0.13.1
|
||||
|
|
|
@ -95,7 +95,7 @@ def maybe_fetch_buildkite_token() -> str:
|
|||
return buildkite_token
|
||||
|
||||
print("Missing BUILDKITE_TOKEN, retrieving from AWS secrets store")
|
||||
os.environ["BUILDKITE_TOKEN"] = boto3.client(
|
||||
buildkite_token = boto3.client(
|
||||
"secretsmanager", region_name="us-west-2"
|
||||
).get_secret_value(
|
||||
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
|
||||
|
@ -103,6 +103,8 @@ def maybe_fetch_buildkite_token() -> str:
|
|||
)[
|
||||
"SecretString"
|
||||
]
|
||||
os.environ["BUILDKITE_TOKEN"] = buildkite_token
|
||||
return buildkite_token
|
||||
|
||||
|
||||
def get_results_from_build_collection(
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from ray.rllib.agents.alpha_star.alpha_star import DEFAULT_CONFIG, AlphaStarTrainer
|
||||
from ray.rllib.agents.alpha_star.alpha_star import (
|
||||
AlphaStarConfig,
|
||||
AlphaStarTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CONFIG",
|
||||
"AlphaStarConfig",
|
||||
"AlphaStarTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
|
|
@ -3,8 +3,8 @@ A multi-agent, distributed multi-GPU, league-capable asynch. PPO
|
|||
================================================================
|
||||
"""
|
||||
import gym
|
||||
from typing import Optional, Type
|
||||
import tree
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
|
||||
import ray
|
||||
|
@ -19,6 +19,7 @@ from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentRepla
|
|||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
|
@ -42,36 +43,60 @@ from ray.rllib.utils.typing import (
|
|||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util.timer import _Timer
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the `IMPALATrainer` config in
|
||||
# rllib/agents/impala/impala.py.
|
||||
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||
appo.DEFAULT_CONFIG, # See keys in appo.py, which are also supported.
|
||||
{
|
||||
# TODO: Unify the buffer API, then clean up our existing
|
||||
# implementations of different buffers.
|
||||
# This is num batches held at any time for each policy.
|
||||
"replay_buffer_capacity": 20,
|
||||
# e.g. ratio=0.2 -> 20% of samples in each train batch are
|
||||
# old (replayed) ones.
|
||||
"replay_buffer_replay_ratio": 0.5,
|
||||
class AlphaStarConfig(appo.APPOConfig):
|
||||
"""Defines a configuration class from which an AlphaStarTrainer can be built.
|
||||
|
||||
# Timeout to use for `ray.wait()` when waiting for samplers to have placed
|
||||
# new data into the buffers. If no samples are ready within the timeout,
|
||||
# the buffers used for mixin-sampling will return only older samples.
|
||||
"sample_wait_timeout": 0.01,
|
||||
# Timeout to use for `ray.wait()` when waiting for the policy learner actors
|
||||
# to have performed an update and returned learning stats. If no learner
|
||||
# actors have produced any learning results in the meantime, their
|
||||
# learner-stats in the results will be empty for that iteration.
|
||||
"learn_wait_timeout": 0.1,
|
||||
Example:
|
||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
|
||||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||
>>> from ray import tune
|
||||
>>> config = AlphaStarConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.vtrace)
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0)
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "AlphaStar",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a AlphaStarConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or AlphaStarTrainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# AlphaStar specific settings:
|
||||
self.replay_buffer_capacity = 20
|
||||
self.replay_buffer_replay_ratio = 0.5
|
||||
self.sample_wait_timeout = 0.01
|
||||
self.learn_wait_timeout = 0.1
|
||||
|
||||
# League-building parameters.
|
||||
# The LeagueBuilder class to be used for league building logic.
|
||||
"league_builder_config": {
|
||||
self.league_builder_config = {
|
||||
# Specify the sub-class of the `LeagueBuilder` API to use.
|
||||
"type": AlphaStarLeagueBuilder,
|
||||
|
||||
# Any any number of constructor kwargs to pass to this class:
|
||||
|
||||
# The number of random policies to add to the league. This must be an
|
||||
# even number (including 0) as these will be evenly distributed
|
||||
# amongst league- and main- exploiters.
|
||||
|
@ -100,28 +125,80 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# Only for ME matches: Prob to play against learning
|
||||
# main (vs a snapshot main).
|
||||
"prob_main_exploiter_playing_against_learning_main": 0.5,
|
||||
},
|
||||
}
|
||||
self.max_num_policies_to_train = None
|
||||
|
||||
# The maximum number of trainable policies for this Trainer.
|
||||
# Each trainable policy will exist as a independent remote actor, co-located
|
||||
# with a replay buffer. This is besides its existence inside
|
||||
# the RolloutWorkers for training and evaluation.
|
||||
# Set to None for automatically inferring this value from the number of
|
||||
# trainable policies found in the `multiagent` config.
|
||||
"max_num_policies_to_train": None,
|
||||
# Override some of APPOConfig's default values with AlphaStar-specific
|
||||
# values.
|
||||
self.vtrace_drop_last_ts = False
|
||||
self.min_time_s_per_reporting = 2
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# By default, don't drop last timestep.
|
||||
# TODO: We should do the same for IMPALA and APPO at some point.
|
||||
"vtrace_drop_last_ts": False,
|
||||
@override(appo.APPOConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
replay_buffer_capacity: Optional[int] = None,
|
||||
replay_buffer_replay_ratio: Optional[float] = None,
|
||||
sample_wait_timeout: Optional[float] = None,
|
||||
learn_wait_timeout: Optional[float] = None,
|
||||
league_builder_config: Optional[Dict[str, Any]] = None,
|
||||
max_num_policies_to_train: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "AlphaStarConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
# Reporting interval.
|
||||
"min_time_s_per_reporting": 2,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
Args:
|
||||
replay_buffer_capacity: This is num batches held at any time for each
|
||||
policy.
|
||||
replay_buffer_replay_ratio: For example, ratio=0.2 -> 20% of samples in
|
||||
each train batch are old (replayed) ones.
|
||||
sample_wait_timeout: Timeout to use for `ray.wait()` when waiting for
|
||||
samplers to have placed new data into the buffers. If no samples are
|
||||
ready within the timeout, the buffers used for mixin-sampling will
|
||||
return only older samples.
|
||||
learn_wait_timeout: Timeout to use for `ray.wait()` when waiting for the
|
||||
policy learner actors to have performed an update and returned learning
|
||||
stats. If no learner actors have produced any learning results in the
|
||||
meantime, their learner-stats in the results will be empty for that
|
||||
iteration.
|
||||
league_builder_config: League-building config dict.
|
||||
The dict Must contain a `type` key indicating the LeagueBuilder class
|
||||
to be used for league building logic. All other keys (that are not
|
||||
`type`) will be used as constructor kwargs on the given class to
|
||||
construct the LeagueBuilder instance. See the
|
||||
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||
(used by default by this algo) as an example.
|
||||
max_num_policies_to_train: The maximum number of trainable policies for this
|
||||
Trainer. Each trainable policy will exist as a independent remote actor,
|
||||
co-located with a replay buffer. This is besides its existence inside
|
||||
the RolloutWorkers for training and evaluation. Set to None for
|
||||
automatically inferring this value from the number of trainable
|
||||
policies found in the `multiagent` config.
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
# TODO: Unify the buffer API, then clean up our existing
|
||||
# implementations of different buffers.
|
||||
if replay_buffer_capacity is not None:
|
||||
self.replay_buffer_capacity = replay_buffer_capacity
|
||||
if replay_buffer_replay_ratio is not None:
|
||||
self.replay_buffer_replay_ratio = replay_buffer_replay_ratio
|
||||
if sample_wait_timeout is not None:
|
||||
self.sample_wait_timeout = sample_wait_timeout
|
||||
if learn_wait_timeout is not None:
|
||||
self.learn_wait_timeout = learn_wait_timeout
|
||||
if league_builder_config is not None:
|
||||
self.league_builder_config = league_builder_config
|
||||
if max_num_policies_to_train is not None:
|
||||
self.max_num_policies_to_train = max_num_policies_to_train
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class AlphaStarTrainer(appo.APPOTrainer):
|
||||
|
@ -204,7 +281,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
@classmethod
|
||||
@override(appo.APPOTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return AlphaStarConfig().to_dict()
|
||||
|
||||
@override(appo.APPOTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict):
|
||||
|
@ -499,3 +576,20 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
state_copy = state.copy()
|
||||
self.league_builder.__setstate__(state.pop("league_builder", {}))
|
||||
super().__setstate__(state_copy)
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(AlphaStarConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -80,9 +80,15 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
|
|||
):
|
||||
"""Initializes a AlphaStarLeagueBuilder instance.
|
||||
|
||||
The following match types are possible:
|
||||
LE: A learning (not snapshot) league_exploiter vs any snapshot policy.
|
||||
ME: A learning (not snapshot) main exploiter vs any main.
|
||||
M: Main self-play (main vs main).
|
||||
|
||||
Args:
|
||||
trainer: The Trainer object by which this league builder is used.
|
||||
Trainer calls `build_league()` after each training step.
|
||||
Trainer calls `build_league()` after each training step to reconfigure
|
||||
the league structure (e.g. to add/remove policies).
|
||||
trainer_config: The (not yet validated) config dict to be
|
||||
used on the Trainer. Child classes of `LeagueBuilder`
|
||||
should preprocess this to add e.g. multiagent settings
|
||||
|
|
|
@ -26,38 +26,33 @@ class TestAlphaStar(unittest.TestCase):
|
|||
|
||||
def test_alpha_star_compilation(self):
|
||||
"""Test whether a AlphaStarTrainer can be built with all frameworks."""
|
||||
|
||||
config = {
|
||||
"env": "connect_four",
|
||||
"gamma": 1.0,
|
||||
"num_workers": 4,
|
||||
"num_envs_per_worker": 5,
|
||||
"model": {
|
||||
"fcnet_hiddens": [256, 256, 256],
|
||||
},
|
||||
"vf_loss_coeff": 0.01,
|
||||
"entropy_coeff": 0.004,
|
||||
"league_builder_config": {
|
||||
"win_rate_threshold_for_new_snapshot": 0.8,
|
||||
"num_random_policies": 2,
|
||||
"num_learning_league_exploiters": 1,
|
||||
"num_learning_main_exploiters": 1,
|
||||
},
|
||||
"grad_clip": 10.0,
|
||||
"replay_buffer_capacity": 10,
|
||||
"replay_buffer_replay_ratio": 0.0,
|
||||
# Two GPUs -> 2 policies per GPU.
|
||||
"num_gpus": 4,
|
||||
"_fake_gpus": True,
|
||||
# Test with KL loss, just to cover that extra code.
|
||||
"use_kl_loss": True,
|
||||
}
|
||||
config = (
|
||||
alpha_star.AlphaStarConfig()
|
||||
.environment(env="connect_four")
|
||||
.training(
|
||||
gamma=1.0,
|
||||
model={"fcnet_hiddens": [256, 256, 256]},
|
||||
vf_loss_coeff=0.01,
|
||||
entropy_coeff=0.004,
|
||||
league_builder_config={
|
||||
"win_rate_threshold_for_new_snapshot": 0.8,
|
||||
"num_random_policies": 2,
|
||||
"num_learning_league_exploiters": 1,
|
||||
"num_learning_main_exploiters": 1,
|
||||
},
|
||||
grad_clip=10.0,
|
||||
replay_buffer_capacity=10,
|
||||
replay_buffer_replay_ratio=0.0,
|
||||
use_kl_loss=True,
|
||||
)
|
||||
.rollouts(num_rollout_workers=4, num_envs_per_worker=5)
|
||||
.resources(num_gpus=4, _fake_gpus=True)
|
||||
)
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
_config = config.copy()
|
||||
trainer = alpha_star.AlphaStarTrainer(config=_config)
|
||||
trainer = config.build()
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
## Overview
|
||||
|
||||
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based enviornments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
|
||||
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based environments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
|
||||
|
||||
Dreamer is a gradient-based RL algorithm. This means that the agent imagines ahead using its learned transition dynamics model (PlaNet) to discover new rewards and states. Because imagining ahead is fully differentiable, the RL objective (maximizing the sum of rewards) is fully differentiable and does not need to be optimized indirectly such as policy gradient methods. This feature of gradient-based learning, in conjunction with PlaNet, enables the agent to learn in a latent space and achieves much better sample complexity and performance than other visual-based agents.
|
||||
|
||||
|
@ -14,6 +14,6 @@ For more details, there is a Ray/RLlib [blogpost](https://medium.com/distributed
|
|||
|
||||
Dreamer.
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#dqn)**
|
||||
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#dreamer)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/simple_q.py)**
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dreamer/dreamer.py)**
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from ray.rllib.agents.dreamer.dreamer import DREAMERTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.dreamer.dreamer import (
|
||||
DREAMERConfig,
|
||||
DREAMERTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DREAMERConfig",
|
||||
"DREAMERTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
|
|
@ -1,97 +1,209 @@
|
|||
import logging
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
SampleBatchType,
|
||||
TrainerConfigDict,
|
||||
ResultDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# PlaNET Model LR
|
||||
"td_model_lr": 6e-4,
|
||||
# Actor LR
|
||||
"actor_lr": 8e-5,
|
||||
# Critic LR
|
||||
"critic_lr": 8e-5,
|
||||
# Grad Clipping
|
||||
"grad_clip": 100.0,
|
||||
# Discount
|
||||
"discount": 0.99,
|
||||
# Lambda
|
||||
"lambda": 0.95,
|
||||
# Clipping is done inherently via policy tanh.
|
||||
"clip_actions": False,
|
||||
# Training iterations per data collection from real env
|
||||
"dreamer_train_iters": 100,
|
||||
# Horizon for Enviornment (1000 for Mujoco/DMC)
|
||||
"horizon": 1000,
|
||||
# Number of episodes to sample for Loss Calculation
|
||||
"batch_size": 50,
|
||||
# Length of each episode to sample for Loss Calculation
|
||||
"batch_length": 50,
|
||||
# Imagination Horizon for Training Actor and Critic
|
||||
"imagine_horizon": 15,
|
||||
# Free Nats
|
||||
"free_nats": 3.0,
|
||||
# KL Coeff for the Model Loss
|
||||
"kl_coeff": 1.0,
|
||||
# Distributed Dreamer not implemented yet
|
||||
"num_workers": 0,
|
||||
# Prefill Timesteps
|
||||
"prefill_timesteps": 5000,
|
||||
# This should be kept at 1 to preserve sample efficiency
|
||||
"num_envs_per_worker": 1,
|
||||
# Exploration Gaussian
|
||||
"explore_noise": 0.3,
|
||||
# Batch mode
|
||||
"batch_mode": "complete_episodes",
|
||||
# Custom Model
|
||||
"dreamer_model": {
|
||||
"custom_model": DreamerModel,
|
||||
# RSSM/PlaNET parameters
|
||||
"deter_size": 200,
|
||||
"stoch_size": 30,
|
||||
# CNN Decoder Encoder
|
||||
"depth_size": 32,
|
||||
# General Network Parameters
|
||||
"hidden_size": 400,
|
||||
# Action STD
|
||||
"action_init_std": 5.0,
|
||||
},
|
||||
|
||||
"env_config": {
|
||||
# Repeats action send by policy for frame_skip times in env
|
||||
"frame_skip": 2,
|
||||
},
|
||||
class DREAMERConfig(TrainerConfig):
|
||||
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
||||
|
||||
# Use `execution_plan` instead of `training_iteration`.
|
||||
"_disable_execution_plan_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
Example:
|
||||
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
||||
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
|
||||
... .resources(num_gpus=0)\
|
||||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray import tune
|
||||
>>> from ray.rllib.agents.dreamer import DREAMERConfig
|
||||
>>> config = DREAMERConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.clip_param)
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "DREAMER",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a PPOConfig instance."""
|
||||
super().__init__(trainer_class=DREAMERTrainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# Dreamer specific settings:
|
||||
self.td_model_lr = 6e-4
|
||||
self.actor_lr = 8e-5
|
||||
self.critic_lr = 8e-5
|
||||
self.grad_clip = 100.0
|
||||
self.lambda_ = 0.95
|
||||
self.dreamer_train_iters = 100
|
||||
self.batch_size = 50
|
||||
self.batch_length = 50
|
||||
self.imagine_horizon = 15
|
||||
self.free_nats = 3.0
|
||||
self.kl_coeff = 1.0
|
||||
self.prefill_timesteps = 5000
|
||||
self.explore_noise = 0.3
|
||||
self.dreamer_model = {
|
||||
"custom_model": DreamerModel,
|
||||
# RSSM/PlaNET parameters
|
||||
"deter_size": 200,
|
||||
"stoch_size": 30,
|
||||
# CNN Decoder Encoder
|
||||
"depth_size": 32,
|
||||
# General Network Parameters
|
||||
"hidden_size": 400,
|
||||
# Action STD
|
||||
"action_init_std": 5.0,
|
||||
}
|
||||
|
||||
# Override some of TrainerConfig's default values with PPO-specific values.
|
||||
# .rollouts()
|
||||
self.num_workers = 0
|
||||
self.num_envs_per_worker = 1
|
||||
self.horizon = 1000
|
||||
self.batch_mode = "complete_episodes"
|
||||
self.clip_actions = False
|
||||
|
||||
# .training()
|
||||
self.gamma = 0.99
|
||||
|
||||
# .environment()
|
||||
self.env_config = {
|
||||
# Repeats action send by policy for frame_skip times in env
|
||||
"frame_skip": 2,
|
||||
}
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(TrainerConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
td_model_lr: Optional[float] = None,
|
||||
actor_lr: Optional[float] = None,
|
||||
critic_lr: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
lambda_: Optional[float] = None,
|
||||
dreamer_train_iters: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
batch_length: Optional[int] = None,
|
||||
imagine_horizon: Optional[int] = None,
|
||||
free_nats: Optional[float] = None,
|
||||
kl_coeff: Optional[float] = None,
|
||||
prefill_timesteps: Optional[int] = None,
|
||||
explore_noise: Optional[float] = None,
|
||||
dreamer_model: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> "DREAMERConfig":
|
||||
"""
|
||||
|
||||
Args:
|
||||
td_model_lr: PlaNET (transition dynamics) model learning rate.
|
||||
actor_lr: Actor model learning rate.
|
||||
critic_lr: Critic model learning rate.
|
||||
grad_clip: If specified, clip the global norm of gradients by this amount.
|
||||
lambda_: The GAE (lambda) parameter.
|
||||
dreamer_train_iters: Training iterations per data collection from real env.
|
||||
batch_size: Number of episodes to sample for loss calculation.
|
||||
batch_length: Length of each episode to sample for loss calculation.
|
||||
imagine_horizon: Imagination horizon for training Actor and Critic.
|
||||
free_nats: Free nats.
|
||||
kl_coeff: KL coefficient for the model Loss.
|
||||
prefill_timesteps: Prefill timesteps.
|
||||
explore_noise: Exploration Gaussian noise.
|
||||
dreamer_model: Custom model config.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if td_model_lr is not None:
|
||||
self.td_model_lr = td_model_lr
|
||||
if actor_lr is not None:
|
||||
self.actor_lr = actor_lr
|
||||
if critic_lr is not None:
|
||||
self.critic_lr = critic_lr
|
||||
if grad_clip is not None:
|
||||
self.grad_clip = grad_clip
|
||||
if lambda_ is not None:
|
||||
self.lambda_ = lambda_
|
||||
if dreamer_train_iters is not None:
|
||||
self.dreamer_train_iters = dreamer_train_iters
|
||||
if batch_size is not None:
|
||||
self.batch_size = batch_size
|
||||
if batch_length is not None:
|
||||
self.batch_length = batch_length
|
||||
if imagine_horizon is not None:
|
||||
self.imagine_horizon = imagine_horizon
|
||||
if free_nats is not None:
|
||||
self.free_nats = free_nats
|
||||
if kl_coeff is not None:
|
||||
self.kl_coeff = kl_coeff
|
||||
if prefill_timesteps is not None:
|
||||
self.prefill_timesteps = prefill_timesteps
|
||||
if explore_noise is not None:
|
||||
self.explore_noise = explore_noise
|
||||
if dreamer_model is not None:
|
||||
self.dreamer_model = dreamer_model
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _postprocess_gif(gif: np.ndarray):
|
||||
"""Process provided gif to a format that can be logged to Tensorboard."""
|
||||
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
|
||||
B, T, C, H, W = gif.shape
|
||||
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
|
||||
return frames
|
||||
|
||||
|
||||
class EpisodicBuffer(object):
|
||||
def __init__(self, max_length: int = 1000, length: int = 50):
|
||||
"""Data structure that stores episodes and samples chunks
|
||||
of size length from episodes
|
||||
"""Stores episodes and samples chunks of size ``length`` from episodes.
|
||||
|
||||
Args:
|
||||
max_length: Maximum episodes it can store
|
||||
length: Episode chunking lengh in sample()
|
||||
length: Episode chunking length in sample()
|
||||
"""
|
||||
|
||||
# Stores all episodes into a list: List[SampleBatchType]
|
||||
|
@ -101,8 +213,7 @@ class EpisodicBuffer(object):
|
|||
self.length = length
|
||||
|
||||
def add(self, batch: SampleBatchType):
|
||||
"""Splits a SampleBatch into episodes and adds episodes
|
||||
to the episode buffer
|
||||
"""Splits a SampleBatch into episodes and adds episodes to the episode buffer.
|
||||
|
||||
Args:
|
||||
batch: SampleBatch to be added
|
||||
|
@ -151,7 +262,6 @@ class DreamerIteration:
|
|||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, samples):
|
||||
|
||||
# Dreamer training loop.
|
||||
for n in range(self.dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
|
||||
|
@ -161,7 +271,7 @@ class DreamerIteration:
|
|||
fetches = self.worker.learn_on_batch(batch)
|
||||
|
||||
# Custom Logging
|
||||
policy_fetches = self.policy_stats(fetches)
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
if "log_gif" in policy_fetches:
|
||||
gif = policy_fetches["log_gif"]
|
||||
policy_fetches["log_gif"] = self.postprocess_gif(gif)
|
||||
|
@ -180,20 +290,14 @@ class DreamerIteration:
|
|||
return res
|
||||
|
||||
def postprocess_gif(self, gif: np.ndarray):
|
||||
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
|
||||
B, T, C, H, W = gif.shape
|
||||
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
|
||||
return frames
|
||||
|
||||
def policy_stats(self, fetches):
|
||||
return fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
return _postprocess_gif(gif=gif)
|
||||
|
||||
|
||||
class DREAMERTrainer(Trainer):
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return DREAMERConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -211,6 +315,11 @@ class DREAMERTrainer(Trainer):
|
|||
raise ValueError("Distributed Dreamer not supported yet!")
|
||||
if config["clip_actions"]:
|
||||
raise ValueError("Clipping is done inherently via policy tanh!")
|
||||
if config["dreamer_train_iters"] <= 0:
|
||||
raise ValueError(
|
||||
"`dreamer_train_iters` must be a positive integer. "
|
||||
f"Received {config['dreamer_train_iters']} instead."
|
||||
)
|
||||
if config["action_repeat"] > 1:
|
||||
config["horizon"] = config["horizon"] / config["action_repeat"]
|
||||
|
||||
|
@ -218,6 +327,22 @@ class DREAMERTrainer(Trainer):
|
|||
def get_default_policy_class(self, config: TrainerConfigDict):
|
||||
return DreamerTorchPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
# `training_iteration` implementation: Setup buffer in `setup`, not
|
||||
# in `execution_plan` (deprecated).
|
||||
if self.config["_disable_execution_plan_api"] is True:
|
||||
self.local_replay_buffer = EpisodicBuffer(length=config["batch_length"])
|
||||
|
||||
# Prefill episode buffer with initial exploration (uniform sampling)
|
||||
while (
|
||||
total_sampled_timesteps(self.workers.local_worker())
|
||||
< self.config["prefill_timesteps"]
|
||||
):
|
||||
samples = self.workers.local_worker().sample()
|
||||
self.local_replay_buffer.add(samples)
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
|
@ -250,3 +375,61 @@ class DREAMERTrainer(Trainer):
|
|||
)
|
||||
)
|
||||
return rollouts
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
# Number of sub-iterations for Dreamer
|
||||
dreamer_train_iters = self.config["dreamer_train_iters"]
|
||||
batch_size = self.config["batch_size"]
|
||||
action_repeat = self.config["action_repeat"]
|
||||
|
||||
# Collect SampleBatches from rollout workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
|
||||
fetches = {}
|
||||
|
||||
# Dreamer training loop.
|
||||
# Run multiple sub-iterations for each training iteration.
|
||||
for n in range(dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{dreamer_train_iters}")
|
||||
batch = self.local_replay_buffer.sample(batch_size)
|
||||
fetches = local_worker.learn_on_batch(batch)
|
||||
|
||||
if fetches:
|
||||
# Custom Logging
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
if "log_gif" in policy_fetches:
|
||||
gif = policy_fetches["log_gif"]
|
||||
policy_fetches["log_gif"] = self._postprocess_gif(gif)
|
||||
|
||||
self._counters[STEPS_SAMPLED_COUNTER] = (
|
||||
self.local_replay_buffer.timesteps * action_repeat
|
||||
)
|
||||
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
return fetches
|
||||
|
||||
def _compile_step_results(self, *args, **kwargs):
|
||||
results = super()._compile_step_results(*args, **kwargs)
|
||||
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
|
||||
return results
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(DREAMERConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -7,12 +7,13 @@ import ray
|
|||
from ray.rllib.agents.dreamer.utils import FreezeParameters
|
||||
from ray.rllib.evaluation.episode import Episode
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy_template import build_policy_class
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
||||
from ray.rllib.utils.typing import AgentID
|
||||
from ray.rllib.utils.typing import AgentID, TensorType
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
if torch:
|
||||
|
@ -23,30 +24,30 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# This is the computation graph for workers (inner adaptation steps)
|
||||
def compute_dreamer_loss(
|
||||
obs,
|
||||
action,
|
||||
reward,
|
||||
model,
|
||||
imagine_horizon,
|
||||
discount=0.99,
|
||||
lambda_=0.95,
|
||||
kl_coeff=1.0,
|
||||
free_nats=3.0,
|
||||
log=False,
|
||||
obs: TensorType,
|
||||
action: TensorType,
|
||||
reward: TensorType,
|
||||
model: TorchModelV2,
|
||||
imagine_horizon: int,
|
||||
gamma: float = 0.99,
|
||||
lambda_: float = 0.95,
|
||||
kl_coeff: float = 1.0,
|
||||
free_nats: float = 3.0,
|
||||
log: bool = False,
|
||||
):
|
||||
"""Constructs loss for the Dreamer objective
|
||||
"""Constructs loss for the Dreamer objective.
|
||||
|
||||
Args:
|
||||
obs (TensorType): Observations (o_t)
|
||||
action (TensorType): Actions (a_(t-1))
|
||||
reward (TensorType): Rewards (r_(t-1))
|
||||
model (TorchModelV2): DreamerModel, encompassing all other models
|
||||
imagine_horizon (int): Imagine horizon for actor and critic loss
|
||||
discount (float): Discount
|
||||
lambda_ (float): Lambda, like in GAE
|
||||
kl_coeff (float): KL Coefficient for Divergence loss in model loss
|
||||
free_nats (float): Threshold for minimum divergence in model loss
|
||||
log (bool): If log, generate gifs
|
||||
obs: Observations (o_t).
|
||||
action: Actions (a_(t-1)).
|
||||
reward: Rewards (r_(t-1)).
|
||||
model: DreamerModel, encompassing all other models.
|
||||
imagine_horizon: Imagine horizon for actor and critic loss.
|
||||
gamma: Discount factor gamma.
|
||||
lambda_: Lambda, like in GAE.
|
||||
kl_coeff: KL Coefficient for Divergence loss in model loss.
|
||||
free_nats: Threshold for minimum divergence in model loss.
|
||||
log: If log, generate gifs.
|
||||
"""
|
||||
encoder_weights = list(model.encoder.parameters())
|
||||
decoder_weights = list(model.decoder.parameters())
|
||||
|
@ -84,7 +85,7 @@ def compute_dreamer_loss(
|
|||
with FreezeParameters(model_weights + critic_weights):
|
||||
reward = model.reward(imag_feat).mean
|
||||
value = model.value(imag_feat).mean
|
||||
pcont = discount * torch.ones_like(reward)
|
||||
pcont = gamma * torch.ones_like(reward)
|
||||
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
|
||||
discount_shape = pcont[:1].size()
|
||||
discount = torch.cumprod(
|
||||
|
@ -168,7 +169,7 @@ def dreamer_loss(policy, model, dist_class, train_batch):
|
|||
train_batch["rewards"],
|
||||
policy.model,
|
||||
policy.config["imagine_horizon"],
|
||||
policy.config["discount"],
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
policy.config["kl_coeff"],
|
||||
policy.config["free_nats"],
|
||||
|
|
|
@ -18,23 +18,24 @@ class TestDreamer(unittest.TestCase):
|
|||
|
||||
def test_dreamer_compilation(self):
|
||||
"""Test whether an DreamerTrainer can be built with all frameworks."""
|
||||
config = dreamer.DEFAULT_CONFIG.copy()
|
||||
config["env_config"] = {
|
||||
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
|
||||
"action_space": Box(-1.0, 1.0, (3,)),
|
||||
}
|
||||
config = dreamer.DREAMERConfig()
|
||||
config.environment(
|
||||
env=RandomEnv,
|
||||
env_config={
|
||||
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
|
||||
"action_space": Box(-1.0, 1.0, (3,)),
|
||||
},
|
||||
)
|
||||
# Num episode chunks per batch.
|
||||
config["batch_size"] = 2
|
||||
# Length (ts) of an episode chunk in a batch.
|
||||
config["batch_length"] = 20
|
||||
# Sub-iterations per .train() call.
|
||||
config["dreamer_train_iters"] = 4
|
||||
config.training(batch_size=2, batch_length=20, dreamer_train_iters=4)
|
||||
|
||||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = dreamer.DREAMERTrainer(config=config, env=RandomEnv)
|
||||
trainer = config.build()
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
|
|
@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ImpalaConfig(TrainerConfig):
|
||||
"""Defines an ARSTrainer configuration class from which an ImpalaTrainer can be built.
|
||||
"""Defines a configuration class from which an ImpalaTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.impala import ImpalaConfig
|
||||
|
@ -136,13 +136,10 @@ class ImpalaConfig(TrainerConfig):
|
|||
self.num_gpus = 1
|
||||
self.lr = 0.0005
|
||||
self.min_time_s_per_reporting = 10
|
||||
# IMPALA and APPO are not on the new training_iteration API yet.
|
||||
self._disable_execution_plan_api = False
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Deprecated value.
|
||||
self._disable_execution_plan_api = True
|
||||
self.num_data_loader_buffers = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
|
|
|
@ -159,6 +159,14 @@ RAY_CONFIG(int64_t, max_grpc_message_size, 250 * 1024 * 1024)
|
|||
// of retries is non zero.
|
||||
RAY_CONFIG(int64_t, grpc_server_retry_timeout_milliseconds, 1000)
|
||||
|
||||
// Whether to allow HTTP proxy on GRPC clients. Disable HTTP proxy by default since it
|
||||
// disrupts local connections. Note that this config item only controls GrpcClient in
|
||||
// `src/ray/rpc/grpc_client.h`. Python GRPC clients are not directly controlled by this.
|
||||
// NOTE (kfstorm): DO NOT set this config item via `_system_config`, use
|
||||
// `RAY_grpc_enable_http_proxy` environment variable instead so that it can be passed to
|
||||
// non-C++ children processes such as dashboard agent.
|
||||
RAY_CONFIG(bool, grpc_enable_http_proxy, false)
|
||||
|
||||
// The min number of retries for direct actor creation tasks. The actual number
|
||||
// of creation retries will be MAX(actor_creation_min_retries, max_restarts).
|
||||
RAY_CONFIG(uint64_t, actor_creation_min_retries, 3)
|
||||
|
|
|
@ -357,9 +357,8 @@ struct SyncerServerTest {
|
|||
|
||||
std::shared_ptr<grpc::Channel> MakeChannel(std::string port) {
|
||||
grpc::ChannelArguments argument;
|
||||
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
|
||||
// this configurable, or selectively set it for known local connections only.
|
||||
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
|
||||
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
|
||||
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
|
||||
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
|
||||
|
|
|
@ -116,9 +116,8 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
if (leader_port != ".") {
|
||||
grpc::ChannelArguments argument;
|
||||
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
|
||||
// this configurable, or selectively set it for known local connections only.
|
||||
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
|
||||
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
|
||||
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
|
||||
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ray/gcs/gcs_server/gcs_worker_manager.h"
|
||||
#include "ray/gcs/gcs_server/stats_handler_impl.h"
|
||||
#include "ray/gcs/gcs_server/store_client_kv.h"
|
||||
#include "ray/gcs/store_client/observable_store_client.h"
|
||||
#include "ray/pubsub/publisher.h"
|
||||
|
||||
namespace ray {
|
||||
|
@ -496,8 +497,9 @@ void GcsServer::InitKVManager() {
|
|||
if (storage_type_ == "redis") {
|
||||
instance = std::make_unique<RedisInternalKV>(GetRedisClientOptions());
|
||||
} else if (storage_type_ == "memory") {
|
||||
instance = std::make_unique<StoreClientInternalKV>(
|
||||
std::make_unique<InMemoryStoreClient>(main_service_));
|
||||
instance =
|
||||
std::make_unique<StoreClientInternalKV>(std::make_unique<ObservableStoreClient>(
|
||||
std::make_unique<InMemoryStoreClient>(main_service_)));
|
||||
}
|
||||
|
||||
kv_manager_ = std::make_unique<GcsInternalKVManager>(std::move(instance));
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "ray/common/asio/instrumented_io_context.h"
|
||||
#include "ray/gcs/store_client/in_memory_store_client.h"
|
||||
#include "ray/gcs/store_client/observable_store_client.h"
|
||||
#include "ray/gcs/store_client/redis_store_client.h"
|
||||
#include "src/ray/protobuf/gcs.pb.h"
|
||||
|
||||
|
@ -366,7 +367,8 @@ class RedisGcsTableStorage : public GcsTableStorage {
|
|||
class InMemoryGcsTableStorage : public GcsTableStorage {
|
||||
public:
|
||||
explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service)
|
||||
: GcsTableStorage(std::make_shared<InMemoryStoreClient>(main_io_service)) {}
|
||||
: GcsTableStorage(std::make_shared<ObservableStoreClient>(
|
||||
std::make_unique<InMemoryStoreClient>(main_io_service))) {}
|
||||
};
|
||||
|
||||
} // namespace gcs
|
||||
|
|
163
src/ray/gcs/store_client/observable_store_client.cc
Normal file
163
src/ray/gcs/store_client/observable_store_client.cc
Normal file
|
@ -0,0 +1,163 @@
|
|||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ray/gcs/store_client/observable_store_client.h"
|
||||
|
||||
#include "absl/time/time.h"
|
||||
#include "ray/stats/metric_defs.h"
|
||||
|
||||
namespace ray {
|
||||
namespace gcs {
|
||||
|
||||
using namespace ray::stats;
|
||||
|
||||
Status ObservableStoreClient::AsyncPut(const std::string &table_name,
|
||||
const std::string &key,
|
||||
const std::string &data,
|
||||
bool overwrite,
|
||||
std::function<void(bool)> callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "Put");
|
||||
return delegate_->AsyncPut(table_name,
|
||||
key,
|
||||
data,
|
||||
overwrite,
|
||||
[start, callback = std::move(callback)](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1),
|
||||
"Put");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status ObservableStoreClient::AsyncGet(
|
||||
const std::string &table_name,
|
||||
const std::string &key,
|
||||
const OptionalItemCallback<std::string> &callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "Get");
|
||||
return delegate_->AsyncGet(
|
||||
table_name, key, [start, callback](auto status, auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Get");
|
||||
if (callback) {
|
||||
callback(status, std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status ObservableStoreClient::AsyncGetAll(
|
||||
const std::string &table_name,
|
||||
const MapCallback<std::string, std::string> &callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "GetAll");
|
||||
return delegate_->AsyncGetAll(table_name, [start, callback](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "GetAll");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
Status ObservableStoreClient::AsyncMultiGet(
|
||||
const std::string &table_name,
|
||||
const std::vector<std::string> &keys,
|
||||
const MapCallback<std::string, std::string> &callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "MultiGet");
|
||||
return delegate_->AsyncMultiGet(table_name, keys, [start, callback](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "MultiGet");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status ObservableStoreClient::AsyncDelete(const std::string &table_name,
|
||||
const std::string &key,
|
||||
std::function<void(bool)> callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "Delete");
|
||||
return delegate_->AsyncDelete(
|
||||
table_name, key, [start, callback = std::move(callback)](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Delete");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status ObservableStoreClient::AsyncBatchDelete(const std::string &table_name,
|
||||
const std::vector<std::string> &keys,
|
||||
std::function<void(int64_t)> callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "BatchDelete");
|
||||
return delegate_->AsyncBatchDelete(
|
||||
table_name, keys, [start, callback = std::move(callback)](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "BatchDelete");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int ObservableStoreClient::GetNextJobID() { return delegate_->GetNextJobID(); }
|
||||
|
||||
Status ObservableStoreClient::AsyncGetKeys(
|
||||
const std::string &table_name,
|
||||
const std::string &prefix,
|
||||
std::function<void(std::vector<std::string>)> callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "GetKeys");
|
||||
return delegate_->AsyncGetKeys(
|
||||
table_name, prefix, [start, callback = std::move(callback)](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "GetKeys");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status ObservableStoreClient::AsyncExists(const std::string &table_name,
|
||||
const std::string &key,
|
||||
std::function<void(bool)> callback) {
|
||||
auto start = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_count.Record(1, "Exists");
|
||||
return delegate_->AsyncExists(
|
||||
table_name, key, [start, callback = std::move(callback)](auto result) {
|
||||
auto end = absl::GetCurrentTimeNanos();
|
||||
STATS_gcs_storage_operation_latency_ms.Record(
|
||||
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Exists");
|
||||
if (callback) {
|
||||
callback(std::move(result));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
70
src/ray/gcs/store_client/observable_store_client.h
Normal file
70
src/ray/gcs/store_client/observable_store_client.h
Normal file
|
@ -0,0 +1,70 @@
|
|||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ray/gcs/store_client/store_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace gcs {
|
||||
|
||||
/// Wraps around a StoreClient instance and observe the metrics.
|
||||
class ObservableStoreClient : public StoreClient {
|
||||
public:
|
||||
explicit ObservableStoreClient(std::unique_ptr<StoreClient> delegate)
|
||||
: delegate_(std::move(delegate)) {}
|
||||
|
||||
Status AsyncPut(const std::string &table_name,
|
||||
const std::string &key,
|
||||
const std::string &data,
|
||||
bool overwrite,
|
||||
std::function<void(bool)> callback) override;
|
||||
|
||||
Status AsyncGet(const std::string &table_name,
|
||||
const std::string &key,
|
||||
const OptionalItemCallback<std::string> &callback) override;
|
||||
|
||||
Status AsyncGetAll(const std::string &table_name,
|
||||
const MapCallback<std::string, std::string> &callback) override;
|
||||
|
||||
Status AsyncMultiGet(const std::string &table_name,
|
||||
const std::vector<std::string> &keys,
|
||||
const MapCallback<std::string, std::string> &callback) override;
|
||||
|
||||
Status AsyncDelete(const std::string &table_name,
|
||||
const std::string &key,
|
||||
std::function<void(bool)> callback) override;
|
||||
|
||||
Status AsyncBatchDelete(const std::string &table_name,
|
||||
const std::vector<std::string> &keys,
|
||||
std::function<void(int64_t)> callback) override;
|
||||
|
||||
int GetNextJobID() override;
|
||||
|
||||
Status AsyncGetKeys(const std::string &table_name,
|
||||
const std::string &prefix,
|
||||
std::function<void(std::vector<std::string>)> callback) override;
|
||||
|
||||
Status AsyncExists(const std::string &table_name,
|
||||
const std::string &key,
|
||||
std::function<void(bool)> callback) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<StoreClient> delegate_;
|
||||
};
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
|
@ -34,7 +34,6 @@ TEST_F(InMemoryStoreClientTest, AsyncPutAndAsyncGetTest) { TestAsyncPutAndAsyncG
|
|||
TEST_F(InMemoryStoreClientTest, AsyncGetAllAndBatchDeleteTest) {
|
||||
TestAsyncGetAllAndBatchDelete();
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ray/gcs/store_client/observable_store_client.h"
|
||||
|
||||
#include "ray/gcs/store_client/in_memory_store_client.h"
|
||||
#include "ray/gcs/store_client/test/store_client_test_base.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace gcs {
|
||||
|
||||
class ObservableStoreClientTest : public StoreClientTestBase {
|
||||
public:
|
||||
void InitStoreClient() override {
|
||||
store_client_ = std::make_shared<ObservableStoreClient>(
|
||||
std::make_unique<InMemoryStoreClient>(*(io_service_pool_->Get())));
|
||||
}
|
||||
|
||||
void DisconnectStoreClient() override {}
|
||||
};
|
||||
|
||||
TEST_F(ObservableStoreClientTest, AsyncPutAndAsyncGetTest) { TestAsyncPutAndAsyncGet(); }
|
||||
|
||||
TEST_F(ObservableStoreClientTest, AsyncGetAllAndBatchDeleteTest) {
|
||||
TestAsyncGetAllAndBatchDelete();
|
||||
}
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
|
@ -51,9 +51,8 @@ inline std::shared_ptr<grpc::Channel> BuildChannel(
|
|||
std::optional<grpc::ChannelArguments> arguments = std::nullopt) {
|
||||
if (!arguments.has_value()) {
|
||||
arguments = grpc::ChannelArguments();
|
||||
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
|
||||
// this configurable, or selectively set it for known local connections only.
|
||||
arguments->SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
|
||||
arguments->SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
|
||||
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
|
||||
arguments->SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
arguments->SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
}
|
||||
|
@ -112,7 +111,8 @@ class GrpcClient {
|
|||
quota.SetMaxThreads(num_threads);
|
||||
grpc::ChannelArguments argument;
|
||||
argument.SetResourceQuota(quota);
|
||||
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
|
||||
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
|
||||
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
|
||||
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
|
||||
|
||||
|
|
|
@ -184,6 +184,18 @@ DEFINE_stats(gcs_new_resource_creation_latency_ms,
|
|||
({0.1, 1, 10, 100, 1000, 10000}, ),
|
||||
ray::stats::HISTOGRAM);
|
||||
|
||||
/// GCS Storage
|
||||
DEFINE_stats(gcs_storage_operation_latency_ms,
|
||||
"Time to invoke an operation on Gcs storage",
|
||||
("Operation"),
|
||||
({0.1, 1, 10, 100, 1000, 10000}, ),
|
||||
ray::stats::HISTOGRAM);
|
||||
DEFINE_stats(gcs_storage_operation_count,
|
||||
"Number of operations invoked on Gcs storage",
|
||||
("Operation"),
|
||||
(),
|
||||
ray::stats::COUNT);
|
||||
|
||||
/// Placement Group
|
||||
// The end to end placement group creation latency.
|
||||
// The time from placement group creation request has received
|
||||
|
|
|
@ -84,6 +84,10 @@ DECLARE_stats(spill_manager_objects_bytes);
|
|||
DECLARE_stats(spill_manager_request_total);
|
||||
DECLARE_stats(spill_manager_throughput_mb);
|
||||
|
||||
/// GCS Storage
|
||||
DECLARE_stats(gcs_storage_operation_latency_ms);
|
||||
DECLARE_stats(gcs_storage_operation_count);
|
||||
|
||||
/// GCS Resource Manager
|
||||
DECLARE_stats(gcs_new_resource_creation_latency_ms);
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue