中文字幕 另类精品,亚洲欧美一区二区蜜桃,日本在线精品视频免费,孩交精品乱子片免费

<sup id="3hn2b"></sup>

    1. <sub id="3hn2b"><ol id="3hn2b"></ol></sub><legend id="3hn2b"></legend>

      1. <xmp id="3hn2b"></xmp>

      2. "); //-->

        博客專欄

        EEPW首頁(yè) > 博客 > 知識(shí)蒸餾綜述:代碼整理(1)

        知識(shí)蒸餾綜述:代碼整理(1)

        發(fā)布人:計(jì)算機(jī)視覺(jué)工坊 時(shí)間:2022-01-16 來(lái)源:工程師 發(fā)布文章

        作者 | PPRP 

        來(lái)源 | GiantPandaCV

        編輯 | 極市平臺(tái)

        導(dǎo)讀

        本文收集自RepDistiller中的蒸餾方法,盡可能簡(jiǎn)單解釋蒸餾用到的策略,并提供了實(shí)現(xiàn)源碼。

        1. KD: Knowledge Distillation

        全稱:Distilling the Knowledge in a Neural Network

        鏈接:https://arxiv.org/pdf/1503.02531.pd3f

        發(fā)表:NIPS14

        最經(jīng)典的,也是明確提出知識(shí)蒸餾概念的工作,通過(guò)使用帶溫度的softmax函數(shù)來(lái)軟化教師網(wǎng)絡(luò)的邏輯層輸出作為學(xué)生網(wǎng)絡(luò)的監(jiān)督信息,

        使用KL divergence來(lái)衡量學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的差異,具體流程如下圖所示(來(lái)自Knowledge Distillation A Survey)

        1.jpg

        對(duì)學(xué)生網(wǎng)絡(luò)來(lái)說(shuō),一部分監(jiān)督信息來(lái)自hard label標(biāo)簽,另一部分來(lái)自教師網(wǎng)絡(luò)提供的soft label。代碼實(shí)現(xiàn):

        class DistillKL(nn.Module):
            """Distilling the Knowledge in a Neural Network"""
            def __init__(self, T):
                super(DistillKL, self).__init__()
                self.T = T
            def forward(self, y_s, y_t):
                p_s = F.log_softmax(y_s/self.T, dim=1)
                p_t = F.softmax(y_t/self.T, dim=1)
                loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
                return loss

        核心就是一個(gè)kl_div函數(shù),用于計(jì)算學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)的分布差異。

        2. FitNet: Hints for thin deep nets

        全稱:Fitnets: hints for thin deep nets

        鏈接:https://arxiv.org/pdf/1412.6550.pdf

        發(fā)表:ICLR 15 Poster

        對(duì)中間層進(jìn)行蒸餾的開山之作,通過(guò)將學(xué)生網(wǎng)絡(luò)的feature map擴(kuò)展到與教師網(wǎng)絡(luò)的feature map相同尺寸以后,使用均方誤差MSE Loss來(lái)衡量?jī)烧卟町悺?/p>

        2.jpg

        實(shí)現(xiàn)如下:

        class HintLoss(nn.Module):
            """Fitnets: hints for thin deep nets, ICLR 2015"""
            def __init__(self):
                super(HintLoss, self).__init__()
                self.crit = nn.MSELoss()
            def forward(self, f_s, f_t):
                loss = self.crit(f_s, f_t)
                return loss

        實(shí)現(xiàn)核心就是MSELoss。

        3. AT: Attention Transfer

        全稱:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

        鏈接:https://arxiv.org/pdf/1612.03928.pdf

        發(fā)表:ICLR16

        為了提升學(xué)生模型性能提出使用注意力作為知識(shí)載體進(jìn)行遷移,文中提到了兩種注意力,一種是activation-based attention transfer,另一種是gradient-based attention transfer。實(shí)驗(yàn)發(fā)現(xiàn)第一種方法既簡(jiǎn)單效果又好。

        3.jpg

        實(shí)現(xiàn)如下:

        class Attention(nn.Module):
            """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
            via Attention Transfer
            code: https://github.com/szagoruyko/attention-transfer"""
            def __init__(self, p=2):
                super(Attention, self).__init__()
                self.p = p
            def forward(self, g_s, g_t):
                return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
            def at_loss(self, f_s, f_t):
                s_H, t_H = f_s.shape[2], f_t.shape[2]
                if s_H > t_H:
                    f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                elif s_H < t_H:
                    f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                else:
                    pass
                return (self.at(f_s) - self.at(f_t)).pow(2).mean()
            def at(self, f):
                return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

        首先使用avgpool將尺寸調(diào)整一致,然后使用MSE Loss來(lái)衡量?jī)烧卟罹唷?/p>

        4. SP: Similarity-Preserving

        全稱:Similarity-Preserving Knowledge Distillation

        鏈接:https://arxiv.org/pdf/1907.09682.pdf

        發(fā)表:ICCV19SP

        歸屬于基于關(guān)系的知識(shí)蒸餾方法。文章思想是提出相似性保留的知識(shí),使得教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)會(huì)對(duì)相同的樣本產(chǎn)生相似的激活??梢詮南聢D看出處理流程,教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)對(duì)應(yīng)feature map通過(guò)計(jì)算內(nèi)積,得到bsxbs的相似度矩陣,然后使用均方誤差來(lái)衡量?jī)蓚€(gè)相似度矩陣。

        4.jpg

        最終Loss為:

        G代表的就是bsxbs的矩陣。實(shí)現(xiàn)如下:

        class Similarity(nn.Module):
            """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
            def __init__(self):
                super(Similarity, self).__init__()
            def forward(self, g_s, g_t):
                return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
            def similarity_loss(self, f_s, f_t):
                bsz = f_s.shape[0]
                f_s = f_s.view(bsz, -1)
                f_t = f_t.view(bsz, -1)
                G_s = torch.mm(f_s, torch.t(f_s))
                # G_s = G_s / G_s.norm(2)
                G_s = torch.nn.functional.normalize(G_s)
                G_t = torch.mm(f_t, torch.t(f_t))
                # G_t = G_t / G_t.norm(2)
                G_t = torch.nn.functional.normalize(G_t)
                G_diff = G_t - G_s
                loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
                return loss

        5. CC: Correlation Congruence

        全稱:Correlation Congruence for Knowledge Distillation

        鏈接:https://arxiv.org/pdf/1904.01802.pdf

        發(fā)表:ICCV19

        CC也歸屬于基于關(guān)系的知識(shí)蒸餾方法。不應(yīng)該僅僅引導(dǎo)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)單個(gè)樣本向量之間的差異,還應(yīng)該學(xué)習(xí)兩個(gè)樣本之間的相關(guān)性,而這個(gè)相關(guān)性使用的是Correlation Congruence 教師網(wǎng)絡(luò)雨學(xué)生網(wǎng)絡(luò)相關(guān)性之間的歐氏距離。

        整體Loss如下:

        實(shí)現(xiàn)如下:

        class Correlation(nn.Module):
            """Similarity-preserving loss. My origianl own reimplementation 
            based on the paper before emailing the original authors."""
            def __init__(self):
                super(Correlation, self).__init__()
            def forward(self, f_s, f_t):
                return self.similarity_loss(f_s, f_t)
            def similarity_loss(self, f_s, f_t):
                bsz = f_s.shape[0]
                f_s = f_s.view(bsz, -1)
                f_t = f_t.view(bsz, -1)
                G_s = torch.mm(f_s, torch.t(f_s))
                G_s = G_s / G_s.norm(2)
                G_t = torch.mm(f_t, torch.t(f_t))
                G_t = G_t / G_t.norm(2)
                G_diff = G_t - G_s
                loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
                return loss


        *博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。



        關(guān)鍵詞: AI

        相關(guān)推薦

        技術(shù)專區(qū)

        關(guān)閉