Abstract

Medical image segmentation is a critical task in the healthcare field. While deep learning techniques have shown promise in this area, they often require a large number of accurately labeled images. To address this issue, semisupervised learning has emerged as a potential solution by reducing the reliance on precise annotations. Among these approaches, the student-teacher framework has garnered attention, but it is limited in its reliance solely on the teacher model for information. To overcome this limitation, we propose a prototype-based mutual consistency learning (PMCL) framework. This framework utilizes two branches that learn from each other, incorporating supervision loss and consistency loss to adapt to minor data perturbations and structural differences. By employing prototype consistency learning, we are able to achieve reliable consistency loss. Our experiments on three public medical image datasets demonstrate that PMCL outperforms other state-of-the-art methods, indicating its potential in semisupervised medical image segmentation. Our framework has the potential to assist medical professionals in enhancing their diagnoses and delivering improved patient care.

1. Introduction

Automatic and accurate segmentation of tumors, organs, or lesions is the premise of designing computer-aided diagnosis and detection systems. Deep convolutional neural networks have performed well at many medical image segmentation tasks [13]. However, these methods require a large number of high-quality labeled images to achieve very good results. It is laborious and time-consuming for experienced experts to make reliable and accurate annotations. We study semisupervised methods to fully utilize a small number of labeled images and a large number of unlabeled images to solve this problem.

Semisupervised methods have developed rapidly, especially in the field of medical image segmentation. Temporal ensembling and the II model [4] are proposed to accomplish semisupervised learning tasks by adding noise to the unlabeled data and then minimizing the difference between the prediction results of the source data and the noised data. The mean teacher framework [5] utilized the exponential moving average (EMA) of the temporal ensembling method. The network consists of a teacher model and a student model. The student model is trained by gradient descent, and the teacher model is obtained by using the parameters of the student model. The mean teacher framework has a simple structure and excellent experimental results, so many subsequent methods [68] make full use of this framework and extend this framework. Xie et al. [6] added a confidence module to the mean teacher framework to predict the confidence of the model and improve the performance of the network. Li et al. [7] introduced more perturbations to both the data and model of the mean teacher framework to construct the consistency loss. Yu et al. [8] encouraged the model to learn more reliable goals by adding uncertainty awareness to the mean teacher framework. Adversarial learning is also used for semisupervised segmentation [9, 10]. Zhang et al. [10] proposed a deep adversarial network to encourage consistency between the predicted segmentation of unlabeled data. More recently, there have been some multitask network structures for semisupervised medical image segmentation tasks [11, 12]. Li et al. [11] performed image segmentation and signed distance map regression tasks at the same time and used the discriminator as a regularization item. Luo et al. [12] built a multitask network that builds the consistency from the difference of segmentation tasks and the level set function regression task.

However, in the mean teacher framework, the parameters of the student network are obtained by the combination of the segmentation and consistency loss, the exponential moving average is calculated to obtain the parameters of the teacher network, and the total loss is updated to guide the student network in turn. We want to build a framework that consists of two student models, which we encourage to learn from each other, combining their learned information to improve network performance. We propose a prototype-based mutual consistency learning framework (PMCL) for medical image segmentation tasks, which is divided into two branches, which we can regard as two student models. To make them learn different information, the two student models are slightly different. The two branches use prototype learning to obtain the segmentation predictions of unlabeled images under different disturbances, and we obtain the consistency loss by comparing the segmentation predictions of the two branches. The prediction difference between the two branches can be considered as a complex area. By applying the consistency loss to the output of each decoder, high-confidence regions can be learned. For labeled images, the two branches obtain different pieces of information through slightly different decoders. The framework learns more reliable information through a combination of two supervision losses. The two branches learn from each other, allowing the network to train end-to-end.The main contributions of this work are as follows:(1)We propose a semisupervised 2D medical image segmentation framework, PMCL, which allows two networks to learn from each other for semisupervised segmentation tasks. The proposed framework can also be applied to other 2D and 3D semisupervised medical image segmentation tasks.(2)We use prototype consistency learning to generate high-quality pseudolabels specifically for unlabeled images, which are more reliable than those generated by other methods. The performance of the network can be significantly improved by using labels obtained specifically from unlabeled images.(3)Comprehensive experiments on three public medical image datasets demonstrate the superiority of PMCL to other semisupervised methods. Ablation experiments confirm the effectiveness of each submodule of the proposed method.

We introduce related work on semisupervised medical image segmentation, mutual learning, and prototype consistency learning.

2.1. Semisupervised Medical Image Segmentation

