[Train] Minor migration guide update (#20683)

* update docs

* tf
This commit is contained in:
Amog Kamsetty 2021-11-29 12:42:28 -08:00 committed by GitHub
parent 99ed623371
commit c03b937b95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 1 deletions

View file

@ -115,6 +115,8 @@ train.world_size
.. autofunction:: ray.train.world_size
.. _train-api-torch-utils:
PyTorch Training Function Utilities
-----------------------------------

View file

@ -49,7 +49,8 @@ There are 3 primary API differences between Ray SGD v1 and Ray Train.
results = trainer.run(train_func_distributed)
trainer.shutdown()
Currently, this means that you are now responsible for modifying your code to support distributed training (specifying ``DistributedDataParallel`` for ``torch`` or ``MultiWorkerMirroredStrategy`` for ``tensorflow``) as opposed to having this be automatically handled internally. However, we have plans to provide utilities that you can use to automatically handle these recipes for you.
If you are using PyTorch, you can use the :ref:`train-api-torch-utils` to automatically prepare your model & data loaders for distributed training.
If you are using Tensorflow, you would have to add ``MultiWorkerMirroredStrategy`` to your model in the training function instead of this automatically being done for you.
3. Rather than iteratively calling ``trainer.train()`` or ``trainer.validate()`` for each epoch, in Ray Train the training function defines the full training execution and is run via ``trainer.run(train_func)``.