mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[Train] Decorate get_device
with PublicAPI
(#22024)
* Decorate `get_device` with `PublicAPI` * Add documentation * Update api.rst
This commit is contained in:
parent
b51b5afaea
commit
7dcb0b6af6
1 changed files with 1 additions and 0 deletions
|
@ -210,6 +210,7 @@ class _WrappedDataLoader(DataLoader):
|
|||
yield self._move_to_device(item)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
def get_device() -> torch.device:
|
||||
"""Gets the correct torch device to use for training."""
|
||||
if torch.cuda.is_available():
|
||||
|
|
Loading…
Add table
Reference in a new issue