mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
add check on shape (#21947)
This commit is contained in:
parent
1f58ee3731
commit
bfe3e5f4a8
1 changed files with 5 additions and 2 deletions
|
@ -236,7 +236,7 @@ class Net(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(dataset, model, device, criterion, optimizer):
|
def train_epoch(dataset, model, device, criterion, optimizer, feature_size):
|
||||||
num_correct = 0
|
num_correct = 0
|
||||||
num_total = 0
|
num_total = 0
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
|
@ -249,6 +249,8 @@ def train_epoch(dataset, model, device, criterion, optimizer):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Forward + backward + optimize
|
# Forward + backward + optimize
|
||||||
|
# check the input's shape matches the expectation
|
||||||
|
assert inputs.size()[1] == feature_size
|
||||||
outputs = model(inputs.float())
|
outputs = model(inputs.float())
|
||||||
loss = criterion(outputs, labels.float())
|
loss = criterion(outputs, labels.float())
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -340,7 +342,8 @@ def train_func(config):
|
||||||
label_column="label", batch_size=batch_size)
|
label_column="label", batch_size=batch_size)
|
||||||
|
|
||||||
train_running_loss, train_num_correct, train_num_total = train_epoch(
|
train_running_loss, train_num_correct, train_num_total = train_epoch(
|
||||||
train_torch_dataset, net, device, criterion, optimizer)
|
train_torch_dataset, net, device, criterion, optimizer,
|
||||||
|
num_features)
|
||||||
train_acc = train_num_correct / train_num_total
|
train_acc = train_num_correct / train_num_total
|
||||||
print(f"epoch [{epoch + 1}]: training accuracy: "
|
print(f"epoch [{epoch + 1}]: training accuracy: "
|
||||||
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")
|
f"{train_num_correct} / {train_num_total} = {train_acc:.4f}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue