Careful Analysis Of Distributional Shifts!
3 main points
✔️ Proposes a framework for distributional shifts
✔️ Defines three important distribution shifts
✔️ Comprehensive comparative evaluation of various methods
A Fine-Grained Analysis on Distribution Shift
written by Olivia Wiles, Sven Gowal, Florian Stimberg, Sylvestre Alvise-Rebuffi, Ira Ktena, Krishnamurthy Dvijotham, Taylan Cemgil
(Submitted on 21 Oct 2021 (v1), last revised 25 Nov 2021 (this version, v2))
Comments: ICLR2022.
Subjects: Machine Learning (cs.LG); Computer Vision and Pattern Recognition (cs.CV)
code:
The images used in this article are from the paper, the introductory slides, or were created based on them.
first of all
For machine learning models to be widely used in applications, they must be robust to changes in distribution. For example, a model trained on images from one group of hospitals may not perform well on images from another group of hospitals. Therefore, increasing the robustness to distributional shifts or understanding how robust a particular model is to distributional shifts is a very important issue, and Domain Generalization is an active research area that addresses these problems.
However, little work has been done to define the distributional shifts that can occur in practice or to evaluate the robustness of algorithms to multiple different distributional shifts.
To address this important issue, the paper presented in this article introduces a framework for fine-grained analysis of distributional changes and defines three distributional shifts (spurious correlation, low data drift, and unseen data shift) that can have real-world implications which can affect the real world. We also introduced two additional conditions (label noise and dataset size) and evaluated 19 existing methods on both real and synthetic data. (In recognition of these contributions, this paper has been Accepted (Oral) to ICLR2022).
proposed method
A Framework for Evaluating Generalization
First, let the input be $x$ and the corresponding attributes be $y^1,y^2,... ,y^K$($y^{1:K}$). Here, one of the attributes is a label, which we denote by $y^l$. For example, if it is a medical image, $y^l$ may be benign/malignant, $y^i(i \neq l)$ may be the information of the hospital where the image was taken, and so on. Also, let $p$ be the joint distribution of $x$ and $y^{1:K}$. In this case, the learning objective of the model is to construct a classifier $f$ that minimizes the risk $R(f)=E_{(x,y^l)~p}[L(y^l,f(x))]$ (L is the loss function). In practice, since the size of the inputs and attributes is a finite number $n$, we instead minimize the empirical risk $\hat{R}(f;p)=\frac{1}{n} \sum_{\{(y^l_i,x_i)~p\}^n_{i=1}}L(y^l_i,f(x_i))$.
Under the condition that distribution shift may occur, the distribution of data $p_{train},p_{test}$ is considered to be different during the train and test of the model. For example, $p_{train}$ and $p_{test}$ may be images taken at different hospitals, or the equipment used to take the images may be different. In this case, the model is trained to minimize $\hat{R}(f;p_{train})$, but in practice, it is desirable to reduce the empirical risk $\hat{R}(f;p_{test})$ at test time.
It is worth noting here that $p_{train},p_{test}$ are related to the true distribution $p$ (the joint distribution of $x$ and $y^{1:K}$) even though they are different distributions. So, to express this relationship, we use the latent factor $Z$ to factorize. Let us suppose that the following relation holds for $z$.
We can then factorize the true distribution $p(y^{1:K},x)$ as follows.
That is, the true distribution can be expressed as the product of the peripheral distribution $p(y^{1:K})$ of the attribute $y^{1:K}$ and the conditional generative model $p(x|y^{1:K})$.
Based on this, we make one important assumption. Namely, we believe that distributional shifts are caused by changes in the surrounding distribution of attributes.
In other words, we think that distributional shifts occur when $p(y^{1:K}) \neq p_{train}(y^{1:K}) \neq p_{test}(y^{1:K})$, while the conditional generative model $p(x|y^{1:K})$ is unchanged and shared for all distributions.
That is, $p_{test}(y^{1:K}, x) = p_{test}(y^{1:K}) \int p(x|z)p(z|y^{1:K})dz$ and $p_{train}(y^{1:K}, x) = p_{train}(y^{1:K}) \int p(x|z)p(z|y^{1:K })dz$ holds.
On distributional shifts
Following the framework described above, we consider three typical types of distribution shifts that can occur in the real world. As a practical example, the distribution shift case for the dSprites dataset is shown in the figure below.
In this case, the attribute $y^1$ is a color (red, green, blue) and $y^2$ is a shape (heart, oval, rectangle).
About the test distribution $p_test$.
In the test distribution $p_test$, we assume that the attribute $y^{1:K}$ is uniformly distributed. That is, let $p_{test}(y^{1:K})=\frac{1}{\prod_i}|A^i|$.
This is the state in which the data is uniformly and unbiasedly distributed for all attributes, as shown in Figure (d) above.
Spurious correlation
First, we consider the case where the attributes are correlated at $p_train$ but not at $p_test$. This pseudo-correlation occurs when two attributes $y^a,y^b$ are correlated (not independent) at the train, according to the framework described above.
Specifically, $p_{train}(y^a|y^1,... ,y^b,... ,y^K) > p_{train}(y^a|y^1,... ,y^{b-1},y^{b+1},... ,y^K)$ is true. This pseudo-correlation is particularly problematic when one of the two attributes that are correlated is a label.
In the example of the previous figure (a), if the label is a shape ($y^2$), the model may predict the shape based on the color, such as $y^2=heart$ if $y^1=red$, $y^2=oval$ if $y^1=green$, and so on. In this case, $p_test$, which does not correlate with attributes, will fail Generalization.
Low-data drift
Low data drift occurs when the attribute values are biased in $p_{train}$ but not in $p_{test}$ (Figure (b) above). This distribution shift occurs when the collection of the dataset is biased by attribute values. According to the above framework, this is the case when $p_{train}(y^a=v) << p_{test}(y^a=v)$.
Unseen data shift
A special case of low data drift can be when data for a particular attribute value is missing during train. According to the framework described above, this can be represented by the following equation
On more complex distribution shifts
The label's neighborhood distribution $p(y^l)$ can be decomposed into two terms consisting of the probability of a particular attribute value $p(y^a)$ and the conditional probability $p(y^l|y^a)$, which is expressed as $p(y^l)=\sum_{y^a}p(y^l|y^a)p(y^a)$.
In this case, the pseudo-correlation controls $p(y^l|y^a)$ and the low data drift/unseen data shift controls $p(y^a)$. Therefore, more complex distribution shifts can be described with these three distribution shifts as components.
additional condition
In addition to these distribution shifts, two additional conditions may arise in a real environment
label noise
Label noise arises when there is disagreement or error among annotators. It is modeled as the observed attribute (e.g., label) being destroyed by the noise. This is represented as $\hat{y}^i~c(y^i)$, where $\hat{y}^i$ is the destroyed label and $y^i$ is the true label.
data-set size
The size constraint of the training dataset may change the performance of the model.
In the original paper, we introduced these conditions in our experiments to evaluate the model.
Methods for increasing robustness
Under the conditions of access to $p_train$ during model training, reducing the risk in the true distribution $p$ and the test distribution $p_{test}$ is to increases the robustness to distribution shifts. To achieve this goal, we can use the following methods.
Weighted resampling (WRS)
For the train set, resampling is performed using the importance weights $W(y^{1:K}) = p(y^{1:K})/p_{train}(y^{1:K})$.
In this case, the $i$-th data point $(y^{1:K}_i,x_i)$ is selected with probability $W(y^{1:K}_i)/\sum^n_{i'=1} W(y^{1:K}_{i'})$ instead of $1/n$.
Since we do not always have access to the true distribution $p(y^{1:K})$ in practice, we often assume that all combinations of attributes occur uniformly at random.
2. Heuristic Data Augmentation
Weighted resampling reduces overfitting by performing heuristic data augmentation, since the same samples may be reused many times.
3. Learned Data Augmentation
Based on the fact that the true distribution can be expressed as the product of the peripheral distribution $p(y^{1:K})$ of the attribute $y^{1:K}$ and the conditional generative model $p(x|y^{1:K})$, we learn the conditional generative model $\hat{p}(x|y^{1:K})$ from the TRAIN data and create a new synthetic data sampling.
In this case, we train the supervised classifier on the dataset obtained from the augmented data distribution $p_{aug}=(1-\alpha)p_{train}+alpha \hat{p}(x|y^{1:K})p(y^{1:K})$.
4. Representation Learning
Another possible factorization is $p_{train}(y^{1:K}, x) = \int p(z|x)p_{train}(y^{1:K}|z)dz$.
Based on this, we can perform unsupervised representation learning of $p(z|x)$ from train data and learn a classifier head $p{train}(y^l|z)$ that makes predictions based on the latent variable $z$. If we can perform representation learning properly, we can achieve Generalization to $p_test,p$ without being affected by a particular attribute distribution.
Experiment setup
In our experiments, we evaluate 19 algorithms for methods to improve the robustness of the model to distributional shifts.
architecture
The architecture of the model used in the experiments is as follows.
- ResNet18,ResNet50,ResNet101
- ViT
- MLP
We also perform weighted resampling during training, oversampling from the low probability part of $p_{train}$.
Heuristic data augmentation
To improve the robustness, we analyze the following data enhancement methods.
- Standard ImageNet augmentation
- AugMix without JSD
- RandAugment
- AutoAugment
Trained data augmentation
We approximate the conditional generative model $p(x|y^{1:K})$ and use the generated image as data augmentation.
We use CycleGAN for approximation.
Domain generalization
Domain generalization methods aim to recover attribute-independent representations $z$, and we experiment with the following methods.
- IRM
- DeepCORAL
- domain MixUp
- DANN
- SagNet
Adaptive approaches
As for adaptive approaches, we will experiment with the following methods.
- JTT
- BN-Adapt
expressive learning
As a representation learning method, we experiment with the following methods.
- beta-VAE
- Pre-training on ImageNet (using additional data for $D_{train}$)
Dataset and Model Selection
For our experiments, we use six image classification datasets.
- DSPRITES
- MPI3D
- SMALLNORB
- SHAPES3D
- CAMELYON17
- IWILDCAM
Here, we use ResNet18 for simple synthetic datasets (DSPRITES, MPI3D, SHAPES3D, SMALLNORB) and ResNet50 for complex real-world datasets (CAMELYON17, IWILDCAM). The experiments are run for five seed values.
experimental results
The results for Spurious Correlation, Low-data drift, and Unseen data shift are as follows.
The results for the case where the label noise data size constraint exists are as follows.
Overall, the results obtained can be summarized as follows.
- No method always shows the best performance.
- Pre-training can be a powerful tool across a variety of data shifts and data sets.
- Heuristic data augmentation does not always improve the results.
- The learned data augmentation is effective across a variety of conditions and distribution shifts.
- The performance improvement by Domain generalization was limited.
- The optimal algorithm may vary depending on the detailed conditions.
- The attributes considered have a direct impact on the results.
For a more detailed description of the experimental results, please refer to the original paper.
Also, for practical information, the paper recommends the following tips
- If heuristic data augmentation promotes invariance, take advantage of it.
- When heuristic data augmentation is not useful, we use learned data augmentation.
- Use prior learning.
- Improvements through complex methods are limited.
Finally, the paper discusses the results of the experiment as follows.
- It is not possible to pre-determine the best method from the data set alone.
- Where there is knowledge of distributional shifts, this should be the focus.
- It is extremely important to evaluate the method under a variety of conditions.
summary
In this article, we have presented a paper that proposes a comprehensive framework for distributional shifts, as well as a detailed experimental analysis of various methods. This framework and benchmarks could be a useful tool for the evaluation of methods related to distributional shifts.
Categories related to this article