DeepLearning For Finance
  • Back to Main Website
  • Home
  • Introduction to Deep Learning
    • Introduction to Deep Learning
    • From Traditional Models to Deep Learning
    • The Multi-Layer Perceptron (MLP)
    • Automatic Differentiation: The Engine of Deep Learning
    • Computation Backends & Keras 3
    • GPUs and Deep Learning: When Hardware Matters
    • Keras Fundamentals: Models & Layers
    • Keras Matrix Operations: The Building Blocks
    • Activation Functions: Adding Non-linearity
    • Model Training Fundamentals

    • Travaux Pratiques
    • TP1 Corrected: Building Neural Networks - From Simple to Custom Implementations
    • TP1 Corrected: Building Neural Networks - From Simple to Custom Implementations
  • Recurrent Neural Networks
    • Recurrent Neural Networks
    • Sequential Data Processing: From MLPs to RNNs
    • Long Short-Term Memory Networks (LSTM)
    • Modern RNN Architectures
    • RNN Limitations: Computational Challenges

    • Travaux Pratiques
    • TP: Recurrent Neural Networks for Time Series Prediction
    • TP Corrected: Recurrent Neural Networks for Time Series Prediction
  • Training a Neural Network
    • Training a Neural Network
    • Understanding the Training Loop
    • Understanding Optimizers
    • Understanding Callbacks
    • Training Parameters and Practical Considerations

    • Travaux Pratiques
    • TP: Using Deep Learning Frameworks for General Optimization
    • tp_general_optimization_corrected.html
    • TP: Impact of Callbacks on Training
  • Essential Building Blocks of Modern Neural Networks
    • Essential Building Blocks of Modern Neural Networks
    • Residual Connections and Gating Mechanisms
    • Convolutional Layers: From Images to Time Series
    • Neural Network Embeddings: Learning Meaningful Representations
    • Attention Mechanisms: Learning What to Focus On
    • Encoder-Decoder Architectures

    • Travaux Pratiques
    • Practical Assignment: Building a Transformer-Based Architecture for Time Series Forecasting
    • Practical Assignment: Building a Transformer-Based Architecture for Time Series Forecasting
  • Projets
    • Projets
  • Code source
  1. Travaux Pratiques
  2. TP Corrected: Recurrent Neural Networks for Time Series Prediction
  • Recurrent Neural Networks
  • Sequential Data Processing: From MLPs to RNNs
  • Long Short-Term Memory Networks (LSTM)
  • Modern RNN Architectures
  • RNN Limitations: Computational Challenges
  • Travaux Pratiques
    • TP: Recurrent Neural Networks for Time Series Prediction
    • TP Corrected: Recurrent Neural Networks for Time Series Prediction

On this page

  • Data Preparation
  • Question 1: Simple LSTM Model
  • Question 2: GRU Implementation
  • Question 3: Two-Layer Models
  • Question 4: Custom LSTM Implementation
  1. Travaux Pratiques
  2. TP Corrected: Recurrent Neural Networks for Time Series Prediction

TP Corrected: Recurrent Neural Networks for Time Series Prediction

Author

Remi Genet

Published

2025-04-03

