赶上最新的AI论文

重新思考表格数据的深度学习模型

重新思考表格数据的深度学习模型

表格式数据

三个要点
✔️ 考察表格式数据上的深度学习方法的现状
✔️ 提出一个基于ResNet和Transformer的基线
✔️ 与现有深度学习方法、GBDT和基线的比较实验

Revisiting Deep Learning Models for Tabular Data
written by Yury GorishniyIvan RubachevValentin KhrulkovArtem Babenko
(Submitted on 22 Jun 2021 (v1), last revised 10 Nov 2021 (this version, v2))
Comments: NeurIPS 2021

Subjects: Machine Learning (cs.LG)

code:  

 

本文所使用的图片要么来自论文、介绍性幻灯片,要么是参考这些图片制作的 。

简介

梯度提升决策树(GBDT)是最受欢迎的表格数据方法之一,但也有很多关于将深度学习用于表格数据的研究。

然而,在表格数据领域,现有的深度学习方法还没有得到充分的比较,部分原因是没有既定的基准(如图像识别的ImageNet或自然语言处理的GLUE)。因此,诸如深度学习方法在表格数据上的有效性以及GBDT和深度学习方法哪个更胜一筹等问题仍然不清楚。

在这篇文章中,我们介绍了一个简单的表格数据的基线方法和一个多样化的任务集,对表格数据上的深度学习方法进行了详细的检查。让我们看看下面的内容。

表格式数据问题的模型

首先,我们介绍了对表格数据问题进行性能比较实验的模型。

MLPs

MLP(多层感知器)由以下公式表示

$MLP(x) = Linear (MLPBlock (...(MLPBlock(x))))$)

$MLPBlock(x) = Dropout(ReLU(Linear(x)))$

储备网

接下来,我们介绍一个基于ResNet的简单基线,它主要用于计算机视觉任务。这可以用以下公式表示

$ResNet(x) = 预测 (ResNetBlock (...(ResNetBlock (Linear(x)))))$)

$ResNetBlock(x) = x + Dropout(Linear(Dropout(ReLU(Linear(BatchNorm(x))))))$)

$Prediction(x) = Linear (ReLU (BatchNorm (x))$

ResNetBlock引入了一个类似于现有ResNet的跳过连接。

FT-变压器

接下来,我们介绍FT-Transformer(特征标记转换器),它是对Transformer架构的修改,适用于表格数据,已成功用于各种任务,包括自然语言处理。粗略的结构如下图所示。

作为一个整体,输入的$x$首先被特征标记器转换为嵌入的$T$,然后添加了[CLS]标记的$T0$被传递到转换器中。然后根据最后一个[CLS]标记所对应的表示进行预测。

关于特征代码器

特征标记模块将输入的$x$转换为R^{k×d}$中的嵌入$T\。

在这种情况下,根据输入的$x$是数字数据($x^{(num)}$)还是分类数据($x^{(cat)}$),会进行不同的处理,如下式所示。

$T^{(num)}_j = b^{(num)}_j + x^{(num)}_j δcdot W^{(num)}_j δin R^d$

$T^{(cat)}_j = b^{(cat)}_j + e^T_j W^{(cat)}_j in R^d$。

$T = 堆栈 [T^{(num)}_1, ... , T^{(num)}_{k^{(num)}} , T^{(cat)}_1 , ...。 , T^{(cat)}_{k^{(cat)} ]/in R^{k×d}$

其中$k$是特征的数量,$e^T_j$是与分类特征相对应的单热向量。

关于Tranformer

变压器过程由$L$变压器层$(F_1,...,F_L)$组成,如下图所示。,F_L)$,如下图所示。

关于预测

使用通过转化器获得的[CLS]标记表示,通过以下公式描述的过程进行预测。

$hat{y} = Linear(ReLU(LayerNorm(T^{[CLS]}_L))))$)

作为注意事项,FT-Transformer与MLP和ResNet相比有一些挑战,如训练所需的大量资源和难以应用于大量的特征。