Semisupervised learning plays an increasing role in the field of medical image segmentation. It can be roughly divided into regularization methods based on data or model disturbances, adversarial learning methods, and consistency methods based on multitask levels.

There are many pseudolabel methods [14, 15], which utilize labeled data to train the model, generate pseudolabels for unlabeled data, and add these to the training set to continue training. The most important task is finding high-quality soft labels. Hung et al. [16] designed a discriminator to provide supervisory signals to perform semisupervised medical image segmentation tasks. It can learn to distinguish between ground-truth label maps and probability maps for segmentation prediction. Combining spatial cross-entropy loss, this paper uses adversarial loss to encourage segmentation networks to generate prediction probability maps that are close to the real label map in high-order structures. Temporal ensembling and the II model [4] were proposed to complete semisupervised learning by minimizing the difference between the predicted results of the original unlabeled data and the noised unlabeled data. Virtual adversarial training (VAT) [17] proposed a regularization method based on virtual adversarial loss: a new measure of local smoothness of label distribution given input conditions. The virtual adversarial loss is defined as the robustness of the conditional label distribution around each input data point to local disturbances. It replaces random perturbations with adversarial perturbations designed to deceive the trained model, enabling the network to effectively learn the local smoothness a priori and become more resilient to various noises.

Mean teacher [5] also uses the consistency regularity and is divided into student and teacher models. The student model obtains the parameters through gradient descent, and the teacher model obtains them through the exponential moving average calculation of the student model parameters. The difference between the two model parameters can be regarded as a part of the network disturbance, which, together with the data disturbance, constitutes the total disturbance. Mean teacher has achieved great success in semisupervised image segmentation, and many subsequent networks [68] have modified and extended it. Li et al. [7] added more perturbations to the data and model based on the mean teacher framework. Yu et al. [8] used Monte Carlo dropout to add uncertainty awareness to the mean teacher framework to allow the learning of more reliable information.

Multitask network structures for semisupervised medical image segmentation have recently appeared. SASSnet [11] performs signed distance map regression and image segmentation tasks at the same time and uses the discriminator as a regularization item. The stability and robustness of the segmentation results are ensured by introducing prior information of shape and position. DTC [12] also builds consistency from the level of tasks for semisupervised learning and uses a multitask network. Unlike SASSnet, it uses the representation difference between the two tasks to build consistency.

2.2. Mutual Learning

High-performance deep neural networks generally have a huge number of parameters, so sophisticated networks such as MobileNet [18] and ShuffleNet [19] appeared later. Hinton et al. [20] proposed knowledge distillation technology, which uses a more complex teacher model that has been trained to guide a relatively lightweight student model for training. While reducing the model size and computing resource requirements, it tries to maintain the accuracy of the original teacher model. In the semisupervised medical image segmentation tasks, much work [58] has used the student-teacher network architecture to improve network performance.

In our work, the entire network framework has a mutual learning framework. In the student-teacher network, the student network can only learn from the teacher network. Unlike the student-teacher network, mutual learning consists of two student networks, which can learn from each other and make progress together. Mutual learning frameworks are widely used in multimodel architectures, and they have achieved good results at various tasks. Zhang et al. [21] first proposed a deep mutual learning strategy. Each network used the sum of its own supervision loss and the interaction loss from other networks to supervise network learning. Wu et al. [22] proposed two decoders in semisupervised medical image segmentation, whose outputs used pseudolabels to guide each other’s probability map. This design made the output of the submodel consistent and low entropy, which can better segment edges and isolated parts of the image. Zhang and Zhang [23] designed two networks with the same structure, resulting in segmentation and regression layers. The networks were optimized to learn useful knowledge through mutual learning. Many methods [2426] have exploited mutual learning methods.

2.3. Prototype Learning

In our method, we generate predicted labels for unlabeled images by using prototype learning in few-shot segmentation learning tasks, where the latter aims to learn transferable knowledge from different tasks with just a few samples. In prototype learning, the labeled data in the training set are used as the model’s support set, and the prediction object is used as the network’s query set. The network must learn to use the support set to predict the label of the query set.

Many methods, including metric- [27, 28], optimization- [29, 30], and graph-based [31, 32] methods, have been proposed for few-shot learning. Among these, prototype-based methods are widely used in few-shot segmentation, as they reduce computation and perform relatively well. Snell et al. [27] proposed a prototypical network to represent each class with one feature vector in image classification tasks, using the nearest neighbor classifier to predict the category of the query set. Shaban et al. [33] proposed a classical two-branch model for few-shot segmentation tasks, using a conditional branch to extract the prototype features of the support set and a segmentation branch to extract the features of the query set, obtaining a segmentation map through logistic regression. Dong and Xing [34] also used metric learning and prototypical networks to complete few-shot segmentation tasks. SG-One [35] used masked average pooling to generate prototypes for the support set and cosine similarity to establish the relationship between the query set and prototype. Masked average pooling has since been widely used. Wang et al. [36] proposed prototype alignment regularization to make full use of the information of the support set. CANet [37] introduced the attention mechanism in prototype learning, using the middle-level features of the network to compare the query and support sets, and continuously iterating the network to obtain the segmentation results. FWB [38] improved the quality of the prototype by performing the same operations on the support set image as on the query set. AMP [39] considered the support set of the historical state when calculating the prototype and combined prototypes under different feature resolutions. Some methods [4042] have used superpixels to accomplish few-shot segmentation tasks.

We transfer the prototype learning in few-shot learning to semisupervised learning and use it to generate high-quality pseudolabels for unlabeled images to improve the reliability of network prediction.

3. Proposed Methodology

We present the details of the proposed PMCL method. We introduce the general semisupervised learning framework to make our method more intuitive and easier to understand, and then, we present the prototype consistency and mutual consistency learning modules. In this section, the overall loss composition of the framework is explained first. Then, the process of generating masks generated by prototype learning is explained, and finally, the consistency loss caused by masks is explained.

3.1. Semisupervised Segmentation Framework

Figure 1 shows the PMCL framework, which is trained as follows: The encoders of the two branches have the same structure and share weights. Two decoders from UNet [13] can capture uncertainty information through slight structural differences. A labeled image and an unlabeled image are fed into the two branches. For each branch, a shared backbone encoder is first used to embed the labeled and unlabeled images into deep features. Then, masked average pooling is utilized to obtain prototypes for the foreground and background from the labeled data and corresponding ground-truth, as discussed in Section 3.2. Label each pixel according to the class of the nearest prototype in order to segment the unlabeled images. A mutual learning network framework constrains the outputs of the two branches, as detailed in Section 3.3. Consistency loss and supervised loss constitute the total loss.

In the semisupervised learning setting, we have labeled and unlabeled training samples. We denote the respective labeled and unlabeled sets as and , where is the input image, is the ground truth of , and and are the image height and width, respectively. So, we can train our semisupervised medical image segmentation framework by minimization:where and are supervised loss and consistency loss, , , and are the weights of the encoder, decoder1, and decoder2, and is a ramp-up weighting coefficient that controls the trade-off between the supervised and consistency loss and can prevent the network from learning meaningless consistency goals at the beginning of training.

The total loss of our prototype-based mutual consistency learning network is a weighted combination of supervised loss and consistency loss , which are calculated only from labeled and unlabeled images, respectively. The total loss is

3.2. Prototype Mutual Learning

Previous semisupervised methods have usually directly used the encoding and decoding structure to generate segmentation predictions for unlabeled images, which does not efficiently utilize the information in the labeled images and corresponding labels. We want to efficiently generate pseudolabels for unlabeled images, which can be accomplished with the prototype learning method in few-shot learning. We use the labeled image and its ground truth as the support set and the unlabeled image as the query set to train the network. Our model is based on a prototypical network [27] that uses the mask annotations of the support set to learn prototypes for the foregrounds and backgrounds of images. To maintain input consistency, we adopt a late fusion strategy that uses a shared feature extractor to generate feature maps for the foregrounds and backgrounds of images [35, 43]. Specifically, we have a support set , and is a feature map extracted by the encoder for the labeled image , where indexes the support images. We can obtain the prototype of the foreground by masked average pooling [35]:where indexes the spatial locations, is an indicator function that returns 1 if the condition is true and otherwise outputs 0, and is the foreground segmentation target. We can also obtain the prototype of the background as

Nonparametric metric learning is used to learn the optimal prototype and complete segmentation. Since segmentation can be thought of as a classification of each spatial position, we calculate the distance between the query feature vector for each spatial position and each computed prototype. We introduce a distance function and apply the softmax function over distances to produce a probability map over classes. Let and denote the feature map extracted from unlabeled data as . For each , we havewhere the distance function adopts the cosine distance (i.e., in Figure 1) to measure the similarity between the unlabeled feature map and the labeled prototypes , and the multiplier is set as 20, as used in PANet [36].

Then, we can obtain the predicted segmentation mask of unlabeled data as follows:

