1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7 error::{Result as SklResult, SklearsError},
8 traits::{Estimator, Fit, Transform, Untrained},
9 types::Float,
10};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
45pub struct KernelRidgeImputer<S = Untrained> {
46 state: S,
47 alpha: f64,
48 kernel: String,
49 gamma: f64,
50 degree: usize,
51 coef0: f64,
52 missing_values: f64,
53 max_iter: usize,
54 tol: f64,
55}
56
57#[derive(Debug, Clone)]
59pub struct KernelRidgeImputerTrained {
60 X_train_: Array2<f64>,
61 y_train_: HashMap<usize, Array1<f64>>, alpha_: HashMap<usize, Array1<f64>>, n_features_in_: usize,
64}
65
66impl KernelRidgeImputer<Untrained> {
67 pub fn new() -> Self {
69 Self {
70 state: Untrained,
71 alpha: 1.0,
72 kernel: "rbf".to_string(),
73 gamma: 1.0,
74 degree: 3,
75 coef0: 1.0,
76 missing_values: f64::NAN,
77 max_iter: 1000,
78 tol: 1e-6,
79 }
80 }
81
82 pub fn alpha(mut self, alpha: f64) -> Self {
84 self.alpha = alpha;
85 self
86 }
87
88 pub fn kernel(mut self, kernel: String) -> Self {
90 self.kernel = kernel;
91 self
92 }
93
94 pub fn gamma(mut self, gamma: f64) -> Self {
96 self.gamma = gamma;
97 self
98 }
99
100 pub fn degree(mut self, degree: usize) -> Self {
102 self.degree = degree;
103 self
104 }
105
106 pub fn coef0(mut self, coef0: f64) -> Self {
108 self.coef0 = coef0;
109 self
110 }
111
112 pub fn missing_values(mut self, missing_values: f64) -> Self {
114 self.missing_values = missing_values;
115 self
116 }
117
118 pub fn max_iter(mut self, max_iter: usize) -> Self {
120 self.max_iter = max_iter;
121 self
122 }
123
124 pub fn tol(mut self, tol: f64) -> Self {
126 self.tol = tol;
127 self
128 }
129
130 fn is_missing(&self, value: f64) -> bool {
131 if self.missing_values.is_nan() {
132 value.is_nan()
133 } else {
134 (value - self.missing_values).abs() < f64::EPSILON
135 }
136 }
137
138 fn kernel_function(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
139 match self.kernel.as_str() {
141 "linear" => x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>(),
142 "polynomial" => {
143 let dot = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>();
144 (self.gamma * dot + self.coef0).powi(self.degree as i32)
145 }
146 "rbf" => {
147 let dist_sq = x1
148 .iter()
149 .zip(x2.iter())
150 .map(|(a, b)| (a - b).powi(2))
151 .sum::<f64>();
152 (-self.gamma * dist_sq).exp()
153 }
154 "sigmoid" => {
155 let dot = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>();
156 (self.gamma * dot + self.coef0).tanh()
157 }
158 "laplacian" => {
159 let dist = x1
160 .iter()
161 .zip(x2.iter())
162 .map(|(a, b)| (a - b).abs())
163 .sum::<f64>();
164 (-self.gamma * dist).exp()
165 }
166 _ => x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>(), }
168 }
169}
170
171impl Default for KernelRidgeImputer<Untrained> {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl Estimator for KernelRidgeImputer<Untrained> {
178 type Config = ();
179 type Error = SklearsError;
180 type Float = Float;
181
182 fn config(&self) -> &Self::Config {
183 &()
184 }
185}
186
187impl Fit<ArrayView2<'_, Float>, ()> for KernelRidgeImputer<Untrained> {
188 type Fitted = KernelRidgeImputer<KernelRidgeImputerTrained>;
189
190 #[allow(non_snake_case)]
191 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
192 let X = X.mapv(|x| x);
193 let (n_samples, n_features) = X.dim();
194
195 let mut complete_rows = Vec::new();
197 for i in 0..n_samples {
198 let mut is_complete = true;
199 for j in 0..n_features {
200 if self.is_missing(X[[i, j]]) {
201 is_complete = false;
202 break;
203 }
204 }
205 if is_complete {
206 complete_rows.push(i);
207 }
208 }
209
210 if complete_rows.is_empty() {
211 return Err(SklearsError::InvalidInput(
212 "No complete cases found for training".to_string(),
213 ));
214 }
215
216 let mut X_train = Array2::zeros((complete_rows.len(), n_features));
218 for (new_i, &orig_i) in complete_rows.iter().enumerate() {
219 for j in 0..n_features {
220 X_train[[new_i, j]] = X[[orig_i, j]];
221 }
222 }
223
224 let mut y_train = HashMap::new();
226 let mut alpha_coeffs = HashMap::new();
227
228 for target_feature in 0..n_features {
229 let mut X_feat = Array2::zeros((complete_rows.len(), n_features - 1));
231 let mut col_idx = 0;
232 for j in 0..n_features {
233 if j != target_feature {
234 for i in 0..complete_rows.len() {
235 X_feat[[i, col_idx]] = X_train[[i, j]];
236 }
237 col_idx += 1;
238 }
239 }
240
241 let y_target = X_train.column(target_feature).to_owned();
243
244 let K = self.compute_kernel_matrix(&X_feat)?;
246
247 let mut K_reg = K.clone();
249 for i in 0..K_reg.nrows() {
250 K_reg[[i, i]] += self.alpha;
251 }
252
253 let alpha_vec = self.solve_linear_system(&K_reg, &y_target)?;
254
255 y_train.insert(target_feature, y_target);
256 alpha_coeffs.insert(target_feature, alpha_vec);
257 }
258
259 Ok(KernelRidgeImputer {
260 state: KernelRidgeImputerTrained {
261 X_train_: X_train,
262 y_train_: y_train,
263 alpha_: alpha_coeffs,
264 n_features_in_: n_features,
265 },
266 alpha: self.alpha,
267 kernel: self.kernel,
268 gamma: self.gamma,
269 degree: self.degree,
270 coef0: self.coef0,
271 missing_values: self.missing_values,
272 max_iter: self.max_iter,
273 tol: self.tol,
274 })
275 }
276}
277
278impl Transform<ArrayView2<'_, Float>, Array2<Float>>
279 for KernelRidgeImputer<KernelRidgeImputerTrained>
280{
281 #[allow(non_snake_case)]
282 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
283 let X = X.mapv(|x| x);
284 let (n_samples, n_features) = X.dim();
285
286 if n_features != self.state.n_features_in_ {
287 return Err(SklearsError::InvalidInput(format!(
288 "Number of features {} does not match training features {}",
289 n_features, self.state.n_features_in_
290 )));
291 }
292
293 let mut X_imputed = X.clone();
294
295 for i in 0..n_samples {
296 for j in 0..n_features {
297 if self.is_missing(X_imputed[[i, j]]) {
298 let imputed_value = self.predict_feature_value(&X_imputed, i, j)?;
299 X_imputed[[i, j]] = imputed_value;
300 }
301 }
302 }
303
304 Ok(X_imputed.mapv(|x| x as Float))
305 }
306}
307
308impl KernelRidgeImputer<Untrained> {
309 fn compute_kernel_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
310 let n_samples = X.nrows();
311 let mut K = Array2::<f64>::zeros((n_samples, n_samples));
312
313 for i in 0..n_samples {
314 for j in 0..n_samples {
315 let x1 = X.row(i);
316 let x2 = X.row(j);
317 K[[i, j]] = match self.kernel.as_str() {
318 "linear" => x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>(),
319 "polynomial" => {
320 let dot = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>();
321 (self.gamma * dot + self.coef0).powi(self.degree as i32)
322 }
323 "rbf" => {
324 let dist_sq = x1
325 .iter()
326 .zip(x2.iter())
327 .map(|(a, b)| (a - b).powi(2))
328 .sum::<f64>();
329 (-self.gamma * dist_sq).exp()
330 }
331 "sigmoid" => {
332 let dot = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>();
333 (self.gamma * dot + self.coef0).tanh()
334 }
335 "laplacian" => {
336 let dist = x1
337 .iter()
338 .zip(x2.iter())
339 .map(|(a, b)| (a - b).abs())
340 .sum::<f64>();
341 (-self.gamma * dist).exp()
342 }
343 _ => x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>(),
344 };
345 }
346 }
347 Ok(K)
348 }
349
350 #[allow(non_snake_case)]
351 fn solve_linear_system(&self, A: &Array2<f64>, b: &Array1<f64>) -> SklResult<Array1<f64>> {
352 let n = A.nrows();
353 if n != A.ncols() || n != b.len() {
354 return Err(SklearsError::InvalidInput(
355 "Matrix dimensions don't match".to_string(),
356 ));
357 }
358
359 let L = self.cholesky_decomposition(A)?;
361
362 let mut y = Array1::zeros(n);
364 for i in 0..n {
365 let mut sum = 0.0;
366 for j in 0..i {
367 sum += L[[i, j]] * y[j];
368 }
369 y[i] = (b[i] - sum) / L[[i, i]];
370 }
371
372 let mut x = Array1::zeros(n);
374 for i in (0..n).rev() {
375 let mut sum = 0.0;
376 for j in (i + 1)..n {
377 sum += L[[j, i]] * x[j];
378 }
379 x[i] = (y[i] - sum) / L[[i, i]];
380 }
381
382 Ok(x)
383 }
384
385 fn cholesky_decomposition(&self, A: &Array2<f64>) -> SklResult<Array2<f64>> {
386 let n = A.nrows();
387 let mut L = Array2::zeros((n, n));
388
389 for i in 0..n {
390 for j in 0..=i {
391 if i == j {
392 let mut sum = 0.0;
394 for k in 0..j {
395 sum += L[[j, k]] * L[[j, k]];
396 }
397 let val = A[[j, j]] - sum;
398 if val <= 0.0 {
399 return Err(SklearsError::InvalidInput(
400 "Matrix is not positive definite".to_string(),
401 ));
402 }
403 L[[j, j]] = val.sqrt();
404 } else {
405 let mut sum = 0.0;
407 for k in 0..j {
408 sum += L[[i, k]] * L[[j, k]];
409 }
410 L[[i, j]] = (A[[i, j]] - sum) / L[[j, j]];
411 }
412 }
413 }
414
415 Ok(L)
416 }
417}
418
419impl KernelRidgeImputer<KernelRidgeImputerTrained> {
420 fn is_missing(&self, value: f64) -> bool {
421 if self.missing_values.is_nan() {
422 value.is_nan()
423 } else {
424 (value - self.missing_values).abs() < f64::EPSILON
425 }
426 }
427
428 fn kernel_function(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
429 match self.kernel.as_str() {
431 "linear" => x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>(),
432 "polynomial" => {
433 let dot = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>();
434 (self.gamma * dot + self.coef0).powi(self.degree as i32)
435 }
436 "rbf" => {
437 let dist_sq = x1
438 .iter()
439 .zip(x2.iter())
440 .map(|(a, b)| (a - b).powi(2))
441 .sum::<f64>();
442 (-self.gamma * dist_sq).exp()
443 }
444 "sigmoid" => {
445 let dot = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>();
446 (self.gamma * dot + self.coef0).tanh()
447 }
448 "laplacian" => {
449 let dist = x1
450 .iter()
451 .zip(x2.iter())
452 .map(|(a, b)| (a - b).abs())
453 .sum::<f64>();
454 (-self.gamma * dist).exp()
455 }
456 _ => x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum::<f64>(), }
458 }
459
460 #[allow(non_snake_case)]
461 fn predict_feature_value(
462 &self,
463 X: &Array2<f64>,
464 sample_idx: usize,
465 target_feature: usize,
466 ) -> SklResult<f64> {
467 let mut x_feat = Array1::zeros(self.state.n_features_in_ - 1);
469 let mut col_idx = 0;
470 for j in 0..self.state.n_features_in_ {
471 if j != target_feature {
472 x_feat[col_idx] = if self.is_missing(X[[sample_idx, j]]) {
473 self.state.X_train_.column(j).mean().unwrap_or(0.0)
475 } else {
476 X[[sample_idx, j]]
477 };
478 col_idx += 1;
479 }
480 }
481
482 let X_train_feat = self.get_training_features(target_feature);
484 let mut k_vec = Array1::zeros(X_train_feat.nrows());
485 for i in 0..X_train_feat.nrows() {
486 k_vec[i] = self.kernel_function(&x_feat.view(), &X_train_feat.row(i));
487 }
488
489 let alpha =
491 self.state.alpha_.get(&target_feature).ok_or_else(|| {
492 SklearsError::InvalidInput("Missing dual coefficients".to_string())
493 })?;
494
495 let prediction = k_vec.dot(alpha);
497 Ok(prediction)
498 }
499
500 fn get_training_features(&self, target_feature: usize) -> Array2<f64> {
501 let n_train = self.state.X_train_.nrows();
502 let mut X_feat = Array2::zeros((n_train, self.state.n_features_in_ - 1));
503
504 let mut col_idx = 0;
505 for j in 0..self.state.n_features_in_ {
506 if j != target_feature {
507 for i in 0..n_train {
508 X_feat[[i, col_idx]] = self.state.X_train_[[i, j]];
509 }
510 col_idx += 1;
511 }
512 }
513
514 X_feat
515 }
516}