(这可以通过使用更有效的变形器变体来改善)。

其他型号

在专门为表格数据设计的现有模型中,以下是用于比较实验的模型

实验结果

在本节中,我们比较了不同架构的性能。我们不采用独立于模型的方法,如预学习或数据增强。

数据集

在实验中,我们使用了11个不同的数据集,如下图所示。

(RMSE代表均方根误差,Acc代表准确度。)

为了评估性能,我们用不同的随机种子进行了15次实验。在集合设置中,15个模型被分为5组,每组5个,并使用每组内预测的平均值(其他设置见原始论文)。

深度学习模型的比较

深度学习模型的实验结果如下

表格右侧的排名显示了每个模型的平均排名。

在本文介绍的基线中,FT-Transformer(FT-T)在大多数任务中表现最好,因此是一个非常强大的模型。

ResNet也被认为是一个有效的基线,尽管它很简单,但显示的结果仅次于FT-Transformer。在其他方法中,NODE显示了最好的结果。集合设置的结果也显示在下面。

在集合设置中,ResNet和FT-Transformer显示了更好的结果。

深度学习模型和GBDT的比较

下一步是将深度学习模型与GBDT进行比较。然而,我们将忽略速度和硬件要求,比较每种方法所能达到的最佳性能。

由于GBDT包括一个集合技术,深度学习模型也使用集合设置的结果进行比较。结果如下

默认的超参数显示了默认的超参数设置,调整的超参数显示了超参数调整后的模型结果。

FT-Transformer在默认乐器和调音乐器上的表现同样出色,这表明无需调音也可以建立一个优秀的合奏模型。

我们还发现,GBDT在一些任务中的调谐设置表现更好,不能排除深度学习模型总是优于GBDT的情况。

(如果只比较任务的数量,似乎深度学习模型更胜一筹,但这可能只是因为该基准偏向于更适合深度学习模型的问题)。

然而,FT-Transformer在所有的任务中都表现良好,与其他方法相比,它是一个更通用的表格数据模型。

一般来说,我们不能说总有一种方法是DL模型和GBDT之间的最佳解决方案。也可以说,未来对深度学习方法的研究应该集中在GBDT优于深度学习方法的数据集上。

FT-Transformer和ResNet的比较

为了比较FT-Transformer和ResNet的普遍性,我们将使用一个合成任务进行实验。

具体来说,对于代表30个随机构建的决策树的平均预测的$f_{GBDT}$和具有三个随机初始化隐藏层的MLP的$f_{DL}$,我们创建了以下合成数据。

$x ~ N(0, I_k), y = α\cdot f_{GBDT}(x) + (1 - α) \cdot f_{DL}(x)$

这个合成任务被认为在$alpha$较大时更适合于GBDT,而在$alpha$较小时更适合于深度学习方法。

该合成任务中不同方法的比较如下

从图中可以看出,在面向GBDT的设置中,ResNet的性能明显恶化,而FT-Transformer则普遍表现良好,揭示了FT-Transformer的普适性。

最后,ResNet和FT-Transformer的学习时间比较如下。

一般来说,FT-Transformer需要更多的时间来训练,特别是在具有大量特征(YA)的数据集上,FT-Transformer的高计算成本是未来需要改进的一个重要问题。

摘要

在本文中,我们在表格数据上实验了深度学习模型,引入了一个简单的基线,并与现有的方法进行了比较。

它包括关于表格数据深度学习方法的各种信息,如基于ResNet的模型,它是表格数据深度学习方法的一个有用的基线,FT-Transformer是Transformer在表格数据中的应用,以及与当前深度学习方法和GBDT的比较。它包含了很多关于该方法的信息,官方代码也是可用的。如果你有兴趣,请你也看看原始论文。

  • メルマガ登録(ver
  • ライター
  • エンジニア_大募集!!

如果您对文章内容有任何改进建议等,请通过 "联系我们 "表格与爱学网编辑部联系。
如果您能通过咨询表与我们联系,我们将非常感激。

联系我们