"""Batch processing and scenario comparison visualizations.
This module provides visualization functions for batch simulation results,
scenario comparisons, and sensitivity analyses.
"""
from typing import Any, List, Optional, Tuple
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from .core import COLOR_SEQUENCE, WSJ_COLORS, set_wsj_style
[docs]
def plot_scenario_comparison( # pylint: disable=too-many-locals
aggregated_results: Any,
metrics: Optional[List[str]] = None,
figsize: Tuple[float, float] = (14, 8),
save_path: Optional[str] = None,
) -> Figure:
"""Create comprehensive scenario comparison visualization.
Compares multiple scenarios across different metrics with bar charts
highlighting the best performer for each metric.
Args:
aggregated_results: AggregatedResults object from batch processing
metrics: List of metrics to compare (default: key metrics)
figsize: Figure size (width, height)
save_path: Path to save figure
Returns:
Matplotlib figure with scenario comparisons
Examples:
>>> from ergodic_insurance.batch_processor import AggregatedResults
>>> results = AggregatedResults(batch_results)
>>> fig = plot_scenario_comparison(results, metrics=["mean_growth_rate"])
"""
from ..batch_processor import AggregatedResults
if not isinstance(aggregated_results, AggregatedResults):
raise ValueError("Input must be AggregatedResults from batch processing")
# Get successful results only
df = aggregated_results.summary_statistics
if df.empty:
print("No successful scenarios to visualize")
return plt.figure()
# Default metrics if not specified
if metrics is None:
metrics = ["ruin_probability", "mean_growth_rate", "mean_final_assets", "var_99"]
metrics = [m for m in metrics if m in df.columns]
# Create subplot grid
n_metrics = len(metrics)
n_cols = 2
n_rows = (n_metrics + 1) // 2
set_wsj_style()
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
axes = axes.flatten() if n_metrics > 1 else [axes]
# Plot each metric
for i, metric in enumerate(metrics):
ax = axes[i]
# Create bar plot
scenarios = df["scenario"]
values = df[metric]
bars = ax.bar(range(len(scenarios)), values, color=WSJ_COLORS["blue"], alpha=0.8)
# Highlight best performer
if metric == "ruin_probability": # Lower is better
best_idx = values.idxmin()
else: # Higher is better
best_idx = values.idxmax()
bars[best_idx].set_color(WSJ_COLORS["green"])
bars[best_idx].set_alpha(1.0)
# Format
ax.set_xlabel("Scenario")
ax.set_ylabel(metric.replace("_", " ").title())
ax.set_title(f"{metric.replace('_', ' ').title()} Comparison")
ax.set_xticks(range(len(scenarios)))
ax.set_xticklabels(scenarios, rotation=45, ha="right")
ax.grid(True, alpha=0.3)
# Add value labels
for _j, (value_bar, val) in enumerate(zip(bars, values)):
height = value_bar.get_height()
format_str = f"{val:.2%}" if "probability" in metric else f"{val:.2g}"
ax.text(
value_bar.get_x() + value_bar.get_width() / 2,
height,
format_str,
ha="center",
va="bottom",
fontsize=8,
)
# Remove empty subplots
for i in range(len(metrics), len(axes)):
fig.delaxes(axes[i])
plt.suptitle("Scenario Comparison Analysis", fontsize=16, fontweight="bold")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
return fig
[docs]
def plot_sensitivity_heatmap( # pylint: disable=too-many-locals
aggregated_results: Any,
metric: str = "mean_growth_rate",
figsize: Tuple[float, float] = (10, 8),
save_path: Optional[str] = None,
) -> Figure:
"""Create sensitivity analysis heatmap.
Visualizes sensitivity of outcomes to parameter changes using
a horizontal bar chart color-coded by impact direction.
Args:
aggregated_results: AggregatedResults with sensitivity analysis
metric: Metric to visualize
figsize: Figure size
save_path: Path to save figure
Returns:
Matplotlib figure with sensitivity analysis
"""
from ..batch_processor import AggregatedResults
if not isinstance(aggregated_results, AggregatedResults):
raise ValueError("Input must be AggregatedResults from batch processing")
sensitivity_df = aggregated_results.sensitivity_analysis
if sensitivity_df is None or sensitivity_df.empty:
print("No sensitivity analysis data available")
return plt.figure()
# Prepare data for heatmap
_sensitivity_matrix: List[List[float]] = []
param_names = []
for _, row in sensitivity_df.iterrows():
scenario_name = row["scenario"]
# Extract parameter name from scenario name
parts = scenario_name.split("_")
if len(parts) >= 2:
param = "_".join(parts[1:-1]) # Remove prefix and direction
if param not in param_names:
param_names.append(param)
# Create matrix of sensitivity values
metric_col = f"{metric}_change_pct"
if metric_col not in sensitivity_df.columns:
available = [c for c in sensitivity_df.columns if "_change_pct" in c]
if available:
metric_col = available[0]
print(f"Using {metric_col} instead of requested metric")
else:
print("No sensitivity metrics found")
return plt.figure()
set_wsj_style()
fig, ax = plt.subplots(figsize=figsize)
# Create simple bar plot if matrix creation fails
scenarios = sensitivity_df["scenario"]
values = sensitivity_df[metric_col]
bars = ax.barh(scenarios, values, color=WSJ_COLORS["blue"], alpha=0.8)
# Color code by positive/negative
for sens_bar, val in zip(bars, values):
if val < 0:
sens_bar.set_color(WSJ_COLORS["red"])
else:
sens_bar.set_color(WSJ_COLORS["green"])
ax.set_xlabel(f"% Change in {metric.replace('_', ' ').title()}")
ax.set_ylabel("Scenario")
ax.set_title(f"Sensitivity Analysis: {metric.replace('_', ' ').title()}")
ax.axvline(x=0, color=WSJ_COLORS["black"], linestyle="-", linewidth=1)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
return fig
[docs]
def plot_parameter_sweep_3d(
aggregated_results: Any,
param1: str,
param2: str,
metric: str = "mean_growth_rate",
height: int = 600,
save_path: Optional[str] = None,
) -> go.Figure:
"""Create 3D surface plot for parameter sweep results.
Visualizes how a metric varies across two parameter dimensions
using an interactive 3D scatter plot.
Args:
aggregated_results: AggregatedResults from grid search
param1: First parameter name
param2: Second parameter name
metric: Metric to plot on z-axis
height: Figure height in pixels
save_path: Path to save figure
Returns:
Plotly figure with 3D parameter sweep
"""
from ..batch_processor import AggregatedResults
if not isinstance(aggregated_results, AggregatedResults):
raise ValueError("Input must be AggregatedResults from batch processing")
# Extract parameter values and metric from results
param1_values = []
param2_values = []
metric_values = []
for result in aggregated_results.batch_results:
if result.simulation_results:
overrides = result.metadata.get("parameter_overrides", {})
if param1 in overrides and param2 in overrides:
param1_values.append(overrides[param1])
param2_values.append(overrides[param2])
if metric == "mean_growth_rate":
metric_values.append(np.mean(result.simulation_results.growth_rates))
elif metric == "ruin_probability":
metric_values.append(result.simulation_results.ruin_probability)
elif metric == "mean_final_assets":
metric_values.append(np.mean(result.simulation_results.final_assets))
else:
metric_values.append(result.simulation_results.metrics.get(metric, np.nan))
if not param1_values:
print("No parameter sweep data found")
return go.Figure()
# Create 3D scatter plot
fig = go.Figure(
data=[
go.Scatter3d(
x=param1_values,
y=param2_values,
z=metric_values,
mode="markers",
marker={
"size": 8,
"color": metric_values,
"colorscale": "Viridis",
"showscale": True,
"colorbar": {"title": metric.replace("_", " ").title()},
},
text=[
f"{param1}: {p1:.3g}<br>{param2}: {p2:.3g}<br>{metric}: {m:.3g}"
for p1, p2, m in zip(param1_values, param2_values, metric_values)
],
hovertemplate="%{text}<extra></extra>",
)
]
)
fig.update_layout(
title=f"Parameter Sweep: {metric.replace('_', ' ').title()}",
scene={
"xaxis_title": param1.replace("_", " ").title(),
"yaxis_title": param2.replace("_", " ").title(),
"zaxis_title": metric.replace("_", " ").title(),
},
height=height,
template="plotly_white",
font={"family": "Arial, sans-serif"},
)
if save_path:
fig.write_html(save_path)
return fig
[docs]
def plot_scenario_convergence(
batch_results: List[Any],
metric: str = "mean_growth_rate",
figsize: Tuple[float, float] = (12, 6),
save_path: Optional[str] = None,
) -> Figure:
"""Plot convergence of metric across scenarios.
Shows how a metric converges as more scenarios are processed,
with execution time distribution.
Args:
batch_results: List of BatchResult objects
metric: Metric to track
figsize: Figure size
save_path: Path to save figure
Returns:
Matplotlib figure with convergence analysis
"""
set_wsj_style()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# Extract metric values in order
scenarios = []
values = []
times = []
for result in batch_results:
if result.simulation_results:
scenarios.append(result.scenario_name)
times.append(result.execution_time)
if metric == "mean_growth_rate":
values.append(np.mean(result.simulation_results.growth_rates))
elif metric == "ruin_probability":
values.append(result.simulation_results.ruin_probability)
elif metric == "mean_final_assets":
values.append(np.mean(result.simulation_results.final_assets))
else:
values.append(result.simulation_results.metrics.get(metric, np.nan))
if not values:
print("No data to plot")
return fig
# Plot 1: Running average
running_avg = np.cumsum(values) / np.arange(1, len(values) + 1)
ax1.plot(running_avg, color=WSJ_COLORS["blue"], linewidth=2)
ax1.fill_between(range(len(running_avg)), running_avg, alpha=0.3, color=WSJ_COLORS["blue"])
ax1.set_xlabel("Scenario Number")
ax1.set_ylabel(f"Running Average {metric.replace('_', ' ').title()}")
ax1.set_title("Metric Convergence")
ax1.grid(True, alpha=0.3)
# Add convergence band
final_avg = running_avg[-1]
ax1.axhline(final_avg, color=WSJ_COLORS["red"], linestyle="--", alpha=0.7)
ax1.fill_between(
range(len(running_avg)),
[final_avg * 0.95] * len(running_avg),
[final_avg * 1.05] * len(running_avg),
alpha=0.2,
color=WSJ_COLORS["red"],
)
# Plot 2: Execution time distribution
ax2.hist(times, bins=20, color=WSJ_COLORS["green"], alpha=0.7, edgecolor="black")
ax2.set_xlabel("Execution Time (seconds)")
ax2.set_ylabel("Count")
ax2.set_title("Scenario Execution Times")
ax2.axvline(
np.mean(times),
color=WSJ_COLORS["red"],
linestyle="--",
label=f"Mean: {np.mean(times):.1f}s",
)
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.suptitle(
f"Batch Processing Analysis ({len(values)} scenarios)", fontsize=14, fontweight="bold"
)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
return fig
[docs]
def plot_parallel_scenarios( # pylint: disable=too-many-branches
batch_results: List[Any],
metrics: List[str],
figsize: Tuple[float, float] = (12, 8),
normalize: bool = True,
) -> Figure:
"""Create parallel coordinates plot for scenario comparison.
Visualizes multiple scenarios across multiple metrics using
parallel coordinates for comprehensive comparison.
Args:
batch_results: List of BatchResult objects
metrics: List of metrics to include
figsize: Figure size
normalize: Whether to normalize metrics to [0, 1]
Returns:
Matplotlib figure with parallel coordinates
"""
set_wsj_style()
fig, ax = plt.subplots(figsize=figsize)
# Extract data
data = []
scenario_names = []
for result in batch_results:
if result.simulation_results:
scenario_names.append(result.scenario_name)
row = []
for metric in metrics:
if metric == "mean_growth_rate":
value = np.mean(result.simulation_results.growth_rates)
elif metric == "ruin_probability":
value = result.simulation_results.ruin_probability
elif metric == "mean_final_assets":
value = np.mean(result.simulation_results.final_assets)
else:
value = result.simulation_results.metrics.get(metric, np.nan)
row.append(value)
data.append(row)
if not data:
print("No data to plot")
return fig
data_array = np.array(data)
# Normalize if requested
if normalize:
for i in range(data_array.shape[1]):
col_min = np.nanmin(data_array[:, i])
col_max = np.nanmax(data_array[:, i])
if col_max > col_min:
data_array[:, i] = (data_array[:, i] - col_min) / (col_max - col_min)
# Create parallel coordinates
x = np.arange(len(metrics))
for i, scenario in enumerate(data_array):
color = COLOR_SEQUENCE[i % len(COLOR_SEQUENCE)]
ax.plot(x, scenario, "o-", color=color, alpha=0.7, label=scenario_names[i])
# Styling
ax.set_xticks(x)
ax.set_xticklabels([m.replace("_", "\n") for m in metrics], rotation=0)
ax.set_ylabel("Normalized Value" if normalize else "Value")
ax.set_title("Parallel Scenarios Comparison")
ax.grid(True, alpha=0.3)
# Add legend
if len(scenario_names) <= 10:
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
return fig