[RLlib] Add timeout to filter synchronization. (#25959)

This commit is contained in:
Artur Niederfahrenhorst 2022-06-24 14:37:43 +02:00 committed by GitHub
parent eee866d762
commit bed9083f35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 13 deletions

View file

@ -612,7 +612,12 @@ class Algorithm(Trainable):
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
# Sync filters on workers.
self._sync_filters_if_needed(self.workers)
self._sync_filters_if_needed(
self.workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)
# Collect worker metrics and add combine them with `results`.
if self.config["_disable_execution_plan_api"]:
@ -674,7 +679,12 @@ class Algorithm(Trainable):
self.evaluation_workers.sync_weights(
from_worker=self.workers.local_worker()
)
self._sync_filters_if_needed(self.evaluation_workers)
self._sync_filters_if_needed(
self.evaluation_workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)
if self.config["custom_eval_function"]:
logger.info(
@ -1597,12 +1607,15 @@ class Algorithm(Trainable):
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
)
def _sync_filters_if_needed(self, workers: WorkerSet):
def _sync_filters_if_needed(
self, workers: WorkerSet, timeout_seconds: Optional[float] = None
):
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
FilterManager.synchronize(
workers.local_worker().filters,
workers.remote_workers(),
update_remote=self.config["synchronize_filters"],
timeout_seconds=timeout_seconds,
)
logger.debug(
"synchronized filters: {}".format(workers.local_worker().filters)
@ -2273,6 +2286,7 @@ class Algorithm(Trainable):
ignore=self.config["ignore_worker_failures"],
recreate=self.config["recreate_failed_workers"],
)
return results, train_iter_ctx
def _run_one_evaluation(

View file

@ -189,6 +189,7 @@ class AlgorithmConfig:
# TODO: Set this flag still in the config or - much better - in the
# RolloutWorker as a property.
self.in_evaluation = False
self.sync_filters_on_rollout_workers_timeout_s = 60.0
# `self.reporting()`
self.keep_per_episode_custom_metrics = False

View file

@ -64,9 +64,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
policy = algo.get_policy()
view_req_model = policy.model.view_requirements
view_req_policy = policy.view_requirements
print(_)
print(view_req_policy)
print(view_req_model)
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 11, view_req_policy
for key in [
@ -321,8 +318,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
normalize_actions=False,
num_envs=1,
)
batch = rollout_worker_w_api.sample()
print(batch)
batch = rollout_worker_w_api.sample() # noqa: F841
def test_counting_by_agent_steps(self):
config = copy.deepcopy(ppo.DEFAULT_CONFIG)

View file

@ -1,5 +1,11 @@
import logging
from typing import Optional
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.exceptions import GetTimeoutError
logger = logging.getLogger(__name__)
@DeveloperAPI
@ -10,7 +16,12 @@ class FilterManager:
@staticmethod
@DeveloperAPI
def synchronize(local_filters, remotes, update_remote=True):
def synchronize(
local_filters,
remotes,
update_remote=True,
timeout_seconds: Optional[float] = None,
):
"""Aggregates all filters from remote evaluators.
Local copy is updated and then broadcasted to all remote evaluators.
@ -19,14 +30,37 @@ class FilterManager:
local_filters: Filters to be synchronized.
remotes: Remote evaluators with filters.
update_remote: Whether to push updates to remote filters.
timeout_seconds: How long to wait for filter to get or set filters
"""
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes]
)
try:
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes],
timeout=timeout_seconds,
)
except GetTimeoutError:
logger.error(
"Failed to get remote filters from a rollout worker in "
"FilterManager. "
"Filtered "
"metrics may be computed, but filtered wrong."
)
for rf in remote_filters:
for k in local_filters:
local_filters[k].apply_changes(rf[k], with_buffer=False)
if update_remote:
copies = {k: v.as_serializable() for k, v in local_filters.items()}
remote_copy = ray.put(copies)
[r.sync_filters.remote(remote_copy) for r in remotes]
try:
ray.get(
[r.sync_filters.remote(remote_copy) for r in remotes],
timeout=timeout_seconds,
)
except GetTimeoutError:
logger.error(
"Failed to set remote filters to a rollout worker in "
"FilterManager. "
"Filtered "
"metrics may be computed, but filtered wrong."
)