"""
Shesha Similarity: Representational Similarity Metrics
This module provides metrics for measuring similarity between representations,
complementing the stability metrics in shesha.core. While stability measures
intrinsic geometric robustness, similarity measures extrinsic alignment.
Key distinction from the paper:
- Similarity is an *extrinsic* property (how one representation aligns with another)
- Stability is an *intrinsic* property (how robust a representation's geometry is)
- These are empirically uncorrelated (ρ ≈ 0.01)
"""
import numpy as np
from scipy.spatial.distance import pdist
from scipy.stats import spearmanr, pearsonr
from scipy.linalg import orthogonal_procrustes
from typing import Optional, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from ._utils import bootstrap_ci_two_sample
__all__ = [
"cka",
"cka_linear",
"cka_debiased",
"procrustes_similarity",
"rdm_similarity",
]
EPS = 1e-12
# =============================================================================
# CKA (Centered Kernel Alignment)
# =============================================================================
[docs]
def cka_linear(
X: np.ndarray,
Y: np.ndarray,
n_bootstrap_ci: Optional[int] = None,
ci: float = 0.95,
seed: Optional[int] = None,
) -> Union[float, dict]:
"""
Linear Centered Kernel Alignment (CKA) - Standard version.
Measures similarity between two representations using linear kernels.
This is the standard (non-debiased) version which is simpler and more
numerically stable, recommended for most use cases.
Parameters
----------
X : np.ndarray
First representation matrix of shape (n_samples, n_features_x).
Y : np.ndarray
Second representation matrix of shape (n_samples, n_features_y).
Must have same number of samples as X.
n_bootstrap_ci : int, optional
If provided, compute bootstrap confidence interval by resampling
the input data this many times.
ci : float, default=0.95
Confidence level for the interval.
seed : int, optional
Random seed for bootstrap reproducibility.
Returns
-------
float or dict
If n_bootstrap_ci is None: CKA similarity score in [0, 1].
If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high',
'std', 'n_bootstraps', 'ci_level'.
Examples
--------
>>> import numpy as np
>>> from shesha.similarity import cka_linear
>>>
>>> # Two representations of the same data
>>> X = np.random.randn(100, 50)
>>> Y = np.random.randn(100, 30)
>>>
>>> similarity = cka_linear(X, Y)
>>> print(f"CKA: {similarity:.3f}")
>>>
>>> # With bootstrap CI
>>> result = cka_linear(X, Y, n_bootstrap_ci=1000)
>>> print(f"{result['mean']:.3f} [{result['ci_low']:.3f}, {result['ci_high']:.3f}]")
Notes
-----
CKA is invariant to:
- Orthogonal transformations
- Isotropic scaling
CKA measures the similarity of Gram matrices (X @ X.T and Y @ Y.T),
which capture the pairwise similarities between samples in each
representation space.
References
----------
Kornblith, S., Norouzi, M., Lee, H., & Hinton, G. (2019).
Similarity of neural network representations revisited.
ICML 2019.
"""
if n_bootstrap_ci is not None:
return bootstrap_ci_two_sample(
cka_linear, n_bootstrap_ci, ci, seed,
np.asarray(X, dtype=np.float64),
np.asarray(Y, dtype=np.float64),
)
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
if X.shape[0] != Y.shape[0]:
raise ValueError(
f"X and Y must have same number of samples: "
f"X has {X.shape[0]}, Y has {Y.shape[0]}"
)
# Center the data (subtract column means)
X = X - X.mean(axis=0, keepdims=True)
Y = Y - Y.mean(axis=0, keepdims=True)
# Compute HSIC using Frobenius norm of cross-Gram matrix
# HSIC(X, Y) = ||X^T Y||_F^2
num = np.linalg.norm(X.T @ Y, 'fro') ** 2
# Normalize by self-similarities
# CKA = HSIC(X, Y) / sqrt(HSIC(X, X) * HSIC(Y, Y))
den = np.linalg.norm(X.T @ X, 'fro') * np.linalg.norm(Y.T @ Y, 'fro')
return float(num / (den + EPS))
[docs]
def cka_debiased(
X: np.ndarray,
Y: np.ndarray,
n_bootstrap_ci: Optional[int] = None,
ci: float = 0.95,
seed: Optional[int] = None,
) -> Union[float, dict]:
"""
Debiased Centered Kernel Alignment (CKA).
Unbiased estimator of CKA that corrects for finite sample effects.
More accurate for small sample sizes but computationally more expensive.
Parameters
----------
X : np.ndarray
First representation matrix of shape (n_samples, n_features_x).
Y : np.ndarray
Second representation matrix of shape (n_samples, n_features_y).
Must have same number of samples as X.
n_bootstrap_ci : int, optional
If provided, compute bootstrap confidence interval by resampling
the input data this many times.
ci : float, default=0.95
Confidence level for the interval.
seed : int, optional
Random seed for bootstrap reproducibility.
Returns
-------
float or dict
If n_bootstrap_ci is None: debiased CKA similarity score in [0, 1].
If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high',
'std', 'n_bootstraps', 'ci_level'.
Examples
--------
>>> import numpy as np
>>> from shesha.similarity import cka_debiased
>>>
>>> # For small sample sizes, debiased version is more accurate
>>> X = np.random.randn(50, 20)
>>> Y = np.random.randn(50, 15)
>>>
>>> # Compare standard vs debiased
>>> from shesha.similarity import cka_linear
>>> std_cka = cka_linear(X, Y)
>>> debiased_cka = cka_debiased(X, Y)
>>>
>>> print(f"Standard: {std_cka:.3f}")
>>> print(f"Debiased: {debiased_cka:.3f}")
Notes
-----
For n < 4, falls back to standard CKA as debiasing is not well-defined.
The debiased estimator uses the unbiased HSIC estimator from Kornblith
et al. (2019), which removes diagonal terms and applies correction factors.
Recommended when:
- Sample size is small (n < 100)
- Exact statistical properties are important
- Computing statistical significance
References
----------
Kornblith, S., Norouzi, M., Lee, H., & Hinton, G. (2019).
Similarity of neural network representations revisited.
ICML 2019.
"""
if n_bootstrap_ci is not None:
return bootstrap_ci_two_sample(
cka_debiased, n_bootstrap_ci, ci, seed,
np.asarray(X, dtype=np.float64),
np.asarray(Y, dtype=np.float64),
)
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
if X.shape[0] != Y.shape[0]:
raise ValueError(
f"X and Y must have same number of samples: "
f"X has {X.shape[0]}, Y has {Y.shape[0]}"
)
# Center the data
X = X - X.mean(axis=0, keepdims=True)
Y = Y - Y.mean(axis=0, keepdims=True)
n = X.shape[0]
# For very small samples, fall back to standard CKA
if n < 4:
num = np.linalg.norm(X.T @ Y, 'fro') ** 2
den = np.linalg.norm(X.T @ X, 'fro') * np.linalg.norm(Y.T @ Y, 'fro')
return float(num / (den + EPS))
# Helper function to center Gram matrix
def center_gram_matrix(G):
"""Center a Gram matrix: H @ G @ H where H is centering matrix."""
row_means = G.mean(axis=1, keepdims=True)
col_means = G.mean(axis=0, keepdims=True)
grand_mean = G.mean()
return G - row_means - col_means + grand_mean
# Compute and center Gram matrices
K = center_gram_matrix(X @ X.T)
L = center_gram_matrix(Y @ Y.T)
# Zero out diagonals for debiasing terms
K_no_diag = K.copy()
L_no_diag = L.copy()
np.fill_diagonal(K_no_diag, 0)
np.fill_diagonal(L_no_diag, 0)
# Debiased HSIC estimator (Kornblith et al., 2019)
# Removes bias from diagonal terms
hsic = (
np.sum(K * L)
+ (np.sum(K_no_diag) * np.sum(L_no_diag)) / ((n - 1) * (n - 2))
- 2 * np.sum(np.sum(K_no_diag, axis=1) * np.sum(L_no_diag, axis=1)) / (n - 2)
) / (n * (n - 3))
# Self-HSIC for normalization (also debiased)
hsic_xx = (
np.sum(K * K)
+ np.sum(K_no_diag)**2 / ((n - 1) * (n - 2))
- 2 * np.sum(np.sum(K_no_diag, axis=1)**2) / (n - 2)
) / (n * (n - 3))
hsic_yy = (
np.sum(L * L)
+ np.sum(L_no_diag)**2 / ((n - 1) * (n - 2))
- 2 * np.sum(np.sum(L_no_diag, axis=1)**2) / (n - 2)
) / (n * (n - 3))
# Avoid division by zero or negative values (can happen due to numerical issues)
if hsic_xx <= 0 or hsic_yy <= 0:
return 0.0
return float(hsic / np.sqrt(hsic_xx * hsic_yy))
[docs]
def cka(
X: np.ndarray,
Y: np.ndarray,
debiased: bool = False,
n_bootstrap_ci: Optional[int] = None,
ci: float = 0.95,
seed: Optional[int] = None,
) -> Union[float, dict]:
"""
Centered Kernel Alignment (CKA) - Unified interface.
Convenience function that selects between standard and debiased CKA.
Parameters
----------
X : np.ndarray
First representation matrix of shape (n_samples, n_features_x).
Y : np.ndarray
Second representation matrix of shape (n_samples, n_features_y).
debiased : bool, default=False
If True, use debiased estimator. Recommended for small sample sizes.
n_bootstrap_ci : int, optional
If provided, compute bootstrap confidence interval by resampling
the input data this many times.
ci : float, default=0.95
Confidence level for the interval.
seed : int, optional
Random seed for bootstrap reproducibility.
Returns
-------
float or dict
If n_bootstrap_ci is None: CKA similarity score in [0, 1].
If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high',
'std', 'n_bootstraps', 'ci_level'.
Examples
--------
>>> from shesha.similarity import cka
>>>
>>> X = np.random.randn(100, 50)
>>> Y = np.random.randn(100, 30)
>>>
>>> # Standard CKA (default, faster)
>>> sim = cka(X, Y)
>>>
>>> # Debiased CKA (more accurate for small n)
>>> sim_debiased = cka(X, Y, debiased=True)
>>>
>>> # With bootstrap CI
>>> result = cka(X, Y, n_bootstrap_ci=1000)
See Also
--------
cka_linear : Standard CKA implementation
cka_debiased : Debiased CKA implementation
"""
if debiased:
return cka_debiased(X, Y, n_bootstrap_ci=n_bootstrap_ci, ci=ci, seed=seed)
else:
return cka_linear(X, Y, n_bootstrap_ci=n_bootstrap_ci, ci=ci, seed=seed)
# =============================================================================
# Procrustes Similarity
# =============================================================================
def _validate_procrustes_inputs(
X: np.ndarray, Y: np.ndarray
) -> Optional[float]:
"""Return np.nan if inputs are invalid, else None (meaning inputs are ok)."""
if X.shape != Y.shape:
raise ValueError(
f"X and Y must have same shape for Procrustes: "
f"X is {X.shape}, Y is {Y.shape}"
)
if np.any(np.isnan(X)) or np.any(np.isnan(Y)):
return np.nan
if np.any(np.isinf(X)) or np.any(np.isinf(Y)):
return np.nan
X_std = X.std(axis=0)
Y_std = Y.std(axis=0)
if np.any(X_std < 1e-12) or np.any(Y_std < 1e-12):
return "degenerate"
return None
def _preprocess_procrustes(
X: np.ndarray, Y: np.ndarray, center: bool, scale: bool
) -> Optional[tuple]:
"""Center and/or scale X and Y. Returns (X_out, Y_out) or None if degenerate."""
if center:
X = X - X.mean(axis=0)
Y = Y - Y.mean(axis=0)
X_norm = np.linalg.norm(X, "fro")
Y_norm = np.linalg.norm(Y, "fro")
if X_norm < 1e-12 or Y_norm < 1e-12:
return None
if scale:
X = X / X_norm
Y = Y / Y_norm
return X, Y
[docs]
def procrustes_similarity(
X: np.ndarray,
Y: np.ndarray,
center: bool = True,
scale: bool = True,
n_bootstrap_ci: Optional[int] = None,
ci: float = 0.95,
seed: Optional[int] = None,
) -> Union[float, dict]:
"""
Procrustes similarity between two representations.
Finds the optimal orthogonal transformation that aligns Y to X and
returns the similarity (1 - disparity). Unlike CKA, Procrustes attempts
to directly align the representations in their original spaces.
Parameters
----------
X : np.ndarray
First representation matrix of shape (n_samples, n_features).
Y : np.ndarray
Second representation matrix of shape (n_samples, n_features).
Must have same shape as X.
center : bool, default=True
If True, center both matrices before alignment.
scale : bool, default=True
If True, scale to unit Frobenius norm before alignment.
n_bootstrap_ci : int, optional
If provided, compute bootstrap confidence interval by resampling
the input data this many times.
ci : float, default=0.95
Confidence level for the interval.
seed : int, optional
Random seed for bootstrap reproducibility.
Returns
-------
float or dict
If n_bootstrap_ci is None: Procrustes similarity in [0, 1].
If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high',
'std', 'n_bootstraps', 'ci_level'.
Examples
--------
>>> import numpy as np
>>> from shesha.similarity import procrustes_similarity
>>>
>>> # Two representations that differ by a rotation
>>> X = np.random.randn(100, 50)
>>> Q = np.linalg.qr(np.random.randn(50, 50))[0] # Random rotation
>>> Y = X @ Q
>>>
>>> similarity = procrustes_similarity(X, Y)
>>> print(f"Procrustes: {similarity:.3f}") # Should be ~1.0
Notes
-----
Procrustes is more sensitive to outliers and noise than CKA, which can
be both an advantage (detects small changes) and disadvantage (more false
alarms). The paper shows CKA is often preferred for representation analysis.
If dimensions don't match, returns NaN. Unlike CKA, Procrustes requires
representations to live in the same dimensional space.
References
----------
Schönemann, P. H. (1966). A generalized solution of the orthogonal
Procrustes problem. Psychometrika, 31(1), 1-10.
"""
if n_bootstrap_ci is not None:
return bootstrap_ci_two_sample(
procrustes_similarity, n_bootstrap_ci, ci, seed,
np.asarray(X, dtype=np.float64),
np.asarray(Y, dtype=np.float64),
center=center, scale=scale,
)
try:
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
status = _validate_procrustes_inputs(X, Y)
if status == "degenerate":
rng = np.random.default_rng(320)
X = X + rng.normal(0, 1e-8, X.shape)
Y = Y + rng.normal(0, 1e-8, Y.shape)
elif status is not None:
return status # np.nan
result = _preprocess_procrustes(X, Y, center, scale)
if result is None:
return np.nan
X_scaled, Y_scaled = result
R, _ = orthogonal_procrustes(X_scaled, Y_scaled)
disparity = np.mean((X_scaled - Y_scaled @ R) ** 2)
similarity = 1.0 - min(disparity / 2.0, 1.0)
if not np.isfinite(similarity):
return np.nan
return float(np.clip(similarity, 0.0, 1.0))
except (ValueError, np.linalg.LinAlgError):
return np.nan
# =============================================================================
# RDM-based Similarity
# =============================================================================
[docs]
def rdm_similarity(
X: np.ndarray,
Y: np.ndarray,
metric: Literal["cosine", "correlation", "euclidean"] = "cosine",
method: Literal["spearman", "pearson"] = "spearman",
n_bootstrap_ci: Optional[int] = None,
ci: float = 0.95,
seed: Optional[int] = None,
) -> Union[float, dict]:
"""
RDM-based similarity using correlation of pairwise distances.
Computes Representational Dissimilarity Matrices (RDMs) for X and Y,
then measures their correlation. This is the same approach used in
shesha.rdm_similarity but available here for comparison with CKA.
Parameters
----------
X : np.ndarray
First representation matrix of shape (n_samples, n_features_x).
Y : np.ndarray
Second representation matrix of shape (n_samples, n_features_y).
Must have same number of samples as X.
metric : str, default="cosine"
Distance metric for RDM: 'cosine', 'correlation', or 'euclidean'.
method : str, default="spearman"
Correlation method: 'spearman' (rank-based) or 'pearson' (linear).
n_bootstrap_ci : int, optional
If provided, compute bootstrap confidence interval by resampling
the input data this many times.
ci : float, default=0.95
Confidence level for the interval.
seed : int, optional
Random seed for bootstrap reproducibility.
Returns
-------
float or dict
If n_bootstrap_ci is None: RDM similarity in [-1, 1].
If n_bootstrap_ci is set: dict with keys 'mean', 'ci_low', 'ci_high',
'std', 'n_bootstraps', 'ci_level'.
Examples
--------
>>> import numpy as np
>>> from shesha.similarity import rdm_similarity
>>>
>>> X = np.random.randn(100, 50)
>>> Y = np.random.randn(100, 30)
>>>
>>> # Spearman correlation (robust, rank-based)
>>> sim_spearman = rdm_similarity(X, Y, method='spearman')
>>>
>>> # Pearson correlation (linear)
>>> sim_pearson = rdm_similarity(X, Y, method='pearson')
Notes
-----
RDM similarity is similar to RSA (Representational Similarity Analysis).
Spearman correlation is preferred as it's robust to monotonic transformations
of distances and less sensitive to outliers.
Unlike CKA, RDM similarity operates on pairwise distances rather than
Gram matrices, making it more interpretable but potentially less sensitive.
See Also
--------
shesha.rdm_similarity : Identical implementation in core module
cka : Alternative similarity metric using kernel alignment
"""
if n_bootstrap_ci is not None:
return bootstrap_ci_two_sample(
rdm_similarity, n_bootstrap_ci, ci, seed,
np.asarray(X, dtype=np.float64),
np.asarray(Y, dtype=np.float64),
metric=metric, method=method,
)
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
if X.shape[0] != Y.shape[0]:
raise ValueError(
f"X and Y must have same number of samples: "
f"X has {X.shape[0]}, Y has {Y.shape[0]}"
)
# Compute RDMs (condensed form - upper triangle only)
rdm_x = pdist(X, metric=metric)
rdm_y = pdist(Y, metric=metric)
# Compute correlation
if method == "spearman":
corr = spearmanr(rdm_x, rdm_y).correlation
elif method == "pearson":
corr, _ = pearsonr(rdm_x, rdm_y)
else:
raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'")
return float(corr) if np.isfinite(corr) else 0.0