mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
parent
99ed623371
commit
c03b937b95
2 changed files with 4 additions and 1 deletions
|
@ -115,6 +115,8 @@ train.world_size
|
|||
|
||||
.. autofunction:: ray.train.world_size
|
||||
|
||||
.. _train-api-torch-utils:
|
||||
|
||||
PyTorch Training Function Utilities
|
||||
-----------------------------------
|
||||
|
||||
|
|
|
@ -49,7 +49,8 @@ There are 3 primary API differences between Ray SGD v1 and Ray Train.
|
|||
results = trainer.run(train_func_distributed)
|
||||
trainer.shutdown()
|
||||
|
||||
Currently, this means that you are now responsible for modifying your code to support distributed training (specifying ``DistributedDataParallel`` for ``torch`` or ``MultiWorkerMirroredStrategy`` for ``tensorflow``) as opposed to having this be automatically handled internally. However, we have plans to provide utilities that you can use to automatically handle these recipes for you.
|
||||
If you are using PyTorch, you can use the :ref:`train-api-torch-utils` to automatically prepare your model & data loaders for distributed training.
|
||||
If you are using Tensorflow, you would have to add ``MultiWorkerMirroredStrategy`` to your model in the training function instead of this automatically being done for you.
|
||||
|
||||
3. Rather than iteratively calling ``trainer.train()`` or ``trainer.validate()`` for each epoch, in Ray Train the training function defines the full training execution and is run via ``trainer.run(train_func)``.
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue