Loss Functions for NER¶
Overview¶
Named Entity Recognition faces a fundamental challenge: severe class imbalance. In typical NER datasets, the majority of tokens are labeled as 'O' (non-entity), while entity tokens are relatively rare. This imbalance can cause standard cross-entropy loss to produce models that simply predict 'O' for most tokens, achieving high accuracy but poor entity recognition.
Mistral NER provides several sophisticated loss functions designed to address this challenge, each with unique characteristics and use cases.
Available Loss Functions¶
Focal Loss (Default)¶
Focal Loss is our default choice, specifically designed to address class imbalance by down-weighting easy examples and focusing on hard ones.
Mathematical Formula
The focal loss is defined as:
Where:
- \(p_t\) is the model's estimated probability for the correct class
- \(\alpha_t\) is the weighting factor for class \(t\)
- \(\gamma\) is the focusing parameter (default: 2.0)
How It Works¶
The term \((1 - p_t)^\gamma\) acts as a modulating factor:
- When a sample is correctly classified with high confidence (\(p_t \to 1\)), the factor approaches 0, reducing the loss contribution
- When a sample is misclassified (\(p_t \to 0\)), the factor approaches 1, maintaining the loss contribution
Configuration¶
training:
loss_type: "focal"
focal_gamma: 2.0 # Higher values focus more on hard examples
focal_alpha: null # Auto-computed from class frequencies
When to Use¶
- ✅ Recommended for most NER tasks with class imbalance
- ✅ When you have many easy negative examples (non-entities)
- ✅ When rare classes are important
Example Effect¶
# Example: Effect of focal loss on different predictions
# Easy correct prediction: p_t = 0.95
# Standard CE Loss: -log(0.95) = 0.051
# Focal Loss (γ=2): -(1-0.95)² × log(0.95) = 0.00013
# Hard incorrect prediction: p_t = 0.3
# Standard CE Loss: -log(0.3) = 1.20
# Focal Loss (γ=2): -(1-0.3)² × log(0.3) = 0.59
Weighted Focal Loss¶
An extension of focal loss that explicitly incorporates class weights.
training:
loss_type: "focal"
use_class_weights: true
class_weight_type: "inverse" # or "inverse_sqrt", "effective"
Batch Balanced Focal Loss¶
Dynamically adjusts weights based on the batch composition, ensuring rare classes get proper gradients even in imbalanced batches.
Best for Extreme Imbalance
This loss is particularly effective when some entity types appear very rarely, ensuring they still contribute meaningfully to training.
training:
loss_type: "batch_balanced_focal"
batch_balance_beta: 0.999 # Re-weighting factor
focal_gamma: 2.0
Class Balanced Loss¶
Based on the "Class-Balanced Loss Based on Effective Number of Samples" paper, this loss re-weights classes based on their effective number of samples.
Effective Number Formula
Where \(n\) is the number of samples and \(\beta\) is a hyperparameter (default: 0.9999)
training:
loss_type: "class_balanced"
class_balance_beta: 0.9999
loss_base_type: "focal" # Can use focal or cross_entropy as base
Label Smoothing Loss¶
Prevents overconfidence by smoothing the target distribution.
Use with Caution
Recent research (2024) suggests label smoothing can degrade out-of-distribution detection. Consider using focal loss instead for NER tasks.
Standard Cross-Entropy¶
The baseline loss function, included for comparison.
Choosing the Right Loss Function¶
Use this decision flowchart to select the appropriate loss function:
graph TD
A[Start] --> B{Severe Class<br/>Imbalance?}
B -->|Yes| C{Batch<br/>Variation?}
B -->|No| D[Cross-Entropy<br/>with Weights]
C -->|High| E[Batch Balanced<br/>Focal Loss]
C -->|Low| F{Dataset<br/>Size?}
F -->|Large| G[Focal Loss<br/>γ=2.0]
F -->|Small| H[Class Balanced<br/>Loss]
style E fill:#9f9,stroke:#333,stroke-width:2px
style G fill:#9f9,stroke:#333,stroke-width:2px
Performance Comparison¶
Based on our experiments with CoNLL-2003:
Loss Function | F1 Score | Precision | Recall | Training Time |
---|---|---|---|---|
Focal Loss (γ=2.0) | 91.2% | 90.8% | 91.6% | 1.0x |
Batch Balanced Focal | 90.8% | 91.2% | 90.4% | 1.1x |
Class Balanced | 90.5% | 90.1% | 90.9% | 1.0x |
Weighted Cross-Entropy | 89.2% | 88.7% | 89.7% | 0.9x |
Standard Cross-Entropy | 86.5% | 87.8% | 85.2% | 0.9x |
Label Smoothing | 88.1% | 89.5% | 86.7% | 1.0x |
Advanced Configuration¶
Auto-Computing Class Weights¶
The system can automatically compute optimal class weights from your training data:
training:
loss_type: "focal"
focal_alpha: null # Will be auto-computed
use_class_weights: true
class_weight_type: "inverse" # Options: inverse, inverse_sqrt, effective
class_weight_smoothing: 1.0 # Smoothing factor to prevent extreme weights
Custom Loss Parameters¶
You can fine-tune loss parameters for your specific use case:
training:
loss_type: "focal"
focal_gamma: 3.0 # Increase for harder focus on difficult examples
focal_alpha: [1.0, 2.5, 2.5, 3.0, 3.0, 2.0, 2.0, 2.5, 2.5] # Per-class weights
Implementation Details¶
Integration with Training¶
The custom trainer automatically:
- Computes class frequencies from the training dataset
- Calculates appropriate class weights
- Initializes the selected loss function
- Overrides the
compute_loss
method
# Example: How loss is integrated (simplified)
class NERTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Compute class frequencies
frequencies = compute_class_frequencies(train_dataset)
# Create loss function
self.loss_fn = create_loss_function(
loss_type=config.loss_type,
num_labels=config.num_labels,
class_frequencies=frequencies,
**loss_params
)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
loss = self.loss_fn(outputs.logits, labels)
return (loss, outputs) if return_outputs else loss
Memory Efficiency¶
All loss functions are implemented with memory efficiency in mind:
- Flatten tensors only when necessary
- Handle sparse labels efficiently
- Support gradient accumulation
- Work with mixed precision training
Best Practices¶
1. Start with Focal Loss¶
For most NER tasks, start with focal loss with default parameters:
2. Monitor Per-Class Metrics¶
Always monitor per-class F1 scores to ensure rare classes are learning:
# The evaluation automatically reports per-class metrics
# Check logs for entries like:
# Class B-PER: P=0.92 R=0.89 F1=0.90
# Class I-MISC: P=0.78 R=0.72 F1=0.75
3. Adjust Based on Results¶
- If rare classes have low recall → Increase focal_gamma or use batch balanced loss
- If common classes have low precision → Reduce class weights
- If model is overconfident → Consider label smoothing (with caution)
4. Combine with Other Techniques¶
Loss functions work best when combined with:
- Data augmentation for rare classes
- Proper learning rate scheduling
- Gradient accumulation for stable training
Examples¶
Example 1: Extreme Imbalance (99% 'O' tokens)¶
training:
loss_type: "batch_balanced_focal"
batch_balance_beta: 0.999
focal_gamma: 3.0 # Higher gamma for extreme imbalance
focal_alpha: null
Example 2: Multi-lingual NER with Varying Class Distributions¶
training:
loss_type: "class_balanced"
class_balance_beta: 0.9999
loss_base_type: "focal"
focal_gamma: 2.0
Example 3: High-Precision Requirements¶
training:
loss_type: "focal"
focal_gamma: 2.0
# Manual weights favoring precision
focal_alpha: [1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]
Troubleshooting¶
Issue: Model predicts mostly 'O' labels¶
Solution: Increase focal_gamma or switch to batch balanced focal loss:
Issue: Rare classes have zero predictions¶
Solution: Use batch balanced focal loss with aggressive re-weighting:
training:
loss_type: "batch_balanced_focal"
batch_balance_beta: 0.99 # Lower beta for stronger re-weighting
Issue: Training is unstable¶
Solution: Reduce learning rate and use weight smoothing:
training:
learning_rate: 1e-5 # Reduced
loss_type: "focal"
class_weight_smoothing: 10.0 # Higher smoothing
References¶
- Focal Loss for Dense Object Detection - Lin et al., 2017
- Class-Balanced Loss Based on Effective Number of Samples - Cui et al., 2019
- When Does Label Smoothing Help? - Müller et al., 2019
Next Steps
After selecting your loss function, explore our Datasets Guide to understand how different dataset characteristics interact with loss functions.