mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Policy Server/Client metrics reporting fix (#24783)
This commit is contained in:
parent
6321c3a85c
commit
b1bc435adc
2 changed files with 82 additions and 36 deletions
26
rllib/env/policy_client.py
vendored
26
rllib/env/policy_client.py
vendored
|
@ -8,6 +8,7 @@ import logging
|
|||
import threading
|
||||
import time
|
||||
from typing import Union, Optional
|
||||
from enum import Enum
|
||||
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
|
@ -36,9 +37,7 @@ except ImportError:
|
|||
|
||||
|
||||
@PublicAPI
|
||||
class PolicyClient:
|
||||
"""REST client to interact with a RLlib policy server."""
|
||||
|
||||
class Commands(Enum):
|
||||
# Generic commands (for both modes).
|
||||
ACTION_SPACE = "ACTION_SPACE"
|
||||
OBSERVATION_SPACE = "OBSERVATION_SPACE"
|
||||
|
@ -55,6 +54,11 @@ class PolicyClient:
|
|||
LOG_RETURNS = "LOG_RETURNS"
|
||||
END_EPISODE = "END_EPISODE"
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class PolicyClient:
|
||||
"""REST client to interact with an RLlib policy server."""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(
|
||||
self, address: str, inference_mode: str = "local", update_interval: float = 10.0
|
||||
|
@ -102,7 +106,7 @@ class PolicyClient:
|
|||
return self._send(
|
||||
{
|
||||
"episode_id": episode_id,
|
||||
"command": PolicyClient.START_EPISODE,
|
||||
"command": Commands.START_EPISODE,
|
||||
"training_enabled": training_enabled,
|
||||
}
|
||||
)["episode_id"]
|
||||
|
@ -134,7 +138,7 @@ class PolicyClient:
|
|||
else:
|
||||
return self._send(
|
||||
{
|
||||
"command": PolicyClient.GET_ACTION,
|
||||
"command": Commands.GET_ACTION,
|
||||
"observation": observation,
|
||||
"episode_id": episode_id,
|
||||
}
|
||||
|
@ -161,7 +165,7 @@ class PolicyClient:
|
|||
|
||||
self._send(
|
||||
{
|
||||
"command": PolicyClient.LOG_ACTION,
|
||||
"command": Commands.LOG_ACTION,
|
||||
"observation": observation,
|
||||
"action": action,
|
||||
"episode_id": episode_id,
|
||||
|
@ -200,7 +204,7 @@ class PolicyClient:
|
|||
|
||||
self._send(
|
||||
{
|
||||
"command": PolicyClient.LOG_RETURNS,
|
||||
"command": Commands.LOG_RETURNS,
|
||||
"reward": reward,
|
||||
"info": info,
|
||||
"episode_id": episode_id,
|
||||
|
@ -225,7 +229,7 @@ class PolicyClient:
|
|||
|
||||
self._send(
|
||||
{
|
||||
"command": PolicyClient.END_EPISODE,
|
||||
"command": Commands.END_EPISODE,
|
||||
"observation": observation,
|
||||
"episode_id": episode_id,
|
||||
}
|
||||
|
@ -252,7 +256,7 @@ class PolicyClient:
|
|||
logger.info("Querying server for rollout worker settings.")
|
||||
kwargs = self._send(
|
||||
{
|
||||
"command": PolicyClient.GET_WORKER_ARGS,
|
||||
"command": Commands.GET_WORKER_ARGS,
|
||||
}
|
||||
)["worker_args"]
|
||||
(self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker(
|
||||
|
@ -269,7 +273,7 @@ class PolicyClient:
|
|||
logger.info("Querying server for new policy weights.")
|
||||
resp = self._send(
|
||||
{
|
||||
"command": PolicyClient.GET_WEIGHTS,
|
||||
"command": Commands.GET_WEIGHTS,
|
||||
}
|
||||
)
|
||||
weights = resp["weights"]
|
||||
|
@ -311,7 +315,7 @@ class _LocalInferenceThread(threading.Thread):
|
|||
)
|
||||
self.send_fn(
|
||||
{
|
||||
"command": PolicyClient.REPORT_SAMPLES,
|
||||
"command": Commands.REPORT_SAMPLES,
|
||||
"samples": samples,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
|
92
rllib/env/policy_server_input.py
vendored
92
rllib/env/policy_server_input.py
vendored
|
@ -6,11 +6,18 @@ import threading
|
|||
import time
|
||||
import traceback
|
||||
|
||||
from typing import List
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.rllib.env.policy_client import PolicyClient, _create_embedded_rollout_worker
|
||||
from ray.rllib.env.policy_client import (
|
||||
_create_embedded_rollout_worker,
|
||||
Commands,
|
||||
)
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.evaluation.metrics import RolloutMetrics
|
||||
from ray.rllib.evaluation.sampler import SamplerInput
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -77,22 +84,56 @@ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
|
|||
self.metrics_queue = queue.Queue()
|
||||
self.idle_timeout = idle_timeout
|
||||
|
||||
def get_metrics():
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
# Forwards client-reported rewards directly into the local rollout
|
||||
# worker. This is a bit of a hack since it is patching the get_metrics
|
||||
# function of the sampler.
|
||||
# Forwards client-reported metrics directly into the local rollout
|
||||
# worker.
|
||||
if self.rollout_worker.sampler is not None:
|
||||
self.rollout_worker.sampler.get_metrics = get_metrics
|
||||
# This is a bit of a hack since it is patching the get_metrics
|
||||
# function of the sampler.
|
||||
|
||||
# Create a request handler that receives commands from the clients
|
||||
def get_metrics():
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
return completed
|
||||
|
||||
self.rollout_worker.sampler.get_metrics = get_metrics
|
||||
else:
|
||||
# If there is no sampler, act like if there would be one to collect
|
||||
# metrics from
|
||||
class MetricsDummySampler(SamplerInput):
|
||||
"""This sampler only maintains a queue to get metrics from."""
|
||||
|
||||
def __init__(self, metrics_queue):
|
||||
"""Initializes an AsyncSampler instance.
|
||||
|
||||
Args:
|
||||
metrics_queue: A queue of metrics
|
||||
"""
|
||||
self.metrics_queue = metrics_queue
|
||||
|
||||
def get_data(self) -> SampleBatchType:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_extra_batches(self) -> List[SampleBatchType]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_metrics(self) -> List[RolloutMetrics]:
|
||||
"""Returns metrics computed on a policy client rollout worker."""
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue)
|
||||
|
||||
# Create a request handler that receives commands from the clients
|
||||
# and sends data and metrics into the queues.
|
||||
handler = _make_handler(
|
||||
self.rollout_worker, self.samples_queue, self.metrics_queue
|
||||
|
@ -153,10 +194,11 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
|
|||
|
||||
def setup_child_rollout_worker():
|
||||
nonlocal lock
|
||||
nonlocal child_rollout_worker
|
||||
nonlocal inference_thread
|
||||
|
||||
with lock:
|
||||
nonlocal child_rollout_worker
|
||||
nonlocal inference_thread
|
||||
|
||||
if child_rollout_worker is None:
|
||||
(
|
||||
child_rollout_worker,
|
||||
|
@ -201,14 +243,14 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
|
|||
response = {}
|
||||
|
||||
# Local inference commands:
|
||||
if command == PolicyClient.GET_WORKER_ARGS:
|
||||
if command == Commands.GET_WORKER_ARGS:
|
||||
logger.info("Sending worker creation args to client.")
|
||||
response["worker_args"] = rollout_worker.creation_args()
|
||||
elif command == PolicyClient.GET_WEIGHTS:
|
||||
elif command == Commands.GET_WEIGHTS:
|
||||
logger.info("Sending worker weights to client.")
|
||||
response["weights"] = rollout_worker.get_weights()
|
||||
response["global_vars"] = rollout_worker.get_global_vars()
|
||||
elif command == PolicyClient.REPORT_SAMPLES:
|
||||
elif command == Commands.REPORT_SAMPLES:
|
||||
logger.info(
|
||||
"Got sample batch of size {} from client.".format(
|
||||
args["samples"].count
|
||||
|
@ -217,23 +259,23 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
|
|||
report_data(args)
|
||||
|
||||
# Remote inference commands:
|
||||
elif command == PolicyClient.START_EPISODE:
|
||||
elif command == Commands.START_EPISODE:
|
||||
setup_child_rollout_worker()
|
||||
assert inference_thread.is_alive()
|
||||
response["episode_id"] = child_rollout_worker.env.start_episode(
|
||||
args["episode_id"], args["training_enabled"]
|
||||
)
|
||||
elif command == PolicyClient.GET_ACTION:
|
||||
elif command == Commands.GET_ACTION:
|
||||
assert inference_thread.is_alive()
|
||||
response["action"] = child_rollout_worker.env.get_action(
|
||||
args["episode_id"], args["observation"]
|
||||
)
|
||||
elif command == PolicyClient.LOG_ACTION:
|
||||
elif command == Commands.LOG_ACTION:
|
||||
assert inference_thread.is_alive()
|
||||
child_rollout_worker.env.log_action(
|
||||
args["episode_id"], args["observation"], args["action"]
|
||||
)
|
||||
elif command == PolicyClient.LOG_RETURNS:
|
||||
elif command == Commands.LOG_RETURNS:
|
||||
assert inference_thread.is_alive()
|
||||
if args["done"]:
|
||||
child_rollout_worker.env.log_returns(
|
||||
|
@ -243,7 +285,7 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
|
|||
child_rollout_worker.env.log_returns(
|
||||
args["episode_id"], args["reward"], args["info"]
|
||||
)
|
||||
elif command == PolicyClient.END_EPISODE:
|
||||
elif command == Commands.END_EPISODE:
|
||||
assert inference_thread.is_alive()
|
||||
child_rollout_worker.env.end_episode(
|
||||
args["episode_id"], args["observation"]
|
||||
|
|
Loading…
Add table
Reference in a new issue