"""Convergence diagnostics for Monte Carlo simulations.
This module provides tools for assessing convergence of Monte Carlo simulations
including Gelman-Rubin R-hat, effective sample size, and Monte Carlo standard error.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
[docs]
@dataclass
class ConvergenceStats:
"""Container for convergence statistics."""
r_hat: float
ess: float
mcse: float
converged: bool
n_iterations: int
autocorrelation: float
[docs]
def __str__(self) -> str:
"""String representation of convergence stats."""
return (
f"ConvergenceStats(r_hat={self.r_hat:.3f}, "
f"ess={self.ess:.0f}, mcse={self.mcse:.4f}, "
f"converged={self.converged})"
)
[docs]
class ConvergenceDiagnostics:
"""Convergence diagnostics for Monte Carlo simulations.
Provides methods for assessing convergence using multiple chains
and calculating effective sample sizes.
"""
def __init__(
self,
r_hat_threshold: float = 1.1,
min_ess: int = 1000,
relative_mcse_threshold: float = 0.05,
):
"""Initialize convergence diagnostics.
Args:
r_hat_threshold: Maximum R-hat for convergence (default 1.1)
min_ess: Minimum effective sample size (default 1000)
relative_mcse_threshold: Maximum relative MCSE (default 0.05)
"""
self.r_hat_threshold = r_hat_threshold
self.min_ess = min_ess
self.relative_mcse_threshold = relative_mcse_threshold
[docs]
def calculate_r_hat(self, chains: np.ndarray) -> float:
"""Calculate Gelman-Rubin R-hat statistic.
Args:
chains: Array of shape (n_chains, n_iterations) or (n_chains, n_iterations, n_metrics)
Returns:
R-hat statistic (values close to 1 indicate convergence)
"""
if chains.ndim == 2:
n_chains, n_iterations = chains.shape
elif chains.ndim == 3:
n_chains, n_iterations, n_metrics = chains.shape
# Calculate R-hat for each metric and return maximum
r_hats = [self.calculate_r_hat(chains[:, :, i]) for i in range(n_metrics)]
return max(r_hats)
else:
raise ValueError("Chains must be 2D or 3D array")
if n_chains < 2:
raise ValueError("Need at least 2 chains for R-hat calculation")
# Calculate between-chain variance
chain_means = np.mean(chains, axis=1)
_grand_mean = np.mean(chain_means)
between_var = n_iterations * np.var(chain_means, ddof=1)
# Calculate within-chain variance
within_vars = np.var(chains, axis=1, ddof=1)
within_var = np.mean(within_vars)
# Calculate pooled variance estimate
var_est = ((n_iterations - 1) * within_var + between_var) / n_iterations
# Calculate R-hat
r_hat = np.sqrt(var_est / within_var) if within_var > 0 else np.inf
return float(r_hat)
[docs]
def calculate_ess(self, chain: np.ndarray, max_lag: Optional[int] = None) -> float:
"""Calculate effective sample size using autocorrelation.
Uses the formula: ESS = N / (1 + 2 * sum(autocorrelations))
where the sum is truncated at the first negative autocorrelation.
Args:
chain: 1D array of samples
max_lag: Maximum lag for autocorrelation calculation
Returns:
Effective sample size
"""
n = len(chain)
if n < 4:
return float(n)
if max_lag is None:
max_lag = min(n // 4, 1000)
# Calculate autocorrelations
autocorr = self._calculate_autocorrelation(chain, max_lag)
# Find first negative autocorrelation (Geyer's initial monotone sequence)
first_negative = np.where(autocorr < 0)[0]
if len(first_negative) > 0:
cutoff = first_negative[0]
else:
cutoff = len(autocorr)
# Apply Geyer's initial positive sequence estimator
# Sum pairs of autocorrelations and stop when sum becomes negative
sum_autocorr = 1.0 # Start with lag 0 (always 1)
for i in range(1, cutoff, 2):
if i + 1 < cutoff:
pair_sum = autocorr[i] + autocorr[i + 1]
if pair_sum > 0:
sum_autocorr += 2 * pair_sum
else:
break
else:
# Handle odd final term
if autocorr[i] > 0:
sum_autocorr += 2 * autocorr[i]
# Calculate ESS
ess = n / max(sum_autocorr, 1)
return float(min(ess, n)) # ESS cannot exceed actual sample size
[docs]
def calculate_batch_ess(
self, chains: np.ndarray, method: str = "mean"
) -> Union[float, np.ndarray]:
"""Calculate ESS for multiple chains or metrics.
Args:
chains: Array of shape (n_chains, n_iterations) or (n_chains, n_iterations, n_metrics)
method: How to combine ESS across chains ('mean', 'min', 'all')
Returns:
Combined ESS value(s)
"""
if chains.ndim == 2:
# Multiple chains, single metric
ess_values = [self.calculate_ess(chain) for chain in chains]
elif chains.ndim == 3:
# Multiple chains, multiple metrics
n_chains, _n_iterations, n_metrics = chains.shape
ess_values = []
for m in range(n_metrics):
metric_ess = [self.calculate_ess(chains[c, :, m]) for c in range(n_chains)]
ess_values.append(metric_ess) # type: ignore
else:
raise ValueError("Chains must be 2D or 3D array")
# Process based on method
if method == "mean":
if chains.ndim == 2:
return float(np.mean(ess_values))
return np.array([np.mean(metric_ess) for metric_ess in ess_values])
if method == "min":
if chains.ndim == 2:
return float(np.min(ess_values))
return np.array([np.min(metric_ess) for metric_ess in ess_values])
if method == "all":
return np.array(ess_values)
raise ValueError(f"Unknown method: {method}")
[docs]
def calculate_ess_per_second(self, chain: np.ndarray, computation_time: float) -> float:
"""Calculate ESS per second of computation.
Useful for comparing efficiency of different sampling methods.
Args:
chain: 1D array of samples
computation_time: Time in seconds taken to generate the chain
Returns:
ESS per second
"""
ess = self.calculate_ess(chain)
return ess / computation_time if computation_time > 0 else 0.0
[docs]
def calculate_mcse(self, chain: np.ndarray, ess: Optional[float] = None) -> float:
"""Calculate Monte Carlo standard error.
Args:
chain: 1D array of samples
ess: Effective sample size (calculated if not provided)
Returns:
Monte Carlo standard error
"""
if ess is None:
ess = self.calculate_ess(chain)
# Calculate standard error using ESS
std_dev = np.std(chain, ddof=1)
mcse = std_dev / np.sqrt(ess)
return float(mcse)
[docs]
def check_convergence(
self, chains: Union[np.ndarray, List[np.ndarray]], metric_names: Optional[List[str]] = None
) -> Dict[str, ConvergenceStats]:
"""Check convergence for multiple chains and metrics.
Args:
chains: Array of shape (n_chains, n_iterations, n_metrics) or list of chains
metric_names: Names of metrics (optional)
Returns:
Dictionary mapping metric names to convergence statistics
"""
# Convert list to array if needed
if isinstance(chains, list):
chains = np.array(chains)
# Handle different array shapes
if chains.ndim == 1:
chains = chains.reshape(1, -1, 1)
elif chains.ndim == 2:
if chains.shape[0] < chains.shape[1]:
# Assume shape is (n_chains, n_iterations)
chains = chains.reshape(chains.shape[0], chains.shape[1], 1)
else:
# Assume shape is (n_iterations, n_metrics)
chains = chains.T.reshape(chains.shape[1], chains.shape[0], 1)
n_chains, n_iterations, n_metrics = chains.shape
if metric_names is None:
metric_names = [f"metric_{i}" for i in range(n_metrics)]
results = {}
for i, name in enumerate(metric_names):
metric_chains = chains[:, :, i]
# Calculate R-hat
r_hat = self.calculate_r_hat(metric_chains) if n_chains > 1 else 1.0
# Calculate ESS and MCSE for combined chain
combined_chain = metric_chains.flatten()
ess = self.calculate_ess(combined_chain)
mcse = self.calculate_mcse(combined_chain, ess)
# Calculate autocorrelation
autocorr = self._calculate_autocorrelation(combined_chain, 1)[0]
# Check convergence criteria
mean_val = np.mean(combined_chain)
relative_mcse = mcse / abs(mean_val) if mean_val != 0 else np.inf
converged = (
r_hat < self.r_hat_threshold
and ess >= self.min_ess
and relative_mcse < self.relative_mcse_threshold
)
results[name] = ConvergenceStats(
r_hat=r_hat,
ess=ess,
mcse=mcse,
converged=converged,
n_iterations=n_iterations * n_chains,
autocorrelation=autocorr,
)
return results
def _calculate_autocorrelation(self, chain: np.ndarray, max_lag: int) -> np.ndarray:
"""Calculate autocorrelation function.
Args:
chain: 1D array of samples
max_lag: Maximum lag
Returns:
Array of autocorrelations for lags 0 to max_lag
"""
n = len(chain)
chain = chain - np.mean(chain)
c0 = np.dot(chain, chain) / n
autocorr = np.zeros(max_lag + 1)
autocorr[0] = 1.0
for lag in range(1, min(max_lag + 1, n)):
c_lag = np.dot(chain[:-lag], chain[lag:]) / n
autocorr[lag] = c_lag / c0 if c0 > 0 else 0
return autocorr
[docs]
def geweke_test(
self, chain: np.ndarray, first_fraction: float = 0.1, last_fraction: float = 0.5
) -> Tuple[float, float]:
"""Perform Geweke convergence test.
Compares means of first and last portions of chain.
Args:
chain: 1D array of samples
first_fraction: Fraction of chain to use for first portion
last_fraction: Fraction of chain to use for last portion
Returns:
Tuple of (z-score, p-value)
"""
n = len(chain)
n_first = int(n * first_fraction)
n_last = int(n * last_fraction)
first_portion = chain[:n_first]
last_portion = chain[-n_last:]
# Calculate means and spectral density estimates
mean_first = np.mean(first_portion)
mean_last = np.mean(last_portion)
# Simple variance estimates (could use spectral density for more accuracy)
var_first = np.var(first_portion, ddof=1) / n_first
var_last = np.var(last_portion, ddof=1) / n_last
# Calculate z-score
z_score = (mean_first - mean_last) / np.sqrt(var_first + var_last)
# Calculate p-value
# Lazy import to avoid scipy issues in worker processes
from scipy import stats
p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
return z_score, p_value
[docs]
def heidelberger_welch_test(
self, chain: np.ndarray, alpha: float = 0.05
) -> Dict[str, Union[bool, float]]:
"""Perform Heidelberger-Welch stationarity and halfwidth tests.
Args:
chain: 1D array of samples
alpha: Significance level
Returns:
Dictionary with test results
"""
n = len(chain)
# Stationarity test using Cramer-von Mises
# Simplified version - checks if mean is stable
window_size = n // 10
means = []
for i in range(window_size, n - window_size):
window_mean = np.mean(chain[i - window_size : i + window_size])
means.append(window_mean)
# Check if means are stable
mean_variance = np.var(means)
overall_variance = np.var(chain)
stationarity_ratio = mean_variance / overall_variance if overall_variance > 0 else np.inf
stationary = stationarity_ratio < 0.1 # Heuristic threshold
# Halfwidth test
mean_estimate = np.mean(chain)
mcse = self.calculate_mcse(chain)
halfwidth = 1.96 * mcse # 95% confidence interval halfwidth
relative_halfwidth = halfwidth / abs(mean_estimate) if mean_estimate != 0 else np.inf
halfwidth_passed = relative_halfwidth < 0.1 # 10% relative precision
return {
"stationary": stationary,
"stationarity_ratio": stationarity_ratio,
"halfwidth_passed": halfwidth_passed,
"relative_halfwidth": relative_halfwidth,
"mean": mean_estimate,
"mcse": mcse,
}