PyTorch中的CrossEntropyLoss与交叉熵计算不一致

本文主要对于交叉熵的手动计算和PyTorch中的CrossEntropyLoss模块计算结果不一致的问题展开讨论,查阅了PyTorch的官方文档,最终发现是CrossEntropyLoss在计算交叉熵之前会对输入的概率分布进行一次SoftMax操作导致的。

在强化学习中,策略学习常用到一个损失函数为\(l=-\ln\pi_\theta(a|s)\cdot g\),其中\(\pi_\theta\)在状态\(s\)下是关于动作的一个概率分布,而动作\(a\)是经验中记录的,在状态\(s\)下选择的确定动作。因此有: \[ -\ln\pi_\theta(a|s) = -\sum_{a'\in A}p(a')\cdot \ln q(a') \] \[ p(a') = \left\{ \begin{array}{lr} 1 &&& a'=a\\ 0 &&& otherwise \end{array} \right. \]

\[ q(a') = \pi_\theta(a'|s) \]

因此,该损失函数便被转换为了计算两个概率分布之间交叉熵的计算形式。从而可以想到,在使用PyTorch计算损失函数时,使用内置的torch.nn.functional.cross_entropy函数(下文简称为F.cross_entropy函数)来进行计算。然而,在实践过程中,发现使用该函数计算出的结果与手动计算出的结果不一致,由此展开了一系列调研。

不一致的计算结果

首先,我们使用Python分别以手动计算和使用F.cross_entropy函数计算两组数据的交叉熵,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.nn import functional as F

x = torch.FloatTensor([[0.4, 0.6],
[0.7, 0.3]])
y = torch.LongTensor([0, 1])

loss_1 = -torch.log(x.gather(1, y.view(-1,1)))
loss_2 = F.cross_entropy(x, y, reduction = "none")

print("手动计算出的交叉熵:\n{}".format(loss_1.squeeze(1)))
print()
print("CrossEntropyLoss计算出的交叉熵:\n{}".format(loss_2))

运行上述代码,得到的计算结果如下:

1
2
3
4
5
手动计算出的交叉熵:
tensor([0.9163, 1.2040])

CrossEntropyLoss计算出的交叉熵:
tensor([0.7981, 0.9130])

从运行结果可以看出,两者的计算结果并不一致,由此,我查阅了PyTorch官方文档来了解F.cross_entropy的实现。

PyTorch 文档描述

文档中关于F.cross_entropy函数的描述中并没有包含具体的计算过程,只对输入数据和输出结果的维度对应关系作了解释1。不过在该函数的介绍中有这么一句话:

See CrossEntropyLoss for details.

于是便转而查找CrossEntropyLoss的文档2,终于找到了Pytorch计算交叉熵的计算过程:

可以看到,官方文档关于交叉熵的计算方式阐述的十分清楚。概括来说就是,F.cross_entropy至少需要两个参数,一个是预测出的概率分布,一个是目标真实类别索引。比较重要的一点在于,F.cross_entropy这个函数是没有要求的,即不要求和等于1,也不要求每一项大于0。这是因为该函数在计算交叉熵之间,会先对输入的概率分布进行一次SoftMax操作。

在计算交叉熵之前进行SoftMax操作提高对输入的宽容度,但是若在构建神经网络时,已经在输出之前进行了SoftMax操作就会导致两次SoftMax使loss的计算失真,也就是上一节的计算结果不一致。

实验验证

根据PyTorch的官方文档,如果我们在手动计算交叉熵之间,加一个SoftMax操作,那么便可以得到和函数F.cross_entropy一样的计算结果,使用如下代码进行验证:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.nn import functional as F

x = torch.FloatTensor([[0.4, 0.6],
[0.7, 0.3]])
y = torch.LongTensor([0, 1])

loss_1 = -torch.log(F.softmax(x, dim=-1).gather(1, y.view(-1,1)))
loss_2 = F.cross_entropy(x, y, reduction = "none")

print("手动计算出的交叉熵:\n{}".format(loss_1.squeeze(1)))
print()
print("CrossEntropyLoss计算出的交叉熵:\n{}".format(loss_2))

运行上述代码,得到的输出结果如下:

1
2
3
4
5
手动计算出的交叉熵:
tensor([0.7981, 0.9130])

CrossEntropyLoss计算出的交叉熵:
tensor([0.7981, 0.9130])

Reference


  1. torch.nn.functional.cross_entropy — PyTorch 1.13 documentation↩︎

  2. CrossEntropyLoss — PyTorch 1.13 documentation↩︎