中文字幕 另类精品,亚洲欧美一区二区蜜桃,日本在线精品视频免费,孩交精品乱子片免费

<sup id="3hn2b"></sup>

    1. <sub id="3hn2b"><ol id="3hn2b"></ol></sub><legend id="3hn2b"></legend>

      1. <xmp id="3hn2b"></xmp>

      2. "); //-->

        博客專欄

        EEPW首頁 > 博客 > 知識蒸餾綜述:代碼整理(2)

        知識蒸餾綜述:代碼整理(2)

        發(fā)布人:計算機視覺工坊 時間:2022-01-16 來源:工程師 發(fā)布文章

        6. VID: Variational Information Distillation

        全稱:Variational Information Distillation for Knowledge Transfer

        鏈接:https://arxiv.org/pdf/1904.05835.pdf

        發(fā)表:CVPR19

        5.jpg

        利用互信息(Mutual Information)來衡量學生網絡和教師網絡差異?;バ畔⒖梢员硎境鰞蓚€變量的互相依賴程度,其值越大,表示變量之間的依賴程度越高?;バ畔⒂嬎闳缦拢?/p>

        互信息是教師模型的熵減去在已知學生模型條件下教師模型的熵。目標是最大化互信息,因為互信息越大說明H(t|s)越小,即學生網絡確定的情況下,教師網絡的熵會變小,證明學生網絡已經學習的比較充分。整體loss如下:

        由于p(t|s)很難計算,可以使用變分分布q(t|s)去接近真實分布。

        其中q(t|s)是使用方差可學習的高斯分布模擬(公式中的log_scale):

        實現如下:

        class VIDLoss(nn.Module):
            """Variational Information Distillation for Knowledge Transfer (CVPR 2019),
            code from author: https://github.com/ssahn0215/variational-information-distillation"""
            def __init__(self,
                         num_input_channels,
                         num_mid_channel,
                         num_target_channels,
                         init_pred_var=5.0,
                         eps=1e-5):
                super(VIDLoss, self).__init__()
                def conv1x1(in_channels, out_channels, stride=1):
                    return nn.Conv2d(
                        in_channels, out_channels,
                        kernel_size=1, padding=0,
                        bias=False, stride=stride)
                self.regressor = nn.Sequential(
                    conv1x1(num_input_channels, num_mid_channel),
                    nn.ReLU(),
                    conv1x1(num_mid_channel, num_mid_channel),
                    nn.ReLU(),
                    conv1x1(num_mid_channel, num_target_channels),
                )
                self.log_scale = torch.nn.Parameter(
                    np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
                    )
                self.eps = eps
            def forward(self, input, target):
                # pool for dimentsion match
                s_H, t_H = input.shape[2], target.shape[2]
                if s_H > t_H:
                    input = F.adaptive_avg_pool2d(input, (t_H, t_H))
                elif s_H < t_H:
                    target = F.adaptive_avg_pool2d(target, (s_H, s_H))
                else:
                    pass
                pred_mean = self.regressor(input)
                pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
                pred_var = pred_var.view(1, -1, 1, 1)
                neg_log_prob = 0.5*(
                    (pred_mean-target)**2/pred_var+torch.log(pred_var)
                    )
                loss = torch.mean(neg_log_prob)
                return loss

        7. RKD: Relation Knowledge Distillation

        全稱:Relational Knowledge Disitllation

        鏈接:http://arxiv.org/pdf/1904.05068

        發(fā)表:CVPR19

        RKD也是基于關系的知識蒸餾方法,RKD提出了兩種損失函數,二階的距離損失和三階的角度損失。

        Distance-wise Loss

        Angle-wise Loss

        實現如下:

        class RKDLoss(nn.Module):
            """Relational Knowledge Disitllation, CVPR2019"""
            def __init__(self, w_d=25, w_a=50):
                super(RKDLoss, self).__init__()
                self.w_d = w_d
                self.w_a = w_a
            def forward(self, f_s, f_t):
                student = f_s.view(f_s.shape[0], -1)
                teacher = f_t.view(f_t.shape[0], -1)
                # RKD distance loss
                with torch.no_grad():
                    t_d = self.pdist(teacher, squared=False)
                    mean_td = t_d[t_d > 0].mean()
                    t_d = t_d / mean_td
                d = self.pdist(student, squared=False)
                mean_d = d[d > 0].mean()
                d = d / mean_d
                loss_d = F.smooth_l1_loss(d, t_d)
                # RKD Angle loss
                with torch.no_grad():
                    td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
                    norm_td = F.normalize(td, p=2, dim=2)
                    t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
                sd = (student.unsqueeze(0) - student.unsqueeze(1))
                norm_sd = F.normalize(sd, p=2, dim=2)
                s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
                loss_a = F.smooth_l1_loss(s_angle, t_angle)
                loss = self.w_d * loss_d + self.w_a * loss_a
                return loss
            @staticmethod
            def pdist(e, squared=False, eps=1e-12):
                e_square = e.pow(2).sum(dim=1)
                prod = e @ e.t()
                res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
                if not squared:
                    res = res.sqrt()
                res = res.clone()
                res[range(len(e)), range(len(e))] = 0
                return res

        8. PKT:Probabilistic Knowledge Transfer

        全稱:Probabilistic Knowledge Transfer for deep representation learning鏈接:https://arxiv.org/abs/1803.10837發(fā)表:CoRR18

        提出一種概率知識轉移方法,引入了互信息來進行建模。該方法具有可跨模態(tài)知識轉移、無需考慮任務類型、可將手工特征融入網絡等有點。

        6.jpg

        實現如下:

        class PKT(nn.Module):
            """Probabilistic Knowledge Transfer for deep representation learning
            Code from author: https://github.com/passalis/probabilistic_kt"""
            def __init__(self):
                super(PKT, self).__init__()
            def forward(self, f_s, f_t):
                return self.cosine_similarity_loss(f_s, f_t)
            @staticmethod
            def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
                # Normalize each vector by its norm
                output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
                output_net = output_net / (output_net_norm + eps)
                output_net[output_net != output_net] = 0
                target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
                target_net = target_net / (target_net_norm + eps)
                target_net[target_net != target_net] = 0
                # Calculate the cosine similarity
                model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
                target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
                # Scale cosine similarity to 0..1
                model_similarity = (model_similarity + 1.0) / 2.0
                target_similarity = (target_similarity + 1.0) / 2.0
                # Transform them into probabilities
                model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
                target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)
                # Calculate the KL-divergence
                loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
                return loss

        9. AB: Activation Boundaries

        全稱:Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons

        鏈接:https://arxiv.org/pdf/1811.03233.pdf

        發(fā)表:AAAI18

        目標:讓教師網絡層的神經元的激活邊界盡量和學生網絡的一樣。所謂的激活邊界指的是分離超平面(針對的是RELU這種激活函數),其決定了神經元的激活與失活。AB提出的激活轉移損失,讓教師網絡與學生網絡之間的分離邊界盡可能一致。

        7.jpg

        實現如下:

        class ABLoss(nn.Module):
            """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
            code: https://github.com/bhheo/AB_distillation
            """
            def __init__(self, feat_num, margin=1.0):
                super(ABLoss, self).__init__()
                self.w = [2**(i-feat_num+1) for i in range(feat_num)]
                self.margin = margin
            def forward(self, g_s, g_t):
                bsz = g_s[0].shape[0]
                losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)]
                losses = [w * l for w, l in zip(self.w, losses)] 
                # loss = sum(losses) / bsz
                # loss = loss / 1000 * 3
                losses = [l / bsz for l in losses]
                losses = [l / 1000 * 3 for l in losses]
                return losses
            def criterion_alternative_l2(self, source, target):
                loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
                        (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
                return torch.abs(loss).sum()

        10. FT: Factor Transfer

        全稱:Paraphrasing Complex Network: Network Compression via Factor Transfer

        鏈接:https://arxiv.org/pdf/1802.04977.pdf

        發(fā)表:NIPS18

        提出的是factor transfer的方法。所謂的factor,其實是對模型最后的數據結果進行一個編解碼的過程,提取出的一個factor矩陣,用教師網絡的factor來指導學生網絡的factor。

        8.jpg

        FT計算公式為:

        實現如下:

        class FactorTransfer(nn.Module):
            """Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018"""
            def __init__(self, p1=2, p2=1):
                super(FactorTransfer, self).__init__()
                self.p1 = p1
                self.p2 = p2
            def forward(self, f_s, f_t):
                return self.factor_loss(f_s, f_t)
            def factor_loss(self, f_s, f_t):
                s_H, t_H = f_s.shape[2], f_t.shape[2]
                if s_H > t_H:
                    f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
                elif s_H < t_H:
                    f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
                else:
                    pass
                if self.p2 == 1:
                    return (self.factor(f_s) - self.factor(f_t)).abs().mean()
                else:
                    return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean()
            def factor(self, f):
                return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))


        *博客內容為網友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區(qū)

        關閉