Source code for ergodic_insurance.visualization.interactive_plots

"""Interactive visualization functions using Plotly.

This module provides functions for creating interactive dashboards
and visualizations for Monte Carlo simulations and analysis results.
"""

__all__ = [
    "create_interactive_dashboard",
    "create_time_series_dashboard",
    "create_correlation_heatmap",
    "create_risk_dashboard",
]

from typing import Any, Dict, Union

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from .core import COLOR_SEQUENCE, WSJ_COLORS


[docs] def create_interactive_dashboard( results: Union[Dict[str, Any], pd.DataFrame], title: str = "Monte Carlo Simulation Dashboard", height: int = 600, show_distributions: bool = False, ) -> go.Figure: """Create interactive Plotly dashboard with WSJ styling. Creates a comprehensive interactive dashboard with multiple panels showing simulation results, convergence, and risk metrics. Args: results: Dictionary with simulation results or DataFrame title: Dashboard title height: Dashboard height in pixels show_distributions: Whether to show distribution plots Returns: Plotly figure with interactive dashboard Examples: >>> results = { ... "growth_rates": np.random.normal(0.05, 0.02, 1000), ... "losses": np.random.lognormal(10, 2, 1000), ... "metrics": {"var_95": 100000, "var_99": 150000} ... } >>> fig = create_interactive_dashboard(results) """ # Handle DataFrame input if isinstance(results, pd.DataFrame): # Convert DataFrame to dictionary format expected by dashboard results_dict = { "data": results, "summary": { "mean_assets": ( results.get("assets", pd.Series()).mean() if "assets" in results.columns else 0 ), "mean_losses": ( results.get("losses", pd.Series()).mean() if "losses" in results.columns else 0 ), "years": results["year"].nunique() if "year" in results.columns else 1, }, } results = results_dict # Create subplots fig = make_subplots( rows=2, cols=2, subplot_titles=( "Growth Rate Distribution", "Loss Exceedance Curve", "Convergence Diagnostics", "Risk Metrics", ), specs=[ [{"type": "histogram"}, {"type": "scatter"}], [{"type": "scatter"}, {"type": "bar"}], ], ) # WSJ-style layout layout_theme = { "plot_bgcolor": "white", "paper_bgcolor": "white", "font": {"family": "Arial, sans-serif", "size": 11, "color": WSJ_COLORS["black"]}, "title": {"font": {"size": 16, "color": WSJ_COLORS["black"]}}, "xaxis": {"gridcolor": WSJ_COLORS["light_gray"], "gridwidth": 0.5}, "yaxis": {"gridcolor": WSJ_COLORS["light_gray"], "gridwidth": 0.5}, "colorway": COLOR_SEQUENCE, } # Growth rate histogram if "growth_rates" in results: fig.add_trace( go.Histogram( x=results["growth_rates"], nbinsx=50, marker_color=WSJ_COLORS["blue"], opacity=0.7, name="Growth Rate", ), row=1, col=1, ) # Loss exceedance curve if "losses" in results: losses_data = np.asarray(results["losses"]) sorted_losses = np.sort(losses_data)[::-1] exceedance_prob = np.arange(1, len(sorted_losses) + 1) / len(sorted_losses) fig.add_trace( go.Scatter( x=sorted_losses / 1e6, y=exceedance_prob, mode="lines", line={"color": WSJ_COLORS["red"], "width": 2}, name="Exceedance", ), row=1, col=2, ) fig.update_xaxes(title_text="Loss Amount ($M)", row=1, col=2) fig.update_yaxes(title_text="Exceedance Probability", type="log", row=1, col=2) # Convergence diagnostics if "convergence" in results and isinstance(results["convergence"], dict): iterations = results["convergence"].get("iterations", []) r_hat = results["convergence"].get("r_hat", []) fig.add_trace( go.Scatter( x=iterations, y=r_hat, mode="lines+markers", line={"color": WSJ_COLORS["green"], "width": 2}, marker={"size": 6}, name="R-hat", ), row=2, col=1, ) # Add convergence threshold line fig.add_hline( y=1.1, line_dash="dash", line_color=WSJ_COLORS["orange"], annotation_text="Convergence Threshold", row=2, col=1, ) fig.update_xaxes(title_text="Iterations", row=2, col=1) fig.update_yaxes(title_text="R-hat Statistic", row=2, col=1) # Risk metrics bar chart if "metrics" in results and isinstance(results["metrics"], dict): metric_names = ["VaR(95%)", "VaR(99%)", "TVaR(99%)", "Expected Shortfall"] metric_values = [ results["metrics"].get("var_95", 0) / 1e6, results["metrics"].get("var_99", 0) / 1e6, results["metrics"].get("tvar_99", 0) / 1e6, results["metrics"].get("expected_shortfall", 0) / 1e6, ] fig.add_trace( go.Bar( x=metric_names, y=metric_values, marker_color=COLOR_SEQUENCE[: len(metric_names)], text=[f"${v:.1f}M" for v in metric_values], textposition="outside", name="Risk Metrics", ), row=2, col=2, ) fig.update_yaxes(title_text="Amount ($M)", row=2, col=2) # Update layout fig.update_layout(title_text=title, showlegend=False, height=height, **layout_theme) # Update all axes fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor=WSJ_COLORS["light_gray"]) fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor=WSJ_COLORS["light_gray"]) return fig
[docs] def create_time_series_dashboard( data: pd.DataFrame, value_col: str, time_col: str = "date", title: str = "Time Series Analysis", height: int = 600, show_forecast: bool = False, ) -> go.Figure: """Create interactive time series visualization. Creates an interactive time series plot with optional forecast bands and statistical overlays. Args: data: DataFrame with time series data value_col: Name of value column time_col: Name of time column title: Plot title height: Plot height in pixels show_forecast: Whether to show forecast bands Returns: Plotly figure with time series visualization """ fig = go.Figure() # Main time series fig.add_trace( go.Scatter( x=data[time_col], y=data[value_col], mode="lines", name="Actual", line={"color": WSJ_COLORS["blue"], "width": 2}, ) ) # Add moving average if len(data) > 20: ma = data[value_col].rolling(window=20).mean() fig.add_trace( go.Scatter( x=data[time_col], y=ma, mode="lines", name="20-period MA", line={"color": WSJ_COLORS["orange"], "width": 1, "dash": "dash"}, ) ) # Add forecast if requested if show_forecast and f"{value_col}_forecast" in data.columns: fig.add_trace( go.Scatter( x=data[time_col], y=data[f"{value_col}_forecast"], mode="lines", name="Forecast", line={"color": WSJ_COLORS["green"], "width": 2, "dash": "dot"}, ) ) # Add confidence bands if available if f"{value_col}_upper" in data.columns and f"{value_col}_lower" in data.columns: fig.add_trace( go.Scatter( x=data[time_col], y=data[f"{value_col}_upper"], mode="lines", showlegend=False, line={"width": 0}, ) ) fig.add_trace( go.Scatter( x=data[time_col], y=data[f"{value_col}_lower"], mode="lines", fill="tonexty", fillcolor="rgba(0, 128, 199, 0.2)", name="95% CI", line={"width": 0}, ) ) # Update layout fig.update_layout( title=title, xaxis_title="Date", yaxis_title=value_col.replace("_", " ").title(), height=height, hovermode="x unified", template="plotly_white", font={"family": "Arial, sans-serif"}, ) # Add range slider fig.update_xaxes(rangeslider_visible=True) return fig
[docs] def create_correlation_heatmap( data: pd.DataFrame, title: str = "Correlation Matrix", height: int = 600, show_values: bool = True, ) -> go.Figure: """Create interactive correlation heatmap. Creates an interactive heatmap showing correlations between variables with customizable color scheme and annotations. Args: data: DataFrame with variables to correlate title: Plot title height: Plot height in pixels show_values: Whether to show correlation values Returns: Plotly figure with correlation heatmap """ # Calculate correlation matrix corr_matrix = data.corr() # Create heatmap fig = go.Figure( data=go.Heatmap( z=corr_matrix.values, x=corr_matrix.columns, y=corr_matrix.columns, colorscale=[ [0, WSJ_COLORS["red"]], [0.5, "white"], [1, WSJ_COLORS["blue"]], ], zmin=-1, zmax=1, text=corr_matrix.values if show_values else None, texttemplate="%{text:.2f}" if show_values else None, textfont={"size": 10}, colorbar={"title": "Correlation", "tickmode": "linear", "tick0": -1, "dtick": 0.5}, ) ) # Update layout fig.update_layout( title=title, height=height, xaxis={"side": "bottom"}, yaxis={"side": "left"}, template="plotly_white", font={"family": "Arial, sans-serif"}, ) return fig
[docs] def create_risk_dashboard( risk_metrics: Dict[str, Any], title: str = "Risk Analytics Dashboard", height: int = 800, ) -> go.Figure: """Create comprehensive risk analytics dashboard. Creates a multi-panel dashboard showing various risk metrics and distributions for comprehensive risk assessment. Args: risk_metrics: Dictionary containing risk metrics and data title: Dashboard title height: Dashboard height in pixels Returns: Plotly figure with risk dashboard """ # Create 3x2 subplot grid fig = make_subplots( rows=3, cols=2, subplot_titles=( "Value at Risk Distribution", "Expected Shortfall Analysis", "Risk Contribution by Factor", "Stress Test Results", "Historical VaR Breaches", "Risk Metric Trends", ), specs=[ [{"type": "histogram"}, {"type": "bar"}], [{"type": "pie"}, {"type": "bar"}], [{"type": "scatter"}, {"type": "scatter"}], ], ) # VaR Distribution (if available) if "var_distribution" in risk_metrics: fig.add_trace( go.Histogram( x=risk_metrics["var_distribution"], nbinsx=30, marker_color=WSJ_COLORS["blue"], opacity=0.7, name="VaR", ), row=1, col=1, ) # Expected Shortfall if "expected_shortfall" in risk_metrics: categories = list(risk_metrics["expected_shortfall"].keys()) values = list(risk_metrics["expected_shortfall"].values()) fig.add_trace( go.Bar( x=categories, y=values, marker_color=WSJ_COLORS["red"], name="ES", ), row=1, col=2, ) # Risk Contribution Pie if "risk_contribution" in risk_metrics: labels = list(risk_metrics["risk_contribution"].keys()) values = list(risk_metrics["risk_contribution"].values()) fig.add_trace( go.Pie( labels=labels, values=values, marker={"colors": COLOR_SEQUENCE[: len(labels)]}, ), row=2, col=1, ) # Stress Test Results if "stress_tests" in risk_metrics: scenarios = list(risk_metrics["stress_tests"].keys()) impacts = list(risk_metrics["stress_tests"].values()) colors = [WSJ_COLORS["green"] if x >= 0 else WSJ_COLORS["red"] for x in impacts] fig.add_trace( go.Bar( x=scenarios, y=impacts, marker_color=colors, name="Impact", ), row=2, col=2, ) # Historical VaR Breaches if "var_breaches" in risk_metrics: dates = risk_metrics["var_breaches"]["dates"] breaches = risk_metrics["var_breaches"]["values"] fig.add_trace( go.Scatter( x=dates, y=breaches, mode="markers", marker={"color": WSJ_COLORS["red"], "size": 10}, name="Breaches", ), row=3, col=1, ) # Risk Metric Trends if "trends" in risk_metrics: for metric_name, trend_data in risk_metrics["trends"].items(): fig.add_trace( go.Scatter( x=trend_data["dates"], y=trend_data["values"], mode="lines", name=metric_name, ), row=3, col=2, ) # Update layout fig.update_layout( title_text=title, showlegend=False, height=height, template="plotly_white", font={"family": "Arial, sans-serif"}, ) return fig