Pytorch交叉熵损失(CrossEntropyLoss)函数内部运算解析_crossentropyloss(reduction="mean")-程序员宅基地

技术标签: python  深度学习  pytorch  

  对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。

1. 函数介绍

  Pytorch官网中对交叉熵损失函数的介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)

  该损失函数计算输入值(input)和目标值(target)之间的交叉熵损失。交叉熵损失函数可用于训练一个 C C C类别的分类问题。参数weight给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
  函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为 ( C ) (C) C的张量, ( m i n i b a t c h , C ) (minibatch,C) minibatchC ( m i n i b a t c h , C , d 1 , d 2 , . . . , d K ) (minibatch,C,d_1 ,d_2 ,... ,d_K) minibatchCd1d2...dK,在K维情况下, K ≥ 1 K \geq1 K1
  函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引
   本文以下内容均为target为类索引的情况。

  函数目标值(target)取值为在 [ 0 , C ) [0,C) [0C)之间的类索引, C C C为类别数。参数reduction设为'none'时,交叉熵损失可描述如下:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n l o g e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) ⋅ 1 { y n   / = i g n o r e _ i n d e x } (1) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T, \\ \large l_n = -w_{yn}log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}\cdot 1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}\tag{1} l(x,y)=L={ l1,...,lN}T,ln=wynlogc=1Cexp(xn,c)exp(xn,yn)1{ yn/=ignore_index}(1)

  其中, x x x是输入, y y y是目标值, w w w是weight, C C C是类别数, N N N为batch size。在reduction不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (2) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{2} l(x,y)=n=1Nn=1Nwyn1{ yn/=ignore_index}1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(2)

 需要指出的是,在这种情况下的交叉熵损失等价于LogSoftmaxNLLLoss的组合。1

  因此,我们可以从LogSoftmaxNLLLoss来深度解析交叉熵损失函数的内部运算。

2. LogSoftmax函数

  LogSoftmax()函数2公式如下:
L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) (3) LogSoftmax(x_i) = log(\frac{exp(x_i)}{\sum_{j}exp(x_j)}) \tag{3} LogSoftmax(xi)=log(jexp(xj)exp(xi))(3)
  即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的 log ⁡ e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) \textcolor{red}{\log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}} logc=1Cexp(xn,c)exp(xn,yn)

  代码示例如下:

>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
        [7., 3., 8.],
        [9., 7., 5.]])
 
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
        [0.2676, 0.0049, 0.7275],
        [0.8668, 0.1173, 0.0159]]) 
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])
        
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])

3. NLLLoss函数

  Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction为默认值'mean',参数weight为默认值'none'时)。

  官网介绍如下:

CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)

  参数reduction值为'none'时:
l ( x , y ) = L = { l 1 , . . . , l N } T ,   l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c   / = i g n o r e _ i n d e x } , (4) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T,\ l_n = -w_{yn}x_{n,yn}, w_c = weight[c]\cdot1\left \{ c\mathrlap{\,/}{=}ignore\_index\right \},\tag{4} l(x,y)=L={ l1,...,lN}T, ln=wynxn,yn,wc=weight[c]1{ c/=ignore_index},(4)
  其中, x x x为输入, y y y为目标值, w w w为weight, N N N为batch size。
  参数reduction值不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (5) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{5} l(x,y)=n=1Nn=1Nwyn1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(5)
  可以看出,当reduction'mean'时,即是对 l n l_n ln求加权平均值。weight参数默认为1,因此默认情况下,即是对 l n l_n ln求平均值。又 l n = − w y n x n , y n l_n = -w_{yn}x_{n,yn} ln=wynxn,yn,所以weight为默认值1时, l n = − x n , y n l_n=-x_{n,yn} ln=xn,yn。故此时,即是 x x x求平均后取负。 这部分对于公式(2)中的 ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n \textcolor{red}{\sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n} n=1Nn=1Nwyn1{ yn/=ignore_index}1ln

  实例代码验证如下:

