Understanding the Training Loop
The Training Loop: How Models Learn
Section 3.1 - What Happens During Training?
When we call model.fit()
in Keras, we’re initiating a complex process that repeatedly updates the model’s weights to minimize the loss function. Let’s understand what happens under the hood.
The Basic Loop Structure
At its core, training follows this pattern:
# Conceptual implementation of training loop
for epoch in range(n_epochs):
for batch_idx in range(n_batches):
= get_batch(batch_idx)
X_batch, y_batch with GradientTape() as tape:
= model(X_batch)
y_pred = loss_function(y_batch, y_pred)
loss = tape.gradient(loss, model.weights)
gradients zip(gradients, model.weights)) optimizer.apply_gradients(
Section 3.2 - Step-by-Step Breakdown
1. Batch Selection
def get_batch(idx, batch_size):
= idx * batch_size
start_idx = start_idx + batch_size
end_idx return X[start_idx:end_idx], y[start_idx:end_idx]
This step: - Selects a subset of training data - Provides manageable chunks for processing - Enables stochastic gradient descent
While full-batch gradient descent would use all data at once, mini-batch training offers: - Better generalization - Lower memory requirements - Faster iterations
2. Forward Pass
# Inside model's call method
def call(self, inputs):
# Layer 1
= self.dense1(inputs)
x = self.activation1(x)
x
# Layer 2
= self.dense2(x)
x = self.activation2(x)
x
# Output layer
return self.output_layer(x)
During this phase: - Data flows through the network - Each layer performs its computations - Activations are stored for backpropagation
3. Loss Computation
def compute_loss(y_true, y_pred):
# Example: Mean Squared Error
return tf.reduce_mean(tf.square(y_true - y_pred))
The loss function: - Measures prediction error - Provides optimization target - Guides weight updates
4. Gradient Computation
def compute_gradients(tape, loss, weights):
# Automatic differentiation
= tape.gradient(loss, weights)
gradients return gradients
During backpropagation: - Gradients flow backwards through network - Chain rule applied automatically - Each weight’s contribution calculated
5. Weight Updates
def apply_updates(optimizer, gradients, weights):
# Basic gradient descent update
for g, w in zip(gradients, weights):
* g) w.assign_sub(learning_rate
The update step: - Modifies weights based on gradients - Scaled by learning rate - Controlled by optimizer logic
Section 3.3 - Memory Management During Training
Forward Pass Storage
During the forward pass, we need to store: 1. Layer inputs for gradient computation 2. Intermediate activations 3. Final outputs for loss calculation
Memory usage scales with: - Batch size - Model depth - Layer sizes
Gradient Computation Requirements
The backward pass requires: 1. Access to forward pass activations 2. Memory for gradient computations 3. Temporary storage for intermediate results
Section 3.4 - Training Loop Variations
Basic Training
# Standard training loop
for epoch in range(epochs):
for batch in train_dataset:
train_step(model, batch)
Training with Validation
# Training with validation checks
for epoch in range(epochs):
# Training
for batch in train_dataset:
train_step(model, batch)
# Validation
for batch in val_dataset:
validate_step(model, batch)
Training with Multiple Losses
# Multiple loss components
for epoch in range(epochs):
for batch in train_dataset:
with GradientTape() as tape:
# Main task loss
= compute_main_loss(model, batch)
main_loss
# Regularization loss
= compute_regularization(model)
reg_loss
# Combined loss
= main_loss + reg_loss total_loss
The training loop’s structure can be customized for: - Multiple inputs/outputs - Custom regularization - Complex loss functions - Gradient accumulation