2016-12-11 12:25:31 -08:00
from __future__ import absolute_import
from __future__ import division
2016-11-02 00:39:35 -07:00
from __future__ import print_function
2016-12-23 00:43:16 -08:00
import os
2016-02-22 13:55:06 -08:00
import unittest
2016-06-10 14:12:15 -07:00
import ray
2016-02-22 13:55:06 -08:00
import numpy as np
import time
2016-12-23 00:43:16 -08:00
import shutil
2016-09-06 13:28:24 -07:00
import string
2016-07-22 14:15:02 -07:00
import sys
2016-09-06 13:28:24 -07:00
from collections import namedtuple
2016-03-01 01:02:08 -08:00
2016-12-13 17:37:22 -08:00
if sys . version_info > = ( 3 , 0 ) :
from importlib import reload
2016-12-11 12:25:31 -08:00
import ray . test . test_functions as test_functions
2016-06-27 11:35:31 -07:00
import ray . array . remote as ra
import ray . array . distributed as da
2016-04-18 13:05:36 -07:00
2016-09-06 13:28:24 -07:00
def assert_equal ( obj1 , obj2 ) :
if type ( obj1 ) . __module__ == np . __name__ or type ( obj2 ) . __module__ == np . __name__ :
if ( hasattr ( obj1 , " shape " ) and obj1 . shape == ( ) ) or ( hasattr ( obj2 , " shape " ) and obj2 . shape == ( ) ) :
# This is a special case because currently np.testing.assert_equal fails
# because we do not properly handle different numerical types.
assert obj1 == obj2 , " Objects {} and {} are different. " . format ( obj1 , obj2 )
else :
np . testing . assert_equal ( obj1 , obj2 )
elif hasattr ( obj1 , " __dict__ " ) and hasattr ( obj2 , " __dict__ " ) :
special_keys = [ " _pytype_ " ]
2016-12-13 17:37:22 -08:00
assert set ( list ( obj1 . __dict__ . keys ( ) ) + special_keys ) == set ( list ( obj2 . __dict__ . keys ( ) ) + special_keys ) , " Objects {} and {} are different. " . format ( obj1 , obj2 )
2016-09-06 13:28:24 -07:00
for key in obj1 . __dict__ . keys ( ) :
if key not in special_keys :
assert_equal ( obj1 . __dict__ [ key ] , obj2 . __dict__ [ key ] )
elif type ( obj1 ) is dict or type ( obj2 ) is dict :
assert_equal ( obj1 . keys ( ) , obj2 . keys ( ) )
for key in obj1 . keys ( ) :
assert_equal ( obj1 [ key ] , obj2 [ key ] )
elif type ( obj1 ) is list or type ( obj2 ) is list :
assert len ( obj1 ) == len ( obj2 ) , " Objects {} and {} are lists with different lengths. " . format ( obj1 , obj2 )
for i in range ( len ( obj1 ) ) :
assert_equal ( obj1 [ i ] , obj2 [ i ] )
elif type ( obj1 ) is tuple or type ( obj2 ) is tuple :
assert len ( obj1 ) == len ( obj2 ) , " Objects {} and {} are tuples with different lengths. " . format ( obj1 , obj2 )
for i in range ( len ( obj1 ) ) :
assert_equal ( obj1 [ i ] , obj2 [ i ] )
2016-12-13 17:37:22 -08:00
elif ray . serialization . is_named_tuple ( type ( obj1 ) ) or ray . serialization . is_named_tuple ( type ( obj2 ) ) :
assert len ( obj1 ) == len ( obj2 ) , " Objects {} and {} are named tuples with different lengths. " . format ( obj1 , obj2 )
for i in range ( len ( obj1 ) ) :
assert_equal ( obj1 [ i ] , obj2 [ i ] )
2016-09-06 13:28:24 -07:00
else :
assert obj1 == obj2 , " Objects {} and {} are different. " . format ( obj1 , obj2 )
2016-12-13 17:37:22 -08:00
if sys . version_info > = ( 3 , 0 ) :
long_extras = [ 0 , np . array ( [ [ " hi " , u " hi " ] , [ 1.3 , 1 ] ] ) ]
else :
long_extras = [ long ( 0 ) , np . array ( [ [ " hi " , u " hi " ] , [ 1.3 , long ( 1 ) ] ] ) ]
PRIMITIVE_OBJECTS = [ 0 , 0.0 , 0.9 , 1 << 62 , " a " , string . printable , " \u262F " ,
2016-09-06 13:28:24 -07:00
u " hello world " , u " \xff \xfe \x9c \x00 1 \x00 0 \x00 " , None , True ,
False , [ ] , ( ) , { } , np . int8 ( 3 ) , np . int32 ( 4 ) , np . int64 ( 5 ) ,
2016-09-15 15:44:11 -07:00
np . uint8 ( 3 ) , np . uint32 ( 4 ) , np . uint64 ( 5 ) , np . float32 ( 1.9 ) ,
np . float64 ( 1.9 ) , np . zeros ( [ 100 , 100 ] ) ,
2016-09-06 13:28:24 -07:00
np . random . normal ( size = [ 100 , 100 ] ) , np . array ( [ " hi " , 3 ] ) ,
2016-12-13 17:37:22 -08:00
np . array ( [ " hi " , 3 ] , dtype = object ) ] + long_extras
2016-09-06 13:28:24 -07:00
2016-11-30 23:21:53 -08:00
COMPLEX_OBJECTS = [ [ [ [ [ [ [ [ [ [ [ [ [ ] ] ] ] ] ] ] ] ] ] ] ] ,
2016-09-06 13:28:24 -07:00
{ " obj {} " . format ( i ) : np . random . normal ( size = [ 100 , 100 ] ) for i in range ( 10 ) } ,
#{(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {}}}}}}}}}}}}},
2016-11-30 23:21:53 -08:00
( ( ( ( ( ( ( ( ( ( ) , ) , ) , ) , ) , ) , ) , ) , ) , ) ,
{ " a " : { " b " : { " c " : { " d " : { } } } } }
2016-09-06 13:28:24 -07:00
]
class Foo ( object ) :
2016-06-21 17:39:48 -07:00
def __init__ ( self ) :
pass
2016-09-06 13:28:24 -07:00
class Bar ( object ) :
def __init__ ( self ) :
for i , val in enumerate ( PRIMITIVE_OBJECTS + COMPLEX_OBJECTS ) :
setattr ( self , " field {} " . format ( i ) , val )
class Baz ( object ) :
def __init__ ( self ) :
self . foo = Foo ( )
self . bar = Bar ( )
def method ( self , arg ) :
pass
class Qux ( object ) :
def __init__ ( self ) :
self . objs = [ Foo ( ) , Bar ( ) , Baz ( ) ]
2016-06-21 17:39:48 -07:00
2016-09-06 13:28:24 -07:00
class SubQux ( Qux ) :
def __init__ ( self ) :
Qux . __init__ ( self )
2016-06-21 17:39:48 -07:00
2016-09-06 13:28:24 -07:00
class CustomError ( Exception ) :
pass
2016-03-10 12:35:31 -08:00
2016-09-06 13:28:24 -07:00
Point = namedtuple ( " Point " , [ " x " , " y " ] )
NamedTupleExample = namedtuple ( " Example " , " field1, field2, field3, field4, field5 " )
2016-03-10 12:35:31 -08:00
2016-09-06 13:28:24 -07:00
CUSTOM_OBJECTS = [ Exception ( " Test object. " ) , CustomError ( ) , Point ( 11 , y = 22 ) ,
Foo ( ) , Bar ( ) , Baz ( ) , # Qux(), SubQux(),
NamedTupleExample ( 1 , 1.0 , " hi " , np . zeros ( [ 3 , 5 ] ) , [ 1 , 2 , 3 ] ) ]
2016-08-15 11:02:54 -07:00
2016-09-06 13:28:24 -07:00
BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS
2016-03-10 12:35:31 -08:00
2016-09-06 13:28:24 -07:00
LIST_OBJECTS = [ [ obj ] for obj in BASE_OBJECTS ]
TUPLE_OBJECTS = [ ( obj , ) for obj in BASE_OBJECTS ]
# The check that type(obj).__module__ != "numpy" should be unnecessary, but
# otherwise this seems to fail on Mac OS X on Travis.
DICT_OBJECTS = ( [ { obj : obj } for obj in PRIMITIVE_OBJECTS if obj . __hash__ is not None and type ( obj ) . __module__ != " numpy " ] +
# DICT_OBJECTS = ([{obj: obj} for obj in BASE_OBJECTS if obj.__hash__ is not None] +
[ { 0 : obj } for obj in BASE_OBJECTS ] )
2016-08-15 11:02:54 -07:00
2016-09-06 13:28:24 -07:00
RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS
2016-03-12 15:25:45 -08:00
2016-09-09 16:46:18 -07:00
# Check that the correct version of cloudpickle is installed.
try :
import cloudpickle
cloudpickle . dumps ( Point )
except AttributeError :
cloudpickle_command = " sudo pip install --upgrade git+git://github.com/cloudpipe/cloudpickle.git@0d225a4695f1f65ae1cbb2e0bbc145e10167cce4 "
raise Exception ( " You have an older version of cloudpickle that is not able to serialize namedtuples. Try running \n \n {} \n \n " . format ( cloudpickle_command ) )
2016-09-19 17:17:42 -07:00
class SerializationTest ( unittest . TestCase ) :
def testRecursiveObjects ( self ) :
ray . init ( start_ray_local = True , num_workers = 0 )
class ClassA ( object ) :
pass
ray . register_class ( ClassA )
# Make a list that contains itself.
l = [ ]
l . append ( l )
# Make an object that contains itself as a field.
a1 = ClassA ( )
a1 . field = a1
# Make two objects that contain each other as fields.
a2 = ClassA ( )
a3 = ClassA ( )
a2 . field = a3
a3 . field = a2
# Make a dictionary that contains itself.
d1 = { }
d1 [ " key " ] = d1
# Create a list of recursive objects.
recursive_objects = [ l , a1 , a2 , a3 , d1 ]
# Check that exceptions are thrown when we serialize the recursive objects.
for obj in recursive_objects :
self . assertRaises ( Exception , lambda : ray . put ( obj ) )
ray . worker . cleanup ( )
2016-11-30 23:21:53 -08:00
def testPassingArgumentsByValue ( self ) :
ray . init ( start_ray_local = True , num_workers = 1 )
@ray.remote
def f ( x ) :
return x
ray . register_class ( Exception )
ray . register_class ( CustomError )
ray . register_class ( Point )
ray . register_class ( Foo )
ray . register_class ( Bar )
ray . register_class ( Baz )
ray . register_class ( NamedTupleExample )
# Check that we can pass arguments by value to remote functions and that
# they are uncorrupted.
for obj in RAY_TEST_OBJECTS :
assert_equal ( obj , ray . get ( f . remote ( obj ) ) )
ray . worker . cleanup ( )
2016-08-15 11:02:54 -07:00
class WorkerTest ( unittest . TestCase ) :
2016-03-10 14:40:46 -08:00
2016-08-15 11:02:54 -07:00
def testPutGet ( self ) :
ray . init ( start_ray_local = True , num_workers = 0 )
2016-08-11 12:40:55 -07:00
2016-03-10 14:40:46 -08:00
for i in range ( 100 ) :
value_before = i * 10 * * 6
2016-07-31 19:58:03 -07:00
objectid = ray . put ( value_before )
value_after = ray . get ( objectid )
2016-03-10 14:40:46 -08:00
self . assertEqual ( value_before , value_after )
for i in range ( 100 ) :
value_before = i * 10 * * 6 * 1.0
2016-07-31 19:58:03 -07:00
objectid = ray . put ( value_before )
value_after = ray . get ( objectid )
2016-03-10 14:40:46 -08:00
self . assertEqual ( value_before , value_after )
for i in range ( 100 ) :
value_before = " h " * i
2016-07-31 19:58:03 -07:00
objectid = ray . put ( value_before )
value_after = ray . get ( objectid )
2016-03-10 14:40:46 -08:00
self . assertEqual ( value_before , value_after )
for i in range ( 100 ) :
value_before = [ 1 ] * i
2016-07-31 19:58:03 -07:00
objectid = ray . put ( value_before )
value_after = ray . get ( objectid )
2016-03-10 14:40:46 -08:00
self . assertEqual ( value_before , value_after )
2016-08-15 11:02:54 -07:00
ray . worker . cleanup ( )
class APITest ( unittest . TestCase ) :
2016-09-06 13:28:24 -07:00
def testRegisterClass ( self ) :
ray . init ( start_ray_local = True , num_workers = 0 )
# Check that putting an object of a class that has not been registered
# throws an exception.
class TempClass ( object ) :
pass
self . assertRaises ( Exception , lambda : ray . put ( Foo ) )
# Check that registering a class that Ray cannot serialize efficiently
# raises an exception.
self . assertRaises ( Exception , lambda : ray . register_class ( type ( True ) ) )
# Check that registering the same class with pickle works.
ray . register_class ( type ( float ) , pickle = True )
self . assertEqual ( ray . get ( ray . put ( float ) ) , float )
ray . worker . cleanup ( )
2016-08-15 11:02:54 -07:00
def testKeywordArgs ( self ) :
reload ( test_functions )
ray . init ( start_ray_local = True , num_workers = 1 )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct1 . remote ( 1 )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 1 hello " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct1 . remote ( 1 , " hi " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 1 hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct1 . remote ( 1 , b = " world " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 1 world " )
2016-06-03 00:10:17 -07:00
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct2 . remote ( a = " w " , b = " hi " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " w hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct2 . remote ( b = " hi " , a = " w " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " w hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct2 . remote ( a = " w " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " w world " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct2 . remote ( b = " hi " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " hello hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct2 . remote ( " w " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " w world " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct2 . remote ( " w " , " hi " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " w hi " )
2016-06-03 00:10:17 -07:00
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct3 . remote ( 0 , 1 , c = " w " , d = " hi " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 0 1 w hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct3 . remote ( 0 , 1 , d = " hi " , c = " w " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 0 1 w hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct3 . remote ( 0 , 1 , c = " w " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 0 1 w world " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct3 . remote ( 0 , 1 , d = " hi " )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 0 1 hello hi " )
2016-07-31 15:25:19 -07:00
x = test_functions . keyword_fct3 . remote ( 0 , 1 )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 0 1 hello world " )
2016-06-03 00:10:17 -07:00
2016-08-15 11:02:54 -07:00
ray . worker . cleanup ( )
def testVariableNumberOfArgs ( self ) :
reload ( test_functions )
ray . init ( start_ray_local = True , num_workers = 1 )
2016-07-31 15:25:19 -07:00
x = test_functions . varargs_fct1 . remote ( 0 , 1 , 2 )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 0 1 2 " )
2016-07-31 15:25:19 -07:00
x = test_functions . varargs_fct2 . remote ( 0 , 1 , 2 )
2016-06-23 12:58:48 -07:00
self . assertEqual ( ray . get ( x ) , " 1 2 " )
2016-06-04 16:22:10 -07:00
self . assertTrue ( test_functions . kwargs_exception_thrown )
self . assertTrue ( test_functions . varargs_and_kwargs_exception_thrown )
2016-08-15 11:02:54 -07:00
ray . worker . cleanup ( )
def testNoArgs ( self ) :
reload ( test_functions )
ray . init ( start_ray_local = True , num_workers = 1 )
2016-11-02 00:39:35 -07:00
ray . get ( test_functions . no_op . remote ( ) )
2016-07-05 15:57:05 -07:00
2016-08-15 11:02:54 -07:00
ray . worker . cleanup ( )
def testDefiningRemoteFunctions ( self ) :
2016-09-03 19:34:45 -07:00
ray . init ( start_ray_local = True , num_workers = 3 )
2016-07-05 15:57:05 -07:00
2016-07-17 22:05:07 -07:00
# Test that we can define a remote function in the shell.
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-17 22:05:07 -07:00
def f ( x ) :
return x + 1
2016-07-31 15:25:19 -07:00
self . assertEqual ( ray . get ( f . remote ( 0 ) ) , 1 )
2016-07-17 22:05:07 -07:00
# Test that we can redefine the remote function.
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-17 22:05:07 -07:00
def f ( x ) :
return x + 10
2016-11-06 22:24:39 -08:00
while True :
val = ray . get ( f . remote ( 0 ) )
2016-11-18 19:57:51 -08:00
self . assertTrue ( val in [ 1 , 10 ] )
2016-11-06 22:24:39 -08:00
if val == 10 :
break
else :
print ( " Still using old definition of f, trying again. " )
2016-07-17 22:05:07 -07:00
# Test that we can close over plain old data.
2016-12-13 17:37:22 -08:00
data = [ np . zeros ( [ 3 , 5 ] ) , ( 1 , 2 , " a " ) , [ 0.0 , 1.0 , 1 << 62 ] , 1 << 60 , { " a " : np . zeros ( 3 ) } ]
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-17 22:05:07 -07:00
def g ( ) :
return data
2016-07-31 15:25:19 -07:00
ray . get ( g . remote ( ) )
2016-07-17 22:05:07 -07:00
# Test that we can close over modules.
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-17 22:05:07 -07:00
def h ( ) :
return np . zeros ( [ 3 , 5 ] )
2016-08-02 16:11:53 -07:00
assert_equal ( ray . get ( h . remote ( ) ) , np . zeros ( [ 3 , 5 ] ) )
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-17 22:05:07 -07:00
def j ( ) :
return time . time ( )
2016-07-31 15:25:19 -07:00
ray . get ( j . remote ( ) )
2016-07-17 22:05:07 -07:00
2016-07-19 16:05:45 -07:00
# Test that we can define remote functions that call other remote functions.
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-19 16:05:45 -07:00
def k ( x ) :
return x + 1
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-19 16:05:45 -07:00
def l ( x ) :
2016-09-03 19:34:45 -07:00
return ray . get ( k . remote ( x ) )
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-19 16:05:45 -07:00
def m ( x ) :
2016-07-31 15:25:19 -07:00
return ray . get ( l . remote ( x ) )
self . assertEqual ( ray . get ( k . remote ( 1 ) ) , 2 )
self . assertEqual ( ray . get ( l . remote ( 1 ) ) , 2 )
self . assertEqual ( ray . get ( m . remote ( 1 ) ) , 2 )
2016-07-19 16:05:45 -07:00
2016-08-01 17:55:38 -07:00
ray . worker . cleanup ( )
2016-07-17 22:05:07 -07:00
2016-09-02 18:02:44 -07:00
def testGetMultiple ( self ) :
ray . init ( start_ray_local = True , num_workers = 0 )
object_ids = [ ray . put ( i ) for i in range ( 10 ) ]
2016-12-13 17:37:22 -08:00
self . assertEqual ( ray . get ( object_ids ) , list ( range ( 10 ) ) )
2016-09-02 18:02:44 -07:00
ray . worker . cleanup ( )
2016-09-14 17:14:11 -07:00
def testWait ( self ) :
ray . init ( start_ray_local = True , num_workers = 1 )
2016-08-15 16:51:59 -07:00
2016-08-30 15:14:02 -07:00
@ray.remote
2016-08-15 16:51:59 -07:00
def f ( delay ) :
time . sleep ( delay )
return 1
2016-09-14 17:14:11 -07:00
objectids = [ f . remote ( 1.0 ) , f . remote ( 0.5 ) , f . remote ( 0.5 ) , f . remote ( 0.5 ) ]
ready_ids , remaining_ids = ray . wait ( objectids )
2016-11-02 00:39:35 -07:00
self . assertEqual ( len ( ready_ids ) , 1 )
self . assertEqual ( len ( remaining_ids ) , 3 )
2016-09-14 17:14:11 -07:00
ready_ids , remaining_ids = ray . wait ( objectids , num_returns = 4 )
2016-11-11 09:18:36 -08:00
self . assertEqual ( set ( ready_ids ) , set ( objectids ) )
2016-09-14 17:14:11 -07:00
self . assertEqual ( remaining_ids , [ ] )
objectids = [ f . remote ( 0.5 ) , f . remote ( 0.5 ) , f . remote ( 0.5 ) , f . remote ( 0.5 ) ]
start_time = time . time ( )
2016-11-02 00:39:35 -07:00
ready_ids , remaining_ids = ray . wait ( objectids , timeout = 1750 , num_returns = 4 )
self . assertLess ( time . time ( ) - start_time , 2 )
2016-09-14 17:14:11 -07:00
self . assertEqual ( len ( ready_ids ) , 3 )
self . assertEqual ( len ( remaining_ids ) , 1 )
ray . wait ( objectids )
objectids = [ f . remote ( 1.0 ) , f . remote ( 0.5 ) , f . remote ( 0.5 ) , f . remote ( 0.5 ) ]
start_time = time . time ( )
2016-11-02 00:39:35 -07:00
ready_ids , remaining_ids = ray . wait ( objectids , timeout = 5000 )
2016-09-14 17:14:11 -07:00
self . assertTrue ( time . time ( ) - start_time < 5 )
self . assertEqual ( len ( ready_ids ) , 1 )
self . assertEqual ( len ( remaining_ids ) , 3 )
2016-08-16 14:54:54 -07:00
2016-08-15 16:51:59 -07:00
ray . worker . cleanup ( )
2016-08-15 11:02:54 -07:00
def testCachingReusables ( self ) :
2016-07-26 11:40:09 -07:00
# Test that we can define reusable variables before the driver is connected.
def foo_initializer ( ) :
return 1
def bar_initializer ( ) :
return [ ]
def bar_reinitializer ( bar ) :
return [ ]
ray . reusables . foo = ray . Reusable ( foo_initializer )
ray . reusables . bar = ray . Reusable ( bar_initializer , bar_reinitializer )
2016-08-15 11:02:54 -07:00
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-26 11:40:09 -07:00
def use_foo ( ) :
return ray . reusables . foo
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-26 11:40:09 -07:00
def use_bar ( ) :
ray . reusables . bar . append ( 1 )
return ray . reusables . bar
2016-07-31 19:26:35 -07:00
ray . init ( start_ray_local = True , num_workers = 2 )
2016-07-26 11:40:09 -07:00
2016-07-31 15:25:19 -07:00
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , 1 )
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , 1 )
self . assertEqual ( ray . get ( use_bar . remote ( ) ) , [ 1 ] )
self . assertEqual ( ray . get ( use_bar . remote ( ) ) , [ 1 ] )
2016-07-26 11:40:09 -07:00
2016-08-11 12:40:55 -07:00
ray . worker . cleanup ( )
2016-10-12 22:17:22 -07:00
def testCachingFunctionsToRun ( self ) :
# Test that we export functions to run on all workers before the driver is connected.
2016-12-22 22:05:58 -08:00
def f ( worker_info ) :
2016-10-12 22:17:22 -07:00
sys . path . append ( 1 )
ray . worker . global_worker . run_function_on_all_workers ( f )
2016-12-22 22:05:58 -08:00
def f ( worker_info ) :
2016-10-12 22:17:22 -07:00
sys . path . append ( 2 )
ray . worker . global_worker . run_function_on_all_workers ( f )
2016-12-22 22:05:58 -08:00
def g ( worker_info ) :
2016-10-12 22:17:22 -07:00
sys . path . append ( 3 )
ray . worker . global_worker . run_function_on_all_workers ( g )
2016-12-22 22:05:58 -08:00
def f ( worker_info ) :
2016-10-12 22:17:22 -07:00
sys . path . append ( 4 )
ray . worker . global_worker . run_function_on_all_workers ( f )
ray . init ( start_ray_local = True , num_workers = 2 )
@ray.remote
def get_state ( ) :
time . sleep ( 1 )
return sys . path [ - 4 ] , sys . path [ - 3 ] , sys . path [ - 2 ] , sys . path [ - 1 ]
res1 = get_state . remote ( )
res2 = get_state . remote ( )
self . assertEqual ( ray . get ( res1 ) , ( 1 , 2 , 3 , 4 ) )
self . assertEqual ( ray . get ( res2 ) , ( 1 , 2 , 3 , 4 ) )
# Clean up the path on the workers.
2016-12-22 22:05:58 -08:00
def f ( worker_info ) :
2016-10-12 22:17:22 -07:00
sys . path . pop ( )
sys . path . pop ( )
sys . path . pop ( )
sys . path . pop ( )
ray . worker . global_worker . run_function_on_all_workers ( f )
ray . worker . cleanup ( )
2016-08-16 14:53:55 -07:00
def testRunningFunctionOnAllWorkers ( self ) :
ray . init ( start_ray_local = True , num_workers = 1 )
2016-12-22 22:05:58 -08:00
def f ( worker_info ) :
2016-08-16 14:53:55 -07:00
sys . path . append ( " fake_directory " )
ray . worker . global_worker . run_function_on_all_workers ( f )
2016-08-30 15:14:02 -07:00
@ray.remote
2016-11-06 22:24:39 -08:00
def get_path1 ( ) :
2016-08-16 14:53:55 -07:00
return sys . path
2016-11-06 22:24:39 -08:00
self . assertEqual ( " fake_directory " , ray . get ( get_path1 . remote ( ) ) [ - 1 ] )
2016-12-22 22:05:58 -08:00
def f ( worker_info ) :
2016-08-16 14:53:55 -07:00
sys . path . pop ( - 1 )
ray . worker . global_worker . run_function_on_all_workers ( f )
2016-11-06 22:24:39 -08:00
# Create a second remote function to guarantee that when we call
# get_path2.remote(), the second function to run will have been run on the
# worker.
@ray.remote
def get_path2 ( ) :
return sys . path
self . assertTrue ( " fake_directory " not in ray . get ( get_path2 . remote ( ) ) )
2016-08-16 14:53:55 -07:00
ray . worker . cleanup ( )
2016-12-22 22:05:58 -08:00
def testPassingInfoToAllWorkers ( self ) :
ray . init ( start_ray_local = True , num_workers = 10 )
def f ( worker_info ) :
sys . path . append ( worker_info )
ray . worker . global_worker . run_function_on_all_workers ( f )
@ray.remote
def get_path ( ) :
time . sleep ( 1 )
return sys . path
# Retrieve the values that we stored in the worker paths.
paths = ray . get ( [ get_path . remote ( ) for _ in range ( 10 ) ] )
# Add the driver's path to the list.
paths . append ( sys . path )
worker_infos = [ path [ - 1 ] for path in paths ]
for worker_info in worker_infos :
self . assertEqual ( list ( worker_info . keys ( ) ) , [ " counter " ] )
counters = [ worker_info [ " counter " ] for worker_info in worker_infos ]
# We use range(11) because the driver also runs the function.
self . assertEqual ( set ( counters ) , set ( range ( 11 ) ) )
# Clean up the worker paths.
def f ( worker_info ) :
sys . path . pop ( - 1 )
ray . worker . global_worker . run_function_on_all_workers ( f )
ray . worker . cleanup ( )
2017-01-05 16:47:16 -08:00
def testLoggingAPI ( self ) :
ray . init ( start_ray_local = True , num_workers = 1 )
def events ( ) :
# This is a hack for getting the event log. It is not part of the API.
keys = ray . worker . global_worker . redis_client . keys ( " event_log:* " )
return [ ray . worker . global_worker . redis_client . lrange ( key , 0 , - 1 ) for key in keys ]
def wait_for_num_events ( num_events , timeout = 10 ) :
start_time = time . time ( )
while time . time ( ) - start_time < timeout :
if len ( events ( ) ) > = num_events :
return
time . sleep ( 0.1 )
print ( " Timing out of wait. " )
@ray.remote
def test_log_event ( ) :
ray . log_event ( " event_type1 " , contents = { " key " : " val " } )
@ray.remote
def test_log_span ( ) :
with ray . log_span ( " event_type2 " , contents = { " key " : " val " } ) :
pass
# Make sure that we can call ray.log_event in a remote function.
ray . get ( test_log_event . remote ( ) )
# Wait for the event to appear in the event log.
wait_for_num_events ( 1 )
self . assertEqual ( len ( events ( ) ) , 1 )
# Make sure that we can call ray.log_span in a remote function.
ray . get ( test_log_span . remote ( ) )
# Wait for the events to appear in the event log.
wait_for_num_events ( 2 )
self . assertEqual ( len ( events ( ) ) , 2 )
@ray.remote
def test_log_span_exception ( ) :
with ray . log_span ( " event_type2 " , contents = { " key " : " val " } ) :
raise Exception ( " This failed. " )
# Make sure that logging a span works if an exception is thrown.
test_log_span_exception . remote ( )
# Wait for the events to appear in the event log.
wait_for_num_events ( 3 )
self . assertEqual ( len ( events ( ) ) , 3 )
ray . worker . cleanup ( )
2016-08-15 11:02:54 -07:00
class PythonModeTest ( unittest . TestCase ) :
def testPythonMode ( self ) :
2016-07-26 11:40:09 -07:00
reload ( test_functions )
2016-07-31 19:26:35 -07:00
ray . init ( start_ray_local = True , driver_mode = ray . PYTHON_MODE )
2016-06-26 13:43:54 -07:00
2016-09-03 19:34:45 -07:00
@ray.remote
def f ( ) :
return np . ones ( [ 3 , 4 , 5 ] )
xref = f . remote ( )
2016-08-02 16:11:53 -07:00
assert_equal ( xref , np . ones ( [ 3 , 4 , 5 ] ) ) # remote functions should return by value
assert_equal ( xref , ray . get ( xref ) ) # ray.get should be the identity
2016-06-26 13:43:54 -07:00
y = np . random . normal ( size = [ 11 , 12 ] )
2016-08-02 16:11:53 -07:00
assert_equal ( y , ray . put ( y ) ) # ray.put should be the identity
2016-06-26 13:43:54 -07:00
# make sure objects are immutable, this example is why we need to copy
# arguments before passing them into remote functions in python mode
2016-07-31 15:25:19 -07:00
aref = test_functions . python_mode_f . remote ( )
2016-08-02 16:11:53 -07:00
assert_equal ( aref , np . array ( [ 0 , 0 ] ) )
2016-07-31 15:25:19 -07:00
bref = test_functions . python_mode_g . remote ( aref )
2016-08-02 16:11:53 -07:00
assert_equal ( aref , np . array ( [ 0 , 0 ] ) ) # python_mode_g should not mutate aref
assert_equal ( bref , np . array ( [ 1 , 0 ] ) )
2016-06-26 13:43:54 -07:00
2016-08-15 11:02:54 -07:00
ray . worker . cleanup ( )
2016-10-12 22:17:22 -07:00
def testReusableVariablesInPythonMode ( self ) :
reload ( test_functions )
ray . init ( start_ray_local = True , driver_mode = ray . PYTHON_MODE )
def l_init ( ) :
return [ ]
def l_reinit ( l ) :
return [ ]
ray . reusables . l = ray . Reusable ( l_init , l_reinit )
@ray.remote
def use_l ( ) :
l = ray . reusables . l
l . append ( 1 )
return l
# Get the local copy of the reusable variable. This should be stateful.
l = ray . reusables . l
assert_equal ( l , [ ] )
# Make sure the remote function does what we expect.
assert_equal ( ray . get ( use_l . remote ( ) ) , [ 1 ] )
assert_equal ( ray . get ( use_l . remote ( ) ) , [ 1 ] )
# Make sure the local copy of the reusable variable has not been mutated.
assert_equal ( l , [ ] )
l = ray . reusables . l
assert_equal ( l , [ ] )
# Make sure that running a remote function does not reset the state of the
# local copy of the reusable variable.
l . append ( 2 )
assert_equal ( ray . get ( use_l . remote ( ) ) , [ 1 ] )
assert_equal ( l , [ 2 ] )
ray . worker . cleanup ( )
2016-08-15 11:02:54 -07:00
class ReusablesTest ( unittest . TestCase ) :
def testReusables ( self ) :
2016-07-31 19:26:35 -07:00
ray . init ( start_ray_local = True , num_workers = 1 )
2016-07-21 00:16:19 -07:00
# Test that we can add a variable to the key-value store.
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
def foo_initializer ( ) :
return 1
def foo_reinitializer ( foo ) :
return foo
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
ray . reusables . foo = ray . Reusable ( foo_initializer , foo_reinitializer )
self . assertEqual ( ray . reusables . foo , 1 )
2016-08-15 11:02:54 -07:00
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-21 00:16:19 -07:00
def use_foo ( ) :
return ray . reusables . foo
2016-07-31 15:25:19 -07:00
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , 1 )
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , 1 )
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , 1 )
2016-07-21 00:16:19 -07:00
# Test that we can add a variable to the key-value store, mutate it, and reset it.
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
def bar_initializer ( ) :
return [ 1 , 2 , 3 ]
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
ray . reusables . bar = ray . Reusable ( bar_initializer )
2016-08-15 11:02:54 -07:00
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-21 00:16:19 -07:00
def use_bar ( ) :
ray . reusables . bar . append ( 4 )
return ray . reusables . bar
2016-07-31 15:25:19 -07:00
self . assertEqual ( ray . get ( use_bar . remote ( ) ) , [ 1 , 2 , 3 , 4 ] )
self . assertEqual ( ray . get ( use_bar . remote ( ) ) , [ 1 , 2 , 3 , 4 ] )
self . assertEqual ( ray . get ( use_bar . remote ( ) ) , [ 1 , 2 , 3 , 4 ] )
2016-07-21 00:16:19 -07:00
# Test that we can use the reinitializer.
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
def baz_initializer ( ) :
return np . zeros ( [ 4 ] )
def baz_reinitializer ( baz ) :
for i in range ( len ( baz ) ) :
baz [ i ] = 0
return baz
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
ray . reusables . baz = ray . Reusable ( baz_initializer , baz_reinitializer )
2016-08-15 11:02:54 -07:00
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-21 00:16:19 -07:00
def use_baz ( i ) :
baz = ray . reusables . baz
baz [ i ] = 1
return baz
2016-08-02 16:11:53 -07:00
assert_equal ( ray . get ( use_baz . remote ( 0 ) ) , np . array ( [ 1 , 0 , 0 , 0 ] ) )
assert_equal ( ray . get ( use_baz . remote ( 1 ) ) , np . array ( [ 0 , 1 , 0 , 0 ] ) )
assert_equal ( ray . get ( use_baz . remote ( 2 ) ) , np . array ( [ 0 , 0 , 1 , 0 ] ) )
assert_equal ( ray . get ( use_baz . remote ( 3 ) ) , np . array ( [ 0 , 0 , 0 , 1 ] ) )
2016-07-21 00:16:19 -07:00
# Make sure the reinitializer is actually getting called. Note that this is
# not the correct usage of a reinitializer because it does not reset qux to
# its original state. This is just for testing.
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
def qux_initializer ( ) :
return 0
def qux_reinitializer ( x ) :
return x + 1
2016-08-15 11:02:54 -07:00
2016-07-21 00:16:19 -07:00
ray . reusables . qux = ray . Reusable ( qux_initializer , qux_reinitializer )
2016-08-15 11:02:54 -07:00
2016-08-30 15:14:02 -07:00
@ray.remote
2016-07-21 00:16:19 -07:00
def use_qux ( ) :
return ray . reusables . qux
2016-07-31 15:25:19 -07:00
self . assertEqual ( ray . get ( use_qux . remote ( ) ) , 0 )
self . assertEqual ( ray . get ( use_qux . remote ( ) ) , 1 )
self . assertEqual ( ray . get ( use_qux . remote ( ) ) , 2 )
2016-07-21 00:16:19 -07:00
2016-08-11 12:40:55 -07:00
ray . worker . cleanup ( )
2016-08-08 16:01:13 -07:00
2016-10-12 22:17:22 -07:00
def testUsingReusablesOnDriver ( self ) :
ray . init ( start_ray_local = True , num_workers = 1 )
# Test that we can add a variable to the key-value store.
def foo_initializer ( ) :
return [ ]
def foo_reinitializer ( foo ) :
return [ ]
ray . reusables . foo = ray . Reusable ( foo_initializer , foo_reinitializer )
@ray.remote
def use_foo ( ) :
foo = ray . reusables . foo
foo . append ( 1 )
return foo
# Check that running a remote function does not reset the reusable variable
# on the driver.
foo = ray . reusables . foo
self . assertEqual ( foo , [ ] )
foo . append ( 2 )
self . assertEqual ( foo , [ 2 ] )
foo . append ( 3 )
self . assertEqual ( foo , [ 2 , 3 ] )
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , [ 1 ] )
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , [ 1 ] )
self . assertEqual ( ray . get ( use_foo . remote ( ) ) , [ 1 ] )
# Check that the copy of foo on the driver has not changed.
self . assertEqual ( foo , [ 2 , 3 ] )
foo = ray . reusables . foo
self . assertEqual ( foo , [ 2 , 3 ] )
ray . worker . cleanup ( )
2016-12-23 00:43:16 -08:00
class UtilsTest ( unittest . TestCase ) :
def testCopyingDirectory ( self ) :
# The functionality being tested here is really multi-node functionality,
# but this test just uses a single node.
ray . init ( start_ray_local = True , num_workers = 1 )
source_text = " hello world "
temp_dir1 = os . path . join ( os . path . dirname ( __file__ ) , " temp_dir1 " )
source_dir = os . path . join ( temp_dir1 , " dir " )
source_file = os . path . join ( source_dir , " file.txt " )
temp_dir2 = os . path . join ( os . path . dirname ( __file__ ) , " temp_dir2 " )
target_dir = os . path . join ( temp_dir2 , " dir " )
target_file = os . path . join ( target_dir , " file.txt " )
def remove_temporary_files ( ) :
if os . path . exists ( temp_dir1 ) :
shutil . rmtree ( temp_dir1 )
if os . path . exists ( temp_dir2 ) :
shutil . rmtree ( temp_dir2 )
# Remove the relevant files if they are left over from a previous run of
# this test.
remove_temporary_files ( )
# Create the source files.
os . mkdir ( temp_dir1 )
os . mkdir ( source_dir )
with open ( source_file , " w " ) as f :
f . write ( source_text )
# Copy the source directory to the target directory.
ray . experimental . copy_directory ( source_dir , target_dir )
time . sleep ( 0.5 )
# Check that the target files exist and are the same as the source files.
self . assertTrue ( os . path . exists ( target_dir ) )
self . assertTrue ( os . path . exists ( target_file ) )
with open ( target_file , " r " ) as f :
self . assertEqual ( f . read ( ) , source_text )
# Remove the relevant files to clean up.
remove_temporary_files ( )
ray . worker . cleanup ( )
2016-06-24 19:43:24 -07:00
if __name__ == " __main__ " :
2016-08-11 12:40:55 -07:00
unittest . main ( verbosity = 2 )