diff --git a/python/ray/worker.py b/python/ray/worker.py index 16f77d4b6..740f134dc 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -506,7 +506,8 @@ class Worker(object): actor_handle_id=None, actor_counter=0, is_actor_checkpoint_method=False, actor_creation_id=None, actor_creation_dummy_object_id=None, - execution_dependencies=None): + execution_dependencies=None, num_return_vals=None, + num_cpus=None, num_gpus=None, resources=None): """Submit a remote task to the scheduler. Tell the scheduler to schedule the execution of the function with ID @@ -528,6 +529,11 @@ class Worker(object): then this argument is the dummy object ID associated with the actor creation task for the corresponding actor. execution_dependencies: The execution dependencies for this task. + num_return_vals: The number of return values this function should + have. + num_cpus: The number of CPUs required by this task. + num_gpus: The number of GPUs required by this task. + resources: The resource requirements for this task. Returns: The return object IDs for this task. @@ -571,12 +577,25 @@ class Worker(object): function_properties = self.function_properties[ self.task_driver_id.id()][function_id.id()] + if num_return_vals is None: + num_return_vals = function_properties.num_return_vals + + if resources is None and num_cpus is None and num_gpus is None: + resources = function_properties.resources + else: + resources = {} if resources is None else resources + if "CPU" in resources or "GPU" in resources: + raise ValueError("The resources dictionary must not " + "contain the keys 'CPU' or 'GPU'") + resources["CPU"] = num_cpus + resources["GPU"] = num_gpus + # Submit the task to local scheduler. task = ray.local_scheduler.Task( self.task_driver_id, ray.local_scheduler.ObjectID(function_id.id()), args_for_local_scheduler, - function_properties.num_return_vals, + num_return_vals, self.current_task_id, self.task_index, actor_creation_id, @@ -586,7 +605,7 @@ class Worker(object): actor_counter, is_actor_checkpoint_method, execution_dependencies, - function_properties.resources) + resources) # Increment the worker's task index to track how many tasks have # been submitted by the current task so far. self.task_index += 1 @@ -725,7 +744,7 @@ class Worker(object): arguments.append(argument) return arguments - def _store_outputs_in_objstore(self, objectids, outputs): + def _store_outputs_in_objstore(self, object_ids, outputs): """Store the outputs of a remote function in the local object store. This stores the values that were returned by a remote function in the @@ -735,18 +754,18 @@ class Worker(object): executes the remote function. Note: - The arguments objectids and outputs should have the same length. + The arguments object_ids and outputs should have the same length. Args: - objectids (List[ObjectID]): The object IDs that were assigned to + object_ids (List[ObjectID]): The object IDs that were assigned to the outputs of the remote function call. outputs (Tuple): The value returned by the remote function. If the remote function was supposed to only return one value, then its output was wrapped in a tuple with one element prior to being passed into this function. """ - for i in range(len(objectids)): - self.put_object(objectids[i], outputs[i]) + for i in range(len(object_ids)): + self.put_object(object_ids[i], outputs[i]) def _process_task(self, task): """Execute a task assigned to this worker. @@ -2337,7 +2356,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): If timeout is set, the function returns either when the requested number of IDs are ready or when the timeout is reached, whichever occurs first. If it is not set, the function simply waits until that number of objects is ready - and returns that exact number of objectids. + and returns that exact number of object_ids. This method returns two lists. The first list consists of object IDs that correspond to objects that are stored in the object store. The second list @@ -2398,7 +2417,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): return ready_ids, remaining_ids -def _submit_task(function_id, args, worker=global_worker): +def _submit_task(function_id, *args, **kwargs): """This is a wrapper around worker.submit_task. We use this wrapper so that in the remote decorator, we can call @@ -2406,7 +2425,7 @@ def _submit_task(function_id, args, worker=global_worker): attempt to serialize remote functions, we don't attempt to serialize the worker object, which cannot be serialized. """ - return worker.submit_task(function_id, args) + return global_worker.submit_task(function_id, *args, **kwargs) def _mode(worker=global_worker): @@ -2566,8 +2585,14 @@ def remote(*args, **kwargs): def func_call(*args, **kwargs): """This runs immediately when a remote function is called.""" + return _submit(args=args, kwargs=kwargs) + + def _submit(args=None, kwargs=None, num_return_vals=None, + num_cpus=None, num_gpus=None, resources=None): + """An experimental alternate way to submit remote functions.""" check_connected() check_main_thread() + kwargs = {} if kwargs is None else kwargs args = signature.extend_args(function_signature, args, kwargs) if _mode() == PYTHON_MODE: @@ -2577,11 +2602,14 @@ def remote(*args, **kwargs): # immutable remote objects. result = func(*copy.deepcopy(args)) return result - objectids = _submit_task(function_id, args) - if len(objectids) == 1: - return objectids[0] - elif len(objectids) > 1: - return objectids + object_ids = _submit_task(function_id, args, + num_return_vals=num_return_vals, + num_cpus=num_cpus, num_gpus=num_gpus, + resources=resources) + if len(object_ids) == 1: + return object_ids[0] + elif len(object_ids) > 1: + return object_ids def func_executor(arguments): """This gets run when the remote function is executed.""" @@ -2594,6 +2622,7 @@ def remote(*args, **kwargs): "Instead of running '{}()', try '{}.remote()'." .format(func_name, func_name)) func_invoker.remote = func_call + func_invoker._submit = _submit func_invoker.executor = func_executor func_invoker.is_remote = True func_name = "{}.{}".format(func.__module__, func.__name__) diff --git a/test/runtest.py b/test/runtest.py index 8af10f50d..757d27bbc 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -685,6 +685,26 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(k2.remote(1)), 2) self.assertEqual(ray.get(m.remote(1)), 2) + def testSubmitAPI(self): + self.init_ray(num_gpus=1, resources={"Custom": 1}, num_workers=1) + + @ray.remote + def f(n): + return list(range(n)) + + @ray.remote + def g(): + return ray.get_gpu_ids() + + assert f._submit([0], num_return_vals=0) is None + assert ray.get(f._submit(args=[1], num_return_vals=1)) == [0] + assert ray.get(f._submit(args=[2], num_return_vals=2)) == [0, 1] + assert ray.get(f._submit(args=[3], num_return_vals=3)) == [0, 1, 2] + assert ray.get(g._submit(args=[], + num_cpus=1, + num_gpus=1, + resources={"Custom": 1})) == [0] + def testGetMultiple(self): self.init_ray() object_ids = [ray.put(i) for i in range(10)]