Transformerで年齢を予測する!!
3つの要点
✔️ 脳のMRIから年齢を予測する研究
✔️ 脳のMRIに適したGlobal-Local Transformerを提案
✔️ 従来の手法より高い精度での年齢予測を実現
Global-Local Transformer for Brain Age Estimation
written by Minghao Chen, Houwen Peng, Jianlong Fu, Haibin Ling
(Submitted on 1 Jul 2021)
Comments: Published on arxiv.
Subjects: Computer Vision and Pattern Recognition (cs.CV)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
背景
発展著しい深層学習ですが、その効果はヘルスケアの領域にも及んでいます。そのなかで、脳のMRIからその人の年齢を予測するBrain age estimationというものがあります。
Brain age estimationで予測された年齢と実際の年齢の差が、その人の脳の健康状態に関係していると言われており、健康状態を測定する指標になっています。そんな深層学習によるBrain age estimationですが、従来のモデルでは、脳のMRI全体のみから特徴が抽出されるため、MRIに含まれる細かな特徴まで考慮できていませんでした。
そこで本論文では、画像全体の特徴に加えて細かな特徴までも抽出するモデル、Global-Local Transformerを提案しています。
Global-Local Transformerと従来の手法を比較したところ、従来手法を超える精度での年齢予測を実現しました。さらに、Global-Local TransformerがMRIのどの部分に注目してBrain age estimationを行っているかを考察しています。
本記事では、Global-Local Transformerの解説、従来の手法との比較実験、モデルの注目箇所を可視化した結果について紹介していきます。
提案手法:GLobal-Local Transformer
提案手法の全体は上図のようになっています。入力データとしては、MRI全体の画像 (上部)とMRIからランダムに切り取ったパッチ画像 (下部)を使用します。提案手法の流れとしては、BackboneとGlobal-Local Transformerに分けられます。
Backbone
Backboneでは、全体画像とパッチ画像に対してCNNによる特徴抽出を行います。CNNのアーキテクチャは、以下の図に示すとおり畳み込み層、Batch Normalization、ReLU、Max poolingを重ねたものになっています。
Global-Local Transformer
Global-Local Transformerでは、Backboneで抽出されたMRIの画像全体とパッチ画像の特徴量を入力として、年齢予測を行います。
Global-Local TransformerとオリジナルのTransformerの異なる点は、Self-Attentionの代わりに、本論文で提案するGlobal-Local Attentionを使用している点です。 また、Global-Local Transformerでは、Layer Nomalizationが使用されておらず、Global-Local Attentionの出力にパッチ画像から抽出された特徴量が結合されています。
本論文の肝であるGlobal-Local Attentionは下図のようになっています。
Global-Local Attentionでは、queryにパッチ画像からの特徴量、keyとvalueに全体画像からの特徴量を使用します。こうすることで、MRIに含まれる細かな特徴とMRI全体の特徴を組み合わせることが出来ます。
また、次節の実験では精度と予測にかかる時間を考慮し、Global-Local Transformerを6回分積み重ねたモデルを使用しています。(2回目以降のGlobal-Local Transformerでは、queryとして、その1つ前のGlobal-Local Transformerの出力を使用しています。)
Brain age estimation
データセット
モデルの評価のために、上の表に示す8つの脳MRIデータセットを使用します。
N_samplesがデータ数、Arge rangeが年齢の範囲、Genderが男女の内訳となっています。
また、本記事では、上から6つのデータセットを合わせたデータセットによる5-fold-validationを行った実験結果のみ紹介します。
実験結果
学習時の損失関数としては、予測した年齢と実際の年齢とのMAEを使用しています。
評価指標としては、予測した年齢と実際の年齢のMAEと相関係数 (Pearson Correlation)、CS (α=5)を使用しています。CS (α=5)は予測した年齢と実際の年齢の差が5歳を下回る割合を表しています。
5-fold-validationの結果は上の表のようになっています。上から8つのモデルが画像認識に使用されているモデルで、その下が提案手法を含む年齢予測のためのモデルになっています。提案手法は、MAE、相関係数、CS (α=5)のすべての評価指標において、従来の手法を上回る結果を実現しています。
Visualization
提案手法が脳MRIのどの部分に注目して、Brain age estimationを行っているかをヒートマップとして可視化したものを下図に示します。
上図から、どの脳MRIでもほとんど同じような領域 (赤い部分)が注目されているのが確認できます。このことから、Brain age estimationを行う際に重要となる情報が、特定の領域に含まれていることがわかります。
さらに、年代別に注目箇所の可視化を行った場合、下図のようになっています。下図において、ヒートマップの下に示されている数字が年代を表しています。
上図から、年代によって注目箇所が異なることがわかります。例えば、0~5歳では前頭葉の部分 (ヒートマップの上部左)あたりに注目していますが、30-35歳では、頭頂葉の部分 (ヒートマップの下部右)あたりに注目箇所が移動しています。
まとめ
今回は、脳のMRIから年齢を予測するモデル、Global-Local Transformerについて紹介しました。MRIに含まれる細かな特徴を抽出する方法をTransformerに組み込むことで、従来の手法を超える年齢予測の精度を実現しました。また、本論文では、この研究に関して「データセットに有病者のデータが含まれていない」、「データセット内で患者の年代に偏りがある」などの課題が挙げられていました。
これらの課題を解決することで、Brain age estimationのさらなる進展に期待したいです。
この記事に関するカテゴリー