Merge branch 'master' of github.com:ray-project/ray into chunkedclienttask

This commit is contained in:
Chris Wong 2022-05-10 09:14:16 -07:00
commit 5847582593
54 changed files with 1305 additions and 472 deletions

View file

@ -254,7 +254,7 @@
- ./ci/env/install-minimal.sh
- ./ci/env/env_info.sh
- python ./ci/env/check_minimal_install.py
- bazel test --test_output=streamed --config=ci $(./ci/run/bazel_export_options)
- bazel test --test_output=streamed --config=ci --test_env=RAY_MINIMAL=1 $(./ci/run/bazel_export_options)
python/ray/tests/test_basic
- bazel test --test_output=streamed --config=ci $(./ci/run/bazel_export_options)
python/ray/tests/test_basic_2

View file

@ -539,6 +539,7 @@ cc_library(
":gcs_service_rpc",
":gcs_table_storage_lib",
":node_manager_rpc",
":observable_store_client",
":pubsub_lib",
":raylet_client_lib",
":scheduler",
@ -1853,6 +1854,7 @@ cc_library(
deps = [
":gcs",
":gcs_in_memory_store_client",
":observable_store_client",
":pubsub_lib",
":ray_common",
":redis_store_client",
@ -2281,6 +2283,24 @@ cc_library(
],
)
cc_library(
name = "observable_store_client",
srcs = [
"src/ray/gcs/store_client/observable_store_client.cc",
],
hdrs = [
"src/ray/gcs/callback.h",
"src/ray/gcs/store_client/observable_store_client.h",
"src/ray/gcs/store_client/store_client.h",
],
copts = COPTS,
strip_include_prefix = "src",
deps = [
":ray_common",
":ray_util",
],
)
cc_library(
name = "store_client_test_lib",
hdrs = [
@ -2327,6 +2347,20 @@ cc_test(
],
)
cc_test(
name = "observable_store_client_test",
size = "small",
srcs = ["src/ray/gcs/store_client/test/in_memory_store_client_test.cc"],
copts = COPTS,
tags = ["team:core"],
deps = [
":gcs_in_memory_store_client",
":observable_store_client",
":store_client_test_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "gcs",
srcs = glob(

View file

@ -76,7 +76,7 @@ if [[ "$platform" == "linux" ]]; then
"$PYTHON_EXE" -u -c "import ray; print(ray.__commit__)" | grep "$TRAVIS_COMMIT" || (echo "ray.__commit__ not set properly!" && exit 1)
# Install the dependencies to run the tests.
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests proxy.py
# Run a simple test script to make sure that the wheel works.
for SCRIPT in "${TEST_SCRIPTS[@]}"; do
@ -117,11 +117,11 @@ elif [[ "$platform" == "macosx" ]]; then
"$PIP_CMD" install -q "$PYTHON_WHEEL"
# Install the dependencies to run the tests.
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests
"$PIP_CMD" install -q aiohttp aiosignal frozenlist grpcio pytest==5.4.3 requests proxy.py
# Run a simple test script to make sure that the wheel works.
for SCRIPT in "${TEST_SCRIPTS[@]}"; do
retry "$PYTHON_EXE" "$SCRIPT"
PATH="$(dirname "$PYTHON_EXE"):$PATH" retry "$PYTHON_EXE" "$SCRIPT"
done
done
elif [ "${platform}" = windows ]; then

View file

@ -85,7 +85,7 @@ class DashboardAgent(object):
logger.info("Parent pid is %s", self.ppid)
# Setup raylet channel
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
)

View file

@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
aiogrpc.init_grpc_aio()
GRPC_CHANNEL_OPTIONS = (
("grpc.enable_http_proxy", 0),
*ray_constants.GLOBAL_GRPC_OPTIONS,
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
)

View file

@ -10,6 +10,7 @@ except ImportError:
from grpc.experimental import aio as aiogrpc
from ray._private.gcs_pubsub import GcsAioActorSubscriber
import ray.ray_constants as ray_constants
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray.dashboard.optional_utils import rest_response
@ -88,7 +89,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
address = "{}:{}".format(
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
)
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True
)
@ -207,7 +208,7 @@ class ActorHead(dashboard_utils.DashboardHeadModule):
except KeyError:
return rest_response(success=False, message="Bad Request")
try:
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
f"{ip_address}:{port}", options=options, asynchronous=True
)

View file

@ -43,7 +43,7 @@ class EventAgent(dashboard_utils.DashboardAgentModule):
)
if dashboard_rpc_address:
logger.info("Report events to %s", dashboard_rpc_address)
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = utils.init_grpc_channel(
dashboard_rpc_address, options=options, asynchronous=True
)

View file

@ -126,7 +126,7 @@ class LogHeadV1(dashboard_utils.DashboardHeadModule):
node_id, ports = change.new
ip = DataSource.node_id_to_ip[node_id]
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = init_grpc_channel(
f"{ip}:{ports[1]}", options=options, asynchronous=True
)

View file

@ -78,7 +78,7 @@ class NodeHead(dashboard_utils.DashboardHeadModule):
address = "{}:{}".format(
node_info["nodeManagerAddress"], int(node_info["nodeManagerPort"])
)
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
address, options, asynchronous=True
)

View file

@ -11,6 +11,7 @@ import ray.experimental.internal_kv as internal_kv
import ray._private.services
import ray._private.utils
from ray.ray_constants import (
GLOBAL_GRPC_OPTIONS,
DEBUG_AUTOSCALING_STATUS,
DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR,
@ -48,7 +49,7 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
if change.new:
node_id, ports = change.new
ip = DataSource.node_id_to_ip[node_id]
options = (("grpc.enable_http_proxy", 0),)
options = GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
f"{ip}:{ports[1]}", options=options, asynchronous=True
)

View file

@ -17,29 +17,6 @@ def disable_aiohttp_cache():
os.environ.pop("RAY_DASHBOARD_NO_CACHE", None)
@pytest.fixture
def set_http_proxy():
http_proxy = os.environ.get("http_proxy", None)
https_proxy = os.environ.get("https_proxy", None)
# set http proxy
os.environ["http_proxy"] = "www.example.com:990"
os.environ["https_proxy"] = "www.example.com:990"
yield
# reset http proxy
if http_proxy:
os.environ["http_proxy"] = http_proxy
else:
del os.environ["http_proxy"]
if https_proxy:
os.environ["https_proxy"] = https_proxy
else:
del os.environ["https_proxy"]
@pytest.fixture
def small_event_line_limit():
os.environ["EVENT_READ_LINE_LENGTH_LIMIT"] = "1024"

View file

@ -596,37 +596,55 @@ def test_immutable_types():
@pytest.mark.skipif(
os.environ.get("RAY_MINIMAL") == "1",
reason="This test is not supposed to work for minimal installation.",
os.environ.get("RAY_MINIMAL") == "1" or os.environ.get("RAY_DEFAULT") == "1",
reason="This test is not supposed to work for minimal or default installation.",
)
def test_http_proxy(enable_test_module, set_http_proxy, shutdown_only):
address_info = ray.init(num_cpus=1, include_dashboard=True)
assert wait_until_server_available(address_info["webui_url"]) is True
def test_http_proxy(enable_test_module, start_http_proxy, shutdown_only):
# C++ config `grpc_enable_http_proxy` only initializes once, so we have to
# run driver as a separate process to make sure the correct config value
# is initialized.
script = """
import ray
import time
import requests
from ray._private.test_utils import (
format_web_url,
wait_until_server_available,
)
import logging
webui_url = address_info["webui_url"]
webui_url = format_web_url(webui_url)
logger = logging.getLogger(__name__)
timeout_seconds = 10
start_time = time.time()
while True:
time.sleep(1)
address_info = ray.init(num_cpus=1, include_dashboard=True)
assert wait_until_server_available(address_info["webui_url"]) is True
webui_url = address_info["webui_url"]
webui_url = format_web_url(webui_url)
timeout_seconds = 10
start_time = time.time()
while True:
time.sleep(1)
try:
response = requests.get(
webui_url + "/test/dump", proxies={"http": None, "https": None}
)
response.raise_for_status()
try:
response = requests.get(
webui_url + "/test/dump", proxies={"http": None, "https": None}
)
response.raise_for_status()
try:
response.json()
assert response.ok
except Exception as ex:
logger.info("failed response: %s", response.text)
raise ex
break
except (AssertionError, requests.exceptions.ConnectionError) as e:
logger.info("Retry because of %s", e)
finally:
if time.time() > start_time + timeout_seconds:
raise Exception("Timed out while testing.")
response.json()
assert response.ok
except Exception as ex:
logger.info("failed response: %s", response.text)
raise ex
break
except (AssertionError, requests.exceptions.ConnectionError) as e:
logger.info("Retry because of %s", e)
finally:
if time.time() > start_time + timeout_seconds:
raise Exception("Timed out while testing.")
"""
env = start_http_proxy
run_string_as_driver(script, dict(os.environ, **env))
@pytest.mark.skipif(

View file

@ -243,7 +243,7 @@
"more machines.\n",
"\n",
"To ingest 500GiB of data, we'll set up a Ray Cluster.\n",
"The provided :download:`big_data_ingestion.yaml <../big_data_ingestion.yaml>`\n",
"The provided [big_data_ingestion.yaml](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/data/big_data_ingestion.yaml)\n",
"cluster config can be used to set up an AWS cluster with 70 CPU nodes and\n",
"16 GPU nodes. Using following command to bring up the Ray cluster.\n",
"\n",
@ -362,7 +362,7 @@
"# -> throughput: 8.56GiB/s\n",
"```\n",
"\n",
"Note: The pipeline can also be submitted using :ref:`Ray Job Submission <jobs-overview>`,\n",
"Note: The pipeline can also be submitted using [Ray Job Submission](https://docs.ray.io/en/latest/cluster/job-submission.html) ,\n",
"which is in beta starting with Ray 1.12. Try it out!"
]
}

