DeepLearning For Finance
  • Back to Main Website
  • Home
  • Introduction to Deep Learning
    • Introduction to Deep Learning
    • From Traditional Models to Deep Learning
    • The Multi-Layer Perceptron (MLP)
    • Automatic Differentiation: The Engine of Deep Learning
    • Computation Backends & Keras 3
    • GPUs and Deep Learning: When Hardware Matters
    • Keras Fundamentals: Models & Layers
    • Keras Matrix Operations: The Building Blocks
    • Activation Functions: Adding Non-linearity
    • Model Training Fundamentals

    • Travaux Pratiques
    • TP1 Corrected: Building Neural Networks - From Simple to Custom Implementations
    • TP1 Corrected: Building Neural Networks - From Simple to Custom Implementations
  • Recurrent Neural Networks
    • Recurrent Neural Networks
    • Sequential Data Processing: From MLPs to RNNs
    • Long Short-Term Memory Networks (LSTM)
    • Modern RNN Architectures
    • RNN Limitations: Computational Challenges

    • Travaux Pratiques
    • TP: Recurrent Neural Networks for Time Series Prediction
    • TP Corrected: Recurrent Neural Networks for Time Series Prediction
  • Training a Neural Network
    • Training a Neural Network
    • Understanding the Training Loop
    • Understanding Optimizers
    • Understanding Callbacks
    • Training Parameters and Practical Considerations

    • Travaux Pratiques
    • TP: Using Deep Learning Frameworks for General Optimization
    • tp_general_optimization_corrected.html
    • TP: Impact of Callbacks on Training
  • Essential Building Blocks of Modern Neural Networks
    • Essential Building Blocks of Modern Neural Networks
    • Residual Connections and Gating Mechanisms
    • Convolutional Layers: From Images to Time Series
    • Neural Network Embeddings: Learning Meaningful Representations
    • Attention Mechanisms: Learning What to Focus On
    • Encoder-Decoder Architectures

    • Travaux Pratiques
    • Practical Assignment: Building a Transformer-Based Architecture for Time Series Forecasting
    • Practical Assignment: Building a Transformer-Based Architecture for Time Series Forecasting
  • Projets
    • Projets
  • Code source
  1. Travaux Pratiques
  2. TP: Using Deep Learning Frameworks for General Optimization
  • Training a Neural Network
  • Understanding the Training Loop
  • Understanding Optimizers
  • Understanding Callbacks
  • Training Parameters and Practical Considerations
  • Travaux Pratiques
    • TP: Using Deep Learning Frameworks for General Optimization
    • tp_general_optimization_corrected.html
    • TP: Impact of Callbacks on Training
    • content/Cours_3/keras_callbacks_corrected.ipynb

On this page

  • Part 1: Setup and Data Generation
    • Exercise 1: Data Generation
  • Part 2: Traditional Optimization with SciPy
    • Exercise 2: Portfolio Statistics
    • Exercise 3: SciPy Optimization
  • Part 4: JAX Implementation
    • Exercise 4: JAX Portfolio Statistics
    • Exercise 5: Weight Projection
    • Exercise 6: Early Stopping Logic
    • Exercise 7: JAX Optimization
  • Part 5: Comparison and Visualization
    • Exercise 8: Comparison Function
    • Exercise 9: Visualization
  • Questions for Analysis:
  1. Travaux Pratiques
  2. TP: Using Deep Learning Frameworks for General Optimization

TP: Using Deep Learning Frameworks for General Optimization

Author

Remi Genet

Published

2025-04-03

In this TP, we’ll explore how deep learning frameworks can be used for general optimization problems, not just training neural networks. We’ll use portfolio optimization as our example.

Part 1: Setup and Data Generation

First, let’s implement the necessary imports and create our data generation function.

import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, jit
from jax.example_libraries import optimizers
from scipy.optimize import minimize
import time
import pandas as pd
import matplotlib.pyplot as plt
from typing import Tuple
from dataclasses import dataclass

# Set random seed for reproducibility
np.random.seed(42)
jax.config.update("jax_enable_x64", True)
: 

Exercise 1: Data Generation

Implement a function to generate random correlation matrix and expected returns:

