Skip to content

Advanced Features

This page covers advanced trainer features available across all trainer types.

Table of Contents

Early Stopping

Early stopping automatically terminates training when the model stops improving. This prevents overfitting and saves computational resources.

Configuration

Set early_stopping_patience when creating a trainer:

from fenn.nn.trainers import ClassificationTrainer

trainer = ClassificationTrainer(
    model=model,
    loss_fn=loss_fn,
    optim=optimizer,
    num_classes=10,
    early_stopping_patience=5  # Stop after 5 epochs without improvement
)

How It Works

The trainer maintains a patience counter: - Counter increments by 1 each epoch without improvement - Counter resets to 0 when loss improves - Training stops when counter reaches early_stopping_patience

Epoch 1: Loss = 0.500 → Best loss = 0.500, Patience = 0
Epoch 2: Loss = 0.480 → Best loss = 0.480, Patience = 0
Epoch 3: Loss = 0.485 → No improvement, Patience = 1
Epoch 4: Loss = 0.490 → No improvement, Patience = 2
Epoch 5: Loss = 0.475 → Best loss = 0.475, Patience = 0
Epoch 6: Loss = 0.478 → No improvement, Patience = 1
...
Epoch 10: Loss = 0.476 → No improvement, Patience = 5 → STOP

Behavior Modes

The early stopping behavior depends on whether you provide a validation loader:

Mode 1: No Validation Loader

Monitors training loss:

trainer = ClassificationTrainer(
    ...,
    early_stopping_patience=5
)
trainer.fit(train_loader, epochs=100)

When to use: - Validation data unavailable - Quick experiments - Note: Less reliable than validation-based stopping

Mode 2: With Validation Loader

Monitors validation loss:

trainer = ClassificationTrainer(
    ...,
    early_stopping_patience=5
)
trainer.fit(train_loader, epochs=100, val_loader=val_loader)

When to use: - You have validation data (recommended) - Preventing overfitting is a priority - Note: Most reliable approach

Examples

Basic Early Stopping

from fenn.nn.trainers import ClassificationTrainer

trainer = ClassificationTrainer(
    model=model,
    loss_fn=loss_fn,
    optim=optimizer,
    num_classes=10,
    early_stopping_patience=10
)

trainer.fit(
    train_loader=train_loader,
    epochs=200,  # May terminate early
    val_loader=val_loader
)

With Best Model Saving

from fenn.nn.trainers import ClassificationTrainer
from fenn.nn.utils import Checkpoint
from pathlib import Path

checkpoint_config = Checkpoint(
    directory=Path("./checkpoints"),
    save_best=True,
)

trainer = ClassificationTrainer(
    model=model,
    loss_fn=loss_fn,
    optim=optimizer,
    num_classes=10,
    early_stopping_patience=10,
    checkpoint_config=checkpoint_config
)

trainer.fit(
    train_loader=train_loader,
    epochs=200,
    val_loader=val_loader
)

# Model automatically contains best weights
# Checkpoint also saved as checkpoint_best.pt

Aggressive vs. Patient Early Stopping

# Aggressive: Stop quickly if not improving
aggressive_trainer = ClassificationTrainer(
    ...,
    early_stopping_patience=3  # Stop after 3 epochs
)

# Patient: Give model more time to improve
patient_trainer = ClassificationTrainer(
    ...,
    early_stopping_patience=20  # Stop after 20 epochs
)

Tips

  • Start with 5-10 epochs: Good default for most problems
  • Increase for noisy validation: If validation loss is volatile, use higher patience
  • Decrease for clear improvement: If validation quickly plateaus, use lower patience
  • Always use validation data: Early stopping based on validation is more reliable

Training State Management

The trainer automatically manages training state through a TrainingState object.

What Gets Tracked

The TrainingState object tracks:

TrainingState(
    epoch=10,                          # Current epoch
    train_loss=0.245,                  # Training loss
    val_loss=0.312,                    # Validation loss (if applicable)
    best_train_loss=0.240,             # Best training loss
    best_val_loss=0.305,               # Best validation loss
    best_epoch=8,                      # Epoch with best validation loss
    patience_counter=2,                # Epochs without improvement
    model_state_dict={...},            # Model weights
    optimizer_state_dict={...},        # Optimizer state
    acc=0.94,                          # Accuracy (classification)
    best_acc=0.95,                     # Best accuracy
    # ... other metrics
)

