sklears_preprocessing/
kernel_centerer.rs1use scirs2_core::ndarray::{Array1, Array2, Axis};
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
26pub struct KernelCenterer<State = Untrained> {
27 state: PhantomData<State>,
28 k_train_mean_: Option<Array1<Float>>,
30 k_train_mean_all_: Option<Float>,
31}
32
33impl KernelCenterer<Untrained> {
34 pub fn new() -> Self {
36 Self {
37 state: PhantomData,
38 k_train_mean_: None,
39 k_train_mean_all_: None,
40 }
41 }
42}
43
44impl Default for KernelCenterer<Untrained> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl Fit<Array2<Float>, (), Untrained> for KernelCenterer<Untrained> {
51 type Fitted = KernelCenterer<Trained>;
52
53 fn fit(self, k: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
54 let n_samples = k.nrows();
55
56 if n_samples == 0 {
57 return Err(SklearsError::InvalidInput(
58 "Cannot fit KernelCenterer on empty kernel matrix".to_string(),
59 ));
60 }
61
62 if k.nrows() != k.ncols() {
63 return Err(SklearsError::InvalidInput(format!(
64 "Kernel matrix must be square, got shape ({}, {})",
65 k.nrows(),
66 k.ncols()
67 )));
68 }
69
70 let k_train_mean = k
72 .mean_axis(Axis(1))
73 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute row means".to_string()))?;
74
75 let k_train_mean_all = k_train_mean.mean().ok_or_else(|| {
77 SklearsError::InvalidInput("Failed to compute overall mean".to_string())
78 })?;
79
80 Ok(KernelCenterer {
81 state: PhantomData,
82 k_train_mean_: Some(k_train_mean),
83 k_train_mean_all_: Some(k_train_mean_all),
84 })
85 }
86}
87
88impl Transform<Array2<Float>> for KernelCenterer<Trained> {
89 fn transform(&self, k: &Array2<Float>) -> Result<Array2<Float>> {
90 let k_train_mean = self
91 .k_train_mean_
92 .as_ref()
93 .ok_or_else(|| SklearsError::NotFitted {
94 operation: "transform".to_string(),
95 })?;
96 let k_train_mean_all = self
97 .k_train_mean_all_
98 .ok_or_else(|| SklearsError::NotFitted {
99 operation: "transform".to_string(),
100 })?;
101
102 let n_samples_train = k_train_mean.len();
103 let n_samples_test = k.nrows();
104
105 if k.ncols() != n_samples_train {
106 return Err(SklearsError::InvalidInput(format!(
107 "Kernel matrix has wrong number of columns. Expected {}, got {}",
108 n_samples_train,
109 k.ncols()
110 )));
111 }
112
113 let mut k_centered = k.clone();
115
116 let k_test_mean = k
118 .mean_axis(Axis(1))
119 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute row means".to_string()))?;
120
121 for i in 0..n_samples_test {
123 for j in 0..n_samples_train {
124 k_centered[[i, j]] =
125 k[[i, j]] - k_test_mean[i] - k_train_mean[j] + k_train_mean_all;
126 }
127 }
128
129 Ok(k_centered)
130 }
131}
132
133#[allow(non_snake_case)]
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use approx::assert_abs_diff_eq;
138 use scirs2_core::ndarray::arr2;
139
140 #[test]
141 fn test_kernel_centerer_fit_transform() {
142 let k_train = arr2(&[[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]);
144
145 let centerer = KernelCenterer::new();
146 let fitted = centerer
147 .fit(&k_train, &())
148 .expect("model fitting should succeed");
149
150 let k_centered = fitted
152 .transform(&k_train)
153 .expect("transformation should succeed");
154
155 let mean = k_centered
157 .mean()
158 .expect("array should have elements for mean computation");
159 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
160
161 for i in 0..k_centered.nrows() {
163 let row_mean = k_centered
164 .row(i)
165 .mean()
166 .expect("array should have elements for mean computation");
167 assert_abs_diff_eq!(row_mean, 0.0, epsilon = 1e-10);
168
169 let col_mean = k_centered
170 .column(i)
171 .mean()
172 .expect("array should have elements for mean computation");
173 assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
174 }
175 }
176
177 #[test]
178 fn test_kernel_centerer_transform_new() {
179 let k_train = arr2(&[[1.0, 2.0], [2.0, 4.0]]);
181
182 let k_test = arr2(&[[1.5, 3.0], [2.5, 5.0], [3.5, 7.0]]);
184
185 let centerer = KernelCenterer::new();
186 let fitted = centerer
187 .fit(&k_train, &())
188 .expect("model fitting should succeed");
189 let k_test_centered = fitted
190 .transform(&k_test)
191 .expect("transformation should succeed");
192
193 assert_eq!(k_test_centered.shape(), &[3, 2]);
195 }
196
197 #[test]
198 fn test_kernel_centerer_errors() {
199 let k_invalid = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
201
202 let centerer = KernelCenterer::new();
203 assert!(centerer.fit(&k_invalid, &()).is_err());
204
205 let k_empty = Array2::<Float>::zeros((0, 0));
207 let centerer = KernelCenterer::new();
208 assert!(centerer.fit(&k_empty, &()).is_err());
209 }
210}