>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
        [0.4952, 1.5870, 0.9594],
        [1.4170, 0.4525, 0.2519]])
        
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)

  平均取负有: V a l u e = − 1 3 ( 2.3858 + 0.4952 + 0.2519 ) = − 1.0443 Value = -\frac{1}{3}\left ( 2.3858+0.4952+0.2519 \right ) =-1.0443 Value=31(2.3858+0.4952+0.2519)=1.0443
  显然,平均取负结果和NLLLoss运算结果相同。

注:笔者窃以为,公式(5)中上式可写为 ∑ n = 1 N l n ∑ n = 1 N w y n \frac{\sum_{n=1}^{N}l_n}{\sum_{n=1}^{N}w_{yn}} n=1Nwynn=1Nln,如此则更容易理解。公式(2)同理。

4. 小结

  本文通过将CrossEntropyLoss拆解为LogSoftmaxNLLLoss两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。


  学习总结,以作分享,如有不妥,敬请指出。


Reference


  1. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss

  2. https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax

  3. https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XIAOSHUCONG/article/details/123527257

智能推荐

cocos2dx之创建悬浮节点_cocos creator scrollview下的节点悬浮-程序员宅基地

文章浏览阅读2.7k次。悬浮节点(NotificationNode)。我们做游戏时,有时候会有一些悬浮图标,永远要显示在屏幕上,如果每次切换场景都重新创建一次则很麻烦。如果设置了NotificationNode那么切换了场景后,这个节点不会消失。_cocos creator scrollview下的节点悬浮

502 bad gateway报错可能的原因之一_java打包报错网关错误502-程序员宅基地

文章浏览阅读1.3k次。记录一次状态码502 Bad gateway第一次遇到请求返回502状态码之前遇到的基本都是500(服务器内部错误),比如java代码运行时,被throw出来的错误。这种错误,打个断点就很好排查。解决过程疯狂网上找资料。网上有一些说法是,当使用Nginx代理时,发生了502报错,很有可能是请求头header过大而导致的。所以,在Nginx配置内添加一下增加缓存的的代码就能解决proxy_buffer_size 64k;proxy_buffers 4 32k;proxy_busy_buffe_java打包报错网关错误502

element-UI+VUE 实现el-table双击单元格编辑(智能操作!不沙雕!看我就对了)_cell-dblclick 其他单格也能触发-程序员宅基地

文章浏览阅读1.1w次,点赞19次,收藏81次。浏览了很多智慧的结晶,要么操作傻瓜,要么过于复杂(不必要的),还有的虽然实现了操作,但逻辑上让我难受。所以自己实操成功后整合一篇。本篇博客涉及到的点有:(后面详解)el_table双击单元格实现编辑操作 el-input回车操作enter与失焦事件blur冲突(会触发两次导致操作异常) 如果是组件之间操作,tableData是从父组件通过props接过来的,在本子页面中定义了另一个空数..._cell-dblclick 其他单格也能触发

【opencv3的鼠标事件选取ROI区域操作】_event == cv_event_lbuttonup-程序员宅基地

文章浏览阅读3k次。选取图像中的ROI区域:#include<opencv2\opencv.hpp>#include <stdio.h> using namespace cv;using namespace std;Mat org, dst, img, tmp; void on_mouse(int event, int x, int y, int flags, void *)..._event == cv_event_lbuttonup

串口输出数据每次都不同的问题排查_firmata软串口输出的问题-程序员宅基地

文章浏览阅读870次。对于不同来源程序移植所需要注意的外部晶振配置_firmata软串口输出的问题

轻量级密码综述_什么是轻量级密码-程序员宅基地

文章浏览阅读472次。轻量级密码在设计时要考虑其应用目标平台,已提出的轻量级密码有面向硬件实现的设计、面向软件实现的设计和综合考虑软硬件实现的混合设计。或者在现有典型分组密码的基础上,对密码算法的组件进行轻量化的改进;还有于2011年发布的Piccolo、Lblock、KLEIN、LED、EPCBC算法,2012年发布的PRINCE、TWICE算法,2014年发布的LEA算法,2015年发布的SIMECK、SIMON算法,2016年发布的QTL算法,2017年发布的Magpie算法,2018年发布的Surge、SFN算法。_什么是轻量级密码

