The CrossEntropyLoss in PyTorch is Inconsistent With The Calculation of Cross-Entropy.

This is an automatically translated post by LLM. The original post is in Chinese. If you find any translation errors, please leave a comment to help me improve the translation. Thanks!

This article mainly discusses the inconsistency between the manual calculation of cross-entropy and the results obtained by the CrossEntropyLoss module in PyTorch. After consulting the official documentation of PyTorch, it was found that the inconsistency was caused by the SoftMax operation performed by CrossEntropyLoss on the input probability distribution before calculating the cross-entropy.

In reinforcement learning, the loss function commonly used in policy learning is \(l=-\ln\pi_\theta(a|s)\cdot g\), where \(\pi_\theta\) is a probability distribution over actions given state \(s\), and \(a\) is the action selected in state \(s\). Therefore, we have:

\[ -\ln\pi_\theta(a|s) = -\sum_{a'\in A}p(a')\cdot \ln q(a') \]

\[ p(a') = \left\{ \begin{array}{lr} 1 &&& a'=a\\ 0 &&& otherwise \end{array} \right. \]

\[ q(a') = \pi_\theta(a'|s) \]

Thus, this loss function is transformed into the calculation of cross-entropy between two probability distributions. Therefore, we can use the built-in torch.nn.functional.cross_entropy function (referred to as the F.cross_entropy function below) in PyTorch to calculate the loss function. However, in practice, it was found that the results calculated using this function were inconsistent with the results calculated manually, which led to a series of investigations.

Firstly, we used Python to manually calculate the cross-entropy of two sets of data and the cross-entropy calculated using the F.cross_entropy function, as shown in the code below:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.nn import functional as F

x = torch.FloatTensor([[0.4, 0.6],
[0.7, 0.3]])
y = torch.LongTensor([0, 1])

loss_1 = -torch.log(x.gather(1, y.view(-1,1)))
loss_2 = F.cross_entropy(x, y, reduction = "none")

print("Manually calculated cross-entropy:\n{}".format(loss_1.squeeze(1)))
print()
print("CrossEntropyLoss calculated cross-entropy:\n{}".format(loss_2))

The results of the above code are as follows:

1
2
3
4
5
Manually calculated cross-entropy:
tensor([0.9163, 1.2040])

CrossEntropyLoss calculated cross-entropy:
tensor([0.7981, 0.9130])

From the results, it can be seen that the two calculation results are not consistent. Therefore, we consulted the official documentation of PyTorch to understand the implementation of F.cross_entropy.

The description of the F.cross_entropy function in the documentation does not include the specific calculation process, only explaining the correspondence between the input data and the output result dimensions 1. However, there is a sentence in the introduction of this function:

See CrossEntropyLoss for details.

So we turned to the documentation of CrossEntropyLoss 2 and finally found the calculation process of cross-entropy in PyTorch:

It can be seen that the official documentation on the calculation of cross-entropy is very clear. In summary, the F.cross_entropy function requires at least two parameters, one is the predicted probability distribution, and the other is the index of the target true class. The important point is that the F.cross_entropy function does not require the input probability distribution to sum to 1 or each item to be greater than 0. This is because the function performs a SoftMax operation on the input probability distribution before calculating the cross-entropy.

Performing the SoftMax operation before calculating the cross-entropy improves the tolerance of the input, but if the SoftMax operation has been performed before the output is constructed in the neural network, it will cause the calculation of loss to be distorted, that is, the calculation results of the previous section are inconsistent.

According to the official documentation of PyTorch, if we add a SoftMax operation to the manual calculation of cross-entropy, we can get the same calculation result as the F.cross_entropy function. The following code is used to verify this:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.nn import functional as F

x = torch.FloatTensor([[0.4, 0.6],
[0.7, 0.3]])
y = torch.LongTensor([0, 1])

loss_1 = -torch.log(F.softmax(x, dim=-1).gather(1, y.view(-1,1)))
loss_2 = F.cross_entropy(x, y, reduction = "none")

print("Manually calculated cross-entropy:\n{}".format(loss_1.squeeze(1)))
print()
print("CrossEntropyLoss calculated cross-entropy:\n{}".format(loss_2))

The output of the above code is as follows:

1
2
3
4
5
Manually calculated cross-entropy:
tensor([0.7981, 0.9130])

CrossEntropyLoss calculated cross-entropy:
tensor([0.7981, 0.9130])

Reference


  1. torch.nn.functional.cross_entropy — PyTorch 1.13 documentation↩︎

  2. CrossEntropyLoss — PyTorch 1.13 documentation↩︎