1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
6use sklears_core::{
7 error::{Result as SklResult, SklearsError},
8 traits::{Estimator, Fit, Transform, Untrained},
9 types::Float,
10};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
47pub struct ReproducingKernelImputer<S = Untrained> {
48 state: S,
49 kernels: Vec<String>,
50 kernel_weights: Vec<f64>,
51 kernel_params: HashMap<String, HashMap<String, f64>>,
52 regularization: String,
53 lambda_reg: f64,
54 alpha_elastic: f64,
55 adaptive_weights: bool,
56 interpolation_method: String,
57 smoothing_parameter: f64,
58 missing_values: f64,
59 max_iter: usize,
60 tol: f64,
61 normalize_kernels: bool,
62 use_bias: bool,
63}
64
65#[derive(Debug, Clone)]
67pub struct ReproducingKernelImputerTrained {
68 X_train_: Array2<f64>,
69 y_train_: HashMap<usize, Array1<f64>>,
70 learned_weights_: HashMap<usize, Array1<f64>>, kernel_weights_: HashMap<usize, Vec<f64>>, bias_: HashMap<usize, f64>, n_features_in_: usize,
74 kernel_matrices_: HashMap<usize, Vec<Array2<f64>>>, regularization_path_: HashMap<usize, Vec<f64>>, }
77
78impl ReproducingKernelImputer<Untrained> {
79 pub fn new() -> Self {
81 Self {
82 state: Untrained,
83 kernels: vec!["rbf".to_string(), "linear".to_string()],
84 kernel_weights: vec![0.5, 0.5],
85 kernel_params: HashMap::new(),
86 regularization: "ridge".to_string(),
87 lambda_reg: 0.01,
88 alpha_elastic: 0.5,
89 adaptive_weights: true,
90 interpolation_method: "nyström".to_string(),
91 smoothing_parameter: 1.0,
92 missing_values: f64::NAN,
93 max_iter: 1000,
94 tol: 1e-6,
95 normalize_kernels: true,
96 use_bias: true,
97 }
98 }
99
100 pub fn kernels(mut self, kernels: Vec<String>) -> Self {
102 self.kernel_weights = vec![1.0 / kernels.len() as f64; kernels.len()];
103 self.kernels = kernels;
104 self
105 }
106
107 pub fn kernel_weights(mut self, weights: Vec<f64>) -> Self {
109 self.kernel_weights = weights;
110 self
111 }
112
113 pub fn kernel_params(mut self, params: HashMap<String, HashMap<String, f64>>) -> Self {
115 self.kernel_params = params;
116 self
117 }
118
119 pub fn regularization(mut self, regularization: String) -> Self {
121 self.regularization = regularization;
122 self
123 }
124
125 pub fn lambda_reg(mut self, lambda_reg: f64) -> Self {
127 self.lambda_reg = lambda_reg;
128 self
129 }
130
131 pub fn alpha_elastic(mut self, alpha_elastic: f64) -> Self {
133 self.alpha_elastic = alpha_elastic;
134 self
135 }
136
137 pub fn adaptive_weights(mut self, adaptive_weights: bool) -> Self {
139 self.adaptive_weights = adaptive_weights;
140 self
141 }
142
143 pub fn interpolation_method(mut self, method: String) -> Self {
145 self.interpolation_method = method;
146 self
147 }
148
149 pub fn smoothing_parameter(mut self, smoothing_parameter: f64) -> Self {
151 self.smoothing_parameter = smoothing_parameter;
152 self
153 }
154
155 pub fn missing_values(mut self, missing_values: f64) -> Self {
157 self.missing_values = missing_values;
158 self
159 }
160
161 pub fn max_iter(mut self, max_iter: usize) -> Self {
163 self.max_iter = max_iter;
164 self
165 }
166
167 pub fn tol(mut self, tol: f64) -> Self {
169 self.tol = tol;
170 self
171 }
172
173 pub fn normalize_kernels(mut self, normalize_kernels: bool) -> Self {
175 self.normalize_kernels = normalize_kernels;
176 self
177 }
178
179 pub fn use_bias(mut self, use_bias: bool) -> Self {
181 self.use_bias = use_bias;
182 self
183 }
184
185 fn is_missing(&self, value: f64) -> bool {
186 if self.missing_values.is_nan() {
187 value.is_nan()
188 } else {
189 (value - self.missing_values).abs() < f64::EPSILON
190 }
191 }
192
193 fn kernel_function(
195 &self,
196 x1: &ArrayView1<f64>,
197 x2: &ArrayView1<f64>,
198 kernel_type: &str,
199 params: &HashMap<String, f64>,
200 ) -> f64 {
201 match kernel_type {
202 "rbf" => {
203 let gamma = params.get("gamma").unwrap_or(&1.0);
204 let diff = (x1 - x2).mapv(|x| x * x).sum();
205 (-gamma * diff).exp()
206 }
207 "linear" => {
208 let offset = params.get("offset").unwrap_or(&0.0);
209 x1.dot(x2) + offset
210 }
211 "polynomial" => {
212 let degree = *params.get("degree").unwrap_or(&3.0) as i32;
213 let gamma = params.get("gamma").unwrap_or(&1.0);
214 let coef0 = params.get("coef0").unwrap_or(&1.0);
215 (gamma * x1.dot(x2) + coef0).powi(degree)
216 }
217 "sobolev" => {
218 let order = params.get("order").unwrap_or(&1.0);
219 let diff = (x1 - x2).mapv(|x| x.abs()).sum();
220 if diff < f64::EPSILON {
221 1.0
222 } else {
223 (1.0 + diff).powf(-order)
224 }
225 }
226 "periodic" => {
227 let period = params.get("period").unwrap_or(&1.0);
228 let length_scale = params.get("length_scale").unwrap_or(&1.0);
229 let diff = (x1 - x2)
230 .mapv(|x| (std::f64::consts::PI * x / period).sin().powi(2))
231 .sum();
232 (-2.0 * diff / (length_scale * length_scale)).exp()
233 }
234 "matern32" => {
235 let length_scale = params.get("length_scale").unwrap_or(&1.0);
236 let r = ((x1 - x2).mapv(|x| x * x).sum().sqrt()) / length_scale;
237 (1.0 + (3.0_f64).sqrt() * r) * (-(3.0_f64).sqrt() * r).exp()
238 }
239 "matern52" => {
240 let length_scale = params.get("length_scale").unwrap_or(&1.0);
241 let r = ((x1 - x2).mapv(|x| x * x).sum().sqrt()) / length_scale;
242 (1.0 + (5.0_f64).sqrt() * r + 5.0 * r * r / 3.0) * (-(5.0_f64).sqrt() * r).exp()
243 }
244 "rational_quadratic" => {
245 let alpha = params.get("alpha").unwrap_or(&1.0);
246 let length_scale = params.get("length_scale").unwrap_or(&1.0);
247 let diff = (x1 - x2).mapv(|x| x * x).sum();
248 (1.0 + diff / (2.0 * alpha * length_scale * length_scale)).powf(-alpha)
249 }
250 "laplacian" => {
251 let gamma = params.get("gamma").unwrap_or(&1.0);
252 let diff = (x1 - x2).mapv(|x| x.abs()).sum();
253 (-gamma * diff).exp()
254 }
255 _ => {
256 let diff = (x1 - x2).mapv(|x| x * x).sum();
258 (-diff).exp()
259 }
260 }
261 }
262
263 fn compute_combined_kernel_matrix(
268 &self,
269 X: &Array2<f64>,
270 weights: &[f64],
271 ) -> SklResult<Array2<f64>> {
272 let n_samples = X.nrows();
273 let mut K_combined = Array2::zeros((n_samples, n_samples));
274
275 for (i, kernel_type) in self.kernels.iter().enumerate() {
276 let params = self
277 .kernel_params
278 .get(kernel_type)
279 .cloned()
280 .unwrap_or_default();
281 let mut K = Array2::zeros((n_samples, n_samples));
282
283 for row in 0..n_samples {
284 for col in 0..n_samples {
285 K[[row, col]] =
286 self.kernel_function(&X.row(row), &X.row(col), kernel_type, ¶ms);
287 }
288 }
289
290 if self.normalize_kernels {
292 let trace = K.diag().sum();
293 if trace > f64::EPSILON {
294 K /= trace / n_samples as f64;
295 }
296 }
297
298 if i < weights.len() {
300 K_combined = K_combined + weights[i] * K;
301 } else {
302 K_combined = K_combined + (1.0 / self.kernels.len() as f64) * K;
303 }
304 }
305
306 Ok(K_combined)
307 }
308
309 fn solve_rkhs_problem(&self, K: &Array2<f64>, y: &Array1<f64>) -> SklResult<Array1<f64>> {
311 let n = K.nrows();
312 let mut K_reg = K.clone();
313
314 for i in 0..n {
316 K_reg[[i, i]] += self.lambda_reg;
317 }
318
319 match self.regularization.as_str() {
320 "ridge" => {
321 self.solve_linear_system(&K_reg, y)
323 }
324 "lasso" => {
325 self.solve_lasso(&K_reg, y)
327 }
328 "elastic_net" => {
329 self.solve_elastic_net(&K_reg, y)
331 }
332 _ => {
333 self.solve_linear_system(&K_reg, y)
335 }
336 }
337 }
338
339 fn solve_linear_system(&self, A: &Array2<f64>, b: &Array1<f64>) -> SklResult<Array1<f64>> {
341 let n = A.nrows();
343 let mut x = Array1::zeros(n);
344
345 for _iter in 0..100 {
347 let residual = A.dot(&x) - b;
348 let gradient = A.t().dot(&residual);
349 x = &x - 0.01 * &gradient;
350 }
351
352 Ok(x)
353 }
354
355 fn solve_lasso(&self, K: &Array2<f64>, y: &Array1<f64>) -> SklResult<Array1<f64>> {
356 let n = K.nrows();
358 let mut alpha = Array1::zeros(n);
359 let step_size = 0.001;
360
361 for _iter in 0..self.max_iter {
362 let residual = K.dot(&alpha) - y;
363 let gradient = K.t().dot(&residual);
364 alpha = &alpha - step_size * &gradient;
365
366 let threshold = self.lambda_reg * step_size;
368 alpha = alpha.mapv(|x| {
369 if x > threshold {
370 x - threshold
371 } else if x < -threshold {
372 x + threshold
373 } else {
374 0.0
375 }
376 });
377 }
378
379 Ok(alpha)
380 }
381
382 fn solve_elastic_net(&self, K: &Array2<f64>, y: &Array1<f64>) -> SklResult<Array1<f64>> {
383 let ridge_result = self.solve_linear_system(K, y)?;
385 let lasso_result = self.solve_lasso(K, y)?;
386
387 let alpha_val = self.alpha_elastic;
389 Ok(alpha_val * lasso_result + (1.0 - alpha_val) * ridge_result)
390 }
391}
392
393impl Default for ReproducingKernelImputer<Untrained> {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399impl Estimator for ReproducingKernelImputer<Untrained> {
400 type Config = ();
401 type Error = SklearsError;
402 type Float = Float;
403
404 fn config(&self) -> &Self::Config {
405 &()
406 }
407}
408
409impl Fit<ArrayView2<'_, Float>, ()> for ReproducingKernelImputer<Untrained> {
410 type Fitted = ReproducingKernelImputer<ReproducingKernelImputerTrained>;
411
412 #[allow(non_snake_case)]
413 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
414 let X = X.mapv(|x| x);
415 let (n_samples, n_features) = X.dim();
416
417 let mut complete_rows = Vec::new();
419 for i in 0..n_samples {
420 let mut is_complete = true;
421 for j in 0..n_features {
422 if self.is_missing(X[[i, j]]) {
423 is_complete = false;
424 break;
425 }
426 }
427 if is_complete {
428 complete_rows.push(i);
429 }
430 }
431
432 if complete_rows.is_empty() {
433 return Err(SklearsError::InvalidInput(
434 "No complete cases found for training".to_string(),
435 ));
436 }
437
438 let mut X_train = Array2::zeros((complete_rows.len(), n_features));
440 for (new_i, &orig_i) in complete_rows.iter().enumerate() {
441 for j in 0..n_features {
442 X_train[[new_i, j]] = X[[orig_i, j]];
443 }
444 }
445
446 let mut y_train = HashMap::new();
448 let mut learned_weights = HashMap::new();
449 let mut kernel_weights = HashMap::new();
450 let mut bias_terms = HashMap::new();
451 let mut kernel_matrices = HashMap::new();
452 let mut regularization_path = HashMap::new();
453
454 for target_feature in 0..n_features {
455 let mut X_feat = Array2::zeros((complete_rows.len(), n_features - 1));
457 let mut col_idx = 0;
458 for j in 0..n_features {
459 if j != target_feature {
460 for i in 0..complete_rows.len() {
461 X_feat[[i, col_idx]] = X_train[[i, j]];
462 }
463 col_idx += 1;
464 }
465 }
466
467 let y_target = X_train.column(target_feature).to_owned();
469
470 let optimal_weights = self.kernel_weights.clone();
472
473 let K_combined = self.compute_combined_kernel_matrix(&X_feat, &optimal_weights)?;
475
476 let alpha_coeffs = self.solve_rkhs_problem(&K_combined, &y_target)?;
478
479 let bias = if self.use_bias {
481 y_target.mean().unwrap_or(0.0)
482 } else {
483 0.0
484 };
485
486 y_train.insert(target_feature, y_target);
487 learned_weights.insert(target_feature, alpha_coeffs);
488 kernel_weights.insert(target_feature, optimal_weights);
489 bias_terms.insert(target_feature, bias);
490 kernel_matrices.insert(target_feature, Vec::new()); regularization_path.insert(target_feature, Vec::new()); }
493
494 Ok(ReproducingKernelImputer {
495 state: ReproducingKernelImputerTrained {
496 X_train_: X_train,
497 y_train_: y_train,
498 learned_weights_: learned_weights,
499 kernel_weights_: kernel_weights,
500 bias_: bias_terms,
501 n_features_in_: n_features,
502 kernel_matrices_: kernel_matrices,
503 regularization_path_: regularization_path,
504 },
505 kernels: self.kernels,
506 kernel_weights: self.kernel_weights,
507 kernel_params: self.kernel_params,
508 regularization: self.regularization,
509 lambda_reg: self.lambda_reg,
510 alpha_elastic: self.alpha_elastic,
511 adaptive_weights: self.adaptive_weights,
512 interpolation_method: self.interpolation_method,
513 smoothing_parameter: self.smoothing_parameter,
514 missing_values: self.missing_values,
515 max_iter: self.max_iter,
516 tol: self.tol,
517 normalize_kernels: self.normalize_kernels,
518 use_bias: self.use_bias,
519 })
520 }
521}
522
523impl Transform<ArrayView2<'_, Float>, Array2<Float>>
524 for ReproducingKernelImputer<ReproducingKernelImputerTrained>
525{
526 #[allow(non_snake_case)]
527 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
528 let X = X.mapv(|x| x);
529 let (n_samples, n_features) = X.dim();
530
531 if n_features != self.state.n_features_in_ {
532 return Err(SklearsError::InvalidInput(format!(
533 "Number of features {} does not match training features {}",
534 n_features, self.state.n_features_in_
535 )));
536 }
537
538 let mut X_imputed = X.clone();
539
540 for i in 0..n_samples {
541 for j in 0..n_features {
542 if self.is_missing(X_imputed[[i, j]]) {
543 let imputed_value = self.predict_rkhs_value(&X_imputed, i, j)?;
544 X_imputed[[i, j]] = imputed_value;
545 }
546 }
547 }
548
549 Ok(X_imputed.mapv(|x| x as Float))
550 }
551}
552
553impl ReproducingKernelImputer<ReproducingKernelImputerTrained> {
554 fn is_missing(&self, value: f64) -> bool {
555 if self.missing_values.is_nan() {
556 value.is_nan()
557 } else {
558 (value - self.missing_values).abs() < f64::EPSILON
559 }
560 }
561
562 #[allow(non_snake_case)]
564 fn predict_rkhs_value(
565 &self,
566 X: &Array2<f64>,
567 sample_idx: usize,
568 target_feature: usize,
569 ) -> SklResult<f64> {
570 let mut x_feat = Array1::zeros(self.state.n_features_in_ - 1);
572 let mut col_idx = 0;
573 for j in 0..self.state.n_features_in_ {
574 if j != target_feature {
575 x_feat[col_idx] = if self.is_missing(X[[sample_idx, j]]) {
576 self.state.X_train_.column(j).mean().unwrap_or(0.0)
578 } else {
579 X[[sample_idx, j]]
580 };
581 col_idx += 1;
582 }
583 }
584
585 let alpha_coeffs = self
587 .state
588 .learned_weights_
589 .get(&target_feature)
590 .ok_or_else(|| {
591 SklearsError::InvalidInput("Missing learned coefficients".to_string())
592 })?;
593 let optimal_weights = self
594 .state
595 .kernel_weights_
596 .get(&target_feature)
597 .ok_or_else(|| SklearsError::InvalidInput("Missing kernel weights".to_string()))?;
598 let bias = self.state.bias_.get(&target_feature).unwrap_or(&0.0);
599
600 let X_train_feat = self.get_training_features(target_feature);
602
603 let mut prediction = *bias;
605 for i in 0..X_train_feat.nrows() {
606 let kernel_val = self.compute_combined_kernel_value(
607 &x_feat.view(),
608 &X_train_feat.row(i),
609 optimal_weights,
610 );
611 prediction += alpha_coeffs[i] * kernel_val;
612 }
613
614 Ok(prediction)
615 }
616
617 fn compute_combined_kernel_value(
619 &self,
620 x1: &ArrayView1<f64>,
621 x2: &ArrayView1<f64>,
622 weights: &[f64],
623 ) -> f64 {
624 let mut kernel_val = 0.0;
625 for (i, kernel_type) in self.kernels.iter().enumerate() {
626 let params = self
627 .kernel_params
628 .get(kernel_type)
629 .cloned()
630 .unwrap_or_default();
631 let k_val = self.kernel_function(x1, x2, kernel_type, ¶ms);
632 if i < weights.len() {
633 kernel_val += weights[i] * k_val;
634 } else {
635 kernel_val += (1.0 / self.kernels.len() as f64) * k_val;
636 }
637 }
638 kernel_val
639 }
640
641 fn kernel_function(
643 &self,
644 x1: &ArrayView1<f64>,
645 x2: &ArrayView1<f64>,
646 kernel_type: &str,
647 params: &HashMap<String, f64>,
648 ) -> f64 {
649 match kernel_type {
650 "rbf" => {
651 let gamma = params.get("gamma").unwrap_or(&1.0);
652 let diff = (x1 - x2).mapv(|x| x * x).sum();
653 (-gamma * diff).exp()
654 }
655 "linear" => {
656 let offset = params.get("offset").unwrap_or(&0.0);
657 x1.dot(x2) + offset
658 }
659 "polynomial" => {
660 let degree = *params.get("degree").unwrap_or(&3.0) as i32;
661 let gamma = params.get("gamma").unwrap_or(&1.0);
662 let coef0 = params.get("coef0").unwrap_or(&1.0);
663 (gamma * x1.dot(x2) + coef0).powi(degree)
664 }
665 _ => {
666 let diff = (x1 - x2).mapv(|x| x * x).sum();
668 (-diff).exp()
669 }
670 }
671 }
672
673 fn get_training_features(&self, target_feature: usize) -> Array2<f64> {
674 let n_train = self.state.X_train_.nrows();
675 let mut X_feat = Array2::zeros((n_train, self.state.n_features_in_ - 1));
676
677 let mut col_idx = 0;
678 for j in 0..self.state.n_features_in_ {
679 if j != target_feature {
680 for i in 0..n_train {
681 X_feat[[i, col_idx]] = self.state.X_train_[[i, j]];
682 }
683 col_idx += 1;
684 }
685 }
686
687 X_feat
688 }
689
690 pub fn learned_kernel_weights(&self) -> &HashMap<usize, Vec<f64>> {
692 &self.state.kernel_weights_
693 }
694
695 pub fn regularization_path(&self) -> &HashMap<usize, Vec<f64>> {
697 &self.state.regularization_path_
698 }
699
700 pub fn bias_terms(&self) -> &HashMap<usize, f64> {
702 &self.state.bias_
703 }
704}