[SOLVED] How to seperate code into train, val and test functions for pytorch cnn?


I am training a cnn using pytorch and have created a training loop. As I am performing optimisation and experimenting with hyper-parameter tuning, I want to separate my training, validation and testing into different functions. I need to be able to record my accuracy and loss for each function in order to plot graphs. For this I want to create a function which returns the accuracy.

I am pretty new to coding and was wondering the best way to go about this. I feel like my code is a bit messy at the moment. I need to be able to feed in various hyper-parameters for experimentation in my training function. Could anyone offer any advice? Below is what I can so far:

def train_model(model, optimizer, data_loader,  num_epochs, criterion=criterion):
  total_epochs = notebook.tqdm(range(num_epochs))

  for epoch in total_epochs:

    train_correct = 0.0

    for i, (img, label) in enumerate(data_loader['train']):
      #uploading images and labels to GPU
      img = img.to(device)
      label = label.to(device)

      #training model
      outputs = model(img)

      #computing losss
      loss = criterion(outputs, label)

      #propagating the loss backwards

      train_running_loss += loss.item()
      _, predicted = outputs.max(1)
      train_total += label.size(0)
      train_correct += predicted.eq(label).sum().item()

    print('Train Loss: %.3f | Train Accuracy: %.3f'%(train_loss,train_accu))

I have also experimented with making a functions to record accuracy:

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim = 1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


First, note that:

  • Unless you have some specific motivation, validation (and testing) should be performed on a different dataset than the training set, so you should use a different DataLoader. The computation time will increase because of an additional for loop at every epoch.
  • Always call model.eval() before validation/testing.

That said, The signature of the validation function is pretty much similar to that of train_model

# criterion is passed if you want to register the validation loss too
def validate_model(model, eval_loader, criterion):

Then, in train_model, after each epoch, you can call the function validate_model and store the returned metrics in some data structure (list, tensor, etc.) that will be used later for plotting.

At the end of the training, you can then use the same validate_model function for testing.

Instead of coding the accuracy by yourself, you can use Accuracy from TorchMetrics

Finally, if you feel the need to level up, you can use DL training frameworks like PyTorch Lightning or FastAI. Give also a look at some hyperparameter tuning library such as Ray Tune.

Answered By – aretor

Answer Checked By – Marilyn (BugsFixing Volunteer)

Leave a Reply

Your email address will not be published. Required fields are marked *