Allow ResNet example to run on multiple machines. (#891)

* Allow a redis address to be passed into the ResNet example.

* Update documentation.
This commit is contained in:
Robert Nishihara 2017-08-29 21:37:53 -07:00 committed by Philipp Moritz
parent 164a8f368e
commit 4b76335157
2 changed files with 17 additions and 6 deletions

View file

@ -39,15 +39,20 @@ Then run the training script that matches the dataset you downloaded.
--dataset=cifar100 \
--num_gpus=1
The script will print out the IP address that the log files are stored on. In the single-node case,
you can ignore this and run tensorboard on the current machine.
To run the training script on a cluster with multiple machines, you will need
to also pass in the flag ``--redis-address=<redis_address>``, where
``<redis-address>`` is the address of the Redis server on the head node.
The script will print out the IP address that the log files are stored on. In
the single-node case, you can ignore this and run tensorboard on the current
machine.
.. code-block:: bash
python -m tensorflow.tensorboard --logdir=/tmp/resnet-model
If you are running Ray on multiple nodes, you will need to go to the node at the IP address printed, and
run the command.
If you are running Ray on multiple nodes, you will need to go to the node at the
IP address printed, and run the command.
The core of the script is the actor definition.
@ -76,7 +81,8 @@ The core of the script is the actor definition.
self.model.variables.sess.run(self.model.train_op)
return self.model.variables.get_weights()
The main script first creates one actor for each GPU, or a single actor if `num_gpus` is zero.
The main script first creates one actor for each GPU, or a single actor if
``num_gpus`` is zero.
.. code-block:: python

View file

@ -35,6 +35,8 @@ parser.add_argument("--eval_batch_count", default=50, type=int,
help="Number of batches to evaluate over.")
parser.add_argument("--num_gpus", default=0, type=int,
help="Number of GPUs to use for training.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
FLAGS = parser.parse_args()
@ -197,7 +199,10 @@ class ResNetTestActor(object):
def train():
num_gpus = FLAGS.num_gpus
ray.init(num_gpus=num_gpus, redirect_output=True)
if FLAGS.redis_address is None:
ray.init(num_gpus=num_gpus, redirect_output=True)
else:
ray.init(redis_address=FLAGS.redis_address)
train_data = get_data.remote(FLAGS.train_data_path, 50000, FLAGS.dataset)
test_data = get_data.remote(FLAGS.eval_data_path, 10000, FLAGS.dataset)
# Creates an actor for each gpu, or one if only using the cpu. Each actor