[RLlib] Policy Server/Client metrics reporting fix (#24783)

This commit is contained in:
Artur Niederfahrenhorst 2022-05-15 17:25:25 +02:00 committed by GitHub
parent 6321c3a85c
commit b1bc435adc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 36 deletions

View file

@ -8,6 +8,7 @@ import logging
import threading
import time
from typing import Union, Optional
from enum import Enum
import ray.cloudpickle as pickle
from ray.rllib.env.external_env import ExternalEnv
@ -36,9 +37,7 @@ except ImportError:
@PublicAPI
class PolicyClient:
"""REST client to interact with a RLlib policy server."""
class Commands(Enum):
# Generic commands (for both modes).
ACTION_SPACE = "ACTION_SPACE"
OBSERVATION_SPACE = "OBSERVATION_SPACE"
@ -55,6 +54,11 @@ class PolicyClient:
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"
@PublicAPI
class PolicyClient:
"""REST client to interact with an RLlib policy server."""
@PublicAPI
def __init__(
self, address: str, inference_mode: str = "local", update_interval: float = 10.0
@ -102,7 +106,7 @@ class PolicyClient:
return self._send(
{
"episode_id": episode_id,
"command": PolicyClient.START_EPISODE,
"command": Commands.START_EPISODE,
"training_enabled": training_enabled,
}
)["episode_id"]
@ -134,7 +138,7 @@ class PolicyClient:
else:
return self._send(
{
"command": PolicyClient.GET_ACTION,
"command": Commands.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
}
@ -161,7 +165,7 @@ class PolicyClient:
self._send(
{
"command": PolicyClient.LOG_ACTION,
"command": Commands.LOG_ACTION,
"observation": observation,
"action": action,
"episode_id": episode_id,
@ -200,7 +204,7 @@ class PolicyClient:
self._send(
{
"command": PolicyClient.LOG_RETURNS,
"command": Commands.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
@ -225,7 +229,7 @@ class PolicyClient:
self._send(
{
"command": PolicyClient.END_EPISODE,
"command": Commands.END_EPISODE,
"observation": observation,
"episode_id": episode_id,
}
@ -252,7 +256,7 @@ class PolicyClient:
logger.info("Querying server for rollout worker settings.")
kwargs = self._send(
{
"command": PolicyClient.GET_WORKER_ARGS,
"command": Commands.GET_WORKER_ARGS,
}
)["worker_args"]
(self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
@ -269,7 +273,7 @@ class PolicyClient:
logger.info("Querying server for new policy weights.")
resp = self._send(
{
"command": PolicyClient.GET_WEIGHTS,
"command": Commands.GET_WEIGHTS,
}
)
weights = resp["weights"]
@ -311,7 +315,7 @@ class _LocalInferenceThread(threading.Thread):
)
self.send_fn(
{
"command": PolicyClient.REPORT_SAMPLES,
"command": Commands.REPORT_SAMPLES,
"samples": samples,
"metrics": metrics,
}

View file

@ -6,11 +6,18 @@ import threading
import time
import traceback
from typing import List
import ray.cloudpickle as pickle
from ray.rllib.env.policy_client import PolicyClient, _create_embedded_rollout_worker
from ray.rllib.env.policy_client import (
_create_embedded_rollout_worker,
Commands,
)
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.evaluation.sampler import SamplerInput
from ray.rllib.utils.typing import SampleBatchType
logger = logging.getLogger(__name__)
@ -77,7 +84,45 @@ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
self.metrics_queue = queue.Queue()
self.idle_timeout = idle_timeout
# Forwards client-reported metrics directly into the local rollout
# worker.
if self.rollout_worker.sampler is not None:
# This is a bit of a hack since it is patching the get_metrics
# function of the sampler.
def get_metrics():
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
self.rollout_worker.sampler.get_metrics = get_metrics
else:
# If there is no sampler, act like if there would be one to collect
# metrics from
class MetricsDummySampler(SamplerInput):
"""This sampler only maintains a queue to get metrics from."""
def __init__(self, metrics_queue):
"""Initializes an AsyncSampler instance.
Args:
metrics_queue: A queue of metrics
"""
self.metrics_queue = metrics_queue
def get_data(self) -> SampleBatchType:
raise NotImplementedError
def get_extra_batches(self) -> List[SampleBatchType]:
raise NotImplementedError
def get_metrics(self) -> List[RolloutMetrics]:
"""Returns metrics computed on a policy client rollout worker."""
completed = []
while True:
try:
@ -86,11 +131,7 @@ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
break
return completed
# Forwards client-reported rewards directly into the local rollout
# worker. This is a bit of a hack since it is patching the get_metrics
# function of the sampler.
if self.rollout_worker.sampler is not None:
self.rollout_worker.sampler.get_metrics = get_metrics
self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue)
# Create a request handler that receives commands from the clients
# and sends data and metrics into the queues.
@ -153,10 +194,11 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
def setup_child_rollout_worker():
nonlocal lock
with lock:
nonlocal child_rollout_worker
nonlocal inference_thread
with lock:
if child_rollout_worker is None:
(
child_rollout_worker,
@ -201,14 +243,14 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
response = {}
# Local inference commands:
if command == PolicyClient.GET_WORKER_ARGS:
if command == Commands.GET_WORKER_ARGS:
logger.info("Sending worker creation args to client.")
response["worker_args"] = rollout_worker.creation_args()
elif command == PolicyClient.GET_WEIGHTS:
elif command == Commands.GET_WEIGHTS:
logger.info("Sending worker weights to client.")
response["weights"] = rollout_worker.get_weights()
response["global_vars"] = rollout_worker.get_global_vars()
elif command == PolicyClient.REPORT_SAMPLES:
elif command == Commands.REPORT_SAMPLES:
logger.info(
"Got sample batch of size {} from client.".format(
args["samples"].count
@ -217,23 +259,23 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
report_data(args)
# Remote inference commands:
elif command == PolicyClient.START_EPISODE:
elif command == Commands.START_EPISODE:
setup_child_rollout_worker()
assert inference_thread.is_alive()
response["episode_id"] = child_rollout_worker.env.start_episode(
args["episode_id"], args["training_enabled"]
)
elif command == PolicyClient.GET_ACTION:
elif command == Commands.GET_ACTION:
assert inference_thread.is_alive()
response["action"] = child_rollout_worker.env.get_action(
args["episode_id"], args["observation"]
)
elif command == PolicyClient.LOG_ACTION:
elif command == Commands.LOG_ACTION:
assert inference_thread.is_alive()
child_rollout_worker.env.log_action(
args["episode_id"], args["observation"], args["action"]
)
elif command == PolicyClient.LOG_RETURNS:
elif command == Commands.LOG_RETURNS:
assert inference_thread.is_alive()
if args["done"]:
child_rollout_worker.env.log_returns(
@ -243,7 +285,7 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"]
)
elif command == PolicyClient.END_EPISODE:
elif command == Commands.END_EPISODE:
assert inference_thread.is_alive()
child_rollout_worker.env.end_episode(
args["episode_id"], args["observation"]