Source code for shesha.sim

"""
Shesha Similarity: Representational Similarity Metrics

This module provides metrics for measuring similarity between representations,
complementing the stability metrics in shesha.core. While stability measures
intrinsic geometric robustness, similarity measures extrinsic alignment.

Key distinction from the paper:
- Similarity is an *extrinsic* property (how one representation aligns with another)
- Stability is an *intrinsic* property (how robust a representation's geometry is)
- These are empirically uncorrelated (ρ ≈ 0.01)
"""

import numpy as np
from scipy.spatial.distance import pdist
from scipy.stats import spearmanr, pearsonr
from scipy.linalg import orthogonal_procrustes
from typing import Optional

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


__all__ = [
    "cka",
    "cka_linear",
    "cka_debiased",
    "procrustes_similarity",
    "rdm_similarity",
]

EPS = 1e-12


# =============================================================================
# CKA (Centered Kernel Alignment)
# =============================================================================

[docs] def cka_linear(X: np.ndarray, Y: np.ndarray) -> float: """ Linear Centered Kernel Alignment (CKA) - Standard version. Measures similarity between two representations using linear kernels. This is the standard (non-debiased) version which is simpler and more numerically stable, recommended for most use cases. Parameters ---------- X : np.ndarray First representation matrix of shape (n_samples, n_features_x). Y : np.ndarray Second representation matrix of shape (n_samples, n_features_y). Must have same number of samples as X. Returns ------- float CKA similarity score in [0, 1]. Higher values indicate more similar representational structure. 1.0 means identical structure (up to linear transformation). Examples -------- >>> import numpy as np >>> from shesha.similarity import cka_linear >>> >>> # Two representations of the same data >>> X = np.random.randn(100, 50) >>> Y = np.random.randn(100, 30) >>> >>> similarity = cka_linear(X, Y) >>> print(f"CKA: {similarity:.3f}") >>> >>> # Self-similarity should be 1.0 >>> self_sim = cka_linear(X, X) >>> print(f"Self-similarity: {self_sim:.3f}") # Should be ~1.0 Notes ----- CKA is invariant to: - Orthogonal transformations - Isotropic scaling CKA measures the similarity of Gram matrices (X @ X.T and Y @ Y.T), which capture the pairwise similarities between samples in each representation space. References ---------- Kornblith, S., Norouzi, M., Lee, H., & Hinton, G. (2019). Similarity of neural network representations revisited. ICML 2019. """ X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) if X.shape[0] != Y.shape[0]: raise ValueError( f"X and Y must have same number of samples: " f"X has {X.shape[0]}, Y has {Y.shape[0]}" ) # Center the data (subtract column means) X = X - X.mean(axis=0, keepdims=True) Y = Y - Y.mean(axis=0, keepdims=True) # Compute HSIC using Frobenius norm of cross-Gram matrix # HSIC(X, Y) = ||X^T Y||_F^2 num = np.linalg.norm(X.T @ Y, 'fro') ** 2 # Normalize by self-similarities # CKA = HSIC(X, Y) / sqrt(HSIC(X, X) * HSIC(Y, Y)) den = np.linalg.norm(X.T @ X, 'fro') * np.linalg.norm(Y.T @ Y, 'fro') return float(num / (den + EPS))
[docs] def cka_debiased(X: np.ndarray, Y: np.ndarray) -> float: """ Debiased Centered Kernel Alignment (CKA). Unbiased estimator of CKA that corrects for finite sample effects. More accurate for small sample sizes but computationally more expensive. Parameters ---------- X : np.ndarray First representation matrix of shape (n_samples, n_features_x). Y : np.ndarray Second representation matrix of shape (n_samples, n_features_y). Must have same number of samples as X. Returns ------- float Debiased CKA similarity score in [0, 1]. Higher values indicate more similar representational structure. Examples -------- >>> import numpy as np >>> from shesha.similarity import cka_debiased >>> >>> # For small sample sizes, debiased version is more accurate >>> X = np.random.randn(50, 20) >>> Y = np.random.randn(50, 15) >>> >>> # Compare standard vs debiased >>> from shesha.similarity import cka_linear >>> std_cka = cka_linear(X, Y) >>> debiased_cka = cka_debiased(X, Y) >>> >>> print(f"Standard: {std_cka:.3f}") >>> print(f"Debiased: {debiased_cka:.3f}") Notes ----- For n < 4, falls back to standard CKA as debiasing is not well-defined. The debiased estimator uses the unbiased HSIC estimator from Kornblith et al. (2019), which removes diagonal terms and applies correction factors. Recommended when: - Sample size is small (n < 100) - Exact statistical properties are important - Computing statistical significance References ---------- Kornblith, S., Norouzi, M., Lee, H., & Hinton, G. (2019). Similarity of neural network representations revisited. ICML 2019. """ X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) if X.shape[0] != Y.shape[0]: raise ValueError( f"X and Y must have same number of samples: " f"X has {X.shape[0]}, Y has {Y.shape[0]}" ) # Center the data X = X - X.mean(axis=0, keepdims=True) Y = Y - Y.mean(axis=0, keepdims=True) n = X.shape[0] # For very small samples, fall back to standard CKA if n < 4: num = np.linalg.norm(X.T @ Y, 'fro') ** 2 den = np.linalg.norm(X.T @ X, 'fro') * np.linalg.norm(Y.T @ Y, 'fro') return float(num / (den + EPS)) # Helper function to center Gram matrix def center_gram_matrix(G): """Center a Gram matrix: H @ G @ H where H is centering matrix.""" row_means = G.mean(axis=1, keepdims=True) col_means = G.mean(axis=0, keepdims=True) grand_mean = G.mean() return G - row_means - col_means + grand_mean # Compute and center Gram matrices K = center_gram_matrix(X @ X.T) L = center_gram_matrix(Y @ Y.T) # Zero out diagonals for debiasing terms K_no_diag = K.copy() L_no_diag = L.copy() np.fill_diagonal(K_no_diag, 0) np.fill_diagonal(L_no_diag, 0) # Debiased HSIC estimator (Kornblith et al., 2019) # Removes bias from diagonal terms hsic = ( np.sum(K * L) + (np.sum(K_no_diag) * np.sum(L_no_diag)) / ((n - 1) * (n - 2)) - 2 * np.sum(np.sum(K_no_diag, axis=1) * np.sum(L_no_diag, axis=1)) / (n - 2) ) / (n * (n - 3)) # Self-HSIC for normalization (also debiased) hsic_xx = ( np.sum(K * K) + np.sum(K_no_diag)**2 / ((n - 1) * (n - 2)) - 2 * np.sum(np.sum(K_no_diag, axis=1)**2) / (n - 2) ) / (n * (n - 3)) hsic_yy = ( np.sum(L * L) + np.sum(L_no_diag)**2 / ((n - 1) * (n - 2)) - 2 * np.sum(np.sum(L_no_diag, axis=1)**2) / (n - 2) ) / (n * (n - 3)) # Avoid division by zero or negative values (can happen due to numerical issues) if hsic_xx <= 0 or hsic_yy <= 0: return 0.0 return float(hsic / np.sqrt(hsic_xx * hsic_yy))
[docs] def cka( X: np.ndarray, Y: np.ndarray, debiased: bool = False ) -> float: """ Centered Kernel Alignment (CKA) - Unified interface. Convenience function that selects between standard and debiased CKA. Parameters ---------- X : np.ndarray First representation matrix of shape (n_samples, n_features_x). Y : np.ndarray Second representation matrix of shape (n_samples, n_features_y). debiased : bool, default=False If True, use debiased estimator. Recommended for small sample sizes. Returns ------- float CKA similarity score in [0, 1]. Examples -------- >>> from shesha.similarity import cka >>> >>> X = np.random.randn(100, 50) >>> Y = np.random.randn(100, 30) >>> >>> # Standard CKA (default, faster) >>> sim = cka(X, Y) >>> >>> # Debiased CKA (more accurate for small n) >>> sim_debiased = cka(X, Y, debiased=True) See Also -------- cka_linear : Standard CKA implementation cka_debiased : Debiased CKA implementation """ if debiased: return cka_debiased(X, Y) else: return cka_linear(X, Y)
# ============================================================================= # Procrustes Similarity # =============================================================================
[docs] def procrustes_similarity( X: np.ndarray, Y: np.ndarray, center: bool = True, scale: bool = True, ) -> float: """ Procrustes similarity between two representations. Finds the optimal orthogonal transformation that aligns Y to X and returns the similarity (1 - disparity). Unlike CKA, Procrustes attempts to directly align the representations in their original spaces. Parameters ---------- X : np.ndarray First representation matrix of shape (n_samples, n_features). Y : np.ndarray Second representation matrix of shape (n_samples, n_features). Must have same shape as X. center : bool, default=True If True, center both matrices before alignment. scale : bool, default=True If True, scale to unit Frobenius norm before alignment. Returns ------- float Procrustes similarity in [0, 1]. Higher values indicate better alignment. 1.0 means perfect alignment (identical up to rotation/reflection). Examples -------- >>> import numpy as np >>> from shesha.similarity import procrustes_similarity >>> >>> # Two representations that differ by a rotation >>> X = np.random.randn(100, 50) >>> Q = np.linalg.qr(np.random.randn(50, 50))[0] # Random rotation >>> Y = X @ Q >>> >>> similarity = procrustes_similarity(X, Y) >>> print(f"Procrustes: {similarity:.3f}") # Should be ~1.0 Notes ----- Procrustes is more sensitive to outliers and noise than CKA, which can be both an advantage (detects small changes) and disadvantage (more false alarms). The paper shows CKA is often preferred for representation analysis. If dimensions don't match, returns NaN. Unlike CKA, Procrustes requires representations to live in the same dimensional space. References ---------- Schönemann, P. H. (1966). A generalized solution of the orthogonal Procrustes problem. Psychometrika, 31(1), 1-10. """ try: X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) if X.shape != Y.shape: raise ValueError( f"X and Y must have same shape for Procrustes: " f"X is {X.shape}, Y is {Y.shape}" ) # Check for NaN/Inf values if np.any(np.isnan(X)) or np.any(np.isnan(Y)): return np.nan if np.any(np.isinf(X)) or np.any(np.isinf(Y)): return np.nan # Check for degenerate cases (all-zero or constant columns) X_std = X.std(axis=0) Y_std = Y.std(axis=0) if np.any(X_std < 1e-12) or np.any(Y_std < 1e-12): # Add small noise to break degeneracy rng = np.random.default_rng(320) X = X + rng.normal(0, 1e-8, X.shape) Y = Y + rng.normal(0, 1e-8, Y.shape) if center: # Center the data X_mean = X.mean(axis=0) Y_mean = Y.mean(axis=0) X_centered = X - X_mean Y_centered = Y - Y_mean else: X_centered = X.copy() Y_centered = Y.copy() # Check if matrices are degenerate after centering X_norm = np.linalg.norm(X_centered, 'fro') Y_norm = np.linalg.norm(Y_centered, 'fro') if X_norm < 1e-12 or Y_norm < 1e-12: return np.nan if scale: # Scale to unit Frobenius norm X_scaled = X_centered / X_norm Y_scaled = Y_centered / Y_norm else: X_scaled = X_centered Y_scaled = Y_centered # Find optimal orthogonal transformation R, scale_factor = orthogonal_procrustes(X_scaled, Y_scaled) # Apply transformation Y_aligned = Y_scaled @ R # Compute disparity (mean squared error) disparity = np.mean((X_scaled - Y_aligned) ** 2) # Convert disparity to similarity # Disparity ranges from 0 (perfect) to ~2 (opposite) # Map to similarity in [0, 1] similarity = 1.0 - min(disparity / 2.0, 1.0) if not np.isfinite(similarity): return np.nan return float(np.clip(similarity, 0.0, 1.0)) except (ValueError, np.linalg.LinAlgError): return np.nan
# ============================================================================= # RDM-based Similarity # =============================================================================
[docs] def rdm_similarity( X: np.ndarray, Y: np.ndarray, metric: Literal["cosine", "correlation", "euclidean"] = "cosine", method: Literal["spearman", "pearson"] = "spearman", ) -> float: """ RDM-based similarity using correlation of pairwise distances. Computes Representational Dissimilarity Matrices (RDMs) for X and Y, then measures their correlation. This is the same approach used in shesha.rdm_similarity but available here for comparison with CKA. Parameters ---------- X : np.ndarray First representation matrix of shape (n_samples, n_features_x). Y : np.ndarray Second representation matrix of shape (n_samples, n_features_y). Must have same number of samples as X. metric : str, default="cosine" Distance metric for RDM: 'cosine', 'correlation', or 'euclidean'. method : str, default="spearman" Correlation method: 'spearman' (rank-based) or 'pearson' (linear). Returns ------- float RDM similarity in [-1, 1]. Higher values indicate more similar pairwise distance structure. Spearman is more robust to outliers. Examples -------- >>> import numpy as np >>> from shesha.similarity import rdm_similarity >>> >>> X = np.random.randn(100, 50) >>> Y = np.random.randn(100, 30) >>> >>> # Spearman correlation (robust, rank-based) >>> sim_spearman = rdm_similarity(X, Y, method='spearman') >>> >>> # Pearson correlation (linear) >>> sim_pearson = rdm_similarity(X, Y, method='pearson') Notes ----- RDM similarity is similar to RSA (Representational Similarity Analysis). Spearman correlation is preferred as it's robust to monotonic transformations of distances and less sensitive to outliers. Unlike CKA, RDM similarity operates on pairwise distances rather than Gram matrices, making it more interpretable but potentially less sensitive. See Also -------- shesha.rdm_similarity : Identical implementation in core module cka : Alternative similarity metric using kernel alignment """ X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) if X.shape[0] != Y.shape[0]: raise ValueError( f"X and Y must have same number of samples: " f"X has {X.shape[0]}, Y has {Y.shape[0]}" ) # Compute RDMs (condensed form - upper triangle only) rdm_x = pdist(X, metric=metric) rdm_y = pdist(Y, metric=metric) # Compute correlation if method == "spearman": corr = spearmanr(rdm_x, rdm_y).correlation elif method == "pearson": corr, _ = pearsonr(rdm_x, rdm_y) else: raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'") return float(corr) if np.isfinite(corr) else 0.0