[Train] Decorate get_device with PublicAPI (#22024)

* Decorate `get_device` with `PublicAPI`

* Add documentation

* Update api.rst
This commit is contained in:
Balaji Veeramani 2022-02-01 08:18:47 -08:00 committed by GitHub
parent b51b5afaea
commit 7dcb0b6af6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -210,6 +210,7 @@ class _WrappedDataLoader(DataLoader):
yield self._move_to_device(item)
@PublicAPI(stability="beta")
def get_device() -> torch.device:
"""Gets the correct torch device to use for training."""
if torch.cuda.is_available():