Implementing Custom Loss Functions in PyTorch

Understanding the theory and implementation of custom loss functions in PyTorch using the MNIST dataset

Marco Sanguineti
Towards Data Science

--

Photo by Markus Winkler on Unsplash

Introduction

In machine learning, the loss function is a critical component that measures the difference between the predicted output and the actual output. It plays a vital role in the training of a model, as it guides the optimization process by indicating the direction in which the model should improve. The choice of loss function depends on the specific task and the data type. In this article, we will delve into the theory and implementation of custom loss functions in PyTorch, using the MNIST dataset for digit classification as an example.

The MNIST dataset is a widely used dataset for image classification tasks, it contains 70,000 images of handwritten digits, each with a resolution of 28x28 pixels. The task is to classify these images into one of the 10 digits (0–9). This task aims to train a model that can accurately classify new images of handwritten digits, based on the training examples provided in the MNIST dataset.

Photo by Carlos Muza on Unsplash

A typical approach for this task is to use a multi-class logistic regression model, which is a softmax classifier. The softmax function maps the output of the model to a probability distribution over the 10 classes. The cross-entropy loss is commonly used as the loss function for this type of model. The cross-entropy loss calculates the difference between the predicted probability distribution and the actual probability distribution.

However, in some cases, the cross-entropy loss may not be the best choice for a particular task. For example, consider a scenario where the cost of misclassifying certain classes is much higher than others. In such cases, it is necessary to use a custom loss function that takes into account the relative importance of each class.

In this article, I will show you how to implement a custom loss function for MNIST dataset, where the cost of misclassifying the digit 9 is much higher than the other digits. We will use Pytorch as the framework, and we will start by discussing the theory behind the custom loss function, then we will show the implementation of the custom loss function using Pytorch. Finally, we will use the custom loss function to train a linear model on the MNIST dataset and we will evaluate the performance of the model.

Custom Loss function: why

Implementing custom loss functions is important for several reasons:

  1. Problem-specific: The choice of loss function depends on the specific task and the type of data. Custom loss functions can be designed to better suit the characteristics of the problem at hand, resulting in improved model performance.
  2. Class imbalance: In many real-world datasets, the number of samples in each class can be very different. A custom loss function can be designed to take into account the class imbalance and assign different costs to different classes.
  3. Cost-sensitive: In some tasks, the cost of misclassifying certain classes may be much higher than others. A custom loss function can be designed to take into account the relative importance of each class, resulting in a more robust model.
  4. Multi-task learning: Custom loss functions can be designed to handle multiple tasks simultaneously. This can be useful in cases where a single model is required to perform multiple related tasks.
  5. Regularization: Custom loss functions can also be used for regularization, which helps to prevent overfitting and improve the generalization of the model.
  6. Adversarial training: Custom loss functions can also be used to train models to be robust against adversarial attacks.

In summary, custom loss functions can provide a way to better optimize the model for a specific problem and can provide better performance and generalization.

Custom Loss function in PyTorch

The MNIST dataset contains 70,000 images of handwritten digits, each with a resolution of 28x28 pixels. The task is to classify these images into one of the 10 digits (0–9). The typical approach for this task is to use a multi-class logistic regression model, which is a softmax classifier. The softmax function maps the output of the model to a probability distribution over the 10 classes. The cross-entropy loss is commonly used as the loss function for this type of model.

Cross-entropy loss calculates the difference between the predicted probability distribution and the actual probability distribution. The predicted probability distribution is obtained by applying the softmax function to the output of the model. The actual probability distribution is a one-hot vector, where the value of the element corresponding to the correct class is 1 and the values of the other elements are 0. The cross-entropy loss is defined as:

L = -∑(y_i * log(p_i))

where y_i is the actual probability of class i and p_i is the predicted probability of class i.

However, in some cases, the cross-entropy loss may not be the best choice for a particular task. For example, consider a scenario where the cost of misclassifying certain classes is much higher than others. In such cases, it is necessary to use a custom loss function that takes into account the relative importance of each class.

