fix handling of non-integral timeout values in signal.receive (#5002)

This commit is contained in:
Andrew Berger 2019-06-20 18:33:40 -04:00 committed by Philipp Moritz
parent 7bda5edc16
commit e59e8074dd
2 changed files with 45 additions and 2 deletions

View file

@ -2,6 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
from collections import defaultdict
import ray
@ -13,6 +15,8 @@ import ray.cloudpickle as cloudpickle
# in node_manager.cc
ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL"
logger = logging.getLogger(__name__)
class Signal(object):
"""Base class for Ray signals."""
@ -125,10 +129,16 @@ def receive(sources, timeout=None):
for s in sources:
task_id_to_sources[_get_task_id(s).hex()].append(s)
if timeout < 1e-3:
logger.warning("Timeout too small. Using 1ms minimum")
timeout = 1e-3
timeout_ms = int(1000 * timeout)
# Construct the redis query.
query = "XREAD BLOCK "
# Multiply by 1000x since timeout is in sec and redis expects ms.
query += str(1000 * timeout)
# redis expects ms.
query += str(timeout_ms)
query += " STREAMS "
query += " ".join([task_id for task_id in task_id_to_sources])
query += " "

View file

@ -353,3 +353,36 @@ def test_serial_tasks_reading_same_signal(ray_start_regular):
assert len(result_list) == 1
result_list = ray.get(f.remote([a]))
assert len(result_list) == 1
def test_non_integral_receive_timeout(ray_start_regular):
@ray.remote
def send_signal(value):
signal.send(UserSignal(value))
a = send_signal.remote(0)
# make sure send_signal had a chance to execute
ray.get(a)
result_list = ray.experimental.signal.receive([a], timeout=0.1)
assert len(result_list) == 1
def test_small_receive_timeout(ray_start_regular):
""" Test that receive handles timeout smaller than the 1ms min
"""
# 0.1 ms
small_timeout = 1e-4
@ray.remote
def send_signal(value):
signal.send(UserSignal(value))
a = send_signal.remote(0)
# make sure send_signal had a chance to execute
ray.get(a)
result_list = ray.experimental.signal.receive([a], timeout=small_timeout)
assert len(result_list) == 1