View file

@ -15,7 +15,7 @@
"\n",
"## Cluster Setup\n",
"\n",
"First, we'll set up our Ray Cluster. The provided `dask_xgboost.yaml`\n",
"First, we'll set up our Ray Cluster. The provided [dask_xgboost.yaml](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-core/examples/dask_xgboost/dask_xgboost.yaml)\n",
"cluster config can be used to set up an AWS cluster with 64 CPUs.\n",
"\n",
"The following steps assume you are in a directory with both\n",
@ -477,4 +477,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View file

@ -14,7 +14,7 @@
"\n",
"## Cluster Setup\n",
"\n",
"First, we'll set up our Ray Cluster. The provided ``modin_xgboost.yaml``\n",
"First, we'll set up our Ray Cluster. The provided [modin_xgboost.yaml](https://raw.githubusercontent.com/ray-project/ray/master/doc/source/ray-core/examples/modin_xgboost/modin_xgboost.yaml)\n",
"cluster config can be used to set up an AWS cluster with 64 CPUs.\n",
"\n",
"The following steps assume you are in a directory with both\n",

View file

@ -7,6 +7,7 @@ import time
import grpc
import ray
from ray import ray_constants
from ray.core.generated.common_pb2 import ErrorType
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated import gcs_service_pb2
@ -74,7 +75,7 @@ _GRPC_KEEPALIVE_TIMEOUT_MS = 60 * 1000
# grpc.keepalive_permit_without_calls=0: No keepalive without inflight calls.
# grpc.use_local_subchannel_pool=0: Subchannels are shared.
_GRPC_OPTIONS = [
("grpc.enable_http_proxy", 0),
*ray_constants.GLOBAL_GRPC_OPTIONS,
("grpc.max_send_message_length", _MAX_MESSAGE_LENGTH),
("grpc.max_receive_message_length", _MAX_MESSAGE_LENGTH),
("grpc.keepalive_time_ms", _GRPC_KEEPALIVE_TIME_MS),

View file

@ -145,7 +145,7 @@ class Monitor:
retry_on_failure: bool = True,
):
gcs_address = address
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
gcs_channel = ray._private.utils.init_grpc_channel(gcs_address, options)
# TODO: Use gcs client for this
self.gcs_node_resources_stub = (

View file

@ -7,6 +7,7 @@ import grpc
import ray
from typing import Dict, List
from ray import ray_constants
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoRequest,
@ -106,7 +107,7 @@ class StateDataSourceClient:
def register_raylet_client(self, node_id: str, address: str, port: int):
full_addr = f"{address}:{port}"
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
full_addr, options, asynchronous=True
)
@ -116,7 +117,7 @@ class StateDataSourceClient:
self._raylet_stubs.pop(node_id)
def register_agent_client(self, node_id, address: str, port: int):
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = ray._private.utils.init_grpc_channel(
f"{address}:{port}", options=options, asynchronous=True
)

View file

@ -332,6 +332,14 @@ CALL_STACK_LINE_DELIMITER = " | "
# NOTE: This is equal to the C++ limit of (RAY_CONFIG::max_grpc_message_size)
GRPC_CPP_MAX_MESSAGE_SIZE = 100 * 1024 * 1024
# GRPC options
GRPC_ENABLE_HTTP_PROXY = (
1
if os.environ.get("RAY_grpc_enable_http_proxy", "0").lower() in ("1", "true")
else 0
)
GLOBAL_GRPC_OPTIONS = (("grpc.enable_http_proxy", GRPC_ENABLE_HTTP_PROXY),)
# Internal kv namespaces
KV_NAMESPACE_DASHBOARD = b"dashboard"
KV_NAMESPACE_SESSION = b"session"

View file