In PyTorch, custom loss functions can be implemented by creating a subclass of the nn.Module class and overriding the forward method. The forward method takes as input the predicted output and the actual output and returns the value of the loss.

Here is an example of a custom loss function for the MNIST classification task, where the cost of misclassifying the digit 9 is much higher than the other digits:

class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()

def forward(self, output, target):
target = torch.LongTensor(target)
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
mask = target == 9
high_cost = (loss * mask.float()).mean()
return loss + high_cost

In this example, we first calculate the cross-entropy loss using the nn.CrossEntropyLoss() function. Next, we create a mask that is 1 for samples that belong to class 9 and 0 for the other samples. We then calculate the mean loss for the samples that belong to class 9. Finally, we add this high cost loss to the original loss to obtain the final loss.

To use the custom loss function, we need to instantiate it and pass it as the argument to the criterion parameter of the optimizer in the training loop. Here is an example of how to use the custom loss function for training a model on the MNIST dataset:

import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F
import torchvision
import os

class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()

def forward(self, output, target):
target = torch.LongTensor(target)
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
mask = target == 9
high_cost = (loss * mask.float()).mean()
return loss + high_cost




# Load the MNIST dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('/files/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=32, shuffle=True)

test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('/files/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=32, shuffle=True)


# Define the model, loss function and optimizer
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)

network = Net()
optimizer = optim.SGD(network.parameters(), lr=0.01,
momentum=0.5)
criterion = CustomLoss()

# Training loop
n_epochs = 10

train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

if os.path.exists('results'):
os.system('rm -r results')

os.mkdir('results')

