mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
add check on shape (#21947)
This commit is contained in:
parent
1f58ee3731
commit
bfe3e5f4a8
1 changed files with 5 additions and 2 deletions
|
@ -236,7 +236,7 @@ class Net(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def train_epoch(dataset, model, device, criterion, optimizer):
|
||||
def train_epoch(dataset, model, device, criterion, optimizer, feature_size):
|
||||
num_correct = 0
|
||||
num_total = 0
|
||||
running_loss = 0.0
|
||||
|
@ -249,6 +249,8 @@ def train_epoch(dataset, model, device, criterion, optimizer):
|
|||
optimizer.zero_grad()
|
||||
|
||||
# Forward + backward + optimize
|
||||
# check the input's shape matches the expectation
|
||||
assert inputs.size()[1] == feature_size
|
||||
outputs = model(inputs.float())
|
||||
loss = criterion(outputs, labels.float())
|
||||
loss.backward()
|
||||
|
@ -340,7 +342,8 @@ def train_func(config):
|
|||
label_column="label", batch_size=batch_size)
|
||||
|
||||
train_running_loss, train_num_correct, train_num_total = train_epoch(
|
||||
train_torch_dataset, net, device, criterion, optimizer)
|
||||
train_torch_dataset, net, device, criterion, optimizer,
|
||||
num_features)
|
||||
train_acc = train_num_correct / train_num_total
|
||||
print(f"epoch [{epoch + 1}]: training accuracy: "
|
||||
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")
|
||||
|
|
Loading…
Add table
Reference in a new issue