!pip install numpy keras jax matplotlib scikit-learn yfinance
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (2.2.2)
Requirement already satisfied: keras in /usr/local/lib/python3.10/dist-packages (3.8.0)
Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (0.5.0)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.10.0)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.6.1)
Requirement already satisfied: yfinance in /usr/local/lib/python3.10/dist-packages (0.2.52)
Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras) (0.0.8)
Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from keras) (24.2)
Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from keras) (3.12.1)
Requirement already satisfied: ml-dtypes in /usr/local/lib/python3.10/dist-packages (from keras) (0.5.1)
Requirement already satisfied: optree in /usr/local/lib/python3.10/dist-packages (from keras) (0.14.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from keras) (2.1.0)
Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras) (13.9.4)
Requirement already satisfied: scipy>=1.11.1 in /usr/local/lib/python3.10/dist-packages (from jax) (1.15.1)
Requirement already satisfied: jaxlib<=0.5.0,>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from jax) (0.5.0)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.10/dist-packages (from jax) (3.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib) (2.4.7)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.55.8)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (11.1.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.8)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)
Requirement already satisfied: requests>=2.31 in /usr/local/lib/python3.10/dist-packages (from yfinance) (2.32.3)
Requirement already satisfied: beautifulsoup4>=4.11.1 in /usr/local/lib/python3.10/dist-packages (from yfinance) (4.12.3)
Requirement already satisfied: frozendict>=2.3.4 in /usr/local/lib/python3.10/dist-packages (from yfinance) (2.4.6)
Requirement already satisfied: pytz>=2022.5 in /usr/local/lib/python3.10/dist-packages (from yfinance) (2024.2)
Requirement already satisfied: html5lib>=1.1 in /usr/local/lib/python3.10/dist-packages (from yfinance) (1.1)
Requirement already satisfied: pandas>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from yfinance) (2.2.3)
Requirement already satisfied: platformdirs>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from yfinance) (4.3.6)
Requirement already satisfied: multitasking>=0.0.7 in /usr/local/lib/python3.10/dist-packages (from yfinance) (0.0.11)
Requirement already satisfied: peewee>=3.16.2 in /usr/local/lib/python3.10/dist-packages (from yfinance) (3.17.8)
Requirement already satisfied: lxml>=4.9.1 in /usr/local/lib/python3.10/dist-packages (from yfinance) (5.3.0)
Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4>=4.11.1->yfinance) (2.6)
Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from html5lib>=1.1->yfinance) (0.5.1)
Requirement already satisfied: six>=1.9 in /usr/lib/python3/dist-packages (from html5lib>=1.1->yfinance) (1.16.0)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.0->yfinance) (2025.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31->yfinance) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31->yfinance) (2.3.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31->yfinance) (3.4.1)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31->yfinance) (2024.12.14)
Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from optree->keras) (4.12.2)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras) (2.19.1)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras) (3.0.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.2)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Data Preparation

First, let’s prepare our dataset using yfinance to get historical stock data.

import os
os.environ["KERAS_BACKEND"] = "jax"  # Environment variable need to be set prior to importing keras, default is tensorflow
import yfinance as yf
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import keras
from keras import layers
import matplotlib.pyplot as plt

# Download data
ticker = "^GSPC"
data = yf.download(ticker, start="2000-01-01", end="2024-01-01")

# Compute absolute returns (volatility proxy)
returns = np.abs(data['Close'].pct_change())
returns = returns.dropna()

def create_sequences(data, seq_length, horizon):
    """Create sequences for training"""
    X, y = [], []
    for i in range(len(data) - seq_length - horizon + 1):
        X.append(data[i:(i + seq_length)])
        y.append(data[(i + seq_length):(i + seq_length + horizon)])
    return np.array(X), np.squeeze(np.array(y))

# Parameters
sequence_length = 200
prediction_horizon = 1

# Create sequences
X, y = create_sequences(returns.values, sequence_length, prediction_horizon)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, shuffle=False
)

early_stopping_callback = lambda : keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=10,
    mode="min",
    restore_best_weights=True,
    start_from_epoch=6,
)
lr_callback = lambda : keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.25,
    patience=5,
    mode="min",
    min_delta=0.00001,
    min_lr=0.000025,
    verbose=0,
)
callbacks = lambda : [early_stopping_callback(), lr_callback(), keras.callbacks.TerminateOnNaN()]
[*********************100%***********************]  1 of 1 completed

Question 1: Simple LSTM Model

Implement a single-layer LSTM of 100 units, that only return last hidden state, followed by a linear Dense layer model using Keras Sequential API to predict the next value. Compare its performance with a simple linear regression model, and a 3 layer MLP of 100 units with relu activation.

# LSTM Model
def create_simple_lstm():
    model = keras.Sequential([
        layers.LSTM(100),
        layers.Dense(prediction_horizon)
    ])
    
    model.compile(optimizer='adam', loss='mse')
    return model

# LSTM Model
def create_mlp():
    model = keras.Sequential([
        layers.Dense(100, 'relu'),
        layers.Dense(100, 'relu'),
        layers.Dense(100, 'relu'),
        layers.Dense(prediction_horizon)
    ])
    
    model.compile(optimizer='adam', loss='mse')
    return model


# Train and evaluate LSTM
lstm_model = create_simple_lstm()
lstm_history = lstm_model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    verbose=0,
    callbacks=callbacks(),
)

# Train and evaluate Linear Regression
# Reshape data for LinearRegression
X_train_2d = X_train.reshape(X_train.shape[0], -1)
X_test_2d = X_test.reshape(X_test.shape[0], -1)

mlp_model = create_mlp()
mlp_history = mlp_model.fit(
    X_train_2d, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    verbose=0, 
    callbacks=callbacks(),
)

lr_model = LinearRegression()
lr_model.fit(X_train_2d, y_train)