Accessing Training State

Training state is automatically saved to checkpoints. To access it:

# After training
trainer.fit(train_loader, epochs=100)

# State is in trainer._state
print(f"Final epoch: {trainer._state.epoch}")
print(f"Best validation loss: {trainer._state.best_val_loss}")
print(f"Best accuracy: {trainer._state.best_acc}")

State Restoration

When loading checkpoints, state is automatically restored:

trainer.load_checkpoint_at_epoch(50)
# Trainer state now matches epoch 50

print(f"Current epoch: {trainer._state.epoch}")  # 50
print(f"Current loss: {trainer._state.val_loss}")  # Loss at epoch 50

Progress Tracking

Training progress is displayed using a rich progress bar:

Epoch 45/100 [████████████░░░░░░] 45% | Train Mean Loss: 0.3245
Epoch 45/100 [████████████░░░░░░] 45% | Train Loss: 0.3245 | Val Loss: 0.3512 | Val Acc: 0.9421

Information Displayed

  • Epoch counter: Current epoch / total epochs
  • Progress bar: Visual representation of training progress
  • Time elapsed: Time spent training so far
  • Training loss: Average loss on training set
  • Validation loss: Loss on validation set (if applicable)
  • Validation accuracy: Accuracy on validation set (for classification)

Validation Metrics

Different trainers calculate different validation metrics.

ClassificationTrainer Metrics

During validation, calculates:

Metric Description Range
Loss Classification loss function value 0-∞
Accuracy % of correct predictions 0-1
Precision % of correct positive predictions 0-1
Recall % of true positives found 0-1
F1 Score Harmonic mean of precision & recall 0-1

Example output:

Epoch 10/100 - Train Loss: 0.3245 | Val Loss: 0.3102 | Val Acc: 0.9512

RegressionTrainer Metrics

During validation, calculates:

Metric Description Range
Loss Regression loss function value 0-∞
R² Score Coefficient of determination -∞ to 1

Example output:

Epoch 10/100 - Train Loss: 0.0245 | Val Loss: 0.0312 | R² Score: 0.8945

Understanding Metrics

Accuracy

  • What it means: Proportion of correct predictions
  • When it matters: Good overall metric for balanced datasets
  • Limitation: Can be misleading with imbalanced data

Precision

  • What it means: Of all positive predictions, how many were correct?
  • Formula: TP / (TP + FP)
  • Use when: False positives are costly (e.g., cancer diagnosis)

Recall

  • What it means: Of all true positives, how many did we find?
  • Formula: TP / (TP + FN)
  • Use when: False negatives are costly (e.g., disease detection)

F1 Score

  • What it means: Balanced harmonic mean of precision and recall
  • Formula: 2 * (precision * recall) / (precision + recall)
  • Use when: You want to balance precision and recall

R² Score

  • What it means: Proportion of variance explained by model
  • Range: -∞ to 1 (1 is perfect, 0 is baseline, <0 is worse than baseline)
  • Use when: Evaluating regression models

Complete Example: Early Stopping with All Features

import torch
from torch import nn
from fenn.nn.trainers import ClassificationTrainer
from fenn.nn.utils import Checkpoint
from pathlib import Path

# Setup
model = ResNet50(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Checkpoint configuration with best model saving
checkpoint_config = Checkpoint(
    directory=Path("./checkpoints"),
    save_best=True,
)

# Create trainer with early stopping
trainer = ClassificationTrainer(
    model=model,
    loss_fn=loss_fn,
    optim=optimizer,
    num_classes=10,
    device="cuda",
    checkpoint_config=checkpoint_config,
    early_stopping_patience=10  # Stop after 10 epochs without improvement
)

# Train with validation every epoch
trainer.fit(
    train_loader=train_dataloader,
    epochs=200,  # May stop early
    val_loader=val_dataloader,
    val_epochs=1
)

# Access final state
print(f"Stopped at epoch: {trainer._state.epoch}")
print(f"Best validation loss: {trainer._state.best_val_loss:.4f}")
print(f"Best accuracy: {trainer._state.best_acc:.4f}")

# Get predictions (using best model)
predictions = trainer.predict(test_dataloader)

See Also