Source code for ergodic_insurance.hjb_solver

"""Hamilton-Jacobi-Bellman solver for optimal insurance control.

This module implements a Hamilton-Jacobi-Bellman (HJB) partial differential equation
solver for finding optimal insurance strategies through dynamic programming. The solver
handles multi-dimensional state spaces and provides theoretically optimal control policies.

The HJB equation provides globally optimal solutions by solving:

    ∂V/∂t + max_u[L^u V + f(x,u)] = 0

where V is the value function, L^u is the controlled infinitesimal generator,
and f(x,u) is the running cost/reward.

Author: Alex Filiakov
Date: 2025-01-26
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from itertools import product as itertools_product
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
from scipy import interpolate, sparse

logger = logging.getLogger(__name__)

# Module-level named constants for numerical tolerances
_DRIFT_THRESHOLD = 1e-10
_MARGINAL_UTILITY_FLOOR = 1e-10
_GAMMA_TOLERANCE = 1e-10

# Policy improvement strategy thresholds
_VECTORIZE_COMBO_THRESHOLD = 5000  # Below: full vectorized batch
_COARSE_STRIDE = 3  # Adaptive: every 3rd point
_REFINE_RADIUS = 2  # Adaptive: +/-2 points around optimum
_DEFAULT_MEMORY_BUDGET_MB = 256  # Max memory for batched evaluation


[docs] class NumericalDivergenceError(RuntimeError): """Raised when the HJB solver detects NaN or Inf in the value function.""" pass
[docs] class TimeSteppingScheme(Enum): """Time stepping schemes for PDE integration.""" EXPLICIT = "explicit" IMPLICIT = "implicit" CRANK_NICOLSON = "crank_nicolson"
[docs] class BoundaryCondition(Enum): """Types of boundary conditions.""" DIRICHLET = "dirichlet" # Fixed value NEUMANN = "neumann" # Fixed derivative ABSORBING = "absorbing" # Zero second derivative REFLECTING = "reflecting" # Zero first derivative
[docs] @dataclass class StateVariable: """Definition of a state variable in the HJB problem.""" name: str min_value: float max_value: float num_points: int boundary_lower: BoundaryCondition = BoundaryCondition.ABSORBING boundary_upper: BoundaryCondition = BoundaryCondition.ABSORBING log_scale: bool = False
[docs] def __post_init__(self): """Validate state variable configuration.""" if self.min_value >= self.max_value: raise ValueError(f"min_value must be less than max_value for {self.name}") if self.num_points < 3: raise ValueError(f"Need at least 3 grid points for {self.name}") if self.log_scale and self.min_value <= 0: raise ValueError(f"Cannot use log scale with non-positive min_value for {self.name}")
[docs] def get_grid(self) -> np.ndarray: """Generate grid points for this variable. Returns: Array of grid points """ if self.log_scale: return np.logspace(np.log10(self.min_value), np.log10(self.max_value), self.num_points) return np.linspace(self.min_value, self.max_value, self.num_points)
[docs] @dataclass class ControlVariable: """Definition of a control variable in the HJB problem.""" name: str min_value: float max_value: float num_points: int = 50 continuous: bool = True log_scale: bool = False
[docs] def __post_init__(self): """Validate control variable configuration.""" if self.min_value >= self.max_value: raise ValueError(f"min_value must be less than max_value for {self.name}") if self.num_points < 2: raise ValueError(f"Need at least 2 control points for {self.name}") if self.log_scale and self.min_value <= 0: raise ValueError(f"Cannot use log scale with non-positive min_value for {self.name}")
[docs] def get_values(self) -> np.ndarray: """Get discrete control values for optimization. Returns: Array of control values """ if self.log_scale: return np.geomspace(self.min_value, self.max_value, self.num_points) return np.linspace(self.min_value, self.max_value, self.num_points)
[docs] @dataclass class StateSpace: """Multi-dimensional state space for HJB problem. Handles arbitrary dimensionality with proper grid management and boundary condition enforcement. """ state_variables: List[StateVariable]
[docs] def __post_init__(self): """Initialize derived attributes.""" self.ndim = len(self.state_variables) self.shape = tuple(sv.num_points for sv in self.state_variables) self.size = np.prod(self.shape) # Create grids for each dimension self.grids = [sv.get_grid() for sv in self.state_variables] # Create meshgrid for full state space self.meshgrid = np.meshgrid(*self.grids, indexing="ij") # Flatten for linear algebra operations self.flat_grids = [mg.ravel() for mg in self.meshgrid] logger.info(f"Initialized {self.ndim}D state space with shape {self.shape}")
[docs] def get_boundary_mask(self) -> np.ndarray: """Get boolean mask for boundary points. Returns: Boolean array where True indicates boundary points """ mask = np.zeros(self.shape, dtype=bool) for dim, _sv in enumerate(self.state_variables): # Create slice for this dimension's boundaries slices_lower: list[slice | int] = [slice(None)] * self.ndim slices_lower[dim] = 0 mask[tuple(slices_lower)] = True slices_upper: list[slice | int] = [slice(None)] * self.ndim slices_upper[dim] = -1 mask[tuple(slices_upper)] = True return mask
[docs] def interpolate_value(self, value_function: np.ndarray, points: np.ndarray) -> np.ndarray: """Interpolate value function at arbitrary points. Args: value_function: Value function on grid points: Points to interpolate at (shape: [n_points, n_dims]) Returns: Interpolated values """ if self.ndim == 1: interp = interpolate.interp1d( self.grids[0], value_function, kind="cubic", bounds_error=False, fill_value="extrapolate", ) return np.array(interp(points[:, 0])) if self.ndim == 2: interp = interpolate.RegularGridInterpolator( self.grids, value_function, method="linear", bounds_error=False, fill_value=None ) return np.array(interp(points)) # For higher dimensions, use linear interpolation interp = interpolate.RegularGridInterpolator( self.grids, value_function, method="linear", bounds_error=False, fill_value=None ) return np.array(interp(points))
[docs] class UtilityFunction(ABC): """Abstract base class for utility functions. Defines the interface for utility functions used in the HJB equation. Concrete implementations should provide both the utility value and its derivative. """
[docs] @abstractmethod def evaluate(self, wealth: np.ndarray) -> np.ndarray: """Evaluate utility at given wealth levels. Args: wealth: Wealth values Returns: Utility values """ pass # pylint: disable=unnecessary-pass
[docs] @abstractmethod def derivative(self, wealth: np.ndarray) -> np.ndarray: """Compute marginal utility (first derivative). Args: wealth: Wealth values Returns: Marginal utility values """ pass # pylint: disable=unnecessary-pass
[docs] @abstractmethod def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray: """Compute inverse of marginal utility. Used for finding optimal controls in some formulations. Args: marginal_utility: Marginal utility values Returns: Wealth values corresponding to given marginal utilities """ pass # pylint: disable=unnecessary-pass
[docs] class LogUtility(UtilityFunction): """Logarithmic utility function for ergodic optimization. U(w) = log(w) This utility function maximizes the long-term growth rate and is particularly suitable for ergodic analysis. """ def __init__(self, wealth_floor: float = 1e-6): """Initialize log utility. Args: wealth_floor: Minimum wealth to prevent log(0) """ self.wealth_floor = wealth_floor
[docs] def evaluate(self, wealth: np.ndarray) -> np.ndarray: """Evaluate log utility.""" safe_wealth = np.maximum(wealth, self.wealth_floor) return np.array(np.log(safe_wealth))
[docs] def derivative(self, wealth: np.ndarray) -> np.ndarray: """Compute marginal utility: U'(w) = 1/w.""" safe_wealth = np.maximum(wealth, self.wealth_floor) return np.array(1.0 / safe_wealth)
[docs] def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray: """Compute inverse: (U')^(-1)(m) = 1/m.""" safe_marginal = np.maximum(marginal_utility, _MARGINAL_UTILITY_FLOOR) return np.array(1.0 / safe_marginal)
[docs] class PowerUtility(UtilityFunction): """Power (CRRA) utility function with risk aversion parameter. U(w) = w^(1-γ)/(1-γ) for γ ≠ 1 U(w) = log(w) for γ = 1 where γ is the coefficient of relative risk aversion. """ def __init__(self, risk_aversion: float = 2.0, wealth_floor: float = 1e-6): """Initialize power utility. Args: risk_aversion: Coefficient of relative risk aversion (γ) wealth_floor: Minimum wealth to prevent numerical issues """ self.gamma = risk_aversion self.wealth_floor = wealth_floor # Use log utility if gamma is close to 1 if abs(self.gamma - 1.0) < _GAMMA_TOLERANCE: self._log_utility = LogUtility(wealth_floor)
[docs] def evaluate(self, wealth: np.ndarray) -> np.ndarray: """Evaluate power utility.""" if abs(self.gamma - 1.0) < _GAMMA_TOLERANCE: return self._log_utility.evaluate(wealth) safe_wealth = np.maximum(wealth, self.wealth_floor) return np.array(np.power(safe_wealth, 1 - self.gamma) / (1 - self.gamma))
[docs] def derivative(self, wealth: np.ndarray) -> np.ndarray: """Compute marginal utility: U'(w) = w^(-γ).""" if abs(self.gamma - 1.0) < _GAMMA_TOLERANCE: return self._log_utility.derivative(wealth) safe_wealth = np.maximum(wealth, self.wealth_floor) return np.array(np.power(safe_wealth, -self.gamma))
[docs] def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray: """Compute inverse: (U')^(-1)(m) = m^(-1/γ).""" if abs(self.gamma - 1.0) < _GAMMA_TOLERANCE: return self._log_utility.inverse_derivative(marginal_utility) safe_marginal = np.maximum(marginal_utility, _MARGINAL_UTILITY_FLOOR) return np.array(np.power(safe_marginal, -1.0 / self.gamma))
[docs] class ExpectedWealth(UtilityFunction): """Linear utility function for risk-neutral wealth maximization. U(w) = w This represents risk-neutral preferences where the goal is to maximize expected wealth. """
[docs] def evaluate(self, wealth: np.ndarray) -> np.ndarray: """Evaluate linear utility.""" return wealth
[docs] def derivative(self, wealth: np.ndarray) -> np.ndarray: """Compute marginal utility: U'(w) = 1.""" return np.ones_like(wealth)
[docs] def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray: """Inverse is undefined for constant marginal utility.""" raise NotImplementedError("Inverse derivative undefined for linear utility")
[docs] @dataclass class HJBProblem: """Complete specification of an HJB optimal control problem.""" state_space: StateSpace control_variables: List[ControlVariable] utility_function: UtilityFunction dynamics: Callable[[np.ndarray, np.ndarray, float], np.ndarray] running_cost: Callable[[np.ndarray, np.ndarray, float], np.ndarray] terminal_value: Optional[Callable[[np.ndarray], np.ndarray]] = None discount_rate: float = 0.0 time_horizon: Optional[float] = None diffusion: Optional[Callable[[np.ndarray, np.ndarray, float], np.ndarray]] = None """Optional callback returning σ²(x,u,t) with same shape as dynamics output. When provided, the solver includes the ½σ²·∇²V diffusion term."""
[docs] def __post_init__(self): """Validate problem specification.""" if self.time_horizon is not None and self.time_horizon <= 0: raise ValueError("Time horizon must be positive") if self.discount_rate < 0: raise ValueError("Discount rate must be non-negative") # For finite horizon problems, terminal value is required if self.time_horizon is not None and self.terminal_value is None: # Default to zero terminal value self.terminal_value = lambda x: np.zeros_like(x[..., 0])
[docs] @dataclass class HJBSolverConfig: """Configuration for HJB solver.""" time_step: float = 0.01 max_iterations: int = 1000 tolerance: float = 1e-6 scheme: TimeSteppingScheme = TimeSteppingScheme.EXPLICIT use_sparse: bool = True verbose: bool = True inner_max_iterations: int = 100 inner_tolerance_factor: float = 0.1 # inner_tol = tolerance * this rannacher_steps: int = 2 # Number of implicit half-step pairs for CN startup control_search_strategy: str = "auto" # "auto", "vectorized", "adaptive", "loop", "gradient" control_memory_budget_mb: int = 256 # Max memory for batched control evaluation
[docs] class HJBSolver: """Hamilton-Jacobi-Bellman PDE solver for optimal control. Implements finite difference methods with upwind schemes for solving HJB equations. Supports multi-dimensional state spaces and various boundary conditions. """ def __init__(self, problem: HJBProblem, config: HJBSolverConfig): """Initialize HJB solver. Args: problem: HJB problem specification config: Solver configuration """ self.problem = problem self.config = config # Initialize value function and policy self.value_function: np.ndarray | None = None self.optimal_policy: dict[str, np.ndarray] | None = None # Boundary values for Dirichlet BCs (captured from initial/terminal condition) self._boundary_values: dict[int, dict[str, np.ndarray]] | None = None # Cache for factorized sparse operators: (theta, dt, dim) -> (solve_func, B_or_None) self._operator_cache: Dict[ Tuple[float, float, int], Tuple[Any, Optional[sparse.spmatrix]] ] = {} # Set up finite difference operators self._setup_operators() logger.info(f"Initialized HJB solver for {problem.state_space.ndim}D problem") def _setup_operators(self): """Set up finite difference operators for the PDE.""" # Store per-interval grid spacings as arrays (supports non-uniform grids) self.dx = [] for sv in self.problem.state_space.state_variables: grid = sv.get_grid() if len(grid) > 1: self.dx.append(np.diff(grid)) else: self.dx.append(np.array([1.0])) # Will construct operators during solve based on current policy self.operators_initialized = True def _compute_cfl_number( self, drift: np.ndarray, sigma_sq: np.ndarray | None, dt: float, ) -> tuple[float, float]: """Compute CFL numbers for advection and diffusion stability. Args: drift: Drift values on the state grid, shape state_shape + (ndim,) sigma_sq: Diffusion coefficients on state grid, shape state_shape + (ndim,), or None if no diffusion. dt: Time step size. Returns: Tuple of (advection_cfl, diffusion_cfl). """ advection_cfl = 0.0 diffusion_cfl = 0.0 n_dims = min(drift.shape[-1], len(self.problem.state_space.state_variables)) for dim in range(n_dims): dx_arr = self.dx[dim] dx_min = float(np.min(dx_arr)) dx_mean = float(np.mean(dx_arr)) drift_max = float(np.max(np.abs(drift[..., dim]))) if dx_min > 0: advection_cfl = max(advection_cfl, drift_max * dt / dx_min) if sigma_sq is not None and dx_min > 0: # Effective diffusion coefficient is D = 0.5 * sigma_sq # Use dx_min (not dx_mean) — stability is governed by the # smallest grid spacing, which matters for non-uniform grids. sigma_sq_max = float(np.max(np.abs(sigma_sq[..., dim]))) diffusion_cfl = max(diffusion_cfl, 0.5 * sigma_sq_max * dt / (dx_min**2)) return advection_cfl, diffusion_cfl def _build_spatial_operator_1d( self, drift_1d: np.ndarray, sigma_sq_1d: np.ndarray | None, dim: int = 0, ) -> sparse.spmatrix: """Build the 1D spatial operator L as a sparse tridiagonal matrix. The operator represents: L(V)[i] = -rho*V[i] + drift*dV/dx + 0.5*sigma^2*d2V/dx2 using upwind first derivatives and non-uniform central second derivatives, consistent with the explicit scheme. The non-uniform second-derivative stencil is: d²V/dx²|_i = 2/(h_b+h_f) * [V[i+1]/h_f - V[i]*(1/h_b+1/h_f) + V[i-1]/h_b] where h_b = x[i]-x[i-1] and h_f = x[i+1]-x[i]. This is exact for quadratic functions and reduces to (V[i+1]-2V[i]+V[i-1])/h² on uniform grids. Args: drift_1d: Drift values at each grid point, shape (N,). sigma_sq_1d: Diffusion coefficient sigma^2 at each grid point, shape (N,), or None for pure advection. dim: Dimension index (for grid spacing lookup). Returns: Sparse CSR matrix of shape (N, N). """ N = len(drift_1d) dx_arr = self.dx[dim] rho = self.problem.discount_rate # Sub-diagonal (V[i-1] coefficients), super-diagonal (V[i+1] coefficients), # main diagonal (V[i] coefficients) — all for interior points sub = np.zeros(N) main = np.zeros(N) sup = np.zeros(N) # Interior points: i = 1, ..., N-2 interior = slice(1, N - 1) drift_int = drift_1d[interior] drift_pos = np.maximum(drift_int, 0.0) drift_neg = np.minimum(drift_int, 0.0) # dx for backward diff at point i: dx_arr[i-1] (h_b) dx_back = dx_arr[0 : N - 2] # dx for forward diff at point i: dx_arr[i] (h_f) dx_fwd = dx_arr[1 : N - 1] # Diffusion coefficient at interior points: D = 0.5 * sigma^2 if sigma_sq_1d is not None: D = 0.5 * sigma_sq_1d[interior] else: D = np.zeros(N - 2) # Non-uniform diffusion stencil coefficients: # d²V/dx²|_i = c_lo*V[i-1] + c_mid*V[i] + c_hi*V[i+1] # where: # c_lo = 2 / (h_b * (h_b + h_f)) # c_mid = -2 / (h_b * h_f) # c_hi = 2 / (h_f * (h_b + h_f)) h_sum = dx_back + dx_fwd diff_sub = D * 2.0 / (dx_back * h_sum) diff_main = -D * 2.0 / (dx_back * dx_fwd) diff_sup = D * 2.0 / (dx_fwd * h_sum) # Operator coefficients (consistent with _apply_upwind_scheme). # # The HJB PDE is V_t = drift*V_x + ..., which is V_t + (-drift)*V_x = 0. # The effective advection coefficient is a = -drift, so the upwind # direction is OPPOSITE to standard advection references that assume # V_t + a*V_x = 0 with a = drift. Concretely: # drift > 0 => a < 0 => upwind = forward diff # drift < 0 => a > 0 => upwind = backward diff sub[interior] = -drift_neg / dx_back + diff_sub main[interior] = -drift_pos / dx_fwd + drift_neg / dx_back + diff_main - rho sup[interior] = drift_pos / dx_fwd + diff_sup # Boundary rows are zero (will be handled after solve) L = sparse.diags( [sub[1:], main, sup[:-1]], offsets=[-1, 0, 1], shape=(N, N), format="csr", ) return L def _invalidate_operator_cache(self): """Clear the cached sparse matrix factorizations. Must be called whenever drift or diffusion coefficients may change (i.e., at the start of each policy evaluation cycle). """ self._operator_cache.clear() def _theta_step_1d( self, old_v: np.ndarray, cost: np.ndarray, drift_1d: np.ndarray, sigma_sq_1d: np.ndarray | None, dt: float, theta: float, dim: int = 0, ) -> np.ndarray: """Perform one theta-scheme time step for a 1D problem. Solves: (I - theta*dt*L)*V_new = (I + (1-theta)*dt*L)*V_old + dt*cost For theta=1: fully implicit (backward Euler). For theta=0.5: Crank-Nicolson. For theta=0: explicit (forward Euler). Args: old_v: Current value function, shape (N,). cost: Running cost at each grid point, shape (N,). drift_1d: Drift at each grid point, shape (N,). sigma_sq_1d: Diffusion coefficient sigma^2, shape (N,) or None. dt: Time step. theta: Implicitness parameter in [0, 1]. dim: Dimension index. Returns: Updated value function, shape (N,). """ N = len(old_v) cache_key = (theta, dt, dim) if cache_key in self._operator_cache: # Cache hit: reuse factorized solver and B matrix solve_func, B = self._operator_cache[cache_key] else: # Cache miss: build, factorize, and store L = self._build_spatial_operator_1d(drift_1d, sigma_sq_1d, dim) I_mat = sparse.eye(N, format="csr") # LHS: A = I - theta*dt*L A = I_mat - theta * dt * L # Set boundary rows to identity (boundary values handled by # _apply_boundary_conditions after return) A = A.tolil() A[0, :] = 0 A[0, 0] = 1.0 A[N - 1, :] = 0 A[N - 1, N - 1] = 1.0 A = A.tocsc() # Factorize once via SuperLU solve_func = sparse.linalg.splu(A).solve # Build B matrix for Crank-Nicolson (theta < 1) if theta < 1.0: B = I_mat + (1.0 - theta) * dt * L else: B = None self._operator_cache[cache_key] = (solve_func, B) # RHS: B @ old_v + dt * cost if B is not None: rhs = B @ old_v + dt * cost else: rhs = old_v + dt * cost # Preserve old boundary values in RHS (will be overwritten by BCs) rhs[0] = old_v[0] rhs[N - 1] = old_v[N - 1] new_v: np.ndarray = solve_func(rhs) return new_v def _build_difference_matrix( self, dim: int, boundary_type: BoundaryCondition ) -> sparse.spmatrix: """Build finite difference matrix for one dimension. Args: dim: Dimension index boundary_type: Type of boundary condition Returns: Sparse difference operator """ n = self.problem.state_space.state_variables[dim].num_points dx = float(np.mean(self.dx[dim])) # First derivative (upwind) # We'll build this dynamically based on drift direction # Second derivative (diffusion) diagonals = np.ones((3, n)) diagonals[0] *= 1.0 / (dx * dx) # Lower diagonal diagonals[1] *= -2.0 / (dx * dx) # Main diagonal diagonals[2] *= 1.0 / (dx * dx) # Upper diagonal # Apply boundary conditions if boundary_type == BoundaryCondition.DIRICHLET: # Fixed value at boundaries (handled separately) diagonals[0, 0] = 0 diagonals[1, 0] = 1 diagonals[2, 0] = 0 diagonals[0, -1] = 0 diagonals[1, -1] = 1 diagonals[2, -1] = 0 elif boundary_type in (BoundaryCondition.NEUMANN, BoundaryCondition.REFLECTING): # Zero first derivative at boundaries (ghost-node reflection) diagonals[1, 0] += diagonals[0, 0] diagonals[0, 0] = 0 diagonals[1, -1] += diagonals[2, -1] diagonals[2, -1] = 0 elif boundary_type == BoundaryCondition.ABSORBING: # Absorbing boundary: enforce d²V/dx² = 0 (linear extrapolation) # Lower boundary: V[0] - 2*V[1] + V[2] = 0 # Upper boundary: V[n-3] - 2*V[n-2] + V[n-1] = 0 # Row 0 needs entry at column 2 (offset +2), and row n-1 needs # entry at column n-3 (offset -2), so we build as tridiagonal # then add the extra entries. coeff = 1.0 / (dx * dx) # Row 0: [coeff, -2*coeff, 0, ..., 0] (tridiagonal part) diagonals[1, 0] = coeff diagonals[2, 0] = -2.0 * coeff # Row n-1: [0, ..., 0, -2*coeff, coeff] (tridiagonal part) diagonals[0, n - 2] = -2.0 * coeff diagonals[1, n - 1] = coeff matrix = sparse.diags(diagonals, offsets=[-1, 0, 1], shape=(n, n)) if boundary_type == BoundaryCondition.ABSORBING: # Add the off-tridiagonal entries for absorbing BCs coeff = 1.0 / (dx * dx) matrix = matrix.tolil() matrix[0, 2] = coeff # V[2] coefficient in lower boundary row matrix[n - 1, n - 3] = coeff # V[n-3] coefficient in upper boundary row matrix = matrix.tocsr() return matrix def _apply_upwind_scheme(self, value: np.ndarray, drift: np.ndarray, dim: int) -> np.ndarray: """Apply upwind finite difference for advection term. Uses proper boundary-aware slicing (no wraparound) and supports non-uniform grid spacing for all dimensions. Args: value: Value function on state grid drift: Drift values at each grid point dim: Dimension for differentiation Returns: Advection term contribution (drift * dV/dx) """ dx_array = self.dx[dim] # Array of per-interval spacings ndim = value.ndim n = value.shape[dim] result = np.zeros_like(value) # Build index slices for interior points hi = [slice(None)] * ndim lo = [slice(None)] * ndim hi[dim] = slice(1, None) # indices 1..N-1 lo[dim] = slice(None, -1) # indices 0..N-2 # Shape dx_array for broadcasting: (1, ..., N-1, ..., 1) dx_shape = [1] * ndim dx_shape[dim] = n - 1 dx_bc = dx_array.reshape(dx_shape) # Forward difference: (V[i+1] - V[i]) / dx[i] for i=0..N-2 # At i=N-1, the forward difference is 0 (boundary) fwd_diff = np.zeros_like(value) fwd_diff[tuple(lo)] = (value[tuple(hi)] - value[tuple(lo)]) / dx_bc # Backward difference: (V[i] - V[i-1]) / dx[i-1] for i=1..N-1 # At i=0, the backward difference is 0 (boundary) bwd_diff = np.zeros_like(value) bwd_diff[tuple(hi)] = (value[tuple(hi)] - value[tuple(lo)]) / dx_bc # Upwind selection based on the HJB sign convention V_t = drift*V_x + ... # The effective advection coefficient is -drift, so: # drift > 0 => forward diff (upwind for negative effective coefficient) # drift < 0 => backward diff (upwind for positive effective coefficient) mask_pos = drift > 0 mask_neg = drift < 0 result[mask_pos] = fwd_diff[mask_pos] * drift[mask_pos] result[mask_neg] = bwd_diff[mask_neg] * drift[mask_neg] return result def _compute_gradient(self) -> np.ndarray: """Compute numerical gradient of the value function. Uses np.gradient which handles non-uniform grids with second-order accurate central differences in the interior and first-order accurate one-sided differences at the boundaries. Returns: Gradient array with shape state_shape + (ndim,) """ if self.value_function is None: shape = self.problem.state_space.shape + (self.problem.state_space.ndim,) return np.zeros(shape) grids = self.problem.state_space.grids if self.problem.state_space.ndim == 1: grad_components = [np.gradient(self.value_function, grids[0])] else: grad_components = np.gradient(self.value_function, *grids) return np.stack(grad_components, axis=-1) def _compute_second_derivatives(self, value: np.ndarray) -> np.ndarray: """Compute second derivatives d²V/dx_i² for each dimension. Uses the non-uniform central difference formula for interior points: d²V/dx²|_i = 2/(h_b+h_f) * [(V[i+1]-V[i])/h_f - (V[i]-V[i-1])/h_b] where h_b = x[i]-x[i-1] and h_f = x[i+1]-x[i]. This is exact for quadratic functions on arbitrary grids and reduces to the standard (V[i+1]-2V[i]+V[i-1])/h² on uniform grids. Boundary values default to zero (no diffusion contribution at boundaries). Args: value: Value function on state grid Returns: Array with shape state_shape + (ndim,) containing d²V/dx_i² """ ndim = self.problem.state_space.ndim components = [] for dim in range(ndim): dx_array = self.dx[dim] d2v = np.zeros_like(value) hi = [slice(None)] * ndim mid = [slice(None)] * ndim lo = [slice(None)] * ndim hi[dim] = slice(2, None) mid[dim] = slice(1, -1) lo[dim] = slice(None, -2) # Per-interval spacings for interior points i = 1..N-2 # h_f[i] = dx_array[i] (forward: x[i+1] - x[i]) # h_b[i] = dx_array[i-1] (backward: x[i] - x[i-1]) n = value.shape[dim] h_b = dx_array[0 : n - 2] # backward spacing for interior points h_f = dx_array[1 : n - 1] # forward spacing for interior points # Broadcast spacings for multi-dimensional arrays bc_shape = [1] * ndim bc_shape[dim] = n - 2 h_b_bc = h_b.reshape(bc_shape) h_f_bc = h_f.reshape(bc_shape) # Non-uniform central difference: # d²V/dx² = 2/(h_b+h_f) * [(V[i+1]-V[i])/h_f - (V[i]-V[i-1])/h_b] d2v[tuple(mid)] = ( 2.0 / (h_b_bc + h_f_bc) * ( (value[tuple(hi)] - value[tuple(mid)]) / h_f_bc - (value[tuple(mid)] - value[tuple(lo)]) / h_b_bc ) ) components.append(d2v) return np.stack(components, axis=-1)
[docs] def solve(self) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: """Solve the HJB equation using policy iteration. Returns: Tuple of (value_function, optimal_policy_dict) """ logger.info("Starting HJB solution with policy iteration") # Initialize value function if self.problem.time_horizon is not None: # Finite horizon: start from terminal condition state_points = np.stack(self.problem.state_space.flat_grids, axis=-1) if self.problem.terminal_value is not None: terminal_values = self.problem.terminal_value(state_points) self.value_function = terminal_values.reshape(self.problem.state_space.shape) else: self.value_function = np.zeros(self.problem.state_space.shape) else: # Infinite horizon: initialize with zeros or heuristic self.value_function = np.zeros(self.problem.state_space.shape) # Capture boundary values for Dirichlet enforcement self._boundary_values = {} ndim = self.problem.state_space.ndim for dim in range(ndim): sv = self.problem.state_space.state_variables[dim] if ( sv.boundary_lower == BoundaryCondition.DIRICHLET or sv.boundary_upper == BoundaryCondition.DIRICHLET ): lo_idx: List[Any] = [slice(None)] * ndim hi_idx: List[Any] = [slice(None)] * ndim lo_idx[dim] = 0 hi_idx[dim] = -1 self._boundary_values[dim] = { "lower": self.value_function[tuple(lo_idx)].copy(), "upper": self.value_function[tuple(hi_idx)].copy(), } # Initialize policy with mid-range controls if self.optimal_policy is None: self.optimal_policy = {} for cv in self.problem.control_variables: self.optimal_policy[cv.name] = np.full( self.problem.state_space.shape, (cv.min_value + cv.max_value) / 2 ) # Policy iteration for iteration in range(self.config.max_iterations): old_value = self.value_function.copy() # Policy evaluation step self._policy_evaluation() # Check for NaN/Inf after policy evaluation (#453) if not np.all(np.isfinite(self.value_function)): n_nan = int(np.sum(np.isnan(self.value_function))) n_inf = int(np.sum(np.isinf(self.value_function))) raise NumericalDivergenceError( f"HJB solver diverged at outer iteration {iteration}: " f"value function contains {n_nan} NaN and {n_inf} Inf values. " f"Consider reducing time_step (current: {self.config.time_step}) " f"or using an implicit scheme." ) # Policy improvement step self._policy_improvement() # Check convergence value_change = np.max(np.abs(self.value_function - old_value)) if self.config.verbose and iteration % 10 == 0: logger.info(f"Iteration {iteration}: value change = {value_change:.6e}") if value_change < self.config.tolerance: logger.info(f"Converged after {iteration + 1} iterations") break if iteration == self.config.max_iterations - 1: logger.warning("Max iterations reached without convergence") return self.value_function, self.optimal_policy
def _reshape_cost(self, cost): """Helper method to reshape cost array.""" if hasattr(cost, "ndim") and cost.ndim > 1: # If cost is multi-dimensional, take the first column or mean if cost.shape[1] > 1: cost = np.mean(cost, axis=1) # Average across extra dimensions else: cost = cost[:, 0] # Take first column return cost.reshape(self.problem.state_space.shape) def _apply_upwind_drift(self, new_v, drift, dt): """Apply upwind differencing for drift term across all dimensions.""" if not np.any(np.abs(drift) > _DRIFT_THRESHOLD): return new_v if self.value_function is None: return new_v for dim in range(drift.shape[-1]): if dim >= len(self.problem.state_space.state_variables): continue drift_component = drift[..., dim] advection = self._apply_upwind_scheme(self.value_function, drift_component, dim) new_v = new_v + dt * advection return new_v def _apply_diffusion_term( self, value: np.ndarray, sigma_sq: np.ndarray, ) -> np.ndarray: """Compute diffusion contribution ½σ²∇²V. Args: value: Value function on state grid sigma_sq: Diffusion coefficients σ²(x,u) with shape state_shape + (ndim,) Returns: Diffusion term (scalar field, same shape as value) """ d2v = self._compute_second_derivatives(value) diffusion = np.zeros_like(value) n_dims = min(sigma_sq.shape[-1], d2v.shape[-1]) for dim in range(n_dims): if dim >= len(self.problem.state_space.state_variables): continue diffusion += 0.5 * sigma_sq[..., dim] * d2v[..., dim] return diffusion def _apply_boundary_conditions(self, value: np.ndarray) -> np.ndarray: """Enforce boundary conditions on the value function. Applies the prescribed boundary condition for each dimension: - ABSORBING: linear extrapolation so d²V/dx² = 0 at boundary - DIRICHLET: reset boundary to prescribed (initial/terminal) values - NEUMANN / REFLECTING: copy adjacent interior value so dV/dx = 0 Args: value: Value function on state grid. Returns: Value function with boundary conditions enforced. """ ndim = self.problem.state_space.ndim result = value.copy() for dim in range(ndim): sv = self.problem.state_space.state_variables[dim] # --- lower boundary --- lo: List[Any] = [slice(None)] * ndim p1: List[Any] = [slice(None)] * ndim p2: List[Any] = [slice(None)] * ndim lo[dim] = 0 p1[dim] = 1 p2[dim] = 2 if sv.boundary_lower == BoundaryCondition.ABSORBING: # V[0] = 2*V[1] - V[2] → d²V/dx² = 0 result[tuple(lo)] = 2.0 * result[tuple(p1)] - result[tuple(p2)] elif sv.boundary_lower == BoundaryCondition.DIRICHLET: # Reset to initial value (stored during initialization) if self._boundary_values is not None and dim in self._boundary_values: result[tuple(lo)] = self._boundary_values[dim]["lower"] elif sv.boundary_lower in ( BoundaryCondition.NEUMANN, BoundaryCondition.REFLECTING, ): # dV/dx = 0 → V[0] = V[1] result[tuple(lo)] = result[tuple(p1)] # --- upper boundary --- hi: List[Any] = [slice(None)] * ndim m1: List[Any] = [slice(None)] * ndim m2: List[Any] = [slice(None)] * ndim hi[dim] = -1 m1[dim] = -2 m2[dim] = -3 if sv.boundary_upper == BoundaryCondition.ABSORBING: # V[-1] = 2*V[-2] - V[-3] → d²V/dx² = 0 result[tuple(hi)] = 2.0 * result[tuple(m1)] - result[tuple(m2)] elif sv.boundary_upper == BoundaryCondition.DIRICHLET: if self._boundary_values is not None and dim in self._boundary_values: result[tuple(hi)] = self._boundary_values[dim]["upper"] elif sv.boundary_upper in ( BoundaryCondition.NEUMANN, BoundaryCondition.REFLECTING, ): # dV/dx = 0 → V[-1] = V[-2] result[tuple(hi)] = result[tuple(m1)] return result def _update_value_finite_horizon(self, old_v, cost, drift, dt, sigma_sq=None): """Update value function for finite horizon problems.""" # Backward Euler scheme for parabolic PDE new_v = old_v + dt * cost # Add discount term if applicable if self.problem.discount_rate > 0: new_v -= dt * self.problem.discount_rate * old_v # Add drift term using upwind differencing new_v = self._apply_upwind_drift(new_v, drift, dt) # Add diffusion term: ½σ²∇²V if sigma_sq is not None: new_v += dt * self._apply_diffusion_term(old_v, sigma_sq) return new_v def _policy_evaluation(self): """Evaluate current policy by solving linear PDE. Supports explicit, implicit, and Crank-Nicolson time-stepping schemes. For explicit scheme, performs CFL stability check and auto-adapts dt. """ if self.value_function is None or self.optimal_policy is None: return # Policy may have changed since last evaluation; invalidate cached operators self._invalidate_operator_cache() dt = self.config.time_step scheme = self.config.scheme for _ in range(self.config.inner_max_iterations): # Inner iterations for policy evaluation # Type guard ensures value_function is not None after the check above assert self.value_function is not None # For mypy old_v = self.value_function.copy() # Get state and control grids state_points = np.stack(self.problem.state_space.flat_grids, axis=-1) control_array = np.stack( [self.optimal_policy[cv.name].ravel() for cv in self.problem.control_variables], axis=-1, ) # Compute dynamics and running cost drift = self.problem.dynamics(state_points, control_array, 0.0) cost = self.problem.running_cost(state_points, control_array, 0.0) # Reshape drift = drift.reshape(self.problem.state_space.shape + (-1,)) cost = self._reshape_cost(cost) # Compute diffusion coefficient if specified sigma_sq = None if self.problem.diffusion is not None: sigma_sq_raw = self.problem.diffusion(state_points, control_array, 0.0) sigma_sq = sigma_sq_raw.reshape(self.problem.state_space.shape + (-1,)) # CFL stability check for explicit scheme (#452) if scheme == TimeSteppingScheme.EXPLICIT and _ == 0: adv_cfl, diff_cfl = self._compute_cfl_number(drift, sigma_sq, dt) if adv_cfl > 1.0 or diff_cfl > 1.0: # Compute safe dt max_rate = 0.0 n_dims = min(drift.shape[-1], len(self.problem.state_space.state_variables)) for dim in range(n_dims): dx_arr = self.dx[dim] dx_min = float(np.min(dx_arr)) dx_mean = float(np.mean(dx_arr)) drift_max = float(np.max(np.abs(drift[..., dim]))) if dx_min > 0: max_rate += drift_max / dx_min if sigma_sq is not None and dx_min > 0: max_rate += ( 0.5 * float(np.max(np.abs(sigma_sq[..., dim]))) / (dx_min**2) ) max_rate += self.problem.discount_rate if max_rate > 0: dt_safe = 0.9 / max_rate logger.warning( f"CFL condition violated (advection CFL={adv_cfl:.2f}, " f"diffusion CFL={diff_cfl:.2f}). " f"Auto-reducing dt from {dt:.4e} to {dt_safe:.4e}." ) dt = dt_safe # Time-stepping: branch on scheme (#451) use_implicit = scheme in ( TimeSteppingScheme.IMPLICIT, TimeSteppingScheme.CRANK_NICOLSON, ) can_use_implicit = use_implicit and self.problem.state_space.ndim == 1 if use_implicit and not can_use_implicit: if _ == 0: logger.warning( f"Implicit/CN schemes not yet supported for " f"{self.problem.state_space.ndim}D problems; " f"falling back to explicit." ) if can_use_implicit: # Implicit or Crank-Nicolson 1D step drift_1d = drift[:, 0] if drift.ndim > 1 else drift sigma_sq_1d = ( sigma_sq[:, 0] if sigma_sq is not None and sigma_sq.ndim > 1 else sigma_sq ) cost_1d = cost.ravel() if scheme == TimeSteppingScheme.CRANK_NICOLSON: if _ < self.config.rannacher_steps: # Rannacher startup: two implicit half-steps half_v = self._theta_step_1d( old_v.ravel(), cost_1d, drift_1d, sigma_sq_1d, dt / 2.0, theta=1.0, ) half_v = self._apply_boundary_conditions( half_v.reshape(old_v.shape) ).ravel() new_v = self._theta_step_1d( half_v, cost_1d, drift_1d, sigma_sq_1d, dt / 2.0, theta=1.0, ) else: # Crank-Nicolson step (theta=0.5) new_v = self._theta_step_1d( old_v.ravel(), cost_1d, drift_1d, sigma_sq_1d, dt, theta=0.5, ) else: # Fully implicit step (theta=1) new_v = self._theta_step_1d( old_v.ravel(), cost_1d, drift_1d, sigma_sq_1d, dt, theta=1.0, ) new_v = new_v.reshape(old_v.shape) else: # Explicit scheme (or fallback for multi-D) if self.problem.time_horizon is not None: new_v = self._update_value_finite_horizon(old_v, cost, drift, dt, sigma_sq) else: advection = np.zeros_like(old_v) for dim in range(drift.shape[-1]): if dim >= len(self.problem.state_space.state_variables): continue drift_component = drift[..., dim] advection += self._apply_upwind_scheme(old_v, drift_component, dim) rhs = -self.problem.discount_rate * old_v + cost + advection if sigma_sq is not None: rhs += self._apply_diffusion_term(old_v, sigma_sq) new_v = old_v + dt * rhs # Enforce boundary conditions after each time step new_v = self._apply_boundary_conditions(new_v) # Check for NaN/Inf after each inner step (#453) if not np.all(np.isfinite(new_v)): n_nan = int(np.sum(np.isnan(new_v))) n_inf = int(np.sum(np.isinf(new_v))) raise NumericalDivergenceError( f"HJB solver diverged during policy evaluation " f"(inner iteration {_}): " f"{n_nan} NaN and {n_inf} Inf values detected. " f"Value function range before step: " f"[{float(np.nanmin(old_v)):.4e}, {float(np.nanmax(old_v)):.4e}]." ) self.value_function = new_v # Check inner convergence if ( np.max(np.abs(new_v - old_v)) < self.config.tolerance * self.config.inner_tolerance_factor ): break def _precompute_upwind_diffs(self) -> List[Tuple[np.ndarray, np.ndarray]]: """Precompute forward and backward finite differences of the value function. Returns a list (one entry per dimension) of (fwd_diff_flat, bwd_diff_flat) arrays, each of shape (n_states,). These encode the same upwind scheme used by ``_apply_upwind_scheme`` so that the advection term for an arbitrary drift field can be assembled as:: advection = sum_dim( max(drift, 0)*fwd + min(drift, 0)*bwd ) without recomputing the differences for every control candidate. """ assert self.value_function is not None ndim = self.problem.state_space.ndim diffs: List[Tuple[np.ndarray, np.ndarray]] = [] for dim in range(ndim): dx_array = self.dx[dim] n = self.value_function.shape[dim] # Broadcast dx along the correct axis dx_shape = [1] * ndim dx_shape[dim] = n - 1 dx_bc = dx_array.reshape(dx_shape) hi = [slice(None)] * ndim lo = [slice(None)] * ndim hi[dim] = slice(1, None) # indices 1..N-1 lo[dim] = slice(None, -1) # indices 0..N-2 diff_vals = (self.value_function[tuple(hi)] - self.value_function[tuple(lo)]) / dx_bc # Forward difference: defined at indices 0..N-2, zero at N-1 fwd = np.zeros_like(self.value_function) fwd[tuple(lo)] = diff_vals # Backward difference: defined at indices 1..N-1, zero at 0 bwd = np.zeros_like(self.value_function) bwd[tuple(hi)] = diff_vals diffs.append((fwd.ravel(), bwd.ravel())) return diffs @staticmethod def _build_control_combos(control_samples: List[np.ndarray]) -> np.ndarray: """Build all control combinations as an (n_combos, n_controls) array. Uses np.meshgrid instead of itertools.product for efficient Cartesian product construction. """ grids = np.meshgrid(*control_samples, indexing="ij") return np.column_stack([g.ravel() for g in grids]) def _compute_chunk_size(self, n_combos: int, n_states: int, n_controls: int) -> int: """Determine batch size from memory budget. Each combo-state pair requires storage for controls, drift, cost, and advection arrays (~4 * n_dims * 8 bytes per pair). """ ndim = self.problem.state_space.ndim budget_bytes = self.config.control_memory_budget_mb * 1024 * 1024 # Estimate bytes per combo: states * (controls + drift + cost + advection) * 8 bytes_per_combo = n_states * (n_controls + ndim + 1 + 1) * 8 if bytes_per_combo == 0: return n_combos chunk = max(1, budget_bytes // bytes_per_combo) return min(chunk, n_combos) def _evaluate_and_update_best( self, combos: np.ndarray, state_points: np.ndarray, n_states: int, n_controls: int, ndim: int, fwd_arr: np.ndarray, bwd_arr: np.ndarray, d2V_flat: Optional[np.ndarray], best_values: np.ndarray, best_controls: np.ndarray, ) -> None: """Evaluate Hamiltonian for a batch of control combos and update best. Processes combos in chunks determined by the memory budget, vectorizing over both states and combos within each chunk. """ chunk_size = self._compute_chunk_size(len(combos), n_states, n_controls) for start in range(0, len(combos), chunk_size): chunk = combos[start : start + chunk_size] n_chunk = len(chunk) # Tile state_points for all combos in chunk: (n_chunk * n_states, state_dim) states_tiled = np.tile(state_points, (n_chunk, 1)) # Repeat each combo for all states: (n_chunk * n_states, n_controls) controls_tiled = np.repeat(chunk, n_states, axis=0) # Evaluate dynamics and running cost in one vectorized call drift = self.problem.dynamics(states_tiled, controls_tiled, 0.0) cost = self.problem.running_cost(states_tiled, controls_tiled, 0.0) # Reduce cost to 1D cost = np.asarray(cost) if cost.ndim > 1: if cost.shape[-1] > 1: cost = np.mean(cost, axis=-1) else: cost = cost[..., 0] cost = cost.flatten() # Reshape to (n_chunk, n_states) cost_2d = cost.reshape(n_chunk, n_states) # Compute advection: upwind scheme drift_flat = np.asarray(drift).reshape(n_chunk * n_states, -1) n_dims = min(drift_flat.shape[1], ndim) drift_pos = np.maximum(drift_flat[:, :n_dims], 0.0) drift_neg = np.minimum(drift_flat[:, :n_dims], 0.0) # Tile fwd/bwd arrays for all combos in chunk fwd_tiled = np.tile(fwd_arr[:, :n_dims], (n_chunk, 1)) bwd_tiled = np.tile(bwd_arr[:, :n_dims], (n_chunk, 1)) advection = np.sum(drift_pos * fwd_tiled + drift_neg * bwd_tiled, axis=1) advection_2d = advection.reshape(n_chunk, n_states) hamiltonian_2d = cost_2d + advection_2d # Add diffusion term if self.problem.diffusion is not None and d2V_flat is not None: sigma_sq = self.problem.diffusion(states_tiled, controls_tiled, 0.0) sigma_sq = np.asarray(sigma_sq).reshape(n_chunk * n_states, -1) n_diff = min(sigma_sq.shape[1], d2V_flat.shape[1], n_dims) diff_term = 0.5 * np.sum( sigma_sq[:, :n_diff] * np.tile(d2V_flat[:, :n_diff], (n_chunk, 1)), axis=1, ) hamiltonian_2d += diff_term.reshape(n_chunk, n_states) # Find best combo per state within this chunk chunk_best_idx = np.argmax(hamiltonian_2d, axis=0) # (n_states,) chunk_best_vals = hamiltonian_2d[chunk_best_idx, np.arange(n_states)] # Update global best improved = chunk_best_vals > best_values best_values[improved] = chunk_best_vals[improved] best_controls[improved] = chunk[chunk_best_idx[improved]] def _policy_improvement_loop( self, state_points: np.ndarray, n_states: int, n_controls: int, ndim: int, fwd_arr: np.ndarray, bwd_arr: np.ndarray, d2V_flat: Optional[np.ndarray], best_values: np.ndarray, best_controls: np.ndarray, control_samples: List[np.ndarray], ) -> None: """Legacy loop-based policy improvement (original implementation).""" for control_combo in itertools_product(*control_samples): control_array = np.array(control_combo) control_broadcast = np.tile(control_array, (n_states, 1)) drift = self.problem.dynamics(state_points, control_broadcast, 0.0) cost = self.problem.running_cost(state_points, control_broadcast, 0.0) cost = np.asarray(cost) if cost.ndim > 1: if cost.shape[-1] > 1: cost = np.mean(cost, axis=-1) else: cost = cost[..., 0] cost = cost.flatten() drift_flat = np.asarray(drift).reshape(n_states, -1) n_dims = min(drift_flat.shape[1], ndim) drift_pos = np.maximum(drift_flat[:, :n_dims], 0.0) drift_neg = np.minimum(drift_flat[:, :n_dims], 0.0) advection_flat = np.sum( drift_pos * fwd_arr[:, :n_dims] + drift_neg * bwd_arr[:, :n_dims], axis=1, ) hamiltonian = cost + advection_flat if self.problem.diffusion is not None and d2V_flat is not None: sigma_sq = self.problem.diffusion(state_points, control_broadcast, 0.0) sigma_sq = np.asarray(sigma_sq).reshape(n_states, -1) n_diff = min(sigma_sq.shape[1], d2V_flat.shape[1], n_dims) hamiltonian += 0.5 * np.sum(sigma_sq[:, :n_diff] * d2V_flat[:, :n_diff], axis=1) improved = hamiltonian > best_values best_values[improved] = hamiltonian[improved] best_controls[improved] = control_array def _policy_improvement_vectorized( self, state_points: np.ndarray, n_states: int, n_controls: int, ndim: int, fwd_arr: np.ndarray, bwd_arr: np.ndarray, d2V_flat: Optional[np.ndarray], best_values: np.ndarray, best_controls: np.ndarray, control_samples: List[np.ndarray], ) -> None: """Fully vectorized policy improvement over all control combos.""" combos = self._build_control_combos(control_samples) self._evaluate_and_update_best( combos, state_points, n_states, n_controls, ndim, fwd_arr, bwd_arr, d2V_flat, best_values, best_controls, ) def _policy_improvement_adaptive( self, state_points: np.ndarray, n_states: int, n_controls: int, ndim: int, fwd_arr: np.ndarray, bwd_arr: np.ndarray, d2V_flat: Optional[np.ndarray], best_values: np.ndarray, best_controls: np.ndarray, control_samples: List[np.ndarray], ) -> None: """Two-pass adaptive policy improvement: coarse search then local refinement.""" # Pass 1: Coarse grid (every _COARSE_STRIDE-th point per control) coarse_samples = [s[::_COARSE_STRIDE] for s in control_samples] coarse_combos = self._build_control_combos(coarse_samples) self._evaluate_and_update_best( coarse_combos, state_points, n_states, n_controls, ndim, fwd_arr, bwd_arr, d2V_flat, best_values, best_controls, ) # Pass 2: Refine around coarse optima # Find unique coarse optima (each row of best_controls is a combo) unique_optima = np.unique(best_controls, axis=0) # For each unique optimum, build a refined grid of nearby combos refined_combos_list = [] for optimum in unique_optima: per_control_refined = [] for j, full_grid in enumerate(control_samples): # Find the closest index in the full grid closest_idx = int(np.argmin(np.abs(full_grid - optimum[j]))) lo = max(0, closest_idx - _REFINE_RADIUS) hi = min(len(full_grid), closest_idx + _REFINE_RADIUS + 1) per_control_refined.append(full_grid[lo:hi]) local_combos = self._build_control_combos(per_control_refined) refined_combos_list.append(local_combos) if refined_combos_list: all_refined = np.vstack(refined_combos_list) # Deduplicate all_refined = np.unique(all_refined, axis=0) self._evaluate_and_update_best( all_refined, state_points, n_states, n_controls, ndim, fwd_arr, bwd_arr, d2V_flat, best_values, best_controls, ) def _policy_improvement(self): """Improve policy by maximizing the Hamiltonian. H(x,u) = f(x,u) + drift(x,u)·∇V(x) + ½σ²(x,u)·∇²V(x) Dispatches to vectorized, adaptive, or loop-based strategy based on ``config.control_search_strategy``. The precomputation of upwind finite differences is shared across all strategies. """ if self.value_function is None or self.optimal_policy is None: return state_points = np.stack(self.problem.state_space.flat_grids, axis=-1) n_states = state_points.shape[0] n_controls = len(self.problem.control_variables) ndim = self.problem.state_space.ndim # Initialize tracking arrays best_values = np.full(n_states, -np.inf) best_controls = np.zeros((n_states, n_controls)) # Compute second derivatives for diffusion term (independent of control) d2V_flat = None if self.problem.diffusion is not None: d2V = self._compute_second_derivatives(self.value_function) # Check for NaN in second derivatives (#453) if not np.all(np.isfinite(d2V)): logger.warning( "NaN/Inf detected in second derivatives during policy improvement; " "skipping policy update." ) return d2V_flat = d2V.reshape(n_states, -1) # Precompute upwind finite differences of V (control-independent). # Each entry is (fwd_flat, bwd_flat) with shape (n_states,). upwind_diffs = self._precompute_upwind_diffs() # Stack into (n_states, ndim) for vectorized advection computation fwd_arr = np.column_stack([f for f, _ in upwind_diffs]) # (n_states, ndim) bwd_arr = np.column_stack([b for _, b in upwind_diffs]) # (n_states, ndim) # Get discrete control samples for each control variable control_samples = [cv.get_values() for cv in self.problem.control_variables] # Shared arguments for all strategies args = ( state_points, n_states, n_controls, ndim, fwd_arr, bwd_arr, d2V_flat, best_values, best_controls, control_samples, ) # Determine strategy strategy = self.config.control_search_strategy if strategy == "gradient": raise NotImplementedError( "Gradient-based control search is reserved for future implementation." ) if strategy == "auto": n_combos = 1 for s in control_samples: n_combos *= len(s) strategy = "vectorized" if n_combos <= _VECTORIZE_COMBO_THRESHOLD else "adaptive" if strategy == "vectorized": self._policy_improvement_vectorized(*args) elif strategy == "adaptive": self._policy_improvement_adaptive(*args) else: # "loop" or any unrecognized value falls back to legacy self._policy_improvement_loop(*args) # Write optimal controls back to policy arrays for j, cv in enumerate(self.problem.control_variables): self.optimal_policy[cv.name] = best_controls[:, j].reshape( self.problem.state_space.shape )
[docs] def extract_feedback_control(self, state: np.ndarray) -> Dict[str, float]: """Extract feedback control law at given state. Args: state: Current state values Returns: Dictionary of control variable names to optimal values """ if self.optimal_policy is None: raise RuntimeError("Must solve HJB equation before extracting controls") # Interpolate policy at given state controls = {} for cv in self.problem.control_variables: policy_func = interpolate.RegularGridInterpolator( self.problem.state_space.grids, self.optimal_policy[cv.name], method="linear", bounds_error=False, fill_value=None, ) controls[cv.name] = float(policy_func(state.reshape(1, -1))[0]) return controls
[docs] def compute_convergence_metrics(self) -> Dict[str, Any]: """Compute metrics for assessing solution quality. Returns: Dictionary of convergence metrics """ if self.value_function is None: return {"error": "No solution computed yet"} # Compute residual of HJB equation state_points = np.stack(self.problem.state_space.flat_grids, axis=-1) # optimal_policy should be non-None at this point since value_function is non-None assert self.optimal_policy is not None control_array = np.stack( [self.optimal_policy[cv.name].ravel() for cv in self.problem.control_variables], axis=-1 ) # Evaluate HJB residual: |−ρV + f(x,u) + drift·∇V| drift = self.problem.dynamics(state_points, control_array, 0.0) cost = self.problem.running_cost(state_points, control_array, 0.0) v_flat = self.value_function.ravel() cost_flat = cost.ravel() if hasattr(cost, "ravel") else cost # Compute drift * grad_V using upwind scheme (consistent with PDE) drift_reshaped = drift.reshape(self.problem.state_space.shape + (-1,)) advection = np.zeros(self.problem.state_space.shape) n_dims = min(drift_reshaped.shape[-1], len(self.problem.state_space.state_variables)) for dim in range(n_dims): drift_component = drift_reshaped[..., dim] advection += self._apply_upwind_scheme(self.value_function, drift_component, dim) advection_flat = advection.ravel() # Include diffusion in residual if present diffusion_flat = np.zeros_like(v_flat) if self.problem.diffusion is not None: sigma_sq = self.problem.diffusion(state_points, control_array, 0.0) sigma_sq_reshaped = sigma_sq.reshape(self.problem.state_space.shape + (-1,)) diffusion_term = self._apply_diffusion_term(self.value_function, sigma_sq_reshaped) diffusion_flat = diffusion_term.ravel() residual = np.abs( -self.problem.discount_rate * v_flat + cost_flat + advection_flat + diffusion_flat ) # Check for NaN/Inf (#453) has_nan_inf = not np.all(np.isfinite(self.value_function)) return { "has_nan_inf": has_nan_inf, "max_residual": float(np.max(residual)), "mean_residual": float(np.mean(residual)), "value_function_range": ( float(np.min(self.value_function)), float(np.max(self.value_function)), ), "policy_stats": { cv.name: { "min": float(np.min(self.optimal_policy[cv.name])), "max": float(np.max(self.optimal_policy[cv.name])), "mean": float(np.mean(self.optimal_policy[cv.name])), } for cv in self.problem.control_variables }, }
[docs] def create_custom_utility( evaluate_func: Callable[[np.ndarray], np.ndarray], derivative_func: Callable[[np.ndarray], np.ndarray], inverse_derivative_func: Optional[Callable[[np.ndarray], np.ndarray]] = None, ) -> UtilityFunction: """Factory function for creating custom utility functions. This function allows users to create custom utility functions by providing the evaluation and derivative functions. This is the recommended way to add new utility functions beyond the built-in ones. Args: evaluate_func: Function that evaluates U(w) derivative_func: Function that computes U'(w) inverse_derivative_func: Optional function for (U')^(-1)(m) Returns: Custom utility function instance Example: >>> # Create exponential utility: U(w) = 1 - exp(-α*w) >>> def exp_eval(w): ... alpha = 0.01 ... return 1 - np.exp(-alpha * w) >>> def exp_deriv(w): ... alpha = 0.01 ... return alpha * np.exp(-alpha * w) >>> exp_utility = create_custom_utility(exp_eval, exp_deriv) """ class CustomUtility(UtilityFunction): """Dynamically created custom utility function.""" def evaluate(self, wealth: np.ndarray) -> np.ndarray: return evaluate_func(wealth) def derivative(self, wealth: np.ndarray) -> np.ndarray: return derivative_func(wealth) def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray: if inverse_derivative_func is None: raise NotImplementedError("Inverse derivative not provided for custom utility") return inverse_derivative_func(marginal_utility) return CustomUtility()