"""Convergence diagnostics for Monte Carlo simulations.
This module provides tools for assessing convergence of Monte Carlo simulations
including Gelman-Rubin R-hat, effective sample size, and Monte Carlo standard error.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from scipy.fft import irfft, rfft
[docs]
@dataclass
class ConvergenceStats:
"""Container for convergence statistics."""
r_hat: float
ess: float
mcse: float
converged: bool
n_iterations: int
autocorrelation: float
[docs]
def __str__(self) -> str:
"""String representation of convergence stats."""
import math
r_hat_str = "nan" if math.isnan(self.r_hat) else f"{self.r_hat:.3f}"
return (
f"ConvergenceStats(r_hat={r_hat_str}, "
f"ess={self.ess:.0f}, mcse={self.mcse:.4f}, "
f"converged={self.converged})"
)
[docs]
class ConvergenceDiagnostics:
"""Convergence diagnostics for Monte Carlo simulations.
Provides methods for assessing convergence using multiple chains
and calculating effective sample sizes.
"""
def __init__(
self,
r_hat_threshold: float = 1.1,
min_ess: int = 1000,
relative_mcse_threshold: float = 0.05,
):
"""Initialize convergence diagnostics.
Args:
r_hat_threshold: Maximum R-hat for convergence (default 1.1)
min_ess: Minimum effective sample size (default 1000)
relative_mcse_threshold: Maximum relative MCSE (default 0.05)
"""
self.r_hat_threshold = r_hat_threshold
self.min_ess = min_ess
self.relative_mcse_threshold = relative_mcse_threshold
[docs]
def calculate_r_hat(self, chains: np.ndarray) -> float:
"""Calculate Gelman-Rubin R-hat statistic.
Args:
chains: Array of shape (n_chains, n_iterations) or (n_chains, n_iterations, n_metrics)
Returns:
R-hat statistic (values close to 1 indicate convergence)
"""
if chains.ndim == 2:
n_chains, n_iterations = chains.shape
elif chains.ndim == 3:
n_chains, n_iterations, n_metrics = chains.shape
# Calculate R-hat for each metric and return maximum
r_hats = [self.calculate_r_hat(chains[:, :, i]) for i in range(n_metrics)]
return max(r_hats)
else:
raise ValueError("Chains must be 2D or 3D array")
if n_chains < 2:
raise ValueError("Need at least 2 chains for R-hat calculation")
# Calculate between-chain variance
chain_means = np.mean(chains, axis=1)
_grand_mean = np.mean(chain_means)
between_var = n_iterations * np.var(chain_means, ddof=1)
# Calculate within-chain variance
within_vars = np.var(chains, axis=1, ddof=1)
within_var = np.mean(within_vars)
# Calculate pooled variance estimate
var_est = ((n_iterations - 1) * within_var + between_var) / n_iterations
# Calculate R-hat
r_hat = np.sqrt(var_est / within_var) if within_var > 0 else np.inf
return float(r_hat)
[docs]
def calculate_ess(self, chain: np.ndarray, max_lag: Optional[int] = None) -> float:
"""Calculate effective sample size using Geyer's initial positive sequence.
Uses Geyer's (1992, Theorem 3.1) initial positive sequence estimator:
ESS = N / tau, where tau = 1 + 2 * sum of consecutive ACF pair sums
(rho[2k-1] + rho[2k]) truncated at the first non-positive pair.
Individual autocorrelation values may be negative while pair sums
remain positive — this is common for oscillating MCMC chains.
Args:
chain: 1D array of samples
max_lag: Maximum lag for autocorrelation calculation
Returns:
Effective sample size
"""
n = len(chain)
if n < 4:
return float(n)
if max_lag is None:
max_lag = min(n // 4, 1000)
# Calculate autocorrelations
autocorr = self._calculate_autocorrelation(chain, max_lag)
# Geyer's initial positive sequence estimator (Geyer, 1992, Theorem 3.1)
# Iterate over consecutive pairs rho(2k-1) + rho(2k) and stop at the
# first pair whose sum is non-positive. No pre-truncation at individual
# negative autocorrelations — only pair sums matter.
sum_autocorr = 1.0 # lag-0 contribution
for i in range(1, len(autocorr) - 1, 2):
pair_sum = autocorr[i] + autocorr[i + 1]
if pair_sum > 0:
sum_autocorr += 2 * pair_sum
else:
break
# Calculate ESS
ess = n / max(sum_autocorr, 1)
return float(min(ess, n)) # ESS cannot exceed actual sample size
[docs]
def calculate_batch_ess(
self, chains: np.ndarray, method: str = "mean"
) -> Union[float, np.ndarray]:
"""Calculate ESS for multiple chains or metrics.
Args:
chains: Array of shape (n_chains, n_iterations) or (n_chains, n_iterations, n_metrics)
method: How to combine ESS across chains ('mean', 'min', 'all')
Returns:
Combined ESS value(s)
"""
if chains.ndim == 2:
# Multiple chains, single metric
ess_values = [self.calculate_ess(chain) for chain in chains]
elif chains.ndim == 3:
# Multiple chains, multiple metrics
n_chains, _n_iterations, n_metrics = chains.shape
ess_values = []
for m in range(n_metrics):
metric_ess = [self.calculate_ess(chains[c, :, m]) for c in range(n_chains)]
ess_values.append(metric_ess) # type: ignore
else:
raise ValueError("Chains must be 2D or 3D array")
# Process based on method
if method == "mean":
if chains.ndim == 2:
return float(np.mean(ess_values))
return np.array([np.mean(metric_ess) for metric_ess in ess_values])
if method == "min":
if chains.ndim == 2:
return float(np.min(ess_values))
return np.array([np.min(metric_ess) for metric_ess in ess_values])
if method == "all":
return np.array(ess_values)
raise ValueError(f"Unknown method: {method}")
[docs]
def calculate_ess_per_second(self, chain: np.ndarray, computation_time: float) -> float:
"""Calculate ESS per second of computation.
Useful for comparing efficiency of different sampling methods.
Args:
chain: 1D array of samples
computation_time: Time in seconds taken to generate the chain
Returns:
ESS per second
"""
ess = self.calculate_ess(chain)
return ess / computation_time if computation_time > 0 else 0.0
[docs]
def calculate_mcse(self, chain: np.ndarray, ess: Optional[float] = None) -> float:
"""Calculate Monte Carlo standard error.
Args:
chain: 1D array of samples
ess: Effective sample size (calculated if not provided)
Returns:
Monte Carlo standard error
"""
if ess is None:
ess = self.calculate_ess(chain)
# Calculate standard error using ESS
std_dev = np.std(chain, ddof=1)
mcse = std_dev / np.sqrt(ess)
return float(mcse)
[docs]
def check_convergence(
self, chains: Union[np.ndarray, List[np.ndarray]], metric_names: Optional[List[str]] = None
) -> Dict[str, ConvergenceStats]:
"""Check convergence for multiple chains and metrics.
Args:
chains: Array of shape (n_chains, n_iterations, n_metrics) or list of chains
metric_names: Names of metrics (optional)
Returns:
Dictionary mapping metric names to convergence statistics
"""
# Convert list to array if needed
if isinstance(chains, list):
chains = np.array(chains)
# Handle different array shapes
if chains.ndim == 1:
chains = chains.reshape(1, -1, 1)
elif chains.ndim == 2:
if chains.shape[0] < chains.shape[1]:
# Assume shape is (n_chains, n_iterations)
chains = chains.reshape(chains.shape[0], chains.shape[1], 1)
else:
# Assume shape is (n_iterations, n_metrics)
chains = chains.T.reshape(chains.shape[1], chains.shape[0], 1)
n_chains, n_iterations, n_metrics = chains.shape
if metric_names is None:
metric_names = [f"metric_{i}" for i in range(n_metrics)]
results = {}
for i, name in enumerate(metric_names):
metric_chains = chains[:, :, i]
# Calculate R-hat
r_hat = self.calculate_r_hat(metric_chains) if n_chains > 1 else 1.0
# Calculate ESS and MCSE for combined chain
combined_chain = metric_chains.flatten()
ess = self.calculate_ess(combined_chain)
mcse = self.calculate_mcse(combined_chain, ess)
# Calculate autocorrelation
autocorr = self._calculate_autocorrelation(combined_chain, 1)[1]
# Check convergence criteria
mean_val = np.mean(combined_chain)
relative_mcse = mcse / abs(mean_val) if mean_val != 0 else np.inf
converged = (
r_hat < self.r_hat_threshold
and ess >= self.min_ess
and relative_mcse < self.relative_mcse_threshold
)
results[name] = ConvergenceStats(
r_hat=r_hat,
ess=ess,
mcse=mcse,
converged=converged,
n_iterations=n_iterations * n_chains,
autocorrelation=autocorr,
)
return results
def _calculate_autocorrelation(self, chain: np.ndarray, max_lag: int) -> np.ndarray:
"""Calculate autocorrelation function using FFT.
Uses FFT-based computation: acf = ifft(|fft(x)|^2), which is
O(N log N) regardless of max_lag, replacing the previous O(N*L)
Python loop.
Args:
chain: 1D array of samples
max_lag: Maximum lag
Returns:
Array of autocorrelations for lags 0 to max_lag
"""
n = len(chain)
chain_centered = chain - np.mean(chain)
c0 = np.dot(chain_centered, chain_centered) / n
if c0 == 0:
# Zero variance: lag-0 is 1.0, all others are 0
autocorr = np.zeros(max_lag + 1)
autocorr[0] = 1.0
return autocorr
# FFT-based autocorrelation: zero-pad to avoid circular artifacts
padded = np.zeros(2 * n)
padded[:n] = chain_centered
f = rfft(padded)
acf_raw = irfft(f * np.conj(f), n=2 * n)[:n]
# Normalize: acf_raw[k] = sum of x[t]*x[t+k], divide by n then by c0
# Since c0 = acf_raw[0]/n, normalizing by acf_raw[0] gives rho[k]
result = acf_raw[: min(max_lag + 1, n)] / acf_raw[0]
return np.asarray(result)
@staticmethod
def _spectral_density_at_zero(segment: np.ndarray) -> float:
"""Estimate spectral density at zero using Bartlett kernel.
Computes S(0) = gamma(0) + 2 * sum_{k=1}^{B} (1 - k/(B+1)) * gamma(k)
where B = sqrt(n) is the bandwidth. This is the standard windowed
autocovariance estimator per Geweke (1992).
Args:
segment: 1D array of samples
Returns:
Estimated spectral density at zero frequency (non-negative)
"""
n = len(segment)
centered = segment - np.mean(segment)
gamma_0 = np.dot(centered, centered) / n
if gamma_0 == 0:
return 0.0
bandwidth = max(int(np.sqrt(n)), 1)
s_zero = gamma_0
for k in range(1, min(bandwidth + 1, n)):
gamma_k = np.dot(centered[:-k], centered[k:]) / n
weight = 1 - k / (bandwidth + 1) # Bartlett taper
s_zero += 2 * weight * gamma_k
return float(max(s_zero, 0.0))
[docs]
def geweke_test(
self, chain: np.ndarray, first_fraction: float = 0.1, last_fraction: float = 0.5
) -> Tuple[float, float]:
"""Perform Geweke convergence test.
Compares means of first and last portions of chain using spectral
density estimates at zero frequency per Geweke (1992). This correctly
accounts for autocorrelation in MCMC and sequential Monte Carlo chains.
Args:
chain: 1D array of samples
first_fraction: Fraction of chain to use for first portion
last_fraction: Fraction of chain to use for last portion
Returns:
Tuple of (z-score, p-value)
"""
n = len(chain)
n_first = int(n * first_fraction)
n_last = int(n * last_fraction)
first_portion = chain[:n_first]
last_portion = chain[-n_last:]
mean_first = np.mean(first_portion)
mean_last = np.mean(last_portion)
# Spectral density at zero frequency for variance of mean
var_first = self._spectral_density_at_zero(first_portion) / n_first
var_last = self._spectral_density_at_zero(last_portion) / n_last
z_score = (mean_first - mean_last) / np.sqrt(var_first + var_last)
# Lazy import to avoid scipy issues in worker processes
from scipy import stats
p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
return z_score, p_value
[docs]
def heidelberger_welch_test(
self, chain: np.ndarray, alpha: float = 0.05
) -> Dict[str, Union[bool, float]]:
"""Perform Heidelberger-Welch stationarity and halfwidth tests.
Args:
chain: 1D array of samples
alpha: Significance level
Returns:
Dictionary with test results
"""
n = len(chain)
# Stationarity test using Cramer-von Mises
# Simplified version - checks if mean is stable
window_size = n // 10
means = []
for i in range(window_size, n - window_size):
window_mean = np.mean(chain[i - window_size : i + window_size])
means.append(window_mean)
# Check if means are stable
mean_variance = np.var(means)
overall_variance = np.var(chain)
stationarity_ratio = mean_variance / overall_variance if overall_variance > 0 else np.inf
stationary = stationarity_ratio < 0.1 # Heuristic threshold
# Halfwidth test
mean_estimate = np.mean(chain)
mcse = self.calculate_mcse(chain)
halfwidth = 1.96 * mcse # 95% confidence interval halfwidth
relative_halfwidth = halfwidth / abs(mean_estimate) if mean_estimate != 0 else np.inf
halfwidth_passed = relative_halfwidth < 0.1 # 10% relative precision
return {
"stationary": stationary,
"stationarity_ratio": stationarity_ratio,
"halfwidth_passed": halfwidth_passed,
"relative_halfwidth": relative_halfwidth,
"mean": mean_estimate,
"mcse": mcse,
}