"""Hamilton-Jacobi-Bellman solver for optimal insurance control.
This module implements a Hamilton-Jacobi-Bellman (HJB) partial differential equation
solver for finding optimal insurance strategies through dynamic programming. The solver
handles multi-dimensional state spaces and provides theoretically optimal control policies.
The HJB equation provides globally optimal solutions by solving:
∂V/∂t + max_u[L^u V + f(x,u)] = 0
where V is the value function, L^u is the controlled infinitesimal generator,
and f(x,u) is the running cost/reward.
Author: Alex Filiakov
Date: 2025-01-26
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from scipy import interpolate, sparse
logger = logging.getLogger(__name__)
[docs]
class TimeSteppingScheme(Enum):
"""Time stepping schemes for PDE integration."""
EXPLICIT = "explicit"
IMPLICIT = "implicit"
CRANK_NICOLSON = "crank_nicolson"
[docs]
class BoundaryCondition(Enum):
"""Types of boundary conditions."""
DIRICHLET = "dirichlet" # Fixed value
NEUMANN = "neumann" # Fixed derivative
ABSORBING = "absorbing" # Zero second derivative
REFLECTING = "reflecting" # Zero first derivative
[docs]
@dataclass
class StateVariable:
"""Definition of a state variable in the HJB problem."""
name: str
min_value: float
max_value: float
num_points: int
boundary_lower: BoundaryCondition = BoundaryCondition.ABSORBING
boundary_upper: BoundaryCondition = BoundaryCondition.ABSORBING
log_scale: bool = False
[docs]
def __post_init__(self):
"""Validate state variable configuration."""
if self.min_value >= self.max_value:
raise ValueError(f"min_value must be less than max_value for {self.name}")
if self.num_points < 3:
raise ValueError(f"Need at least 3 grid points for {self.name}")
if self.log_scale and self.min_value <= 0:
raise ValueError(f"Cannot use log scale with non-positive min_value for {self.name}")
[docs]
def get_grid(self) -> np.ndarray:
"""Generate grid points for this variable.
Returns:
Array of grid points
"""
if self.log_scale:
return np.logspace(np.log10(self.min_value), np.log10(self.max_value), self.num_points)
return np.linspace(self.min_value, self.max_value, self.num_points)
[docs]
@dataclass
class ControlVariable:
"""Definition of a control variable in the HJB problem."""
name: str
min_value: float
max_value: float
num_points: int = 50
continuous: bool = True
[docs]
def __post_init__(self):
"""Validate control variable configuration."""
if self.min_value >= self.max_value:
raise ValueError(f"min_value must be less than max_value for {self.name}")
if self.num_points < 2:
raise ValueError(f"Need at least 2 control points for {self.name}")
[docs]
def get_values(self) -> np.ndarray:
"""Get discrete control values for optimization.
Returns:
Array of control values
"""
return np.linspace(self.min_value, self.max_value, self.num_points)
[docs]
@dataclass
class StateSpace:
"""Multi-dimensional state space for HJB problem.
Handles arbitrary dimensionality with proper grid management
and boundary condition enforcement.
"""
state_variables: List[StateVariable]
[docs]
def __post_init__(self):
"""Initialize derived attributes."""
self.ndim = len(self.state_variables)
self.shape = tuple(sv.num_points for sv in self.state_variables)
self.size = np.prod(self.shape)
# Create grids for each dimension
self.grids = [sv.get_grid() for sv in self.state_variables]
# Create meshgrid for full state space
self.meshgrid = np.meshgrid(*self.grids, indexing="ij")
# Flatten for linear algebra operations
self.flat_grids = [mg.ravel() for mg in self.meshgrid]
logger.info(f"Initialized {self.ndim}D state space with shape {self.shape}")
[docs]
def get_boundary_mask(self) -> np.ndarray:
"""Get boolean mask for boundary points.
Returns:
Boolean array where True indicates boundary points
"""
mask = np.zeros(self.shape, dtype=bool)
for dim, _sv in enumerate(self.state_variables):
# Create slice for this dimension's boundaries
slices_lower: list[slice | int] = [slice(None)] * self.ndim
slices_lower[dim] = 0
mask[tuple(slices_lower)] = True
slices_upper: list[slice | int] = [slice(None)] * self.ndim
slices_upper[dim] = -1
mask[tuple(slices_upper)] = True
return mask
[docs]
def interpolate_value(self, value_function: np.ndarray, points: np.ndarray) -> np.ndarray:
"""Interpolate value function at arbitrary points.
Args:
value_function: Value function on grid
points: Points to interpolate at (shape: [n_points, n_dims])
Returns:
Interpolated values
"""
if self.ndim == 1:
interp = interpolate.interp1d(
self.grids[0],
value_function,
kind="cubic",
bounds_error=False,
fill_value="extrapolate",
)
return np.array(interp(points[:, 0]))
if self.ndim == 2:
interp = interpolate.RegularGridInterpolator(
self.grids, value_function, method="linear", bounds_error=False, fill_value=None
)
return np.array(interp(points))
# For higher dimensions, use linear interpolation
interp = interpolate.RegularGridInterpolator(
self.grids, value_function, method="linear", bounds_error=False, fill_value=None
)
return np.array(interp(points))
[docs]
class UtilityFunction(ABC):
"""Abstract base class for utility functions.
Defines the interface for utility functions used in the HJB equation.
Concrete implementations should provide both the utility value and its derivative.
"""
[docs]
@abstractmethod
def evaluate(self, wealth: np.ndarray) -> np.ndarray:
"""Evaluate utility at given wealth levels.
Args:
wealth: Wealth values
Returns:
Utility values
"""
pass # pylint: disable=unnecessary-pass
[docs]
@abstractmethod
def derivative(self, wealth: np.ndarray) -> np.ndarray:
"""Compute marginal utility (first derivative).
Args:
wealth: Wealth values
Returns:
Marginal utility values
"""
pass # pylint: disable=unnecessary-pass
[docs]
@abstractmethod
def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray:
"""Compute inverse of marginal utility.
Used for finding optimal controls in some formulations.
Args:
marginal_utility: Marginal utility values
Returns:
Wealth values corresponding to given marginal utilities
"""
pass # pylint: disable=unnecessary-pass
[docs]
class LogUtility(UtilityFunction):
"""Logarithmic utility function for ergodic optimization.
U(w) = log(w)
This utility function maximizes the long-term growth rate and is
particularly suitable for ergodic analysis.
"""
def __init__(self, wealth_floor: float = 1e-6):
"""Initialize log utility.
Args:
wealth_floor: Minimum wealth to prevent log(0)
"""
self.wealth_floor = wealth_floor
[docs]
def evaluate(self, wealth: np.ndarray) -> np.ndarray:
"""Evaluate log utility."""
safe_wealth = np.maximum(wealth, self.wealth_floor)
return np.array(np.log(safe_wealth))
[docs]
def derivative(self, wealth: np.ndarray) -> np.ndarray:
"""Compute marginal utility: U'(w) = 1/w."""
safe_wealth = np.maximum(wealth, self.wealth_floor)
return np.array(1.0 / safe_wealth)
[docs]
def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray:
"""Compute inverse: (U')^(-1)(m) = 1/m."""
safe_marginal = np.maximum(marginal_utility, 1e-10)
return np.array(1.0 / safe_marginal)
[docs]
class PowerUtility(UtilityFunction):
"""Power (CRRA) utility function with risk aversion parameter.
U(w) = w^(1-γ)/(1-γ) for γ ≠ 1
U(w) = log(w) for γ = 1
where γ is the coefficient of relative risk aversion.
"""
def __init__(self, risk_aversion: float = 2.0, wealth_floor: float = 1e-6):
"""Initialize power utility.
Args:
risk_aversion: Coefficient of relative risk aversion (γ)
wealth_floor: Minimum wealth to prevent numerical issues
"""
self.gamma = risk_aversion
self.wealth_floor = wealth_floor
# Use log utility if gamma is close to 1
if abs(self.gamma - 1.0) < 1e-10:
self._log_utility = LogUtility(wealth_floor)
[docs]
def evaluate(self, wealth: np.ndarray) -> np.ndarray:
"""Evaluate power utility."""
if abs(self.gamma - 1.0) < 1e-10:
return self._log_utility.evaluate(wealth)
safe_wealth = np.maximum(wealth, self.wealth_floor)
return np.array(np.power(safe_wealth, 1 - self.gamma) / (1 - self.gamma))
[docs]
def derivative(self, wealth: np.ndarray) -> np.ndarray:
"""Compute marginal utility: U'(w) = w^(-γ)."""
if abs(self.gamma - 1.0) < 1e-10:
return self._log_utility.derivative(wealth)
safe_wealth = np.maximum(wealth, self.wealth_floor)
return np.array(np.power(safe_wealth, -self.gamma))
[docs]
def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray:
"""Compute inverse: (U')^(-1)(m) = m^(-1/γ)."""
if abs(self.gamma - 1.0) < 1e-10:
return self._log_utility.inverse_derivative(marginal_utility)
safe_marginal = np.maximum(marginal_utility, 1e-10)
return np.array(np.power(safe_marginal, -1.0 / self.gamma))
[docs]
class ExpectedWealth(UtilityFunction):
"""Linear utility function for risk-neutral wealth maximization.
U(w) = w
This represents risk-neutral preferences where the goal is to
maximize expected wealth.
"""
[docs]
def evaluate(self, wealth: np.ndarray) -> np.ndarray:
"""Evaluate linear utility."""
return wealth
[docs]
def derivative(self, wealth: np.ndarray) -> np.ndarray:
"""Compute marginal utility: U'(w) = 1."""
return np.ones_like(wealth)
[docs]
def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray:
"""Inverse is undefined for constant marginal utility."""
raise NotImplementedError("Inverse derivative undefined for linear utility")
[docs]
@dataclass
class HJBProblem:
"""Complete specification of an HJB optimal control problem."""
state_space: StateSpace
control_variables: List[ControlVariable]
utility_function: UtilityFunction
dynamics: Callable[[np.ndarray, np.ndarray, float], np.ndarray]
running_cost: Callable[[np.ndarray, np.ndarray, float], np.ndarray]
terminal_value: Optional[Callable[[np.ndarray], np.ndarray]] = None
discount_rate: float = 0.0
time_horizon: Optional[float] = None
[docs]
def __post_init__(self):
"""Validate problem specification."""
if self.time_horizon is not None and self.time_horizon <= 0:
raise ValueError("Time horizon must be positive")
if self.discount_rate < 0:
raise ValueError("Discount rate must be non-negative")
# For finite horizon problems, terminal value is required
if self.time_horizon is not None and self.terminal_value is None:
# Default to zero terminal value
self.terminal_value = lambda x: np.zeros_like(x[..., 0])
[docs]
@dataclass
class HJBSolverConfig:
"""Configuration for HJB solver."""
time_step: float = 0.01
max_iterations: int = 1000
tolerance: float = 1e-6
scheme: TimeSteppingScheme = TimeSteppingScheme.IMPLICIT
use_sparse: bool = True
verbose: bool = True
[docs]
class HJBSolver:
"""Hamilton-Jacobi-Bellman PDE solver for optimal control.
Implements finite difference methods with upwind schemes for solving
HJB equations. Supports multi-dimensional state spaces and various
boundary conditions.
"""
def __init__(self, problem: HJBProblem, config: HJBSolverConfig):
"""Initialize HJB solver.
Args:
problem: HJB problem specification
config: Solver configuration
"""
self.problem = problem
self.config = config
# Initialize value function and policy
self.value_function: np.ndarray | None = None
self.optimal_policy: dict[str, np.ndarray] | None = None
# Set up finite difference operators
self._setup_operators()
logger.info(f"Initialized HJB solver for {problem.state_space.ndim}D problem")
def _setup_operators(self):
"""Set up finite difference operators for the PDE."""
# Store per-interval grid spacings as arrays (supports non-uniform grids)
self.dx = []
for sv in self.problem.state_space.state_variables:
grid = sv.get_grid()
if len(grid) > 1:
self.dx.append(np.diff(grid))
else:
self.dx.append(np.array([1.0]))
# Will construct operators during solve based on current policy
self.operators_initialized = True
def _build_difference_matrix(
self, dim: int, boundary_type: BoundaryCondition
) -> sparse.spmatrix:
"""Build finite difference matrix for one dimension.
Args:
dim: Dimension index
boundary_type: Type of boundary condition
Returns:
Sparse difference operator
"""
n = self.problem.state_space.state_variables[dim].num_points
dx = float(np.mean(self.dx[dim]))
# First derivative (upwind)
# We'll build this dynamically based on drift direction
# Second derivative (diffusion)
diagonals = np.ones((3, n))
diagonals[0] *= 1.0 / (dx * dx) # Lower diagonal
diagonals[1] *= -2.0 / (dx * dx) # Main diagonal
diagonals[2] *= 1.0 / (dx * dx) # Upper diagonal
# Apply boundary conditions
if boundary_type == BoundaryCondition.DIRICHLET:
# Fixed value at boundaries (handled separately)
diagonals[0, 0] = 0
diagonals[1, 0] = 1
diagonals[2, 0] = 0
diagonals[0, -1] = 0
diagonals[1, -1] = 1
diagonals[2, -1] = 0
elif boundary_type == BoundaryCondition.NEUMANN:
# Zero derivative at boundaries
diagonals[1, 0] += diagonals[0, 0]
diagonals[0, 0] = 0
diagonals[1, -1] += diagonals[2, -1]
diagonals[2, -1] = 0
elif boundary_type == BoundaryCondition.ABSORBING:
# Absorbing boundaries: value is fixed at boundary
# First row: only main diagonal = 1
diagonals[1, 0] = 1
diagonals[2, 0] = 0 # No upper diagonal from first row
# Last row: only main diagonal = 1
# Note: For lower diagonal, element at index i goes to matrix[i+1, i]
# So to zero out matrix[n-1, n-2], we need diagonals[0, n-2]
diagonals[0, n - 2] = 0 # Zero out lower diagonal element going to last row
diagonals[1, n - 1] = 1 # Set main diagonal of last row to 1
matrix = sparse.diags(diagonals, offsets=[-1, 0, 1], shape=(n, n))
return matrix
def _apply_upwind_scheme(self, value: np.ndarray, drift: np.ndarray, dim: int) -> np.ndarray:
"""Apply upwind finite difference for advection term.
Uses proper boundary-aware slicing (no wraparound) and supports
non-uniform grid spacing for all dimensions.
Args:
value: Value function on state grid
drift: Drift values at each grid point
dim: Dimension for differentiation
Returns:
Advection term contribution (drift * dV/dx)
"""
dx_array = self.dx[dim] # Array of per-interval spacings
ndim = value.ndim
n = value.shape[dim]
result = np.zeros_like(value)
# Build index slices for interior points
hi = [slice(None)] * ndim
lo = [slice(None)] * ndim
hi[dim] = slice(1, None) # indices 1..N-1
lo[dim] = slice(None, -1) # indices 0..N-2
# Shape dx_array for broadcasting: (1, ..., N-1, ..., 1)
dx_shape = [1] * ndim
dx_shape[dim] = n - 1
dx_bc = dx_array.reshape(dx_shape)
# Forward difference: (V[i+1] - V[i]) / dx[i] for i=0..N-2
# At i=N-1, the forward difference is 0 (boundary)
fwd_diff = np.zeros_like(value)
fwd_diff[tuple(lo)] = (value[tuple(hi)] - value[tuple(lo)]) / dx_bc
# Backward difference: (V[i] - V[i-1]) / dx[i-1] for i=1..N-1
# At i=0, the backward difference is 0 (boundary)
bwd_diff = np.zeros_like(value)
bwd_diff[tuple(hi)] = (value[tuple(hi)] - value[tuple(lo)]) / dx_bc
# Upwind selection: forward for positive drift, backward for negative
mask_pos = drift > 0
mask_neg = drift < 0
result[mask_pos] = fwd_diff[mask_pos] * drift[mask_pos]
result[mask_neg] = bwd_diff[mask_neg] * drift[mask_neg]
return result
def _compute_gradient(self) -> np.ndarray:
"""Compute numerical gradient of the value function.
Uses np.gradient which handles non-uniform grids with second-order
accurate central differences in the interior and first-order accurate
one-sided differences at the boundaries.
Returns:
Gradient array with shape state_shape + (ndim,)
"""
if self.value_function is None:
shape = self.problem.state_space.shape + (self.problem.state_space.ndim,)
return np.zeros(shape)
grids = self.problem.state_space.grids
if self.problem.state_space.ndim == 1:
grad_components = [np.gradient(self.value_function, grids[0])]
else:
grad_components = np.gradient(self.value_function, *grids)
return np.stack(grad_components, axis=-1)
[docs]
def solve(self) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
"""Solve the HJB equation using policy iteration.
Returns:
Tuple of (value_function, optimal_policy_dict)
"""
logger.info("Starting HJB solution with policy iteration")
# Initialize value function
if self.problem.time_horizon is not None:
# Finite horizon: start from terminal condition
state_points = np.stack(self.problem.state_space.flat_grids, axis=-1)
if self.problem.terminal_value is not None:
terminal_values = self.problem.terminal_value(state_points)
self.value_function = terminal_values.reshape(self.problem.state_space.shape)
else:
self.value_function = np.zeros(self.problem.state_space.shape)
else:
# Infinite horizon: initialize with zeros or heuristic
self.value_function = np.zeros(self.problem.state_space.shape)
# Initialize policy with mid-range controls
if self.optimal_policy is None:
self.optimal_policy = {}
for cv in self.problem.control_variables:
self.optimal_policy[cv.name] = np.full(
self.problem.state_space.shape, (cv.min_value + cv.max_value) / 2
)
# Policy iteration
for iteration in range(self.config.max_iterations):
old_value = self.value_function.copy()
# Policy evaluation step
self._policy_evaluation()
# Policy improvement step
self._policy_improvement()
# Check convergence
value_change = np.max(np.abs(self.value_function - old_value))
if self.config.verbose and iteration % 10 == 0:
logger.info(f"Iteration {iteration}: value change = {value_change:.6e}")
if value_change < self.config.tolerance:
logger.info(f"Converged after {iteration + 1} iterations")
break
if iteration == self.config.max_iterations - 1:
logger.warning("Max iterations reached without convergence")
return self.value_function, self.optimal_policy
def _reshape_cost(self, cost):
"""Helper method to reshape cost array."""
if hasattr(cost, "ndim") and cost.ndim > 1:
# If cost is multi-dimensional, take the first column or mean
if cost.shape[1] > 1:
cost = np.mean(cost, axis=1) # Average across extra dimensions
else:
cost = cost[:, 0] # Take first column
return cost.reshape(self.problem.state_space.shape)
def _apply_upwind_drift(self, new_v, drift, dt):
"""Apply upwind differencing for drift term across all dimensions."""
if not np.any(np.abs(drift) > 1e-10):
return new_v
if self.value_function is None:
return new_v
for dim in range(drift.shape[-1]):
if dim >= len(self.problem.state_space.state_variables):
continue
drift_component = drift[..., dim]
advection = self._apply_upwind_scheme(self.value_function, drift_component, dim)
new_v = new_v + dt * advection
return new_v
def _update_value_finite_horizon(self, old_v, cost, drift, dt):
"""Update value function for finite horizon problems."""
# Backward Euler scheme for parabolic PDE
new_v = old_v + dt * cost
# Add discount term if applicable
if self.problem.discount_rate > 0:
new_v -= dt * self.problem.discount_rate * old_v
# Add drift term using upwind differencing
return self._apply_upwind_drift(new_v, drift, dt)
def _policy_evaluation(self):
"""Evaluate current policy by solving linear PDE."""
# For now, implement a simple iterative scheme
# In production, would use sparse linear solver
if self.value_function is None or self.optimal_policy is None:
return
dt = self.config.time_step
for _ in range(100): # Inner iterations for policy evaluation
# Type guard ensures value_function is not None after the check above
assert self.value_function is not None # For mypy
old_v = self.value_function.copy()
# Get state and control grids
state_points = np.stack(self.problem.state_space.flat_grids, axis=-1)
control_array = np.stack(
[self.optimal_policy[cv.name].ravel() for cv in self.problem.control_variables],
axis=-1,
)
# Compute dynamics and running cost
drift = self.problem.dynamics(state_points, control_array, 0.0)
cost = self.problem.running_cost(state_points, control_array, 0.0)
# Reshape
drift = drift.reshape(self.problem.state_space.shape + (-1,))
cost = self._reshape_cost(cost)
# Apply finite differences with upwind scheme
# For finite horizon problems, integrate backwards from terminal condition
if self.problem.time_horizon is not None:
new_v = self._update_value_finite_horizon(old_v, cost, drift, dt)
else:
# For infinite horizon: 0 = -ρV + f(x,u) + drift·∇V
# Time-step: V_new = V_old + dt * (-ρ*V_old + cost + drift·∇V)
advection = np.zeros_like(old_v)
for dim in range(drift.shape[-1]):
if dim >= len(self.problem.state_space.state_variables):
continue
drift_component = drift[..., dim]
advection += self._apply_upwind_scheme(old_v, drift_component, dim)
new_v = old_v + dt * (-self.problem.discount_rate * old_v + cost + advection)
# Apply boundary conditions (skip for now to preserve terminal condition)
self.value_function = new_v
# Check inner convergence
if np.max(np.abs(new_v - old_v)) < self.config.tolerance / 10:
break
def _policy_improvement(self):
"""Improve policy by maximizing Hamiltonian H(x,u) = f(x,u) + drift(x,u)·∇V(x).
Vectorized over all state points for each control combination to avoid
the combinatorial explosion of evaluating each state-control pair individually.
"""
if self.value_function is None or self.optimal_policy is None:
return
state_points = np.stack(self.problem.state_space.flat_grids, axis=-1)
n_states = state_points.shape[0]
n_controls = len(self.problem.control_variables)
# Compute value function gradient at all state points
grad_V = self._compute_gradient()
grad_V_flat = grad_V.reshape(n_states, -1)
# Initialize tracking arrays
best_values = np.full(n_states, -np.inf)
best_controls = np.zeros((n_states, n_controls))
# Get discrete control samples for each control variable
control_samples = [cv.get_values() for cv in self.problem.control_variables]
# Iterate over control combinations, vectorized over all states
from itertools import product
for control_combo in product(*control_samples):
control_array = np.array(control_combo)
control_broadcast = np.tile(control_array, (n_states, 1))
# Evaluate dynamics and running cost at all states simultaneously
drift = self.problem.dynamics(state_points, control_broadcast, 0.0)
cost = self.problem.running_cost(state_points, control_broadcast, 0.0)
# Reduce cost to 1D (one value per state point)
cost = np.asarray(cost)
if cost.ndim > 1:
if cost.shape[-1] > 1:
cost = np.mean(cost, axis=-1)
else:
cost = cost[..., 0]
cost = cost.flatten()
# Flatten drift for dot product with gradient
drift_flat = np.asarray(drift).reshape(n_states, -1)
# Match drift and gradient dimensions
n_dims = min(drift_flat.shape[1], grad_V_flat.shape[1])
# Full Hamiltonian: H = f(x,u) + drift(x,u) · ∇V(x)
hamiltonian = cost + np.sum(drift_flat[:, :n_dims] * grad_V_flat[:, :n_dims], axis=1)
# Update best control where this combo improves the Hamiltonian
improved = hamiltonian > best_values
best_values[improved] = hamiltonian[improved]
best_controls[improved] = control_array
# Write optimal controls back to policy arrays
for j, cv in enumerate(self.problem.control_variables):
self.optimal_policy[cv.name] = best_controls[:, j].reshape(
self.problem.state_space.shape
)
[docs]
def compute_convergence_metrics(self) -> Dict[str, Any]:
"""Compute metrics for assessing solution quality.
Returns:
Dictionary of convergence metrics
"""
if self.value_function is None:
return {"error": "No solution computed yet"}
# Compute residual of HJB equation
state_points = np.stack(self.problem.state_space.flat_grids, axis=-1)
# optimal_policy should be non-None at this point since value_function is non-None
assert self.optimal_policy is not None
control_array = np.stack(
[self.optimal_policy[cv.name].ravel() for cv in self.problem.control_variables], axis=-1
)
# Evaluate HJB residual
_drift = self.problem.dynamics(state_points, control_array, 0.0)
cost = self.problem.running_cost(state_points, control_array, 0.0)
# Approximate time derivative (backward difference)
_dt = self.config.time_step
v_flat = self.value_function.ravel()
# Simplified residual (would compute full PDE residual in production)
cost_flat = cost.ravel() if hasattr(cost, "ravel") else cost
residual = np.abs(-self.problem.discount_rate * v_flat + cost_flat)
return {
"max_residual": float(np.max(residual)),
"mean_residual": float(np.mean(residual)),
"value_function_range": (
float(np.min(self.value_function)),
float(np.max(self.value_function)),
),
"policy_stats": {
cv.name: {
"min": float(np.min(self.optimal_policy[cv.name])),
"max": float(np.max(self.optimal_policy[cv.name])),
"mean": float(np.mean(self.optimal_policy[cv.name])),
}
for cv in self.problem.control_variables
},
}
[docs]
def create_custom_utility(
evaluate_func: Callable[[np.ndarray], np.ndarray],
derivative_func: Callable[[np.ndarray], np.ndarray],
inverse_derivative_func: Optional[Callable[[np.ndarray], np.ndarray]] = None,
) -> UtilityFunction:
"""Factory function for creating custom utility functions.
This function allows users to create custom utility functions by providing
the evaluation and derivative functions. This is the recommended way to
add new utility functions beyond the built-in ones.
Args:
evaluate_func: Function that evaluates U(w)
derivative_func: Function that computes U'(w)
inverse_derivative_func: Optional function for (U')^(-1)(m)
Returns:
Custom utility function instance
Example:
>>> # Create exponential utility: U(w) = 1 - exp(-α*w)
>>> def exp_eval(w):
... alpha = 0.01
... return 1 - np.exp(-alpha * w)
>>> def exp_deriv(w):
... alpha = 0.01
... return alpha * np.exp(-alpha * w)
>>> exp_utility = create_custom_utility(exp_eval, exp_deriv)
"""
class CustomUtility(UtilityFunction):
"""Dynamically created custom utility function."""
def evaluate(self, wealth: np.ndarray) -> np.ndarray:
return evaluate_func(wealth)
def derivative(self, wealth: np.ndarray) -> np.ndarray:
return derivative_func(wealth)
def inverse_derivative(self, marginal_utility: np.ndarray) -> np.ndarray:
if inverse_derivative_func is None:
raise NotImplementedError("Inverse derivative not provided for custom utility")
return inverse_derivative_func(marginal_utility)
return CustomUtility()