Pytorch交叉熵损失(CrossEntropyLoss)函数内部运算解析_crossentropyloss(reduction="mean")-程序员宅基地

技术标签: python  深度学习  pytorch  

  对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。

1. 函数介绍

  Pytorch官网中对交叉熵损失函数的介绍如下:

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)

  该损失函数计算输入值(input)和目标值(target)之间的交叉熵损失。交叉熵损失函数可用于训练一个 C C C类别的分类问题。参数weight给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
  函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为 ( C ) (C) C的张量, ( m i n i b a t c h , C ) (minibatch,C) minibatchC ( m i n i b a t c h , C , d 1 , d 2 , . . . , d K ) (minibatch,C,d_1 ,d_2 ,... ,d_K) minibatchCd1d2...dK,在K维情况下, K ≥ 1 K \geq1 K1
  函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引
   本文以下内容均为target为类索引的情况。

  函数目标值(target)取值为在 [ 0 , C ) [0,C) [0C)之间的类索引, C C C为类别数。参数reduction设为'none'时,交叉熵损失可描述如下:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n l o g e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) ⋅ 1 { y n   / = i g n o r e _ i n d e x } (1) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T, \\ \large l_n = -w_{yn}log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}\cdot 1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}\tag{1} l(x,y)=L={ l1,...,lN}T,ln=wynlogc=1Cexp(xn,c)exp(xn,yn)1{ yn/=ignore_index}(1)

  其中, x x x是输入, y y y是目标值, w w w是weight, C C C是类别数, N N N为batch size。在reduction不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (2) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{2} l(x,y)=n=1Nn=1Nwyn1{ yn/=ignore_index}1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(2)

 需要指出的是,在这种情况下的交叉熵损失等价于LogSoftmaxNLLLoss的组合。1

  因此,我们可以从LogSoftmaxNLLLoss来深度解析交叉熵损失函数的内部运算。

2. LogSoftmax函数

  LogSoftmax()函数2公式如下:
L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) (3) LogSoftmax(x_i) = log(\frac{exp(x_i)}{\sum_{j}exp(x_j)}) \tag{3} LogSoftmax(xi)=log(jexp(xj)exp(xi))(3)
  即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的 log ⁡ e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) \textcolor{red}{\log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}} logc=1Cexp(xn,c)exp(xn,yn)

  代码示例如下:

>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
        [7., 3., 8.],
        [9., 7., 5.]])
 
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
        [0.2676, 0.0049, 0.7275],
        [0.8668, 0.1173, 0.0159]]) 
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])
        
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
        [-1.3182, -5.3182, -0.3182],
        [-0.1429, -2.1429, -4.1429]])

3. NLLLoss函数

  Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction为默认值'mean',参数weight为默认值'none'时)。

  官网介绍如下:

CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)

  参数reduction值为'none'时:
l ( x , y ) = L = { l 1 , . . . , l N } T ,   l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c   / = i g n o r e _ i n d e x } , (4) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T,\ l_n = -w_{yn}x_{n,yn}, w_c = weight[c]\cdot1\left \{ c\mathrlap{\,/}{=}ignore\_index\right \},\tag{4} l(x,y)=L={ l1,...,lN}T, ln=wynxn,yn,wc=weight[c]1{ c/=ignore_index},(4)
  其中, x x x为输入, y y y为目标值, w w w为weight, N N N为batch size。
  参数reduction值不为'none'时(默认为'mean'),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n l n , i f   r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f   r e d u c t i o n = ‘ s u m ’ . (5) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{5} l(x,y)=n=1Nn=1Nwyn1ln,ifreduction=mean;n=1Nln,ifreduction=sum.(5)
  可以看出,当reduction'mean'时,即是对 l n l_n ln求加权平均值。weight参数默认为1,因此默认情况下,即是对 l n l_n ln求平均值。又 l n = − w y n x n , y n l_n = -w_{yn}x_{n,yn} ln=wynxn,yn,所以weight为默认值1时, l n = − x n , y n l_n=-x_{n,yn} ln=xn,yn。故此时,即是 x x x求平均后取负。 这部分对于公式(2)中的 ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n   / = i g n o r e _ i n d e x } l n \textcolor{red}{\sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n} n=1Nn=1Nwyn1{ yn/=ignore_index}1ln

  实例代码验证如下:

