如何将代码分离为 pytorch cnn 的 train、val 和 test 函数?

乘风 pytorch 538

原文标题How to seperate code into train, val and test functions for pytorch cnn?

我正在使用 pytorch 训练 cnn 并创建了一个训练循环。当我正在执行优化和试验超参数调整时,我想将我的训练、验证和测试分成不同的功能。为了绘制图表,我需要能够记录每个函数的准确性和损失。为此,我想创建一个返回准确性的函数。

我对编码很陌生,并且想知道最好的方法。我觉得我的代码现在有点乱。我需要能够输入各种超参数,以便在我的训练功能中进行实验。有人可以提供任何建议吗?以下是我目前所能做到的:

def train_model(model, optimizer, data_loader,  num_epochs, criterion=criterion):
  total_epochs = notebook.tqdm(range(num_epochs))

  for epoch in total_epochs:
    model.train()

    train_correct = 0.0
    train_running_loss=0.0
    train_total=0.0

    for i, (img, label) in enumerate(data_loader['train']):
      #uploading images and labels to GPU
      img = img.to(device)
      label = label.to(device)

      #training model
      outputs = model(img)

      #computing losss
      loss = criterion(outputs, label)

      #propagating the loss backwards
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_running_loss += loss.item()
    
      _, predicted = outputs.max(1)
      train_total += label.size(0)
      train_correct += predicted.eq(label).sum().item()
      
    train_loss=train_running_loss/len(data_loader['train'])
    train_accu=100.*correct/total

    print('Train Loss: %.3f | Train Accuracy: %.3f'%(train_loss,train_accu))

我还尝试过制作一个记录准确性的函数:

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim = 1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

原文链接:https://stackoverflow.com//questions/71512214/how-to-seperate-code-into-train-val-and-test-functions-for-pytorch-cnn

回复

我来回复
  • aretor的头像
    aretor 评论

    首先,请注意:

    • 除非您有特定的动机,否则验证(和测试)应该在与训练集不同的数据集上执行,因此您应该使用不同的 DataLoader。由于每个时期都有一个额外的 for 循环,计算时间会增加。
    • 在验证/测试之前始终调用 model.eval()。

    也就是说,验证函数的签名与train_model的签名非常相似

    # criterion is passed if you want to register the validation loss too
    def validate_model(model, eval_loader, criterion):
       ...
    

    然后,在train_model中,在每个 epoch 之后,您可以调用函数validate_model并将返回的指标存储在一些数据结构(listtensor等)中,稍后将用于绘图。

    在训练结束时,您可以使用相同的validate_model函数进行测试。

    您可以使用AccuracyfromTorchMetrics,而不是自己编码准确性

    最后,如果你觉得需要升级,可以使用 DL 训练框架,如 PyTorch Lightning 或 FastAI。还可以查看一些超参数调优库,例如 Ray Tune。

    2年前 0条评论