sklears_preprocessing/
kernel_centerer.rs

1//! Kernel Centerer for centering kernel matrices
2//!
3//! This module provides the KernelCenterer transformer which centers a kernel matrix
4//! in the feature space defined by the kernel.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// KernelCenterer centers a kernel matrix
15///
16/// Let K(x, z) be a kernel matrix computed from samples x and z.
17/// KernelCenterer computes the centered kernel matrix K_c as:
18///
19/// K_c(x, z) = K(x, z) - K_mean_x - K_mean_z + K_mean_all
20///
21/// where:
22/// - K_mean_x is the mean of K along samples x
23/// - K_mean_z is the mean of K along samples z  
24/// - K_mean_all is the overall mean of K
25#[derive(Debug, Clone)]
26pub struct KernelCenterer<State = Untrained> {
27    state: PhantomData<State>,
28    // Fitted parameters
29    k_train_mean_: Option<Array1<Float>>,
30    k_train_mean_all_: Option<Float>,
31}
32
33impl KernelCenterer<Untrained> {
34    /// Create a new KernelCenterer
35    pub fn new() -> Self {
36        Self {
37            state: PhantomData,
38            k_train_mean_: None,
39            k_train_mean_all_: None,
40        }
41    }
42}
43
44impl Default for KernelCenterer<Untrained> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl Fit<Array2<Float>, (), Untrained> for KernelCenterer<Untrained> {
51    type Fitted = KernelCenterer<Trained>;
52
53    fn fit(self, k: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
54        let n_samples = k.nrows();
55
56        if n_samples == 0 {
57            return Err(SklearsError::InvalidInput(
58                "Cannot fit KernelCenterer on empty kernel matrix".to_string(),
59            ));
60        }
61
62        if k.nrows() != k.ncols() {
63            return Err(SklearsError::InvalidInput(format!(
64                "Kernel matrix must be square, got shape ({}, {})",
65                k.nrows(),
66                k.ncols()
67            )));
68        }
69
70        // Compute mean of each row (mean over training samples)
71        let k_train_mean = k
72            .mean_axis(Axis(1))
73            .ok_or_else(|| SklearsError::InvalidInput("Failed to compute row means".to_string()))?;
74
75        // Compute overall mean
76        let k_train_mean_all = k_train_mean.mean().ok_or_else(|| {
77            SklearsError::InvalidInput("Failed to compute overall mean".to_string())
78        })?;
79
80        Ok(KernelCenterer {
81            state: PhantomData,
82            k_train_mean_: Some(k_train_mean),
83            k_train_mean_all_: Some(k_train_mean_all),
84        })
85    }
86}
87
88impl Transform<Array2<Float>> for KernelCenterer<Trained> {
89    fn transform(&self, k: &Array2<Float>) -> Result<Array2<Float>> {
90        let k_train_mean = self
91            .k_train_mean_
92            .as_ref()
93            .ok_or_else(|| SklearsError::NotFitted {
94                operation: "transform".to_string(),
95            })?;
96        let k_train_mean_all = self
97            .k_train_mean_all_
98            .ok_or_else(|| SklearsError::NotFitted {
99                operation: "transform".to_string(),
100            })?;
101
102        let n_samples_train = k_train_mean.len();
103        let n_samples_test = k.nrows();
104
105        if k.ncols() != n_samples_train {
106            return Err(SklearsError::InvalidInput(format!(
107                "Kernel matrix has wrong number of columns. Expected {}, got {}",
108                n_samples_train,
109                k.ncols()
110            )));
111        }
112
113        // Center the kernel matrix
114        let mut k_centered = k.clone();
115
116        // Subtract row means (mean over training samples for each test sample)
117        let k_test_mean = k
118            .mean_axis(Axis(1))
119            .ok_or_else(|| SklearsError::InvalidInput("Failed to compute row means".to_string()))?;
120
121        // Apply centering formula: K_c(x, z) = K(x, z) - K_mean_x - K_mean_z + K_mean_all
122        for i in 0..n_samples_test {
123            for j in 0..n_samples_train {
124                k_centered[[i, j]] =
125                    k[[i, j]] - k_test_mean[i] - k_train_mean[j] + k_train_mean_all;
126            }
127        }
128
129        Ok(k_centered)
130    }
131}
132
133#[allow(non_snake_case)]
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use approx::assert_abs_diff_eq;
138    use scirs2_core::ndarray::arr2;
139
140    #[test]
141    fn test_kernel_centerer_fit_transform() {
142        // Create a simple kernel matrix (e.g., linear kernel)
143        let k_train = arr2(&[[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]);
144
145        let centerer = KernelCenterer::new();
146        let fitted = centerer.fit(&k_train, &()).unwrap();
147
148        // Transform the training kernel itself
149        let k_centered = fitted.transform(&k_train).unwrap();
150
151        // Check that the centered kernel has zero mean
152        let mean = k_centered.mean().unwrap();
153        assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
154
155        // Check row and column means are zero
156        for i in 0..k_centered.nrows() {
157            let row_mean = k_centered.row(i).mean().unwrap();
158            assert_abs_diff_eq!(row_mean, 0.0, epsilon = 1e-10);
159
160            let col_mean = k_centered.column(i).mean().unwrap();
161            assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
162        }
163    }
164
165    #[test]
166    fn test_kernel_centerer_transform_new() {
167        // Train kernel
168        let k_train = arr2(&[[1.0, 2.0], [2.0, 4.0]]);
169
170        // Test kernel (new samples vs training samples)
171        let k_test = arr2(&[[1.5, 3.0], [2.5, 5.0], [3.5, 7.0]]);
172
173        let centerer = KernelCenterer::new();
174        let fitted = centerer.fit(&k_train, &()).unwrap();
175        let k_test_centered = fitted.transform(&k_test).unwrap();
176
177        // Verify shape
178        assert_eq!(k_test_centered.shape(), &[3, 2]);
179    }
180
181    #[test]
182    fn test_kernel_centerer_errors() {
183        // Non-square kernel matrix
184        let k_invalid = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
185
186        let centerer = KernelCenterer::new();
187        assert!(centerer.fit(&k_invalid, &()).is_err());
188
189        // Empty kernel matrix
190        let k_empty = Array2::<Float>::zeros((0, 0));
191        let centerer = KernelCenterer::new();
192        assert!(centerer.fit(&k_empty, &()).is_err());
193    }
194}