@ -58,6 +58,8 @@ from ray.serve.context import (
get_internal_replica_context,
ReplicaContext,
)
from ray.serve.pipeline.api import build as pipeline_build
from ray.serve.pipeline.api import get_and_validate_ingress_deployment
from ray._private.usage import usage_lib
logger = logging.getLogger(__file__)
@ -592,9 +594,6 @@ def run(
RayServeHandle: A regular ray serve handle that can be called by user
to execute the serve DAG.
"""
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
from ray.serve.pipeline.api import build as pipeline_build
from ray.serve.pipeline.api import get_and_validate_ingress_deployment
client = start(detached=True, http_options={"host": host, "port": port})
@ -668,8 +667,6 @@ def build(target: Union[ClassNode, FunctionNode]) -> Application:
The returned Application object can be exported to a dictionary or YAML
config.
"""
# TODO (jiaodong): Resolve circular reference in pipeline codebase and serve
from ray.serve.pipeline.api import build as pipeline_build
if in_interactive_shell():
raise RuntimeError(

View file

@ -46,6 +46,12 @@ class DeploymentStatusInfo:
)
HEALTH_CHECK_CONCURRENCY_GROUP = "health_check"
REPLICA_DEFAULT_ACTOR_OPTIONS = {
"concurrency_groups": {HEALTH_CHECK_CONCURRENCY_GROUP: 1}
}
class DeploymentInfo:
def __init__(
self,
@ -95,14 +101,14 @@ class DeploymentInfo:
or self.serialized_deployment_def is not None
)
if self.replica_config.import_path is not None:
self._cached_actor_def = ray.remote(
self._cached_actor_def = ray.remote(**REPLICA_DEFAULT_ACTOR_OPTIONS)(
create_replica_wrapper(
self.actor_name,
import_path=self.replica_config.import_path,
)
)
else:
self._cached_actor_def = ray.remote(
self._cached_actor_def = ray.remote(**REPLICA_DEFAULT_ACTOR_OPTIONS)(
create_replica_wrapper(
self.actor_name,
serialized_deployment_def=self.serialized_deployment_def,

View file

@ -18,7 +18,7 @@ from ray.util import metrics
from ray._private.async_compat import sync_to_async
from ray.serve.autoscaling_metrics import start_metrics_pusher
from ray.serve.common import ReplicaTag
from ray.serve.common import HEALTH_CHECK_CONCURRENCY_GROUP, ReplicaTag
from ray.serve.config import DeploymentConfig
from ray.serve.constants import (
HEALTH_CHECK_METHOD,
@ -219,6 +219,7 @@ def create_replica_wrapper(
if self.replica is not None:
return await self.replica.prepare_for_shutdown()
@ray.method(concurrency_group=HEALTH_CHECK_CONCURRENCY_GROUP)
async def check_health(self):
await self.replica.check_health()

View file

@ -235,6 +235,33 @@ def test_uvicorn_duplicate_headers(serve_instance):
assert resp.headers["content-length"] == "9"
def test_healthcheck_timeout(serve_instance):
# https://github.com/ray-project/ray/issues/24554
signal = SignalActor.remote()
@serve.deployment(
_health_check_timeout_s=2,
_health_check_period_s=1,
_graceful_shutdown_timeout_s=0,
)
class A:
def check_health(self):
return True
def __call__(self):
ray.get(signal.wait.remote())
A.deploy()
handle = A.get_handle()
ref = handle.remote()
# without the proper fix, the ref will fail with actor died error.
with pytest.raises(GetTimeoutError):
ray.get(ref, timeout=10)
signal.send.remote()
ray.get(ref)
if __name__ == "__main__":
import sys

View file

@ -276,9 +276,13 @@ def call_ray_start(request):
"--max-worker-port=0 --port 0",
)
command_args = parameter.split(" ")
out = ray._private.utils.decode(
subprocess.check_output(command_args, stderr=subprocess.STDOUT)
)
try:
out = ray._private.utils.decode(
subprocess.check_output(command_args, stderr=subprocess.STDOUT)
)
except Exception as e:
print(type(e), e)
raise
# Get the redis address from the output.
redis_substring_prefix = "--address='"
address_location = out.find(redis_substring_prefix) + len(redis_substring_prefix)
@ -867,3 +871,27 @@ def create_ray_logs_for_failed_test(rep):
test_name = rep.nodeid.replace(os.sep, "::")
output_file = os.path.join(archive_dir, f"{test_name}_{time.time():.4f}")
shutil.make_archive(output_file, "zip", logs_dir)
@pytest.fixture(params=[True, False])
def start_http_proxy(request):
env = {}
proxy = None
try:
if request.param:
# the `proxy` command is from the proxy.py package.
proxy = subprocess.Popen(
["proxy", "--port", "8899", "--log-level", "ERROR"]
)
env["RAY_grpc_enable_http_proxy"] = "1"
proxy_url = "http://localhost:8899"
else:
proxy_url = "http://example.com"
env["http_proxy"] = proxy_url
env["https_proxy"] = proxy_url
yield env
finally:
if proxy:
proxy.terminate()
proxy.wait()

View file

@ -22,17 +22,29 @@ logger = logging.getLogger(__name__)
# https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(
os.environ.get("RAY_MINIMAL") == "1",
reason="This test is not supposed to work for minimal installation.",
)
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
def test_ignore_http_proxy(shutdown_only):
ray.init(num_cpus=1)
os.environ["http_proxy"] = "http://example.com"
os.environ["https_proxy"] = "http://example.com"
def test_http_proxy(start_http_proxy, shutdown_only):
# C++ config `grpc_enable_http_proxy` only initializes once, so we have to
# run driver as a separate process to make sure the correct config value
# is initialized.
script = """
import ray
@ray.remote
def f():
return 1
ray.init(num_cpus=1)
assert ray.get(f.remote()) == 1
@ray.remote
def f():
return 1
assert ray.get(f.remote()) == 1
"""
env = start_http_proxy
run_string_as_driver(script, dict(os.environ, **env))
# https://github.com/ray-project/ray/issues/16025
@ -310,6 +322,66 @@ def test_options():
"zzz": 42,
}
# test options for other Ray libraries.
namespace = "namespace"
class mock_options:
def __init__(self, **options):
self.options = {"_metadata": {namespace: options}}
def keys(self):
return ("_metadata",)
def __getitem__(self, key):
return self.options[key]
def __call__(self, f):
f._default_options.update(self.options)
return f
@mock_options(a=1, b=2)
@ray.remote(num_gpus=2)
def foo():
pass
assert foo._default_options == {
"_metadata": {"namespace": {"a": 1, "b": 2}},
"num_gpus": 2,
}
f2 = foo.options(num_cpus=1, num_gpus=1, **mock_options(a=11, c=3))
# TODO(suquark): The current implementation of `.options()` is so bad that we
# cannot even access its options from outside. Here we hack the closures to
# achieve our goal. Need futher efforts to clean up the tech debt.
assert f2.remote.__closure__[1].cell_contents == {
"_metadata": {"namespace": {"a": 11, "b": 2, "c": 3}},
"num_cpus": 1,
"num_gpus": 1,
}
class mock_options2(mock_options):
def __init__(self, **options):
self.options = {"_metadata": {namespace + "2": options}}
f3 = foo.options(num_cpus=1, num_gpus=1, **mock_options2(a=11, c=3))
assert f3.remote.__closure__[1].cell_contents == {
"_metadata": {"namespace": {"a": 1, "b": 2}, "namespace2": {"a": 11, "c": 3}},
"num_cpus": 1,
"num_gpus": 1,
}
with pytest.raises(TypeError):
# Ensure only a single "**option" per ".options()".
# Otherwise it would be confusing.
foo.options(
num_cpus=1,
num_gpus=1,
**mock_options(a=11, c=3),
**mock_options2(a=11, c=3),
)
# https://github.com/ray-project/ray/issues/17842
def test_disable_cuda_devices():

View file

@ -400,7 +400,7 @@ def test_gcs_drain(ray_start_cluster_head, error_pubsub):
"""
# Prepare requests.
gcs_server_addr = cluster.gcs_address
options = (("grpc.enable_http_proxy", 0),)
options = ray_constants.GLOBAL_GRPC_OPTIONS
channel = grpc.insecure_channel(gcs_server_addr, options)
stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(channel)
r = gcs_service_pb2.DrainNodeRequest()

View file

@ -404,7 +404,7 @@ async def test_state_data_source_client(ray_start_cluster):
worker = cluster.add_node(num_cpus=2)
GRPC_CHANNEL_OPTIONS = (
("grpc.enable_http_proxy", 0),
*ray_constants.GLOBAL_GRPC_OPTIONS,
("grpc.max_send_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", ray_constants.GRPC_CPP_MAX_MESSAGE_SIZE),
)

View file

@ -1,7 +1,8 @@
from __future__ import print_function
import datetime
from typing import Dict, List, Optional, Union
import numbers
from typing import Any, Dict, List, Optional, Union
import collections
import os
@ -129,6 +130,8 @@ class TuneReporterBase(ProgressReporter):
max_error_rows: Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_column_length: Maximum column length (in characters). Column
headers and values longer than this will be abbreviated.
max_report_frequency: Maximum report frequency in seconds.
Defaults to 5s.
infer_limit: Maximum number of metrics to automatically infer
@ -169,11 +172,13 @@ class TuneReporterBase(ProgressReporter):
def __init__(
self,
*,
metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
total_samples: Optional[int] = None,
max_progress_rows: int = 20,
max_error_rows: int = 20,
max_column_length: int = 20,
max_report_frequency: int = 5,
infer_limit: int = 3,
print_intermediate_tables: Optional[bool] = None,
@ -188,6 +193,7 @@ class TuneReporterBase(ProgressReporter):
self._parameter_columns = parameter_columns or []
self._max_progress_rows = max_progress_rows
self._max_error_rows = max_error_rows
self._max_column_length = max_column_length
self._infer_limit = infer_limit
if print_intermediate_tables is None:
@ -360,6 +366,7 @@ class TuneReporterBase(ProgressReporter):
force_table=self._print_intermediate_tables,
fmt=fmt,
max_rows=max_progress,
max_column_length=self._max_column_length,
done=done,
metric=self._metric,
mode=self._mode,
@ -461,6 +468,8 @@ class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
max_error_rows: Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_column_length: Maximum column length (in characters). Column
headers and values longer than this will be abbreviated.
max_report_frequency: Maximum report frequency in seconds.
Defaults to 5s.
infer_limit: Maximum number of metrics to automatically infer
@ -480,12 +489,14 @@ class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
def __init__(
self,
*,
overwrite: bool = True,
metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
total_samples: Optional[int] = None,
max_progress_rows: int = 20,
max_error_rows: int = 20,
max_column_length: int = 20,
max_report_frequency: int = 5,
infer_limit: int = 3,
print_intermediate_tables: Optional[bool] = None,
@ -494,17 +505,18 @@ class JupyterNotebookReporter(TuneReporterBase, RemoteReporterMixin):
sort_by_metric: bool = False,
):
super(JupyterNotebookReporter, self).__init__(
metric_columns,
parameter_columns,
total_samples,
max_progress_rows,
max_error_rows,
max_report_frequency,
infer_limit,
print_intermediate_tables,
metric,
mode,
sort_by_metric,
metric_columns=metric_columns,
parameter_columns=parameter_columns,
total_samples=total_samples,
max_progress_rows=max_progress_rows,
max_error_rows=max_error_rows,
max_column_length=max_column_length,
max_report_frequency=max_report_frequency,
infer_limit=infer_limit,
print_intermediate_tables=print_intermediate_tables,
metric=metric,
mode=mode,
sort_by_metric=sort_by_metric,
)
if not IS_NOTEBOOK:
@ -564,6 +576,8 @@ class CLIReporter(TuneReporterBase):
max_error_rows: Maximum number of rows to print in the
error table. The error table lists the error file, if any,
corresponding to each trial. Defaults to 20.
max_column_length: Maximum column length (in characters). Column
headers and values longer than this will be abbreviated.
max_report_frequency: Maximum report frequency in seconds.
Defaults to 5s.
infer_limit: Maximum number of metrics to automatically infer
@ -583,11 +597,13 @@ class CLIReporter(TuneReporterBase):
def __init__(
self,
*,
metric_columns: Optional[Union[List[str], Dict[str, str]]] = None,
parameter_columns: Optional[Union[List[str], Dict[str, str]]] = None,
total_samples: Optional[int] = None,
max_progress_rows: int = 20,
max_error_rows: int = 20,
max_column_length: int = 20,
max_report_frequency: int = 5,
infer_limit: int = 3,
print_intermediate_tables: Optional[bool] = None,
@ -597,17 +613,18 @@ class CLIReporter(TuneReporterBase):
):
super(CLIReporter, self).__init__(
metric_columns,
parameter_columns,
total_samples,
max_progress_rows,
max_error_rows,
max_report_frequency,
infer_limit,
print_intermediate_tables,
metric,
mode,
sort_by_metric,
metric_columns=metric_columns,
parameter_columns=parameter_columns,
total_samples=total_samples,
max_progress_rows=max_progress_rows,
max_error_rows=max_error_rows,
max_column_length=max_column_length,
max_report_frequency=max_report_frequency,
infer_limit=infer_limit,
print_intermediate_tables=print_intermediate_tables,
metric=metric,
mode=mode,
sort_by_metric=sort_by_metric,
)
def report(self, trials: List[Trial], done: bool, *sys_info: Dict):
@ -683,6 +700,7 @@ def trial_progress_str(
force_table: bool = False,
fmt: str = "psql",
max_rows: Optional[int] = None,
max_column_length: int = 20,
done: bool = False,
metric: Optional[str] = None,
mode: Optional[str] = None,
@ -711,6 +729,7 @@ def trial_progress_str(
fmt: Output format (see tablefmt in tabulate API).
max_rows: Maximum number of rows in the trial table. Defaults to
unlimited.
max_column_length: Maximum column length (in characters).
done: True indicates that the tuning run finished.
metric: Metric used to sort trials.
mode: One of [min, max]. Determines whether objective is
@ -747,19 +766,48 @@ def trial_progress_str(
if force_table or (has_verbosity(Verbosity.V2_TRIAL_NORM) and done):
messages += trial_progress_table(
trials,
metric_columns,
parameter_columns,
fmt,
max_rows,
metric,
mode,
sort_by_metric,
trials=trials,
metric_columns=metric_columns,
parameter_columns=parameter_columns,
fmt=fmt,
max_rows=max_rows,
metric=metric,
mode=mode,
sort_by_metric=sort_by_metric,
max_column_length=max_column_length,
)
return delim.join(messages)
def _max_len(value: Any, max_len: int = 20, add_addr: bool = False) -> Any:
"""Abbreviate a string representation of an object to `max_len` characters.
For numbers, booleans and None, the original value will be returned for
correct rendering in the table formatting tool.
Args:
value: Object to be represented as a string.
max_len: Maximum return string length.
add_addr: If True, will add part of the object address to the end of the
string, e.g. to identify different instances of the same class. If
False, three dots (``...``) will be used instead.
"""
if value is None or isinstance(value, (int, float, numbers.Number, bool)):
return value
string = str(value)
if len(string) <= max_len:
return string
if add_addr and not isinstance(value, (int, float, bool)):
result = f"{string[: (max_len - 5)]}_{hex(id(value))[-4:]}"
return result
result = f"{string[: (max_len - 3)]}..."
return result
def trial_progress_table(
trials: List[Trial],
metric_columns: Union[List[str], Dict[str, str]],
@ -769,6 +817,7 @@ def trial_progress_table(
metric: Optional[str] = None,
mode: Optional[str] = None,
sort_by_metric: bool = False,
max_column_length: int = 20,
):
messages = []
num_trials = len(trials)
@ -840,17 +889,29 @@ def trial_progress_table(
# Build trial rows.
trial_table = [
_get_trial_info(trial, parameter_keys, metric_keys) for trial in trials
_get_trial_info(
trial, parameter_keys, metric_keys, max_column_length=max_column_length
)
for trial in trials
]
# Format column headings
if isinstance(metric_columns, Mapping):
formatted_metric_columns = [metric_columns[k] for k in metric_keys]
formatted_metric_columns = [
_max_len(metric_columns[k], max_len=max_column_length, add_addr=False)
for k in metric_keys
]
else:
formatted_metric_columns = metric_keys
if isinstance(parameter_columns, Mapping):
formatted_parameter_columns = [parameter_columns[k] for k in parameter_keys]
formatted_parameter_columns = [
_max_len(parameter_columns[k], max_len=max_column_length, add_addr=False)
for k in parameter_keys
]
else:
formatted_parameter_columns = parameter_keys
formatted_parameter_columns = [
_max_len(k, max_len=max_column_length, add_addr=False)
for k in parameter_keys
]
columns = (
["Trial name", "status", "loc"]
+ formatted_parameter_columns
@ -972,7 +1033,9 @@ def _get_trial_location(trial: Trial, result: dict) -> Location:
return location
def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]):
def _get_trial_info(
trial: Trial, parameters: List[str], metrics: List[str], max_column_length: int = 20
):
"""Returns the following information about a trial:
name | status | loc | params... | metrics...
@ -981,16 +1044,27 @@ def _get_trial_info(trial: Trial, parameters: List[str], metrics: List[str]):
trial: Trial to get information for.
parameters: Names of trial parameters to include.
metrics: Names of metrics to include.
max_column_length: Maximum column length (in characters).
"""
result = trial.last_result
config = trial.config
location = _get_trial_location(trial, result)
trial_info = [str(trial), trial.status, str(location)]
trial_info += [
unflattened_lookup(param, config, default=None) for param in parameters
_max_len(
unflattened_lookup(param, config, default=None),
max_len=max_column_length,
add_addr=True,
)
for param in parameters
]
trial_info += [
unflattened_lookup(metric, result, default=None) for metric in metrics
_max_len(
unflattened_lookup(metric, result, default=None),
max_len=max_column_length,
add_addr=True,
)
for metric in metrics
]
return trial_info

View file

@ -694,6 +694,25 @@ class ProgressReporterTest(unittest.TestCase):
tune.run(lambda config: 2, num_samples=1, progress_reporter=CustomReporter())
def testMaxLen(self):
trials = []
for i in range(5):
t = Mock()
t.status = "TERMINATED"
t.trial_id = "%05d" % i
t.local_dir = "/foo"
t.location = "here"
t.config = {"verylong" * 20: i}
t.evaluated_params = {"verylong" * 20: i}
t.last_result = {"some_metric": "evenlonger" * 100}
t.__str__ = lambda self: self.trial_id
trials.append(t)
progress_str = trial_progress_str(
trials, metric_columns=["some_metric"], force_table=True
)
assert any(len(row) <= 90 for row in progress_str.split("\n"))
if __name__ == "__main__":
import sys

View file

@ -1,4 +1,5 @@
from typing import List, Any, Union, Dict, Callable, Tuple, Optional
from collections import deque, defaultdict
import ray
from ray.workflow import workflow_context
@ -12,7 +13,6 @@ from ray.workflow.common import (
StepType,
)
from ray.workflow import workflow_storage
from ray.workflow.step_function import WorkflowStepFunction
class WorkflowStepNotRecoverableError(Exception):
@ -32,136 +32,111 @@ class WorkflowNotResumableError(Exception):
super().__init__(self.message)
@WorkflowStepFunction
def _recover_workflow_step(
input_workflows: List[Any],
input_workflow_refs: List[WorkflowRef],
*args,
**kwargs,
):
"""A workflow step that recovers the output of an unfinished step.
Args:
args: The positional arguments for the step function.
kwargs: The keyword args for the step function.
input_workflows: The workflows in the argument of the (original) step.
They are resolved into physical objects (i.e. the output of the
workflows) here. They come from other recover workflows we
construct recursively.
Returns:
The output of the recovered step.
"""
reader = workflow_storage.get_workflow_storage()
step_id = workflow_context.get_current_step_id()
func: Callable = reader.load_step_func_body(step_id)
return func(*args, **kwargs)
def _reconstruct_wait_step(
reader: workflow_storage.WorkflowStorage,
step_id: StepID,
result: workflow_storage.StepInspectResult,
input_map: Dict[StepID, Any],
):
input_workflows = []
step_options = result.step_options
wait_options = step_options.ray_options.get("wait_options", {})
for i, _step_id in enumerate(result.workflows):
# Check whether the step has been loaded or not to avoid
# duplication
if _step_id in input_map:
r = input_map[_step_id]
else:
r = _construct_resume_workflow_from_step(reader, _step_id, input_map)
input_map[_step_id] = r
if isinstance(r, Workflow):
input_workflows.append(r)
else:
assert isinstance(r, StepID)
# TODO (Alex): We should consider caching these outputs too.
output = reader.load_step_output(r)
# Simulate a workflow with a workflow reference so it could be
# used directly by 'workflow.wait'.
static_ref = WorkflowStaticRef(step_id=r, ref=ray.put(output))
wf = Workflow.from_ref(static_ref)
input_workflows.append(wf)
from ray import workflow
wait_step = workflow.wait(input_workflows, **wait_options)
# override step id
wait_step._step_id = step_id
return wait_step
def _construct_resume_workflow_from_step(
reader: workflow_storage.WorkflowStorage,
step_id: StepID,
input_map: Dict[StepID, Any],
) -> Union[Workflow, StepID]:
workflow_id: str, step_id: StepID
) -> Union[Workflow, Any]:
"""Try to construct a workflow (step) that recovers the workflow step.
If the workflow step already has an output checkpointing file, we return
the workflow step id instead.
Args:
reader: The storage reader for inspecting the step.
workflow_id: The ID of the workflow.
step_id: The ID of the step we want to recover.
input_map: This is a context storing the input which has been loaded.
This context is important for dedupe
Returns:
A workflow that recovers the step, or a ID of a step
that contains the output checkpoint file.
A workflow that recovers the step, or the output of the step
if it has been checkpointed.
"""
result: workflow_storage.StepInspectResult = reader.inspect_step(step_id)
if result.output_object_valid:
# we already have the output
return step_id
if isinstance(result.output_step_id, str):
return _construct_resume_workflow_from_step(
reader, result.output_step_id, input_map
)
# output does not exists or not valid. try to reconstruct it.
if not result.is_recoverable():
raise WorkflowStepNotRecoverableError(step_id)
reader = workflow_storage.WorkflowStorage(workflow_id)
step_options = result.step_options
# Process the wait step as a special case.
if step_options.step_type == StepType.WAIT:
return _reconstruct_wait_step(reader, step_id, result, input_map)
# Step 1: construct dependency of the DAG (BFS)
inpsect_results = {}
dependency_map = defaultdict(list)
num_in_edges = {}
dag_visit_queue = deque([step_id])
while dag_visit_queue:
s: StepID = dag_visit_queue.popleft()
if s in inpsect_results:
continue
r = reader.inspect_step(s)
inpsect_results[s] = r
if not r.is_recoverable():
raise WorkflowStepNotRecoverableError(s)
if r.output_object_valid:
deps = []
elif isinstance(r.output_step_id, str):
deps = [r.output_step_id]
else:
deps = r.workflows
for w in deps:
dependency_map[w].append(s)
num_in_edges[s] = len(deps)
dag_visit_queue.extend(deps)
# Step 2: topological sort to determine the execution order (Kahn's algorithm)
execution_queue: List[StepID] = []
start_nodes = deque(k for k, v in num_in_edges.items() if v == 0)
while start_nodes:
n = start_nodes.popleft()
execution_queue.append(n)
for m in dependency_map[n]:
num_in_edges[m] -= 1
assert num_in_edges[m] >= 0, (m, n)
if num_in_edges[m] == 0:
start_nodes.append(m)
# Step 3: recover the workflow by the order of the execution queue
with serialization.objectref_cache():
input_workflows = []
for i, _step_id in enumerate(result.workflows):
# Check whether the step has been loaded or not to avoid
# duplication
if _step_id in input_map:
r = input_map[_step_id]
else:
r = _construct_resume_workflow_from_step(reader, _step_id, input_map)
input_map[_step_id] = r
if isinstance(r, Workflow):
input_workflows.append(r)
else:
assert isinstance(r, StepID)
# TODO (Alex): We should consider caching these outputs too.
input_workflows.append(reader.load_step_output(r))
workflow_refs = list(map(WorkflowRef, result.workflow_refs))
# "input_map" is a context storing the input which has been loaded.
# This context is important for deduplicate step inputs.
input_map: Dict[StepID, Any] = {}
args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs)
# Note: we must uppack args and kwargs, so the refs in the args/kwargs can get
# resolved consistently like in Ray.
recovery_workflow: Workflow = _recover_workflow_step.step(
input_workflows,
workflow_refs,
*args,
**kwargs,
)
recovery_workflow._step_id = step_id
# override step_options
recovery_workflow.data.step_options = step_options
return recovery_workflow
for _step_id in execution_queue:
result = inpsect_results[_step_id]
if result.output_object_valid:
input_map[_step_id] = reader.load_step_output(_step_id)
continue
if isinstance(result.output_step_id, str):
input_map[_step_id] = input_map[result.output_step_id]
continue
# Process the wait step as a special case.
if result.step_options.step_type == StepType.WAIT:
wait_input_workflows = []
for w in result.workflows:
output = input_map[w]
if isinstance(output, Workflow):
wait_input_workflows.append(output)
else:
# Simulate a workflow with a workflow reference so it could be
# used directly by 'workflow.wait'.
static_ref = WorkflowStaticRef(step_id=w, ref=ray.put(output))
wait_input_workflows.append(Workflow.from_ref(static_ref))
recovery_workflow = ray.workflow.wait(
wait_input_workflows,
**result.step_options.ray_options.get("wait_options", {}),
)
else:
args, kwargs = reader.load_step_args(
_step_id,
workflows=[input_map[w] for w in result.workflows],
workflow_refs=list(map(WorkflowRef, result.workflow_refs)),
)
func: Callable = reader.load_step_func_body(_step_id)
# TODO(suquark): Use an alternative function when "workflow.step"
# is fully deprecated.
recovery_workflow = ray.workflow.step(func).step(*args, **kwargs)
# override step_options
recovery_workflow._step_id = _step_id
recovery_workflow.data.step_options = result.step_options
input_map[_step_id] = recovery_workflow
# Step 4: return the output of the requested step
return input_map[step_id]
@ray.remote(num_returns=2)
@ -183,21 +158,19 @@ def _resume_workflow_step_executor(
except Exception:
pass
try:
wf_store = workflow_storage.WorkflowStorage(workflow_id)
r = _construct_resume_workflow_from_step(wf_store, step_id, {})
r = _construct_resume_workflow_from_step(workflow_id, step_id)
except Exception as e:
raise WorkflowNotResumableError(workflow_id) from e
if isinstance(r, Workflow):
with workflow_context.workflow_step_context(
workflow_id, last_step_of_workflow=True
):
from ray.workflow.step_executor import execute_workflow
if not isinstance(r, Workflow):
return r, None
with workflow_context.workflow_step_context(
workflow_id, last_step_of_workflow=True
):
from ray.workflow.step_executor import execute_workflow
result = execute_workflow(job_id, r)
return result.persisted_output, result.volatile_output
assert isinstance(r, StepID)
return wf_store.load_step_output(r), None
result = execute_workflow(job_id, r)
return result.persisted_output, result.volatile_output
def resume_workflow_step(

View file

@ -94,6 +94,7 @@ tqdm
async-exit-stack
async-generator
cryptography>=3.0.0
proxy.py
# For doc tests
myst-parser==0.15.2
myst-nb==0.13.1

View file

@ -95,7 +95,7 @@ def maybe_fetch_buildkite_token() -> str:
return buildkite_token
print("Missing BUILDKITE_TOKEN, retrieving from AWS secrets store")
os.environ["BUILDKITE_TOKEN"] = boto3.client(
buildkite_token = boto3.client(
"secretsmanager", region_name="us-west-2"
).get_secret_value(
SecretId="arn:aws:secretsmanager:us-west-2:029272617770:secret:"
@ -103,6 +103,8 @@ def maybe_fetch_buildkite_token() -> str:
)[
"SecretString"
]
os.environ["BUILDKITE_TOKEN"] = buildkite_token
return buildkite_token
def get_results_from_build_collection(

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.alpha_star.alpha_star import DEFAULT_CONFIG, AlphaStarTrainer
from ray.rllib.agents.alpha_star.alpha_star import (
AlphaStarConfig,
AlphaStarTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"DEFAULT_CONFIG",
"AlphaStarConfig",
"AlphaStarTrainer",
"DEFAULT_CONFIG",
]

View file

@ -3,8 +3,8 @@ A multi-agent, distributed multi-GPU, league-capable asynch. PPO
================================================================
"""
import gym
from typing import Optional, Type
import tree
from typing import Any, Dict, Optional, Type
import ray
@ -19,6 +19,7 @@ from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentRepla
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
@ -42,36 +43,60 @@ from ray.rllib.utils.typing import (
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.timer import _Timer
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the `IMPALATrainer` config in
# rllib/agents/impala/impala.py.
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
appo.DEFAULT_CONFIG, # See keys in appo.py, which are also supported.
{
# TODO: Unify the buffer API, then clean up our existing
# implementations of different buffers.
# This is num batches held at any time for each policy.
"replay_buffer_capacity": 20,
# e.g. ratio=0.2 -> 20% of samples in each train batch are
# old (replayed) ones.
"replay_buffer_replay_ratio": 0.5,
class AlphaStarConfig(appo.APPOConfig):
"""Defines a configuration class from which an AlphaStarTrainer can be built.
# Timeout to use for `ray.wait()` when waiting for samplers to have placed
# new data into the buffers. If no samples are ready within the timeout,
# the buffers used for mixin-sampling will return only older samples.
"sample_wait_timeout": 0.01,
# Timeout to use for `ray.wait()` when waiting for the policy learner actors
# to have performed an update and returned learning stats. If no learner
# actors have produced any learning results in the meantime, their
# learner-stats in the results will be empty for that iteration.
"learn_wait_timeout": 0.1,
Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
>>> from ray import tune
>>> config = AlphaStarConfig()
>>> # Print out some default values.
>>> print(config.vtrace)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "AlphaStar",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, trainer_class=None):
"""Initializes a AlphaStarConfig instance."""
super().__init__(trainer_class=trainer_class or AlphaStarTrainer)
# fmt: off
# __sphinx_doc_begin__
# AlphaStar specific settings:
self.replay_buffer_capacity = 20
self.replay_buffer_replay_ratio = 0.5
self.sample_wait_timeout = 0.01
self.learn_wait_timeout = 0.1
# League-building parameters.
# The LeagueBuilder class to be used for league building logic.
"league_builder_config": {
self.league_builder_config = {
# Specify the sub-class of the `LeagueBuilder` API to use.
"type": AlphaStarLeagueBuilder,
# Any any number of constructor kwargs to pass to this class:
# The number of random policies to add to the league. This must be an
# even number (including 0) as these will be evenly distributed
# amongst league- and main- exploiters.
@ -100,28 +125,80 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Only for ME matches: Prob to play against learning
# main (vs a snapshot main).
"prob_main_exploiter_playing_against_learning_main": 0.5,
},
}
self.max_num_policies_to_train = None
# The maximum number of trainable policies for this Trainer.
# Each trainable policy will exist as a independent remote actor, co-located
# with a replay buffer. This is besides its existence inside
# the RolloutWorkers for training and evaluation.
# Set to None for automatically inferring this value from the number of
# trainable policies found in the `multiagent` config.
"max_num_policies_to_train": None,
# Override some of APPOConfig's default values with AlphaStar-specific
# values.
self.vtrace_drop_last_ts = False
self.min_time_s_per_reporting = 2
# __sphinx_doc_end__
# fmt: on
# By default, don't drop last timestep.
# TODO: We should do the same for IMPALA and APPO at some point.
"vtrace_drop_last_ts": False,
@override(appo.APPOConfig)
def training(
self,
*,
replay_buffer_capacity: Optional[int] = None,
replay_buffer_replay_ratio: Optional[float] = None,
sample_wait_timeout: Optional[float] = None,
learn_wait_timeout: Optional[float] = None,
league_builder_config: Optional[Dict[str, Any]] = None,
max_num_policies_to_train: Optional[int] = None,
**kwargs,
) -> "AlphaStarConfig":
"""Sets the training related configuration.
# Reporting interval.
"min_time_s_per_reporting": 2,
},
_allow_unknown_configs=True,
)
Args:
replay_buffer_capacity: This is num batches held at any time for each
policy.
replay_buffer_replay_ratio: For example, ratio=0.2 -> 20% of samples in
each train batch are old (replayed) ones.
sample_wait_timeout: Timeout to use for `ray.wait()` when waiting for
samplers to have placed new data into the buffers. If no samples are
ready within the timeout, the buffers used for mixin-sampling will
return only older samples.
learn_wait_timeout: Timeout to use for `ray.wait()` when waiting for the
policy learner actors to have performed an update and returned learning
stats. If no learner actors have produced any learning results in the
meantime, their learner-stats in the results will be empty for that
iteration.
league_builder_config: League-building config dict.
The dict Must contain a `type` key indicating the LeagueBuilder class
to be used for league building logic. All other keys (that are not
`type`) will be used as constructor kwargs on the given class to
construct the LeagueBuilder instance. See the
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
(used by default by this algo) as an example.
max_num_policies_to_train: The maximum number of trainable policies for this
Trainer. Each trainable policy will exist as a independent remote actor,
co-located with a replay buffer. This is besides its existence inside
the RolloutWorkers for training and evaluation. Set to None for
automatically inferring this value from the number of trainable
policies found in the `multiagent` config.
# __sphinx_doc_end__
# fmt: on
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
# TODO: Unify the buffer API, then clean up our existing
# implementations of different buffers.
if replay_buffer_capacity is not None:
self.replay_buffer_capacity = replay_buffer_capacity
if replay_buffer_replay_ratio is not None:
self.replay_buffer_replay_ratio = replay_buffer_replay_ratio
if sample_wait_timeout is not None:
self.sample_wait_timeout = sample_wait_timeout
if learn_wait_timeout is not None:
self.learn_wait_timeout = learn_wait_timeout
if league_builder_config is not None:
self.league_builder_config = league_builder_config
if max_num_policies_to_train is not None:
self.max_num_policies_to_train = max_num_policies_to_train
return self
class AlphaStarTrainer(appo.APPOTrainer):
@ -204,7 +281,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
@classmethod
@override(appo.APPOTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return AlphaStarConfig().to_dict()
@override(appo.APPOTrainer)
def validate_config(self, config: TrainerConfigDict):
@ -499,3 +576,20 @@ class AlphaStarTrainer(appo.APPOTrainer):
state_copy = state.copy()
self.league_builder.__setstate__(state.pop("league_builder", {}))
super().__setstate__(state_copy)
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(AlphaStarConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -80,9 +80,15 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
):
"""Initializes a AlphaStarLeagueBuilder instance.
The following match types are possible:
LE: A learning (not snapshot) league_exploiter vs any snapshot policy.
ME: A learning (not snapshot) main exploiter vs any main.
M: Main self-play (main vs main).
Args:
trainer: The Trainer object by which this league builder is used.
Trainer calls `build_league()` after each training step.
Trainer calls `build_league()` after each training step to reconfigure
the league structure (e.g. to add/remove policies).
trainer_config: The (not yet validated) config dict to be
used on the Trainer. Child classes of `LeagueBuilder`
should preprocess this to add e.g. multiagent settings

View file

@ -26,38 +26,33 @@ class TestAlphaStar(unittest.TestCase):
def test_alpha_star_compilation(self):
"""Test whether a AlphaStarTrainer can be built with all frameworks."""
config = {
"env": "connect_four",
"gamma": 1.0,
"num_workers": 4,
"num_envs_per_worker": 5,
"model": {
"fcnet_hiddens": [256, 256, 256],
},
"vf_loss_coeff": 0.01,
"entropy_coeff": 0.004,
"league_builder_config": {
"win_rate_threshold_for_new_snapshot": 0.8,
"num_random_policies": 2,
"num_learning_league_exploiters": 1,
"num_learning_main_exploiters": 1,
},
"grad_clip": 10.0,
"replay_buffer_capacity": 10,
"replay_buffer_replay_ratio": 0.0,
# Two GPUs -> 2 policies per GPU.
"num_gpus": 4,
"_fake_gpus": True,
# Test with KL loss, just to cover that extra code.
"use_kl_loss": True,
}
config = (
alpha_star.AlphaStarConfig()
.environment(env="connect_four")
.training(
gamma=1.0,
model={"fcnet_hiddens": [256, 256, 256]},
vf_loss_coeff=0.01,
entropy_coeff=0.004,
league_builder_config={
"win_rate_threshold_for_new_snapshot": 0.8,
"num_random_policies": 2,
"num_learning_league_exploiters": 1,
"num_learning_main_exploiters": 1,
},
grad_clip=10.0,
replay_buffer_capacity=10,
replay_buffer_replay_ratio=0.0,
use_kl_loss=True,
)
.rollouts(num_rollout_workers=4, num_envs_per_worker=5)
.resources(num_gpus=4, _fake_gpus=True)
)
num_iterations = 2
for _ in framework_iterator(config, with_eager_tracing=True):
_config = config.copy()
trainer = alpha_star.AlphaStarTrainer(config=_config)
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
print(results)

View file

@ -4,7 +4,7 @@
## Overview
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based enviornments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
[Dreamer](https://arxiv.org/abs/1912.01603) is a model-based off-policy RL algorithm that learns by imagining and works well in visual-based environments. Like all model-based algorithms, Dreamer learns the environment's transiton dynamics via a latent-space model called [PlaNet](https://ai.googleblog.com/2019/02/introducing-planet-deep-planning.html). PlaNet learns to encode visual space into latent vectors, which can be used as pseudo-observations in Dreamer.
Dreamer is a gradient-based RL algorithm. This means that the agent imagines ahead using its learned transition dynamics model (PlaNet) to discover new rewards and states. Because imagining ahead is fully differentiable, the RL objective (maximizing the sum of rewards) is fully differentiable and does not need to be optimized indirectly such as policy gradient methods. This feature of gradient-based learning, in conjunction with PlaNet, enables the agent to learn in a latent space and achieves much better sample complexity and performance than other visual-based agents.
@ -14,6 +14,6 @@ For more details, there is a Ray/RLlib [blogpost](https://medium.com/distributed
Dreamer.
**[Detailed Documentation](https://docs.ray.io/en/master/rllib-algorithms.html#dqn)**
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#dreamer)**
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/simple_q.py)**
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/dreamer/dreamer.py)**

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.dreamer.dreamer import DREAMERTrainer, DEFAULT_CONFIG
from ray.rllib.agents.dreamer.dreamer import (
DREAMERConfig,
DREAMERTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"DREAMERConfig",
"DREAMERTrainer",
"DEFAULT_CONFIG",
]

View file

@ -1,97 +1,209 @@
import logging
import random
import numpy as np
import random
from typing import Optional
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.agents.dreamer.dreamer_torch_policy import DreamerTorchPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.agents.dreamer.dreamer_model import DreamerModel
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
SampleBatchType,
TrainerConfigDict,
ResultDict,
)
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# PlaNET Model LR
"td_model_lr": 6e-4,
# Actor LR
"actor_lr": 8e-5,
# Critic LR
"critic_lr": 8e-5,
# Grad Clipping
"grad_clip": 100.0,
# Discount
"discount": 0.99,
# Lambda
"lambda": 0.95,
# Clipping is done inherently via policy tanh.
"clip_actions": False,
# Training iterations per data collection from real env
"dreamer_train_iters": 100,
# Horizon for Enviornment (1000 for Mujoco/DMC)
"horizon": 1000,
# Number of episodes to sample for Loss Calculation
"batch_size": 50,
# Length of each episode to sample for Loss Calculation
"batch_length": 50,
# Imagination Horizon for Training Actor and Critic
"imagine_horizon": 15,
# Free Nats
"free_nats": 3.0,
# KL Coeff for the Model Loss
"kl_coeff": 1.0,
# Distributed Dreamer not implemented yet
"num_workers": 0,
# Prefill Timesteps
"prefill_timesteps": 5000,
# This should be kept at 1 to preserve sample efficiency
"num_envs_per_worker": 1,
# Exploration Gaussian
"explore_noise": 0.3,
# Batch mode
"batch_mode": "complete_episodes",
# Custom Model
"dreamer_model": {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
},
"env_config": {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
},
class DREAMERConfig(TrainerConfig):
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on
Example:
>>> from ray.rllib.agents.dreamer import DREAMERConfig
>>> config = DREAMERConfig().training(gamma=0.9, lr=0.01)\
... .resources(num_gpus=0)\
... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray import tune
>>> from ray.rllib.agents.dreamer import DREAMERConfig
>>> config = DREAMERConfig()
>>> # Print out some default values.
>>> print(config.clip_param)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), clip_param=0.2)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "DREAMER",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self):
"""Initializes a PPOConfig instance."""
super().__init__(trainer_class=DREAMERTrainer)
# fmt: off
# __sphinx_doc_begin__
# Dreamer specific settings:
self.td_model_lr = 6e-4
self.actor_lr = 8e-5
self.critic_lr = 8e-5
self.grad_clip = 100.0
self.lambda_ = 0.95
self.dreamer_train_iters = 100
self.batch_size = 50
self.batch_length = 50
self.imagine_horizon = 15
self.free_nats = 3.0
self.kl_coeff = 1.0
self.prefill_timesteps = 5000
self.explore_noise = 0.3
self.dreamer_model = {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
}
# Override some of TrainerConfig's default values with PPO-specific values.
# .rollouts()
self.num_workers = 0
self.num_envs_per_worker = 1
self.horizon = 1000
self.batch_mode = "complete_episodes"
self.clip_actions = False
# .training()
self.gamma = 0.99
# .environment()
self.env_config = {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
}
# __sphinx_doc_end__
# fmt: on
@override(TrainerConfig)
def training(
self,
*,
td_model_lr: Optional[float] = None,
actor_lr: Optional[float] = None,
critic_lr: Optional[float] = None,
grad_clip: Optional[float] = None,
lambda_: Optional[float] = None,
dreamer_train_iters: Optional[int] = None,
batch_size: Optional[int] = None,
batch_length: Optional[int] = None,
imagine_horizon: Optional[int] = None,
free_nats: Optional[float] = None,
kl_coeff: Optional[float] = None,
prefill_timesteps: Optional[int] = None,
explore_noise: Optional[float] = None,
dreamer_model: Optional[dict] = None,
**kwargs,
) -> "DREAMERConfig":
"""
Args:
td_model_lr: PlaNET (transition dynamics) model learning rate.
actor_lr: Actor model learning rate.
critic_lr: Critic model learning rate.
grad_clip: If specified, clip the global norm of gradients by this amount.
lambda_: The GAE (lambda) parameter.
dreamer_train_iters: Training iterations per data collection from real env.
batch_size: Number of episodes to sample for loss calculation.
batch_length: Length of each episode to sample for loss calculation.
imagine_horizon: Imagination horizon for training Actor and Critic.
free_nats: Free nats.
kl_coeff: KL coefficient for the model Loss.
prefill_timesteps: Prefill timesteps.
explore_noise: Exploration Gaussian noise.
dreamer_model: Custom model config.
Returns:
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if td_model_lr is not None:
self.td_model_lr = td_model_lr
if actor_lr is not None:
self.actor_lr = actor_lr
if critic_lr is not None:
self.critic_lr = critic_lr
if grad_clip is not None:
self.grad_clip = grad_clip
if lambda_ is not None:
self.lambda_ = lambda_
if dreamer_train_iters is not None:
self.dreamer_train_iters = dreamer_train_iters
if batch_size is not None:
self.batch_size = batch_size
if batch_length is not None:
self.batch_length = batch_length
if imagine_horizon is not None:
self.imagine_horizon = imagine_horizon
if free_nats is not None:
self.free_nats = free_nats
if kl_coeff is not None:
self.kl_coeff = kl_coeff
if prefill_timesteps is not None:
self.prefill_timesteps = prefill_timesteps
if explore_noise is not None:
self.explore_noise = explore_noise
if dreamer_model is not None:
self.dreamer_model = dreamer_model
return self
def _postprocess_gif(gif: np.ndarray):
"""Process provided gif to a format that can be logged to Tensorboard."""
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
class EpisodicBuffer(object):
def __init__(self, max_length: int = 1000, length: int = 50):
"""Data structure that stores episodes and samples chunks
of size length from episodes
"""Stores episodes and samples chunks of size ``length`` from episodes.
Args:
max_length: Maximum episodes it can store
length: Episode chunking lengh in sample()
length: Episode chunking length in sample()
"""
# Stores all episodes into a list: List[SampleBatchType]
@ -101,8 +213,7 @@ class EpisodicBuffer(object):
self.length = length
def add(self, batch: SampleBatchType):
"""Splits a SampleBatch into episodes and adds episodes
to the episode buffer
"""Splits a SampleBatch into episodes and adds episodes to the episode buffer.
Args:
batch: SampleBatch to be added
@ -151,7 +262,6 @@ class DreamerIteration:
self.batch_size = batch_size
def __call__(self, samples):
# Dreamer training loop.
for n in range(self.dreamer_train_iters):
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
@ -161,7 +271,7 @@ class DreamerIteration:
fetches = self.worker.learn_on_batch(batch)
# Custom Logging
policy_fetches = self.policy_stats(fetches)
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self.postprocess_gif(gif)
@ -180,20 +290,14 @@ class DreamerIteration:
return res
def postprocess_gif(self, gif: np.ndarray):
gif = np.clip(255 * gif, 0, 255).astype(np.uint8)
B, T, C, H, W = gif.shape
frames = gif.transpose((1, 2, 3, 0, 4)).reshape((1, T, C, H, B * W))
return frames
def policy_stats(self, fetches):
return fetches[DEFAULT_POLICY_ID]["learner_stats"]
return _postprocess_gif(gif=gif)
class DREAMERTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return DREAMERConfig().to_dict()
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
@ -211,6 +315,11 @@ class DREAMERTrainer(Trainer):
raise ValueError("Distributed Dreamer not supported yet!")
if config["clip_actions"]:
raise ValueError("Clipping is done inherently via policy tanh!")
if config["dreamer_train_iters"] <= 0:
raise ValueError(
"`dreamer_train_iters` must be a positive integer. "
f"Received {config['dreamer_train_iters']} instead."
)
if config["action_repeat"] > 1:
config["horizon"] = config["horizon"] / config["action_repeat"]
@ -218,6 +327,22 @@ class DREAMERTrainer(Trainer):
def get_default_policy_class(self, config: TrainerConfigDict):
return DreamerTorchPolicy
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# `training_iteration` implementation: Setup buffer in `setup`, not
# in `execution_plan` (deprecated).
if self.config["_disable_execution_plan_api"] is True:
self.local_replay_buffer = EpisodicBuffer(length=config["batch_length"])
# Prefill episode buffer with initial exploration (uniform sampling)
while (
total_sampled_timesteps(self.workers.local_worker())
< self.config["prefill_timesteps"]
):
samples = self.workers.local_worker().sample()
self.local_replay_buffer.add(samples)
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
@ -250,3 +375,61 @@ class DREAMERTrainer(Trainer):
)
)
return rollouts
@override(Trainer)
def training_iteration(self) -> ResultDict:
local_worker = self.workers.local_worker()
# Number of sub-iterations for Dreamer
dreamer_train_iters = self.config["dreamer_train_iters"]
batch_size = self.config["batch_size"]
action_repeat = self.config["action_repeat"]
# Collect SampleBatches from rollout workers.
batch = synchronous_parallel_sample(worker_set=self.workers)
fetches = {}
# Dreamer training loop.
# Run multiple sub-iterations for each training iteration.
for n in range(dreamer_train_iters):
print(f"sub-iteration={n}/{dreamer_train_iters}")
batch = self.local_replay_buffer.sample(batch_size)
fetches = local_worker.learn_on_batch(batch)
if fetches:
# Custom Logging
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
if "log_gif" in policy_fetches:
gif = policy_fetches["log_gif"]
policy_fetches["log_gif"] = self._postprocess_gif(gif)
self._counters[STEPS_SAMPLED_COUNTER] = (
self.local_replay_buffer.timesteps * action_repeat
)
self.local_replay_buffer.add(batch)
return fetches
def _compile_step_results(self, *args, **kwargs):
results = super()._compile_step_results(*args, **kwargs)
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
return results
# Deprecated: Use ray.rllib.agents.dreamer.DREAMERConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(DREAMERConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG",
new="ray.rllib.agents.dreamer.dreamer.DREAMERConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -7,12 +7,13 @@ import ray
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import apply_grad_clipping
from ray.rllib.utils.typing import AgentID
from ray.rllib.utils.typing import AgentID, TensorType
torch, nn = try_import_torch()
if torch:
@ -23,30 +24,30 @@ logger = logging.getLogger(__name__)
# This is the computation graph for workers (inner adaptation steps)
def compute_dreamer_loss(
obs,
action,
reward,
model,
imagine_horizon,
discount=0.99,
lambda_=0.95,
kl_coeff=1.0,
free_nats=3.0,
log=False,
obs: TensorType,
action: TensorType,
reward: TensorType,
model: TorchModelV2,
imagine_horizon: int,
gamma: float = 0.99,
lambda_: float = 0.95,
kl_coeff: float = 1.0,
free_nats: float = 3.0,
log: bool = False,
):
"""Constructs loss for the Dreamer objective
"""Constructs loss for the Dreamer objective.
Args:
obs (TensorType): Observations (o_t)
action (TensorType): Actions (a_(t-1))
reward (TensorType): Rewards (r_(t-1))
model (TorchModelV2): DreamerModel, encompassing all other models
imagine_horizon (int): Imagine horizon for actor and critic loss
discount (float): Discount
lambda_ (float): Lambda, like in GAE
kl_coeff (float): KL Coefficient for Divergence loss in model loss
free_nats (float): Threshold for minimum divergence in model loss
log (bool): If log, generate gifs
obs: Observations (o_t).
action: Actions (a_(t-1)).
reward: Rewards (r_(t-1)).
model: DreamerModel, encompassing all other models.
imagine_horizon: Imagine horizon for actor and critic loss.
gamma: Discount factor gamma.
lambda_: Lambda, like in GAE.
kl_coeff: KL Coefficient for Divergence loss in model loss.
free_nats: Threshold for minimum divergence in model loss.
log: If log, generate gifs.
"""
encoder_weights = list(model.encoder.parameters())
decoder_weights = list(model.decoder.parameters())
@ -84,7 +85,7 @@ def compute_dreamer_loss(
with FreezeParameters(model_weights + critic_weights):
reward = model.reward(imag_feat).mean
value = model.value(imag_feat).mean
pcont = discount * torch.ones_like(reward)
pcont = gamma * torch.ones_like(reward)
returns = lambda_return(reward[:-1], value[:-1], pcont[:-1], value[-1], lambda_)
discount_shape = pcont[:1].size()
discount = torch.cumprod(
@ -168,7 +169,7 @@ def dreamer_loss(policy, model, dist_class, train_batch):
train_batch["rewards"],
policy.model,
policy.config["imagine_horizon"],
policy.config["discount"],
policy.config["gamma"],
policy.config["lambda"],
policy.config["kl_coeff"],
policy.config["free_nats"],

View file

@ -18,23 +18,24 @@ class TestDreamer(unittest.TestCase):
def test_dreamer_compilation(self):
"""Test whether an DreamerTrainer can be built with all frameworks."""
config = dreamer.DEFAULT_CONFIG.copy()
config["env_config"] = {
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
"action_space": Box(-1.0, 1.0, (3,)),
}
config = dreamer.DREAMERConfig()
config.environment(
env=RandomEnv,
env_config={
"observation_space": Box(-1.0, 1.0, (3, 64, 64)),
"action_space": Box(-1.0, 1.0, (3,)),
},
)
# Num episode chunks per batch.
config["batch_size"] = 2
# Length (ts) of an episode chunk in a batch.
config["batch_length"] = 20
# Sub-iterations per .train() call.
config["dreamer_train_iters"] = 4
config.training(batch_size=2, batch_length=20, dreamer_train_iters=4)
num_iterations = 1
# Test against all frameworks.
for _ in framework_iterator(config, frameworks="torch"):
trainer = dreamer.DREAMERTrainer(config=config, env=RandomEnv)
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
print(results)

View file

@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
class ImpalaConfig(TrainerConfig):
"""Defines an ARSTrainer configuration class from which an ImpalaTrainer can be built.
"""Defines a configuration class from which an ImpalaTrainer can be built.
Example:
>>> from ray.rllib.agents.impala import ImpalaConfig
@ -136,13 +136,10 @@ class ImpalaConfig(TrainerConfig):
self.num_gpus = 1
self.lr = 0.0005
self.min_time_s_per_reporting = 10
# IMPALA and APPO are not on the new training_iteration API yet.
self._disable_execution_plan_api = False
# __sphinx_doc_end__
# fmt: on
# Deprecated value.
self._disable_execution_plan_api = True
self.num_data_loader_buffers = DEPRECATED_VALUE
@override(TrainerConfig)

View file

@ -159,6 +159,14 @@ RAY_CONFIG(int64_t, max_grpc_message_size, 250 * 1024 * 1024)
// of retries is non zero.
RAY_CONFIG(int64_t, grpc_server_retry_timeout_milliseconds, 1000)
// Whether to allow HTTP proxy on GRPC clients. Disable HTTP proxy by default since it
// disrupts local connections. Note that this config item only controls GrpcClient in
// `src/ray/rpc/grpc_client.h`. Python GRPC clients are not directly controlled by this.
// NOTE (kfstorm): DO NOT set this config item via `_system_config`, use
// `RAY_grpc_enable_http_proxy` environment variable instead so that it can be passed to
// non-C++ children processes such as dashboard agent.
RAY_CONFIG(bool, grpc_enable_http_proxy, false)
// The min number of retries for direct actor creation tasks. The actual number
// of creation retries will be MAX(actor_creation_min_retries, max_restarts).
RAY_CONFIG(uint64_t, actor_creation_min_retries, 3)

View file

@ -357,9 +357,8 @@ struct SyncerServerTest {
std::shared_ptr<grpc::Channel> MakeChannel(std::string port) {
grpc::ChannelArguments argument;
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
// this configurable, or selectively set it for known local connections only.
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());

View file

@ -116,9 +116,8 @@ int main(int argc, char *argv[]) {
}
if (leader_port != ".") {
grpc::ChannelArguments argument;
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
// this configurable, or selectively set it for known local connections only.
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());

View file

@ -29,6 +29,7 @@
#include "ray/gcs/gcs_server/gcs_worker_manager.h"
#include "ray/gcs/gcs_server/stats_handler_impl.h"
#include "ray/gcs/gcs_server/store_client_kv.h"
#include "ray/gcs/store_client/observable_store_client.h"
#include "ray/pubsub/publisher.h"
namespace ray {
@ -496,8 +497,9 @@ void GcsServer::InitKVManager() {
if (storage_type_ == "redis") {
instance = std::make_unique<RedisInternalKV>(GetRedisClientOptions());
} else if (storage_type_ == "memory") {
instance = std::make_unique<StoreClientInternalKV>(
std::make_unique<InMemoryStoreClient>(main_service_));
instance =
std::make_unique<StoreClientInternalKV>(std::make_unique<ObservableStoreClient>(
std::make_unique<InMemoryStoreClient>(main_service_)));
}
kv_manager_ = std::make_unique<GcsInternalKVManager>(std::move(instance));

View file

@ -19,6 +19,7 @@
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/gcs/store_client/in_memory_store_client.h"
#include "ray/gcs/store_client/observable_store_client.h"
#include "ray/gcs/store_client/redis_store_client.h"
#include "src/ray/protobuf/gcs.pb.h"
@ -366,7 +367,8 @@ class RedisGcsTableStorage : public GcsTableStorage {
class InMemoryGcsTableStorage : public GcsTableStorage {
public:
explicit InMemoryGcsTableStorage(instrumented_io_context &main_io_service)
: GcsTableStorage(std::make_shared<InMemoryStoreClient>(main_io_service)) {}
: GcsTableStorage(std::make_shared<ObservableStoreClient>(
std::make_unique<InMemoryStoreClient>(main_io_service))) {}
};
} // namespace gcs

View file

@ -0,0 +1,163 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/gcs/store_client/observable_store_client.h"
#include "absl/time/time.h"
#include "ray/stats/metric_defs.h"
namespace ray {
namespace gcs {
using namespace ray::stats;
Status ObservableStoreClient::AsyncPut(const std::string &table_name,
const std::string &key,
const std::string &data,
bool overwrite,
std::function<void(bool)> callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "Put");
return delegate_->AsyncPut(table_name,
key,
data,
overwrite,
[start, callback = std::move(callback)](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1),
"Put");
if (callback) {
callback(std::move(result));
}
});
}
Status ObservableStoreClient::AsyncGet(
const std::string &table_name,
const std::string &key,
const OptionalItemCallback<std::string> &callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "Get");
return delegate_->AsyncGet(
table_name, key, [start, callback](auto status, auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Get");
if (callback) {
callback(status, std::move(result));
}
});
}
Status ObservableStoreClient::AsyncGetAll(
const std::string &table_name,
const MapCallback<std::string, std::string> &callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "GetAll");
return delegate_->AsyncGetAll(table_name, [start, callback](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "GetAll");
if (callback) {
callback(std::move(result));
}
});
}
Status ObservableStoreClient::AsyncMultiGet(
const std::string &table_name,
const std::vector<std::string> &keys,
const MapCallback<std::string, std::string> &callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "MultiGet");
return delegate_->AsyncMultiGet(table_name, keys, [start, callback](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "MultiGet");
if (callback) {
callback(std::move(result));
}
});
}
Status ObservableStoreClient::AsyncDelete(const std::string &table_name,
const std::string &key,
std::function<void(bool)> callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "Delete");
return delegate_->AsyncDelete(
table_name, key, [start, callback = std::move(callback)](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Delete");
if (callback) {
callback(std::move(result));
}
});
}
Status ObservableStoreClient::AsyncBatchDelete(const std::string &table_name,
const std::vector<std::string> &keys,
std::function<void(int64_t)> callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "BatchDelete");
return delegate_->AsyncBatchDelete(
table_name, keys, [start, callback = std::move(callback)](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "BatchDelete");
if (callback) {
callback(std::move(result));
}
});
}
int ObservableStoreClient::GetNextJobID() { return delegate_->GetNextJobID(); }
Status ObservableStoreClient::AsyncGetKeys(
const std::string &table_name,
const std::string &prefix,
std::function<void(std::vector<std::string>)> callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "GetKeys");
return delegate_->AsyncGetKeys(
table_name, prefix, [start, callback = std::move(callback)](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "GetKeys");
if (callback) {
callback(std::move(result));
}
});
}
Status ObservableStoreClient::AsyncExists(const std::string &table_name,
const std::string &key,
std::function<void(bool)> callback) {
auto start = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_count.Record(1, "Exists");
return delegate_->AsyncExists(
table_name, key, [start, callback = std::move(callback)](auto result) {
auto end = absl::GetCurrentTimeNanos();
STATS_gcs_storage_operation_latency_ms.Record(
absl::Nanoseconds(end - start) / absl::Milliseconds(1), "Exists");
if (callback) {
callback(std::move(result));
}
});
}
} // namespace gcs
} // namespace ray