>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
        [0.4952, 1.5870, 0.9594],
        [1.4170, 0.4525, 0.2519]])
        
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)

  平均取负有: V a l u e = − 1 3 ( 2.3858 + 0.4952 + 0.2519 ) = − 1.0443 Value = -\frac{1}{3}\left ( 2.3858+0.4952+0.2519 \right ) =-1.0443 Value=31(2.3858+0.4952+0.2519)=1.0443
  显然,平均取负结果和NLLLoss运算结果相同。

注:笔者窃以为,公式(5)中上式可写为 ∑ n = 1 N l n ∑ n = 1 N w y n \frac{\sum_{n=1}^{N}l_n}{\sum_{n=1}^{N}w_{yn}} n=1Nwynn=1Nln,如此则更容易理解。公式(2)同理。

4. 小结

  本文通过将CrossEntropyLoss拆解为LogSoftmaxNLLLoss两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。


  学习总结,以作分享,如有不妥,敬请指出。


Reference


  1. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss

  2. https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax

  3. https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/XIAOSHUCONG/article/details/123527257

智能推荐

苹果https java_apple登录 后端java实现最终版-程序员宅基地

文章浏览阅读298次。import com.alibaba.fastjson.JSONArray;import com.alibaba.fastjson.JSONObject;import com.auth0.jwk.Jwk;import com.helijia.appuser.modules.user.vo.AppleCredential;import com.helijia.common.api.model.Api..._com.auth0.jwk.jwk

NLP学习记录(六)最大熵模型MaxEnt_顺序潜在最大熵强化学习(maxent rl)-程序员宅基地

文章浏览阅读4.7k次。原理在叧掌握关于未知分布的部分信息的情况下,符合已知知识的概率分布可能有夗个,但使熵值最大的概率分布最真实地反映了事件的的分布情况,因为熵定义了随机变量的不确定性,弼熵值最大时,随机变量最不确定,最难预测其行为。最大熵模型介绍我们通过一个简单的例子来介绍最大熵概念。假设我们模拟一个翻译专家的决策过程,关于英文单词in到法语单词的翻译。我们的翻译决策模型p给每一个单词或短语分配一..._顺序潜在最大熵强化学习(maxent rl)

计算机毕业设计ssm科研成果管理系统p57gs系统+程序+源码+lw+远程部署-程序员宅基地

文章浏览阅读107次。计算机毕业设计ssm科研成果管理系统p57gs系统+程序+源码+lw+远程部署。springboot基于springboot的影视资讯管理系统。ssm基于SSM高校教师个人主页网站的设计与实现。ssm基于JAVA的求职招聘网站的设计与实现。springboot校园头条新闻管理系统。ssm基于SSM框架的毕业生离校管理系统。ssm预装箱式净水站可视化信息管理系统。ssm基于SSM的网络饮品销售管理系统。

Caused by: org.xml.sax.SAXParseException; lineNumber: 38; columnNumber: 9; cvc-complex-type.2.3: 元素_saxparseexception; linenumber: 35; columnnumber: 9-程序员宅基地

文章浏览阅读1.6w次。不知道大家有没有遇到过与我类似的报错情况,今天发生了此错误后就黏贴复制了报错信息“Caused by: org.xml.sax.SAXParseException; lineNumber: 38; columnNumber: 9; cvc-complex-type.2.3: 元素 'beans' 必须不含字符 [子级], 因为该类型的内容类型为“仅元素”。”然后就是一顿的百度啊, 可一直都没有找到..._saxparseexception; linenumber: 35; columnnumber: 9; cvc-complex-type.2.3:

计算机科学与技术创新创业意见,计算机科学与技术学院大学生创新创业工作会议成功举行...-程序员宅基地

文章浏览阅读156次。(通讯员 粟坤萍 2018-04-19)4月19日,湖北师范大学计算机科学与技术学院于教育大楼学院会议室1110成功召开大学生创新创业工作会议。参与本次会议的人员有党总支副书记黄海军老师,创新创业学院吴杉老师,计算机科学与技术学院创新创业活动指导老师,15、16、17级各班班主任及学生代表。首先吴杉老师介绍了“互联网+”全国大学生创新创业大赛的相关工作进度,动员各级班主任充分做好“大学生创新创业大..._湖北师范 吴杉

【Android逆向】爬虫进阶实战应用必知必会-程序员宅基地

