The Chain Rule & Backpropagation
๐ง The Theory
AI/ML Concept: Backpropagation (Single Node)
This concept of multiplying derivatives backward through a chain of equations is called Backpropagation.
When data flows forward through our network to make a prediction, it is called the Forward Pass. But when we calculate the error, we have to trace that error backward to see exactly who is responsible for it. We propagate the error backwards.
- The Loss function says: "Hey , you were off by this much!" (The Outer Derivative)
- The Prediction function turns around to the weight and says: "Hey , because the input was this size, your portion of the blame is this!" (The Inner Derivative multiplied by the Outer)
By chaining these derivatives together, the neural network can assign exact mathematical blame to every single weight, no matter how many hidden layers deep it is.
๐The Math
Math: The Chain Rule
In machine learning, our math is a chain of functions nested inside each other.
- We calculate a prediction:
- We plug that prediction into our loss function:
If we substitute the first equation into the second, our full equation is:
How do we find the derivative of this nested function with respect to our weight ()? We use The Chain Rule.
The Chain Rule states that the derivative of nested functions is the product of their individual derivatives.
Let's break that down:
- The Outer Derivative (): How does the prediction affect the loss? Using the power rule on , the derivative is .
- The Inner Derivative (): How does the weight affect the prediction? The derivative of with respect to is just .
Multiply them together, and you have the exact formula for your gradient:
๐กInsights and Mistakes
Developer's Insight: Asymptotic Convergence
While running the Backpropagation loop, I noticed an interesting pattern with the Loss over time. I had to manually increase the epochs variable to track when the model actually finished learning.
Here is what I observed during execution:
- Epoch 20: Loss = 404.0076
- Epoch 50: Loss = 2.7143 (Still dropping rapidly)
- Epoch 100: Loss = 0.0006 (Slowing down significantly)
- Epoch 112: Loss = 0.00008 (Converged. Barely changing after this point).
The Insight: Why does the learning slow down so drastically at the end? It is mathematically baked into our gradient formula: 2 * x * (y_hat - y).
As the model gets smarter, the prediction y_hat gets closer and closer to the true target y. This means the error (y_hat - y) approaches . Since the error is multiplying the entire gradient equation, the gradient itself shrinks toward .
Because the gradient is shrinking, our step sizes get microscopically small. The model takes massive leaps when it's wrong, but delicately tip-toes as it approaches the exact right answer. This is known as asymptotic convergence!
โ๏ธThe Code
def forward_pass(x: float, w: float) -> float:
"""the prediction (y_hat) of our model."""
return w * x
def calculate_loss(y: float, y_hat: float) -> float:
"""the squared error loss."""
return (y_hat - y) ** 2
def get_gradient(x: float, y: float, y_hat: float) -> float:
"""Uses the Chain Rule to calculate how much to change the weight."""
return 2 * x * (y_hat - y)
# House that is 2000 SqFt (x = 2.0). True price is $100k (y = 100.0).
x = 2.0
y = 100.0
w = 1.0 # Initial weight (price per SqFt)
learning_rate = 0.01
epochs = 125
for epoch in range(epochs):
y_hat = forward_pass(x, w)
loss = calculate_loss(y, y_hat)
gradient = get_gradient(x, y, y_hat)
w = w - learning_rate * gradient
print(f"Epoch {epoch + 1}: Weight = {w:.4f}, Loss = {loss:.8f}")Code Breakdown
get_gradient(...): Implements the Chain Rule formula2 * x * (y_hat - y). This exact equation points the weight in the exact direction needed to reduce error.- The Loop: Notice the clear separation of the Forward Pass (
forward_pass), Loss Calculation (calculate_loss), and Backpropagation (get_gradient). This mirrors exactly how PyTorch structures its training loops!