Skip to main content

sklears_kernel_approximation/kernel_ridge_regression/
multitask_regression.rs

1//! Multi-Task Kernel Ridge Regression Implementation
2//!
3//! This module provides multi-task learning capabilities for kernel ridge regression,
4//! allowing simultaneous learning across multiple related tasks with cross-task
5//! regularization strategies to improve generalization.
6
7use crate::{
8    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
12
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20/// Multi-Task Kernel Ridge Regression
21///
22/// Performs kernel ridge regression simultaneously across multiple related tasks,
23/// with optional cross-task regularization to encourage similarity between tasks.
24///
25/// This is particularly useful when you have multiple regression targets that are
26/// related and can benefit from shared representations and joint learning.
27///
28/// # Parameters
29///
30/// * `approximation_method` - Method for kernel approximation
31/// * `alpha` - Within-task regularization parameter
32/// * `task_regularization` - Cross-task regularization strategy
33/// * `solver` - Method for solving the linear system
34/// * `random_state` - Random seed for reproducibility
35#[derive(Debug, Clone)]
36pub struct MultiTaskKernelRidgeRegression<State = Untrained> {
37    pub approximation_method: ApproximationMethod,
38    pub alpha: Float,
39    pub task_regularization: TaskRegularization,
40    pub solver: Solver,
41    pub random_state: Option<u64>,
42
43    // Fitted parameters
44    weights_: Option<Array2<Float>>, // Shape: (n_features, n_tasks)
45    feature_transformer_: Option<FeatureTransformer>,
46    n_tasks_: Option<usize>,
47
48    _state: PhantomData<State>,
49}
50
51/// Cross-task regularization strategies for multi-task learning
52#[derive(Debug, Clone)]
53pub enum TaskRegularization {
54    /// No cross-task regularization (independent tasks)
55    None,
56    /// L2 regularization on task weight differences
57    L2 { beta: Float },
58    /// L1 regularization promoting sparsity across tasks
59    L1 { beta: Float },
60    /// Nuclear norm regularization on weight matrix (low-rank)
61    NuclearNorm { beta: Float },
62    /// Group sparsity regularization
63    GroupSparsity { beta: Float },
64    /// Custom regularization function
65    Custom {
66        beta: Float,
67        regularizer: fn(&Array2<Float>) -> Float,
68    },
69}
70
71impl Default for TaskRegularization {
72    fn default() -> Self {
73        Self::None
74    }
75}
76
77impl MultiTaskKernelRidgeRegression<Untrained> {
78    /// Create a new multi-task kernel ridge regression model
79    pub fn new(approximation_method: ApproximationMethod) -> Self {
80        Self {
81            approximation_method,
82            alpha: 1.0,
83            task_regularization: TaskRegularization::None,
84            solver: Solver::Direct,
85            random_state: None,
86            weights_: None,
87            feature_transformer_: None,
88            n_tasks_: None,
89            _state: PhantomData,
90        }
91    }
92
93    /// Set regularization parameter
94    pub fn alpha(mut self, alpha: Float) -> Self {
95        self.alpha = alpha;
96        self
97    }
98
99    /// Set cross-task regularization strategy
100    pub fn task_regularization(mut self, regularization: TaskRegularization) -> Self {
101        self.task_regularization = regularization;
102        self
103    }
104
105    /// Set solver method
106    pub fn solver(mut self, solver: Solver) -> Self {
107        self.solver = solver;
108        self
109    }
110
111    /// Set random state for reproducibility
112    pub fn random_state(mut self, seed: u64) -> Self {
113        self.random_state = Some(seed);
114        self
115    }
116
117    /// Compute regularization penalty for the weight matrix
118    fn compute_task_regularization_penalty(&self, weights: &Array2<Float>) -> Float {
119        match &self.task_regularization {
120            TaskRegularization::None => 0.0,
121            TaskRegularization::L2 { beta } => {
122                // L2 penalty on differences between task weights
123                let mut penalty = 0.0;
124                let n_tasks = weights.ncols();
125                for i in 0..n_tasks {
126                    for j in (i + 1)..n_tasks {
127                        let diff = &weights.column(i) - &weights.column(j);
128                        penalty += diff.mapv(|x| x * x).sum();
129                    }
130                }
131                beta * penalty / (n_tasks * (n_tasks - 1) / 2) as Float
132            }
133            TaskRegularization::L1 { beta } => {
134                // L1 penalty on task weights
135                beta * weights.mapv(|x| x.abs()).sum()
136            }
137            TaskRegularization::NuclearNorm { beta } => {
138                // Nuclear norm (sum of singular values)
139                // Approximate with Frobenius norm for efficiency
140                beta * weights.mapv(|x| x * x).sum().sqrt()
141            }
142            TaskRegularization::GroupSparsity { beta } => {
143                // Group sparsity: L2,1 norm
144                let mut penalty = 0.0;
145                for row in weights.axis_iter(Axis(0)) {
146                    penalty += row.mapv(|x| x * x).sum().sqrt();
147                }
148                beta * penalty
149            }
150            TaskRegularization::Custom { beta, regularizer } => beta * regularizer(weights),
151        }
152    }
153}
154
155impl Estimator for MultiTaskKernelRidgeRegression<Untrained> {
156    type Config = ();
157    type Error = SklearsError;
158    type Float = Float;
159
160    fn config(&self) -> &Self::Config {
161        &()
162    }
163}
164
165impl Fit<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Untrained> {
166    type Fitted = MultiTaskKernelRidgeRegression<Trained>;
167
168    fn fit(self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Self::Fitted> {
169        if x.nrows() != y.nrows() {
170            return Err(SklearsError::InvalidInput(
171                "Number of samples must match".to_string(),
172            ));
173        }
174
175        let _n_samples = x.nrows();
176        let n_tasks = y.ncols();
177
178        // Fit the feature transformer
179        let feature_transformer = self.fit_feature_transformer(x)?;
180        let x_transformed = feature_transformer.transform(x)?;
181        let _n_features = x_transformed.ncols();
182
183        // Solve multi-task regression problem
184        let weights = match self.solver {
185            Solver::Direct => self.solve_direct_multitask(&x_transformed, y)?,
186            Solver::SVD => self.solve_svd_multitask(&x_transformed, y)?,
187            Solver::ConjugateGradient { max_iter, tol } => {
188                self.solve_cg_multitask(&x_transformed, y, max_iter, tol)?
189            }
190        };
191
192        Ok(MultiTaskKernelRidgeRegression {
193            approximation_method: self.approximation_method,
194            alpha: self.alpha,
195            task_regularization: self.task_regularization,
196            solver: self.solver,
197            random_state: self.random_state,
198            weights_: Some(weights),
199            feature_transformer_: Some(feature_transformer),
200            n_tasks_: Some(n_tasks),
201            _state: PhantomData,
202        })
203    }
204}
205
206impl MultiTaskKernelRidgeRegression<Untrained> {
207    /// Fit the feature transformer based on the approximation method
208    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
209        match &self.approximation_method {
210            ApproximationMethod::Nystroem {
211                kernel,
212                n_components,
213                sampling_strategy,
214            } => {
215                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
216                    .sampling_strategy(sampling_strategy.clone());
217                if let Some(seed) = self.random_state {
218                    nystroem = nystroem.random_state(seed);
219                }
220                let fitted = nystroem.fit(x, &())?;
221                Ok(FeatureTransformer::Nystroem(fitted))
222            }
223            ApproximationMethod::RandomFourierFeatures {
224                n_components,
225                gamma,
226            } => {
227                let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
228                if let Some(seed) = self.random_state {
229                    rff = rff.random_state(seed);
230                }
231                let fitted = rff.fit(x, &())?;
232                Ok(FeatureTransformer::RBFSampler(fitted))
233            }
234            ApproximationMethod::StructuredRandomFeatures {
235                n_components,
236                gamma,
237            } => {
238                let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
239                if let Some(seed) = self.random_state {
240                    srf = srf.random_state(seed);
241                }
242                let fitted = srf.fit(x, &())?;
243                Ok(FeatureTransformer::StructuredRFF(fitted))
244            }
245            ApproximationMethod::Fastfood {
246                n_components,
247                gamma,
248            } => {
249                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
250                if let Some(seed) = self.random_state {
251                    fastfood = fastfood.random_state(seed);
252                }
253                let fitted = fastfood.fit(x, &())?;
254                Ok(FeatureTransformer::Fastfood(fitted))
255            }
256        }
257    }
258
259    /// Solve multi-task problem using direct method
260    fn solve_direct_multitask(
261        &self,
262        x: &Array2<Float>,
263        y: &Array2<Float>,
264    ) -> Result<Array2<Float>> {
265        let n_features = x.ncols();
266        let n_tasks = y.ncols();
267
268        // For multi-task learning, we solve each task separately but with shared features
269        // and apply cross-task regularization
270        let mut all_weights = Array2::zeros((n_features, n_tasks));
271
272        for task_idx in 0..n_tasks {
273            let y_task = y.column(task_idx);
274
275            // Standard ridge regression for this task
276            let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
277            let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
278
279            let xtx = x_f64.t().dot(&x_f64);
280            let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
281
282            let xty = x_f64.t().dot(&y_task_f64);
283            let weights_task_f64 =
284                regularized_xtx
285                    .solve(&xty)
286                    .map_err(|e| SklearsError::InvalidParameter {
287                        name: "regularization".to_string(),
288                        reason: format!("Linear system solving failed: {:?}", e),
289                    })?;
290
291            // Convert back to Float
292            let weights_task =
293                Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
294            all_weights.column_mut(task_idx).assign(&weights_task);
295        }
296
297        // Apply cross-task regularization (simplified approach)
298        // In practice, you might want to solve a joint optimization problem
299        match &self.task_regularization {
300            TaskRegularization::L2 { beta } => {
301                // Apply additional regularization penalty
302                let mean_weight = all_weights.mean_axis(Axis(1)).unwrap();
303                for mut col in all_weights.axis_iter_mut(Axis(1)) {
304                    let diff = &col.to_owned() - &mean_weight;
305                    col.scaled_add(-beta, &diff);
306                }
307            }
308            _ => {} // Other regularization methods would be implemented here
309        }
310
311        Ok(all_weights)
312    }
313
314    /// Solve multi-task problem using SVD
315    fn solve_svd_multitask(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
316        let n_features = x.ncols();
317        let n_tasks = y.ncols();
318        let mut all_weights = Array2::zeros((n_features, n_tasks));
319
320        for task_idx in 0..n_tasks {
321            let y_task = y.column(task_idx);
322
323            // Use SVD for more stable solution
324            let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
325            let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
326
327            let xtx = x_f64.t().dot(&x_f64);
328            let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
329
330            let (u, s, vt) =
331                regularized_xtx
332                    .svd(true)
333                    .map_err(|e| SklearsError::InvalidParameter {
334                        name: "svd".to_string(),
335                        reason: format!("SVD decomposition failed: {:?}", e),
336                    })?;
337
338            // Solve using SVD
339            let xty = x_f64.t().dot(&y_task_f64);
340            let ut_b = u.t().dot(&xty);
341            let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
342            let y_svd = ut_b * s_inv;
343            let weights_task_f64 = vt.t().dot(&y_svd);
344
345            // Convert back to Float
346            let weights_task =
347                Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
348            all_weights.column_mut(task_idx).assign(&weights_task);
349        }
350
351        Ok(all_weights)
352    }
353
354    /// Solve multi-task problem using conjugate gradient
355    fn solve_cg_multitask(
356        &self,
357        x: &Array2<Float>,
358        y: &Array2<Float>,
359        max_iter: usize,
360        tol: Float,
361    ) -> Result<Array2<Float>> {
362        let n_features = x.ncols();
363        let n_tasks = y.ncols();
364        let mut all_weights = Array2::zeros((n_features, n_tasks));
365
366        for task_idx in 0..n_tasks {
367            let y_task = y.column(task_idx);
368            let xty = x.t().dot(&y_task);
369
370            // Conjugate gradient solver for each task
371            let mut weights = Array1::zeros(n_features);
372            let mut r = xty.clone();
373            let mut p = r.clone();
374            let mut rsold = r.dot(&r);
375
376            for _iter in 0..max_iter {
377                let xtx_p = x.t().dot(&x.dot(&p)) + &p * self.alpha;
378                let alpha_cg = rsold / p.dot(&xtx_p);
379
380                weights = weights + &p * alpha_cg;
381                r = r - &xtx_p * alpha_cg;
382
383                let rsnew = r.dot(&r);
384
385                if rsnew.sqrt() < tol {
386                    break;
387                }
388
389                let beta = rsnew / rsold;
390                p = &r + &p * beta;
391                rsold = rsnew;
392            }
393
394            all_weights.column_mut(task_idx).assign(&weights);
395        }
396
397        Ok(all_weights)
398    }
399}
400
401impl Predict<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Trained> {
402    fn predict(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
403        let feature_transformer =
404            self.feature_transformer_
405                .as_ref()
406                .ok_or_else(|| SklearsError::NotFitted {
407                    operation: "predict".to_string(),
408                })?;
409
410        let weights = self
411            .weights_
412            .as_ref()
413            .ok_or_else(|| SklearsError::NotFitted {
414                operation: "predict".to_string(),
415            })?;
416
417        let x_transformed = feature_transformer.transform(x)?;
418        let predictions = x_transformed.dot(weights);
419
420        Ok(predictions)
421    }
422}
423
424impl MultiTaskKernelRidgeRegression<Trained> {
425    /// Get the number of tasks
426    pub fn n_tasks(&self) -> usize {
427        self.n_tasks_.unwrap_or(0)
428    }
429
430    /// Get the fitted weights for all tasks
431    pub fn weights(&self) -> Option<&Array2<Float>> {
432        self.weights_.as_ref()
433    }
434
435    /// Get the weights for a specific task
436    pub fn task_weights(&self, task_idx: usize) -> Result<Array1<Float>> {
437        let weights = self
438            .weights_
439            .as_ref()
440            .ok_or_else(|| SklearsError::NotFitted {
441                operation: "predict".to_string(),
442            })?;
443
444        if task_idx >= weights.ncols() {
445            return Err(SklearsError::InvalidInput(format!(
446                "Task index {} out of range",
447                task_idx
448            )));
449        }
450
451        Ok(weights.column(task_idx).to_owned())
452    }
453
454    /// Predict for a specific task only
455    pub fn predict_task(&self, x: &Array2<Float>, task_idx: usize) -> Result<Array1<Float>> {
456        let predictions = self.predict(x)?;
457
458        if task_idx >= predictions.ncols() {
459            return Err(SklearsError::InvalidInput(format!(
460                "Task index {} out of range",
461                task_idx
462            )));
463        }
464
465        Ok(predictions.column(task_idx).to_owned())
466    }
467}
468
469#[allow(non_snake_case)]
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use scirs2_core::ndarray::array;
474
475    #[test]
476    fn test_multitask_kernel_ridge_regression() {
477        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
478        let y = array![[1.0, 2.0], [4.0, 5.0], [9.0, 10.0], [16.0, 17.0]]; // Two tasks
479
480        let approximation = ApproximationMethod::RandomFourierFeatures {
481            n_components: 20,
482            gamma: 0.1,
483        };
484
485        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
486        let fitted = mtkrr.fit(&x, &y).unwrap();
487        let predictions = fitted.predict(&x).unwrap();
488
489        assert_eq!(predictions.shape(), &[4, 2]);
490        assert_eq!(fitted.n_tasks(), 2);
491
492        // Test individual task prediction
493        let task0_pred = fitted.predict_task(&x, 0).unwrap();
494        let task1_pred = fitted.predict_task(&x, 1).unwrap();
495
496        assert_eq!(task0_pred.len(), 4);
497        assert_eq!(task1_pred.len(), 4);
498
499        // Check that predictions are reasonable
500        for pred in predictions.iter() {
501            assert!(pred.is_finite());
502        }
503    }
504
505    #[test]
506    fn test_multitask_with_regularization() {
507        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
508        let y = array![[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]; // Similar tasks
509
510        let approximation = ApproximationMethod::RandomFourierFeatures {
511            n_components: 15,
512            gamma: 1.0,
513        };
514
515        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation)
516            .alpha(0.1)
517            .task_regularization(TaskRegularization::L2 { beta: 0.1 });
518
519        let fitted = mtkrr.fit(&x, &y).unwrap();
520        let predictions = fitted.predict(&x).unwrap();
521
522        assert_eq!(predictions.shape(), &[3, 2]);
523
524        // With L2 regularization, task weights should be similar
525        let weights = fitted.weights().unwrap();
526        let task0_weights = weights.column(0);
527        let task1_weights = weights.column(1);
528        let weight_diff = (&task0_weights - &task1_weights)
529            .mapv(|x| x.abs())
530            .mean()
531            .unwrap();
532
533        // Tasks should have similar weights due to regularization
534        assert!(weight_diff < 1.0);
535    }
536
537    #[test]
538    fn test_multitask_different_solvers() {
539        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
540        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
541
542        let approximation = ApproximationMethod::RandomFourierFeatures {
543            n_components: 10,
544            gamma: 1.0,
545        };
546
547        // Test different solvers
548        let solvers = vec![
549            Solver::Direct,
550            Solver::SVD,
551            Solver::ConjugateGradient {
552                max_iter: 100,
553                tol: 1e-6,
554            },
555        ];
556
557        for solver in solvers {
558            let mtkrr = MultiTaskKernelRidgeRegression::new(approximation.clone())
559                .solver(solver)
560                .alpha(0.1);
561
562            let fitted = mtkrr.fit(&x, &y).unwrap();
563            let predictions = fitted.predict(&x).unwrap();
564
565            assert_eq!(predictions.shape(), &[3, 2]);
566        }
567    }
568
569    #[test]
570    fn test_multitask_single_task() {
571        // Test that multi-task regression works with a single task
572        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
573        let y = array![[1.0], [2.0], [3.0]]; // Single task
574
575        let approximation = ApproximationMethod::RandomFourierFeatures {
576            n_components: 10,
577            gamma: 1.0,
578        };
579
580        let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
581        let fitted = mtkrr.fit(&x, &y).unwrap();
582        let predictions = fitted.predict(&x).unwrap();
583
584        assert_eq!(predictions.shape(), &[3, 1]);
585        assert_eq!(fitted.n_tasks(), 1);
586    }
587
588    #[test]
589    fn test_multitask_reproducibility() {
590        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
591        let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
592
593        let approximation = ApproximationMethod::RandomFourierFeatures {
594            n_components: 10,
595            gamma: 1.0,
596        };
597
598        let mtkrr1 = MultiTaskKernelRidgeRegression::new(approximation.clone())
599            .alpha(0.1)
600            .random_state(42);
601        let fitted1 = mtkrr1.fit(&x, &y).unwrap();
602        let pred1 = fitted1.predict(&x).unwrap();
603
604        let mtkrr2 = MultiTaskKernelRidgeRegression::new(approximation)
605            .alpha(0.1)
606            .random_state(42);
607        let fitted2 = mtkrr2.fit(&x, &y).unwrap();
608        let pred2 = fitted2.predict(&x).unwrap();
609
610        assert_eq!(pred1.shape(), pred2.shape());
611        for i in 0..pred1.nrows() {
612            for j in 0..pred1.ncols() {
613                assert!((pred1[[i, j]] - pred2[[i, j]]).abs() < 1e-10);
614            }
615        }
616    }
617
618    #[test]
619    fn test_task_regularization_penalties() {
620        let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
621
622        let model =
623            MultiTaskKernelRidgeRegression::new(ApproximationMethod::RandomFourierFeatures {
624                n_components: 10,
625                gamma: 1.0,
626            });
627
628        // Test different regularization types
629        let reg_l2 = TaskRegularization::L2 { beta: 0.1 };
630        let reg_l1 = TaskRegularization::L1 { beta: 0.1 };
631        let reg_nuclear = TaskRegularization::NuclearNorm { beta: 0.1 };
632        let reg_group = TaskRegularization::GroupSparsity { beta: 0.1 };
633
634        let model_l2 = model.clone().task_regularization(reg_l2);
635        let model_l1 = model.clone().task_regularization(reg_l1);
636        let model_nuclear = model.clone().task_regularization(reg_nuclear);
637        let model_group = model.clone().task_regularization(reg_group);
638
639        let penalty_l2 = model_l2.compute_task_regularization_penalty(&weights);
640        let penalty_l1 = model_l1.compute_task_regularization_penalty(&weights);
641        let penalty_nuclear = model_nuclear.compute_task_regularization_penalty(&weights);
642        let penalty_group = model_group.compute_task_regularization_penalty(&weights);
643
644        // All penalties should be non-negative
645        assert!(penalty_l2 >= 0.0);
646        assert!(penalty_l1 >= 0.0);
647        assert!(penalty_nuclear >= 0.0);
648        assert!(penalty_group >= 0.0);
649
650        // L1 penalty should be larger than others for this matrix
651        assert!(penalty_l1 > penalty_l2);
652    }
653}