Catch up on the latest AI articles

A Way To Create A GPT-3 Equivalent For Vision Transformers?

A Way To Create A GPT-3 Equivalent For Vision Transformers?


3 main points
✔️ Self-supervised learning(SSL) method for vision transformers.
✔️ Use of the SSL method to train a powerful transformer autoencoder for images, with a simple linear decoder.
✔️ Outperforms current SOTA models on several benchmarks with more than 13.53% improvements on some datasets.

SiT: Self-supervised vIsion Transformer
Written by Sara Atito, Muhammad Awais, Josef Kittler
(Submitted on 8 Apr 2021)
Comments: Accepted to arXiv.

Subjects: Computer Vision and Pattern Recognition (cs.CV); Machine Learning (cs.LG)



One of the reasons why models like BERT in NLP are so successful is that they can be trained using an extremely large corpus of unlabeled data. On the other hand, CNNs and vision transformers are usually limited to pre-training on large-scale supervised data. Especially transformers have been shown to require extremely large datasets to perform well. Scaling these datasets is an expensive task and several self-supervised learning(SSL) algorithms have been proposed to address this limitation.

Generative SSL approaches are computationally expensive and not always necessary for representation learning. Also, contrastive learning approaches that learn representations agnostic to data augmentations have been very successful. However, contrastive learning fails to capture contextual information, which is where pretext tasks such as colorization, puzzle-solving, noise prediction, have proved to be useful. Also, there is yet to be an effective SSL method to enable us to fully leverage the power of vision transformers. 

In this paper, we propose a novel method to combine the advantages of the contrastive and pretext methods. Our method improves the current state of the art in different datasets with improvements of over 13.53% in some datasets. 


The Vision Transformer(ViT) learns a bottleneck representation where the content and context representations are centered around a special 'class' token. This severely limits the ability of the transformer to model the data efficiently, therefore requiring several samples for training. Our objective is to reduce the amount of labeled training data while still achieving state-of-the-art performance. Our method is a combination of changes to the model, training, and training task, each of which is described in the following sections. 

Self-Supervised Vision Transformer

Our base model is the recently introduced ViT vision transformer. We replace the 'class' token in the ViT with two new tokens: the rotation and contrastive tokens, which are used for the tasks of rotation transformation prediction, and contrastive prediction respectively. Just like the class token in ViT, these new tokens are concatenated along with the image patch tokens. As shown in the above diagram, the positional embeddings are appended to both of these tokens along with the image patch tokens. Since the class token is not used, the model is trained on unlabeled data. The role of these two tokens is discussed in the next section. 

Self-Supervised Tasks

1)The transformer is trained on three different tasks: Image reconstruction, Rotation Prediction, and Contrastive Learning. In image reconstruction, we train a transformer autoencoder. CNN-based autoencoders use a series of convolutional and pooling layers which essentially discard valuable information in the process. Then, an expensive decoder recovers the information with upsampling and convolutional operations. Our transformer is trained to reconstruct the distorted image patches, and in doing so it learns the semantic concepts of the image. Local transformation operations applied to connected patches include random drop and randomly replace,  blurring, converting to grey-scale, recoloring, etc. All these changes are applied simultaneously to the image. The loss function is calculated as the L1-loss between the original image(x) and the reconstructed image-SiT(x'):

2)We take advantage of the transformer's flexibility and combine the reconstruction loss with other complementary losses. The input images are randomly rotated by {0,90,180,270} degrees and fed to the network. The network is trained to classify the rotation of the input to one of the aforementioned classes. This enables the model to learn the notion of objects in the image before figuring out their orientations. The cross-entropy loss is used to calculate the prediction error. 

3)It is necessary for the transformer model to be invariant to all geometric transformations and perturbation made to the input image i.e. the transformer must produce similar representations for all those augmented images. Specifically, we measure the cosine similarity and try to maximize the cosine distance of positive augmented images and try to minimize the cosine distance of negative dissimilar images. We measure the normalized temperature-scaled softmax similarity as follows:

Here, sim(., .) represents the cosine similarity of its L2-normalized inputs, and the temperature T is set to 0.5.

Finally, we combine all the three losses by taking a weighted sum of the losses (α1, α2, α3). It would be very expensive to optimize the three parameters with a grid search. Therefore, we use a method called the uncertainty weighting approach and treat (α1, α2, α3) as learned parameters.   


The experiments were conducted on four popular image classification datasets: CIFAR-10, CIFAR-100, Tiny-ImageNet, and STL-10. For linear evaluation tasks, the models are first trained on unlabeled data to learn the representations, and then a linear layer corresponding to the number of classes is tuned on top of learned features. Unlike CNN models which require a complex decoder, the transformer model works well even with a single linear layer as the decoder. We also conduct domain transfer experiments by training on unlabeled images from Cifar-100 and then tuning the model on CIFAR-10 and vice versa.  The results are shown in the table below.

SiT outperforms all existing methods by a large margin in all datasets. We also conducted experiments to see how the amount of labeled data used affects the performance of the model on CIFAR-10 and CIFAR-100 datasets.

In this section, we present sample images of original images (top row), corrupted images (middle row), and reconstructed images (bottom row) obtained from various sources (training set, test set, and on the Internet).


This paper has shown how to successfully implement a transformer as an image autoencoder by minimizing the rotation, reconstruction, and contrastive losses. The choice of these three losses and the natural ability of the transformer for multi-task learning allows the method to outperform state-of-the-art methods with a wide margin. Although this work focused mainly on image classification, it is possible to extend it to other vision tasks like instance segmentation and object detection.

Thapa Samrat avatar
I am a second year international student from Nepal who is currently studying at the Department of Electronic and Information Engineering at Osaka University. I am interested in machine learning and deep learning. So I write articles about them in my spare time.

If you have any suggestions for improvement of the content of the article,
please contact the AI-SCHOLAR editorial team through the contact form.

Contact Us