sklears_kernel_approximation/kernel_ridge_regression/
multitask_regression.rs

1//! Multi-Task Kernel Ridge Regression Implementation
2//!
3//! This module provides multi-task learning capabilities for kernel ridge regression,
4//! allowing simultaneous learning across multiple related tasks with cross-task
5//! regularization strategies to improve generalization.
6
7use crate::{
8    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
12
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20/// Multi-Task Kernel Ridge Regression
21///
22/// Performs kernel ridge regression simultaneously across multiple related tasks,
23/// with optional cross-task regularization to encourage similarity between tasks.
24///
25/// This is particularly useful when you have multiple regression targets that are
26/// related and can benefit from shared representations and joint learning.
27///
28/// # Parameters
29///
30/// * `approximation_method` - Method for kernel approximation
31/// * `alpha` - Within-task regularization parameter
32/// * `task_regularization` - Cross-task regularization strategy
33/// * `solver` - Method for solving the linear system
34/// * `random_state` - Random seed for reproducibility
35///
36/// # Examples
37///
38/// ```rust,ignore
39/// use sklears_kernel_approximation::kernel_ridge_regression::{
40#[derive(Debug, Clone)]
41pub struct MultiTaskKernelRidgeRegression<State = Untrained> {
42    pub approximation_method: ApproximationMethod,
43    pub alpha: Float,
44    pub task_regularization: TaskRegularization,
45    pub solver: Solver,
46    pub random_state: Option<u64>,
47
48    // Fitted parameters
49    weights_: Option<Array2<Float>>, // Shape: (n_features, n_tasks)
50    feature_transformer_: Option<FeatureTransformer>,
51    n_tasks_: Option<usize>,
52
53    _state: PhantomData<State>,
54}
55
56/// Cross-task regularization strategies for multi-task learning
57#[derive(Debug, Clone)]
58pub enum TaskRegularization {
59    /// No cross-task regularization (independent tasks)
60    None,
61    /// L2 regularization on task weight differences
62    L2 { beta: Float },
63    /// L1 regularization promoting sparsity across tasks
64    L1 { beta: Float },
65    /// Nuclear norm regularization on weight matrix (low-rank)
66    NuclearNorm { beta: Float },
67    /// Group sparsity regularization
68    GroupSparsity { beta: Float },
69    /// Custom regularization function
70    Custom {
71        beta: Float,
72        regularizer: fn(&Array2<Float>) -> Float,
73    },
74}
75
76impl Default for TaskRegularization {
77    fn default() -> Self {
78        Self::None
79    }
80}
81
82impl MultiTaskKernelRidgeRegression<Untrained> {
83    /// Create a new multi-task kernel ridge regression model
84    pub fn new(approximation_method: ApproximationMethod) -> Self {
85        Self {
86            approximation_method,
87            alpha: 1.0,
88            task_regularization: TaskRegularization::None,
89            solver: Solver::Direct,
90            random_state: None,
91            weights_: None,
92            feature_transformer_: None,
93            n_tasks_: None,
94            _state: PhantomData,
95        }
96    }
97
98    /// Set regularization parameter
99    pub fn alpha(mut self, alpha: Float) -> Self {
100        self.alpha = alpha;
101        self
102    }
103
104    /// Set cross-task regularization strategy
105    pub fn task_regularization(mut self, regularization: TaskRegularization) -> Self {
106        self.task_regularization = regularization;
107        self
108    }
109
110    /// Set solver method
111    pub fn solver(mut self, solver: Solver) -> Self {
112        self.solver = solver;
113        self
114    }
115
116    /// Set random state for reproducibility
117    pub fn random_state(mut self, seed: u64) -> Self {
118        self.random_state = Some(seed);
119        self
120    }
121
122    /// Compute regularization penalty for the weight matrix
123    fn compute_task_regularization_penalty(&self, weights: &Array2<Float>) -> Float {
124        match &self.task_regularization {
125            TaskRegularization::None => 0.0,
126            TaskRegularization::L2 { beta } => {
127                // L2 penalty on differences between task weights
128                let mut penalty = 0.0;
129                let n_tasks = weights.ncols();
130                for i in 0..n_tasks {
131                    for j in (i + 1)..n_tasks {
132                        let diff = &weights.column(i) - &weights.column(j);
133                        penalty += diff.mapv(|x| x * x).sum();
134                    }
135                }
136                beta * penalty / (n_tasks * (n_tasks - 1) / 2) as Float
137            }
138            TaskRegularization::L1 { beta } => {
139                // L1 penalty on task weights
140                beta * weights.mapv(|x| x.abs()).sum()
141            }
142            TaskRegularization::NuclearNorm { beta } => {
143                // Nuclear norm (sum of singular values)
144                // Approximate with Frobenius norm for efficiency
145                beta * weights.mapv(|x| x * x).sum().sqrt()
146            }
147            TaskRegularization::GroupSparsity { beta } => {
148                // Group sparsity: L2,1 norm
149                let mut penalty = 0.0;
150                for row in weights.axis_iter(Axis(0)) {
151                    penalty += row.mapv(|x| x * x).sum().sqrt();
152                }
153                beta * penalty
154            }
155            TaskRegularization::Custom { beta, regularizer } => beta * regularizer(weights),
156        }
157    }
158}
159
160impl Estimator for MultiTaskKernelRidgeRegression<Untrained> {
161    type Config = ();
162    type Error = SklearsError;
163    type Float = Float;
164
165    fn config(&self) -> &Self::Config {
166        &()
167    }
168}
169
170impl Fit<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Untrained> {
171    type Fitted = MultiTaskKernelRidgeRegression<Trained>;
172
173    fn fit(self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Self::Fitted> {
174        if x.nrows() != y.nrows() {
175            return Err(SklearsError::InvalidInput(
176                "Number of samples must match".to_string(),
177            ));
178        }
179
180        let _n_samples = x.nrows();
181        let n_tasks = y.ncols();
182
183        // Fit the feature transformer
184        let feature_transformer = self.fit_feature_transformer(x)?;
185        let x_transformed = feature_transformer.transform(x)?;
186        let _n_features = x_transformed.ncols();
187
188        // Solve multi-task regression problem
189        let weights = match self.solver {
190            Solver::Direct => self.solve_direct_multitask(&x_transformed, y)?,
191            Solver::SVD => self.solve_svd_multitask(&x_transformed, y)?,
192            Solver::ConjugateGradient { max_iter, tol } => {
193                self.solve_cg_multitask(&x_transformed, y, max_iter, tol)?
194            }
195        };
196
197        Ok(MultiTaskKernelRidgeRegression {
198            approximation_method: self.approximation_method,
199            alpha: self.alpha,
200            task_regularization: self.task_regularization,
201            solver: self.solver,
202            random_state: self.random_state,
203            weights_: Some(weights),
204            feature_transformer_: Some(feature_transformer),
205            n_tasks_: Some(n_tasks),
206            _state: PhantomData,
207        })
208    }
209}
210
211impl MultiTaskKernelRidgeRegression<Untrained> {
212    /// Fit the feature transformer based on the approximation method
213    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
214        match &self.approximation_method {
215            ApproximationMethod::Nystroem {
216                kernel,
217                n_components,
218                sampling_strategy,
219            } => {
220                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
221                    .sampling_strategy(sampling_strategy.clone());
222                if let Some(seed) = self.random_state {
223                    nystroem = nystroem.random_state(seed);
224                }
225                let fitted = nystroem.fit(x, &())?;
226                Ok(FeatureTransformer::Nystroem(fitted))
227            }
228            ApproximationMethod::RandomFourierFeatures {
229                n_components,
230                gamma,
231            } => {
232                let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
233                if let Some(seed) = self.random_state {
234                    rff = rff.random_state(seed);
235                }
236                let fitted = rff.fit(x, &())?;
237                Ok(FeatureTransformer::RBFSampler(fitted))
238            }
239            ApproximationMethod::StructuredRandomFeatures {
240                n_components,
241                gamma,
242            } => {
243                let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
244                if let Some(seed) = self.random_state {
245                    srf = srf.random_state(seed);
246                }
247                let fitted = srf.fit(x, &())?;
248                Ok(FeatureTransformer::StructuredRFF(fitted))
249            }
250            ApproximationMethod::Fastfood {
251                n_components,
252                gamma,
253            } => {
254                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
255                if let Some(seed) = self.random_state {
256                    fastfood = fastfood.random_state(seed);
257                }
258                let fitted = fastfood.fit(x, &())?;
259                Ok(FeatureTransformer::Fastfood(fitted))
260            }
261        }
262    }
263
264    /// Solve multi-task problem using direct method
265    fn solve_direct_multitask(
266        &self,
267        x: &Array2<Float>,
268        y: &Array2<Float>,
269    ) -> Result<Array2<Float>> {
270        let n_features = x.ncols();
271        let n_tasks = y.ncols();
272
273        // For multi-task learning, we solve each task separately but with shared features
274        // and apply cross-task regularization
275        let mut all_weights = Array2::zeros((n_features, n_tasks));
276
277        for task_idx in 0..n_tasks {
278            let y_task = y.column(task_idx);
279
280            // Standard ridge regression for this task
281            let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
282            let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
283
284            let xtx = x_f64.t().dot(&x_f64);
285            let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
286
287            let xty = x_f64.t().dot(&y_task_f64);
288            let weights_task_f64 =
289                regularized_xtx
290                    .solve(&xty)
291                    .map_err(|e| SklearsError::InvalidParameter {
292                        name: "regularization".to_string(),
293                        reason: format!("Linear system solving failed: {:?}", e),
294                    })?;
295
296            // Convert back to Float
297            let weights_task =
298                Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
299            all_weights.column_mut(task_idx).assign(&weights_task);
300        }
301
302        // Apply cross-task regularization (simplified approach)
303        // In practice, you might want to solve a joint optimization problem
304        match &self.task_regularization {
305            TaskRegularization::L2 { beta } => {
306                // Apply additional regularization penalty
307                let mean_weight = all_weights.mean_axis(Axis(1)).unwrap();
308                for mut col in all_weights.axis_iter_mut(Axis(1)) {
309                    let diff = &col.to_owned() - &mean_weight;
310                    col.scaled_add(-beta, &diff);
311                }
312            }
313            _ => {} // Other regularization methods would be implemented here
314        }
315
316        Ok(all_weights)
317    }
318
319    /// Solve multi-task problem using SVD
320    fn solve_svd_multitask(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
321        let n_features = x.ncols();
322        let n_tasks = y.ncols();
323        let mut all_weights = Array2::zeros((n_features, n_tasks));
324
325        for task_idx in 0..n_tasks {
326            let y_task = y.column(task_idx);
327
328            // Use SVD for more stable solution
329            let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
330            let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
331
332            let xtx = x_f64.t().dot(&x_f64);
333            let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
334
335            let (u, s, vt) =
336                regularized_xtx
337                    .svd(true)
338                    .map_err(|e| SklearsError::InvalidParameter {
339                        name: "svd".to_string(),
340                        reason: format!("SVD decomposition failed: {:?}", e),
341                    })?;
342
343            // Solve using SVD
344            let xty = x_f64.t().dot(&y_task_f64);
345            let ut_b = u.t().dot(&xty);
346            let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
347            let y_svd = ut_b * s_inv;
348            let weights_task_f64 = vt.t().dot(&y_svd);
349
350            // Convert back to Float
351            let weights_task =
352                Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
353            all_weights.column_mut(task_idx).assign(&weights_task);
354        }
355
356        Ok(all_weights)
357    }
358
359    /// Solve multi-task problem using conjugate gradient
360    fn solve_cg_multitask(
361        &self,
362        x: &Array2<Float>,
363        y: &Array2<Float>,
364        max_iter: usize,
365        tol: Float,
366    ) -> Result<Array2<Float>> {
367        let n_features = x.ncols();
368        let n_tasks = y.ncols();
369        let mut all_weights = Array2::zeros((n_features, n_tasks));
370
371        for task_idx in 0..n_tasks {
372            let y_task = y.column(task_idx);
373            let xty = x.t().dot(&y_task);
374
375            // Conjugate gradient solver for each task
376            let mut weights = Array1::zeros(n_features);
377            let mut r = xty.clone();
378            let mut p = r.clone();
379            let mut rsold = r.dot(&r);
380
381            for _iter in 0..max_iter {
382                let xtx_p = x.t().dot(&x.dot(&p)) + &p * self.alpha;
383                let alpha_cg = rsold / p.dot(&xtx_p);
384
385                weights = weights + &p * alpha_cg;
386                r = r - &xtx_p * alpha_cg;
387
388                let rsnew = r.dot(&r);
389
390                if rsnew.sqrt() < tol {
391                    break;
392                }
393
394                let beta = rsnew / rsold;
395                p = &r + &p * beta;
396                rsold = rsnew;
397            }
398
399            all_weights.column_mut(task_idx).assign(&weights);
400        }
401
402        Ok(all_weights)
403    }
404}
405
406impl Predict<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Trained> {
407    fn predict(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
408        let feature_transformer =
409            self.feature_transformer_
410                .as_ref()
411                .ok_or_else(|| SklearsError::NotFitted {
412                    operation: "predict".to_string(),
413                })?;
414
415        let weights = self
416            .weights_
417            .as_ref()
418            .ok_or_else(|| SklearsError::NotFitted {
419                operation: "predict".to_string(),
420            })?;
421
422        let x_transformed = feature_transformer.transform(x)?;
423        let predictions = x_transformed.dot(weights);
424
425        Ok(predictions)
426    }
427}
428
429impl MultiTaskKernelRidgeRegression<Trained> {
430    /// Get the number of tasks
431    pub fn n_tasks(&self) -> usize {
432        self.n_tasks_.unwrap_or(0)
433    }
434
435    /// Get the fitted weights for all tasks
436    pub fn weights(&self) -> Option<&Array2<Float>> {
437        self.weights_.as_ref()
438    }
439
440    /// Get the weights for a specific task
441    pub fn task_weights(&self, task_idx: usize) -> Result<Array1<Float>> {
442        let weights = self
443            .weights_
444            .as_ref()
445            .ok_or_else(|| SklearsError::NotFitted {
446                operation: "predict".to_string(),
447            })?;
448
449        if task_idx >= weights.ncols() {
450            return Err(SklearsError::InvalidInput(format!(
451                "Task index {} out of range",
452                task_idx
453            )));
454        }
455
456        Ok(weights.column(task_idx).to_owned())
457    }
458
459    /// Predict for a specific task only
460    pub fn predict_task(&self, x: &Array2<Float>, task_idx: usize) -> Result<Array1<Float>> {
461        let predictions = self.predict(x)?;
462
463        if task_idx >= predictions.ncols() {
464            return Err(SklearsError::InvalidInput(format!(
465                "Task index {} out of range",
466                task_idx
467            )));
468        }
469
470        Ok(predictions.column(task_idx).to_owned())
471    }
472}
473
474#[allow(non_snake_case)]
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use scirs2_core::ndarray::array;
479
480    #[test]
481    fn test_multitask_kernel_ridge_regression() {
482        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
483        let y = array![[1.0, 2.0], [4.0, 5.0], [9.0, 10.0], [16.0, 17.0]]; // Two tasks
484
485        let approximation = ApproximationMethod::RandomFourierFeatures {
486            n_components: 20,
487            gamma: 0.1,
488        };
489
490        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
491        let fitted = mtkrr.fit(&x, &y).unwrap();
492        let predictions = fitted.predict(&x).unwrap();
493
494        assert_eq!(predictions.shape(), &[4, 2]);
495        assert_eq!(fitted.n_tasks(), 2);
496
497        // Test individual task prediction
498        let task0_pred = fitted.predict_task(&x, 0).unwrap();
499        let task1_pred = fitted.predict_task(&x, 1).unwrap();
500
501        assert_eq!(task0_pred.len(), 4);
502        assert_eq!(task1_pred.len(), 4);
503
504        // Check that predictions are reasonable
505        for pred in predictions.iter() {
506            assert!(pred.is_finite());
507        }
508    }
509
510    #[test]
511    fn test_multitask_with_regularization() {
512        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
513        let y = array![[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]; // Similar tasks
514
515        let approximation = ApproximationMethod::RandomFourierFeatures {
516            n_components: 15,
517            gamma: 1.0,
518        };
519
520        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation)
521            .alpha(0.1)
522            .task_regularization(TaskRegularization::L2 { beta: 0.1 });
523
524        let fitted = mtkrr.fit(&x, &y).unwrap();
525        let predictions = fitted.predict(&x).unwrap();
526
527        assert_eq!(predictions.shape(), &[3, 2]);
528
529        // With L2 regularization, task weights should be similar
530        let weights = fitted.weights().unwrap();
531        let task0_weights = weights.column(0);
532        let task1_weights = weights.column(1);
533        let weight_diff = (&task0_weights - &task1_weights)
534            .mapv(|x| x.abs())
535            .mean()
536            .unwrap();
537
538        // Tasks should have similar weights due to regularization
539        assert!(weight_diff < 1.0);
540    }
541
542    #[test]
543    fn test_multitask_different_solvers() {
544        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
545        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
546
547        let approximation = ApproximationMethod::RandomFourierFeatures {
548            n_components: 10,
549            gamma: 1.0,
550        };
551
552        // Test different solvers
553        let solvers = vec![
554            Solver::Direct,
555            Solver::SVD,
556            Solver::ConjugateGradient {
557                max_iter: 100,
558                tol: 1e-6,
559            },
560        ];
561
562        for solver in solvers {
563            let mtkrr = MultiTaskKernelRidgeRegression::new(approximation.clone())
564                .solver(solver)
565                .alpha(0.1);
566
567            let fitted = mtkrr.fit(&x, &y).unwrap();
568            let predictions = fitted.predict(&x).unwrap();
569
570            assert_eq!(predictions.shape(), &[3, 2]);
571        }
572    }
573
574    #[test]
575    fn test_multitask_single_task() {
576        // Test that multi-task regression works with a single task
577        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
578        let y = array![[1.0], [2.0], [3.0]]; // Single task
579
580        let approximation = ApproximationMethod::RandomFourierFeatures {
581            n_components: 10,
582            gamma: 1.0,
583        };
584
585        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
586        let fitted = mtkrr.fit(&x, &y).unwrap();
587        let predictions = fitted.predict(&x).unwrap();
588
589        assert_eq!(predictions.shape(), &[3, 1]);
590        assert_eq!(fitted.n_tasks(), 1);
591    }
592
593    #[test]
594    fn test_multitask_reproducibility() {
595        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
596        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
597
598        let approximation = ApproximationMethod::RandomFourierFeatures {
599            n_components: 10,
600            gamma: 1.0,
601        };
602
603        let mtkrr1 = MultiTaskKernelRidgeRegression::new(approximation.clone())
604            .alpha(0.1)
605            .random_state(42);
606        let fitted1 = mtkrr1.fit(&x, &y).unwrap();
607        let pred1 = fitted1.predict(&x).unwrap();
608
609        let mtkrr2 = MultiTaskKernelRidgeRegression::new(approximation)
610            .alpha(0.1)
611            .random_state(42);
612        let fitted2 = mtkrr2.fit(&x, &y).unwrap();
613        let pred2 = fitted2.predict(&x).unwrap();
614
615        assert_eq!(pred1.shape(), pred2.shape());
616        for i in 0..pred1.nrows() {
617            for j in 0..pred1.ncols() {
618                assert!((pred1[[i, j]] - pred2[[i, j]]).abs() < 1e-10);
619            }
620        }
621    }
622
623    #[test]
624    fn test_task_regularization_penalties() {
625        let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
626
627        let model =
628            MultiTaskKernelRidgeRegression::new(ApproximationMethod::RandomFourierFeatures {
629                n_components: 10,
630                gamma: 1.0,
631            });
632
633        // Test different regularization types
634        let reg_l2 = TaskRegularization::L2 { beta: 0.1 };
635        let reg_l1 = TaskRegularization::L1 { beta: 0.1 };
636        let reg_nuclear = TaskRegularization::NuclearNorm { beta: 0.1 };
637        let reg_group = TaskRegularization::GroupSparsity { beta: 0.1 };
638
639        let model_l2 = model.clone().task_regularization(reg_l2);
640        let model_l1 = model.clone().task_regularization(reg_l1);
641        let model_nuclear = model.clone().task_regularization(reg_nuclear);
642        let model_group = model.clone().task_regularization(reg_group);
643
644        let penalty_l2 = model_l2.compute_task_regularization_penalty(&weights);
645        let penalty_l1 = model_l1.compute_task_regularization_penalty(&weights);
646        let penalty_nuclear = model_nuclear.compute_task_regularization_penalty(&weights);
647        let penalty_group = model_group.compute_task_regularization_penalty(&weights);
648
649        // All penalties should be non-negative
650        assert!(penalty_l2 >= 0.0);
651        assert!(penalty_l1 >= 0.0);
652        assert!(penalty_nuclear >= 0.0);
653        assert!(penalty_group >= 0.0);
654
655        // L1 penalty should be larger than others for this matrix
656        assert!(penalty_l1 > penalty_l2);
657    }
658}