"""Visualization utilities for sensitivity analysis results.
This module provides publication-ready visualization functions for sensitivity
analysis results, including tornado diagrams, two-way sensitivity heatmaps,
and parameter impact charts.
Example:
Creating a tornado diagram::
from ergodic_insurance.sensitivity_visualization import plot_tornado_diagram
# Assuming tornado_data is a DataFrame from SensitivityAnalyzer
fig = plot_tornado_diagram(
tornado_data,
title="Parameter Sensitivity Analysis",
metric_label="ROE Impact"
)
fig.savefig("tornado_diagram.png", dpi=300, bbox_inches='tight')
Author: Alex Filiakov
Date: 2025-01-29
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
if TYPE_CHECKING:
from ergodic_insurance.sensitivity import (
SensitivityAnalyzer,
SensitivityResult,
TwoWaySensitivityResult,
)
# Set default style for publication-ready plots
plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("husl")
[docs]
def plot_tornado_diagram( # pylint: disable=too-many-locals
tornado_data: pd.DataFrame,
title: str = "Sensitivity Analysis - Tornado Diagram",
metric_label: str = "Impact on Objective",
figsize: Tuple[float, float] = (10, 6),
n_params: Optional[int] = None,
color_positive: str = "#2E7D32",
color_negative: str = "#C62828",
show_values: bool = True,
) -> Figure:
"""Create a tornado diagram for sensitivity analysis results.
Args:
tornado_data: DataFrame with columns: parameter, impact, direction,
low_value, high_value, baseline
title: Plot title
metric_label: Label for the x-axis
figsize: Figure size as (width, height)
n_params: Number of top parameters to show (None for all)
color_positive: Color for positive impacts
color_negative: Color for negative impacts
show_values: Whether to show numeric values on bars
Returns:
Matplotlib Figure object
"""
# Select top n parameters if specified
if n_params is not None and len(tornado_data) > n_params:
tornado_data = tornado_data.head(n_params)
# Create figure
fig, ax = plt.subplots(figsize=figsize)
# Prepare data
n = len(tornado_data)
y_pos = np.arange(n)
# Handle empty data case
if n == 0:
low_change = np.array([])
high_change = np.array([])
low_values = np.array([])
high_values = np.array([])
else:
# Calculate bar widths (centered on baseline)
baseline_values = tornado_data["baseline"].values
low_values = np.asarray(tornado_data["low_value"].values)
high_values = np.asarray(tornado_data["high_value"].values)
# Normalize to percentage change from baseline
# Convert to numpy arrays to ensure proper operations
low_values_arr = low_values
high_values_arr = high_values
baseline_values_arr = np.asarray(baseline_values)
low_change = (low_values_arr - baseline_values_arr) / np.abs(baseline_values_arr) * 100
high_change = (high_values_arr - baseline_values_arr) / np.abs(baseline_values_arr) * 100
# Create bars
for i, (_idx, row) in enumerate(tornado_data.iterrows()):
color = color_positive if row["direction"] == "positive" else color_negative
# Draw bar from low to high
left = low_change[i]
width = high_change[i] - low_change[i]
_rect = ax.barh(
i, width, left=left, height=0.6, color=color, alpha=0.7, edgecolor="black", linewidth=1
)
# Add value labels if requested
if show_values:
# Low value - increased padding to prevent overlap
ax.text(left - 2.0, i, f"{low_values[i]:.2g}", ha="right", va="center", fontsize=8)
# High value - increased padding to prevent overlap
ax.text(
left + width + 2.0, i, f"{high_values[i]:.2g}", ha="left", va="center", fontsize=8
)
# Add baseline line
ax.axvline(x=0, color="black", linestyle="--", linewidth=1.5, alpha=0.7)
# Move baseline label to avoid title overlap - place it inside the plot area
ax.text(
0,
-0.7,
"Baseline",
ha="center",
fontsize=10,
style="italic",
transform=ax.get_xaxis_transform(),
)
# Customize axes
ax.set_yticks(y_pos)
ax.set_yticklabels(tornado_data["parameter"].values)
ax.set_xlabel(f"{metric_label} (% change from baseline)", fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold", pad=25) # Increased padding for title
# Add grid
ax.grid(True, axis="x", alpha=0.3)
ax.set_axisbelow(True)
# Set x-axis limits with extra padding for value labels
if len(tornado_data) > 0:
x_min = min(low_change.min(), -40)
x_max = max(high_change.max(), 40)
x_padding = (x_max - x_min) * 0.15 # 15% padding on each side
ax.set_xlim(x_min - x_padding, x_max + x_padding)
else:
# Default limits for empty data
ax.set_xlim(-50, 50)
# Adjust layout with specific margins
plt.tight_layout(rect=(0.05, 0.02, 0.95, 0.98)) # left, bottom, right, top
return fig
[docs]
def plot_two_way_sensitivity(
result: "TwoWaySensitivityResult",
title: Optional[str] = None,
cmap: str = "RdYlGn",
figsize: Tuple[float, float] = (10, 8),
show_contours: bool = True,
contour_levels: Optional[int] = 10,
optimal_point: Optional[Tuple[float, float]] = None,
fmt: str = ".2f",
) -> Figure:
"""Create a heatmap for two-way sensitivity analysis.
Args:
result: TwoWaySensitivityResult object
title: Plot title (auto-generated if None)
cmap: Colormap name
figsize: Figure size as (width, height)
show_contours: Whether to show contour lines
contour_levels: Number of contour levels
optimal_point: Optional (param1_value, param2_value) to mark
fmt: Format string for contour labels. Can be:
- New-style format like '.2f' or '.2%'
- Old-style format like '%.2f'
- Callable that takes a number and returns a string
Returns:
Matplotlib Figure object
"""
if title is None:
title = f"{result.metric_name} Sensitivity: {result.parameter1} vs {result.parameter2}"
fig, ax = plt.subplots(figsize=figsize)
# Create meshgrid for plotting
X, Y = np.meshgrid(result.values1, result.values2, indexing="ij")
# Create heatmap
im = ax.pcolormesh(X, Y, result.metric_grid, cmap=cmap, shading="auto")
# Add colorbar
_cbar = plt.colorbar(im, ax=ax, label=result.metric_name)
# Add contours if requested
if show_contours:
contours = ax.contour(
X,
Y,
result.metric_grid,
levels=contour_levels,
colors="black",
linewidths=0.5,
alpha=0.5,
)
# Handle format string - convert new-style to old-style or use callable
formatter: Union[str, Callable[[float], str]]
if fmt and fmt.startswith("."):
# New-style format string like '.2f' or '.2%'
if "%" in fmt:
# For percentage format, create a formatter function
decimal_places = int(fmt[1]) if fmt[1].isdigit() else 1
def format_pct(x: float) -> str:
"""Format value as percentage."""
return f"{x:.{decimal_places}%}"
formatter = format_pct
else:
# For regular float format, convert to old-style
formatter = f"%{fmt[1:]}f" # Convert '.2f' to '%.2f'
elif callable(fmt):
formatter = fmt
else:
# Assume it's already an old-style format string or callable
formatter = fmt if fmt else "%.2f"
ax.clabel(contours, inline=True, fontsize=8, fmt=formatter)
# Mark optimal point if provided
if optimal_point is not None:
ax.plot(optimal_point[0], optimal_point[1], "r*", markersize=15, label="Optimal Point")
ax.legend()
# Customize axes
ax.set_xlabel(result.parameter1, fontsize=12)
ax.set_ylabel(result.parameter2, fontsize=12)
ax.set_title(title, fontsize=14, fontweight="bold", pad=20)
# Add grid
ax.grid(True, alpha=0.3, linestyle=":")
plt.tight_layout()
return fig
[docs]
def plot_parameter_sweep(
result: "SensitivityResult",
metrics: Optional[List[str]] = None,
title: Optional[str] = None,
figsize: Tuple[float, float] = (12, 8),
normalize: bool = False,
mark_baseline: bool = True,
) -> Figure:
"""Plot multiple metrics against parameter variations.
Args:
result: SensitivityResult object
metrics: List of metrics to plot (None for all)
title: Plot title (auto-generated if None)
figsize: Figure size as (width, height)
normalize: Whether to normalize metrics to [0, 1]
mark_baseline: Whether to mark the baseline value
Returns:
Matplotlib Figure object
"""
if metrics is None:
metrics = list(result.metrics.keys())
if title is None:
title = f"Sensitivity Analysis: {result.parameter}"
# Determine subplot layout
n_metrics = len(metrics)
n_cols = min(3, n_metrics)
n_rows = (n_metrics + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
if n_metrics == 1:
axes = [axes]
else:
axes = axes.flatten()
# Plot each metric
for i, metric in enumerate(metrics):
ax = axes[i]
values = result.metrics[metric]
# Normalize if requested
if normalize and values.max() != values.min():
values = (values - values.min()) / (values.max() - values.min())
# Plot line
ax.plot(result.variations, values, "o-", linewidth=2, markersize=6)
# Mark baseline
if mark_baseline:
baseline_idx = len(result.variations) // 2
ax.axvline(
x=result.baseline_value, color="red", linestyle="--", alpha=0.5, label="Baseline"
)
ax.plot(result.baseline_value, values[baseline_idx], "r*", markersize=12)
# Customize subplot
ax.set_xlabel(result.parameter, fontsize=10)
ax.set_ylabel(metric.replace("_", " ").title(), fontsize=10)
ax.grid(True, alpha=0.3)
# Add trend annotation
trend = np.polyfit(result.variations, values, 1)[0]
trend_text = "↑" if trend > 0 else "↓" if trend < 0 else "→"
ax.text(
0.95,
0.95,
trend_text,
transform=ax.transAxes,
fontsize=16,
ha="right",
va="top",
alpha=0.5,
)
# Hide unused subplots
for i in range(n_metrics, len(axes)):
axes[i].set_visible(False)
# Add main title
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
return fig
[docs]
def create_sensitivity_report(
analyzer: "SensitivityAnalyzer",
parameters: List[Union[str, Tuple[str, str]]],
output_dir: Optional[str] = None,
metric: str = "optimal_roe",
formats: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Generate a complete sensitivity analysis report.
Args:
analyzer: SensitivityAnalyzer object with results
parameters: List of parameters to analyze
output_dir: Directory to save figures (None for no saving)
metric: Primary metric for analysis
formats: File formats to save figures in
Returns:
Dictionary with generated figures and analysis summary
"""
from pathlib import Path
if formats is None:
formats = ["png", "pdf"]
report: Dict[str, Any] = {"figures": {}, "summary": {}, "data": {}}
# Generate tornado diagram
print("Generating tornado diagram...")
tornado_data = analyzer.create_tornado_diagram(parameters, metric=metric)
report["data"]["tornado"] = tornado_data
fig_tornado = plot_tornado_diagram(
tornado_data,
title=f"Sensitivity Analysis - {metric.replace('_', ' ').title()}",
metric_label=metric.replace("_", " ").title(),
)
report["figures"]["tornado"] = fig_tornado
# Save if output directory provided
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for fmt in formats:
filename = output_path / f"tornado_diagram.{fmt}"
fig_tornado.savefig(filename, dpi=300, bbox_inches="tight")
print(f"Saved: {filename}")
# Generate parameter sweeps for top 3 most impactful parameters
top_params = tornado_data.head(3)["parameter"].values
for param in top_params:
print(f"Analyzing parameter: {param}")
# Run sensitivity analysis
result = analyzer.analyze_parameter(param)
report["data"][f"sweep_{param}"] = result
# Create sweep plot
fig_sweep = plot_parameter_sweep(result, title=f"Parameter Sweep: {param}")
report["figures"][f"sweep_{param}"] = fig_sweep
# Save if requested
if output_dir:
for fmt in formats:
filename = output_path / f"sweep_{param}.{fmt}"
fig_sweep.savefig(filename, dpi=300, bbox_inches="tight")
print(f"Saved: {filename}")
# Generate summary statistics
report["summary"]["most_impactful"] = tornado_data.iloc[0]["parameter"]
report["summary"]["least_impactful"] = tornado_data.iloc[-1]["parameter"]
report["summary"]["total_parameters"] = len(tornado_data)
report["summary"]["primary_metric"] = metric
# Calculate relative importances
total_impact = tornado_data["impact"].sum()
if total_impact > 0:
tornado_data["relative_importance"] = tornado_data["impact"] / total_impact * 100
report["summary"]["relative_importances"] = tornado_data[
["parameter", "relative_importance"]
].to_dict("records")
print("Sensitivity report generation complete!")
return report
[docs]
def plot_sensitivity_matrix( # pylint: disable=too-many-locals
results: Dict[str, "SensitivityResult"],
metric: str = "optimal_roe",
figsize: Tuple[float, float] = (12, 10),
cmap: str = "coolwarm",
show_values: bool = True,
) -> Figure:
"""Create a matrix plot showing sensitivity across multiple parameters.
Args:
results: Dictionary of parameter names to SensitivityResult objects
metric: Metric to display
figsize: Figure size as (width, height)
cmap: Colormap name
show_values: Whether to show numeric values in cells
Returns:
Matplotlib Figure object
"""
# Extract data for matrix
params = list(results.keys())
n_params = len(params)
# Find common variation points (assuming normalized to percentages)
variation_points = [-30, -20, -10, 0, 10, 20, 30] # Percentage changes
# Create matrix
matrix = np.zeros((n_params, len(variation_points)))
for i, param in enumerate(params):
result = results[param]
baseline_idx = len(result.variations) // 2
baseline_metric = result.metrics[metric][baseline_idx]
# Interpolate to common points
param_pct = (result.variations - result.baseline_value) / result.baseline_value * 100
metric_pct = (result.metrics[metric] - baseline_metric) / abs(baseline_metric) * 100
for j, pct in enumerate(variation_points):
# Find closest point or interpolate
idx = np.argmin(np.abs(param_pct - pct))
matrix[i, j] = metric_pct[idx]
# Create figure
fig, ax = plt.subplots(figsize=figsize)
# Create heatmap
im = ax.imshow(matrix, cmap=cmap, aspect="auto")
# Add colorbar
_cbar = plt.colorbar(im, ax=ax, label=f"{metric} (% change)")
# Set ticks and labels
ax.set_xticks(np.arange(len(variation_points)))
ax.set_yticks(np.arange(n_params))
ax.set_xticklabels([f"{v:+d}%" for v in variation_points])
ax.set_yticklabels(params)
# Add values if requested
if show_values:
for i in range(n_params):
for j in range(len(variation_points)):
_text = ax.text(
j,
i,
f"{matrix[i, j]:.1f}",
ha="center",
va="center",
color="white" if abs(matrix[i, j]) > 5 else "black",
fontsize=8,
)
# Customize
ax.set_xlabel("Parameter Change (%)", fontsize=12)
ax.set_title(
f'Sensitivity Matrix: {metric.replace("_", " ").title()}',
fontsize=14,
fontweight="bold",
pad=20,
)
# Add grid
ax.set_xticks(np.arange(len(variation_points) + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(n_params + 1) - 0.5, minor=True)
ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.5)
plt.tight_layout()
return fig