How to calculate mean and standard deviation of images in PyTorch


In this article, we will learn to:
  • Calculate the mean and standard deviation of the image dataset.
First, we load our images/ image dataset. To load a custom image dataset, use torchvision.datasets.ImageFolder()

The images are arranged in the following way:

root/class_1/xxx.png 
root/class_1/xxy.png 
root/class_1/[...]/xxz.png 

example: 
dataset/cat/101.png 
dataset/cat/1002.png 
dataset/cat/[...]/1000.png
 
In our dataset, we have many cat images as above mentioned in the example.

# python code to load the image dataset

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

data_path = './dataset/'

transform_img = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

image_data = torchvision.datasets.ImageFolder(
  root=data_path, transform=transform_img
)

image_data_loader = DataLoader(
  image_data, 
  batch_size=len(image_data), 
  shuffle=False
  num_workers=0
)

We loaded the image datasets and we got dataloader.

Visualize an image from the image dataset.

Python3
# Python code to visualize an image
import matplotlib.pyplot as plt

images, labels = next(iter(image_data_loader))

def display_image(images):
  images_np = images.numpy()
  img_plt = images_np.transpose(0,2,3,1)
  # display 5th image from dataset
  plt.imshow(img_plt[4])

display_image(images)

Output:

Calculate the mean and standard deviation of the dataset

When the dataset is small and the batch size is the whole dataset. Below is an easy way to calculate when we equate batch size to the whole dataset.


Python3
# python code calculate mean and std

from torch.utils.data import DataLoader

image_data_loader = DataLoader(
  image_data, 
  # batch size is whole datset
  batch_size=len(image_data), 
  shuffle=False
  num_workers=0)

def mean_std(loader):
  images, lebels = next(iter(loader))
  # shape of images = [b,c,w,h]
  mean, std = images.mean([0,2,3]), images.std([0,2,3])
  return mean, std

mean, std = mean_std(image_data_loader)
print("mean and std: \n", mean, std)

Output:
mean and std:
(tensor([0.5125, 0.4667, 0.4110]), tensor([0.2621, 0.2501, 0.2453]))


If our dataset is large and we divide the dataset into batches we can use the below python code to determine the mean and standard deviation.


Python3
# python code to calculate mean and std 

import torch
from torch.utils.data import DataLoader

batch_size = 2

loader = DataLoader(
  image_data, 
  batch_size = batch_size, 
  num_workers=1)

def batch_mean_and_sd(loader):
    
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images, _ in loader:
        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[023])
        sum_of_square = torch.sum(images ** 2,
                                  dim=[023])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)        
    return mean,std
  
mean, std = batch_mean_and_sd(loader)
print("mean and std: \n", mean, std)


Output:
mean and std:
(tensor([0.5125, 0.4667, 0.4110]), tensor([0.2621, 0.2501, 0.2453])

You may be interested in the following posts.



Comments