Base Trainer Class¶
The abstract Trainer base class provides the core functionality for all trainers in FENN. It is not meant to be used directly—instead, use one of the specialized trainer classes:
Overview¶
The base Trainer class handles:
- Device management and model setup
- Checkpoint configuration and loading
- Early stopping logic
- Training state management
- Model summary logging
Core Methods¶
fit()¶
Executes the complete training cycle. Implemented by concrete trainer classes.
def fit(
train_loader: DataLoader,
epochs: int,
val_loader: Optional[DataLoader] = None,
val_epochs: int = 1
) -> None:
"""Train the model with optional validation and early stopping.
Args:
train_loader: DataLoader for training data
epochs: Total number of epochs to train for
val_loader: DataLoader for validation data (optional)
val_epochs: How often to evaluate on validation set (in epochs)
"""
predict()¶
Performs inference on a dataset or batch. Implemented by concrete trainer classes.
def predict(dataloader_or_batch: Union[DataLoader, torch.Tensor]):
"""Predicts the output of the model for a given dataloader or batch.
Args:
dataloader_or_batch: A DataLoader or a torch tensor.
Returns:
list: A list of predictions.
"""
Life Cycle¶
- Initialization: Model is moved to device, checkpoint config is setup, early stopping is configured
- Training:
fit()trains the model for specified epochs - Checkpointing: Checkpoints are saved according to configuration
- Prediction:
predict()can be called on trained models - Resuming: Use
load_checkpoint()methods to resume from saved states
Internal State Management¶
The trainer maintains a TrainingState object that tracks:
- Current epoch
- Training loss per epoch
- Validation loss per epoch
- Model and optimizer states
- Best validation parameters
- Metrics (accuracy, F1, etc.)
This state is automatically saved to checkpoints and restored when loading.
Checkpoint Methods¶
All trainers inherit checkpoint management methods:
def load_checkpoint(checkpoint_path: Union[str, Path]):
"""Load a checkpoint from a given path."""
def load_checkpoint_at_epoch(epoch: int):
"""Load the checkpoint at the given epoch."""
def load_best_checkpoint():
"""Load the best checkpoint."""
See Checkpoint Management for detailed usage.
Device Management¶
# Model is automatically moved to device during initialization
trainer = ClassificationTrainer(
model=model,
loss_fn=loss_fn,
optim=optimizer,
num_classes=10,
device="cuda" # or "cpu", "mps"
)
The trainer automatically: - Moves the model to the specified device - Moves batches to the device during training - Handles variable batch shapes
Early Stopping¶
Configure early stopping via early_stopping_patience:
trainer = ClassificationTrainer(
model=model,
loss_fn=loss_fn,
optim=optimizer,
num_classes=10,
early_stopping_patience=5 # Stop after 5 epochs without improvement
)
See Advanced Features for detailed early stopping behavior.