Understanding Callbacks
Training Control Through Callbacks
Section 3.9 - What Are Callbacks?
In programming, callbacks are functions passed as arguments to other functions, to be executed at specific points. In deep learning, callbacks allow us to: - Monitor training progress - Save model checkpoints - Adjust training parameters - Stop training when needed
Basic Callback Structure
class CustomCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
# Called at start of each epoch
pass
def on_epoch_end(self, epoch, logs=None):
# Called at end of each epoch
pass
def on_batch_begin(self, batch, logs=None):
# Called at start of each batch
pass
def on_batch_end(self, batch, logs=None):
# Called at end of each batch
passSection 3.10 - Essential Callbacks
1. ModelCheckpoint
Saves model weights during training:
checkpoint_cb = keras.callbacks.ModelCheckpoint(
'best_model.h5',
save_best_only=True, # Only save when model improves
monitor='val_loss', # Metric to monitor
mode='min' # Lower is better
)Use cases: - Save best model during training - Resume training from checkpoints - Ensemble multiple checkpoints
2. EarlyStopping
Stops training when model stops improving:
early_stopping = keras.callbacks.EarlyStopping(
monitor='val_loss', # Metric to watch
patience=10, # Number of epochs to wait
restore_best_weights=True # Restore best model
)Benefits: - Prevents overfitting - Saves computation time - Automatically selects best epoch
3. ReduceLROnPlateau
Adjusts learning rate when progress stalls:
reduce_lr = keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5, # Multiply LR by this factor
patience=5, # Epochs to wait
min_lr=1e-6 # Minimum LR allowed
)Operation: - Monitors validation metric - Reduces learning rate when stuck - Helps fine-tune convergence
Section 3.11 - Advanced Callbacks
1. Custom Learning Rate Scheduler
class CosineAnnealingCallback(keras.callbacks.Callback):
def __init__(self, total_epochs, min_lr=1e-6):
super().__init__()
self.total_epochs = total_epochs
self.min_lr = min_lr
def on_epoch_begin(self, epoch, logs=None):
# Cosine annealing formula
progress = epoch / self.total_epochs
cosine = 0.5 * (1 + np.cos(np.pi * progress))
new_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine
K.set_value(self.model.optimizer.lr, new_lr)2. Training Progress Logger
class MetricsLogger(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
# Log metrics to file or system
metrics = {
'epoch': epoch,
'loss': logs['loss'],
'val_loss': logs['val_loss'],
'lr': K.get_value(self.model.optimizer.lr)
}
self.log_metrics(metrics)3. Gradient Monitor
class GradientMonitor(keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
gradients = self.get_gradients()
if np.any(np.isnan(gradients)):
print("Warning: NaN gradients detected")Section 3.12 - Practical Applications
Complete Training Setup
callbacks = [
# Save best model
keras.callbacks.ModelCheckpoint(
'best_model.h5',
save_best_only=True,
monitor='val_loss'
),
# Early stopping
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
# Learning rate reduction
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5
),
# Custom logging
MetricsLogger()
]
# Use in training
model.fit(
X_train, y_train,
epochs=100,
validation_split=0.2,
callbacks=callbacks
)- Callback Order:
- Monitoring callbacks first
- LR schedulers next
- Early stopping last
- Resource Management:
- Use appropriate file paths
- Clean up old checkpoints
- Monitor memory usage
- Error Handling:
- Catch and log exceptions
- Implement graceful stopping
- Save progress on interrupts
Section 3.13 - Common Use Cases
1. Research and Development
# Experimental setup
callbacks = [
# Save frequent checkpoints
ModelCheckpoint('model_{epoch:02d}.h5',
save_freq='epoch'),
# Detailed logging
TensorBoard(log_dir='./logs'),
# Multiple metrics monitoring
EarlyStopping(monitor='val_loss', patience=10),
EarlyStopping(monitor='val_accuracy', patience=15)
]2. Production Training
# Production setup
callbacks = [
# Save best model only
ModelCheckpoint('best_model.h5',
save_best_only=True),
# Conservative early stopping
EarlyStopping(patience=20),
# Gradual LR reduction
ReduceLROnPlateau(factor=0.2,
patience=10)
]Remember that callbacks can significantly impact training time and resource usage. Choose and configure them based on your specific needs and constraints.