[air] update pytorch_training_e2e.py to use iter_torch_batches. (#27241)

update pytorch_training_e2e.py to use iter_torch_batches.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
This commit is contained in:
xwjiang2010 2022-08-01 11:23:01 -07:00 committed by GitHub
parent 57adde3f7d
commit c9579fea1c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,6 @@ import pandas as pd
from torchvision import transforms
from torchvision.models import resnet18
import torch
import torch.nn as nn
import torch.optim as optim
@ -49,15 +48,11 @@ def train_loop_per_worker(config):
for epoch in range(config["num_epochs"]):
running_loss = 0.0
for i, data in enumerate(
train_dataset_shard.iter_batches(
batch_size=config["batch_size"], batch_format="numpy"
)
train_dataset_shard.iter_torch_batches(batch_size=config["batch_size"])
):
# get the inputs; data is a list of [inputs, labels]
inputs = torch.as_tensor(data["image"], dtype=torch.float32).to(
device="cuda"
)
labels = torch.as_tensor(data["label"], dtype=torch.int64).to(device="cuda")
inputs = data["image"].to(device="cuda")
labels = data["label"].to(device="cuda")
# zero the parameter gradients
optimizer.zero_grad()