sklears_preprocessing/
kernel_centerer.rs1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
26pub struct KernelCenterer<State = Untrained> {
27 state: PhantomData<State>,
28 k_train_mean_: Option<Array1<Float>>,
30 k_train_mean_all_: Option<Float>,
31}
32
33impl KernelCenterer<Untrained> {
34 pub fn new() -> Self {
36 Self {
37 state: PhantomData,
38 k_train_mean_: None,
39 k_train_mean_all_: None,
40 }
41 }
42}
43
44impl Default for KernelCenterer<Untrained> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl Fit<Array2<Float>, (), Untrained> for KernelCenterer<Untrained> {
51 type Fitted = KernelCenterer<Trained>;
52
53 fn fit(self, k: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
54 let n_samples = k.nrows();
55
56 if n_samples == 0 {
57 return Err(SklearsError::InvalidInput(
58 "Cannot fit KernelCenterer on empty kernel matrix".to_string(),
59 ));
60 }
61
62 if k.nrows() != k.ncols() {
63 return Err(SklearsError::InvalidInput(format!(
64 "Kernel matrix must be square, got shape ({}, {})",
65 k.nrows(),
66 k.ncols()
67 )));
68 }
69
70 let k_train_mean = k
72 .mean_axis(Axis(1))
73 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute row means".to_string()))?;
74
75 let k_train_mean_all = k_train_mean.mean().ok_or_else(|| {
77 SklearsError::InvalidInput("Failed to compute overall mean".to_string())
78 })?;
79
80 Ok(KernelCenterer {
81 state: PhantomData,
82 k_train_mean_: Some(k_train_mean),
83 k_train_mean_all_: Some(k_train_mean_all),
84 })
85 }
86}
87
88impl Transform<Array2<Float>> for KernelCenterer<Trained> {
89 fn transform(&self, k: &Array2<Float>) -> Result<Array2<Float>> {
90 let k_train_mean = self
91 .k_train_mean_
92 .as_ref()
93 .ok_or_else(|| SklearsError::NotFitted {
94 operation: "transform".to_string(),
95 })?;
96 let k_train_mean_all = self
97 .k_train_mean_all_
98 .ok_or_else(|| SklearsError::NotFitted {
99 operation: "transform".to_string(),
100 })?;
101
102 let n_samples_train = k_train_mean.len();
103 let n_samples_test = k.nrows();
104
105 if k.ncols() != n_samples_train {
106 return Err(SklearsError::InvalidInput(format!(
107 "Kernel matrix has wrong number of columns. Expected {}, got {}",
108 n_samples_train,
109 k.ncols()
110 )));
111 }
112
113 let mut k_centered = k.clone();
115
116 let k_test_mean = k
118 .mean_axis(Axis(1))
119 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute row means".to_string()))?;
120
121 for i in 0..n_samples_test {
123 for j in 0..n_samples_train {
124 k_centered[[i, j]] =
125 k[[i, j]] - k_test_mean[i] - k_train_mean[j] + k_train_mean_all;
126 }
127 }
128
129 Ok(k_centered)
130 }
131}
132
133#[allow(non_snake_case)]
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use approx::assert_abs_diff_eq;
138 use scirs2_core::ndarray::arr2;
139
140 #[test]
141 fn test_kernel_centerer_fit_transform() {
142 let k_train = arr2(&[[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]);
144
145 let centerer = KernelCenterer::new();
146 let fitted = centerer.fit(&k_train, &()).unwrap();
147
148 let k_centered = fitted.transform(&k_train).unwrap();
150
151 let mean = k_centered.mean().unwrap();
153 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
154
155 for i in 0..k_centered.nrows() {
157 let row_mean = k_centered.row(i).mean().unwrap();
158 assert_abs_diff_eq!(row_mean, 0.0, epsilon = 1e-10);
159
160 let col_mean = k_centered.column(i).mean().unwrap();
161 assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
162 }
163 }
164
165 #[test]
166 fn test_kernel_centerer_transform_new() {
167 let k_train = arr2(&[[1.0, 2.0], [2.0, 4.0]]);
169
170 let k_test = arr2(&[[1.5, 3.0], [2.5, 5.0], [3.5, 7.0]]);
172
173 let centerer = KernelCenterer::new();
174 let fitted = centerer.fit(&k_train, &()).unwrap();
175 let k_test_centered = fitted.transform(&k_test).unwrap();
176
177 assert_eq!(k_test_centered.shape(), &[3, 2]);
179 }
180
181 #[test]
182 fn test_kernel_centerer_errors() {
183 let k_invalid = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
185
186 let centerer = KernelCenterer::new();
187 assert!(centerer.fit(&k_invalid, &()).is_err());
188
189 let k_empty = Array2::<Float>::zeros((0, 0));
191 let centerer = KernelCenterer::new();
192 assert!(centerer.fit(&k_empty, &()).is_err());
193 }
194}