Computation Backends & Keras 3
Cours
Fundamentals
Understanding the ecosystem of deep learning frameworks and how Keras 3 abstracts hardware acceleration through backend engines.
The Stack: From Python to Silicon
Section 1.18 - Core Computation Engines
TensorFlow (Google)
- Developer: Google Brain Team (2015)
- Key Trait: Static computation graphs (define-and-run)
- Strengths:
- Production-grade deployment (TF Serving, TFLite)
- Tight TPU integration
- Production-grade deployment (TF Serving, TFLite)
- Weakness: Less flexible for research prototyping
PyTorch (Meta)
- Developer: Facebook AI Research (2016)
- Key Trait: Dynamic computation graphs (define-by-run)
- Strengths:
- Pythonic debugging experience
- Dominant in academic research
- Pythonic debugging experience
- Weakness: Historically weaker mobile/edge support
JAX (Google)
- Developer: Google Research (2018)
- Key Trait: Functional programming + composable transforms
- Strengths:
- Automatic vectorization (
vmap
)
- Native support for higher-order gradients
- Automatic vectorization (
- Weakness: Steeper learning curve
Section 1.19 - Keras 3: Unified Abstraction Layer
Key Innovation
Keras 3 acts as a backend-agnostic interface:
# Same code runs on TensorFlow, PyTorch, or JAX
import os
"KERAS_BACKEND"] = "jax" # Environment variable needs to be set prior to importing keras, default is tensorflow
os.environ[
from keras import layers
= keras.Sequential([
model 64, activation='relu'),
layers.Dense(10)
layers.Dense(
])
compile() model.
Architecture
┌──────────────────────────┐
│ Keras API (Python) │ ← You code here
├──────────────────────────┤
│ Backend Adapter │ ← Converts Keras ops to backend primitives
├───────┬────────┬─────────┤
│ TF │ PyTorch│ JAX │ ← Backend engines
├───────┴────────┴─────────┤
│ XLA/CUDA/C++/ROCm │ ← Hardware-specific optimization
└──────────────────────────┘
Section 1.20 - The Performance Layer
Under the Hood
All frameworks ultimately delegate computation to:
- Optimized C/C++ Kernels:
- BLAS (e.g., Intel MKL, OpenBLAS) for linear algebra
- Custom ops for neural networks (e.g., convolution)
- BLAS (e.g., Intel MKL, OpenBLAS) for linear algebra
- GPU Acceleration:
- CUDA (NVIDIA) / ROCm (AMD) for parallel computation
- Kernel fusion via XLA (TensorFlow/JAX) or TorchScript (PyTorch)
- CUDA (NVIDIA) / ROCm (AMD) for parallel computation
Example Stack Trace
# Python
keras.layers.Dense(..)
↓# Backend-agnostic op
keras.backend.matmul()
↓# TensorFlow implementation
tf.linalg.matmul()
↓# C++/CUDA kernel Eigen::Tensor contraction
Section 1.21 - Why Abstraction Matters
- Portability: Same model code runs on CPU/GPU/TPU
- Vendor Independence: Avoid lock-in to any ecosystem
- Performance: Leverage decades of HPC optimization