mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
fix handling of non-integral timeout values in signal.receive (#5002)
This commit is contained in:
parent
7bda5edc16
commit
e59e8074dd
2 changed files with 45 additions and 2 deletions
|
@ -2,6 +2,8 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import ray
|
||||
|
@ -13,6 +15,8 @@ import ray.cloudpickle as cloudpickle
|
|||
# in node_manager.cc
|
||||
ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Signal(object):
|
||||
"""Base class for Ray signals."""
|
||||
|
@ -125,10 +129,16 @@ def receive(sources, timeout=None):
|
|||
for s in sources:
|
||||
task_id_to_sources[_get_task_id(s).hex()].append(s)
|
||||
|
||||
if timeout < 1e-3:
|
||||
logger.warning("Timeout too small. Using 1ms minimum")
|
||||
timeout = 1e-3
|
||||
|
||||
timeout_ms = int(1000 * timeout)
|
||||
|
||||
# Construct the redis query.
|
||||
query = "XREAD BLOCK "
|
||||
# Multiply by 1000x since timeout is in sec and redis expects ms.
|
||||
query += str(1000 * timeout)
|
||||
# redis expects ms.
|
||||
query += str(timeout_ms)
|
||||
query += " STREAMS "
|
||||
query += " ".join([task_id for task_id in task_id_to_sources])
|
||||
query += " "
|
||||
|
|
|
@ -353,3 +353,36 @@ def test_serial_tasks_reading_same_signal(ray_start_regular):
|
|||
assert len(result_list) == 1
|
||||
result_list = ray.get(f.remote([a]))
|
||||
assert len(result_list) == 1
|
||||
|
||||
|
||||
def test_non_integral_receive_timeout(ray_start_regular):
|
||||
@ray.remote
|
||||
def send_signal(value):
|
||||
signal.send(UserSignal(value))
|
||||
|
||||
a = send_signal.remote(0)
|
||||
# make sure send_signal had a chance to execute
|
||||
ray.get(a)
|
||||
|
||||
result_list = ray.experimental.signal.receive([a], timeout=0.1)
|
||||
|
||||
assert len(result_list) == 1
|
||||
|
||||
|
||||
def test_small_receive_timeout(ray_start_regular):
|
||||
""" Test that receive handles timeout smaller than the 1ms min
|
||||
"""
|
||||
# 0.1 ms
|
||||
small_timeout = 1e-4
|
||||
|
||||
@ray.remote
|
||||
def send_signal(value):
|
||||
signal.send(UserSignal(value))
|
||||
|
||||
a = send_signal.remote(0)
|
||||
# make sure send_signal had a chance to execute
|
||||
ray.get(a)
|
||||
|
||||
result_list = ray.experimental.signal.receive([a], timeout=small_timeout)
|
||||
|
||||
assert len(result_list) == 1
|
||||
|
|
Loading…
Add table
Reference in a new issue