change remote function invocation from func() to func.remote() (#328)

This commit is contained in:
Robert Nishihara 2016-07-31 15:25:19 -07:00 committed by Philipp Moritz
parent 92f1976e94
commit 0e5b858324
19 changed files with 219 additions and 215 deletions

View file

@ -30,15 +30,15 @@ def estimate_pi(n):
# Launch 10 tasks, each of which estimates pi.
results = []
for _ in range(10):
results.append(estimate_pi(100))
results.append(estimate_pi.remote(100))
# Fetch the results of the tasks and print their average.
estimate = np.mean([ray.get(ref) for ref in results])
print "Pi is approximately {}.".format(estimate)
```
Within the for loop, each call to `estimate_pi(100)` sends a message to the
scheduler asking it to schedule the task of running `estimate_pi` with the
Within the for loop, each call to `estimate_pi.remote(100)` sends a message to
the scheduler asking it to schedule the task of running `estimate_pi` with the
argument `100`. This call returns right away without waiting for the actual
estimation of pi to take place. Instead of returning a float, it returns an
**object reference**, which represents the eventual output of the computation

View file

@ -122,12 +122,12 @@ function is called instead of catching them when the task is actually executed
### Remote functions
Whereas in regular Python, calling `add(1, 2)` would return `3`, in Ray, calling
`add(1, 2)` does not actually execute the task. Instead, it adds a task to the
computation graph and immediately returns an object reference to the output of
the computation.
`add.remote(1, 2)` does not actually execute the task. Instead, it adds a task
to the computation graph and immediately returns an object reference to the
output of the computation.
```python
>>> ref = add(1, 2)
>>> ref = add.remote(1, 2)
>>> ray.get(ref) # prints 3
```
@ -141,9 +141,9 @@ When a task is submitted, each argument may be passed in by value or by object
reference. For example, these lines have the same behavior.
```python
>>> add(1, 2)
>>> add(1, ray.put(2))
>>> add(ray.put(1), ray.put(2))
>>> add.remote(1, 2)
>>> add.remote(1, ray.put(2))
>>> add.remote(ray.put(1), ray.put(2))
```
Remote functions never return actual values, they always return object
@ -195,7 +195,7 @@ Then we can write
# Submit ten tasks to the scheduler. This finishes almost immediately.
result_refs = []
for i in range(10):
result_refs.append(sleep(5))
result_refs.append(sleep.remote(5))
# Wait for the results. If we have at least ten workers, this takes 5 seconds.
[ray.get(ref) for ref in result_refs] # prints [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
@ -246,9 +246,9 @@ def dot(a, b):
Then we run
```python
aref = zeros([10, 10])
bref = zeros([10, 10])
cref = dot(aref, bref)
aref = zeros.remote([10, 10])
bref = zeros.remote([10, 10])
cref = dot.remote(aref, bref)
```
The corresponding computation graph looks like this.

View file

@ -72,7 +72,7 @@ of object references, where the first object reference in each pair refers to a
batch of images and the second refers to the corresponding batch of labels.
```python
batches = [load_tarfile_from_s3(bucket, s3_key, size) for s3_key in s3_keys]
batches = [load_tarfile_from_s3.remote(bucket, s3_key, size) for s3_key in s3_keys]
```
By default, this will only fetch objects whose keys have prefix
@ -104,5 +104,5 @@ gradient_refs = []
for i in range(num_workers):
# Choose a random batch and use it to compute the gradient of the loss.
x_ref, y_ref = batches[np.random.randint(len(batches))]
gradient_refs.append(compute_grad(x_ref, y_ref, mean_ref, weights_ref))
gradient_refs.append(compute_grad.remote(x_ref, y_ref, mean_ref, weights_ref))
```

View file

@ -74,7 +74,7 @@ def load_tarfiles_from_s3(bucket, s3_keys, size=[]):
np.ndarray: Contains object references to the chunks of the images (see load_chunk).
"""
return [load_tarfile_from_s3(bucket, s3_key, size) for s3_key in s3_keys]
return [load_tarfile_from_s3.remote(bucket, s3_key, size) for s3_key in s3_keys]
def setup_variables(params, placeholders, assigns, kernelshape, biasshape):
"""Creates the variables for each layer and adds the variables and the components needed to feed them to various lists
@ -239,7 +239,7 @@ def num_images(batches):
Returns:
int: The number of images
"""
shape_refs = [ra.shape(batch) for batch in batches]
shape_refs = [ra.shape.remote(batch) for batch in batches]
return sum([ray.get(shape_ref)[0] for shape_ref in shape_refs])
@ray.remote([List], [np.ndarray])
@ -254,9 +254,9 @@ def compute_mean_image(batches):
"""
if len(batches) == 0:
raise Exception("No images were passed into `compute_mean_image`.")
sum_image_refs = [ra.sum(batch, axis=0) for batch in batches]
sum_image_refs = [ra.sum.remote(batch, axis=0) for batch in batches]
sum_images = [ray.get(ref) for ref in sum_image_refs]
n_images = num_images(batches)
n_images = num_images.remote(batches)
return np.sum(sum_images, axis=0).astype("float64") / ray.get(n_images)
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray, np.ndarray, np.ndarray, np.ndarray])
@ -303,7 +303,7 @@ def shuffle_pair(first_batch, second_batch):
Tuple[ObjRef, Objref]: The first batch of shuffled data.
Tuple[ObjRef, Objref]: Two second bach of shuffled data.
"""
images1, labels1, images2, labels2 = shuffle_arrays(first_batch[0], first_batch[1], second_batch[0], second_batch[1])
images1, labels1, images2, labels2 = shuffle_arrays.remote(first_batch[0], first_batch[1], second_batch[0], second_batch[1])
return (images1, labels1), (images2, labels2)
@ray.remote([list, dict], [np.ndarray])

View file

@ -44,10 +44,10 @@ if __name__ == "__main__":
imagenet_data = alexnet.load_tarfiles_from_s3(args.s3_bucket, image_tar_files, [256, 256])
# Convert the parsed filenames to integer labels and create batches.
batches = [(images, alexnet.filenames_to_labels(filenames, filename_label_dict_ref)) for images, filenames in imagenet_data]
batches = [(images, alexnet.filenames_to_labels.remote(filenames, filename_label_dict_ref)) for images, filenames in imagenet_data]
# Compute the mean image.
mean_ref = alexnet.compute_mean_image([images for images, labels in batches])
mean_ref = alexnet.compute_mean_image.remote([images for images, labels in batches])
# The data does not start out shuffled. Images of the same class all appear
# together, so we shuffle it ourselves here. Each shuffle pairs up the batches
@ -71,14 +71,14 @@ if __name__ == "__main__":
# Compute the accuracy on a random training batch.
x_ref, y_ref = batches[np.random.randint(len(batches))]
accuracy = alexnet.compute_accuracy(x_ref, y_ref, weights_ref)
accuracy = alexnet.compute_accuracy.remote(x_ref, y_ref, weights_ref)
# Launch tasks in parallel to compute the gradients for some batches.
gradient_refs = []
for i in range(num_workers - 1):
# Choose a random batch and use it to compute the gradient of the loss.
x_ref, y_ref = batches[np.random.randint(len(batches))]
gradient_refs.append(alexnet.compute_grad(x_ref, y_ref, mean_ref, weights_ref))
gradient_refs.append(alexnet.compute_grad.remote(x_ref, y_ref, mean_ref, weights_ref))
# Print the accuracy on a random training batch.
print "Iteration {}: accuracy = {:.3}%".format(iteration, 100 * ray.get(accuracy))

View file

@ -105,7 +105,7 @@ computation. Instead, it simply submits a number of tasks to the scheduler.
result_refs = []
for _ in range(100):
params = generate_random_params()
results.append((params, train_cnn_and_compute_accuracy(params, epochs)))
results.append((params, train_cnn_and_compute_accuracy.remote(params, epochs)))
```
If we wish to wait until the results have all been retrieved, we can retrieve

View file

@ -37,7 +37,7 @@ if __name__ == "__main__":
dropout = np.random.uniform(0, 1)
stddev = 10 ** np.random.uniform(-5, 5)
params = {"learning_rate": learning_rate, "batch_size": batch_size, "dropout": dropout, "stddev": stddev}
results.append((params, hyperopt.train_cnn_and_compute_accuracy(params, epochs, train_images, train_labels, validation_images, validation_labels)))
results.append((params, hyperopt.train_cnn_and_compute_accuracy.remote(params, epochs, train_images, train_labels, validation_images, validation_labels)))
# Fetch the results of the tasks and print the results.
for i in range(trials):

View file

@ -112,12 +112,12 @@ gradient.
```python
def full_loss(theta):
theta_ref = ray.put(theta)
loss_refs = [loss(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
loss_refs = [loss.remote(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
return sum([ray.get(loss_ref) for loss_ref in loss_refs])
def full_grad(theta):
theta_ref = ray.put(theta)
grad_refs = [grad(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
grad_refs = [grad.remote(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
return sum([ray.get(grad_ref) for grad_ref in grad_refs]).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b.
```
@ -125,14 +125,14 @@ Note that we turn `theta` into a remote object with the line `theta_ref =
ray.put(theta)` before passing it into the remote functions. If we had written
```python
[loss(theta, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
[loss.remote(theta, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
```
instead of
```python
theta_ref = ray.put(theta)
[loss(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
[loss.remote(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
```
then each task that got sent to the scheduler (one for every element of

View file

@ -79,13 +79,13 @@ if __name__ == "__main__":
# Compute the loss on the entire dataset.
def full_loss(theta):
theta_ref = ray.put(theta)
loss_refs = [loss(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
loss_refs = [loss.remote(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
return sum([ray.get(loss_ref) for loss_ref in loss_refs])
# Compute the gradient of the loss on the entire dataset.
def full_grad(theta):
theta_ref = ray.put(theta)
grad_refs = [grad(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
grad_refs = [grad.remote(theta_ref, xs_ref, ys_ref) for (xs_ref, ys_ref) in batch_refs]
return sum([ray.get(grad_ref) for grad_ref in grad_refs]).astype("float64") # This conversion is necessary for use with fmin_l_bfgs_b.
# From the perspective of scipy.optimize.fmin_l_bfgs_b, full_loss is simply a

View file

@ -54,7 +54,7 @@ model_ref = ray.put(model)
grads, reward_sums = [], []
# Launch tasks to compute gradients from multiple rollouts in parallel.
for i in range(10):
grad_ref, reward_sum_ref = compute_gradient(model_ref)
grad_ref, reward_sum_ref = compute_gradient.remote(model_ref)
grads.append(grad_ref)
reward_sums.append(reward_sum_ref)
```

View file

@ -127,7 +127,7 @@ if __name__ == "__main__":
grads, reward_sums = [], []
# Launch tasks to compute gradients from multiple rollouts in parallel.
for i in range(batch_size):
grad_ref, reward_sum_ref = compute_gradient(model_ref)
grad_ref, reward_sum_ref = compute_gradient.remote(model_ref)
grads.append(grad_ref)
reward_sums.append(reward_sum_ref)
for i in range(batch_size):

View file

@ -87,14 +87,14 @@ def numpy_to_dist(a):
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objrefs[index] = ra.zeros(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
result.objrefs[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([List, str], [DistArray])
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objrefs[index] = ra.ones(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
result.objrefs[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([DistArray], [DistArray])
@ -112,9 +112,9 @@ def eye(dim1, dim2=-1, dtype_name="float"):
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objrefs[i, j] = ra.eye(block_shape[0], block_shape[1], dtype_name=dtype_name)
result.objrefs[i, j] = ra.eye.remote(block_shape[0], block_shape[1], dtype_name=dtype_name)
else:
result.objrefs[i, j] = ra.zeros(block_shape, dtype_name=dtype_name)
result.objrefs[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
return result
@ray.remote([DistArray], [DistArray])
@ -124,11 +124,11 @@ def triu(a):
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i < j:
result.objrefs[i, j] = ra.copy(a.objrefs[i, j])
result.objrefs[i, j] = ra.copy.remote(a.objrefs[i, j])
elif i == j:
result.objrefs[i, j] = ra.triu(a.objrefs[i, j])
result.objrefs[i, j] = ra.triu.remote(a.objrefs[i, j])
else:
result.objrefs[i, j] = ra.zeros_like(a.objrefs[i, j])
result.objrefs[i, j] = ra.zeros_like.remote(a.objrefs[i, j])
return result
@ray.remote([DistArray], [DistArray])
@ -138,11 +138,11 @@ def tril(a):
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i > j:
result.objrefs[i, j] = ra.copy(a.objrefs[i, j])
result.objrefs[i, j] = ra.copy.remote(a.objrefs[i, j])
elif i == j:
result.objrefs[i, j] = ra.tril(a.objrefs[i, j])
result.objrefs[i, j] = ra.tril.remote(a.objrefs[i, j])
else:
result.objrefs[i, j] = ra.zeros_like(a.objrefs[i, j])
result.objrefs[i, j] = ra.zeros_like.remote(a.objrefs[i, j])
return result
@ray.remote([np.ndarray], [np.ndarray])
@ -168,7 +168,7 @@ def dot(a, b):
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
args = list(a.objrefs[i, :]) + list(b.objrefs[:, j])
result.objrefs[i, j] = blockwise_dot(*args)
result.objrefs[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote([DistArray, List], [DistArray])
@ -208,7 +208,7 @@ def transpose(a):
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
result.objrefs[i, j] = ra.transpose(a.objrefs[j, i])
result.objrefs[i, j] = ra.transpose.remote(a.objrefs[j, i])
return result
# TODO(rkn): support broadcasting?
@ -218,7 +218,7 @@ def add(x1, x2):
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objrefs[index] = ra.add(x1.objrefs[index], x2.objrefs[index])
result.objrefs[index] = ra.add.remote(x1.objrefs[index], x2.objrefs[index])
return result
# TODO(rkn): support broadcasting?
@ -228,5 +228,5 @@ def subtract(x1, x2):
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objrefs[index] = ra.subtract(x1.objrefs[index], x2.objrefs[index])
result.objrefs[index] = ra.subtract.remote(x1.objrefs[index], x2.objrefs[index])
return result

View file

@ -33,14 +33,14 @@ def tsqr(a):
current_rs = []
for i in range(num_blocks):
block = a.objrefs[i, 0]
q, r = ra.linalg.qr(block)
q, r = ra.linalg.qr.remote(block)
q_tree[i, 0] = q
current_rs.append(r)
for j in range(1, K):
new_rs = []
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
stacked_rs = ra.vstack(*current_rs[(2 * i):(2 * i + 2)])
q, r = ra.linalg.qr(stacked_rs)
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
q, r = ra.linalg.qr.remote(stacked_rs)
q_tree[i, j] = q
new_rs.append(r)
current_rs = new_rs
@ -72,7 +72,7 @@ def tsqr(a):
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], BLOCK_SIZE]
ith_index /= 2
q_block_current = ra.dot(q_block_current, ra.subarray(q_tree[ith_index, j], lower, upper))
q_block_current = ra.dot.remote(q_block_current, ra.subarray.remote(q_tree[ith_index, j], lower, upper))
q_result.objrefs[i] = q_block_current
r = current_rs[0]
return q_result, r
@ -106,7 +106,7 @@ def modified_lu(q):
for i in range(b):
L[i, i] = 1
U = np.triu(q_work)[:b, :]
return numpy_to_dist(ray.put(L)), U, S # TODO(rkn): get rid of put
return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
def tsqr_hr_helper1(u, s, y_top_block, b):
@ -123,11 +123,11 @@ def tsqr_hr_helper2(s, r_temp):
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray])
def tsqr_hr(a):
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
q, r_temp = tsqr(a)
y, u, s = modified_lu(q)
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
y_blocked = ray.get(y)
t, y_top = tsqr_hr_helper1(u, s, y_blocked.objrefs[0, 0], a.shape[1])
r = tsqr_hr_helper2(s, r_temp)
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objrefs[0, 0], a.shape[1])
r = tsqr_hr_helper2.remote(s, r_temp)
return y, t, y_top, r
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
@ -149,42 +149,42 @@ def qr(a):
a_work.construct(a.shape, np.copy(a.objrefs))
result_dtype = np.linalg.qr(ray.get(a.objrefs[0, 0]))[0].dtype.name
r_res = ray.get(zeros([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(zeros([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
r_res = ray.get(zeros.remote([k, n], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(zeros.remote([m, k], result_dtype)) # TODO(rkn): It would be preferable not to get this right after creating it.
Ts = []
for i in range(min(a.num_blocks[0], a.num_blocks[1])): # this differs from the paper, which says "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense when a.num_blocks[1] > a.num_blocks[0]
sub_dist_array = subblocks(a_work, range(i, a_work.num_blocks[0]), [i])
y, t, _, R = tsqr_hr(sub_dist_array)
sub_dist_array = subblocks.remote(a_work, range(i, a_work.num_blocks[0]), [i])
y, t, _, R = tsqr_hr.remote(sub_dist_array)
y_val = ray.get(y)
for j in range(i, a.num_blocks[0]):
y_res.objrefs[j, i] = y_val.objrefs[j - i, 0]
if a.shape[0] > a.shape[1]:
# in this case, R needs to be square
R_shape = ray.get(ra.shape(R))
eye_temp = ra.eye(R_shape[1], R_shape[0], dtype_name=result_dtype)
r_res.objrefs[i, i] = ra.dot(eye_temp, R)
R_shape = ray.get(ra.shape.remote(R))
eye_temp = ra.eye.remote(R_shape[1], R_shape[0], dtype_name=result_dtype)
r_res.objrefs[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objrefs[i, i] = R
Ts.append(numpy_to_dist(t))
Ts.append(numpy_to_dist.remote(t))
for c in range(i + 1, a.num_blocks[1]):
W_rcs = []
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objrefs[r - i, 0]
W_rcs.append(qr_helper2(y_ri, a_work.objrefs[r, c]))
W_c = ra.sum_list(*W_rcs)
W_rcs.append(qr_helper2.remote(y_ri, a_work.objrefs[r, c]))
W_c = ra.sum_list.remote(*W_rcs)
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objrefs[r - i, 0]
A_rc = qr_helper1(a_work.objrefs[r, c], y_ri, t, W_c)
A_rc = qr_helper1.remote(a_work.objrefs[r, c], y_ri, t, W_c)
a_work.objrefs[r, c] = A_rc
r_res.objrefs[i, c] = a_work.objrefs[i, c]
# construct q_res from Ys and Ts
q = eye(m, k, dtype_name=result_dtype)
q = eye.remote(m, k, dtype_name=result_dtype)
for i in range(len(Ts))[::-1]:
y_col_block = subblocks(y_res, [], [i])
q = subtract(q, dot(y_col_block, dot(Ts[i], dot(transpose(y_col_block), q))))
y_col_block = subblocks.remote(y_res, [], [i])
q = subtract.remote(q, dot.remote(y_col_block, dot.remote(Ts[i], dot.remote(transpose.remote(y_col_block), q))))
return q, r_res

View file

@ -11,7 +11,7 @@ def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objrefs = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objrefs[index] = ra.random.normal(DistArray.compute_block_shape(index, shape))
objrefs[index] = ra.random.normal.remote(DistArray.compute_block_shape(index, shape))
result = DistArray()
result.construct(shape, objrefs)
return result

View file

@ -836,20 +836,24 @@ def remote(arg_types, return_types, worker=global_worker):
start_time = time.time()
result = func(*arguments)
end_time = time.time()
check_return_values(func_call, result) # throws an exception if result is invalid
check_return_values(func_invoker, result) # throws an exception if result is invalid
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
return result
func_call.executor = func_executor
func_call.arg_types = arg_types
func_call.return_types = return_types
func_call.is_remote = True
def func_invoker(*args, **kwargs):
"""This is returned by the decorator and used to invoke the function."""
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
func_invoker.remote = func_call
func_invoker.executor = func_executor
func_invoker.arg_types = arg_types
func_invoker.return_types = return_types
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_call.func_name = func_name
func_call.func_doc = func.func_doc
func_invoker.func_name = func_name
func_invoker.func_doc = func.func_doc
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
keyword_defaults = [(k, v.default) for k, v in sig_params]
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
func_call.has_vararg_param = has_vararg_param
func_invoker.has_vararg_param = has_vararg_param
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
@ -858,7 +862,7 @@ def remote(arg_types, return_types, worker=global_worker):
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_call # Allow the function to reference itself as a global variable
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
finally:
@ -869,7 +873,7 @@ def remote(arg_types, return_types, worker=global_worker):
ray.lib.export_function(worker.handle, to_export)
elif worker.mode is None:
worker.cached_remote_functions.append(to_export)
return func_call
return func_invoker
return remote_decorator
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):

View file

@ -16,25 +16,25 @@ class RemoteArrayTest(unittest.TestCase):
ray.services.start_ray_local(num_workers=1)
# test eye
ref = ra.eye(3)
ref = ra.eye.remote(3)
val = ray.get(ref)
self.assertTrue(np.alltrue(val == np.eye(3)))
# test zeros
ref = ra.zeros([3, 4, 5])
ref = ra.zeros.remote([3, 4, 5])
val = ray.get(ref)
self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5])))
# test qr - pass by value
val_a = np.random.normal(size=[10, 11])
ref_q, ref_r = ra.linalg.qr(val_a)
ref_q, ref_r = ra.linalg.qr.remote(val_a)
val_q = ray.get(ref_q)
val_r = ray.get(ref_r)
self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a))
# test qr - pass by objref
a = ra.random.normal([10, 13])
ref_q, ref_r = ra.linalg.qr(a)
a = ra.random.normal.remote([10, 13])
ref_q, ref_r = ra.linalg.qr.remote(a)
val_a = ray.get(a)
val_q = ray.get(ref_q)
val_r = ray.get(ref_r)
@ -63,8 +63,8 @@ class DistributedArrayTest(unittest.TestCase):
reload(module)
ray.services.start_ray_local(num_workers=1)
a = ra.ones([da.BLOCK_SIZE, da.BLOCK_SIZE])
b = ra.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])
a = ra.ones.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
b = ra.zeros.remote([da.BLOCK_SIZE, da.BLOCK_SIZE])
x = da.DistArray()
x.construct([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])))
@ -77,68 +77,68 @@ class DistributedArrayTest(unittest.TestCase):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../scripts/default_worker.py")
ray.services.start_services_local(num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path)
x = da.zeros([9, 25, 51], "float")
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == np.zeros([9, 25, 51])))
x = da.zeros.remote([9, 25, 51], "float")
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(x)) == np.zeros([9, 25, 51])))
x = da.ones([11, 25, 49], dtype_name="float")
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == np.ones([11, 25, 49])))
x = da.ones.remote([11, 25, 49], dtype_name="float")
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(x)) == np.ones([11, 25, 49])))
x = da.random.normal([11, 25, 49])
y = da.copy(x)
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == ray.get(da.assemble(y))))
x = da.random.normal.remote([11, 25, 49])
y = da.copy.remote(x)
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(x)) == ray.get(da.assemble.remote(y))))
x = da.eye(25, dtype_name="float")
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == np.eye(25)))
x = da.eye.remote(25, dtype_name="float")
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(x)) == np.eye(25)))
x = da.random.normal([25, 49])
y = da.triu(x)
self.assertTrue(np.alltrue(ray.get(da.assemble(y)) == np.triu(ray.get(da.assemble(x)))))
x = da.random.normal.remote([25, 49])
y = da.triu.remote(x)
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(y)) == np.triu(ray.get(da.assemble.remote(x)))))
x = da.random.normal([25, 49])
y = da.tril(x)
self.assertTrue(np.alltrue(ray.get(da.assemble(y)) == np.tril(ray.get(da.assemble(x)))))
x = da.random.normal.remote([25, 49])
y = da.tril.remote(x)
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(y)) == np.tril(ray.get(da.assemble.remote(x)))))
x = da.random.normal([25, 49])
y = da.random.normal([49, 18])
z = da.dot(x, y)
w = da.assemble(z)
u = da.assemble(x)
v = da.assemble(y)
x = da.random.normal.remote([25, 49])
y = da.random.normal.remote([49, 18])
z = da.dot.remote(x, y)
w = da.assemble.remote(z)
u = da.assemble.remote(x)
v = da.assemble.remote(y)
np.allclose(ray.get(w), np.dot(ray.get(u), ray.get(v)))
self.assertTrue(np.allclose(ray.get(w), np.dot(ray.get(u), ray.get(v))))
# test add
x = da.random.normal([23, 42])
y = da.random.normal([23, 42])
z = da.add(x, y)
self.assertTrue(np.allclose(ray.get(da.assemble(z)), ray.get(da.assemble(x)) + ray.get(da.assemble(y))))
x = da.random.normal.remote([23, 42])
y = da.random.normal.remote([23, 42])
z = da.add.remote(x, y)
self.assertTrue(np.allclose(ray.get(da.assemble.remote(z)), ray.get(da.assemble.remote(x)) + ray.get(da.assemble.remote(y))))
# test subtract
x = da.random.normal([33, 40])
y = da.random.normal([33, 40])
z = da.subtract(x, y)
self.assertTrue(np.allclose(ray.get(da.assemble(z)), ray.get(da.assemble(x)) - ray.get(da.assemble(y))))
x = da.random.normal.remote([33, 40])
y = da.random.normal.remote([33, 40])
z = da.subtract.remote(x, y)
self.assertTrue(np.allclose(ray.get(da.assemble.remote(z)), ray.get(da.assemble.remote(x)) - ray.get(da.assemble.remote(y))))
# test transpose
x = da.random.normal([234, 432])
y = da.transpose(x)
self.assertTrue(np.alltrue(ray.get(da.assemble(x)).T == ray.get(da.assemble(y))))
x = da.random.normal.remote([234, 432])
y = da.transpose.remote(x)
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(x)).T == ray.get(da.assemble.remote(y))))
# test numpy_to_dist
x = da.random.normal([23, 45])
y = da.assemble(x)
z = da.numpy_to_dist(y)
w = da.assemble(z)
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == ray.get(da.assemble(z))))
x = da.random.normal.remote([23, 45])
y = da.assemble.remote(x)
z = da.numpy_to_dist.remote(y)
w = da.assemble.remote(z)
self.assertTrue(np.alltrue(ray.get(da.assemble.remote(x)) == ray.get(da.assemble.remote(z))))
self.assertTrue(np.alltrue(ray.get(y) == ray.get(w)))
# test da.tsqr
for shape in [[123, da.BLOCK_SIZE], [7, da.BLOCK_SIZE], [da.BLOCK_SIZE, da.BLOCK_SIZE], [da.BLOCK_SIZE, 7], [10 * da.BLOCK_SIZE, da.BLOCK_SIZE]]:
x = da.random.normal(shape)
x = da.random.normal.remote(shape)
K = min(shape)
q, r = da.linalg.tsqr(x)
x_val = ray.get(da.assemble(x))
q_val = ray.get(da.assemble(q))
q, r = da.linalg.tsqr.remote(x)
x_val = ray.get(da.assemble.remote(x))
q_val = ray.get(da.assemble.remote(q))
r_val = ray.get(r)
self.assertTrue(r_val.shape == (K, shape[1]))
self.assertTrue(np.alltrue(r_val == np.triu(r_val)))
@ -150,12 +150,12 @@ class DistributedArrayTest(unittest.TestCase):
print "testing dist_modified_lu with d1 = " + str(d1) + ", d2 = " + str(d2)
assert d1 >= d2
k = min(d1, d2)
m = ra.random.normal([d1, d2])
q, r = ra.linalg.qr(m)
l, u, s = da.linalg.modified_lu(da.numpy_to_dist(q))
m = ra.random.normal.remote([d1, d2])
q, r = ra.linalg.qr.remote(m)
l, u, s = da.linalg.modified_lu.remote(da.numpy_to_dist.remote(q))
q_val = ray.get(q)
r_val = ray.get(r)
l_val = ray.get(da.assemble(l))
l_val = ray.get(da.assemble.remote(l))
u_val = ray.get(u)
s_val = ray.get(s)
s_mat = np.zeros((d1, d2))
@ -171,10 +171,10 @@ class DistributedArrayTest(unittest.TestCase):
# test dist_tsqr_hr
def test_dist_tsqr_hr(d1, d2):
print "testing dist_tsqr_hr with d1 = " + str(d1) + ", d2 = " + str(d2)
a = da.random.normal([d1, d2])
y, t, y_top, r = da.linalg.tsqr_hr(a)
a_val = ray.get(da.assemble(a))
y_val = ray.get(da.assemble(y))
a = da.random.normal.remote([d1, d2])
y, t, y_top, r = da.linalg.tsqr_hr.remote(a)
a_val = ray.get(da.assemble.remote(a))
y_val = ray.get(da.assemble.remote(y))
t_val = ray.get(t)
y_top_val = ray.get(y_top)
r_val = ray.get(r)
@ -189,12 +189,12 @@ class DistributedArrayTest(unittest.TestCase):
def test_dist_qr(d1, d2):
print "testing qr with d1 = {}, and d2 = {}.".format(d1, d2)
a = da.random.normal([d1, d2])
a = da.random.normal.remote([d1, d2])
K = min(d1, d2)
q, r = da.linalg.qr(a)
a_val = ray.get(da.assemble(a))
q_val = ray.get(da.assemble(q))
r_val = ray.get(da.assemble(r))
q, r = da.linalg.qr.remote(a)
a_val = ray.get(da.assemble.remote(a))
q_val = ray.get(da.assemble.remote(q))
r_val = ray.get(da.assemble.remote(r))
self.assertTrue(q_val.shape == (d1, K))
self.assertTrue(r_val.shape == (K, d2))
self.assertTrue(np.allclose(np.dot(q_val.T, q_val), np.eye(K)))

View file

@ -16,7 +16,7 @@ class MicroBenchmarkTest(unittest.TestCase):
elapsed_times = []
for _ in range(1000):
start_time = time.time()
test_functions.empty_function()
test_functions.empty_function.remote()
end_time = time.time()
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
@ -32,7 +32,7 @@ class MicroBenchmarkTest(unittest.TestCase):
elapsed_times = []
for _ in range(1000):
start_time = time.time()
test_functions.trivial_function()
test_functions.trivial_function.remote()
end_time = time.time()
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
@ -48,7 +48,7 @@ class MicroBenchmarkTest(unittest.TestCase):
elapsed_times = []
for _ in range(1000):
start_time = time.time()
x = test_functions.trivial_function()
x = test_functions.trivial_function.remote()
ray.get(x)
end_time = time.time()
elapsed_times.append(end_time - start_time)

View file

@ -182,11 +182,11 @@ class APITest(unittest.TestCase):
reload(test_functions)
ray.services.start_ray_local(num_workers=3, driver_mode=ray.SILENT_MODE)
ref = test_functions.test_alias_f()
ref = test_functions.test_alias_f.remote()
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
ref = test_functions.test_alias_g()
ref = test_functions.test_alias_g.remote()
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
ref = test_functions.test_alias_h()
ref = test_functions.test_alias_h.remote()
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
ray.services.cleanup()
@ -195,35 +195,35 @@ class APITest(unittest.TestCase):
reload(test_functions)
ray.services.start_ray_local(num_workers=1)
x = test_functions.keyword_fct1(1)
x = test_functions.keyword_fct1.remote(1)
self.assertEqual(ray.get(x), "1 hello")
x = test_functions.keyword_fct1(1, "hi")
x = test_functions.keyword_fct1.remote(1, "hi")
self.assertEqual(ray.get(x), "1 hi")
x = test_functions.keyword_fct1(1, b="world")
x = test_functions.keyword_fct1.remote(1, b="world")
self.assertEqual(ray.get(x), "1 world")
x = test_functions.keyword_fct2(a="w", b="hi")
x = test_functions.keyword_fct2.remote(a="w", b="hi")
self.assertEqual(ray.get(x), "w hi")
x = test_functions.keyword_fct2(b="hi", a="w")
x = test_functions.keyword_fct2.remote(b="hi", a="w")
self.assertEqual(ray.get(x), "w hi")
x = test_functions.keyword_fct2(a="w")
x = test_functions.keyword_fct2.remote(a="w")
self.assertEqual(ray.get(x), "w world")
x = test_functions.keyword_fct2(b="hi")
x = test_functions.keyword_fct2.remote(b="hi")
self.assertEqual(ray.get(x), "hello hi")
x = test_functions.keyword_fct2("w")
x = test_functions.keyword_fct2.remote("w")
self.assertEqual(ray.get(x), "w world")
x = test_functions.keyword_fct2("w", "hi")
x = test_functions.keyword_fct2.remote("w", "hi")
self.assertEqual(ray.get(x), "w hi")
x = test_functions.keyword_fct3(0, 1, c="w", d="hi")
x = test_functions.keyword_fct3.remote(0, 1, c="w", d="hi")
self.assertEqual(ray.get(x), "0 1 w hi")
x = test_functions.keyword_fct3(0, 1, d="hi", c="w")
x = test_functions.keyword_fct3.remote(0, 1, d="hi", c="w")
self.assertEqual(ray.get(x), "0 1 w hi")
x = test_functions.keyword_fct3(0, 1, c="w")
x = test_functions.keyword_fct3.remote(0, 1, c="w")
self.assertEqual(ray.get(x), "0 1 w world")
x = test_functions.keyword_fct3(0, 1, d="hi")
x = test_functions.keyword_fct3.remote(0, 1, d="hi")
self.assertEqual(ray.get(x), "0 1 hello hi")
x = test_functions.keyword_fct3(0, 1)
x = test_functions.keyword_fct3.remote(0, 1)
self.assertEqual(ray.get(x), "0 1 hello world")
ray.services.cleanup()
@ -232,9 +232,9 @@ class APITest(unittest.TestCase):
reload(test_functions)
ray.services.start_ray_local(num_workers=1)
x = test_functions.varargs_fct1(0, 1, 2)
x = test_functions.varargs_fct1.remote(0, 1, 2)
self.assertEqual(ray.get(x), "0 1 2")
x = test_functions.varargs_fct2(0, 1, 2)
x = test_functions.varargs_fct2.remote(0, 1, 2)
self.assertEqual(ray.get(x), "1 2")
self.assertTrue(test_functions.kwargs_exception_thrown)
@ -246,14 +246,14 @@ class APITest(unittest.TestCase):
reload(test_functions)
ray.services.start_ray_local(num_workers=1, driver_mode=ray.SILENT_MODE)
test_functions.no_op()
test_functions.no_op.remote()
time.sleep(0.2)
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 0)
self.assertEqual(len(task_info["running_tasks"]), 0)
self.assertEqual(task_info["num_succeeded"], 1)
test_functions.no_op_fail()
test_functions.no_op_fail.remote()
time.sleep(0.2)
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 1)
@ -269,8 +269,8 @@ class APITest(unittest.TestCase):
# Make sure that these functions throw exceptions because there return
# values do not type check.
test_functions.test_return1()
test_functions.test_return2()
test_functions.test_return1.remote()
test_functions.test_return2.remote()
time.sleep(0.2)
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 2)
@ -286,30 +286,30 @@ class APITest(unittest.TestCase):
@ray.remote([int], [int])
def f(x):
return x + 1
self.assertEqual(ray.get(f(0)), 1)
self.assertEqual(ray.get(f.remote(0)), 1)
# Test that we can redefine the remote function.
@ray.remote([int], [int])
def f(x):
return x + 10
self.assertEqual(ray.get(f(0)), 10)
self.assertEqual(ray.get(f.remote(0)), 10)
# Test that we can close over plain old data.
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}]
@ray.remote([], [list])
def g():
return data
ray.get(g())
ray.get(g.remote())
# Test that we can close over modules.
@ray.remote([], [np.ndarray])
def h():
return np.zeros([3, 5])
self.assertTrue(np.alltrue(ray.get(h()) == np.zeros([3, 5])))
self.assertTrue(np.alltrue(ray.get(h.remote()) == np.zeros([3, 5])))
@ray.remote([], [float])
def j():
return time.time()
ray.get(j())
ray.get(j.remote())
# Test that we can define remote functions that call other remote functions.
@ray.remote([int], [int])
@ -317,13 +317,13 @@ class APITest(unittest.TestCase):
return x + 1
@ray.remote([int], [int])
def l(x):
return k(x)
return k.remote(x)
@ray.remote([int], [int])
def m(x):
return ray.get(l(x))
self.assertEqual(ray.get(k(1)), 2)
self.assertEqual(ray.get(l(1)), 2)
self.assertEqual(ray.get(m(1)), 2)
return ray.get(l.remote(x))
self.assertEqual(ray.get(k.remote(1)), 2)
self.assertEqual(ray.get(l.remote(1)), 2)
self.assertEqual(ray.get(m.remote(1)), 2)
ray.services.cleanup()
@ -348,10 +348,10 @@ class APITest(unittest.TestCase):
ray.services.start_ray_local(num_workers=2)
self.assertEqual(ray.get(use_foo()), 1)
self.assertEqual(ray.get(use_foo()), 1)
self.assertEqual(ray.get(use_bar()), [1])
self.assertEqual(ray.get(use_bar()), [1])
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_bar.remote()), [1])
self.assertEqual(ray.get(use_bar.remote()), [1])
ray.services.cleanup()
@ -360,9 +360,9 @@ class TaskStatusTest(unittest.TestCase):
reload(test_functions)
ray.services.start_ray_local(num_workers=3, driver_mode=ray.SILENT_MODE)
test_functions.test_alias_f()
test_functions.throw_exception_fct1()
test_functions.throw_exception_fct1()
test_functions.test_alias_f.remote()
test_functions.throw_exception_fct1.remote()
test_functions.throw_exception_fct1.remote()
time.sleep(1)
result = ray.task_info()
self.assertEqual(len(result["failed_tasks"]), 2)
@ -374,7 +374,7 @@ class TaskStatusTest(unittest.TestCase):
self.assertTrue(task["operationid"] not in task_ids)
task_ids.add(task["operationid"])
x = test_functions.throw_exception_fct2()
x = test_functions.throw_exception_fct2.remote()
try:
ray.get(x)
except Exception as e:
@ -382,7 +382,7 @@ class TaskStatusTest(unittest.TestCase):
else:
self.assertTrue(False) # ray.get should throw an exception
x, y, z = test_functions.throw_exception_fct3(1.0)
x, y, z = test_functions.throw_exception_fct3.remote(1.0)
for ref in [x, y, z]:
try:
ray.get(ref)
@ -411,7 +411,7 @@ class ReferenceCountingTest(unittest.TestCase):
reload(module)
ray.services.start_ray_local(num_workers=1)
x = test_functions.test_alias_f()
x = test_functions.test_alias_f.remote()
ray.get(x)
time.sleep(0.1)
objref_val = x.val
@ -420,7 +420,7 @@ class ReferenceCountingTest(unittest.TestCase):
del x
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], -1) # -1 indicates deallocated
y = test_functions.test_alias_h()
y = test_functions.test_alias_h.remote()
ray.get(y)
time.sleep(0.1)
objref_val = y.val
@ -429,7 +429,7 @@ class ReferenceCountingTest(unittest.TestCase):
del y
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)], [-1, -1, -1])
z = da.zeros([da.BLOCK_SIZE, 2 * da.BLOCK_SIZE])
z = da.zeros.remote([da.BLOCK_SIZE, 2 * da.BLOCK_SIZE])
time.sleep(0.1)
objref_val = z.val
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)], [1, 1, 1])
@ -438,9 +438,9 @@ class ReferenceCountingTest(unittest.TestCase):
time.sleep(0.1)
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)], [-1, -1, -1])
x = ra.zeros([10, 10])
y = ra.zeros([10, 10])
z = ra.dot(x, y)
x = ra.zeros.remote([10, 10])
y = ra.zeros.remote([10, 10])
z = ra.dot.remote(x, y)
objref_val = x.val
time.sleep(0.1)
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)], [1, 1, 1])
@ -502,7 +502,7 @@ class PythonModeTest(unittest.TestCase):
reload(test_functions)
ray.services.start_ray_local(driver_mode=ray.PYTHON_MODE)
xref = test_functions.test_alias_h()
xref = test_functions.test_alias_h.remote()
self.assertTrue(np.alltrue(xref == np.ones([3, 4, 5]))) # remote functions should return by value
self.assertTrue(np.alltrue(xref == ray.get(xref))) # ray.get should be the identity
y = np.random.normal(size=[11, 12])
@ -510,9 +510,9 @@ class PythonModeTest(unittest.TestCase):
# make sure objects are immutable, this example is why we need to copy
# arguments before passing them into remote functions in python mode
aref = test_functions.python_mode_f()
aref = test_functions.python_mode_f.remote()
self.assertTrue(np.alltrue(aref == np.array([0, 0])))
bref = test_functions.python_mode_g(aref)
bref = test_functions.python_mode_g.remote(aref)
self.assertTrue(np.alltrue(aref == np.array([0, 0]))) # python_mode_g should not mutate aref
self.assertTrue(np.alltrue(bref == np.array([1, 0])))
@ -528,8 +528,8 @@ class PythonCExtensionTest(unittest.TestCase):
@ray.remote([], [int])
def f():
return sys.getrefcount(obj)
first_count = ray.get(f())
second_count = ray.get(f())
first_count = ray.get(f.remote())
second_count = ray.get(f.remote())
self.assertEqual(first_count, second_count)
ray.services.cleanup()
@ -552,9 +552,9 @@ class ReusablesTest(unittest.TestCase):
@ray.remote([], [int])
def use_foo():
return ray.reusables.foo
self.assertEqual(ray.get(use_foo()), 1)
self.assertEqual(ray.get(use_foo()), 1)
self.assertEqual(ray.get(use_foo()), 1)
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_foo.remote()), 1)
self.assertEqual(ray.get(use_foo.remote()), 1)
# Test that we can add a variable to the key-value store, mutate it, and reset it.
@ -567,9 +567,9 @@ class ReusablesTest(unittest.TestCase):
def use_bar():
ray.reusables.bar.append(4)
return ray.reusables.bar
self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4])
self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4])
self.assertEqual(ray.get(use_bar()), [1, 2, 3, 4])
self.assertEqual(ray.get(use_bar.remote()), [1, 2, 3, 4])
self.assertEqual(ray.get(use_bar.remote()), [1, 2, 3, 4])
self.assertEqual(ray.get(use_bar.remote()), [1, 2, 3, 4])
# Test that we can use the reinitializer.
@ -587,10 +587,10 @@ class ReusablesTest(unittest.TestCase):
baz = ray.reusables.baz
baz[i] = 1
return baz
self.assertTrue(np.alltrue(ray.get(use_baz(0)) == np.array([1, 0, 0, 0])))
self.assertTrue(np.alltrue(ray.get(use_baz(1)) == np.array([0, 1, 0, 0])))
self.assertTrue(np.alltrue(ray.get(use_baz(2)) == np.array([0, 0, 1, 0])))
self.assertTrue(np.alltrue(ray.get(use_baz(3)) == np.array([0, 0, 0, 1])))
self.assertTrue(np.alltrue(ray.get(use_baz.remote(0)) == np.array([1, 0, 0, 0])))
self.assertTrue(np.alltrue(ray.get(use_baz.remote(1)) == np.array([0, 1, 0, 0])))
self.assertTrue(np.alltrue(ray.get(use_baz.remote(2)) == np.array([0, 0, 1, 0])))
self.assertTrue(np.alltrue(ray.get(use_baz.remote(3)) == np.array([0, 0, 0, 1])))
# Make sure the reinitializer is actually getting called. Note that this is
# not the correct usage of a reinitializer because it does not reset qux to
@ -606,9 +606,9 @@ class ReusablesTest(unittest.TestCase):
@ray.remote([], [int])
def use_qux():
return ray.reusables.qux
self.assertEqual(ray.get(use_qux()), 0)
self.assertEqual(ray.get(use_qux()), 1)
self.assertEqual(ray.get(use_qux()), 2)
self.assertEqual(ray.get(use_qux.remote()), 0)
self.assertEqual(ray.get(use_qux.remote()), 1)
self.assertEqual(ray.get(use_qux.remote()), 2)
ray.services.cleanup()

View file

@ -16,11 +16,11 @@ def test_alias_f():
@ray.remote([], [np.ndarray])
def test_alias_g():
return test_alias_f()
return test_alias_f.remote()
@ray.remote([], [np.ndarray])
def test_alias_h():
return test_alias_g()
return test_alias_g.remote()
# Test timing