ray/release/sgd_tests/sgd_gpu/sgd_gpu_test.py
Antoni Baum 2fb10e6730
[SGD] Add support for native Torch AMP in SGD (#16382)
* SGD native AMP initial commit

* SGD native amp second pass

* Update docs

* Update TorchTrainer doc

* Temp fix release test

* Update release/sgd_tests/sgd_gpu/sgd_gpu_app_config.yaml

Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
2021-06-15 17:48:21 -07:00

43 lines
1.1 KiB
Python

import json
import os
import time
import ray
from ray.util.sgd.torch.examples.cifar_pytorch_example import train_cifar
import traceback
if __name__ == "__main__":
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"))
start_time = time.time()
success = True
try:
from apex import amp # noqa: F401
except ImportError:
traceback.print_exc()
success = False
try:
train_cifar(
num_workers=1,
use_gpu=True,
num_epochs=5,
fp16=True,
test_mode=False)
except Exception as e:
print(f"(native fp16) The test failed with {e}")
success = False
try:
train_cifar(
num_workers=1,
use_gpu=True,
num_epochs=5,
fp16="apex",
test_mode=False)
except Exception as e:
print(f"(apex fp16) The test failed with {e}")
success = False
delta = time.time() - start_time
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
f.write(json.dumps({"train_time": delta, "success": success}))