Source code for mdr.visualization.plots

"""
Plotting utilities for Macrodata Refinement (MDR).

This module provides functions for creating and saving data visualizations.
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Union, Optional, Any, Callable
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from dataclasses import dataclass, field
import os


[docs] @dataclass class PlotConfig: """Configuration for plot appearance and behavior.""" title: Optional[str] = None figsize: Tuple[float, float] = (10.0, 6.0) dpi: int = 100 xlabel: Optional[str] = None ylabel: Optional[str] = None xlim: Optional[Tuple[float, float]] = None ylim: Optional[Tuple[float, float]] = None legend: bool = True grid: bool = True style: str = "seaborn-v0_8-whitegrid" palette: str = "viridis" font_family: str = "sans-serif" font_size: int = 12 figure_bgcolor: str = "#ffffff" axis_bgcolor: str = "#f8f8f8"
[docs] def __post_init__(self) -> None: """Validate configuration parameters.""" if self.title is not None: assert isinstance(self.title, str), "title must be a string" assert isinstance(self.figsize, tuple), "figsize must be a tuple" assert len(self.figsize) == 2, "figsize must be a tuple of length 2" assert isinstance(self.figsize[0], float), "figsize[0] must be a floating-point number" assert isinstance(self.figsize[1], float), "figsize[1] must be a floating-point number" assert self.figsize[0] > 0.0, "figsize[0] must be positive" assert self.figsize[1] > 0.0, "figsize[1] must be positive" assert isinstance(self.dpi, int), "dpi must be an integer" assert self.dpi > 0, "dpi must be positive" if self.xlabel is not None: assert isinstance(self.xlabel, str), "xlabel must be a string" if self.ylabel is not None: assert isinstance(self.ylabel, str), "ylabel must be a string" if self.xlim is not None: assert isinstance(self.xlim, tuple), "xlim must be a tuple" assert len(self.xlim) == 2, "xlim must be a tuple of length 2" assert isinstance(self.xlim[0], float), "xlim[0] must be a floating-point number" assert isinstance(self.xlim[1], float), "xlim[1] must be a floating-point number" assert self.xlim[0] <= self.xlim[1], "xlim[0] must be less than or equal to xlim[1]" if self.ylim is not None: assert isinstance(self.ylim, tuple), "ylim must be a tuple" assert len(self.ylim) == 2, "ylim must be a tuple of length 2" assert isinstance(self.ylim[0], float), "ylim[0] must be a floating-point number" assert isinstance(self.ylim[1], float), "ylim[1] must be a floating-point number" assert self.ylim[0] <= self.ylim[1], "ylim[0] must be less than or equal to ylim[1]" assert isinstance(self.legend, bool), "legend must be a boolean" assert isinstance(self.grid, bool), "grid must be a boolean" assert isinstance(self.style, str), "style must be a string" assert isinstance(self.palette, str), "palette must be a string" assert isinstance(self.font_family, str), "font_family must be a string" assert isinstance(self.font_size, int), "font_size must be an integer" assert self.font_size > 0, "font_size must be positive" assert isinstance(self.figure_bgcolor, str), "figure_bgcolor must be a string" assert isinstance(self.axis_bgcolor, str), "axis_bgcolor must be a string"
def _apply_plot_config( fig: plt.Figure, ax: plt.Axes, config: PlotConfig ) -> None: """ Apply plot configuration to a figure and axes. Args: fig: Matplotlib figure ax: Matplotlib axes config: Plot configuration """ assert isinstance(config, PlotConfig), "config must be a PlotConfig object" # Apply style plt.style.use(config.style) # Apply title if config.title is not None: ax.set_title(config.title, fontsize=config.font_size + 2, fontfamily=config.font_family) # Apply labels if config.xlabel is not None: ax.set_xlabel(config.xlabel, fontsize=config.font_size, fontfamily=config.font_family) if config.ylabel is not None: ax.set_ylabel(config.ylabel, fontsize=config.font_size, fontfamily=config.font_family) # Apply limits if config.xlim is not None: ax.set_xlim(config.xlim) if config.ylim is not None: ax.set_ylim(config.ylim) # Apply grid ax.grid(config.grid) # Apply background colors fig.patch.set_facecolor(config.figure_bgcolor) ax.set_facecolor(config.axis_bgcolor) # Apply font properties to tick labels ax.tick_params(labelsize=config.font_size, labelcolor="black") # Apply legend if there are any labeled elements if config.legend and ax.get_legend_handles_labels()[0]: ax.legend(frameon=True, fontsize=config.font_size, loc="best")
[docs] def plot_time_series( data: Dict[str, np.ndarray], timestamps: Optional[np.ndarray] = None, config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, plt.Axes]: """ Plot time series data. Args: data: Dictionary mapping variable names to data arrays timestamps: Optional array of timestamps for the x-axis config: Plot configuration Returns: Tuple of (figure, axes) """ assert isinstance(data, dict), "data must be a dictionary" assert all(isinstance(k, str) for k in data.keys()), "All keys in data must be strings" assert all(isinstance(v, np.ndarray) for v in data.values()), "All values in data must be numpy arrays" if timestamps is not None: assert isinstance(timestamps, np.ndarray), "timestamps must be a numpy ndarray" # Check that the timestamps array has the same length as the data arrays first_key = next(iter(data.keys())) assert len(timestamps) == len(data[first_key]), \ "timestamps array must have the same length as data arrays" # Use default config if not provided if config is None: config = PlotConfig( title="Time Series Plot", xlabel="Time" if timestamps is None else None, ylabel="Value" ) # Create figure and axes fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Plot each time series for name, values in data.items(): if timestamps is not None: ax.plot(timestamps, values, label=name) else: ax.plot(values, label=name) # Apply configuration _apply_plot_config(fig, ax, config) return fig, ax
[docs] def plot_histogram( data: Union[np.ndarray, Dict[str, np.ndarray]], bins: int = 30, density: bool = False, config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a histogram of data. Args: data: Array of data or dictionary mapping variable names to data arrays bins: Number of histogram bins density: Whether to normalize the histogram config: Plot configuration Returns: Tuple of (figure, axes) """ assert isinstance(bins, int), "bins must be an integer" assert bins > 0, "bins must be positive" assert isinstance(density, bool), "density must be a boolean" # Convert single array to dictionary if needed if isinstance(data, np.ndarray): data = {"Data": data} assert isinstance(data, dict), "data must be a dictionary or numpy ndarray" assert all(isinstance(k, str) for k in data.keys()), "All keys in data must be strings" assert all(isinstance(v, np.ndarray) for v in data.values()), "All values in data must be numpy arrays" # Use default config if not provided if config is None: config = PlotConfig( title="Histogram", xlabel="Value", ylabel="Frequency" if not density else "Density" ) # Create figure and axes fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Plot each histogram for name, values in data.items(): # Filter out NaN values valid_values = values[~np.isnan(values)] if len(valid_values) > 0: ax.hist( valid_values, bins=bins, density=density, label=name, alpha=0.7 ) # Apply configuration _apply_plot_config(fig, ax, config) return fig, ax
[docs] def plot_boxplot( data: Union[np.ndarray, Dict[str, np.ndarray]], vert: bool = True, showfliers: bool = True, config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a box plot of data. Args: data: Array of data or dictionary mapping variable names to data arrays vert: Whether to draw the boxes vertically showfliers: Whether to show outliers config: Plot configuration Returns: Tuple of (figure, axes) """ assert isinstance(vert, bool), "vert must be a boolean" assert isinstance(showfliers, bool), "showfliers must be a boolean" # Convert single array to dictionary if needed if isinstance(data, np.ndarray): data = {"Data": data} assert isinstance(data, dict), "data must be a dictionary or numpy ndarray" assert all(isinstance(k, str) for k in data.keys()), "All keys in data must be strings" assert all(isinstance(v, np.ndarray) for v in data.values()), "All values in data must be numpy arrays" # Use default config if not provided if config is None: config = PlotConfig( title="Box Plot", xlabel="" if vert else "Value", ylabel="Value" if vert else "" ) # Create figure and axes fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Prepare data for plotting box_data = [] labels = [] for name, values in data.items(): # Filter out NaN values valid_values = values[~np.isnan(values)] if len(valid_values) > 0: box_data.append(valid_values) labels.append(name) # Plot the box plot if box_data: ax.boxplot( box_data, labels=labels, vert=vert, showfliers=showfliers, patch_artist=True ) # Apply configuration _apply_plot_config(fig, ax, config) return fig, ax
[docs] def plot_heatmap( data: np.ndarray, row_labels: Optional[List[str]] = None, col_labels: Optional[List[str]] = None, cmap: str = "viridis", vmin: Optional[float] = None, vmax: Optional[float] = None, config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a heatmap of 2D data. Args: data: 2D array of data row_labels: Labels for the rows col_labels: Labels for the columns cmap: Colormap name vmin: Minimum value for color scaling vmax: Maximum value for color scaling config: Plot configuration Returns: Tuple of (figure, axes) """ assert isinstance(data, np.ndarray), "data must be a numpy ndarray" assert len(data.shape) == 2, "data must be a 2D array" assert isinstance(cmap, str), "cmap must be a string" if row_labels is not None: assert isinstance(row_labels, list), "row_labels must be a list" assert all(isinstance(label, str) for label in row_labels), "All row labels must be strings" assert len(row_labels) == data.shape[0], "Number of row labels must match data shape" if col_labels is not None: assert isinstance(col_labels, list), "col_labels must be a list" assert all(isinstance(label, str) for label in col_labels), "All column labels must be strings" assert len(col_labels) == data.shape[1], "Number of column labels must match data shape" if vmin is not None: assert isinstance(vmin, float), "vmin must be a floating-point number" if vmax is not None: assert isinstance(vmax, float), "vmax must be a floating-point number" if vmin is not None and vmax is not None: assert vmin <= vmax, "vmin must be less than or equal to vmax" # Use default config if not provided if config is None: config = PlotConfig( title="Heatmap", figsize=(8.0, 6.0) ) # Create figure and axes fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Plot the heatmap im = ax.imshow( data, cmap=cmap, vmin=vmin, vmax=vmax, aspect="auto" ) # Add colorbar cbar = fig.colorbar(im, ax=ax) # Add row and column labels if row_labels is not None: ax.set_yticks(np.arange(len(row_labels))) ax.set_yticklabels(row_labels) if col_labels is not None: ax.set_xticks(np.arange(len(col_labels))) ax.set_xticklabels(col_labels, rotation=45, ha="right") # Apply configuration _apply_plot_config(fig, ax, config) return fig, ax
[docs] def plot_scatter( x: np.ndarray, y: np.ndarray, labels: Optional[np.ndarray] = None, sizes: Optional[np.ndarray] = None, alpha: float = 0.7, config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a scatter plot of data. Args: x: X-coordinates y: Y-coordinates labels: Labels or categories for the points sizes: Sizes for the points alpha: Transparency for the points config: Plot configuration Returns: Tuple of (figure, axes) """ assert isinstance(x, np.ndarray), "x must be a numpy ndarray" assert isinstance(y, np.ndarray), "y must be a numpy ndarray" assert len(x) == len(y), "x and y must have the same length" assert isinstance(alpha, float), "alpha must be a floating-point number" assert 0.0 <= alpha <= 1.0, "alpha must be between 0 and 1" if labels is not None: assert isinstance(labels, np.ndarray), "labels must be a numpy ndarray" assert len(labels) == len(x), "labels must have the same length as x and y" if sizes is not None: assert isinstance(sizes, np.ndarray), "sizes must be a numpy ndarray" assert len(sizes) == len(x), "sizes must have the same length as x and y" # Use default config if not provided if config is None: config = PlotConfig( title="Scatter Plot", xlabel="X", ylabel="Y" ) # Create figure and axes fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Plot scatter points if labels is not None: # Get unique labels unique_labels = np.unique(labels) # Create a colormap cmap = plt.get_cmap(config.palette) colors = [cmap(i / len(unique_labels)) for i in range(len(unique_labels))] # Plot each category with a different color for i, label in enumerate(unique_labels): mask = labels == label if sizes is not None: ax.scatter( x[mask], y[mask], color=colors[i], s=sizes[mask], alpha=alpha, label=str(label) ) else: ax.scatter( x[mask], y[mask], color=colors[i], alpha=alpha, label=str(label) ) else: # Plot all points with the same color if sizes is not None: ax.scatter(x, y, s=sizes, alpha=alpha) else: ax.scatter(x, y, alpha=alpha) # Apply configuration _apply_plot_config(fig, ax, config) return fig, ax
[docs] def plot_correlation_matrix( data: Union[np.ndarray, Dict[str, np.ndarray], pd.DataFrame], method: str = "pearson", cmap: str = "coolwarm", config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, plt.Axes]: """ Plot a correlation matrix. Args: data: 2D array of data, dictionary mapping variable names to data arrays, or pandas DataFrame method: Correlation method ('pearson', 'kendall', 'spearman') cmap: Colormap name config: Plot configuration Returns: Tuple of (figure, axes) """ assert isinstance(method, str), "method must be a string" assert method in ["pearson", "kendall", "spearman"], \ "method must be one of ['pearson', 'kendall', 'spearman']" assert isinstance(cmap, str), "cmap must be a string" # Convert dictionary to DataFrame if needed if isinstance(data, dict): assert all(isinstance(k, str) for k in data.keys()), "All keys in data must be strings" assert all(isinstance(v, np.ndarray) for v in data.values()), "All values in data must be numpy arrays" # Check that all arrays have the same length first_length = len(next(iter(data.values()))) assert all(len(v) == first_length for v in data.values()), \ "All arrays in data must have the same length" # Convert to DataFrame df = pd.DataFrame(data) elif isinstance(data, np.ndarray): assert len(data.shape) == 2, "data must be a 2D array or dictionary of arrays" # Convert to DataFrame with default column names df = pd.DataFrame(data) elif isinstance(data, pd.DataFrame): # Use the DataFrame directly df = data else: raise ValueError("data must be a numpy ndarray, dictionary of arrays, or pandas DataFrame") # Use default config if not provided if config is None: config = PlotConfig( title=f"Correlation Matrix ({method.capitalize()})", figsize=(8.0, 6.0) ) # Create figure and axes fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Calculate correlation matrix corr = df.corr(method=method) # Plot heatmap im = sns.heatmap( corr, ax=ax, cmap=cmap, vmin=-1.0, vmax=1.0, center=0, annot=True, fmt=".2f", square=True, linewidths=0.5 ) # Apply configuration _apply_plot_config(fig, ax, config) return fig, ax
[docs] def plot_validation_results( results: Dict[str, Dict[str, Any]], config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, List[plt.Axes]]: """ Plot validation results. Args: results: Dictionary mapping variable names to validation results config: Plot configuration Returns: Tuple of (figure, list of axes) """ assert isinstance(results, dict), "results must be a dictionary" # Use default config if not provided if config is None: config = PlotConfig( title="Validation Results", figsize=(12.0, 8.0) ) # Create figure and axes for multiple subplots fig, axes = plt.subplots(2, 2, figsize=config.figsize, dpi=config.dpi) axes = axes.flatten() # Get variables and their validation status variable_names = list(results.keys()) valid_status = [results[var]["is_valid"] for var in variable_names] # 1. Bar chart of valid vs. invalid variables ax1 = axes[0] valid_count = sum(valid_status) invalid_count = len(variable_names) - valid_count ax1.bar( ["Valid", "Invalid"], [valid_count, invalid_count], color=["green", "red"] ) ax1.set_title("Validation Status") ax1.set_ylabel("Count") # 2. Pie chart of valid vs. invalid variables ax2 = axes[1] ax2.pie( [valid_count, invalid_count], labels=["Valid", "Invalid"], colors=["green", "red"], autopct="%1.1f%%", startangle=90 ) ax2.set_title("Validation Status") # 3. Bar chart of error counts by variable ax3 = axes[2] error_counts = [len(results[var]["error_messages"]) for var in variable_names] # Sort by error count sorted_indices = np.argsort(error_counts)[::-1] sorted_names = [variable_names[i] for i in sorted_indices] sorted_counts = [error_counts[i] for i in sorted_indices] # Limit to top 10 if there are many variables if len(sorted_names) > 10: sorted_names = sorted_names[:10] sorted_counts = sorted_counts[:10] ax3.barh( sorted_names, sorted_counts, color="orange" ) ax3.set_title("Error Counts by Variable") ax3.set_xlabel("Number of Errors") # 4. Statistics summary ax4 = axes[3] ax4.axis("off") # Calculate some overall statistics total_vars = len(variable_names) total_errors = sum(error_counts) avg_errors = total_errors / total_vars if total_vars > 0 else 0 stats_text = ( f"Total Variables: {total_vars}\n" f"Valid Variables: {valid_count} ({valid_count/total_vars*100:.1f}%)\n" f"Invalid Variables: {invalid_count} ({invalid_count/total_vars*100:.1f}%)\n" f"Total Errors: {total_errors}\n" f"Avg. Errors per Variable: {avg_errors:.2f}" ) ax4.text( 0.5, 0.5, stats_text, ha="center", va="center", fontsize=config.font_size ) ax4.set_title("Summary Statistics") # Set overall title fig.suptitle(config.title, fontsize=config.font_size + 4) # Adjust layout fig.tight_layout() fig.subplots_adjust(top=0.9) return fig, axes
[docs] def plot_refinement_comparison( original_data: np.ndarray, refined_data: np.ndarray, timestamps: Optional[np.ndarray] = None, config: Optional[PlotConfig] = None ) -> Tuple[plt.Figure, List[plt.Axes]]: """ Plot a comparison of original and refined data. Args: original_data: Original data array refined_data: Refined data array timestamps: Optional array of timestamps for the x-axis config: Plot configuration Returns: Tuple of (figure, list of axes) """ assert isinstance(original_data, np.ndarray), "original_data must be a numpy ndarray" assert isinstance(refined_data, np.ndarray), "refined_data must be a numpy ndarray" assert original_data.shape == refined_data.shape, "original_data and refined_data must have the same shape" if timestamps is not None: assert isinstance(timestamps, np.ndarray), "timestamps must be a numpy ndarray" assert len(timestamps) == len(original_data), "timestamps must have the same length as data arrays" # Use default config if not provided if config is None: config = PlotConfig( title="Data Refinement Comparison", figsize=(12.0, 8.0) ) # Create figure and axes for multiple subplots fig, axes = plt.subplots(2, 2, figsize=config.figsize, dpi=config.dpi) axes = axes.flatten() # 1. Time series plot of original and refined data ax1 = axes[0] if timestamps is not None: ax1.plot(timestamps, original_data, label="Original", alpha=0.7) ax1.plot(timestamps, refined_data, label="Refined", alpha=0.7) else: ax1.plot(original_data, label="Original", alpha=0.7) ax1.plot(refined_data, label="Refined", alpha=0.7) ax1.set_title("Time Series Comparison") ax1.set_xlabel("Time" if timestamps is None else "") ax1.set_ylabel("Value") ax1.legend() ax1.grid(True) # 2. Histogram of original and refined data ax2 = axes[1] # Filter out NaN values original_valid = original_data[~np.isnan(original_data)] refined_valid = refined_data[~np.isnan(refined_data)] ax2.hist( original_valid, bins=30, alpha=0.5, label="Original" ) ax2.hist( refined_valid, bins=30, alpha=0.5, label="Refined" ) ax2.set_title("Distribution Comparison") ax2.set_xlabel("Value") ax2.set_ylabel("Frequency") ax2.legend() ax2.grid(True) # 3. Scatter plot of original vs. refined data ax3 = axes[2] # Create a mask for non-NaN values in both arrays mask = ~np.isnan(original_data) & ~np.isnan(refined_data) if np.any(mask): ax3.scatter( original_data[mask], refined_data[mask], alpha=0.5, edgecolor="k", linewidth=0.5 ) # Add diagonal line for reference min_val = min(original_data[mask].min(), refined_data[mask].min()) max_val = max(original_data[mask].max(), refined_data[mask].max()) ax3.plot([min_val, max_val], [min_val, max_val], "k--", alpha=0.7) ax3.set_title("Original vs. Refined Values") ax3.set_xlabel("Original Value") ax3.set_ylabel("Refined Value") ax3.grid(True) # 4. Residuals plot ax4 = axes[3] if np.any(mask): residuals = refined_data[mask] - original_data[mask] if timestamps is not None and len(timestamps) == len(original_data): ax4.plot(timestamps[mask], residuals, "o-", alpha=0.5, markersize=3) else: ax4.plot(np.arange(len(residuals)), residuals, "o-", alpha=0.5, markersize=3) # Add horizontal line at zero ax4.axhline(y=0, color="k", linestyle="--", alpha=0.7) ax4.set_title("Refinement Residuals") ax4.set_xlabel("Time" if timestamps is not None else "Index") ax4.set_ylabel("Refined - Original") ax4.grid(True) # Set overall title fig.suptitle(config.title, fontsize=config.font_size + 4) # Adjust layout fig.tight_layout() fig.subplots_adjust(top=0.9) return fig, axes
[docs] def save_plot( fig: plt.Figure, filepath: str, dpi: Optional[int] = None, format: Optional[str] = None, transparent: bool = False ) -> None: """ Save a figure to a file. Args: fig: Matplotlib figure filepath: Path to the output file dpi: Resolution in dots per inch format: File format (auto-detected from extension if None) transparent: Whether to use a transparent background """ assert isinstance(filepath, str), "filepath must be a string" if dpi is not None: assert isinstance(dpi, int), "dpi must be an integer" assert dpi > 0, "dpi must be positive" if format is not None: assert isinstance(format, str), "format must be a string" assert isinstance(transparent, bool), "transparent must be a boolean" # Create directory if it doesn't exist directory = os.path.dirname(filepath) if directory and not os.path.exists(directory): os.makedirs(directory) # Save the figure fig.savefig( filepath, dpi=dpi, format=format, bbox_inches="tight", transparent=transparent )