mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Add timeout to filter synchronization. (#25959)
This commit is contained in:
parent
eee866d762
commit
bed9083f35
4 changed files with 58 additions and 13 deletions
|
@ -612,7 +612,12 @@ class Algorithm(Trainable):
|
|||
|
||||
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
|
||||
# Sync filters on workers.
|
||||
self._sync_filters_if_needed(self.workers)
|
||||
self._sync_filters_if_needed(
|
||||
self.workers,
|
||||
timeout_seconds=self.config[
|
||||
"sync_filters_on_rollout_workers_timeout_s"
|
||||
],
|
||||
)
|
||||
|
||||
# Collect worker metrics and add combine them with `results`.
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
|
@ -674,7 +679,12 @@ class Algorithm(Trainable):
|
|||
self.evaluation_workers.sync_weights(
|
||||
from_worker=self.workers.local_worker()
|
||||
)
|
||||
self._sync_filters_if_needed(self.evaluation_workers)
|
||||
self._sync_filters_if_needed(
|
||||
self.evaluation_workers,
|
||||
timeout_seconds=self.config[
|
||||
"sync_filters_on_rollout_workers_timeout_s"
|
||||
],
|
||||
)
|
||||
|
||||
if self.config["custom_eval_function"]:
|
||||
logger.info(
|
||||
|
@ -1597,12 +1607,15 @@ class Algorithm(Trainable):
|
|||
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
|
||||
)
|
||||
|
||||
def _sync_filters_if_needed(self, workers: WorkerSet):
|
||||
def _sync_filters_if_needed(
|
||||
self, workers: WorkerSet, timeout_seconds: Optional[float] = None
|
||||
):
|
||||
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
|
||||
FilterManager.synchronize(
|
||||
workers.local_worker().filters,
|
||||
workers.remote_workers(),
|
||||
update_remote=self.config["synchronize_filters"],
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
logger.debug(
|
||||
"synchronized filters: {}".format(workers.local_worker().filters)
|
||||
|
@ -2273,6 +2286,7 @@ class Algorithm(Trainable):
|
|||
ignore=self.config["ignore_worker_failures"],
|
||||
recreate=self.config["recreate_failed_workers"],
|
||||
)
|
||||
|
||||
return results, train_iter_ctx
|
||||
|
||||
def _run_one_evaluation(
|
||||
|
|
|
@ -189,6 +189,7 @@ class AlgorithmConfig:
|
|||
# TODO: Set this flag still in the config or - much better - in the
|
||||
# RolloutWorker as a property.
|
||||
self.in_evaluation = False
|
||||
self.sync_filters_on_rollout_workers_timeout_s = 60.0
|
||||
|
||||
# `self.reporting()`
|
||||
self.keep_per_episode_custom_metrics = False
|
||||
|
|
|
@ -64,9 +64,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
policy = algo.get_policy()
|
||||
view_req_model = policy.model.view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
print(_)
|
||||
print(view_req_policy)
|
||||
print(view_req_model)
|
||||
assert len(view_req_model) == 1, view_req_model
|
||||
assert len(view_req_policy) == 11, view_req_policy
|
||||
for key in [
|
||||
|
@ -321,8 +318,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
normalize_actions=False,
|
||||
num_envs=1,
|
||||
)
|
||||
batch = rollout_worker_w_api.sample()
|
||||
print(batch)
|
||||
batch = rollout_worker_w_api.sample() # noqa: F841
|
||||
|
||||
def test_counting_by_agent_steps(self):
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.exceptions import GetTimeoutError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -10,7 +16,12 @@ class FilterManager:
|
|||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def synchronize(local_filters, remotes, update_remote=True):
|
||||
def synchronize(
|
||||
local_filters,
|
||||
remotes,
|
||||
update_remote=True,
|
||||
timeout_seconds: Optional[float] = None,
|
||||
):
|
||||
"""Aggregates all filters from remote evaluators.
|
||||
|
||||
Local copy is updated and then broadcasted to all remote evaluators.
|
||||
|
@ -19,14 +30,37 @@ class FilterManager:
|
|||
local_filters: Filters to be synchronized.
|
||||
remotes: Remote evaluators with filters.
|
||||
update_remote: Whether to push updates to remote filters.
|
||||
timeout_seconds: How long to wait for filter to get or set filters
|
||||
"""
|
||||
remote_filters = ray.get(
|
||||
[r.get_filters.remote(flush_after=True) for r in remotes]
|
||||
)
|
||||
try:
|
||||
remote_filters = ray.get(
|
||||
[r.get_filters.remote(flush_after=True) for r in remotes],
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except GetTimeoutError:
|
||||
logger.error(
|
||||
"Failed to get remote filters from a rollout worker in "
|
||||
"FilterManager. "
|
||||
"Filtered "
|
||||
"metrics may be computed, but filtered wrong."
|
||||
)
|
||||
|
||||
for rf in remote_filters:
|
||||
for k in local_filters:
|
||||
local_filters[k].apply_changes(rf[k], with_buffer=False)
|
||||
if update_remote:
|
||||
copies = {k: v.as_serializable() for k, v in local_filters.items()}
|
||||
remote_copy = ray.put(copies)
|
||||
[r.sync_filters.remote(remote_copy) for r in remotes]
|
||||
|
||||
try:
|
||||
ray.get(
|
||||
[r.sync_filters.remote(remote_copy) for r in remotes],
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except GetTimeoutError:
|
||||
logger.error(
|
||||
"Failed to set remote filters to a rollout worker in "
|
||||
"FilterManager. "
|
||||
"Filtered "
|
||||
"metrics may be computed, but filtered wrong."
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue