"""Interactive visualization functions using Plotly.
This module provides functions for creating interactive dashboards
and visualizations for Monte Carlo simulations and analysis results.
"""
from typing import Any, Dict, Union
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .core import COLOR_SEQUENCE, WSJ_COLORS
[docs]
def create_interactive_dashboard(
results: Union[Dict[str, Any], pd.DataFrame],
title: str = "Monte Carlo Simulation Dashboard",
height: int = 600,
show_distributions: bool = False,
) -> go.Figure:
"""Create interactive Plotly dashboard with WSJ styling.
Creates a comprehensive interactive dashboard with multiple panels
showing simulation results, convergence, and risk metrics.
Args:
results: Dictionary with simulation results or DataFrame
title: Dashboard title
height: Dashboard height in pixels
show_distributions: Whether to show distribution plots
Returns:
Plotly figure with interactive dashboard
Examples:
>>> results = {
... "growth_rates": np.random.normal(0.05, 0.02, 1000),
... "losses": np.random.lognormal(10, 2, 1000),
... "metrics": {"var_95": 100000, "var_99": 150000}
... }
>>> fig = create_interactive_dashboard(results)
"""
# Handle DataFrame input
if isinstance(results, pd.DataFrame):
# Convert DataFrame to dictionary format expected by dashboard
results_dict = {
"data": results,
"summary": {
"mean_assets": (
results.get("assets", pd.Series()).mean() if "assets" in results.columns else 0
),
"mean_losses": (
results.get("losses", pd.Series()).mean() if "losses" in results.columns else 0
),
"years": results["year"].nunique() if "year" in results.columns else 1,
},
}
results = results_dict
# Create subplots
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=(
"Growth Rate Distribution",
"Loss Exceedance Curve",
"Convergence Diagnostics",
"Risk Metrics",
),
specs=[
[{"type": "histogram"}, {"type": "scatter"}],
[{"type": "scatter"}, {"type": "bar"}],
],
)
# WSJ-style layout
layout_theme = {
"plot_bgcolor": "white",
"paper_bgcolor": "white",
"font": {"family": "Arial, sans-serif", "size": 11, "color": WSJ_COLORS["black"]},
"title": {"font": {"size": 16, "color": WSJ_COLORS["black"]}},
"xaxis": {"gridcolor": WSJ_COLORS["light_gray"], "gridwidth": 0.5},
"yaxis": {"gridcolor": WSJ_COLORS["light_gray"], "gridwidth": 0.5},
"colorway": COLOR_SEQUENCE,
}
# Growth rate histogram
if "growth_rates" in results:
fig.add_trace(
go.Histogram(
x=results["growth_rates"],
nbinsx=50,
marker_color=WSJ_COLORS["blue"],
opacity=0.7,
name="Growth Rate",
),
row=1,
col=1,
)
# Loss exceedance curve
if "losses" in results:
losses_data = np.asarray(results["losses"])
sorted_losses = np.sort(losses_data)[::-1]
exceedance_prob = np.arange(1, len(sorted_losses) + 1) / len(sorted_losses)
fig.add_trace(
go.Scatter(
x=sorted_losses / 1e6,
y=exceedance_prob,
mode="lines",
line={"color": WSJ_COLORS["red"], "width": 2},
name="Exceedance",
),
row=1,
col=2,
)
fig.update_xaxes(title_text="Loss Amount ($M)", row=1, col=2)
fig.update_yaxes(title_text="Exceedance Probability", type="log", row=1, col=2)
# Convergence diagnostics
if "convergence" in results and isinstance(results["convergence"], dict):
iterations = results["convergence"].get("iterations", [])
r_hat = results["convergence"].get("r_hat", [])
fig.add_trace(
go.Scatter(
x=iterations,
y=r_hat,
mode="lines+markers",
line={"color": WSJ_COLORS["green"], "width": 2},
marker={"size": 6},
name="R-hat",
),
row=2,
col=1,
)
# Add convergence threshold line
fig.add_hline(
y=1.1,
line_dash="dash",
line_color=WSJ_COLORS["orange"],
annotation_text="Convergence Threshold",
row=2,
col=1,
)
fig.update_xaxes(title_text="Iterations", row=2, col=1)
fig.update_yaxes(title_text="R-hat Statistic", row=2, col=1)
# Risk metrics bar chart
if "metrics" in results and isinstance(results["metrics"], dict):
metric_names = ["VaR(95%)", "VaR(99%)", "TVaR(99%)", "Expected Shortfall"]
metric_values = [
results["metrics"].get("var_95", 0) / 1e6,
results["metrics"].get("var_99", 0) / 1e6,
results["metrics"].get("tvar_99", 0) / 1e6,
results["metrics"].get("expected_shortfall", 0) / 1e6,
]
fig.add_trace(
go.Bar(
x=metric_names,
y=metric_values,
marker_color=COLOR_SEQUENCE[: len(metric_names)],
text=[f"${v:.1f}M" for v in metric_values],
textposition="outside",
name="Risk Metrics",
),
row=2,
col=2,
)
fig.update_yaxes(title_text="Amount ($M)", row=2, col=2)
# Update layout
fig.update_layout(title_text=title, showlegend=False, height=height, **layout_theme)
# Update all axes
fig.update_xaxes(showgrid=True, gridwidth=0.5, gridcolor=WSJ_COLORS["light_gray"])
fig.update_yaxes(showgrid=True, gridwidth=0.5, gridcolor=WSJ_COLORS["light_gray"])
return fig
[docs]
def create_time_series_dashboard(
data: pd.DataFrame,
value_col: str,
time_col: str = "date",
title: str = "Time Series Analysis",
height: int = 600,
show_forecast: bool = False,
) -> go.Figure:
"""Create interactive time series visualization.
Creates an interactive time series plot with optional forecast bands
and statistical overlays.
Args:
data: DataFrame with time series data
value_col: Name of value column
time_col: Name of time column
title: Plot title
height: Plot height in pixels
show_forecast: Whether to show forecast bands
Returns:
Plotly figure with time series visualization
"""
fig = go.Figure()
# Main time series
fig.add_trace(
go.Scatter(
x=data[time_col],
y=data[value_col],
mode="lines",
name="Actual",
line={"color": WSJ_COLORS["blue"], "width": 2},
)
)
# Add moving average
if len(data) > 20:
ma = data[value_col].rolling(window=20).mean()
fig.add_trace(
go.Scatter(
x=data[time_col],
y=ma,
mode="lines",
name="20-period MA",
line={"color": WSJ_COLORS["orange"], "width": 1, "dash": "dash"},
)
)
# Add forecast if requested
if show_forecast and f"{value_col}_forecast" in data.columns:
fig.add_trace(
go.Scatter(
x=data[time_col],
y=data[f"{value_col}_forecast"],
mode="lines",
name="Forecast",
line={"color": WSJ_COLORS["green"], "width": 2, "dash": "dot"},
)
)
# Add confidence bands if available
if f"{value_col}_upper" in data.columns and f"{value_col}_lower" in data.columns:
fig.add_trace(
go.Scatter(
x=data[time_col],
y=data[f"{value_col}_upper"],
mode="lines",
showlegend=False,
line={"width": 0},
)
)
fig.add_trace(
go.Scatter(
x=data[time_col],
y=data[f"{value_col}_lower"],
mode="lines",
fill="tonexty",
fillcolor="rgba(0, 128, 199, 0.2)",
name="95% CI",
line={"width": 0},
)
)
# Update layout
fig.update_layout(
title=title,
xaxis_title="Date",
yaxis_title=value_col.replace("_", " ").title(),
height=height,
hovermode="x unified",
template="plotly_white",
font={"family": "Arial, sans-serif"},
)
# Add range slider
fig.update_xaxes(rangeslider_visible=True)
return fig
[docs]
def create_correlation_heatmap(
data: pd.DataFrame,
title: str = "Correlation Matrix",
height: int = 600,
show_values: bool = True,
) -> go.Figure:
"""Create interactive correlation heatmap.
Creates an interactive heatmap showing correlations between variables
with customizable color scheme and annotations.
Args:
data: DataFrame with variables to correlate
title: Plot title
height: Plot height in pixels
show_values: Whether to show correlation values
Returns:
Plotly figure with correlation heatmap
"""
# Calculate correlation matrix
corr_matrix = data.corr()
# Create heatmap
fig = go.Figure(
data=go.Heatmap(
z=corr_matrix.values,
x=corr_matrix.columns,
y=corr_matrix.columns,
colorscale=[
[0, WSJ_COLORS["red"]],
[0.5, "white"],
[1, WSJ_COLORS["blue"]],
],
zmin=-1,
zmax=1,
text=corr_matrix.values if show_values else None,
texttemplate="%{text:.2f}" if show_values else None,
textfont={"size": 10},
colorbar={"title": "Correlation", "tickmode": "linear", "tick0": -1, "dtick": 0.5},
)
)
# Update layout
fig.update_layout(
title=title,
height=height,
xaxis={"side": "bottom"},
yaxis={"side": "left"},
template="plotly_white",
font={"family": "Arial, sans-serif"},
)
return fig
[docs]
def create_risk_dashboard(
risk_metrics: Dict[str, Any],
title: str = "Risk Analytics Dashboard",
height: int = 800,
) -> go.Figure:
"""Create comprehensive risk analytics dashboard.
Creates a multi-panel dashboard showing various risk metrics
and distributions for comprehensive risk assessment.
Args:
risk_metrics: Dictionary containing risk metrics and data
title: Dashboard title
height: Dashboard height in pixels
Returns:
Plotly figure with risk dashboard
"""
# Create 3x2 subplot grid
fig = make_subplots(
rows=3,
cols=2,
subplot_titles=(
"Value at Risk Distribution",
"Expected Shortfall Analysis",
"Risk Contribution by Factor",
"Stress Test Results",
"Historical VaR Breaches",
"Risk Metric Trends",
),
specs=[
[{"type": "histogram"}, {"type": "bar"}],
[{"type": "pie"}, {"type": "bar"}],
[{"type": "scatter"}, {"type": "scatter"}],
],
)
# VaR Distribution (if available)
if "var_distribution" in risk_metrics:
fig.add_trace(
go.Histogram(
x=risk_metrics["var_distribution"],
nbinsx=30,
marker_color=WSJ_COLORS["blue"],
opacity=0.7,
name="VaR",
),
row=1,
col=1,
)
# Expected Shortfall
if "expected_shortfall" in risk_metrics:
categories = list(risk_metrics["expected_shortfall"].keys())
values = list(risk_metrics["expected_shortfall"].values())
fig.add_trace(
go.Bar(
x=categories,
y=values,
marker_color=WSJ_COLORS["red"],
name="ES",
),
row=1,
col=2,
)
# Risk Contribution Pie
if "risk_contribution" in risk_metrics:
labels = list(risk_metrics["risk_contribution"].keys())
values = list(risk_metrics["risk_contribution"].values())
fig.add_trace(
go.Pie(
labels=labels,
values=values,
marker={"colors": COLOR_SEQUENCE[: len(labels)]},
),
row=2,
col=1,
)
# Stress Test Results
if "stress_tests" in risk_metrics:
scenarios = list(risk_metrics["stress_tests"].keys())
impacts = list(risk_metrics["stress_tests"].values())
colors = [WSJ_COLORS["green"] if x >= 0 else WSJ_COLORS["red"] for x in impacts]
fig.add_trace(
go.Bar(
x=scenarios,
y=impacts,
marker_color=colors,
name="Impact",
),
row=2,
col=2,
)
# Historical VaR Breaches
if "var_breaches" in risk_metrics:
dates = risk_metrics["var_breaches"]["dates"]
breaches = risk_metrics["var_breaches"]["values"]
fig.add_trace(
go.Scatter(
x=dates,
y=breaches,
mode="markers",
marker={"color": WSJ_COLORS["red"], "size": 10},
name="Breaches",
),
row=3,
col=1,
)
# Risk Metric Trends
if "trends" in risk_metrics:
for metric_name, trend_data in risk_metrics["trends"].items():
fig.add_trace(
go.Scatter(
x=trend_data["dates"],
y=trend_data["values"],
mode="lines",
name=metric_name,
),
row=3,
col=2,
)
# Update layout
fig.update_layout(
title_text=title,
showlegend=False,
height=height,
template="plotly_white",
font={"family": "Arial, sans-serif"},
)
return fig