Source code for shesha.core

"""
Shesha: Self-consistency Metrics for Representational Stability

Core implementations of Shesha variants for measuring geometric stability
of high-dimensional representations.
"""

import numpy as np
from scipy.stats import spearmanr, pearsonr
from scipy.spatial.distance import pdist, cdist
from typing import List, Optional, Union
try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

from ._utils import bootstrap_ci, bootstrap_ci_two_sample

__all__ = [
    # Unsupervised variants
    "feature_split",
    "sample_split", 
    "anchor_stability",
    # Supervised variants
    "variance_ratio",
    "supervised_alignment",
    "class_separation_ratio",
    "lda_stability",
    # Drift metrics
    "rdm_similarity",
    "rdm_drift",
    # Utilities
    "compute_rdm",
]

EPS = 1e-12


# =============================================================================
# RDM Utilities
# =============================================================================

[docs] def compute_rdm( X: np.ndarray, metric: Literal["cosine", "correlation", "euclidean"] = "cosine", normalize: bool = True, ) -> np.ndarray: """ Compute Representational Dissimilarity Matrix (RDM). Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). metric : str Distance metric: 'cosine', 'correlation', or 'euclidean'. normalize : bool If True and metric='cosine', L2-normalize rows before computing distances. Returns ------- np.ndarray Condensed distance vector (upper triangle of RDM). """ X = np.asarray(X, dtype=np.float64) if normalize and metric == "cosine": norms = np.linalg.norm(X, axis=1, keepdims=True) X = X / np.maximum(norms, EPS) return pdist(X, metric=metric)
# ============================================================================= # Unsupervised Variants # =============================================================================
[docs] def feature_split( X: Union[np.ndarray, List[np.ndarray]], n_splits: int = 30, metric: Literal["cosine", "correlation"] = "cosine", seed: Optional[int] = None, max_samples: Optional[int] = 1600, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, return_all_splits: bool = False, ) -> Union[float, dict, List]: """ Feature-Split Shesha: measures internal geometric consistency. Partitions feature dimensions into random disjoint halves, computes RDMs on each half, and measures their rank correlation. High values indicate that geometric structure is distributed across features (redundant encoding). Parameters ---------- X : np.ndarray or list of np.ndarray Data matrix of shape (n_samples, n_features), or a list of such matrices for batch evaluation. When a list is passed, returns a list of results in the same order. n_splits : int Number of random feature partitions to average over. metric : str Distance metric for RDM computation. seed : int, optional Random seed for reproducibility. max_samples : int, optional Subsample to this many samples if exceeded (for efficiency). n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times (e.g. 1000 or 10000). ci : float, default=0.95 Confidence level for the interval (only used when n_bootstrap_ci is set). return_all_splits : bool, default=False If True, return a dict with the mean score and per-split correlation scores instead of only the mean score. Returns ------- float or dict or list If X is a single array and n_bootstrap_ci is None: mean Spearman correlation in [-1, 1]. If X is a single array and return_all_splits is True: dict with keys 'mean' and 'split_scores'. If X is a single array and n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. If X is a list: list of the above, one entry per input matrix. Examples -------- >>> X = np.random.randn(500, 768) # 500 samples, 768-dim embeddings >>> stability = feature_split(X, n_splits=30, seed=320) >>> print(f"Feature-split stability: {stability:.3f}") >>> # Batch evaluation across multiple representations >>> matrices = [np.random.randn(500, 768) for _ in range(5)] >>> scores = feature_split(matrices, n_splits=30, seed=320) >>> print(scores) # list of 5 floats >>> # With bootstrap confidence interval >>> result = feature_split(X, n_splits=30, seed=320, n_bootstrap_ci=1000) >>> print(f"{result['mean']:.3f} [{result['ci_low']:.3f}, {result['ci_high']:.3f}]") >>> # Return per-split scores for distribution plots >>> result = feature_split(X, n_splits=30, seed=320, return_all_splits=True) >>> scores = result["split_scores"] """ if isinstance(X, list): return [ feature_split( x, n_splits=n_splits, metric=metric, seed=seed, max_samples=max_samples, n_bootstrap_ci=n_bootstrap_ci, ci=ci, return_all_splits=return_all_splits, ) for x in X ] if n_bootstrap_ci is not None and return_all_splits: raise ValueError("return_all_splits cannot be used with n_bootstrap_ci.") if n_bootstrap_ci is not None: return bootstrap_ci( feature_split, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), n_splits=n_splits, metric=metric, seed=seed, max_samples=max_samples, ) X = np.asarray(X, dtype=np.float64) n_samples, n_features = X.shape if n_features < 4: return np.nan if n_samples < 4: return np.nan rng = np.random.default_rng(seed) # Subsample if needed if max_samples is not None and n_samples > max_samples: idx = rng.choice(n_samples, max_samples, replace=False) X = X[idx] n_samples = max_samples # L2 normalize for cosine metric if metric == "cosine": norms = np.linalg.norm(X, axis=1, keepdims=True) X = X / np.maximum(norms, EPS) correlations = [] for i in range(n_splits): # Random partition of features perm = rng.permutation(n_features) mid = n_features // 2 feat1, feat2 = perm[:mid], perm[mid:2*mid] X1, X2 = X[:, feat1], X[:, feat2] # Compute RDMs rdm1 = pdist(X1, metric=metric) rdm2 = pdist(X2, metric=metric) # Handle NaN distances (can occur with zero vectors) rdm1 = np.nan_to_num(rdm1, nan=1.0) rdm2 = np.nan_to_num(rdm2, nan=1.0) # Check for constant RDMs if np.std(rdm1) < EPS or np.std(rdm2) < EPS: continue rho, _ = spearmanr(rdm1, rdm2) if np.isfinite(rho): correlations.append(rho) mean_score = float(np.mean(correlations)) if correlations else np.nan if return_all_splits: return { "mean": mean_score, "split_scores": [float(score) for score in correlations], } return mean_score
[docs] def sample_split( X: np.ndarray, n_splits: int = 30, subsample_fraction: float = 0.4, metric: Literal["cosine", "correlation"] = "cosine", seed: Optional[int] = None, max_samples: Optional[int] = 1500, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, return_all_splits: bool = False, ) -> Union[float, dict]: """ Sample-Split Shesha (Bootstrap RDM): measures robustness to input variation. Creates random subsamples of data points, computes RDMs on each, and measures their correlation. Assesses whether distance structure generalizes across different subsets of the data. Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). n_splits : int Number of bootstrap iterations. subsample_fraction : float Fraction of samples to use in each subsample. metric : str Distance metric for RDM computation. seed : int, optional Random seed for reproducibility. max_samples : int, optional Subsample to this many samples if exceeded. n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. return_all_splits : bool, default=False If True, return a dict with the mean score and per-split correlation scores instead of only the mean score. Returns ------- float or dict If n_bootstrap_ci is None: mean Spearman correlation. Range: [-1, 1]. If return_all_splits is True: dict with keys 'mean' and 'split_scores'. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. Examples -------- >>> X = np.random.randn(1000, 384) >>> stability = sample_split(X, n_splits=50, seed=320) >>> result = sample_split(X, n_splits=50, seed=320, return_all_splits=True) >>> scores = result["split_scores"] """ if n_bootstrap_ci is not None and return_all_splits: raise ValueError("return_all_splits cannot be used with n_bootstrap_ci.") if n_bootstrap_ci is not None: return bootstrap_ci( sample_split, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), n_splits=n_splits, subsample_fraction=subsample_fraction, metric=metric, seed=seed, max_samples=max_samples, ) X = np.asarray(X, dtype=np.float64) n_samples = X.shape[0] if n_samples < 10: return np.nan rng = np.random.default_rng(seed) # Subsample if needed if max_samples is not None and n_samples > max_samples: idx = rng.choice(n_samples, max_samples, replace=False) X = X[idx] n_samples = max_samples m = int(n_samples * subsample_fraction) if m < 5: return np.nan correlations = [] for _ in range(n_splits): # Two independent subsamples idx1 = rng.choice(n_samples, m, replace=False) idx2 = rng.choice(n_samples, m, replace=False) rdm1 = pdist(X[idx1], metric=metric) rdm2 = pdist(X[idx2], metric=metric) if np.std(rdm1) < EPS or np.std(rdm2) < EPS: continue rho, _ = spearmanr(rdm1, rdm2) if np.isfinite(rho): correlations.append(rho) mean_score = float(np.mean(correlations)) if correlations else np.nan if return_all_splits: return { "mean": mean_score, "split_scores": [float(score) for score in correlations], } return mean_score
[docs] def anchor_stability( X: np.ndarray, n_splits: int = 30, n_anchors: int = 100, n_per_split: int = 200, metric: Literal["cosine", "euclidean"] = "cosine", rank_normalize: bool = True, seed: Optional[int] = None, max_samples: Optional[int] = 1500, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, ) -> Union[float, dict]: """ Anchor-based Shesha: measures stability of distance profiles from fixed anchors. Selects fixed anchor points, then measures consistency of distance profiles from anchors to random data splits. More robust to sampling variation than pure bootstrap approaches. Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). n_splits : int Number of random splits. n_anchors : int Number of fixed anchor points. n_per_split : int Number of samples per split. metric : str Distance metric. rank_normalize : bool If True, rank-normalize distances within each anchor before correlating. seed : int, optional Random seed. max_samples : int, optional Subsample to this many samples if exceeded. n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. Returns ------- float or dict If n_bootstrap_ci is None: mean correlation of anchor distance profiles. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. """ if n_bootstrap_ci is not None: return bootstrap_ci( anchor_stability, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), n_splits=n_splits, n_anchors=n_anchors, n_per_split=n_per_split, metric=metric, rank_normalize=rank_normalize, seed=seed, max_samples=max_samples, ) X = np.asarray(X, dtype=np.float64) n_samples = X.shape[0] rng = np.random.default_rng(seed) # Subsample if needed if max_samples is not None and n_samples > max_samples: idx = rng.choice(n_samples, max_samples, replace=False) X = X[idx] n_samples = max_samples # Need enough samples for anchors + two splits min_required = n_anchors + 2 * n_per_split if n_samples < min_required: # Reduce sizes proportionally scale = n_samples / min_required * 0.9 n_anchors = max(10, int(n_anchors * scale)) n_per_split = max(20, int(n_per_split * scale)) if n_samples < n_anchors + 2 * n_per_split: return np.nan # Select fixed anchors anchor_idx = rng.choice(n_samples, n_anchors, replace=False) anchors = X[anchor_idx] remaining_idx = np.setdiff1d(np.arange(n_samples), anchor_idx) if len(remaining_idx) < 2 * n_per_split: return np.nan correlations = [] for _ in range(n_splits): # Two disjoint splits from remaining samples perm = rng.permutation(remaining_idx) split1_idx = perm[:n_per_split] split2_idx = perm[n_per_split:2*n_per_split] # Distance matrices: anchors x split_samples D1 = cdist(anchors, X[split1_idx], metric=metric) D2 = cdist(anchors, X[split2_idx], metric=metric) if rank_normalize: # Rank within each anchor's distances from scipy.stats import rankdata D1 = np.apply_along_axis(rankdata, 1, D1) D2 = np.apply_along_axis(rankdata, 1, D2) # Flatten and correlate rho, _ = spearmanr(D1.ravel(), D2.ravel()) if np.isfinite(rho): correlations.append(rho) return float(np.mean(correlations)) if correlations else np.nan
# ============================================================================= # Supervised Variants # =============================================================================
[docs] def variance_ratio( X: np.ndarray, y: np.ndarray, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, seed: Optional[int] = None, ) -> Union[float, dict]: """ Variance Ratio Shesha: ratio of between-class to total variance. A simple, efficient measure of how much geometric structure is explained by class labels. Equivalent to the R-squared of predicting coordinates from class membership. Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). y : np.ndarray Class labels of shape (n_samples,). n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. seed : int, optional Random seed for bootstrap reproducibility. Returns ------- float or dict If n_bootstrap_ci is None: between-class variance / total variance. Range: [0, 1]. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. Examples -------- >>> X = np.random.randn(500, 768) >>> y = np.random.randint(0, 10, 500) >>> vr = variance_ratio(X, y) """ if n_bootstrap_ci is not None: return bootstrap_ci( variance_ratio, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), np.asarray(y), ) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) classes = np.unique(y) if len(classes) < 2: return np.nan global_mean = np.mean(X, axis=0) X_centered = X - global_mean ss_total = np.sum(X_centered ** 2) + EPS ss_between = 0.0 for c in classes: mask = (y == c) n_c = np.sum(mask) if n_c == 0: continue class_mean = np.mean(X[mask], axis=0) ss_between += n_c * np.sum((class_mean - global_mean) ** 2) return float(ss_between / ss_total)
[docs] def supervised_alignment( X: np.ndarray, y: np.ndarray, metric: Literal["cosine", "correlation"] = "correlation", seed: Optional[int] = None, max_samples: int = 300, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, ) -> Union[float, dict]: """ Supervised RDM Alignment: correlation between model RDM and ideal label RDM. Measures how well the representation's distance structure aligns with task-defined similarity (same class = similar, different class = dissimilar). Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). y : np.ndarray Class labels of shape (n_samples,). metric : str Distance metric for model RDM. seed : int, optional Random seed for subsampling. max_samples : int Subsample to this many samples (RDM computation is O(n^2)). n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. Returns ------- float or dict If n_bootstrap_ci is None: Spearman correlation. Range: [-1, 1]. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. """ if n_bootstrap_ci is not None: return bootstrap_ci( supervised_alignment, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), np.asarray(y), metric=metric, seed=seed, max_samples=max_samples, ) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) rng = np.random.default_rng(seed) if len(X) > max_samples: idx = rng.choice(len(X), max_samples, replace=False) X, y = X[idx], y[idx] # Center for correlation distance X = X - np.mean(X, axis=0) # Model RDM model_rdm = pdist(X, metric=metric) # Ideal RDM from labels (Hamming distance on labels) ideal_rdm = pdist(y.reshape(-1, 1), metric="hamming") rho, _ = spearmanr(model_rdm, ideal_rdm) return float(rho) if np.isfinite(rho) else np.nan
[docs] def class_separation_ratio( X: np.ndarray, y: np.ndarray, n_bootstrap: int = 50, subsample_frac: float = 0.5, metric: Literal["cosine", "euclidean"] = "euclidean", seed: Optional[int] = None, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, ) -> Union[float, dict]: """ Class Separation Ratio: ratio of between-class to within-class distances. Measures how well-separated classes are in the representation space. Uses bootstrap subsampling for computational efficiency and stability. Related to Fisher's discriminant ratio but operates in distance space. Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). y : np.ndarray Class labels of shape (n_samples,). n_bootstrap : int Number of bootstrap iterations for stability. subsample_frac : float Fraction of samples to use per bootstrap (0.0-1.0). metric : str Distance metric: 'cosine' or 'euclidean'. seed : int, optional Random seed for reproducibility. n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. Returns ------- float or dict If n_bootstrap_ci is None: mean separation ratio. Range: [0, inf). If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. Examples -------- >>> # Well-separated classes >>> X = np.vstack([np.random.randn(100, 10), ... np.random.randn(100, 10) + 5]) >>> y = np.array([0]*100 + [1]*100) >>> ratio = class_separation_ratio(X, y) >>> print(f"Separation: {ratio:.2f}") # High value Notes ----- Higher values indicate representations where same-class samples are closer together than different-class samples, suggesting good discriminability. """ if n_bootstrap_ci is not None: return bootstrap_ci( class_separation_ratio, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), np.asarray(y), n_bootstrap=n_bootstrap, subsample_frac=subsample_frac, metric=metric, seed=seed, ) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) if len(np.unique(y)) < 2: return np.nan rng = np.random.default_rng(seed) if metric == "cosine": # Normalize for cosine distance norms = np.linalg.norm(X, axis=1, keepdims=True) X = X / np.maximum(norms, EPS) ratios = [] n_samples = int(len(X) * subsample_frac) for _ in range(n_bootstrap): # Subsample idx = rng.choice(len(X), n_samples, replace=False) X_sub, y_sub = X[idx], y[idx] # Skip if any class is missing if len(np.unique(y_sub)) < 2: continue # Compute pairwise distances dists = cdist(X_sub, X_sub, metric='euclidean' if metric == 'euclidean' else 'cosine') # Within-class distances (same label) within_dists = [] for label in np.unique(y_sub): mask = y_sub == label if np.sum(mask) < 2: continue class_dists = dists[mask][:, mask] # Upper triangle only (avoid diagonal) within_dists.extend(class_dists[np.triu_indices_from(class_dists, k=1)]) # Between-class distances (different labels) between_dists = [] for i, label_i in enumerate(np.unique(y_sub)): for label_j in np.unique(y_sub): if label_i >= label_j: continue mask_i = y_sub == label_i mask_j = y_sub == label_j between_dists.extend(dists[mask_i][:, mask_j].flatten()) if len(within_dists) == 0 or len(between_dists) == 0: continue mean_between = np.mean(between_dists) mean_within = np.mean(within_dists) if mean_within > EPS: ratios.append(mean_between / mean_within) return float(np.mean(ratios)) if len(ratios) > 0 else np.nan
[docs] def lda_stability( X: np.ndarray, y: np.ndarray, n_bootstrap: int = 50, subsample_frac: float = 0.5, seed: Optional[int] = None, n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, ) -> Union[float, dict]: """ LDA Subspace Stability: consistency of linear discriminant direction. Measures whether the optimal linear decision boundary is robust to sampling variation. Computes LDA on full dataset and bootstrapped subsamples, then measures alignment of discriminant vectors. Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). y : np.ndarray Binary class labels of shape (n_samples,). Must have exactly 2 classes. n_bootstrap : int Number of bootstrap iterations. subsample_frac : float Fraction of samples to use per bootstrap (0.0-1.0). seed : int, optional Random seed for reproducibility. n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. Returns ------- float or dict If n_bootstrap_ci is None: mean absolute cosine similarity. Range: [0, 1]. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. Examples -------- >>> # Create well-separated binary classification data >>> X = np.vstack([np.random.randn(100, 10), ... np.random.randn(100, 10) + 3]) >>> y = np.array([0]*100 + [1]*100) >>> stability = lda_stability(X, y) >>> print(f"LDA Stability: {stability:.3f}") # Should be high Notes ----- Low values suggest the discriminant subspace is unstable, potentially indicating overfitting to source domain structure. This metric is particularly useful for predicting transfer learning performance. Only works for binary classification. For multi-class, consider using class_separation_ratio instead. """ if n_bootstrap_ci is not None: return bootstrap_ci( lda_stability, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), np.asarray(y), n_bootstrap=n_bootstrap, subsample_frac=subsample_frac, seed=seed, ) X = np.asarray(X, dtype=np.float64) y = np.asarray(y) # Check for binary classification classes = np.unique(y) if len(classes) != 2: raise ValueError(f"LDA stability requires exactly 2 classes, got {len(classes)}") rng = np.random.default_rng(seed) # Compute full discriminant vector try: # Compute class means mean_0 = np.mean(X[y == classes[0]], axis=0) mean_1 = np.mean(X[y == classes[1]], axis=0) # Compute pooled within-class covariance X_0_centered = X[y == classes[0]] - mean_0 X_1_centered = X[y == classes[1]] - mean_1 S_w = (X_0_centered.T @ X_0_centered + X_1_centered.T @ X_1_centered) / len(X) # Add regularization for numerical stability S_w += np.eye(X.shape[1]) * 1e-6 # Compute discriminant direction: S_w^{-1} (μ_1 - μ_0) mean_diff = mean_1 - mean_0 w_full = np.linalg.solve(S_w, mean_diff) w_full = w_full / (np.linalg.norm(w_full) + EPS) except np.linalg.LinAlgError: return np.nan # Bootstrap similarities = [] n_samples = int(len(X) * subsample_frac) for _ in range(n_bootstrap): # Subsample with stratification idx_0 = rng.choice(np.where(y == classes[0])[0], n_samples // 2, replace=True) idx_1 = rng.choice(np.where(y == classes[1])[0], n_samples // 2, replace=True) idx = np.concatenate([idx_0, idx_1]) X_boot, y_boot = X[idx], y[idx] try: # Compute bootstrap discriminant mean_0_boot = np.mean(X_boot[y_boot == classes[0]], axis=0) mean_1_boot = np.mean(X_boot[y_boot == classes[1]], axis=0) X_0_boot_centered = X_boot[y_boot == classes[0]] - mean_0_boot X_1_boot_centered = X_boot[y_boot == classes[1]] - mean_1_boot S_w_boot = (X_0_boot_centered.T @ X_0_boot_centered + X_1_boot_centered.T @ X_1_boot_centered) / len(X_boot) S_w_boot += np.eye(X_boot.shape[1]) * 1e-6 mean_diff_boot = mean_1_boot - mean_0_boot w_boot = np.linalg.solve(S_w_boot, mean_diff_boot) w_boot = w_boot / (np.linalg.norm(w_boot) + EPS) # Absolute cosine similarity (sign ambiguity in discriminant) sim = np.abs(np.dot(w_full, w_boot)) similarities.append(sim) except np.linalg.LinAlgError: continue return float(np.mean(similarities)) if len(similarities) > 0 else np.nan
# ============================================================================= # Drift Metrics # =============================================================================
[docs] def rdm_similarity( X: np.ndarray, Y: np.ndarray, method: Literal["spearman", "pearson"] = "spearman", metric: Literal["cosine", "correlation", "euclidean"] = "cosine", n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, seed: Optional[int] = None, ) -> Union[float, dict]: """ Compute RDM similarity between two representations. Measures how similar the pairwise distance structures are between two representations. Useful for measuring representational drift, comparing models, or tracking changes during training. Parameters ---------- X : np.ndarray First representation matrix of shape (n_samples, n_features_x). Y : np.ndarray Second representation matrix of shape (n_samples, n_features_y). Must have the same number of samples as X. method : str Correlation method: 'spearman' (rank-based, default) or 'pearson'. metric : str Distance metric for RDM computation: 'cosine', 'correlation', or 'euclidean'. n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. seed : int, optional Random seed for bootstrap reproducibility. Returns ------- float or dict If n_bootstrap_ci is None: correlation between RDMs. Range: [-1, 1]. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. Examples -------- >>> # Compare representations before and after training >>> X_before = model_before.encode(data) >>> X_after = model_after.encode(data) >>> similarity = rdm_similarity(X_before, X_after) >>> print(f"RDM similarity: {similarity:.3f}") >>> # Compare two different models >>> X_model1 = model1.encode(data) >>> X_model2 = model2.encode(data) >>> similarity = rdm_similarity(X_model1, X_model2, method='pearson') Notes ----- - Spearman (default) is more robust to outliers and non-linear relationships - Pearson captures linear relationships in distance magnitudes - The representations can have different feature dimensions (only sample count must match) """ if n_bootstrap_ci is not None: return bootstrap_ci_two_sample( rdm_similarity, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), np.asarray(Y, dtype=np.float64), method=method, metric=metric, ) X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) if X.shape[0] != Y.shape[0]: raise ValueError(f"Sample counts must match: X has {X.shape[0]}, Y has {Y.shape[0]}") if X.shape[0] < 3: return np.nan # Compute RDMs rdm_x = pdist(X, metric=metric) rdm_y = pdist(Y, metric=metric) # Handle NaN values rdm_x = np.nan_to_num(rdm_x, nan=1.0) rdm_y = np.nan_to_num(rdm_y, nan=1.0) # Check for constant RDMs if np.std(rdm_x) < EPS or np.std(rdm_y) < EPS: return 0.0 # Compute correlation if method == "spearman": rho = spearmanr(rdm_x, rdm_y).correlation elif method == "pearson": rho, _ = pearsonr(rdm_x, rdm_y) else: raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'") return float(rho) if np.isfinite(rho) else 0.0
[docs] def rdm_drift( X: np.ndarray, Y: np.ndarray, method: Literal["spearman", "pearson"] = "spearman", metric: Literal["cosine", "correlation", "euclidean"] = "cosine", n_bootstrap_ci: Optional[int] = None, ci: float = 0.95, seed: Optional[int] = None, ) -> Union[float, dict]: """ Compute representational drift between two representations. Drift is defined as 1 - rdm_similarity, so higher values indicate more change in geometric structure. This is useful for tracking how much a representation has changed over time or due to some intervention (fine-tuning, perturbation, etc.). Parameters ---------- X : np.ndarray First (baseline/before) representation of shape (n_samples, n_features_x). Y : np.ndarray Second (comparison/after) representation of shape (n_samples, n_features_y). Must have the same number of samples as X. method : str Correlation method: 'spearman' (rank-based, default) or 'pearson'. metric : str Distance metric for RDM computation. n_bootstrap_ci : int, optional If provided, compute bootstrap confidence interval by resampling the input data this many times. ci : float, default=0.95 Confidence level for the interval. seed : int, optional Random seed for bootstrap reproducibility. Returns ------- float or dict If n_bootstrap_ci is None: drift score. Range: [0, 2]. If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high', 'std', 'n_bootstraps', 'ci_level'. Examples -------- >>> # Track drift during training >>> X_epoch0 = model.encode(data) >>> for epoch in range(10): ... train_one_epoch(model) ... X_current = model.encode(data) ... drift = rdm_drift(X_epoch0, X_current) ... print(f"Epoch {epoch+1}: drift = {drift:.3f}") >>> # Measure drift due to noise perturbation >>> X_clean = model.encode(clean_data) >>> X_noisy = model.encode(noisy_data) >>> drift = rdm_drift(X_clean, X_noisy) >>> print(f"Noise-induced drift: {drift:.3f}") See Also -------- rdm_similarity : The inverse metric (similarity instead of drift) """ if n_bootstrap_ci is not None: return bootstrap_ci_two_sample( rdm_drift, n_bootstrap_ci, ci, seed, np.asarray(X, dtype=np.float64), np.asarray(Y, dtype=np.float64), method=method, metric=metric, ) similarity = rdm_similarity(X, Y, method=method, metric=metric) if np.isnan(similarity): return np.nan return 1.0 - similarity
# ============================================================================= # Convenience function # =============================================================================
[docs] def shesha( X: np.ndarray, y: Optional[np.ndarray] = None, variant: Literal["feature_split", "sample_split", "anchor", "variance", "supervised"] = "feature_split", **kwargs, ) -> float: """ Unified interface for computing Shesha stability metrics. Parameters ---------- X : np.ndarray Data matrix of shape (n_samples, n_features). y : np.ndarray, optional Class labels (required for supervised variants). variant : str Which Shesha variant to compute: - 'feature_split': Unsupervised, partitions features - 'sample_split': Unsupervised, bootstrap resampling - 'anchor': Unsupervised, anchor-based stability - 'variance': Supervised, variance ratio - 'supervised': Supervised, RDM alignment **kwargs Additional arguments passed to the specific variant function. Returns ------- float Shesha stability score. Examples -------- >>> # Unsupervised >>> stability = shesha(X, variant='feature_split', n_splits=30, seed=320) >>> # Supervised >>> alignment = shesha(X, y, variant='supervised') """ if variant == "feature_split": return feature_split(X, **kwargs) elif variant == "sample_split": return sample_split(X, **kwargs) elif variant == "anchor": return anchor_stability(X, **kwargs) elif variant == "variance": if y is None: raise ValueError("Labels required for variance_ratio") return variance_ratio(X, y) elif variant == "supervised": if y is None: raise ValueError("Labels required for supervised_alignment") return supervised_alignment(X, y, **kwargs) else: raise ValueError(f"Unknown variant: {variant}")