def generate_data(n_assets: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    TODO: Generate random correlation matrix and expected returns
    Hint: Use np.random.randn for generation
    Hint: Correlation matrix should be symmetric and normalized
    """
    pass

Part 2: Traditional Optimization with SciPy

Let’s first implement the traditional approach using SciPy’s optimizer.

Exercise 2: Portfolio Statistics

Implement the function to calculate portfolio statistics:

def portfolio_stats_scipy(weights: np.ndarray, 
                         corr: np.ndarray, 
                         returns: np.ndarray, 
                         risk_free_rate: float = 0.02) -> Tuple[float, float, float]:
    """
    TODO: Calculate portfolio return, volatility, and Sharpe ratio
    Hint: Use numpy for matrix operations
    """
    pass

def negative_sharpe_scipy(weights: np.ndarray, 
                         corr: np.ndarray, 
                         returns: np.ndarray, 
                         risk_free_rate: float = 0.02) -> float:
    """
    TODO: Return negative Sharpe ratio for minimization
    """
    pass

Exercise 3: SciPy Optimization

Complete the SciPy optimization function:

def optimize_scipy(corr: np.ndarray, returns: np.ndarray) -> Tuple[np.ndarray, float, float]:
    """
    TODO: Implement portfolio optimization using SciPy
    Hint: Use minimize with SLSQP method
    Hint: Don't forget constraints (sum of weights = 1, weights >= 0)
    """
    pass

Part 4: JAX Implementation

Now let’s implement the same optimization using JAX and automatic differentiation.

Exercise 4: JAX Portfolio Statistics

Implement the JAX version of portfolio statistics:

@jit
def portfolio_stats_jax(weights: jnp.ndarray, 
                       corr: jnp.ndarray, 
                       returns: jnp.ndarray, 
                       risk_free_rate: float = 0.02) -> Tuple[float, float, float]:
    """
    TODO: Implement portfolio statistics using JAX
    Hint: Use jnp instead of np
    """
    pass

@jit
def negative_sharpe_jax(weights: jnp.ndarray, 
                       corr: jnp.ndarray, 
                       returns: jnp.ndarray, 
                       risk_free_rate: float = 0.02) -> float:
    """
    TODO: Implement negative Sharpe ratio using JAX
    """
    pass

Exercise 5: Weight Projection

Implement the simplex projection function:

@jit
def projection_simplex(x: jnp.ndarray) -> jnp.ndarray:
    """
    TODO: Project weights onto probability simplex
    Hint: Ensure non-negative weights that sum to 1
    """
    pass

Exercise 6: Early Stopping Logic

Implement early stopping and learning rate scheduling:

@dataclass
class EarlyStoppingState:
    """State for early stopping and learning rate scheduling."""
    best_sharpe: float = float('-inf')
    patience_counter: int = 0
    best_weights: jnp.ndarray = None
    lr: float = 0.1
    plateau_counter: int = 0

def update_early_stopping_state(state: EarlyStoppingState, 
                              current_sharpe: float, 
                              current_weights: jnp.ndarray,
                              patience: int = 10,
                              min_improvement: float = 1e-4,
                              lr_reduction_factor: float = 0.5,
                              lr_patience: int = 5,
                              min_lr: float = 1e-6) -> EarlyStoppingState:
    """
    TODO: Implement early stopping and learning rate scheduling logic
    Hint: Update best results if improved
    Hint: Implement patience mechanism
    Hint: Handle learning rate reduction
    """
    pass

Exercise 7: JAX Optimization

Complete the JAX optimization function:

def optimize_jax(corr: np.ndarray, 
                returns: np.ndarray, 
                n_iterations: int = 1000,
                initial_lr: float = 0.01,
                patience: int = 10,
                min_improvement: float = 1e-4,
                lr_patience: int = 5,
                lr_reduction_factor: float = 0.5,
                min_lr: float = 1e-6,
                b1: float = 0.9,
                b2: float = 0.999,
                eps: float = 1e-8) -> Tuple[np.ndarray, float, float]:
    """
    TODO: Implement portfolio optimization using JAX and Adam
    Hint: Use Adam optimizer from JAX
    Hint: Implement training loop with early stopping
    """
    pass

Part 5: Comparison and Visualization

Exercise 8: Comparison Function

Implement the comparison between optimizers:

def compare_optimizers() -> pd.DataFrame:
    """
    TODO: Compare JAX and SciPy optimizers
    Hint: Test with different numbers of assets
    Hint: Compare computation time and Sharpe ratios
    """
    pass

Exercise 9: Visualization

Create visualizations to compare the results:

def plot_comparison(results_df: pd.DataFrame):
    """
    TODO: Create plots comparing computation time and Sharpe ratios
    Hint: Use loglog plot for computation time
    Hint: Use semilogx for Sharpe ratios
    """
    pass

Questions for Analysis:

  1. How does the computation time scale with the number of assets for each method?
  2. Compare the Sharpe ratios achieved by both methods. Are they similar?
  3. What are the advantages and disadvantages of each approach?
  4. How does the choice of optimizer parameters (learning rate, patience, etc.) affect the results?
  5. Can you think of other applications where this approach could be useful?
Back to top
Training Parameters and Practical Considerations
tp_general_optimization_corrected.html

Deep Learning For Finance, Rémi Genet.
Licence
Code source disponible sur Github

 

Site construit avec et Quarto
Inspiration pour la mise en forme du site ici
Code source disponible sur GitHub