add check on shape (#21947)

This commit is contained in:
Chen Shen 2022-01-28 12:27:43 -08:00 committed by GitHub
parent 1f58ee3731
commit bfe3e5f4a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -236,7 +236,7 @@ class Net(nn.Module):
return x
def train_epoch(dataset, model, device, criterion, optimizer):
def train_epoch(dataset, model, device, criterion, optimizer, feature_size):
num_correct = 0
num_total = 0
running_loss = 0.0
@ -249,6 +249,8 @@ def train_epoch(dataset, model, device, criterion, optimizer):
optimizer.zero_grad()
# Forward + backward + optimize
# check the input's shape matches the expectation
assert inputs.size()[1] == feature_size
outputs = model(inputs.float())
loss = criterion(outputs, labels.float())
loss.backward()
@ -340,7 +342,8 @@ def train_func(config):
label_column="label", batch_size=batch_size)
train_running_loss, train_num_correct, train_num_total = train_epoch(
train_torch_dataset, net, device, criterion, optimizer)
train_torch_dataset, net, device, criterion, optimizer,
num_features)
train_acc = train_num_correct / train_num_total
print(f"epoch [{epoch + 1}]: training accuracy: "
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")