Understanding Callbacks
Training Control Through Callbacks
Section 3.9 - What Are Callbacks?
In programming, callbacks are functions passed as arguments to other functions, to be executed at specific points. In deep learning, callbacks allow us to: - Monitor training progress - Save model checkpoints - Adjust training parameters - Stop training when needed
Basic Callback Structure
class CustomCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
# Called at start of each epoch
pass
def on_epoch_end(self, epoch, logs=None):
# Called at end of each epoch
pass
def on_batch_begin(self, batch, logs=None):
# Called at start of each batch
pass
def on_batch_end(self, batch, logs=None):
# Called at end of each batch
pass
Section 3.10 - Essential Callbacks
1. ModelCheckpoint
Saves model weights during training:
= keras.callbacks.ModelCheckpoint(
checkpoint_cb 'best_model.h5',
=True, # Only save when model improves
save_best_only='val_loss', # Metric to monitor
monitor='min' # Lower is better
mode )
Use cases: - Save best model during training - Resume training from checkpoints - Ensemble multiple checkpoints
2. EarlyStopping
Stops training when model stops improving:
= keras.callbacks.EarlyStopping(
early_stopping ='val_loss', # Metric to watch
monitor=10, # Number of epochs to wait
patience=True # Restore best model
restore_best_weights )
Benefits: - Prevents overfitting - Saves computation time - Automatically selects best epoch
3. ReduceLROnPlateau
Adjusts learning rate when progress stalls:
= keras.callbacks.ReduceLROnPlateau(
reduce_lr ='val_loss',
monitor=0.5, # Multiply LR by this factor
factor=5, # Epochs to wait
patience=1e-6 # Minimum LR allowed
min_lr )
Operation: - Monitors validation metric - Reduces learning rate when stuck - Helps fine-tune convergence
Section 3.11 - Advanced Callbacks
1. Custom Learning Rate Scheduler
class CosineAnnealingCallback(keras.callbacks.Callback):
def __init__(self, total_epochs, min_lr=1e-6):
super().__init__()
self.total_epochs = total_epochs
self.min_lr = min_lr
def on_epoch_begin(self, epoch, logs=None):
# Cosine annealing formula
= epoch / self.total_epochs
progress = 0.5 * (1 + np.cos(np.pi * progress))
cosine = self.min_lr + (self.initial_lr - self.min_lr) * cosine
new_lr self.model.optimizer.lr, new_lr) K.set_value(
2. Training Progress Logger
class MetricsLogger(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
# Log metrics to file or system
= {
metrics 'epoch': epoch,
'loss': logs['loss'],
'val_loss': logs['val_loss'],
'lr': K.get_value(self.model.optimizer.lr)
}self.log_metrics(metrics)
3. Gradient Monitor
class GradientMonitor(keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
= self.get_gradients()
gradients if np.any(np.isnan(gradients)):
print("Warning: NaN gradients detected")
Section 3.12 - Practical Applications
Complete Training Setup
= [
callbacks # Save best model
keras.callbacks.ModelCheckpoint('best_model.h5',
=True,
save_best_only='val_loss'
monitor
),
# Early stopping
keras.callbacks.EarlyStopping(='val_loss',
monitor=10,
patience=True
restore_best_weights
),
# Learning rate reduction
keras.callbacks.ReduceLROnPlateau(='val_loss',
monitor=0.5,
factor=5
patience
),
# Custom logging
MetricsLogger()
]
# Use in training
model.fit(
X_train, y_train,=100,
epochs=0.2,
validation_split=callbacks
callbacks )
- Callback Order:
- Monitoring callbacks first
- LR schedulers next
- Early stopping last
- Resource Management:
- Use appropriate file paths
- Clean up old checkpoints
- Monitor memory usage
- Error Handling:
- Catch and log exceptions
- Implement graceful stopping
- Save progress on interrupts
Section 3.13 - Common Use Cases
1. Research and Development
# Experimental setup
= [
callbacks # Save frequent checkpoints
'model_{epoch:02d}.h5',
ModelCheckpoint(='epoch'),
save_freq
# Detailed logging
='./logs'),
TensorBoard(log_dir
# Multiple metrics monitoring
='val_loss', patience=10),
EarlyStopping(monitor='val_accuracy', patience=15)
EarlyStopping(monitor ]
2. Production Training
# Production setup
= [
callbacks # Save best model only
'best_model.h5',
ModelCheckpoint(=True),
save_best_only
# Conservative early stopping
=20),
EarlyStopping(patience
# Gradual LR reduction
=0.2,
ReduceLROnPlateau(factor=10)
patience ]
Remember that callbacks can significantly impact training time and resource usage. Choose and configure them based on your specific needs and constraints.