diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 2fc77c9b3..ef6330deb 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -156,7 +156,7 @@ class TrialRunner(object): # have been lost def _process_events(self): - [result_id], _ = ray.wait(self._running.keys()) + [result_id], _ = ray.wait(list(self._running.keys())) trial = self._running[result_id] del self._running[result_id] try: diff --git a/python/ray/worker.py b/python/ray/worker.py index dff99df55..a1255bd3e 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -2223,6 +2223,22 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): A list of object IDs that are ready and a list of the remaining object IDs. """ + + if isinstance(object_ids, ray.local_scheduler.ObjectID): + raise TypeError( + "wait() expected a list of ObjectID, got a single ObjectID") + + if not isinstance(object_ids, list): + raise TypeError("wait() expected a list of ObjectID, got {}".format( + type(object_ids))) + + if worker.mode != PYTHON_MODE: + for object_id in object_ids: + if not isinstance(object_id, ray.local_scheduler.ObjectID): + raise TypeError( + "wait() expected a list of ObjectID, " + "got list containing {}".format(type(object_id))) + check_connected(worker) with log_span("ray:wait", worker=worker): check_main_thread() diff --git a/test/runtest.py b/test/runtest.py index 5f4505781..b1b3c4e63 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -749,6 +749,15 @@ class APITest(unittest.TestCase): self.assertEqual(ready_ids, []) self.assertEqual(remaining_ids, []) + # Verify that incorrect usage raises a TypeError. + x = ray.put(1) + with self.assertRaises(TypeError): + ray.wait(x) + with self.assertRaises(TypeError): + ray.wait(1) + with self.assertRaises(TypeError): + ray.wait([1]) + def testMultipleWaitsAndGets(self): # It is important to use three workers here, so that the three tasks # launched in this experiment can run at the same time.