TP: Using Deep Learning Frameworks for General Optimization
Author
Remi Genet
Published
2025-02-18
In this TP, we’ll explore how deep learning frameworks can be used for general optimization problems, not just training neural networks. We’ll use portfolio optimization as our example.
Part 1: Setup and Data Generation
First, let’s implement the necessary imports and create our data generation function.
import numpy as npimport jaximport jax.numpy as jnpfrom jax import grad, jitfrom jax.example_libraries import optimizersfrom scipy.optimize import minimizeimport timeimport pandas as pdimport matplotlib.pyplot as pltfrom typing import Tuplefrom dataclasses import dataclass# Set random seed for reproducibilitynp.random.seed(42)jax.config.update("jax_enable_x64", True)
:
Exercise 1: Data Generation
Implement a function to generate random correlation matrix and expected returns:
def generate_data(n_assets: int) -> Tuple[np.ndarray, np.ndarray]:"""TODO: Generate random correlation matrix and expected returns Hint: Use np.random.randn for generation Hint: Correlation matrix should be symmetric and normalized """pass
Part 2: Traditional Optimization with SciPy
Let’s first implement the traditional approach using SciPy’s optimizer.
Exercise 2: Portfolio Statistics
Implement the function to calculate portfolio statistics:
def portfolio_stats_scipy(weights: np.ndarray, corr: np.ndarray, returns: np.ndarray, risk_free_rate: float=0.02) -> Tuple[float, float, float]:"""TODO: Calculate portfolio return, volatility, and Sharpe ratio Hint: Use numpy for matrix operations """passdef negative_sharpe_scipy(weights: np.ndarray, corr: np.ndarray, returns: np.ndarray, risk_free_rate: float=0.02) ->float:"""TODO: Return negative Sharpe ratio for minimization """pass
Exercise 3: SciPy Optimization
Complete the SciPy optimization function:
def optimize_scipy(corr: np.ndarray, returns: np.ndarray) -> Tuple[np.ndarray, float, float]:"""TODO: Implement portfolio optimization using SciPy Hint: Use minimize with SLSQP method Hint: Don't forget constraints (sum of weights = 1, weights >= 0) """pass
Part 4: JAX Implementation
Now let’s implement the same optimization using JAX and automatic differentiation.
Exercise 4: JAX Portfolio Statistics
Implement the JAX version of portfolio statistics:
@jitdef portfolio_stats_jax(weights: jnp.ndarray, corr: jnp.ndarray, returns: jnp.ndarray, risk_free_rate: float=0.02) -> Tuple[float, float, float]:"""TODO: Implement portfolio statistics using JAX Hint: Use jnp instead of np """pass@jitdef negative_sharpe_jax(weights: jnp.ndarray, corr: jnp.ndarray, returns: jnp.ndarray, risk_free_rate: float=0.02) ->float:"""TODO: Implement negative Sharpe ratio using JAX """pass
Exercise 5: Weight Projection
Implement the simplex projection function:
@jitdef projection_simplex(x: jnp.ndarray) -> jnp.ndarray:"""TODO: Project weights onto probability simplex Hint: Ensure non-negative weights that sum to 1 """pass
Exercise 6: Early Stopping Logic
Implement early stopping and learning rate scheduling:
@dataclassclass EarlyStoppingState:"""State for early stopping and learning rate scheduling.""" best_sharpe: float=float('-inf') patience_counter: int=0 best_weights: jnp.ndarray =None lr: float=0.1 plateau_counter: int=0def update_early_stopping_state(state: EarlyStoppingState, current_sharpe: float, current_weights: jnp.ndarray, patience: int=10, min_improvement: float=1e-4, lr_reduction_factor: float=0.5, lr_patience: int=5, min_lr: float=1e-6) -> EarlyStoppingState:"""TODO: Implement early stopping and learning rate scheduling logic Hint: Update best results if improved Hint: Implement patience mechanism Hint: Handle learning rate reduction """pass
Exercise 7: JAX Optimization
Complete the JAX optimization function:
def optimize_jax(corr: np.ndarray, returns: np.ndarray, n_iterations: int=1000, initial_lr: float=0.01, patience: int=10, min_improvement: float=1e-4, lr_patience: int=5, lr_reduction_factor: float=0.5, min_lr: float=1e-6, b1: float=0.9, b2: float=0.999, eps: float=1e-8) -> Tuple[np.ndarray, float, float]:"""TODO: Implement portfolio optimization using JAX and Adam Hint: Use Adam optimizer from JAX Hint: Implement training loop with early stopping """pass
Part 5: Comparison and Visualization
Exercise 8: Comparison Function
Implement the comparison between optimizers:
def compare_optimizers() -> pd.DataFrame:"""TODO: Compare JAX and SciPy optimizers Hint: Test with different numbers of assets Hint: Compare computation time and Sharpe ratios """pass
Exercise 9: Visualization
Create visualizations to compare the results:
def plot_comparison(results_df: pd.DataFrame):"""TODO: Create plots comparing computation time and Sharpe ratios Hint: Use loglog plot for computation time Hint: Use semilogx for Sharpe ratios """pass
Questions for Analysis:
How does the computation time scale with the number of assets for each method?
Compare the Sharpe ratios achieved by both methods. Are they similar?
What are the advantages and disadvantages of each approach?
How does the choice of optimizer parameters (learning rate, patience, etc.) affect the results?
Can you think of other applications where this approach could be useful?