mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Implement repr, hash, and richcompare for ObjectIDs. (#33)
* Implement repr, hash, and richcompare for ObjectIDs. * Addressing comments. * Partially fix example applications.
This commit is contained in:
parent
9d1e750e8f
commit
336a904404
8 changed files with 127 additions and 87 deletions
|
@ -8,8 +8,6 @@ import alexnet
|
|||
|
||||
# Arguments to specify where the imagenet data is stored.
|
||||
parser = argparse.ArgumentParser(description="Run the AlexNet example.")
|
||||
parser.add_argument("--node-ip-address", default=None, type=str, help="The IP address of this node.")
|
||||
parser.add_argument("--scheduler-address", default=None, type=str, help="The address of the scheduler.")
|
||||
parser.add_argument("--s3-bucket", required=True, type=str, help="Name of the bucket that contains the image data.")
|
||||
parser.add_argument("--key-prefix", default="ILSVRC2012_img_train/n015", type=str, help="Prefix for files to fetch.")
|
||||
parser.add_argument("--label-file", default="train.txt", type=str, help="File containing labels.")
|
||||
|
@ -17,13 +15,7 @@ parser.add_argument("--label-file", default="train.txt", type=str, help="File co
|
|||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# If node_ip_address and scheduler_address are provided, then this command
|
||||
# will connect the driver to the existing scheduler. If not, it will start
|
||||
# a local scheduler and connect to it.
|
||||
ray.init(start_ray_local=(args.node_ip_address is None),
|
||||
node_ip_address=args.node_ip_address,
|
||||
scheduler_address=args.scheduler_address,
|
||||
num_workers=(10 if args.node_ip_address is None else None))
|
||||
ray.init(start_ray_local=True, num_workers=10)
|
||||
|
||||
# Note we do not do sess.run(tf.initialize_all_variables()) because that would
|
||||
# result in a different initialization on each worker. Instead, we initialize
|
||||
|
|
|
@ -10,21 +10,13 @@ from tensorflow.examples.tutorials.mnist import input_data
|
|||
import hyperopt
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the hyperparameter optimization example.")
|
||||
parser.add_argument("--node-ip-address", default=None, type=str, help="The IP address of this node.")
|
||||
parser.add_argument("--scheduler-address", default=None, type=str, help="The address of the scheduler.")
|
||||
parser.add_argument("--trials", default=2, type=int, help="The number of random trials to do.")
|
||||
parser.add_argument("--steps", default=10, type=int, help="The number of steps of training to do per network.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# If node_ip_address and scheduler_address are provided, then this command
|
||||
# will connect the driver to the existing scheduler. If not, it will start
|
||||
# a local scheduler and connect to it.
|
||||
ray.init(start_ray_local=(args.node_ip_address is None),
|
||||
node_ip_address=args.node_ip_address,
|
||||
scheduler_address=args.scheduler_address,
|
||||
num_workers=(10 if args.node_ip_address is None else None))
|
||||
ray.init(start_ray_local=True, num_workers=10)
|
||||
|
||||
# The number of sets of random hyperparameters to try.
|
||||
trials = args.trials
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import ray
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import scipy.optimize
|
||||
|
@ -7,20 +6,8 @@ import tensorflow as tf
|
|||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the L-BFGS example.")
|
||||
parser.add_argument("--node-ip-address", default=None, type=str, help="The IP address of this node.")
|
||||
parser.add_argument("--scheduler-address", default=None, type=str, help="The address of the scheduler.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# If node_ip_address and scheduler_address are provided, then this command
|
||||
# will connect the driver to the existing scheduler. If not, it will start
|
||||
# a local scheduler and connect to it.
|
||||
ray.init(start_ray_local=(args.node_ip_address is None),
|
||||
node_ip_address=args.node_ip_address,
|
||||
scheduler_address=args.scheduler_address,
|
||||
num_workers=(10 if args.node_ip_address is None else None))
|
||||
ray.init(start_ray_local=True, num_workers=10)
|
||||
|
||||
# Define the dimensions of the data and of the model.
|
||||
image_dimension = 784
|
||||
|
|
|
@ -4,14 +4,9 @@
|
|||
import numpy as np
|
||||
import cPickle as pickle
|
||||
import ray
|
||||
import argparse
|
||||
|
||||
import gym
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run the Pong example.")
|
||||
parser.add_argument("--node-ip-address", default=None, type=str, help="The IP address of this node.")
|
||||
parser.add_argument("--scheduler-address", default=None, type=str, help="The address of the scheduler.")
|
||||
|
||||
# hyperparameters
|
||||
H = 200 # number of hidden layer neurons
|
||||
batch_size = 10 # every how many episodes to do a param update?
|
||||
|
@ -113,15 +108,7 @@ def compute_gradient(model):
|
|||
return policy_backward(eph, epx, epdlogp, model), reward_sum
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# If node_ip_address and scheduler_address are provided, then this command
|
||||
# will connect the driver to the existing scheduler. If not, it will start
|
||||
# a local scheduler and connect to it.
|
||||
ray.init(start_ray_local=(args.node_ip_address is None),
|
||||
node_ip_address=args.node_ip_address,
|
||||
scheduler_address=args.scheduler_address,
|
||||
num_workers=(10 if args.node_ip_address is None else None))
|
||||
ray.init(start_ray_local=True, num_workers=10)
|
||||
|
||||
# Run the reinforcement learning
|
||||
running_reward = None
|
||||
|
|
|
@ -113,7 +113,7 @@ class RayGetError(Exception):
|
|||
|
||||
def __str__(self):
|
||||
"""Format a RayGetError as a string."""
|
||||
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid.id(), colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
class RayGetArgumentError(Exception):
|
||||
"""An exception used when a task's argument was produced by a failed task.
|
||||
|
@ -136,7 +136,7 @@ class RayGetArgumentError(Exception):
|
|||
|
||||
def __str__(self):
|
||||
"""Format a RayGetArgumentError as a string."""
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid.id(), self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
|
||||
class Reusable(object):
|
||||
|
@ -1008,6 +1008,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
|||
object_id_strs = [object_id.id() for object_id in object_ids]
|
||||
timeout = timeout if timeout is not None else 2 ** 36
|
||||
ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, timeout, num_returns)
|
||||
ready_ids = [photon.ObjectID(object_id) for object_id in ready_ids]
|
||||
remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids]
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
def format_error_message(exception_message):
|
||||
|
|
|
@ -51,6 +51,65 @@ static PyObject *PyObjectID_id(PyObject *self) {
|
|||
sizeof(object_id));
|
||||
}
|
||||
|
||||
static PyObject *PyObjectID_richcompare(PyObjectID *self,
|
||||
PyObject *other,
|
||||
int op) {
|
||||
PyObject *result = NULL;
|
||||
if (Py_TYPE(self)->tp_richcompare != Py_TYPE(other)->tp_richcompare) {
|
||||
result = Py_NotImplemented;
|
||||
} else {
|
||||
PyObjectID *other_id = (PyObjectID *) other;
|
||||
switch (op) {
|
||||
case Py_LT:
|
||||
result = Py_NotImplemented;
|
||||
break;
|
||||
case Py_LE:
|
||||
result = Py_NotImplemented;
|
||||
break;
|
||||
case Py_EQ:
|
||||
result = object_ids_equal(self->object_id, other_id->object_id)
|
||||
? Py_True
|
||||
: Py_False;
|
||||
break;
|
||||
case Py_NE:
|
||||
result = !object_ids_equal(self->object_id, other_id->object_id)
|
||||
? Py_True
|
||||
: Py_False;
|
||||
break;
|
||||
case Py_GT:
|
||||
result = Py_NotImplemented;
|
||||
break;
|
||||
case Py_GE:
|
||||
result = Py_NotImplemented;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Py_XINCREF(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
static long PyObjectID_hash(PyObjectID *self) {
|
||||
PyObject *tuple = PyTuple_New(UNIQUE_ID_SIZE);
|
||||
for (int i = 0; i < UNIQUE_ID_SIZE; ++i) {
|
||||
PyTuple_SetItem(tuple, i, PyInt_FromLong(self->object_id.id[i]));
|
||||
}
|
||||
long hash = PyObject_Hash(tuple);
|
||||
Py_XDECREF(tuple);
|
||||
return hash;
|
||||
}
|
||||
|
||||
static PyObject *PyObjectID_repr(PyObjectID *self) {
|
||||
int hex_length = 2 * UNIQUE_ID_SIZE + 1;
|
||||
char hex_id[hex_length];
|
||||
sha1_to_hex(self->object_id.id, hex_id);
|
||||
UT_string *repr;
|
||||
utstring_new(repr);
|
||||
utstring_printf(repr, "ObjectID(%s)", hex_id);
|
||||
PyObject *result = PyString_FromString(utstring_body(repr));
|
||||
utstring_free(repr);
|
||||
return result;
|
||||
}
|
||||
|
||||
static PyObject *PyObjectID___reduce__(PyObjectID *self) {
|
||||
PyErr_SetString(CommonError, "ObjectID objects cannot be serialized.");
|
||||
return NULL;
|
||||
|
@ -70,44 +129,44 @@ static PyMemberDef PyObjectID_members[] = {
|
|||
};
|
||||
|
||||
PyTypeObject PyObjectIDType = {
|
||||
PyObject_HEAD_INIT(NULL) 0, /* ob_size */
|
||||
"common.ObjectID", /* tp_name */
|
||||
sizeof(PyObjectID), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
0, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
0, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
0, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
0, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
"ObjectID object", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
0, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
0, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
PyObjectID_methods, /* tp_methods */
|
||||
PyObjectID_members, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc) PyObjectID_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
PyType_GenericNew, /* tp_new */
|
||||
PyObject_HEAD_INIT(NULL) 0, /* ob_size */
|
||||
"common.ObjectID", /* tp_name */
|
||||
sizeof(PyObjectID), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
0, /* tp_dealloc */
|
||||
0, /* tp_print */
|
||||
0, /* tp_getattr */
|
||||
0, /* tp_setattr */
|
||||
0, /* tp_compare */
|
||||
(reprfunc) PyObjectID_repr, /* tp_repr */
|
||||
0, /* tp_as_number */
|
||||
0, /* tp_as_sequence */
|
||||
0, /* tp_as_mapping */
|
||||
(hashfunc) PyObjectID_hash, /* tp_hash */
|
||||
0, /* tp_call */
|
||||
0, /* tp_str */
|
||||
0, /* tp_getattro */
|
||||
0, /* tp_setattro */
|
||||
0, /* tp_as_buffer */
|
||||
Py_TPFLAGS_DEFAULT, /* tp_flags */
|
||||
"ObjectID object", /* tp_doc */
|
||||
0, /* tp_traverse */
|
||||
0, /* tp_clear */
|
||||
(richcmpfunc) PyObjectID_richcompare, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
0, /* tp_iter */
|
||||
0, /* tp_iternext */
|
||||
PyObjectID_methods, /* tp_methods */
|
||||
PyObjectID_members, /* tp_members */
|
||||
0, /* tp_getset */
|
||||
0, /* tp_base */
|
||||
0, /* tp_dict */
|
||||
0, /* tp_descr_get */
|
||||
0, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
(initproc) PyObjectID_init, /* tp_init */
|
||||
0, /* tp_alloc */
|
||||
PyType_GenericNew, /* tp_new */
|
||||
};
|
||||
|
||||
/* Define the PyTask class. */
|
||||
|
|
|
@ -70,6 +70,27 @@ class TestObjectID(unittest.TestCase):
|
|||
self.assertRaises(Exception, lambda : pickling.dumps(g))
|
||||
self.assertRaises(Exception, lambda : pickling.dumps(h))
|
||||
|
||||
def test_equality_comparisons(self):
|
||||
x1 = common.ObjectID(20 * "a")
|
||||
x2 = common.ObjectID(20 * "a")
|
||||
y1 = common.ObjectID(20 * "b")
|
||||
y2 = common.ObjectID(20 * "b")
|
||||
self.assertEqual(x1, x2)
|
||||
self.assertEqual(y1, y2)
|
||||
self.assertNotEqual(x1, y1)
|
||||
|
||||
object_ids1 = [common.ObjectID(20 * chr(i)) for i in range(256)]
|
||||
object_ids2 = [common.ObjectID(20 * chr(i)) for i in range(256)]
|
||||
self.assertEqual(len(set(object_ids1)), 256)
|
||||
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
|
||||
self.assertEqual(set(object_ids1), set(object_ids2))
|
||||
|
||||
def test_hashability(self):
|
||||
x = common.ObjectID(20 * "a")
|
||||
y = common.ObjectID(20 * "b")
|
||||
{x: y}
|
||||
set([x, y])
|
||||
|
||||
class TestTask(unittest.TestCase):
|
||||
|
||||
def test_create_task(self):
|
||||
|
|
|
@ -326,7 +326,7 @@ class APITest(unittest.TestCase):
|
|||
self.assertEqual(len(ready_ids), 1)
|
||||
self.assertEqual(len(remaining_ids), 3)
|
||||
ready_ids, remaining_ids = ray.wait(objectids, num_returns=4)
|
||||
self.assertEqual(set(ready_ids), set([object_id.id() for object_id in objectids]))
|
||||
self.assertEqual(set(ready_ids), set(objectids))
|
||||
self.assertEqual(remaining_ids, [])
|
||||
|
||||
objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), f.remote(0.5)]
|
||||
|
|
Loading…
Add table
Reference in a new issue