mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[air] update pytorch_training_e2e.py to use iter_torch_batches. (#27241)
update pytorch_training_e2e.py to use iter_torch_batches. Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
This commit is contained in:
parent
57adde3f7d
commit
c9579fea1c
1 changed files with 3 additions and 8 deletions
|
@ -6,7 +6,6 @@ import pandas as pd
|
|||
|
||||
from torchvision import transforms
|
||||
from torchvision.models import resnet18
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
|
@ -49,15 +48,11 @@ def train_loop_per_worker(config):
|
|||
for epoch in range(config["num_epochs"]):
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(
|
||||
train_dataset_shard.iter_batches(
|
||||
batch_size=config["batch_size"], batch_format="numpy"
|
||||
)
|
||||
train_dataset_shard.iter_torch_batches(batch_size=config["batch_size"])
|
||||
):
|
||||
# get the inputs; data is a list of [inputs, labels]
|
||||
inputs = torch.as_tensor(data["image"], dtype=torch.float32).to(
|
||||
device="cuda"
|
||||
)
|
||||
labels = torch.as_tensor(data["label"], dtype=torch.int64).to(device="cuda")
|
||||
inputs = data["image"].to(device="cuda")
|
||||
labels = data["label"].to(device="cuda")
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue