Transformers Are Replacing CNNs: Transformers For Medical Image Segmentation.
3 main points
✔️ The first CNN free model for 3D medical image segmentation
✔️ Segmentation accuracy that is better than or competitive with CNNs on three different datasets
✔️ Much better transfer learning ability than CNNs
Convolution-Free Medical Image Segmentation using Transformers
written by Davood Karimi, Serge Vasylechko, Ali Gholipour
(Submitted on 26 Feb 2021)
Comments: Accepted to arXiv.
Subjects: Image and Video Processing (eess.IV); Computer Vision and Pattern Recognition (cs.CV)
code:
First of all
Deep Neural Networks have had a huge impact in the medical sector. DNN can efficiently perform a variety of tasks like classifying malignant image instances, detecting and segmenting malignant regions in the images. Although manual work is considered to be the most reliable method, DNN models are faster, scalable, and cheaper with some initial cost. Behind the success of DNNs in medical image-processing tasks are Convolutional Neural Networks (CNNs). They have a strong inductive bias to images, which makes them successful in a variety of vision tasks. Although much has changed in the architecture, loss functions, and the way DNN models are trained in vision, the fundamental structural component, i.e. the convolution layer remains the same.
Transformers have performed extremely well in several NLP tasks. With the introduction of the visual transformer(ViT), self-attention has proven to be efficient even for computer vision tasks. This makes us wonder whether transformers could help improve the current state of the art in medical vision tasks.
In this paper, we introduce a transformer-based model free of CNNs for 3D medical image segmentation. The model is competitive or better than CNN-based models and can be fine-tuned on datasets with only 20~200 labeled images. In addition, the model has better transfer-learning capabilities than state-of-the-art CNNs.
Architecture
The above diagram shows the model architecture of the CNN-free network. At first, a block is extracted from 3D images and divided into n3 patches. Let the points in the block B ∈ RW×W×W×c be divided into n3 non-overlapping patches{pi ∈ Rw×w×w×c}, where w =W/n and c is the image dimension(3 for RGB). n is experimentally chosen to be 3 or 5 which gives 27 or 125 patches. Next, each of the patches is flattened into the vector of dimension w3c. These w3c dimensioned vectors are transformed into a D-dimension space using a learned linear mapping. We add positional encodings to the vector thus obtained. In this way we form a sequence of n3 blocks each with dimension D: X0 = [Ep1; ...; EpN ] + Epos. Note that, unlike most other tasks, the positional encodings here are learnable parameters.
The transformer encoder has K stages each with one multi-headed self-attention layer(MSA) followed by two feed-forward networks(FFN) that use layer normalization. The encoder is quite similar to that of the standard transformer. The query(Q), key(K), values(V) are computed and the self-attention is calculated using the expression below where Dh is the scaling factor equal to the hidden dimension.
Finally, after the final FNN, the sequence is projected into a space of dimension Nnclass using a fully connected layer where nclass is the number of classes(2 for binary segmentation). This matrix is reshaped into IRn×n×n×nclass segmentation mask Y' (The blocks are segmented and not individual pixels).
This is the segmentation mask only for the central patch in the block. This process should be repeated for the entire 3D image block.
Experiment and Evaluation
The model was compared to 3D UNet++ which is a CNN-based state-of-the-art model for medical image segmentation based on the DICE coefficient(DSC).
Pre-Training
In order to further improve the model accuracy for cases with a small number of labeled training instances, we train the model on the larger unlabeled dataset for denoising and inpainting (image reconstruction) tasks. For denoising, Gaussian noise with SNR=10 dB is added to the center patch of the image blocks. In the case of inpainting, the values for the center patch of the image blocks are set to 0, which is to be reconstructed. Both tasks are trained to minimize the L2 distance between the real and constructed images. The model is pre-trained without a softmax layer and for finetuning, a softmax layer is added to the model to predict the segmentation masks. It was also found that fine-tuning the entire network is more effective than finetuning just the last layer.
Evaluation
The model was benchmarked on three different datasets for images of the brain cortical plate, hippocampus, and pancreas.
The above diagram shows the experimental results on different datasets. As you can see, in almost all cases our model shows superior performance than the UNet++ model across different metrics. The lower diagram shows how our model and UNet++ perform with extremely few instances(5,10,15) for the cortical plate(left) and pancreas(right). In both cases, our proposed model adapts better and it can be seen that pre-training with inpainting works better than pre-training with denoising. We also observed that training with learnable positional encodings works better than training with fixed positional encodings. The diagram below shows the results of some ablation studies.
Prediction result image
Pancreatic segmentation is considered to be a difficult task. Even in such a situation, we believe that we are able to predict the segmentation mask to some extent.
Summary
Some tasks like the manual segmentation of the brain's cortical plate is a highly complex task and can take hours to complete even for experts. A model like this one is certainly advantageous in such cases. Given the few true-positive training instances in medical computer-vision tasks, it is a necessity to have models like this that can learn from fewer instances. Finally, it would also be interesting to see how the 3D point transformer model works on medical image segmentation tasks. Although it had not been tested on medical images, it had shown very impressive results on other 3D segmentation tasks.
Categories related to this article