sklears_kernel_approximation/kernel_ridge_regression/
multitask_regression.rs

1//! Multi-Task Kernel Ridge Regression Implementation
2//!
3//! This module provides multi-task learning capabilities for kernel ridge regression,
4//! allowing simultaneous learning across multiple related tasks with cross-task
5//! regularization strategies to improve generalization.
6
7use crate::{
8    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
11use scirs2_core::ndarray::ndarray_linalg::SVD;
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use sklears_core::error::{Result, SklearsError};
14use sklears_core::prelude::{Estimator, Fit, Float, Predict};
15use std::marker::PhantomData;
16
17use super::core_types::*;
18
19/// Multi-Task Kernel Ridge Regression
20///
21/// Performs kernel ridge regression simultaneously across multiple related tasks,
22/// with optional cross-task regularization to encourage similarity between tasks.
23///
24/// This is particularly useful when you have multiple regression targets that are
25/// related and can benefit from shared representations and joint learning.
26///
27/// # Parameters
28///
29/// * `approximation_method` - Method for kernel approximation
30/// * `alpha` - Within-task regularization parameter
31/// * `task_regularization` - Cross-task regularization strategy
32/// * `solver` - Method for solving the linear system
33/// * `random_state` - Random seed for reproducibility
34///
35/// # Examples
36///
37/// ```rust,ignore
38/// use sklears_kernel_approximation::kernel_ridge_regression::{
39#[derive(Debug, Clone)]
40pub struct MultiTaskKernelRidgeRegression<State = Untrained> {
41    pub approximation_method: ApproximationMethod,
42    pub alpha: Float,
43    pub task_regularization: TaskRegularization,
44    pub solver: Solver,
45    pub random_state: Option<u64>,
46
47    // Fitted parameters
48    weights_: Option<Array2<Float>>, // Shape: (n_features, n_tasks)
49    feature_transformer_: Option<FeatureTransformer>,
50    n_tasks_: Option<usize>,
51
52    _state: PhantomData<State>,
53}
54
55/// Cross-task regularization strategies for multi-task learning
56#[derive(Debug, Clone)]
57pub enum TaskRegularization {
58    /// No cross-task regularization (independent tasks)
59    None,
60    /// L2 regularization on task weight differences
61    L2 { beta: Float },
62    /// L1 regularization promoting sparsity across tasks
63    L1 { beta: Float },
64    /// Nuclear norm regularization on weight matrix (low-rank)
65    NuclearNorm { beta: Float },
66    /// Group sparsity regularization
67    GroupSparsity { beta: Float },
68    /// Custom regularization function
69    Custom {
70        beta: Float,
71        regularizer: fn(&Array2<Float>) -> Float,
72    },
73}
74
75impl Default for TaskRegularization {
76    fn default() -> Self {
77        Self::None
78    }
79}
80
81impl MultiTaskKernelRidgeRegression<Untrained> {
82    /// Create a new multi-task kernel ridge regression model
83    pub fn new(approximation_method: ApproximationMethod) -> Self {
84        Self {
85            approximation_method,
86            alpha: 1.0,
87            task_regularization: TaskRegularization::None,
88            solver: Solver::Direct,
89            random_state: None,
90            weights_: None,
91            feature_transformer_: None,
92            n_tasks_: None,
93            _state: PhantomData,
94        }
95    }
96
97    /// Set regularization parameter
98    pub fn alpha(mut self, alpha: Float) -> Self {
99        self.alpha = alpha;
100        self
101    }
102
103    /// Set cross-task regularization strategy
104    pub fn task_regularization(mut self, regularization: TaskRegularization) -> Self {
105        self.task_regularization = regularization;
106        self
107    }
108
109    /// Set solver method
110    pub fn solver(mut self, solver: Solver) -> Self {
111        self.solver = solver;
112        self
113    }
114
115    /// Set random state for reproducibility
116    pub fn random_state(mut self, seed: u64) -> Self {
117        self.random_state = Some(seed);
118        self
119    }
120
121    /// Compute regularization penalty for the weight matrix
122    fn compute_task_regularization_penalty(&self, weights: &Array2<Float>) -> Float {
123        match &self.task_regularization {
124            TaskRegularization::None => 0.0,
125            TaskRegularization::L2 { beta } => {
126                // L2 penalty on differences between task weights
127                let mut penalty = 0.0;
128                let n_tasks = weights.ncols();
129                for i in 0..n_tasks {
130                    for j in (i + 1)..n_tasks {
131                        let diff = &weights.column(i) - &weights.column(j);
132                        penalty += diff.mapv(|x| x * x).sum();
133                    }
134                }
135                beta * penalty / (n_tasks * (n_tasks - 1) / 2) as Float
136            }
137            TaskRegularization::L1 { beta } => {
138                // L1 penalty on task weights
139                beta * weights.mapv(|x| x.abs()).sum()
140            }
141            TaskRegularization::NuclearNorm { beta } => {
142                // Nuclear norm (sum of singular values)
143                // Approximate with Frobenius norm for efficiency
144                beta * weights.mapv(|x| x * x).sum().sqrt()
145            }
146            TaskRegularization::GroupSparsity { beta } => {
147                // Group sparsity: L2,1 norm
148                let mut penalty = 0.0;
149                for row in weights.axis_iter(Axis(0)) {
150                    penalty += row.mapv(|x| x * x).sum().sqrt();
151                }
152                beta * penalty
153            }
154            TaskRegularization::Custom { beta, regularizer } => beta * regularizer(weights),
155        }
156    }
157}
158
159impl Estimator for MultiTaskKernelRidgeRegression<Untrained> {
160    type Config = ();
161    type Error = SklearsError;
162    type Float = Float;
163
164    fn config(&self) -> &Self::Config {
165        &()
166    }
167}
168
169impl Fit<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Untrained> {
170    type Fitted = MultiTaskKernelRidgeRegression<Trained>;
171
172    fn fit(self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Self::Fitted> {
173        if x.nrows() != y.nrows() {
174            return Err(SklearsError::InvalidInput(
175                "Number of samples must match".to_string(),
176            ));
177        }
178
179        let n_samples = x.nrows();
180        let n_tasks = y.ncols();
181
182        // Fit the feature transformer
183        let feature_transformer = self.fit_feature_transformer(x)?;
184        let x_transformed = feature_transformer.transform(x)?;
185        let n_features = x_transformed.ncols();
186
187        // Solve multi-task regression problem
188        let weights = match self.solver {
189            Solver::Direct => self.solve_direct_multitask(&x_transformed, y)?,
190            Solver::SVD => self.solve_svd_multitask(&x_transformed, y)?,
191            Solver::ConjugateGradient { max_iter, tol } => {
192                self.solve_cg_multitask(&x_transformed, y, max_iter, tol)?
193            }
194        };
195
196        Ok(MultiTaskKernelRidgeRegression {
197            approximation_method: self.approximation_method,
198            alpha: self.alpha,
199            task_regularization: self.task_regularization,
200            solver: self.solver,
201            random_state: self.random_state,
202            weights_: Some(weights),
203            feature_transformer_: Some(feature_transformer),
204            n_tasks_: Some(n_tasks),
205            _state: PhantomData,
206        })
207    }
208}
209
210impl MultiTaskKernelRidgeRegression<Untrained> {
211    /// Fit the feature transformer based on the approximation method
212    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
213        match &self.approximation_method {
214            ApproximationMethod::Nystroem {
215                kernel,
216                n_components,
217                sampling_strategy,
218            } => {
219                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
220                    .sampling_strategy(sampling_strategy.clone());
221                if let Some(seed) = self.random_state {
222                    nystroem = nystroem.random_state(seed);
223                }
224                let fitted = nystroem.fit(x, &())?;
225                Ok(FeatureTransformer::Nystroem(fitted))
226            }
227            ApproximationMethod::RandomFourierFeatures {
228                n_components,
229                gamma,
230            } => {
231                let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
232                if let Some(seed) = self.random_state {
233                    rff = rff.random_state(seed);
234                }
235                let fitted = rff.fit(x, &())?;
236                Ok(FeatureTransformer::RBFSampler(fitted))
237            }
238            ApproximationMethod::StructuredRandomFeatures {
239                n_components,
240                gamma,
241            } => {
242                let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
243                if let Some(seed) = self.random_state {
244                    srf = srf.random_state(seed);
245                }
246                let fitted = srf.fit(x, &())?;
247                Ok(FeatureTransformer::StructuredRFF(fitted))
248            }
249            ApproximationMethod::Fastfood {
250                n_components,
251                gamma,
252            } => {
253                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
254                if let Some(seed) = self.random_state {
255                    fastfood = fastfood.random_state(seed);
256                }
257                let fitted = fastfood.fit(x, &())?;
258                Ok(FeatureTransformer::Fastfood(fitted))
259            }
260        }
261    }
262
263    /// Solve multi-task problem using direct method
264    fn solve_direct_multitask(
265        &self,
266        x: &Array2<Float>,
267        y: &Array2<Float>,
268    ) -> Result<Array2<Float>> {
269        let n_features = x.ncols();
270        let n_tasks = y.ncols();
271
272        // For multi-task learning, we solve each task separately but with shared features
273        // and apply cross-task regularization
274        let mut all_weights = Array2::zeros((n_features, n_tasks));
275
276        for task_idx in 0..n_tasks {
277            let y_task = y.column(task_idx);
278
279            // Standard ridge regression for this task
280            let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]] as f64);
281            let y_task_f64 = Array1::from_vec(y_task.iter().map(|&val| val as f64).collect());
282
283            let xtx = x_f64.t().dot(&x_f64);
284            let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * (self.alpha as f64);
285
286            let xty = x_f64.t().dot(&y_task_f64);
287            let weights_task_f64 =
288                regularized_xtx
289                    .solve(&xty)
290                    .map_err(|e| SklearsError::InvalidParameter {
291                        name: "regularization".to_string(),
292                        reason: format!("Linear system solving failed: {:?}", e),
293                    })?;
294
295            // Convert back to Float
296            let weights_task =
297                Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
298            all_weights.column_mut(task_idx).assign(&weights_task);
299        }
300
301        // Apply cross-task regularization (simplified approach)
302        // In practice, you might want to solve a joint optimization problem
303        match &self.task_regularization {
304            TaskRegularization::L2 { beta } => {
305                // Apply additional regularization penalty
306                let mean_weight = all_weights.mean_axis(Axis(1)).unwrap();
307                for mut col in all_weights.axis_iter_mut(Axis(1)) {
308                    let diff = &col.to_owned() - &mean_weight;
309                    col.scaled_add(-beta, &diff);
310                }
311            }
312            _ => {} // Other regularization methods would be implemented here
313        }
314
315        Ok(all_weights)
316    }
317
318    /// Solve multi-task problem using SVD
319    fn solve_svd_multitask(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
320        let n_features = x.ncols();
321        let n_tasks = y.ncols();
322        let mut all_weights = Array2::zeros((n_features, n_tasks));
323
324        for task_idx in 0..n_tasks {
325            let y_task = y.column(task_idx);
326
327            // Use SVD for more stable solution
328            let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]] as f64);
329            let y_task_f64 = Array1::from_vec(y_task.iter().map(|&val| val as f64).collect());
330
331            let xtx = x_f64.t().dot(&x_f64);
332            let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * (self.alpha as f64);
333
334            let (u, s, vt) =
335                regularized_xtx
336                    .svd(true, true)
337                    .map_err(|e| SklearsError::InvalidParameter {
338                        name: "svd".to_string(),
339                        reason: format!("SVD decomposition failed: {:?}", e),
340                    })?;
341            let u = u.unwrap();
342            let vt = vt.unwrap();
343
344            // Solve using SVD
345            let xty = x_f64.t().dot(&y_task_f64);
346            let ut_b = u.t().dot(&xty);
347            let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
348            let y_svd = ut_b * s_inv;
349            let weights_task_f64 = vt.t().dot(&y_svd);
350
351            // Convert back to Float
352            let weights_task =
353                Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
354            all_weights.column_mut(task_idx).assign(&weights_task);
355        }
356
357        Ok(all_weights)
358    }
359
360    /// Solve multi-task problem using conjugate gradient
361    fn solve_cg_multitask(
362        &self,
363        x: &Array2<Float>,
364        y: &Array2<Float>,
365        max_iter: usize,
366        tol: Float,
367    ) -> Result<Array2<Float>> {
368        let n_features = x.ncols();
369        let n_tasks = y.ncols();
370        let mut all_weights = Array2::zeros((n_features, n_tasks));
371
372        for task_idx in 0..n_tasks {
373            let y_task = y.column(task_idx);
374            let xty = x.t().dot(&y_task);
375
376            // Conjugate gradient solver for each task
377            let mut weights = Array1::zeros(n_features);
378            let mut r = xty.clone();
379            let mut p = r.clone();
380            let mut rsold = r.dot(&r);
381
382            for _iter in 0..max_iter {
383                let xtx_p = x.t().dot(&x.dot(&p)) + &p * self.alpha;
384                let alpha_cg = rsold / p.dot(&xtx_p);
385
386                weights = weights + &p * alpha_cg;
387                r = r - &xtx_p * alpha_cg;
388
389                let rsnew = r.dot(&r);
390
391                if rsnew.sqrt() < tol {
392                    break;
393                }
394
395                let beta = rsnew / rsold;
396                p = &r + &p * beta;
397                rsold = rsnew;
398            }
399
400            all_weights.column_mut(task_idx).assign(&weights);
401        }
402
403        Ok(all_weights)
404    }
405}
406
407impl Predict<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Trained> {
408    fn predict(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
409        let feature_transformer =
410            self.feature_transformer_
411                .as_ref()
412                .ok_or_else(|| SklearsError::NotFitted {
413                    operation: "predict".to_string(),
414                })?;
415
416        let weights = self
417            .weights_
418            .as_ref()
419            .ok_or_else(|| SklearsError::NotFitted {
420                operation: "predict".to_string(),
421            })?;
422
423        let x_transformed = feature_transformer.transform(x)?;
424        let predictions = x_transformed.dot(weights);
425
426        Ok(predictions)
427    }
428}
429
430impl MultiTaskKernelRidgeRegression<Trained> {
431    /// Get the number of tasks
432    pub fn n_tasks(&self) -> usize {
433        self.n_tasks_.unwrap_or(0)
434    }
435
436    /// Get the fitted weights for all tasks
437    pub fn weights(&self) -> Option<&Array2<Float>> {
438        self.weights_.as_ref()
439    }
440
441    /// Get the weights for a specific task
442    pub fn task_weights(&self, task_idx: usize) -> Result<Array1<Float>> {
443        let weights = self
444            .weights_
445            .as_ref()
446            .ok_or_else(|| SklearsError::NotFitted {
447                operation: "predict".to_string(),
448            })?;
449
450        if task_idx >= weights.ncols() {
451            return Err(SklearsError::InvalidInput(format!(
452                "Task index {} out of range",
453                task_idx
454            )));
455        }
456
457        Ok(weights.column(task_idx).to_owned())
458    }
459
460    /// Predict for a specific task only
461    pub fn predict_task(&self, x: &Array2<Float>, task_idx: usize) -> Result<Array1<Float>> {
462        let predictions = self.predict(x)?;
463
464        if task_idx >= predictions.ncols() {
465            return Err(SklearsError::InvalidInput(format!(
466                "Task index {} out of range",
467                task_idx
468            )));
469        }
470
471        Ok(predictions.column(task_idx).to_owned())
472    }
473}
474
475#[allow(non_snake_case)]
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use scirs2_core::ndarray::array;
480
481    #[test]
482    fn test_multitask_kernel_ridge_regression() {
483        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
484        let y = array![[1.0, 2.0], [4.0, 5.0], [9.0, 10.0], [16.0, 17.0]]; // Two tasks
485
486        let approximation = ApproximationMethod::RandomFourierFeatures {
487            n_components: 20,
488            gamma: 0.1,
489        };
490
491        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
492        let fitted = mtkrr.fit(&x, &y).unwrap();
493        let predictions = fitted.predict(&x).unwrap();
494
495        assert_eq!(predictions.shape(), &[4, 2]);
496        assert_eq!(fitted.n_tasks(), 2);
497
498        // Test individual task prediction
499        let task0_pred = fitted.predict_task(&x, 0).unwrap();
500        let task1_pred = fitted.predict_task(&x, 1).unwrap();
501
502        assert_eq!(task0_pred.len(), 4);
503        assert_eq!(task1_pred.len(), 4);
504
505        // Check that predictions are reasonable
506        for pred in predictions.iter() {
507            assert!(pred.is_finite());
508        }
509    }
510
511    #[test]
512    fn test_multitask_with_regularization() {
513        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
514        let y = array![[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]; // Similar tasks
515
516        let approximation = ApproximationMethod::RandomFourierFeatures {
517            n_components: 15,
518            gamma: 1.0,
519        };
520
521        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation)
522            .alpha(0.1)
523            .task_regularization(TaskRegularization::L2 { beta: 0.1 });
524
525        let fitted = mtkrr.fit(&x, &y).unwrap();
526        let predictions = fitted.predict(&x).unwrap();
527
528        assert_eq!(predictions.shape(), &[3, 2]);
529
530        // With L2 regularization, task weights should be similar
531        let weights = fitted.weights().unwrap();
532        let task0_weights = weights.column(0);
533        let task1_weights = weights.column(1);
534        let weight_diff = (&task0_weights - &task1_weights)
535            .mapv(|x| x.abs())
536            .mean()
537            .unwrap();
538
539        // Tasks should have similar weights due to regularization
540        assert!(weight_diff < 1.0);
541    }
542
543    #[test]
544    fn test_multitask_different_solvers() {
545        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
546        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
547
548        let approximation = ApproximationMethod::RandomFourierFeatures {
549            n_components: 10,
550            gamma: 1.0,
551        };
552
553        // Test different solvers
554        let solvers = vec![
555            Solver::Direct,
556            Solver::SVD,
557            Solver::ConjugateGradient {
558                max_iter: 100,
559                tol: 1e-6,
560            },
561        ];
562
563        for solver in solvers {
564            let mtkrr = MultiTaskKernelRidgeRegression::new(approximation.clone())
565                .solver(solver)
566                .alpha(0.1);
567
568            let fitted = mtkrr.fit(&x, &y).unwrap();
569            let predictions = fitted.predict(&x).unwrap();
570
571            assert_eq!(predictions.shape(), &[3, 2]);
572        }
573    }
574
575    #[test]
576    fn test_multitask_single_task() {
577        // Test that multi-task regression works with a single task
578        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
579        let y = array![[1.0], [2.0], [3.0]]; // Single task
580
581        let approximation = ApproximationMethod::RandomFourierFeatures {
582            n_components: 10,
583            gamma: 1.0,
584        };
585
586        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
587        let fitted = mtkrr.fit(&x, &y).unwrap();
588        let predictions = fitted.predict(&x).unwrap();
589
590        assert_eq!(predictions.shape(), &[3, 1]);
591        assert_eq!(fitted.n_tasks(), 1);
592    }
593
594    #[test]
595    fn test_multitask_reproducibility() {
596        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
597        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
598
599        let approximation = ApproximationMethod::RandomFourierFeatures {
600            n_components: 10,
601            gamma: 1.0,
602        };
603
604        let mtkrr1 = MultiTaskKernelRidgeRegression::new(approximation.clone())
605            .alpha(0.1)
606            .random_state(42);
607        let fitted1 = mtkrr1.fit(&x, &y).unwrap();
608        let pred1 = fitted1.predict(&x).unwrap();
609
610        let mtkrr2 = MultiTaskKernelRidgeRegression::new(approximation)
611            .alpha(0.1)
612            .random_state(42);
613        let fitted2 = mtkrr2.fit(&x, &y).unwrap();
614        let pred2 = fitted2.predict(&x).unwrap();
615
616        assert_eq!(pred1.shape(), pred2.shape());
617        for i in 0..pred1.nrows() {
618            for j in 0..pred1.ncols() {
619                assert!((pred1[[i, j]] - pred2[[i, j]]).abs() < 1e-10);
620            }
621        }
622    }
623
624    #[test]
625    fn test_task_regularization_penalties() {
626        let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
627
628        let model =
629            MultiTaskKernelRidgeRegression::new(ApproximationMethod::RandomFourierFeatures {
630                n_components: 10,
631                gamma: 1.0,
632            });
633
634        // Test different regularization types
635        let reg_l2 = TaskRegularization::L2 { beta: 0.1 };
636        let reg_l1 = TaskRegularization::L1 { beta: 0.1 };
637        let reg_nuclear = TaskRegularization::NuclearNorm { beta: 0.1 };
638        let reg_group = TaskRegularization::GroupSparsity { beta: 0.1 };
639
640        let model_l2 = model.clone().task_regularization(reg_l2);
641        let model_l1 = model.clone().task_regularization(reg_l1);
642        let model_nuclear = model.clone().task_regularization(reg_nuclear);
643        let model_group = model.clone().task_regularization(reg_group);
644
645        let penalty_l2 = model_l2.compute_task_regularization_penalty(&weights);
646        let penalty_l1 = model_l1.compute_task_regularization_penalty(&weights);
647        let penalty_nuclear = model_nuclear.compute_task_regularization_penalty(&weights);
648        let penalty_group = model_group.compute_task_regularization_penalty(&weights);
649
650        // All penalties should be non-negative
651        assert!(penalty_l2 >= 0.0);
652        assert!(penalty_l1 >= 0.0);
653        assert!(penalty_nuclear >= 0.0);
654        assert!(penalty_group >= 0.0);
655
656        // L1 penalty should be larger than others for this matrix
657        assert!(penalty_l1 > penalty_l2);
658    }
659}