SoFunction
Updated on 2024-11-21

pytorch using voc segmentation dataset to train FCN process explained

Semantic segmentation is the process of classifying each pixel in an image to accomplish image segmentation. Segmentation is mainly used in the field of medical images and in the field of unmanned vehicles.

Like other algorithms, the image segmentation development process has also experienced the transformation of traditional algorithms to deep learning algorithms, traditional segmentation algorithms, including threshold segmentation, watershed, edge detection, etc., face the same problem as other traditional image processing algorithms, that is, the robustness is not enough, but in some of the scenes of a single unchanged occasions, the traditional image processing is still used more.

FCN is a 2014 paper, the groundbreaking work of deep learning semantic segmentation, which ideologically lays the foundation of semantic segmentation.

Fully Convolutional Networks for Semantic Segmentation

Submitted on 14 Nov 2014

/abs/1411.4038

I. Introduction to FCN theory

The above screenshot from the original paper depicts the network architecture of FCN in terms of overall architecture. It is actually an image that is subjected to a series of convolutional operations and then upsampled to the original image size, outputting the category probability of each pixel.

The above figure describes the FCN network in more detail. backbone uses VGG16, which represents the fully-connected layer of VGG as a convolution, conv6-7 (a convolution kernel of the same size as the feature_map, which is equivalent to fully-connected). Overall, the network has the following key points:

1. Fully Convolution: Used to solve the pixel prediction problem. By replacing the last fully connected layer of the base network (e.g., VGG16) with a convolutional layer, an image input of arbitrary size can be realized and the output image size corresponds to the input;

Convolution: up-sampling process to recover the image size for subsequent pixel-by-pixel prediction;

3. Skip Architecture: It is used to fuse the information of high bottom layer features. Because convolution is a downsampling operation, and transposed convolution although the image size is restored, but after all, it is not the inverse of the convolution operation, so the information must be lost, and skip architecture can be fused with thousands of layers of fine-grained information and deep coarse-grained information to improve the segmentation of the degree of refinement.

FCN-32s: no jump-joins, zoom in at a rate of 2x per layer of transposed convolution, zoom in 32x after five layers to recover the original size.

FCN-16s: a skip-connect, where (1/32) is zoomed to (1/16), then added to vgg's (1/16), and then continued to be zoomed to the original image size.

FCN-8s: two skip-connects, one is (1/32) zoomed to (1/16) and then added to vgg's (1/16); the other is (1/16) zoomed to (1/8) and then added to vgg's (1/8), and then continue to zoom until the original image size.

II. Training process

pytorch training deep learning models can be implemented in three main files, namely , , . In which to realize the data batch processing function, the definition of the network model, the realization of the training step.

2.1 Introduction to the voc dataset

Download Address:Pascal VOC Dataset Mirror

The name of the image is in /ImageSets/Segmentation/ ans

The images are all under . /data/VOC2012/JPEGImages folder below, you need to add .jpg after each line you read

The tags are all under . /data/VOC2012/SegmentationClass folder, you need to add the .png after each line that you read

voc_seg_data.py

import torch
import  as nn
import  as T
from  import DataLoader,Dataset
import numpy as np
import os
from PIL import Image
from datetime import datetime
class VOC_SEG(Dataset):
    def __init__(self, root, width, height, train=True, transforms=None):
        # Image Uniform Cropping Size (width, height)
         = width
         = height
        # Corresponding labels in the VOC dataset
         = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']
        # Colors corresponding to the various labels
         = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]
        # Auxiliary variables
         = 0
        if transforms is None:
            normalize = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
             = ([
                (),
                normalize
            ])
        # Pixel value (RGB) with category label(0,1,3...) one by one
        self.cm2lbl = (256**3)
        for i, cm in enumerate():
            self.cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i
        if train:
            txt_fname = root+"/ImageSets/Segmentation/"
        else:
            txt_fname = root+"/ImageSets/Segmentation/"
        with open(txt_fname, 'r') as f:
            images = ().split()
        imgs = [(root, "JPEGImages", item+".jpg") for item in images]
        labels = [(root, "SegmentationClass", item+".png") for item in images]
         = self._filter(imgs)
         = self._filter(labels)
        if train:
            print("Training set: loaded " + str(len()) + " Pictures and labels " + ", filtered." + str() + "A picture.")
        else:
            print("Test set: loaded " + str(len()) + " Pictures and labels " + ", filtered." + str() + "A picture.")
    def _crop(self, data, label):
        """
        Cut function, the default are cut from the upper left corner of the picture. The width of the cut image is width and the height is height.
        data and label are Image objects.
        """
        box = (0,0,,)
        data = (box)
        label = (box)
        return data, label
    def _image2label(self, im):
        data = (im, dtype="int32")
        idx = (data[:,:,0]*256+data[:,:,1])*256+data[:,:,2]
        return (self.cm2lbl[idx], dtype="int64")
    def _image_transforms(self, data, label):
        data, label = self._crop(data,label)
        data = (data)
        label = self._image2label(label)
        label = torch.from_numpy(label)
        return data, label
    def _filter(self, imgs): 
        img = []
        for im in imgs:
            if ((im).size[1] >=  and 
               (im).size[0] >= ):
                (im)
            else:
                  = +1
        return img
    def __getitem__(self, index: int):
        img_path = [index]
        label_path = [index]
        img = (img_path)
        label = (label_path).convert("RGB")
        img, label = self._image_transforms(img, label)
        return img, label
    def __len__(self) :
        return len()
