tensorlogic_sklears_kernels/kernel_pca/
centering.rs1use scirs2_core::ndarray::{Array1, Array2};
19
20use crate::kernel_pca::error::{KernelPcaError, KernelPcaResult};
21
22#[derive(Clone, Debug, PartialEq)]
39pub struct KernelCenteringStats {
40 pub row_means: Array1<f64>,
42 pub grand_mean: f64,
44}
45
46impl KernelCenteringStats {
47 pub fn n(&self) -> usize {
49 self.row_means.len()
50 }
51}
52
53pub fn double_center(k: &Array2<f64>) -> KernelPcaResult<(Array2<f64>, KernelCenteringStats)> {
62 let (rows, cols) = (k.nrows(), k.ncols());
63 if rows == 0 || cols == 0 {
64 return Err(KernelPcaError::InvalidInput(
65 "double_center: Gram matrix must be non-empty".to_string(),
66 ));
67 }
68 if rows != cols {
69 return Err(KernelPcaError::InvalidInput(format!(
70 "double_center: Gram matrix must be square, got {}x{}",
71 rows, cols
72 )));
73 }
74
75 let n = rows;
76 let n_f = n as f64;
77
78 let mut row_means = Array1::<f64>::zeros(n);
80 for j in 0..n {
81 let mut s = 0.0;
82 for i in 0..n {
83 s += k[(i, j)];
84 }
85 row_means[j] = s / n_f;
86 }
87
88 let grand_mean = row_means.iter().copied().sum::<f64>() / n_f;
90
91 let mut centered = Array2::<f64>::zeros((n, n));
93 for i in 0..n {
94 for j in 0..n {
95 centered[(i, j)] = k[(i, j)] - row_means[i] - row_means[j] + grand_mean;
96 }
97 }
98
99 for i in 0..n {
102 for j in (i + 1)..n {
103 let avg = 0.5 * (centered[(i, j)] + centered[(j, i)]);
104 centered[(i, j)] = avg;
105 centered[(j, i)] = avg;
106 }
107 }
108
109 Ok((
110 centered,
111 KernelCenteringStats {
112 row_means,
113 grand_mean,
114 },
115 ))
116}
117
118pub fn center_test_kernel(
129 k_test: &[f64],
130 stats: &KernelCenteringStats,
131) -> KernelPcaResult<Array1<f64>> {
132 let n = stats.n();
133 if k_test.len() != n {
134 return Err(KernelPcaError::DimensionMismatch {
135 expected: n,
136 got: k_test.len(),
137 context: "center_test_kernel: test kernel row length".to_string(),
138 });
139 }
140
141 let n_f = n as f64;
142 let row_mean_test = k_test.iter().copied().sum::<f64>() / n_f;
143
144 let mut out = Array1::<f64>::zeros(n);
145 for j in 0..n {
146 out[j] = k_test[j] - stats.row_means[j] - row_mean_test + stats.grand_mean;
147 }
148 Ok(out)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 fn constant_matrix(n: usize, value: f64) -> Array2<f64> {
156 Array2::<f64>::from_shape_fn((n, n), |_| value)
157 }
158
159 #[test]
160 fn double_center_rejects_empty_matrix() {
161 let k = Array2::<f64>::zeros((0, 0));
162 assert!(double_center(&k).is_err());
163 }
164
165 #[test]
166 fn double_center_rejects_non_square() {
167 let k = Array2::<f64>::zeros((3, 4));
168 assert!(double_center(&k).is_err());
169 }
170
171 #[test]
172 fn double_center_of_constant_matrix_is_zero() {
173 let k = constant_matrix(5, 3.7);
177 let (centered, stats) = double_center(&k).expect("double_center");
178 for v in centered.iter() {
179 assert!(v.abs() < 1e-12, "expected zero, got {}", v);
180 }
181 assert_eq!(stats.row_means.len(), 5);
182 for &rm in stats.row_means.iter() {
183 assert!((rm - 3.7).abs() < 1e-12);
184 }
185 assert!((stats.grand_mean - 3.7).abs() < 1e-12);
186 }
187
188 #[test]
189 fn double_center_row_column_sums_are_zero() {
190 let k = Array2::<f64>::from_shape_fn((4, 4), |(i, j)| ((i + 1) as f64) * ((j + 1) as f64));
193 let (centered, _) = double_center(&k).expect("double_center");
194 for i in 0..4 {
195 let row_sum: f64 = (0..4).map(|j| centered[(i, j)]).sum();
196 assert!(row_sum.abs() < 1e-10, "row {} sum = {}", i, row_sum);
197 }
198 for j in 0..4 {
199 let col_sum: f64 = (0..4).map(|i| centered[(i, j)]).sum();
200 assert!(col_sum.abs() < 1e-10, "col {} sum = {}", j, col_sum);
201 }
202 }
203
204 #[test]
205 fn double_center_is_symmetric() {
206 let k = Array2::<f64>::from_shape_fn((6, 6), |(i, j)| {
209 let a = (i as f64).sin();
211 let b = (j as f64).sin();
212 1.0 + a + b + 0.5 * (a * b)
213 });
214 let (centered, _) = double_center(&k).expect("double_center");
215 for i in 0..6 {
216 for j in 0..6 {
217 assert!(
218 (centered[(i, j)] - centered[(j, i)]).abs() < 1e-14,
219 "asymmetry at ({},{})",
220 i,
221 j
222 );
223 }
224 }
225 }
226
227 #[test]
228 fn center_test_kernel_matches_pulling_row_from_double_center() {
229 let k = Array2::<f64>::from_shape_fn((4, 4), |(i, j)| {
233 (-((i as f64 - j as f64).powi(2)) / 4.0).exp()
235 });
236 let (centered, stats) = double_center(&k).expect("double_center");
237 for i in 0..4 {
238 let test_row: Vec<f64> = (0..4).map(|j| k[(i, j)]).collect();
239 let out = center_test_kernel(&test_row, &stats).expect("center_test_kernel");
240 for j in 0..4 {
241 assert!(
242 (out[j] - centered[(i, j)]).abs() < 1e-12,
243 "row {} col {} mismatch: test={}, expected={}",
244 i,
245 j,
246 out[j],
247 centered[(i, j)]
248 );
249 }
250 }
251 }
252
253 #[test]
254 fn center_test_kernel_rejects_wrong_length() {
255 let stats = KernelCenteringStats {
256 row_means: Array1::<f64>::zeros(3),
257 grand_mean: 0.0,
258 };
259 let err = center_test_kernel(&[1.0, 2.0], &stats).expect_err("must reject");
260 match err {
261 KernelPcaError::DimensionMismatch { expected, got, .. } => {
262 assert_eq!(expected, 3);
263 assert_eq!(got, 2);
264 }
265 other => panic!("wrong variant: {:?}", other),
266 }
267 }
268}