文章浏览阅读1.1w次,点赞69次,收藏76次。安卓逆向技术是一门深奥且充满挑战的领域。通过本文的介绍,我们了解了安卓逆向的基本概念、常用工具、进阶技术以及实战案例分析。然而,逆向工程的世界仍然在不断发展和变化,新的技术和方法不断涌现。展望未来,随着安卓系统的不断更新和加固,逆向工程将面临更大的挑战。同时,随着人工智能和机器学习技术的发展,我们也许能够看到更智能、更高效的逆向工具和方法的出现。由于篇幅限制,本文仅对安卓逆向技术进行了介绍和案例分析。

随便推点

Python数据可视化之环形饼图_数据可视化绘制饼图或圆环图-程序员宅基地

文章浏览阅读1.1k次。制作饼图还需要下载pyecharts库,Echarts 是一个由百度开源的数据可视化,凭借着良好的交互性,精巧的图表设计,得到了众多开发者的认可。随着学习python的热潮不断增加,Python数据可视化也不停的被使用,那我今天就介绍一下Python数据可视化中的饼图。在我们的生活和学习中,编程是一项非常有用的技能,能够丰富我们的视野,为各行各业的领域提供了新的角度。环形饼图的制作并不难,主要是在于数据的打包和分组这里会有点问题,属性的标签可以去 这个网站进行修改。图中的zip压缩函数,并分组打包。_数据可视化绘制饼图或圆环图

SpringMVC开发技术~5~基于注解的控制器_jsp/servlet到controller到基于注解的控制器-程序员宅基地

文章浏览阅读325次。1 Spring MVC注解类型Controller和RequestMapping注释类型是SpringMVC API最重要的两个注释类型。基于注解的控制器的几个优点:一个控制器类可以控制几个动作,而一个实现了Controller接口的控制器只能处理一个动作。这就允许将相关操作写在一个控制器类内,从而减少应用类的数量基于注解的控制器的请求映射不需要存储在配置文件中,而是使用RequestM..._jsp/servlet到controller到基于注解的控制器

利用波特图来满足动态控制行为的要求-程序员宅基地

文章浏览阅读260次,点赞3次,收藏4次。相位裕量可以从增益图中的交越频率处读取(参见图2)。使用的开关频率、选择的外部元件(例如电感和输出电容),以及各自的工作条件(例如输入电压、输出电压和负载电流)都会产生巨大影响。图2所示为波特图中控制环路的增益曲线,其中提供了两条重要信息。对于图2所示的控制环路,这个所谓的交越频率出现在约80 kHz处。通过使用波特图,您可以查看控制环路的速度,特别是其调节稳定性。图2. 显示控制环路增益的波特图(约80 kHz时,达到0 dB交越点)。图3. 控制环路的相位曲线,相位裕量为60°。

Glibc Error: `_obstack@GLIBC_2.2.5‘ can‘t be versioned to common symbol ‘_obstack_compat‘_`_obstack@glibc_2.2.5' can't be versioned to commo-程序员宅基地

文章浏览阅读1.8k次。Error: `_obstack@GLIBC_2.2.5’ can’t be versioned to common symbol '_obstack_compat’原因:https://www.lordaro.co.uk/posts/2018-08-26-compiling-glibc.htmlThis was another issue relating to the newer binutils install. Turns out that all was needed was to initi_`_obstack@glibc_2.2.5' can't be versioned to common symbol '_obstack_compat

基于javaweb+mysql的电影院售票购票电影票管理系统(前台、后台)_电影售票系统javaweb-程序员宅基地

文章浏览阅读3k次。基于javaweb+mysql的电影院售票购票电影票管理系统(前台、后台)运行环境Java≥8、MySQL≥5.7开发工具eclipse/idea/myeclipse/sts等均可配置运行适用课程设计,大作业,毕业设计,项目练习,学习演示等功能说明前台用户:查看电影列表、查看排版、选座购票、查看个人信息后台管理员:管理电影排版,活动,会员,退票,影院,统计等前台:后台:技术框架_电影售票系统javaweb

分分钟拯救监控知识体系-程序员宅基地

文章浏览阅读95次。分分钟拯救监控知识体系本文出自:http://liangweilinux.blog.51cto.com0 监控目标我们先来了解什么是监控,监控的重要性以及监控的目标,当然每个人所在的行业不同、公司不同、业务不同、岗位不同、对监控的理解也不同,但是我们需要注意,监控是需要站在公司的业务角度去考虑,而不是针对某个监控技术的使用。监控目标1.对系统不间断实时监控:实际上是对系统不间..._不属于监控目标范畴的是 实时反馈系统当前状态

推荐文章

热门文章

相关标签