Source code for ergodic_insurance.visualization_infra.figure_factory

"""Figure factory for creating standardized plots with consistent styling.

This module provides a factory class for creating various types of plots
with automatic styling, spacing, and formatting applied consistently.
"""

from typing import Any, Dict, List, Optional, Tuple, Union

from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd

from .style_manager import StyleManager, Theme


[docs] class FigureFactory: """Factory for creating standardized figures with consistent styling. This class provides methods to create various types of plots with automatic application of themes, consistent formatting, and proper spacing. Example: >>> factory = FigureFactory(theme=Theme.PRESENTATION) >>> fig, ax = factory.create_line_plot( ... x_data=[1, 2, 3, 4], ... y_data=[10, 20, 15, 25], ... title="Revenue Growth", ... x_label="Quarter", ... y_label="Revenue ($M)" ... ) >>> # Create multiple subplots >>> fig, axes = factory.create_subplots( ... rows=2, cols=2, ... size_type="large", ... subplot_titles=["Q1", "Q2", "Q3", "Q4"] ... ) """ def __init__( self, style_manager: Optional[StyleManager] = None, theme: Theme = Theme.DEFAULT, auto_apply: bool = True, ): """Initialize figure factory. Args: style_manager: Custom style manager (creates default if None) theme: Theme to use for all figures auto_apply: Whether to automatically apply styling """ self.style_manager = style_manager or StyleManager(theme=theme) self.auto_apply = auto_apply if self.auto_apply: self.style_manager.apply_style()
[docs] def create_figure( self, size_type: str = "medium", orientation: str = "landscape", dpi_type: str = "screen", title: Optional[str] = None, ) -> Tuple[Figure, Axes]: """Create a basic figure with styling applied. Args: size_type: Size preset (small, medium, large, blog, technical, presentation) orientation: Figure orientation (landscape or portrait) dpi_type: DPI type (screen, web, print) title: Optional figure title Returns: Tuple of (figure, axes) """ size = self.style_manager.get_figure_size(size_type, orientation) dpi = self.style_manager.get_dpi(dpi_type) fig, ax = plt.subplots(figsize=size, dpi=dpi) if title: fig.suptitle(title, fontweight="bold") self._apply_axis_styling(ax) return fig, ax
[docs] def create_subplots( self, rows: int = 1, cols: int = 1, size_type: str = "large", dpi_type: str = "screen", title: Optional[str] = None, subplot_titles: Optional[List[str]] = None, **kwargs, ) -> Tuple[Figure, Union[Axes, np.ndarray]]: """Create subplots with consistent styling. Args: rows: Number of subplot rows cols: Number of subplot columns size_type: Size preset dpi_type: DPI type title: Main figure title subplot_titles: Titles for each subplot **kwargs: Additional arguments for plt.subplots Returns: Tuple of (figure, axes array) """ size = self.style_manager.get_figure_size(size_type) dpi = self.style_manager.get_dpi(dpi_type) fig, axes = plt.subplots(rows, cols, figsize=size, dpi=dpi, **kwargs) if title: fig.suptitle( title, fontweight="bold", fontsize=self.style_manager.get_fonts().size_title + 2 ) # Apply styling to all axes axes_list: List[Axes] = list(axes.flatten()) if isinstance(axes, np.ndarray) else [axes] for i, ax in enumerate(axes_list): self._apply_axis_styling(ax) if subplot_titles and i < len(subplot_titles): ax.set_title(subplot_titles[i]) plt.tight_layout() return fig, axes
[docs] def create_line_plot( # pylint: disable=too-many-locals self, x_data: Union[List, np.ndarray, pd.Series], y_data: Union[List, np.ndarray, pd.Series, Dict[str, Union[List, np.ndarray]]], title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, labels: Optional[List[str]] = None, size_type: str = "medium", dpi_type: str = "screen", show_legend: bool = True, show_grid: bool = True, markers: bool = False, **kwargs, ) -> Tuple[Figure, Axes]: """Create a line plot with automatic formatting. Args: x_data: X-axis data y_data: Y-axis data (can be multiple series as dict) title: Plot title x_label: X-axis label y_label: Y-axis label labels: Series labels for legend size_type: Figure size preset dpi_type: DPI type show_legend: Whether to show legend show_grid: Whether to show grid markers: Whether to add markers to lines **kwargs: Additional arguments for plot Returns: Tuple of (figure, axes) """ fig, ax = self.create_figure(size_type=size_type, dpi_type=dpi_type, title=title) colors = self.style_manager.get_colors() # Handle multiple series if isinstance(y_data, dict): for i, (series_label, data) in enumerate(y_data.items()): color = colors.series[i % len(colors.series)] marker = "o" if markers else None ax.plot(x_data, data, label=series_label, color=color, marker=marker, **kwargs) else: # Single series marker = "o" if markers else None label: Optional[str] = labels[0] if labels else None ax.plot(x_data, y_data, color=colors.primary, label=label, marker=marker, **kwargs) # Labels and formatting if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) if show_legend and (isinstance(y_data, dict) or labels): ax.legend(loc="best", frameon=True) ax.grid(show_grid, alpha=self.style_manager.get_grid_config().grid_alpha) plt.tight_layout() return fig, ax
[docs] def create_bar_plot( # pylint: disable=too-many-locals,too-many-branches self, categories: Union[List, np.ndarray], values: Union[List, np.ndarray, Dict[str, Union[List, np.ndarray]]], title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, labels: Optional[List[str]] = None, size_type: str = "medium", dpi_type: str = "screen", orientation: str = "vertical", show_values: bool = False, value_format: str = ".1f", **kwargs, ) -> Tuple[Figure, Axes]: """Create a bar plot with automatic formatting. Args: categories: Category labels values: Values to plot (can be multiple series as dict) title: Plot title x_label: X-axis label y_label: Y-axis label labels: Series labels for legend size_type: Figure size preset dpi_type: DPI type orientation: Bar orientation (vertical or horizontal) show_values: Whether to show value labels on bars value_format: Format string for value labels **kwargs: Additional arguments for bar plot Returns: Tuple of (figure, axes) """ fig, ax = self.create_figure(size_type=size_type, dpi_type=dpi_type, title=title) colors = self.style_manager.get_colors() # Handle multiple series if isinstance(values, dict): n_series = len(values) width = 0.8 / n_series x_pos = np.arange(len(categories)) for i, (label, data) in enumerate(values.items()): offset = (i - n_series / 2 + 0.5) * width color = colors.series[i % len(colors.series)] if orientation == "vertical": bars = ax.bar(x_pos + offset, data, width, label=label, color=color, **kwargs) else: bars = ax.barh(x_pos + offset, data, width, label=label, color=color, **kwargs) if show_values: self._add_value_labels(ax, bars, orientation, value_format) if orientation == "vertical": ax.set_xticks(x_pos) ax.set_xticklabels(categories) else: ax.set_yticks(x_pos) ax.set_yticklabels(categories) ax.legend(loc="best", frameon=True) else: # Single series if orientation == "vertical": bars = ax.bar(categories, values, color=colors.primary, **kwargs) else: bars = ax.barh(categories, values, color=colors.primary, **kwargs) if show_values: self._add_value_labels(ax, bars, orientation, value_format) # Labels if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) ax.grid( True, alpha=self.style_manager.get_grid_config().grid_alpha, axis="y" if orientation == "vertical" else "x", ) plt.tight_layout() return fig, ax
[docs] def create_scatter_plot( self, x_data: Union[List, np.ndarray], y_data: Union[List, np.ndarray], title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, size_type: str = "medium", dpi_type: str = "screen", colors: Optional[Union[List, np.ndarray]] = None, sizes: Optional[Union[List, np.ndarray]] = None, labels: Optional[List[str]] = None, show_colorbar: bool = False, **kwargs, ) -> Tuple[Figure, Axes]: """Create a scatter plot with automatic formatting. Args: x_data: X-axis data y_data: Y-axis data title: Plot title x_label: X-axis label y_label: Y-axis label size_type: Figure size preset dpi_type: DPI type colors: Optional colors for points (for continuous coloring) sizes: Optional sizes for points labels: Optional labels for points show_colorbar: Whether to show colorbar when colors provided **kwargs: Additional arguments for scatter Returns: Tuple of (figure, axes) """ fig, ax = self.create_figure(size_type=size_type, dpi_type=dpi_type, title=title) theme_colors = self.style_manager.get_colors() # Default sizes if not provided if sizes is None: sizes = np.array([50] * len(x_data)) # Create scatter plot if colors is not None: scatter = ax.scatter(x_data, y_data, c=colors, s=sizes, cmap="viridis", **kwargs) if show_colorbar: plt.colorbar(scatter, ax=ax) else: scatter = ax.scatter(x_data, y_data, s=sizes, color=theme_colors.primary, **kwargs) # Labels if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) ax.grid(True, alpha=self.style_manager.get_grid_config().grid_alpha) plt.tight_layout() return fig, ax
[docs] def create_histogram( # pylint: disable=too-many-locals self, data: Union[List, np.ndarray, pd.Series], title: Optional[str] = None, x_label: Optional[str] = None, y_label: str = "Frequency", bins: Union[int, str] = "auto", size_type: str = "medium", dpi_type: str = "screen", show_statistics: bool = False, show_kde: bool = False, **kwargs, ) -> Tuple[Figure, Axes]: """Create a histogram with automatic formatting. Args: data: Data to plot title: Plot title x_label: X-axis label y_label: Y-axis label bins: Number of bins or method size_type: Figure size preset dpi_type: DPI type show_statistics: Whether to show mean/median lines show_kde: Whether to overlay KDE **kwargs: Additional arguments for hist Returns: Tuple of (figure, axes) """ fig, ax = self.create_figure(size_type=size_type, dpi_type=dpi_type, title=title) colors = self.style_manager.get_colors() # Create histogram _n, _bins_out, _patches = ax.hist( data, bins=bins, color=colors.primary, alpha=0.7, edgecolor="black", linewidth=0.5, **kwargs, ) # Add statistics if requested if show_statistics: mean_val = np.mean(data) median_val = np.median(data) ax.axvline( mean_val, color=colors.warning, linestyle="--", linewidth=2, label=f"Mean: {mean_val:.2f}", ) ax.axvline( median_val, color=colors.success, linestyle="--", linewidth=2, label=f"Median: {median_val:.2f}", ) ax.legend() # Add KDE if requested if show_kde: from scipy import stats kde = stats.gaussian_kde(data) if isinstance(data, (list, np.ndarray, pd.Series)): data_array = np.array(data) if isinstance(data, list) else data x_range = np.linspace(float(np.min(data_array)), float(np.max(data_array)), 100) else: # Should never reach here but keeping for type safety x_range = np.linspace(0, 100, 100) # type: ignore[unreachable] kde_values = kde(x_range) # Scale KDE to match histogram ax2 = ax.twinx() ax2.plot(x_range, kde_values, color=colors.secondary, linewidth=2, label="KDE") ax2.set_ylabel("Density") ax2.tick_params(axis="y", labelcolor=colors.secondary) # Labels if x_label: ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.grid(True, alpha=self.style_manager.get_grid_config().grid_alpha, axis="y") plt.tight_layout() return fig, ax
[docs] def create_heatmap( self, data: Union[np.ndarray, pd.DataFrame], title: Optional[str] = None, x_labels: Optional[List[str]] = None, y_labels: Optional[List[str]] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, size_type: str = "medium", dpi_type: str = "screen", cmap: str = "RdBu_r", show_values: bool = True, value_format: str = ".2f", **kwargs, ) -> Tuple[Figure, Axes]: """Create a heatmap with automatic formatting. Args: data: 2D data array or DataFrame title: Plot title x_labels: Labels for x-axis y_labels: Labels for y-axis x_label: X-axis title y_label: Y-axis title size_type: Figure size preset dpi_type: DPI type cmap: Colormap name show_values: Whether to show values in cells value_format: Format string for cell values **kwargs: Additional arguments for imshow Returns: Tuple of (figure, axes) """ fig, ax = self.create_figure(size_type=size_type, dpi_type=dpi_type, title=title) # Handle DataFrame if isinstance(data, pd.DataFrame): if x_labels is None: x_labels = list(data.columns) if y_labels is None: y_labels = list(data.index) data = data.values # Create heatmap im = ax.imshow(data, cmap=cmap, aspect="auto", **kwargs) plt.colorbar(im, ax=ax) # Set tick labels if x_labels: ax.set_xticks(np.arange(len(x_labels))) ax.set_xticklabels(x_labels) plt.setp(ax.get_xticklabels(), rotation=45, ha="right") if y_labels: ax.set_yticks(np.arange(len(y_labels))) ax.set_yticklabels(y_labels) # Add values if requested if show_values: for i in range(data.shape[0]): for j in range(data.shape[1]): _text = ax.text( j, i, f"{data[i, j]:{value_format}}", ha="center", va="center", color="black", ) # Labels if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) plt.tight_layout() return fig, ax
[docs] def create_box_plot( # pylint: disable=too-many-locals,too-many-branches self, data: Union[List[List], Dict[str, List], pd.DataFrame], title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, labels: Optional[List[str]] = None, size_type: str = "medium", dpi_type: str = "screen", orientation: str = "vertical", show_means: bool = True, **kwargs, ) -> Tuple[Figure, Axes]: """Create a box plot with automatic formatting. Args: data: Data for box plot (list of lists, dict, or DataFrame) title: Plot title x_label: X-axis label y_label: Y-axis label labels: Labels for each box size_type: Figure size preset dpi_type: DPI type orientation: Plot orientation (vertical or horizontal) show_means: Whether to show mean markers **kwargs: Additional arguments for boxplot Returns: Tuple of (figure, axes) """ fig, ax = self.create_figure(size_type=size_type, dpi_type=dpi_type, title=title) colors = self.style_manager.get_colors() # Prepare data if isinstance(data, dict): plot_data = list(data.values()) if labels is None: labels = list(data.keys()) elif isinstance(data, pd.DataFrame): plot_data = [data[col].dropna().values.tolist() for col in data.columns] if labels is None: labels = list(data.columns) else: plot_data = data # Create box plot (using updated matplotlib API) # Filter out 'vert' from kwargs if it exists to avoid duplication filtered_kwargs = {k: v for k, v in kwargs.items() if k != "vert"} # Use vert parameter with warning suppression for compatibility import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PendingDeprecationWarning) bp = ax.boxplot( plot_data, vert=(orientation == "vertical"), tick_labels=labels, # Use tick_labels instead of labels showmeans=show_means, patch_artist=True, **filtered_kwargs, ) # Style the boxes for i, (box, median) in enumerate(zip(bp["boxes"], bp["medians"])): box.set_facecolor(colors.series[i % len(colors.series)]) box.set_alpha(0.7) median.set_color(colors.text) median.set_linewidth(2) # Style whiskers and caps for whisker in bp["whiskers"]: whisker.set_color(colors.neutral) whisker.set_linewidth(1) for cap in bp["caps"]: cap.set_color(colors.neutral) cap.set_linewidth(1) # Style outliers for flier in bp["fliers"]: flier.set_marker("o") flier.set_markersize(4) flier.set_markeredgecolor(colors.warning) flier.set_markerfacecolor(colors.warning) flier.set_alpha(0.5) # Style means if shown if show_means and "means" in bp: for mean in bp["means"]: mean.set_marker("D") mean.set_markersize(6) mean.set_markeredgecolor(colors.success) mean.set_markerfacecolor(colors.success) # Labels if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) ax.grid( True, alpha=self.style_manager.get_grid_config().grid_alpha, axis="y" if orientation == "vertical" else "x", ) plt.tight_layout() return fig, ax
[docs] def format_axis_currency( self, ax: Axes, axis: str = "y", abbreviate: bool = True, decimals: int = 0, ) -> None: """Format axis labels as currency. Args: ax: Matplotlib axes axis: Which axis to format (x or y) abbreviate: Whether to abbreviate large numbers decimals: Number of decimal places """ def currency_formatter(x, pos): if abbreviate: if abs(x) >= 1e9: return f"${x/1e9:.{decimals}f}B" if abs(x) >= 1e6: return f"${x/1e6:.{decimals}f}M" if abs(x) >= 1e3: return f"${x/1e3:.{decimals}f}K" return f"${x:,.{decimals}f}" formatter = mticker.FuncFormatter(currency_formatter) if axis == "y": ax.yaxis.set_major_formatter(formatter) else: ax.xaxis.set_major_formatter(formatter)
[docs] def format_axis_percentage( self, ax: Axes, axis: str = "y", decimals: int = 0, ) -> None: """Format axis labels as percentages. Args: ax: Matplotlib axes axis: Which axis to format (x or y) decimals: Number of decimal places """ def percentage_formatter(x, pos): return f"{x*100:.{decimals}f}%" formatter = mticker.FuncFormatter(percentage_formatter) if axis == "y": ax.yaxis.set_major_formatter(formatter) else: ax.xaxis.set_major_formatter(formatter)
[docs] def add_annotations( self, ax: Axes, x: float, y: float, text: str, arrow: bool = True, offset: Tuple[float, float] = (10, 10), **kwargs, ) -> None: """Add styled annotation to plot. Args: ax: Matplotlib axes x: X coordinate y: Y coordinate text: Annotation text arrow: Whether to show arrow offset: Text offset from point **kwargs: Additional arguments for annotate """ colors = self.style_manager.get_colors() if arrow: ax.annotate( text, xy=(x, y), xytext=offset, textcoords="offset points", arrowprops={ "arrowstyle": "->", "color": colors.neutral, "connectionstyle": "arc3,rad=0.2", }, fontsize=self.style_manager.get_fonts().size_base - 1, color=colors.text, **kwargs, ) else: ax.text( x, y, text, fontsize=self.style_manager.get_fonts().size_base - 1, color=colors.text, **kwargs, )
[docs] def save_figure( self, fig: Figure, filename: str, output_type: str = "web", **kwargs, ) -> None: """Save figure with appropriate DPI settings. Args: fig: Figure to save filename: Output filename output_type: Output type (screen, web, print) **kwargs: Additional arguments for savefig """ dpi = self.style_manager.get_dpi(output_type) fig.savefig(filename, dpi=dpi, bbox_inches="tight", **kwargs)
def _apply_axis_styling(self, ax: Axes) -> None: """Apply consistent styling to axes. Args: ax: Axes to style """ colors = self.style_manager.get_colors() grid_config = self.style_manager.get_grid_config() # Apply grid settings ax.grid( grid_config.show_grid, alpha=grid_config.grid_alpha, linewidth=grid_config.grid_linewidth, color=colors.grid, ) # Apply spine visibility ax.spines["top"].set_visible(grid_config.spine_top) ax.spines["right"].set_visible(grid_config.spine_right) ax.spines["bottom"].set_visible(grid_config.spine_bottom) ax.spines["left"].set_visible(grid_config.spine_left) # Apply spine colors and widths for spine in ax.spines.values(): spine.set_edgecolor(colors.neutral) spine.set_linewidth(grid_config.spine_linewidth) # Apply tick parameters ax.tick_params( axis="both", which="major", width=grid_config.tick_major_width, length=5, color=colors.neutral, ) ax.tick_params( axis="both", which="minor", width=grid_config.tick_minor_width, length=3, color=colors.neutral, ) def _add_value_labels( self, ax: Axes, bars: Any, orientation: str, format_str: str, ) -> None: """Add value labels to bars. Args: ax: Axes containing bars bars: Bar container orientation: Bar orientation format_str: Format string for values """ for bar_element in bars: if orientation == "vertical": height = bar_element.get_height() ax.text( bar_element.get_x() + bar_element.get_width() / 2.0, height, f"{height:{format_str}}", ha="center", va="bottom", fontsize=self.style_manager.get_fonts().size_base - 2, ) else: width = bar_element.get_width() ax.text( width, bar_element.get_y() + bar_element.get_height() / 2.0, f"{width:{format_str}}", ha="left", va="center", fontsize=self.style_manager.get_fonts().size_base - 2, )