Similarly, we can get the predictive segmentation mask of the unlabeled image of the other branch by performing the above operations (i.e., ). The two branches generate different pseudolabels for unlabeled data under different data perturbations. We add Gaussian noise on unlabeled data of the second branch (i.e., noise ). The network can focus on high-confidence areas through the different pseudolabels generated by the two branches and obtain more reliable and robust results through consistency learning. We verify the role of prototype consistency learning in the network through an ablation experiment, as described in Section 4.4.

3.3. Mutual Consistency Learning

In a mutual learning framework, multiple untrained branches learn at the same time to solve tasks together. Each branch is guided by traditional supervised learning loss and consistency loss from other branches.

At the beginning of training, each branch can quickly segment images relatively correctly because of the traditional supervised loss. At this point, the predictions of the same pixels may differ according to initial conditions and network structures. The framework encourages consistent predictions from each branch. The consistency loss from other branches fine-tunes the model to perform better in complex segmentation areas. In the end, mutual learning helps to obtain a more robust and generalized network.

Our mutual learning framework consists of prototype mutual consistency learning for unlabeled data and mutual supervision learning for labeled data.

For prototype mutual consistency learning, to measure the segmentation predictions of the two branches, Kullback Leibler (KL) [21] divergence is used as the consistency loss. The consistency loss from to is computed as

We can similarly obtain the consistency loss from to as

In this way, each branch learns to correctly predict segmentations of training data and to match the probability estimate of its peer. We balance the two consistency losses to obtain the final consistency loss:

For mutual supervision learning, decoder1 and decoder2 perform upsampling through bilinear interpolation and deconvolution, respectively. The different decoder structures prompt the model to learn more information. We combine cross-entropy and Dice loss to calculate the supervised loss. The two branches are calculated as follows:

To fully utilize the information of both branches and let the model train end-to-end, we combine the two supervised losses:

Hence, the network obtains more reliable information from the labeled data through the mutual learning framework.

4. Experiments and Results

We discuss the implementation and compare the performance of PMCL and other semisupervised medical image segmentation algorithms on three public datasets. We performed ablation experiments to validate each part of our method.

4.1. Datasets and Evaluation Metrics

We evaluated our method on three public polyp segmentation datasets: CVC-ClinicDB [44], CVC-ColonDB [45], and Kvasir-SEG [46]. CVC-ClinicDB contains 612 images of size 384  288 pixels. CVC-ColonDB contains 380 images of size 574  500 pixels. Kvasir-SEG contains 1000 images, which we scaled to 256  256 pixels before training, as they vary in size from 332  487 to 1920  1072 pixels. In our experiments, we follow the training settings of [3, 47, 48]. The division of the three datasets was the same, with random selections of 80% of the images for training, 10% for validation, and 10% for testing. Each image was normalized to unit variance and zero mean. For training images, only 10% and 20% were used as labeled, and the remaining data were used as unlabeled data. Table 1 shows the image size, scale of training set, validation set, and testing set of these datasets.

We evaluated segmentation performance using the Dice similarity coefficient (DSC), Jaccard index (JI), sensitivity (SE), accuracy (AC), 95% Hausdorff distance (95HD), and average surface distance (ASD). We combine the experimental protocols in [6, 8] to calculate these metrics.

4.2. Implementation Details

All the networks in our experiments were trained using PyTorch, with an Nvidia GeForce TITAN X GPU. For all the methods, the encoder and the decoders came from UNet [13]. We adopted the SGD optimizer to train the networks, setting the weight decay to 0.0001 and momentum to 0.9. We used no pretrained weights. We set the initial learning rate of the network to 0.01 and reduced it by a factor of 10 every 2500 iterations. The input batch size of the network was set to 4, consisting of two labeled images and two unlabeled images. We set the consistency weight factor as a time-dependent Gaussian warming-up function , where and indicate the current and last training step, respectively. Because both branches were trained through mutual learning, we chose the better performance of the two branches as the final test result.

With 1000 or fewer iterations, we let the consistency loss equal 0, because the network parameters did not converge at the beginning, and the consistency loss was meaningless. With greater than 1000 iterations, we added the consistency loss to the total loss.

4.3. Comparison between PMCL and Other Methods

We compared the proposed method with existing methods on CVC-ClinicDB, CVC-ColonDB, and Kvasir-SEG. As shown in Tables 24, we implemented several semisupervised segmentation methods for comparison, including mean teacher (MT) [5], deep adversarial network (DAN) [10], entropy minimization (EM) [49], uncertainty aware mean teacher (UAMT) [8], and interpolation consistency training (ICT) [50]. Fully supervised utilized 100% labeled data to obtain an upper bound on performance. For fair comparisons, all methods utilized a UNet [13] backbone network.