View file

@ -0,0 +1,70 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "ray/gcs/store_client/store_client.h"
namespace ray {
namespace gcs {
/// Wraps around a StoreClient instance and observe the metrics.
class ObservableStoreClient : public StoreClient {
public:
explicit ObservableStoreClient(std::unique_ptr<StoreClient> delegate)
: delegate_(std::move(delegate)) {}
Status AsyncPut(const std::string &table_name,
const std::string &key,
const std::string &data,
bool overwrite,
std::function<void(bool)> callback) override;
Status AsyncGet(const std::string &table_name,
const std::string &key,
const OptionalItemCallback<std::string> &callback) override;
Status AsyncGetAll(const std::string &table_name,
const MapCallback<std::string, std::string> &callback) override;
Status AsyncMultiGet(const std::string &table_name,
const std::vector<std::string> &keys,
const MapCallback<std::string, std::string> &callback) override;
Status AsyncDelete(const std::string &table_name,
const std::string &key,
std::function<void(bool)> callback) override;
Status AsyncBatchDelete(const std::string &table_name,
const std::vector<std::string> &keys,
std::function<void(int64_t)> callback) override;
int GetNextJobID() override;
Status AsyncGetKeys(const std::string &table_name,
const std::string &prefix,
std::function<void(std::vector<std::string>)> callback) override;
Status AsyncExists(const std::string &table_name,
const std::string &key,
std::function<void(bool)> callback) override;
private:
std::unique_ptr<StoreClient> delegate_;
};
} // namespace gcs
} // namespace ray

