From b1bc435adce63649d7af1722ca8514df942cc397 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Sun, 15 May 2022 17:25:25 +0200 Subject: [PATCH] [RLlib] Policy Server/Client metrics reporting fix (#24783) --- rllib/env/policy_client.py | 26 +++++---- rllib/env/policy_server_input.py | 92 +++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 36 deletions(-) diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 31478eb60..e7fb3f273 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -8,6 +8,7 @@ import logging import threading import time from typing import Union, Optional +from enum import Enum import ray.cloudpickle as pickle from ray.rllib.env.external_env import ExternalEnv @@ -36,9 +37,7 @@ except ImportError: @PublicAPI -class PolicyClient: - """REST client to interact with a RLlib policy server.""" - +class Commands(Enum): # Generic commands (for both modes). ACTION_SPACE = "ACTION_SPACE" OBSERVATION_SPACE = "OBSERVATION_SPACE" @@ -55,6 +54,11 @@ class PolicyClient: LOG_RETURNS = "LOG_RETURNS" END_EPISODE = "END_EPISODE" + +@PublicAPI +class PolicyClient: + """REST client to interact with an RLlib policy server.""" + @PublicAPI def __init__( self, address: str, inference_mode: str = "local", update_interval: float = 10.0 @@ -102,7 +106,7 @@ class PolicyClient: return self._send( { "episode_id": episode_id, - "command": PolicyClient.START_EPISODE, + "command": Commands.START_EPISODE, "training_enabled": training_enabled, } )["episode_id"] @@ -134,7 +138,7 @@ class PolicyClient: else: return self._send( { - "command": PolicyClient.GET_ACTION, + "command": Commands.GET_ACTION, "observation": observation, "episode_id": episode_id, } @@ -161,7 +165,7 @@ class PolicyClient: self._send( { - "command": PolicyClient.LOG_ACTION, + "command": Commands.LOG_ACTION, "observation": observation, "action": action, "episode_id": episode_id, @@ -200,7 +204,7 @@ class PolicyClient: self._send( { - "command": PolicyClient.LOG_RETURNS, + "command": Commands.LOG_RETURNS, "reward": reward, "info": info, "episode_id": episode_id, @@ -225,7 +229,7 @@ class PolicyClient: self._send( { - "command": PolicyClient.END_EPISODE, + "command": Commands.END_EPISODE, "observation": observation, "episode_id": episode_id, } @@ -252,7 +256,7 @@ class PolicyClient: logger.info("Querying server for rollout worker settings.") kwargs = self._send( { - "command": PolicyClient.GET_WORKER_ARGS, + "command": Commands.GET_WORKER_ARGS, } )["worker_args"] (self.rollout_worker, self.inference_thread) = _create_embedded_rollout_worker( @@ -269,7 +273,7 @@ class PolicyClient: logger.info("Querying server for new policy weights.") resp = self._send( { - "command": PolicyClient.GET_WEIGHTS, + "command": Commands.GET_WEIGHTS, } ) weights = resp["weights"] @@ -311,7 +315,7 @@ class _LocalInferenceThread(threading.Thread): ) self.send_fn( { - "command": PolicyClient.REPORT_SAMPLES, + "command": Commands.REPORT_SAMPLES, "samples": samples, "metrics": metrics, } diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index 2ae632a86..53f0a3923 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -6,11 +6,18 @@ import threading import time import traceback +from typing import List import ray.cloudpickle as pickle -from ray.rllib.env.policy_client import PolicyClient, _create_embedded_rollout_worker +from ray.rllib.env.policy_client import ( + _create_embedded_rollout_worker, + Commands, +) from ray.rllib.offline.input_reader import InputReader from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.evaluation.metrics import RolloutMetrics +from ray.rllib.evaluation.sampler import SamplerInput +from ray.rllib.utils.typing import SampleBatchType logger = logging.getLogger(__name__) @@ -77,22 +84,56 @@ class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader): self.metrics_queue = queue.Queue() self.idle_timeout = idle_timeout - def get_metrics(): - completed = [] - while True: - try: - completed.append(self.metrics_queue.get_nowait()) - except queue.Empty: - break - return completed - - # Forwards client-reported rewards directly into the local rollout - # worker. This is a bit of a hack since it is patching the get_metrics - # function of the sampler. + # Forwards client-reported metrics directly into the local rollout + # worker. if self.rollout_worker.sampler is not None: - self.rollout_worker.sampler.get_metrics = get_metrics + # This is a bit of a hack since it is patching the get_metrics + # function of the sampler. - # Create a request handler that receives commands from the clients + def get_metrics(): + completed = [] + while True: + try: + completed.append(self.metrics_queue.get_nowait()) + except queue.Empty: + break + + return completed + + self.rollout_worker.sampler.get_metrics = get_metrics + else: + # If there is no sampler, act like if there would be one to collect + # metrics from + class MetricsDummySampler(SamplerInput): + """This sampler only maintains a queue to get metrics from.""" + + def __init__(self, metrics_queue): + """Initializes an AsyncSampler instance. + + Args: + metrics_queue: A queue of metrics + """ + self.metrics_queue = metrics_queue + + def get_data(self) -> SampleBatchType: + raise NotImplementedError + + def get_extra_batches(self) -> List[SampleBatchType]: + raise NotImplementedError + + def get_metrics(self) -> List[RolloutMetrics]: + """Returns metrics computed on a policy client rollout worker.""" + completed = [] + while True: + try: + completed.append(self.metrics_queue.get_nowait()) + except queue.Empty: + break + return completed + + self.rollout_worker.sampler = MetricsDummySampler(self.metrics_queue) + + # Create a request handler that receives commands from the clients # and sends data and metrics into the queues. handler = _make_handler( self.rollout_worker, self.samples_queue, self.metrics_queue @@ -153,10 +194,11 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue): def setup_child_rollout_worker(): nonlocal lock - nonlocal child_rollout_worker - nonlocal inference_thread with lock: + nonlocal child_rollout_worker + nonlocal inference_thread + if child_rollout_worker is None: ( child_rollout_worker, @@ -201,14 +243,14 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue): response = {} # Local inference commands: - if command == PolicyClient.GET_WORKER_ARGS: + if command == Commands.GET_WORKER_ARGS: logger.info("Sending worker creation args to client.") response["worker_args"] = rollout_worker.creation_args() - elif command == PolicyClient.GET_WEIGHTS: + elif command == Commands.GET_WEIGHTS: logger.info("Sending worker weights to client.") response["weights"] = rollout_worker.get_weights() response["global_vars"] = rollout_worker.get_global_vars() - elif command == PolicyClient.REPORT_SAMPLES: + elif command == Commands.REPORT_SAMPLES: logger.info( "Got sample batch of size {} from client.".format( args["samples"].count @@ -217,23 +259,23 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue): report_data(args) # Remote inference commands: - elif command == PolicyClient.START_EPISODE: + elif command == Commands.START_EPISODE: setup_child_rollout_worker() assert inference_thread.is_alive() response["episode_id"] = child_rollout_worker.env.start_episode( args["episode_id"], args["training_enabled"] ) - elif command == PolicyClient.GET_ACTION: + elif command == Commands.GET_ACTION: assert inference_thread.is_alive() response["action"] = child_rollout_worker.env.get_action( args["episode_id"], args["observation"] ) - elif command == PolicyClient.LOG_ACTION: + elif command == Commands.LOG_ACTION: assert inference_thread.is_alive() child_rollout_worker.env.log_action( args["episode_id"], args["observation"], args["action"] ) - elif command == PolicyClient.LOG_RETURNS: + elif command == Commands.LOG_RETURNS: assert inference_thread.is_alive() if args["done"]: child_rollout_worker.env.log_returns( @@ -243,7 +285,7 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue): child_rollout_worker.env.log_returns( args["episode_id"], args["reward"], args["info"] ) - elif command == PolicyClient.END_EPISODE: + elif command == Commands.END_EPISODE: assert inference_thread.is_alive() child_rollout_worker.env.end_episode( args["episode_id"], args["observation"]