Table 2 shows the results of comparative experiments on CVC-ClinicDB under 10% and 20% labeled images, taking the supervised-only method as the baseline. With 10% labeled images, we can see that all semisupervised methods show an improvement over the baseline because they can learn additional information from the unlabeled images by regularization loss. The proposed PMCL method shows steady and obvious improvement over other state-of-the-art semi-supervised learning methods on the six metrics. DSC has increased by 6.64%, 4.89%, 6.56%, 4.9%, and 5.65% compared with [5, 8, 10, 49, 50], respectively, by leveraging 10% labeled images and 90% unlabeled images. When using 20% of labeled images, all semisupervised learning methods improved. Our method still shows a notable performance improvement, as DSC has increased by 2.6%, 1.54%, 1.7%, 2.97%, and 0.79% compared with [5, 8, 10, 49, 50], respectively. The proposed PMCL outperforms the other methods on the DSC, JI, SE, and 95HD metrics.

Tables 3 and 4 show the performance of the proposed method and other state-of-the-art methods under 10% and 20% labeled images on CVC-ColonDB and Kvasir-SEG. For CVC-ColonDB, compared with other state-of-the-art semisupervised methods, on all six metrics, our method achieves the best performance under 10% and 20% labeled data. For Kvasir-SEG, our method performs best on five metrics under 10% labeled data and on four metrics under 20% labeled data. Through experiments on these three datasets, we can find that when using a small amount of labeled data, our method improves greatly compared with other methods, which means that it can more efficiently exploit unlabeled images compared with other semisupervised methods.

Figures 24 show the predicted segmentation results of the proposed PMCL and other methods under 10% labeled image settings on three datasets. Compared with other semisupervised approaches, the predicted segmentation map of our PMCL has a larger intersection rate with the ground truth, and its segmentation results are smoother in the edge area of the lesion.

Overall, the comparison experiments demonstrate that the PMCL framework can outperform other state-of-the-art methods under different numbers of labeled images, which means that our method is fully capable of learning the rich and effective information from the unlabeled images.

4.4. Ablation Study

To verify the impact of prototype mutual consistency learning and mutual supervision learning on the entire framework, we conducted ablation studies on CVC-ClinicDB. We designed a method to use the MT framework, replacing consistency loss with the proposed prototype mutual consistency learning, referred to as Prototype-MT. The proposed method utilizes mutual learning between the two branches, where both the supervision and consistency losses have two parts. We designed an experiment to explore the impact of the two parts on the overall network, proposing three framework structures based on our network framework for ablation experiments under 10% labeled data settings: PMCL-B1 uses the loss of the branch above, ; PMCL-B2 uses the loss of the branch below, ; and PMCL combines the two branch losses.

In Table 5, it can be observed that the performance of Prototype-MT is better than that of MT, which indicates that prototype mutual consistency learning can more effectively utilize unlabeled data and learn more reliable and rich knowledge from it. Moreover, the performance of PMCL significantly exceeds that of PMCL-B1 and PMCL-B2, which means that the two branches obtain better performance through mutual learning. The ablation experiments show that our mutual learning framework can learn rich information from both branches and effectively improve network performance.

5. Conclusion

We investigated common methods for semisupervised medical image segmentation and proposed the PMCL framework. Through experiments on these three datasets, it can be found that when using a small number of labeled images, the PMCL framework has a greater improvement than other methods. This is because the proportion of labeled data is smaller, the semisupervised method can utilize less reliable information, and the proportion of unlabeled data is higher. Therefore, the semisupervised method can extract more information from unlabeled data. At this point, different semisupervised learning methods have significant differences in their ability to extract information, resulting in significant differences in the final results.

From the experiment, it can be seen that the PMCL method can more fully utilize unlabeled images to improve network performance compared to other semisupervised methods. The proposed method makes full use of a mutual learning framework to improve its performance and robustness. We designed prototype mutual consistency learning to obtain more reliable consistency loss for unlabeled images and supervision mutual learning for labeled images. Experiments demonstrated that our method has potential in semisupervised segmentation tasks.

Data Availability

The data that support the findings of this study are available on request from the corresponding author.

Conflicts of Interest

The authors declare that they have no conflicts of interest.

Acknowledgments

This work was supported by the National Natural Science Foundation of China (grant number 62176181).