"""Adaptive stopping criteria for Monte Carlo simulations.
This module implements adaptive stopping rules based on convergence diagnostics,
allowing simulations to terminate early when convergence criteria are met.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import warnings
import numpy as np
from scipy import stats
[docs]
class StoppingRule(Enum):
"""Enumeration of available stopping rules."""
R_HAT = "r_hat"
ESS = "ess"
RELATIVE_CHANGE = "relative_change"
MCSE = "mcse"
GEWEKE = "geweke"
HEIDELBERGER = "heidelberger"
COMBINED = "combined"
CUSTOM = "custom"
[docs]
@dataclass
class StoppingCriteria:
"""Configuration for stopping criteria."""
rule: StoppingRule = StoppingRule.COMBINED
r_hat_threshold: float = 1.05
min_ess: int = 1000
relative_tolerance: float = 0.01
mcse_relative_threshold: float = 0.05
min_iterations: int = 1000
max_iterations: int = 100000
check_interval: int = 100
patience: int = 3
confidence_level: float = 0.95
[docs]
def __post_init__(self):
"""Validate criteria after initialization."""
if self.r_hat_threshold <= 1.0:
raise ValueError("R-hat threshold must be > 1.0")
if self.min_ess < 100:
warnings.warn("Very low ESS threshold may lead to poor estimates")
if self.min_iterations < 100:
warnings.warn("Very low minimum iterations may lead to premature stopping")
[docs]
@dataclass
class ConvergenceStatus:
"""Container for convergence status information."""
converged: bool
iteration: int
reason: str
diagnostics: Dict[str, float]
should_stop: bool
estimated_remaining: Optional[int] = None
[docs]
def __str__(self) -> str:
"""String representation of convergence status."""
status = "CONVERGED" if self.converged else "NOT CONVERGED"
return f"ConvergenceStatus({status} at iteration {self.iteration}): " f"{self.reason}"
[docs]
class AdaptiveStoppingMonitor:
"""Monitor for adaptive stopping based on convergence criteria.
Provides sophisticated adaptive stopping with multiple criteria,
burn-in detection, and convergence rate estimation.
"""
def __init__(
self, criteria: Optional[StoppingCriteria] = None, custom_rule: Optional[Callable] = None
):
"""Initialize adaptive stopping monitor.
Args:
criteria: Stopping criteria configuration
custom_rule: Custom stopping rule function
"""
self.criteria = criteria or StoppingCriteria()
self.custom_rule = custom_rule
# History tracking
self.r_hat_history: List[float] = []
self.ess_history: List[float] = []
self.mean_history: List[float] = []
self.variance_history: List[float] = []
self.iteration_history: List[int] = []
# Convergence tracking
self.consecutive_convergence = 0
self.burn_in_detected = False
self.burn_in_iteration = 0
# Rate estimation
self.convergence_rate: Optional[float] = None
self.estimated_total_iterations: Optional[int] = None
[docs]
def check_convergence(
self, iteration: int, chains: np.ndarray, diagnostics: Optional[Dict[str, float]] = None
) -> ConvergenceStatus:
"""Check if convergence criteria are met.
Args:
iteration: Current iteration number
chains: Array of chain values
diagnostics: Pre-calculated diagnostics (optional)
Returns:
ConvergenceStatus object with convergence information
"""
# Don't check before minimum iterations
if iteration < self.criteria.min_iterations:
return ConvergenceStatus(
converged=False,
iteration=iteration,
reason=f"Below minimum iterations ({self.criteria.min_iterations})",
diagnostics={},
should_stop=False,
estimated_remaining=self.criteria.min_iterations - iteration,
)
# Check if at maximum iterations
if iteration >= self.criteria.max_iterations:
return ConvergenceStatus(
converged=False,
iteration=iteration,
reason=f"Maximum iterations reached ({self.criteria.max_iterations})",
diagnostics=diagnostics or {},
should_stop=True,
)
# Only check at intervals for efficiency
if iteration % self.criteria.check_interval != 0:
return ConvergenceStatus(
converged=False,
iteration=iteration,
reason="Not at check interval",
diagnostics={},
should_stop=False,
)
# Calculate diagnostics if not provided
if diagnostics is None:
diagnostics = self._calculate_diagnostics(chains)
# Update history
self._update_history(iteration, diagnostics)
# Detect burn-in if not already done
if not self.burn_in_detected:
self._detect_burn_in(chains, iteration)
# Check stopping rule
converged, reason = self._check_stopping_rule(diagnostics)
# Update consecutive convergence counter
if converged:
self.consecutive_convergence += 1
else:
self.consecutive_convergence = 0
# Determine if should stop (need patience consecutive convergences)
should_stop = converged and self.consecutive_convergence >= self.criteria.patience
# Estimate remaining iterations
estimated_remaining = self._estimate_remaining_iterations(iteration, converged, diagnostics)
return ConvergenceStatus(
converged=converged,
iteration=iteration,
reason=reason,
diagnostics=diagnostics,
should_stop=should_stop,
estimated_remaining=estimated_remaining,
)
[docs]
def detect_adaptive_burn_in(self, chains: np.ndarray, method: str = "geweke") -> int:
"""Detect burn-in period adaptively.
Args:
chains: Array of chain values
method: Method for burn-in detection
Returns:
Estimated burn-in period
"""
if chains.ndim == 1:
chains = chains.reshape(1, -1)
n_chains, n_iterations = chains.shape[:2]
if method == "geweke":
# Use Geweke test to find burn-in
burn_in_estimates = []
for chain_idx in range(n_chains):
chain = chains[chain_idx].flatten()
# Test different potential burn-in points
test_points = np.linspace(0, n_iterations // 2, 20).astype(int)
for test_point in test_points[1:]: # Skip 0
# Test stationarity after this point
test_chain = chain[test_point:]
if len(test_chain) < 100:
continue
# Geweke test
_z_score, p_value = self._geweke_test(test_chain)
if p_value > 0.05: # Stationary
burn_in_estimates.append(test_point)
break
if burn_in_estimates:
return int(np.median(burn_in_estimates))
return int(n_iterations // 10) # Default fallback
if method == "variance":
# Detect when variance stabilizes
window_size = max(10, n_iterations // 100)
variances = []
for i in range(window_size, n_iterations - window_size):
window_var = np.var(chains[:, i - window_size : i + window_size])
variances.append(window_var)
if len(variances) > 0:
# Find where variance change rate drops
var_change = np.abs(np.diff(variances))
threshold = np.percentile(var_change, 10)
stable_points = np.where(var_change < threshold)[0]
if len(stable_points) > 0:
return int(stable_points[0] + window_size)
return int(n_iterations // 10) # Default fallback
raise ValueError(f"Unknown burn-in detection method: {method}")
[docs]
def estimate_convergence_rate(
self, diagnostic_history: List[float], target_value: float = 1.0
) -> Tuple[float, int]:
"""Estimate convergence rate and iterations to target.
Args:
diagnostic_history: History of diagnostic values
target_value: Target value for convergence
Returns:
Tuple of (convergence_rate, estimated_iterations_to_target)
"""
if len(diagnostic_history) < 3:
return 0.0, -1
# Fit exponential decay model
iterations = np.arange(len(diagnostic_history))
values = np.array(diagnostic_history)
# Transform for linear regression
# Assuming: value = a * exp(-rate * iteration) + target
# log(value - target) = log(a) - rate * iteration
values_shifted = values - target_value
positive_mask = values_shifted > 0
if np.sum(positive_mask) < 2:
return 0.0, -1
log_values = np.log(values_shifted[positive_mask])
iterations_masked = iterations[positive_mask]
# Linear regression
if len(iterations_masked) >= 2:
slope, _intercept = np.polyfit(iterations_masked, log_values, 1)
rate = -slope
if rate > 0:
# Estimate iterations to reach target (within 1% of target)
current_value = values[-1]
if abs(current_value - target_value) > 0.01 * abs(target_value):
iterations_to_target = int(np.log(0.01) / (-rate))
else:
iterations_to_target = 0
return rate, iterations_to_target
return 0.0, -1
[docs]
def get_stopping_summary(self) -> Dict[str, Any]:
"""Get summary of stopping monitor state.
Returns:
Dictionary with monitor summary information
"""
summary = {
"iterations_checked": len(self.iteration_history),
"consecutive_convergence": self.consecutive_convergence,
"burn_in_detected": self.burn_in_detected,
"burn_in_iteration": self.burn_in_iteration,
"convergence_rate": self.convergence_rate,
"estimated_total_iterations": self.estimated_total_iterations,
"criteria": {
"rule": self.criteria.rule.value,
"r_hat_threshold": self.criteria.r_hat_threshold,
"min_ess": self.criteria.min_ess,
"patience": self.criteria.patience,
},
}
# Add latest diagnostics if available
if self.r_hat_history:
summary["latest_r_hat"] = self.r_hat_history[-1]
if self.ess_history:
summary["latest_ess"] = self.ess_history[-1]
if self.mean_history:
summary["latest_mean"] = self.mean_history[-1]
return summary
# Private helper methods
def _calculate_diagnostics(self, chains: np.ndarray) -> Dict[str, float]:
"""Calculate convergence diagnostics from chains."""
if chains.ndim == 1:
chains = chains.reshape(1, -1)
diagnostics = {}
# R-hat calculation (if multiple chains)
if chains.shape[0] > 1:
diagnostics["r_hat"] = self._calculate_r_hat(chains)
else:
diagnostics["r_hat"] = 1.0
# ESS calculation
pooled_chain = chains.flatten()
diagnostics["ess"] = self._calculate_ess(pooled_chain)
# Mean and variance
diagnostics["mean"] = np.mean(pooled_chain)
diagnostics["variance"] = np.var(pooled_chain, ddof=1)
# MCSE
if diagnostics["ess"] > 0:
diagnostics["mcse"] = np.sqrt(diagnostics["variance"] / diagnostics["ess"])
diagnostics["mcse_relative"] = (
diagnostics["mcse"] / abs(diagnostics["mean"])
if diagnostics["mean"] != 0
else np.inf
)
else:
diagnostics["mcse"] = np.inf
diagnostics["mcse_relative"] = np.inf
return diagnostics
def _calculate_r_hat(self, chains: np.ndarray) -> float:
"""Calculate Gelman-Rubin R-hat statistic."""
_n_chains, n_iterations = chains.shape[:2]
# Between-chain variance
chain_means = np.mean(chains, axis=1)
_grand_mean = np.mean(chain_means)
between_var = n_iterations * np.var(chain_means, ddof=1)
# Within-chain variance
within_vars = np.var(chains, axis=1, ddof=1)
within_var = np.mean(within_vars)
# Calculate R-hat
var_est = ((n_iterations - 1) * within_var + between_var) / n_iterations
r_hat = np.sqrt(var_est / within_var) if within_var > 0 else np.inf
return float(r_hat)
def _calculate_ess(self, chain: np.ndarray) -> float:
"""Calculate effective sample size."""
n = len(chain)
if n < 4:
return float(n)
# Calculate autocorrelations
max_lag = min(n // 4, 1000)
acf = self._calculate_acf(chain, max_lag)
# Find first negative autocorrelation
first_negative = np.where(acf < 0)[0]
if len(first_negative) > 0:
cutoff = first_negative[0]
else:
cutoff = len(acf)
# Sum autocorrelations (Geyer's method)
sum_acf = 1.0
for i in range(1, cutoff, 2):
if i + 1 < cutoff:
pair_sum = acf[i] + acf[i + 1]
if pair_sum > 0:
sum_acf += 2 * pair_sum
else:
break
ess = n / max(sum_acf, 1)
return float(min(ess, n))
def _calculate_acf(self, chain: np.ndarray, max_lag: int) -> np.ndarray:
"""Calculate autocorrelation function."""
n = len(chain)
chain_centered = chain - np.mean(chain)
c0 = np.dot(chain_centered, chain_centered) / n
acf = np.zeros(max_lag + 1)
acf[0] = 1.0
for lag in range(1, min(max_lag + 1, n)):
c_lag = np.dot(chain_centered[:-lag], chain_centered[lag:]) / n
acf[lag] = c_lag / c0 if c0 > 0 else 0
return acf
def _geweke_test(
self, chain: np.ndarray, first_frac: float = 0.1, last_frac: float = 0.5
) -> Tuple[float, float]:
"""Perform Geweke convergence test."""
n = len(chain)
n_first = int(n * first_frac)
n_last = int(n * last_frac)
first_portion = chain[:n_first]
last_portion = chain[-n_last:]
mean_first = np.mean(first_portion)
mean_last = np.mean(last_portion)
var_first = np.var(first_portion, ddof=1) / n_first
var_last = np.var(last_portion, ddof=1) / n_last
z_score = (mean_first - mean_last) / np.sqrt(var_first + var_last)
p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
return z_score, p_value
def _update_history(self, iteration: int, diagnostics: Dict[str, float]):
"""Update diagnostic history."""
self.iteration_history.append(iteration)
if "r_hat" in diagnostics:
self.r_hat_history.append(diagnostics["r_hat"])
if "ess" in diagnostics:
self.ess_history.append(diagnostics["ess"])
if "mean" in diagnostics:
self.mean_history.append(diagnostics["mean"])
if "variance" in diagnostics:
self.variance_history.append(diagnostics["variance"])
def _detect_burn_in(self, chains: np.ndarray, iteration: int):
"""Detect burn-in period."""
if iteration < 500: # Too early to detect
return
# Use adaptive burn-in detection
burn_in = self.detect_adaptive_burn_in(chains, method="geweke")
if burn_in < iteration // 2: # Reasonable burn-in found
self.burn_in_detected = True
self.burn_in_iteration = burn_in
def _check_stopping_rule( # pylint: disable=too-many-branches
self, diagnostics: Dict[str, float]
) -> Tuple[bool, str]:
"""Check if stopping rule is satisfied."""
if self.criteria.rule == StoppingRule.R_HAT:
r_hat = diagnostics.get("r_hat", np.inf)
converged = r_hat < self.criteria.r_hat_threshold
reason = f"R-hat = {r_hat:.4f} (threshold: {self.criteria.r_hat_threshold})"
elif self.criteria.rule == StoppingRule.ESS:
ess = diagnostics.get("ess", 0)
converged = ess >= self.criteria.min_ess
reason = f"ESS = {ess:.0f} (minimum: {self.criteria.min_ess})"
elif self.criteria.rule == StoppingRule.MCSE:
mcse_rel = diagnostics.get("mcse_relative", np.inf)
converged = mcse_rel < self.criteria.mcse_relative_threshold
reason = f"Relative MCSE = {mcse_rel:.4f} (threshold: {self.criteria.mcse_relative_threshold})"
elif self.criteria.rule == StoppingRule.RELATIVE_CHANGE:
converged = False
reason = "Relative change not yet implemented"
if len(self.mean_history) >= 2:
recent_mean = self.mean_history[-1]
previous_mean = self.mean_history[-2]
if previous_mean != 0:
rel_change = abs(recent_mean - previous_mean) / abs(previous_mean)
converged = rel_change < self.criteria.relative_tolerance
reason = f"Relative change = {rel_change:.4f} (tolerance: {self.criteria.relative_tolerance})"
elif self.criteria.rule == StoppingRule.COMBINED:
# Check all criteria
checks = []
r_hat = diagnostics.get("r_hat", np.inf)
r_hat_ok = r_hat < self.criteria.r_hat_threshold
checks.append((r_hat_ok, f"R-hat={r_hat:.3f}"))
ess = diagnostics.get("ess", 0)
ess_ok = ess >= self.criteria.min_ess
checks.append((ess_ok, f"ESS={ess:.0f}"))
mcse_rel = diagnostics.get("mcse_relative", np.inf)
mcse_ok = mcse_rel < self.criteria.mcse_relative_threshold
checks.append((mcse_ok, f"MCSE_rel={mcse_rel:.3f}"))
converged = all(check[0] for check in checks)
failed_checks = [check[1] for check in checks if not check[0]]
if converged:
reason = "All criteria met: " + ", ".join(check[1] for check in checks)
else:
reason = "Failed: " + ", ".join(failed_checks)
elif self.criteria.rule == StoppingRule.CUSTOM:
if self.custom_rule is not None:
converged, reason = self.custom_rule(diagnostics)
else:
converged = False
reason = "No custom rule provided"
else:
converged = False
reason = f"Unknown stopping rule: {self.criteria.rule}"
return converged, reason
def _estimate_remaining_iterations(
self, current_iteration: int, converged: bool, diagnostics: Dict[str, float]
) -> Optional[int]:
"""Estimate remaining iterations to convergence."""
if converged:
return 0
# Use R-hat history for estimation if available
if len(self.r_hat_history) >= 3:
rate, iterations_to_target = self.estimate_convergence_rate(
self.r_hat_history, target_value=self.criteria.r_hat_threshold
)
if iterations_to_target > 0:
self.convergence_rate = rate
self.estimated_total_iterations = current_iteration + iterations_to_target
return iterations_to_target
# Fallback: use simple linear extrapolation
if len(self.ess_history) >= 2:
current_ess = self.ess_history[-1]
previous_ess = self.ess_history[-2]
ess_rate = (current_ess - previous_ess) / self.criteria.check_interval
if ess_rate > 0:
ess_needed = self.criteria.min_ess - current_ess
iterations_needed = int(ess_needed / ess_rate)
return max(0, iterations_needed)
return None