Loss for an example is just the negative of the log-probability that the model assigned to the true class?

Blurb published in December, 2025

The goal of this blurb is to understand why many cross-entropy implementations compute the loss for an example as simply the negative of the log-probability that the model assigned to the true class.

The simplification from one-hot encoding

We start from the cross-entropy equation previously studied in this blurb: $$H(P, Q) = -\sum_i P_i \log Q_i$$ where \(P\) is the true distribution and \(Q\) is the model's predicted distribution.

Expanding this for a 3-class example, we get: $$H(P, Q) = -\left( P_0 \log Q_0 + P_1 \log Q_1 + P_2 \log Q_2 \right)$$ Now, many true distributions that we study come in the form of a one-hot encoded vector, which is also the case for language modeling. If class 2 is the correct class, then \(P = [0, 0, 1]\). Therefore, all terms where \(P_i = 0\) vanish: $$H(P, Q) = -\left( 0 \cdot \log Q_0 + 0 \cdot \log Q_1 + 1 \cdot \log Q_2 \right) = -\log Q_2$$ We are left with just the negative log-probability of the true class.

In practice with neural networks

In practice, given a neural network that classifies between \(K\) classes, the final layer outputs \(K\) logits which don't sum to 1. We force the neural network to model a probability distribution by applying softmax, then computing the loss and backpropagating. Softmax converts logits to probabilities: $$\text{softmax}(z)_i = \frac{e^{z_i}}{\sum_j e^{z_j}}$$ Throughout this blurb, we assume logits have shape (batch_size, seq_len, num_classes) for language modeling, where the last dimension is always the class dimension.

def softmax(logits):
    exp_logits = torch.exp(logits)
    return exp_logits / exp_logits.sum(dim=-1, keepdim=True)

This converts logits into a valid probability distribution that sums to 1.

PyTorch in practice has log_softmax, which is equivalent to doing log(softmax(...)) but in a numerically stable manner. A simple implementation subtracts the max before exponentiating:

def log_softmax(logits):
    logits_stable = logits - logits.max(dim=-1, keepdim=True).values
    return logits_stable - torch.log(torch.exp(logits_stable).sum(dim=-1, keepdim=True))

Then following the equation we derived: $$\text{Loss} = -\log Q_{\text{true class}}$$ the loss for an example can be computed by taking the negative log-probability at the index of the true class:

def cross_entropy_loss(logits, true_labels):
    log_probs = log_softmax(logits)
    return -log_probs.gather(dim=-1, index=true_labels.unsqueeze(-1)).squeeze(-1)