# Predictions
lstm_preds = lstm_model.predict(X_test)
mlp_preds = mlp_model.predict(X_test_2d)
lr_preds = lr_model.predict(X_test_2d)

# Calculate R² scores
lstm_r2 = r2_score(y_test, lstm_preds)
mlp_r2 = r2_score(y_test, mlp_preds)
lr_r2 = r2_score(y_test, lr_preds)

print(f"LSTM R² score: {lstm_r2:.4f}")
print(f"MLP R² score: {mlp_r2:.4f}")
print(f"Linear Regression R² score: {lr_r2:.4f}")
37/37 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step  
37/37 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step 
LSTM R² score: 0.3068
MLP R² score: 0.1371
Linear Regression R² score: 0.2458

Question 2: GRU Implementation

Now implement a similar model using GRU and compare its performance.

def create_simple_gru():
    model = keras.Sequential([
        layers.GRU(100),
        layers.Dense(prediction_horizon)
    ])
    
    model.compile(optimizer='adam', loss='mse')
    return model

# Train and evaluate GRU
gru_model = create_simple_gru()
gru_history = gru_model.fit(
    X_train, y_train,
    epochs=10,
    batch_size=32,
    validation_split=0.2,
    verbose=0,
    callbacks=callbacks(),
)

# Predictions
gru_preds = gru_model.predict(X_test)
gru_r2 = r2_score(y_test, gru_preds)

print(f"GRU R² score: {gru_r2:.4f}")
37/37 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step  
GRU R² score: 0.2700

Question 3: Two-Layer Models

Implement and compare two-layer versions of both LSTM and GRU models.

def create_double_lstm():
    model = keras.Sequential([
        layers.LSTM(100, return_sequences=True),
        layers.LSTM(100),
        layers.Dense(prediction_horizon)
    ])
    
    model.compile(optimizer='adam', loss='mse')
    return model

def create_double_gru():
    model = keras.Sequential([
        layers.GRU(100, return_sequences=True),
        layers.GRU(100),
        layers.Dense(prediction_horizon)
    ])
    
    model.compile(optimizer='adam', loss='mse')
    return model

# Train and evaluate double-layer models
double_lstm_model = create_double_lstm()
double_lstm_history = double_lstm_model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    verbose=0,
    callbacks=callbacks(),
)

double_gru_model = create_double_gru()
double_gru_history = double_gru_model.fit(
    X_train, y_train,
    epochs=100,
    batch_size=32,
    validation_split=0.2,
    verbose=0,
    callbacks=callbacks(),
)

# Predictions
double_lstm_preds = double_lstm_model.predict(X_test)
double_gru_preds = double_gru_model.predict(X_test)

double_lstm_r2 = r2_score(y_test, double_lstm_preds)
double_gru_r2 = r2_score(y_test, double_gru_preds)

print(f"Double LSTM R² score: {double_lstm_r2:.4f}")
print(f"Double GRU R² score: {double_gru_r2:.4f}")
37/37 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step 
37/37 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step 
Double LSTM R² score: 0.3084
Double GRU R² score: 0.2510

Question 4: Custom LSTM Implementation

Implement a custom LSTM layer by writing out all the equations. Here’s the mathematical formulation you need to implement:

  1. Input gate: i_t = σ(W_i·[h_{t-1}, x_t] + b_i)
  2. Forget gate: f_t = σ(W_f·[h_{t-1}, x_t] + b_f)
  3. Cell state: c̃_t = tanh(W_c·[h_{t-1}, x_t] + b_c)
  4. Output gate: o_t = σ(W_o·[h_{t-1}, x_t] + b_o)
  5. New cell state: c_t = f_t ∗ c_{t-1} + i_t ∗ c̃_t
  6. Hidden state: h_t = o_t ∗ tanh(c_t)

Hint: Split the code in two, implement a method that do one step only, while the call function iterate over the sequence and uses it

