A New Self-Supervised Learning Algorithm From Facebook AI: Barlow Twins
3 main points
✔️ A novel self-supervised learning algorithm for visual tasks
✔️ Works with smaller batch sizes and larger dimensions compared to traditional methods
✔️ Perform competitively with SOTA models
Barlow Twins: Self-Supervised Learning via Redundancy Reduction
written by Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny
(Submitted on 4 Mar 2021)
Comments: Accepted to arXiv.
Subjects: Computer Vision and Pattern Recognition (cs.CV); Artificial Intelligence (cs.AI); Machine Learning (cs.LG); Neurons and Cognition (q-bio.NC)
Self-supervised learning is becoming competitive with supervised learning methods on large computer vision benchmarks like ImageNet.A number of self-supervised learning techniques have been introduced in computer vision in the form of data augmentations. The aim is to learn model representations that are invariant to input distortions. This is done by training under an objective function that maximizes the similarity of model representations of different versions of the distorted sample. The most straightforward thing for the model to do to achieve that goal is to learn a constant representation, which is undesirable. So, Siamese network variants like SimCLR, SimSIAM, BYOL, SwAV, and SeLa use different techniques to introduce asymmetry in the network pair.
In this paper, we introduce Barlow Twins: a self-supervised learning algorithm that makes use of redundancy reduction, a concept introduced by the neuroscientist H. Barlow. Our method tries to make the cross-correlation matrix between output representations close to the identity matrix. In contrast to other methods, it works well with smaller batch sizes and high-dimensional representations. It outperforms other methods on ImageNet in the low-data regime and is on par with the current SOTA model.
Barlow Twins Method
Implementing the Barlow Twins algorithm is pretty straightforward. First, we sample a random batch of images X. These images are distorted using a random set of image augmentation methods sampled from T, to obtain distorted sets YA and YB. The distorted images are passed through a function(DNN) fθ with learnable parameters θ. We, therefore, obtain two output batches ZA and ZB, which are made such that the mean is 0 along the batch dimension. We compute the cross-correlation matrix along the batch dimension(b) as follows:
C is a square matrix with a size equal to the dimensionality of the network's output and values between -1 and 1. The loss function is computed as follows:
The invariance term in the above equation aims to bring the diagonal terms closer to 1 i.e. it makes the representation invariant to distortions in the input. The second term tries to bring the non-diagonal components of the correlation matrix to zero i.e. it reduces redundancy between the output units. λ is a constant and a greater value gives more priority to redundancy reduction. A pytorch-style pseudocode is given below:
Important Ablation Studies
A model was trained on the ImageNet dataset using Barlow Twins and the following studies were conducted on a linear evaluation of the model.
Loss Function Ablations: We tried to alter our loss function(baseline) in various ways and it was found that our loss function gives optimal performance.
Like our method, some SSL loss functions like infoNCE normalize along batch dimension while some loss functions that measure cosine similarity normalize along feature dimension. So, we tried normalizing along the feature dimension such that the values lie on the unit sphere but it showed inferior results. Using no batch normalization(BN) in the final two projection hidden layers(MLP) only slightly reduced performance but using no normalization in the covariance matrix is detrimental. Finally, replacing the loss function with cross-entropy loss with temperature did not help either.
Batch Size: Some methods that use the infoNCE loss(like SimCLR) perform poorly with smaller batch sizes and therefore it was necessary to test our method's robustness to batch size. As we can see in the diagram above, our model performs quite well even for smaller batch sizes of up to 256.
Need for Augmentations: As the above diagram suggests, data augmentation is quite crucial to getting better performance. This shows that our model's representations can be better controlled by the type of augmentation that is used in contrast to BYOL which is less variant to the augmentations used.
Dimensionality: For other SSL methods, the dimensionality reduces the model performance while in the case of BT, we found that increasing the projector dimensionality drastically improves the model performance. Moreover, our model works better when the projector network has more layers, 3 at most.
Additional asymmetry: BYOL and SimCLR use additional methods to introduce asymmetry to the network. Although our loss function inherently introduces asymmetry, we experimented to see if additional asymmetry helped. We added a predictor network with 2 dense layers(with batch normalization followed by a ReLU in the first layer) to one network and/or a stop-gradient mechanism on the other network. Both failed to further improve the performance and were detrimental when used simultaneously.
Results and Evaluation
The network(ResNet-50 backbone) was pre-trained on the ImageNet ILSVRC-2012 dataset without using the labels and evaluated across a variety of tasks such as image classification and object detection. Different augmentations were used: random cropping(always), random resizing(always), horizontal flipping, color jittering, grayscale, Gaussian blurring, and solarization.
The weights of the ResNet-50 pre-trained model were kept fixed and a linear layer was trained for classification. The top-1 and top-5 accuracy of the model compared to other self-supervised models is shown in the table above. The following table shows the results of fine-tuning the model on the Places-205 (top-1 accuracy), VOC07(mAP), and iNat18(top-1 accuracy) datasets.
In all the cases, the model performance is on par or better than the current SOTA models.
Likewise, our method works well with transfer learning for object detection and instance segmentation tasks. The above table shows the scores on the VOC07+12 object detection benchmark using Faster R-CNN and on COCO object detection and instance segmentation using Mask R-CNN.
The Barlow Twins method is on par with the current state-of-the-art methods and has some contrasting properties. One of its most fascinating properties is how well it performs with high dimensional feature projections. It would be worthwhile to see how this method can be used to learn representations using even higher dimensions in the >16000 domain after we overcome the hardware limitations. This work lays a foundation for future works, which could refine this method even more to obtain better SSL algorithms.
Categories related to this article