随便推点

棒料切割机设计(论文+CAD图纸+实习报告)_小型棒料切割机的结构设计-程序员宅基地

文章浏览阅读433次。铸棒线割机在连续的铸造中工作,它的工作是由PLC控制电磁阀,使电磁阀控制气缸,并由气缸驱动与其连接的部件,实现对铸棒的准确定长切割,切割后自动返回初始位置。连续铸造是一种先进的铸造方法,其原理是将熔融的金属,不断浇入一种叫做结晶器的特殊金属型中,凝固(结壳)了的铸件,连续不断地从结晶器的另一端拉出,它可获得任意长或特定的长度的铸件。三、在设计过程中,纵横行走装置采用了直线导轨,既提高了运动系统的运动精度,又很大程度的减小了摩擦力,达到了节能的效果。5.3气动原理图的设计 ……1.1设计要求 ……_小型棒料切割机的结构设计

ygz_slam_ros测试_tabs are prohibited in yaml! in function icvymlski-程序员宅基地

文章浏览阅读933次。###将ygz_slam_ros单独拿出来编译#按照安装依赖https://gaas.gitbook.io/guide/software-realization-build-your-own-autonomous-drone/wu-ren-ji-zi-dong-jia-shi-xi-lie-part-3-zai-wu-gps-huan-jing-xia-tong-guo-slam-shi-x..._tabs are prohibited in yaml! in function icvymlskipspaces

解决ThinkPad 笔记本电脑无法连接手动添加的隐藏网络问题-提示“无法连接这个网络”_thinkpad e460安装网卡失败,没法上网-程序员宅基地

文章浏览阅读3.8k次。本人ThinkPad E460,Win10操作系统,因为工作内容,需要在特定的网络中开发,要连接隐藏网络。公司分配的是台式机,给了个2.4g的网卡,5g没到货时,用着是在憋屈,所以就用自己的笔记本连接隐藏5g网。然后按照正常步骤进行:点击网络图标 - 网络和Internet设置 - 网络和共享中心 - 设置新的网络 - 手动连接到无线网络 ...当点击连接时就给反馈:“无法连接这个..._thinkpad e460安装网卡失败,没法上网

YOLOV8安卓端部署_yolov8部署到手机-程序员宅基地

文章浏览阅读1.7k次,点赞5次,收藏31次。之前部署的yolov5-ncnn不支持调用本地摄像头进行在线推理,多少还是感觉遗憾。说实话yolov8-ncnn的部署属实有点割韭菜的嫌疑,这篇博客教你从0部署yolov8到安卓手机。_yolov8部署到手机

举例说明计算机图形学的主要应用领域,计算机图形学-程序员宅基地

文章浏览阅读2.6k次。1、 举例说明计算机图形学的主要应用领域(至少说明5个应用领域)计算机及辅助设计与制造、可视化、图形实时绘制与自然景物仿真、计算机动画、用户接口、计算机艺术2、 分别解释直线生成算法DDA法、中点画线法和Bresenham法的基本原理。 DDA法:设过端点P0(x0 ,y0)、P1(x1 ,y1)的直线段为L(P0 ,P1),则直线段L的斜率L的起点P0的横坐标x0向L的终点P1的横坐标x1步进..._列举有关计算机图形学的应用

python中在一个类中调用另一个类的方法_python中类方法如何调用-程序员宅基地

文章浏览阅读5.3w次,点赞17次,收藏67次。通过实例化一个对象,使一个类能调用另一个类的方法主题代码主题描述老张开车去东北这件事类人实例变量:名字name实例方法:去go_to车实例方法:run代码class Person: def __inti__(self,name): self.name = name def go_to(self,position,type): ''' :par..._python中类方法如何调用

推荐文章

热门文章

相关标签