View file

@ -34,7 +34,6 @@ TEST_F(InMemoryStoreClientTest, AsyncPutAndAsyncGetTest) { TestAsyncPutAndAsyncG
TEST_F(InMemoryStoreClientTest, AsyncGetAllAndBatchDeleteTest) {
TestAsyncGetAllAndBatchDelete();
}
} // namespace gcs
} // namespace ray

View file

@ -0,0 +1,47 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/gcs/store_client/observable_store_client.h"
#include "ray/gcs/store_client/in_memory_store_client.h"
#include "ray/gcs/store_client/test/store_client_test_base.h"
namespace ray {
namespace gcs {
class ObservableStoreClientTest : public StoreClientTestBase {
public:
void InitStoreClient() override {
store_client_ = std::make_shared<ObservableStoreClient>(
std::make_unique<InMemoryStoreClient>(*(io_service_pool_->Get())));
}
void DisconnectStoreClient() override {}
};
TEST_F(ObservableStoreClientTest, AsyncPutAndAsyncGetTest) { TestAsyncPutAndAsyncGet(); }
TEST_F(ObservableStoreClientTest, AsyncGetAllAndBatchDeleteTest) {
TestAsyncGetAllAndBatchDelete();
}
} // namespace gcs
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View file

@ -51,9 +51,8 @@ inline std::shared_ptr<grpc::Channel> BuildChannel(
std::optional<grpc::ChannelArguments> arguments = std::nullopt) {
if (!arguments.has_value()) {
arguments = grpc::ChannelArguments();
// Disable http proxy since it disrupts local connections. TODO(ekl) we should make
// this configurable, or selectively set it for known local connections only.
arguments->SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
arguments->SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
arguments->SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
arguments->SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());
}
@ -112,7 +111,8 @@ class GrpcClient {
quota.SetMaxThreads(num_threads);
grpc::ChannelArguments argument;
argument.SetResourceQuota(quota);
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY, 0);
argument.SetInt(GRPC_ARG_ENABLE_HTTP_PROXY,
::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0);
argument.SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size());
argument.SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size());

View file

@ -184,6 +184,18 @@ DEFINE_stats(gcs_new_resource_creation_latency_ms,
({0.1, 1, 10, 100, 1000, 10000}, ),
ray::stats::HISTOGRAM);
/// GCS Storage
DEFINE_stats(gcs_storage_operation_latency_ms,
"Time to invoke an operation on Gcs storage",
("Operation"),
({0.1, 1, 10, 100, 1000, 10000}, ),
ray::stats::HISTOGRAM);
DEFINE_stats(gcs_storage_operation_count,
"Number of operations invoked on Gcs storage",
("Operation"),
(),
ray::stats::COUNT);
/// Placement Group
// The end to end placement group creation latency.
// The time from placement group creation request has received

View file

@ -84,6 +84,10 @@ DECLARE_stats(spill_manager_objects_bytes);
DECLARE_stats(spill_manager_request_total);
DECLARE_stats(spill_manager_throughput_mb);
/// GCS Storage
DECLARE_stats(gcs_storage_operation_latency_ms);
DECLARE_stats(gcs_storage_operation_count);
/// GCS Resource Manager
DECLARE_stats(gcs_new_resource_creation_latency_ms);