mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* SGD native AMP initial commit * SGD native amp second pass * Update docs * Update TorchTrainer doc * Temp fix release test * Update release/sgd_tests/sgd_gpu/sgd_gpu_app_config.yaml Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
import json
|
|
import os
|
|
import time
|
|
|
|
import ray
|
|
from ray.util.sgd.torch.examples.cifar_pytorch_example import train_cifar
|
|
import traceback
|
|
|
|
if __name__ == "__main__":
|
|
ray.init(address=os.environ.get("RAY_ADDRESS", "auto"))
|
|
start_time = time.time()
|
|
success = True
|
|
try:
|
|
from apex import amp # noqa: F401
|
|
except ImportError:
|
|
traceback.print_exc()
|
|
success = False
|
|
|
|
try:
|
|
train_cifar(
|
|
num_workers=1,
|
|
use_gpu=True,
|
|
num_epochs=5,
|
|
fp16=True,
|
|
test_mode=False)
|
|
except Exception as e:
|
|
print(f"(native fp16) The test failed with {e}")
|
|
success = False
|
|
|
|
try:
|
|
train_cifar(
|
|
num_workers=1,
|
|
use_gpu=True,
|
|
num_epochs=5,
|
|
fp16="apex",
|
|
test_mode=False)
|
|
except Exception as e:
|
|
print(f"(apex fp16) The test failed with {e}")
|
|
success = False
|
|
|
|
delta = time.time() - start_time
|
|
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
|
|
f.write(json.dumps({"train_time": delta, "success": success}))
|