from keras import ops
class CustomLSTM(layers.Layer):
    def __init__(self, units, return_sequences=False, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.return_sequences = return_sequences

    def build(self, input_shape):
        input_dim = input_shape[-1]
        
        # Input weights
        self.W_i = self.add_weight(shape=(input_dim, self.units),
                                 initializer='glorot_uniform',
                                 name='W_i')
        self.W_f = self.add_weight(shape=(input_dim, self.units),
                                 initializer='glorot_uniform',
                                 name='W_f')
        self.W_c = self.add_weight(shape=(input_dim, self.units),
                                 initializer='glorot_uniform',
                                 name='W_c')
        self.W_o = self.add_weight(shape=(input_dim, self.units),
                                 initializer='glorot_uniform',
                                 name='W_o')
        
        # Recurrent weights
        self.U_i = self.add_weight(shape=(self.units, self.units),
                                 initializer='orthogonal',
                                 name='U_i')
        self.U_f = self.add_weight(shape=(self.units, self.units),
                                 initializer='orthogonal',
                                 name='U_f')
        self.U_c = self.add_weight(shape=(self.units, self.units),
                                 initializer='orthogonal',
                                 name='U_c')
        self.U_o = self.add_weight(shape=(self.units, self.units),
                                 initializer='orthogonal',
                                 name='U_o')
        
        # Biases
        self.b_i = self.add_weight(shape=(self.units,),
                                 initializer='zeros',
                                 name='b_i')
        self.b_f = self.add_weight(shape=(self.units,),
                                 initializer='ones',  # Initialize forget gate bias to 1
                                 name='b_f')
        self.b_c = self.add_weight(shape=(self.units,),
                                 initializer='zeros',
                                 name='b_c')
        self.b_o = self.add_weight(shape=(self.units,),
                                 initializer='zeros',
                                 name='b_o')
        
        self.built = True

    def lstm_step(self, x_t, h_prev, c_prev):
        # Input gate
        i = ops.sigmoid(
            ops.dot(x_t, self.W_i) + 
            ops.dot(h_prev, self.U_i) + 
            self.b_i
        )
        
        # Forget gate
        f = ops.sigmoid(
            ops.dot(x_t, self.W_f) + 
            ops.dot(h_prev, self.U_f) + 
            self.b_f
        )
        
        # Cell candidate
        c_tilde = ops.tanh(
            ops.dot(x_t, self.W_c) + 
            ops.dot(h_prev, self.U_c) + 
            self.b_c
        )
        
        # Output gate
        o = ops.sigmoid(
            ops.dot(x_t, self.W_o) + 
            ops.dot(h_prev, self.U_o) + 
            self.b_o
        )
        
        # New cell state
        c_t = f * c_prev + i * c_tilde
        
        # New hidden state
        h_t = o * ops.tanh(c_t)
        
        return h_t, c_t

    def call(self, inputs):
        # Get sequence length and batch size
        _, time_steps, _ = inputs.shape
        
        # Initialize hidden state and cell state
        h_t = ops.zeros((inputs.shape[0], self.units))
        c_t = ops.zeros((inputs.shape[0], self.units))
        
        # Store outputs if return_sequences is True
        if self.return_sequences:
            outputs = []
        
        # Process each timestep
        for t in range(time_steps):
            x_t = inputs[:, t, :]
            h_t, c_t = self.lstm_step(x_t, h_t, c_t)
            
            if self.return_sequences:
                outputs.append(h_t)
        
        # Return full sequence or just final output
        if self.return_sequences:
            return ops.stack(outputs, axis=1)
        return h_t

# Create and train model with custom LSTM
def create_custom_lstm_model():
    model = keras.Sequential([
        CustomLSTM(100),
        layers.Dense(prediction_horizon)
    ])
    
    model.compile(optimizer='adam', loss='mse')
    return model

custom_lstm_model = create_custom_lstm_model()
custom_lstm_history = custom_lstm_model.fit(
    X_train, y_train,
    epochs=50,
    batch_size=32,
    validation_split=0.2,
    verbose=0,
    callbacks=callbacks(),
)

# Predictions
custom_lstm_preds = custom_lstm_model.predict(X_test)
custom_lstm_r2 = r2_score(y_test, custom_lstm_preds)

print(f"Custom LSTM R² score: {custom_lstm_r2:.4f}")

# Plot comparison of all models
models = ['Linear', 'LSTM', 'GRU', 'Double LSTM', 'Double GRU', 'Custom LSTM']
scores = [lr_r2, lstm_r2, gru_r2, double_lstm_r2, double_gru_r2, custom_lstm_r2]

plt.figure(figsize=(10, 6))
plt.bar(models, scores)
plt.title('Model Comparison - R² Scores')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
37/37 ━━━━━━━━━━━━━━━━━━━━ 32s 463ms/step
Custom LSTM R² score: 0.2960

Back to top
TP: Recurrent Neural Networks for Time Series Prediction

Deep Learning For Finance, Rémi Genet.
Licence
Code source disponible sur Github

 

Site construit avec et Quarto
Inspiration pour la mise en forme du site ici
Code source disponible sur GitHub