def train(epoch):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 1000 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
train_counter.append(
(batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
torch.save(network.state_dict(), 'results/model.pth')
torch.save(optimizer.state_dict(), 'results/optimizer.pth')

def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += criterion(output, target).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


test()
for epoch in range(1, n_epochs + 1):
train(epoch)
test()

This code is an implementation of a custom loss function for the MNIST dataset in PyTorch. The MNIST dataset contains 70,000 images of handwritten digits, each with a resolution of 28x28 pixels. The task is to classify these images into one of the 10 digits (0–9).

The first block of code creates a custom loss function called “CustomLoss” by subclassing the PyTorch nn.Module. It has a forward method that takes in two inputs; the output of the model and the target label. The forward method first converts the target label to a tensor of long integers. Then it creates an instance of the built-in PyTorch cross-entropy loss function and uses it to calculate the loss between the model’s output and the target label. Next, it creates a mask that identifies the target label that is equal to 9, then it multiplies the loss by this mask and calculates the mean of the resulting tensor. Finally, it returns the sum of the original loss and the mean of the high-cost loss.

The next block of code loads the MNIST dataset using PyTorch’s built-in data-loading utilities. The train_loader loads the training dataset and applies the specified transforms to the images, such as converting the images to tensors and normalizing the pixel values. The test_loader loads the test dataset and applies the same transforms.

The following block of code defines a convolutional neural network (CNN) called “Net” by subclassing the PyTorch nn.Module. The CNN consists of 2 convolutional layers, 2 linear layers, and some dropout layers for regularization. The forward method of the Net class applies the convolutional and linear layers in sequence, passing the output through a ReLU activation function and max pooling layers. It also applies dropout layers to the output and returns the log-softmax of the final output.

The next block of code creates an instance of the Net class, an optimizer (stochastic gradient descent), and an instance of the custom loss function.

The final block of code is the training loop, where the model is trained for 10 epochs. In each epoch, the model iterates over the training dataset, passing the images through the network, calculating the loss using the custom loss function and backpropagating the gradients. Then it updates the model’s parameters using the optimizer. It also tracks the training loss and test loss and periodically prints the current loss to the console. Additionally, it creates a new directory called “results” to store the results and outputs of the training process.

import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()
Custom loss trend — Image by Author

This code is creating a plot of the custom loss function for the MNIST dataset during the training process. The plot will display the custom loss for both the training set and the test set.

It starts by importing the Matplotlib library, which is a plotting library for Python. Then, it creates a figure object with a specified size using the plt.figure() function.

The next line of code plots the custom loss for the training set using the plt.plot() function. It uses the train_counter and train_losses variables as the x and y-axis values, respectively. The color of the plot is set to blue using the color parameter.

Then, it plots the custom loss for the test set using the plt.scatter() function. It uses the test_counter and test_losses variables as the x and y-axis values, respectively. The colour of the plot is set to red using the color parameter.

The plt.legend() function adds a legend to the plot, indicating which plot corresponds to the train loss and which corresponds to the test loss. The loc parameter is set to 'upper right' which means the legend will be located at the upper right corner of the plot.

The plt.xlabel() and plt.ylabel() functions add labels to the x and y-axis of the plot, respectively. The x-axis label is set to 'number of training examples seen' and the y-axis label is set to 'Custom loss'.

Finally, the plt.show() function is used to display the plot.

This code will display a plot that shows the custom loss function over the training examples seen. The blue line represents the custom loss for the training set, and the red dots represent the custom loss for the test set. The plot will allow you to see how the custom loss function is behaving during the training process, and evaluate the performance of the model.

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
output = network(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Prediction: {}".format(
output.data.max(1, keepdim=True)[1][i].item()))
plt.xticks([])
plt.yticks([])

plt.show()
Test set samples and prediction — Image by Author

This code is displaying a figure with 6 images from the test set and their corresponding predictions made by the trained network.

It starts by using the enumerate() function to loop over the test_loader, which is an iterator that loads the test dataset in batches. The next() function is used to get the first batch of examples from the test set.

The example_data variable contains the images and the example_targets variable contains the corresponding labels.

Then it uses the Pytorch’s torch.no_grad() function, it is used to temporarily set the requires_grad flag to false. It will reduce memory usage and speed up computations but also it will not track the operations.

The next block of code creates a new figure object using the plt.figure() function. Then, it uses a for loop to iterate over the first 6 examples in the test set. For each example, it creates a subplot in the current figure using the plt.subplot() function. The plt.tight_layout() function is used to adjust the spacing between the subplots.

Then it uses the plt.imshow() function to display the image in the current subplot. The cmap parameter is set to 'gray' to display the image in grayscale and the interpolation parameter is set to 'none' to display the image without any interpolation.

The plt.title() function is used to add a title to the current subplot. The title shows the prediction made by the network for the current example. The output of the network is passed through the output.data.max(1, keepdim=True)[1] which returns the index of the predicted class. The [i].item() extracts the integer value of the predicted class.

The plt.xticks() and plt.yticks() functions are used to remove the x and y-axis ticks from the current subplot, respectively.

Finally, the plt.show() function is used to display the figure. This code will display a figure with 6 images from the test set and their corresponding predictions made by the trained network. The images are displayed in grayscale and without any interpolation, and the predicted class is displayed as a title above each image. This can be a useful tool for visualizing the model’s performance on the test set and identifying any potential issues or misclassification.

Greetings

In this article, we have discussed the theory and implementation of custom loss functions in PyTorch, using the MNIST dataset for digit classification as an example. We have shown how to create a custom loss function by subclassing the nn.Module class and overridding the forward method. We have also provided an example of how to use the custom loss function for training a model on the MNIST dataset. Custom loss functions can be useful in scenarios where the cost of misclassifying certain classes is much higher than others. It is important to note that care should be taken when implementing custom loss functions, as they can have a significant impact on the performance of the model.

Join Medium Membership

If you enjoyed this article and want to keep learning more about this topic, I invite you to join Medium membership at this link.

By becoming a member, you’ll have access to a wider variety of high-quality content, and exclusive access to member-only stories, and you’ll be supporting independent writers and creators like myself. Plus, as a member, you’ll be able to highlight your favourite passages, save stories for later, and get personalized reading recommendations. Sign up today and let’s continue exploring this topic and others together.

Thank you for your support! Until next,

Marco

--

--

Graduated in Mechanical Engineering, I work in the world of AI, Deep Learning and Software Development. Passionate about Technology, Videogames and AI.