[NAS論文][Transformer][預(yù)訓(xùn)練模型]精讀NAS-BERT
NAS-BERT: Task-Agnostic and Adaptive-Size BERT Compression with Neural Architecture Search

簡(jiǎn)介:
-碼沒(méi)有開(kāi)源,但是論文寫(xiě)得挺清晰,應(yīng)該可以手工實(shí)現(xiàn)。BERT參數(shù)量太多推理太慢(雖然已經(jīng)支持用tensorRT8.X取得不錯(cuò)的推理效果,BERT-Large推理僅需1.2毫秒),但是精益求精一直是科研人員的追求,所以本文用權(quán)重共享的one-shot的NAS方式對(duì)BERT做NAS搜索。
涉及到的方法包括 block-wise search, progressive shrinking,and performance approximation
講解:
1、搜索空間定義
搜索空間的ops包括深度可分離卷積的卷積核大小[3/5/7],Hidden size大小【128/192/256/384/512】MHA的head數(shù)[2/3/4/6/8],F(xiàn)NN[512/768/1021/1536/2048]、和identity 連接,也就是跳層了,一共26個(gè)op,具體可見(jiàn)下圖:
注意這里的MHA和FNN是二選一的關(guān)系,但是可以比如說(shuō)第一層選MHA第二層選FNN,這樣構(gòu)成一個(gè)基本的Transformer塊,可以說(shuō)這個(gè)方法打破的定式的Transformer塊的搜索又包含了Transformer和BERT的結(jié)構(gòu),不同層間也是鏈?zhǔn)芥溄樱繉又贿x擇一個(gè)op,如下圖
2、超網(wǎng)絡(luò)的訓(xùn)練方式
【 Block-Wise Training + Knowledge Distillation、分塊訓(xùn)練+KD蒸餾】
(1)首先把超網(wǎng)絡(luò)等分成N個(gè)Blocks
(2)以原始的BERT作為T(mén)eacher模型,BERT也同樣分為N個(gè)Blocks
(3)超網(wǎng)絡(luò)(Student)中第n個(gè)塊的輸入是teacher模型第n-1個(gè)塊的輸出,來(lái)和teacher模型的第n個(gè)塊的輸出做均方差來(lái)作為loss,來(lái)預(yù)測(cè)teacher模型中這第n個(gè)block的輸出
(4)超網(wǎng)絡(luò)的訓(xùn)練是單架構(gòu)隨機(jī)采樣訓(xùn)練
(5)由于student 塊的隱藏大小可能與teacher塊中的hidden size不同,能直接利用教師塊隱藏的輸入,和輸出作為學(xué)生塊的訓(xùn)練數(shù)據(jù)。為了解決這個(gè)問(wèn)題,需要在學(xué)生塊的輸入和輸出處使用一個(gè)可學(xué)習(xí)的線(xiàn)性變換層來(lái)轉(zhuǎn)換每個(gè)hidden size,以匹配教師塊的大小,如下圖所示
【 Progressive Shrinking】
搜索空間太大,超網(wǎng)絡(luò)需要有效的訓(xùn)練,可以借助Progressive Shrinking的方式來(lái)加速訓(xùn)練和提高搜索效率,以下簡(jiǎn)稱(chēng)為PS。但是不能簡(jiǎn)單粗暴的剔除架構(gòu),因?yàn)榇蠹軜?gòu)再訓(xùn)練初期難收斂,效果不好,但是并不能代表其表征能力差,所以本文設(shè)置了一個(gè)PS規(guī)則:
其含義,a^t表示超網(wǎng)絡(luò)中最大的架構(gòu),p(?)表示參數(shù)量大小,l(?)表示latency大小,B表示設(shè)置B個(gè)區(qū)間桶,b表示當(dāng)前為第幾個(gè)區(qū)間。如果一個(gè)架構(gòu)a不滿(mǎn)足p_b>p(a)>pb_1并且l_b>l(a)>l_b-1這個(gè)區(qū)間,就剔除。
PS的過(guò)程就是從每個(gè)B桶中抽E個(gè)架構(gòu),過(guò)驗(yàn)證集,剔除R個(gè)最大loss的架構(gòu),重復(fù)這個(gè)過(guò)程直到只有m個(gè)架構(gòu)在每個(gè)桶中
3、Model Selection
建一個(gè)表,包括 latency 、loss、 參數(shù)量 和結(jié)構(gòu)編碼,其中l(wèi)oss和latency是預(yù)測(cè)評(píng)估的方法,評(píng)估方法具體可以看論文,對(duì)于給定的模型大小和推理延遲約束條件,從滿(mǎn)足參數(shù)和延遲約束的表中選擇最低loss的T個(gè)架構(gòu),然后把這個(gè)T個(gè)架構(gòu)過(guò)驗(yàn)證集,選取最好的那個(gè)。
實(shí)驗(yàn)結(jié)果
1、和原始BERT相比在 GLUE Datasets上都有一定的提升:
2、和其他變種BERT相比效果也不錯(cuò):
消融實(shí)驗(yàn)
1、PS是否有效?
如果不用PS方法,需要巨大的驗(yàn)證上的時(shí)間(5min vs 50hours),并且超網(wǎng)絡(luò)訓(xùn)練更難收斂,影響架構(gòu)排序:
2、是PS架構(gòu)還是PS掉node
結(jié)論是PS掉node太過(guò)粗暴,效果不好:
3、二階段蒸餾是否有必有?
本文蒸餾探究了預(yù)訓(xùn)練階段和finetune階段,也就是pre-train KD 和 finetune KD,結(jié)論是:
1、預(yù)訓(xùn)練蒸餾效果比f(wàn)inetune時(shí)候蒸餾好
2、兩階段一起蒸餾效果最好
AI 卷積神經(jīng)網(wǎng)絡(luò) 機(jī)器學(xué)習(xí) 深度學(xué)習(xí) 神經(jīng)網(wǎng)絡(luò)
版權(quán)聲明:本文內(nèi)容由網(wǎng)絡(luò)用戶(hù)投稿,版權(quán)歸原作者所有,本站不擁有其著作權(quán),亦不承擔(dān)相應(yīng)法律責(zé)任。如果您發(fā)現(xiàn)本站中有涉嫌抄襲或描述失實(shí)的內(nèi)容,請(qǐng)聯(lián)系我們jiasou666@gmail.com 處理,核實(shí)后本網(wǎng)站將在24小時(shí)內(nèi)刪除侵權(quán)內(nèi)容。
版權(quán)聲明:本文內(nèi)容由網(wǎng)絡(luò)用戶(hù)投稿,版權(quán)歸原作者所有,本站不擁有其著作權(quán),亦不承擔(dān)相應(yīng)法律責(zé)任。如果您發(fā)現(xiàn)本站中有涉嫌抄襲或描述失實(shí)的內(nèi)容,請(qǐng)聯(lián)系我們jiasou666@gmail.com 處理,核實(shí)后本網(wǎng)站將在24小時(shí)內(nèi)刪除侵權(quán)內(nèi)容。