if __name__=="__main__":
    root = "./VOCdevkit/VOC2012"
    height = 224
    width = 224
    voc_train = VOC_SEG(root, width, height, train=True)
    voc_test = VOC_SEG(root, width, height, train=False)
    # train_data = DataLoader(voc_train, batch_size=8, shuffle=True)
    # valid_data = DataLoader(voc_test, batch_size=8)
    for data, label in voc_train:
        print()
        print()
        break
  • I here in order to save trouble to some auxiliary functions, such as _crop (), _filter (), or there are variables colormap, etc. are written to the class inside. In fact, it is better to write a separate data preprocessing file, so that after training, inference testing can directly call the corresponding processing function.
  • The result of the data processing is to get data, label. data is an image in tensor format, label is also a tensor, and the pixels (RGB) have been replaced with int category numbers. In this way, when training, the cross entropy function will directly realize one-hot processing, just like training classification network.

2.2 Network definitions

fcn8s_net.py

import torch
import  as nn
from  import Variable
import  as F
from torchsummary import summary
from torchvision import models
class FCN8s():
    def __init__(self, num_classes=21):
        super(FCN8s,self).__init__()
        net = models.vgg16(pretrained=True)   # Load VGG16 network parameters from pre-trained models
         =           # Use only the five convolutional layers (feature extraction layers) of Vgg16 (3, 224, 224) -----> (512, 7, 7)
        # self.conv6 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1) 
        # self.conv7 = nn.Conv2d(512,512,kernel_size=1,stride=1,padding=0,dilation=1)
        # (512,7,7)
         = (inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512,512,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn1 = nn.BatchNorm2d(512)
        # (512, 14, 14)
        self.deconv2 = nn.ConvTranspose2d(512,256,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn2 = nn.BatchNorm2d(256)
        # (256, 28, 28)
        self.deconv3 = nn.ConvTranspose2d(256,128,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)  # x2
        self.bn3 = nn.BatchNorm2d(128)
        # (128, 56, 56)
        self.deconv4 = nn.ConvTranspose2d(128,64,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)   # x2
        self.bn4 = nn.BatchNorm2d(64)
        # (64, 112, 112)
        self.deconv5 = nn.ConvTranspose2d(64,32,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)    # x2
        self.bn5 = nn.BatchNorm2d(32)
        # (32, 224, 224)
         = nn.Conv2d(32, num_classes, kernel_size=1)
        # (num_classes, 224, 224)
    def forward(self, input):
        x = input
        for i in range(len()):
            x = [i](x)
            if i == 16:
                x3 = x  # maxpooling3's feature map (1/8)
            if i == 23:
                x4 = x  # maxpooling4's feature map (1/16)
            if i == 30:
                x5 = x  # maxpooling5's feature map (1/32)
        # Five-layer transposed convolution with each layer size scaled up by a factor of 2, just the opposite of VGG16. Two skip-connect
        score = (self.deconv1(x5))   # out_size = 2*in_size (1/16)
        score = self.bn1(score + x4)
        score = (self.deconv2(score)) # out_size = 2*in_size (1/8)  
        score = self.bn2(score + x3)
        score = self.bn3((self.deconv3(score)))  # out_size = 2*in_size (1/4)
        score = self.bn4((self.deconv4(score)))  # out_size = 2*in_size (1/2)
        score = self.bn5((self.deconv5(score)))  # out_size = 2*in_size (1)
        score = (score)                    # size unchanged so that the output channel equals the number of categories
        return score
if __name__ == "__main__":
    model = FCN8s()
    device = ('cuda' if .is_available() else 'cpu')
    model = (device)
    print(model)

The network code implementation of FCN varies from what you can find online, but the overall structure is convolution + transposed convolution + jump links. In fact, as long as the implementation of feature extraction (extracting abstract features) - transposition convolution (to restore the size of the original image) - to each pixel classification process is enough.

This experiment uses five convolutional layers of vgg16 as the feature extraction network, and then connects five transposition convolution (2x) to restore to the original image size, and then connects another convolutional layer to adjust the channels of the feature map to the number of categories (21). Finally then softmax classification on the line.

2.3 Training

import torch
import  as nn
from  import DataLoader,Dataset
from voc_seg_data import VOC_SEG
from fcn_net import FCN8s
import os
import numpy as np
# Calculate the confusion matrix
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = (
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist
# Calculate Acc and mIou from the confusion matrix
def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
    """
    hist = ((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist((), (), n_class)
    acc = (hist).sum() / ()
    with (divide='ignore', invalid='ignore'):
        acc_cls = (hist) / (axis=1)
    acc_cls = (acc_cls)
    with (divide='ignore', invalid='ignore'):
        iu = (hist) / (
            (axis=1) + (axis=0) - (hist)
        )
    mean_iu = (iu)
    freq = (axis=1) / ()
    return acc, acc_cls, mean_iu
def main():
    # 1. load dataset
    root = "./VOCdevkit/VOC2012"
    batch_size = 32
    height = 224
    width = 224
    voc_train = VOC_SEG(root, width, height, train=True)
    voc_test = VOC_SEG(root, width, height, train=False)
    train_dataloader = DataLoader(voc_train,batch_size=batch_size,shuffle=True)
    val_dataloader = DataLoader(voc_test,batch_size=batch_size,shuffle=True)
    # 2. load model
    num_class = 21
    model = FCN8s(num_classes=num_class)
    device = ('cuda' if .is_available() else 'cpu')
    model = (device)
    # 3. prepare super parameters
    criterion = () 
    optimizer = ((), lr=1e-3, momentum=0.7)
    epoch = 50
    # 4. train
    val_acc_list = []
    out_dir = "./checkpoints/"
    if not (out_dir):
        (out_dir)
    for epoch in range(0, epoch):
        print('\nEpoch: %d' % (epoch + 1))
        ()
        sum_loss = 0.0
        for batch_idx, (images, labels) in enumerate(train_dataloader):
            length = len(train_dataloader)
            images, labels = (device), (device)
            optimizer.zero_grad()
            outputs = model(images) # ([batch_size, num_class, width, height])
            loss = criterion(outputs, labels)
            ()
            ()
            sum_loss += ()
            predicted = (, 1)
            label_pred = ().numpy()
            label_true = ().numpy()
            acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class)
            print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% | Acc_cls: %.03f%% |Mean_iu: %.3f' 
                % (epoch + 1, (batch_idx + 1 + epoch * length), sum_loss / (batch_idx + 1), 
                100. *acc, 100.*acc_cls, mean_iu))
        #get the ac with testdataset in each epoch
        print('Waiting Val...')
        mean_iu_epoch = 0.0
        mean_acc = 0.0
        mean_acc_cls = 0.0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_dataloader):
                ()
                images, labels = (device), (device)
                outputs = model(images)
                predicted = (, 1)
                label_pred = ().numpy()
                label_true = ().numpy()
                acc, acc_cls, mean_iu = label_accuracy_score(label_true,label_pred,num_class)
                # total += (0)
                # iou = ((predicted == ), (1,2)) / float(width*height)
                # iou = (iou)
                # correct += iou
                mean_iu_epoch += mean_iu
                mean_acc += acc
                mean_acc_cls += acc_cls
            print('Acc_epoch: %.3f%% | Acc_cls_epoch: %.03f%% |Mean_iu_epoch: %.3f' 
                % ((100. *mean_acc / len(val_dataloader)), (100.*mean_acc_cls/len(val_dataloader)), mean_iu_epoch/len(val_dataloader)) )
            val_acc_list.append(mean_iu_epoch/len(val_dataloader))
        (model.state_dict(), out_dir+"")
        if mean_iu_epoch/len(val_dataloader) == max(val_acc_list):
            (model.state_dict(), out_dir+"")
            print("save epoch {} model".format(epoch))
if __name__ == "__main__":
    main()

The overall training process is fine, and readers can change their model evaluation criteria and related code as needed. In this training, Acc is mainly used as the evaluation metric, which is actually the number of correctly classified pixels divided by the number of all pixels. The final training results are as follows:

0.8

Acc for the training set came to 0.8, and Acc for the validation set came to 0.77. Since some of the functions were copied, such as _hist, other metrics are not referenced for now.

To this point this article on pytorch using voc split dataset training FCN process explains the article is introduced to this, more related pytorch training FCN content please search for my previous articles or continue to browse the following related articles I hope that you will support me more in the future!