1use crate::{
8 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11use scirs2_core::ndarray::{Array1, Array2, Axis};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20#[derive(Debug, Clone)]
36pub struct MultiTaskKernelRidgeRegression<State = Untrained> {
37 pub approximation_method: ApproximationMethod,
38 pub alpha: Float,
39 pub task_regularization: TaskRegularization,
40 pub solver: Solver,
41 pub random_state: Option<u64>,
42
43 weights_: Option<Array2<Float>>, feature_transformer_: Option<FeatureTransformer>,
46 n_tasks_: Option<usize>,
47
48 _state: PhantomData<State>,
49}
50
51#[derive(Debug, Clone)]
53pub enum TaskRegularization {
54 None,
56 L2 { beta: Float },
58 L1 { beta: Float },
60 NuclearNorm { beta: Float },
62 GroupSparsity { beta: Float },
64 Custom {
66 beta: Float,
67 regularizer: fn(&Array2<Float>) -> Float,
68 },
69}
70
71impl Default for TaskRegularization {
72 fn default() -> Self {
73 Self::None
74 }
75}
76
77impl MultiTaskKernelRidgeRegression<Untrained> {
78 pub fn new(approximation_method: ApproximationMethod) -> Self {
80 Self {
81 approximation_method,
82 alpha: 1.0,
83 task_regularization: TaskRegularization::None,
84 solver: Solver::Direct,
85 random_state: None,
86 weights_: None,
87 feature_transformer_: None,
88 n_tasks_: None,
89 _state: PhantomData,
90 }
91 }
92
93 pub fn alpha(mut self, alpha: Float) -> Self {
95 self.alpha = alpha;
96 self
97 }
98
99 pub fn task_regularization(mut self, regularization: TaskRegularization) -> Self {
101 self.task_regularization = regularization;
102 self
103 }
104
105 pub fn solver(mut self, solver: Solver) -> Self {
107 self.solver = solver;
108 self
109 }
110
111 pub fn random_state(mut self, seed: u64) -> Self {
113 self.random_state = Some(seed);
114 self
115 }
116
117 fn compute_task_regularization_penalty(&self, weights: &Array2<Float>) -> Float {
119 match &self.task_regularization {
120 TaskRegularization::None => 0.0,
121 TaskRegularization::L2 { beta } => {
122 let mut penalty = 0.0;
124 let n_tasks = weights.ncols();
125 for i in 0..n_tasks {
126 for j in (i + 1)..n_tasks {
127 let diff = &weights.column(i) - &weights.column(j);
128 penalty += diff.mapv(|x| x * x).sum();
129 }
130 }
131 beta * penalty / (n_tasks * (n_tasks - 1) / 2) as Float
132 }
133 TaskRegularization::L1 { beta } => {
134 beta * weights.mapv(|x| x.abs()).sum()
136 }
137 TaskRegularization::NuclearNorm { beta } => {
138 beta * weights.mapv(|x| x * x).sum().sqrt()
141 }
142 TaskRegularization::GroupSparsity { beta } => {
143 let mut penalty = 0.0;
145 for row in weights.axis_iter(Axis(0)) {
146 penalty += row.mapv(|x| x * x).sum().sqrt();
147 }
148 beta * penalty
149 }
150 TaskRegularization::Custom { beta, regularizer } => beta * regularizer(weights),
151 }
152 }
153}
154
155impl Estimator for MultiTaskKernelRidgeRegression<Untrained> {
156 type Config = ();
157 type Error = SklearsError;
158 type Float = Float;
159
160 fn config(&self) -> &Self::Config {
161 &()
162 }
163}
164
165impl Fit<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Untrained> {
166 type Fitted = MultiTaskKernelRidgeRegression<Trained>;
167
168 fn fit(self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Self::Fitted> {
169 if x.nrows() != y.nrows() {
170 return Err(SklearsError::InvalidInput(
171 "Number of samples must match".to_string(),
172 ));
173 }
174
175 let _n_samples = x.nrows();
176 let n_tasks = y.ncols();
177
178 let feature_transformer = self.fit_feature_transformer(x)?;
180 let x_transformed = feature_transformer.transform(x)?;
181 let _n_features = x_transformed.ncols();
182
183 let weights = match self.solver {
185 Solver::Direct => self.solve_direct_multitask(&x_transformed, y)?,
186 Solver::SVD => self.solve_svd_multitask(&x_transformed, y)?,
187 Solver::ConjugateGradient { max_iter, tol } => {
188 self.solve_cg_multitask(&x_transformed, y, max_iter, tol)?
189 }
190 };
191
192 Ok(MultiTaskKernelRidgeRegression {
193 approximation_method: self.approximation_method,
194 alpha: self.alpha,
195 task_regularization: self.task_regularization,
196 solver: self.solver,
197 random_state: self.random_state,
198 weights_: Some(weights),
199 feature_transformer_: Some(feature_transformer),
200 n_tasks_: Some(n_tasks),
201 _state: PhantomData,
202 })
203 }
204}
205
206impl MultiTaskKernelRidgeRegression<Untrained> {
207 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
209 match &self.approximation_method {
210 ApproximationMethod::Nystroem {
211 kernel,
212 n_components,
213 sampling_strategy,
214 } => {
215 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
216 .sampling_strategy(sampling_strategy.clone());
217 if let Some(seed) = self.random_state {
218 nystroem = nystroem.random_state(seed);
219 }
220 let fitted = nystroem.fit(x, &())?;
221 Ok(FeatureTransformer::Nystroem(fitted))
222 }
223 ApproximationMethod::RandomFourierFeatures {
224 n_components,
225 gamma,
226 } => {
227 let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
228 if let Some(seed) = self.random_state {
229 rff = rff.random_state(seed);
230 }
231 let fitted = rff.fit(x, &())?;
232 Ok(FeatureTransformer::RBFSampler(fitted))
233 }
234 ApproximationMethod::StructuredRandomFeatures {
235 n_components,
236 gamma,
237 } => {
238 let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
239 if let Some(seed) = self.random_state {
240 srf = srf.random_state(seed);
241 }
242 let fitted = srf.fit(x, &())?;
243 Ok(FeatureTransformer::StructuredRFF(fitted))
244 }
245 ApproximationMethod::Fastfood {
246 n_components,
247 gamma,
248 } => {
249 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
250 if let Some(seed) = self.random_state {
251 fastfood = fastfood.random_state(seed);
252 }
253 let fitted = fastfood.fit(x, &())?;
254 Ok(FeatureTransformer::Fastfood(fitted))
255 }
256 }
257 }
258
259 fn solve_direct_multitask(
261 &self,
262 x: &Array2<Float>,
263 y: &Array2<Float>,
264 ) -> Result<Array2<Float>> {
265 let n_features = x.ncols();
266 let n_tasks = y.ncols();
267
268 let mut all_weights = Array2::zeros((n_features, n_tasks));
271
272 for task_idx in 0..n_tasks {
273 let y_task = y.column(task_idx);
274
275 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
277 let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
278
279 let xtx = x_f64.t().dot(&x_f64);
280 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
281
282 let xty = x_f64.t().dot(&y_task_f64);
283 let weights_task_f64 =
284 regularized_xtx
285 .solve(&xty)
286 .map_err(|e| SklearsError::InvalidParameter {
287 name: "regularization".to_string(),
288 reason: format!("Linear system solving failed: {:?}", e),
289 })?;
290
291 let weights_task =
293 Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
294 all_weights.column_mut(task_idx).assign(&weights_task);
295 }
296
297 match &self.task_regularization {
300 TaskRegularization::L2 { beta } => {
301 let mean_weight = all_weights.mean_axis(Axis(1)).unwrap();
303 for mut col in all_weights.axis_iter_mut(Axis(1)) {
304 let diff = &col.to_owned() - &mean_weight;
305 col.scaled_add(-beta, &diff);
306 }
307 }
308 _ => {} }
310
311 Ok(all_weights)
312 }
313
314 fn solve_svd_multitask(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
316 let n_features = x.ncols();
317 let n_tasks = y.ncols();
318 let mut all_weights = Array2::zeros((n_features, n_tasks));
319
320 for task_idx in 0..n_tasks {
321 let y_task = y.column(task_idx);
322
323 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
325 let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
326
327 let xtx = x_f64.t().dot(&x_f64);
328 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
329
330 let (u, s, vt) =
331 regularized_xtx
332 .svd(true)
333 .map_err(|e| SklearsError::InvalidParameter {
334 name: "svd".to_string(),
335 reason: format!("SVD decomposition failed: {:?}", e),
336 })?;
337
338 let xty = x_f64.t().dot(&y_task_f64);
340 let ut_b = u.t().dot(&xty);
341 let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
342 let y_svd = ut_b * s_inv;
343 let weights_task_f64 = vt.t().dot(&y_svd);
344
345 let weights_task =
347 Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
348 all_weights.column_mut(task_idx).assign(&weights_task);
349 }
350
351 Ok(all_weights)
352 }
353
354 fn solve_cg_multitask(
356 &self,
357 x: &Array2<Float>,
358 y: &Array2<Float>,
359 max_iter: usize,
360 tol: Float,
361 ) -> Result<Array2<Float>> {
362 let n_features = x.ncols();
363 let n_tasks = y.ncols();
364 let mut all_weights = Array2::zeros((n_features, n_tasks));
365
366 for task_idx in 0..n_tasks {
367 let y_task = y.column(task_idx);
368 let xty = x.t().dot(&y_task);
369
370 let mut weights = Array1::zeros(n_features);
372 let mut r = xty.clone();
373 let mut p = r.clone();
374 let mut rsold = r.dot(&r);
375
376 for _iter in 0..max_iter {
377 let xtx_p = x.t().dot(&x.dot(&p)) + &p * self.alpha;
378 let alpha_cg = rsold / p.dot(&xtx_p);
379
380 weights = weights + &p * alpha_cg;
381 r = r - &xtx_p * alpha_cg;
382
383 let rsnew = r.dot(&r);
384
385 if rsnew.sqrt() < tol {
386 break;
387 }
388
389 let beta = rsnew / rsold;
390 p = &r + &p * beta;
391 rsold = rsnew;
392 }
393
394 all_weights.column_mut(task_idx).assign(&weights);
395 }
396
397 Ok(all_weights)
398 }
399}
400
401impl Predict<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Trained> {
402 fn predict(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
403 let feature_transformer =
404 self.feature_transformer_
405 .as_ref()
406 .ok_or_else(|| SklearsError::NotFitted {
407 operation: "predict".to_string(),
408 })?;
409
410 let weights = self
411 .weights_
412 .as_ref()
413 .ok_or_else(|| SklearsError::NotFitted {
414 operation: "predict".to_string(),
415 })?;
416
417 let x_transformed = feature_transformer.transform(x)?;
418 let predictions = x_transformed.dot(weights);
419
420 Ok(predictions)
421 }
422}
423
424impl MultiTaskKernelRidgeRegression<Trained> {
425 pub fn n_tasks(&self) -> usize {
427 self.n_tasks_.unwrap_or(0)
428 }
429
430 pub fn weights(&self) -> Option<&Array2<Float>> {
432 self.weights_.as_ref()
433 }
434
435 pub fn task_weights(&self, task_idx: usize) -> Result<Array1<Float>> {
437 let weights = self
438 .weights_
439 .as_ref()
440 .ok_or_else(|| SklearsError::NotFitted {
441 operation: "predict".to_string(),
442 })?;
443
444 if task_idx >= weights.ncols() {
445 return Err(SklearsError::InvalidInput(format!(
446 "Task index {} out of range",
447 task_idx
448 )));
449 }
450
451 Ok(weights.column(task_idx).to_owned())
452 }
453
454 pub fn predict_task(&self, x: &Array2<Float>, task_idx: usize) -> Result<Array1<Float>> {
456 let predictions = self.predict(x)?;
457
458 if task_idx >= predictions.ncols() {
459 return Err(SklearsError::InvalidInput(format!(
460 "Task index {} out of range",
461 task_idx
462 )));
463 }
464
465 Ok(predictions.column(task_idx).to_owned())
466 }
467}
468
469#[allow(non_snake_case)]
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use scirs2_core::ndarray::array;
474
475 #[test]
476 fn test_multitask_kernel_ridge_regression() {
477 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
478 let y = array![[1.0, 2.0], [4.0, 5.0], [9.0, 10.0], [16.0, 17.0]]; let approximation = ApproximationMethod::RandomFourierFeatures {
481 n_components: 20,
482 gamma: 0.1,
483 };
484
485 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
486 let fitted = mtkrr.fit(&x, &y).unwrap();
487 let predictions = fitted.predict(&x).unwrap();
488
489 assert_eq!(predictions.shape(), &[4, 2]);
490 assert_eq!(fitted.n_tasks(), 2);
491
492 let task0_pred = fitted.predict_task(&x, 0).unwrap();
494 let task1_pred = fitted.predict_task(&x, 1).unwrap();
495
496 assert_eq!(task0_pred.len(), 4);
497 assert_eq!(task1_pred.len(), 4);
498
499 for pred in predictions.iter() {
501 assert!(pred.is_finite());
502 }
503 }
504
505 #[test]
506 fn test_multitask_with_regularization() {
507 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
508 let y = array![[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]; let approximation = ApproximationMethod::RandomFourierFeatures {
511 n_components: 15,
512 gamma: 1.0,
513 };
514
515 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation)
516 .alpha(0.1)
517 .task_regularization(TaskRegularization::L2 { beta: 0.1 });
518
519 let fitted = mtkrr.fit(&x, &y).unwrap();
520 let predictions = fitted.predict(&x).unwrap();
521
522 assert_eq!(predictions.shape(), &[3, 2]);
523
524 let weights = fitted.weights().unwrap();
526 let task0_weights = weights.column(0);
527 let task1_weights = weights.column(1);
528 let weight_diff = (&task0_weights - &task1_weights)
529 .mapv(|x| x.abs())
530 .mean()
531 .unwrap();
532
533 assert!(weight_diff < 1.0);
535 }
536
537 #[test]
538 fn test_multitask_different_solvers() {
539 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
540 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
541
542 let approximation = ApproximationMethod::RandomFourierFeatures {
543 n_components: 10,
544 gamma: 1.0,
545 };
546
547 let solvers = vec![
549 Solver::Direct,
550 Solver::SVD,
551 Solver::ConjugateGradient {
552 max_iter: 100,
553 tol: 1e-6,
554 },
555 ];
556
557 for solver in solvers {
558 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation.clone())
559 .solver(solver)
560 .alpha(0.1);
561
562 let fitted = mtkrr.fit(&x, &y).unwrap();
563 let predictions = fitted.predict(&x).unwrap();
564
565 assert_eq!(predictions.shape(), &[3, 2]);
566 }
567 }
568
569 #[test]
570 fn test_multitask_single_task() {
571 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
573 let y = array![[1.0], [2.0], [3.0]]; let approximation = ApproximationMethod::RandomFourierFeatures {
576 n_components: 10,
577 gamma: 1.0,
578 };
579
580 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
581 let fitted = mtkrr.fit(&x, &y).unwrap();
582 let predictions = fitted.predict(&x).unwrap();
583
584 assert_eq!(predictions.shape(), &[3, 1]);
585 assert_eq!(fitted.n_tasks(), 1);
586 }
587
588 #[test]
589 fn test_multitask_reproducibility() {
590 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
591 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
592
593 let approximation = ApproximationMethod::RandomFourierFeatures {
594 n_components: 10,
595 gamma: 1.0,
596 };
597
598 let mtkrr1 = MultiTaskKernelRidgeRegression::new(approximation.clone())
599 .alpha(0.1)
600 .random_state(42);
601 let fitted1 = mtkrr1.fit(&x, &y).unwrap();
602 let pred1 = fitted1.predict(&x).unwrap();
603
604 let mtkrr2 = MultiTaskKernelRidgeRegression::new(approximation)
605 .alpha(0.1)
606 .random_state(42);
607 let fitted2 = mtkrr2.fit(&x, &y).unwrap();
608 let pred2 = fitted2.predict(&x).unwrap();
609
610 assert_eq!(pred1.shape(), pred2.shape());
611 for i in 0..pred1.nrows() {
612 for j in 0..pred1.ncols() {
613 assert!((pred1[[i, j]] - pred2[[i, j]]).abs() < 1e-10);
614 }
615 }
616 }
617
618 #[test]
619 fn test_task_regularization_penalties() {
620 let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
621
622 let model =
623 MultiTaskKernelRidgeRegression::new(ApproximationMethod::RandomFourierFeatures {
624 n_components: 10,
625 gamma: 1.0,
626 });
627
628 let reg_l2 = TaskRegularization::L2 { beta: 0.1 };
630 let reg_l1 = TaskRegularization::L1 { beta: 0.1 };
631 let reg_nuclear = TaskRegularization::NuclearNorm { beta: 0.1 };
632 let reg_group = TaskRegularization::GroupSparsity { beta: 0.1 };
633
634 let model_l2 = model.clone().task_regularization(reg_l2);
635 let model_l1 = model.clone().task_regularization(reg_l1);
636 let model_nuclear = model.clone().task_regularization(reg_nuclear);
637 let model_group = model.clone().task_regularization(reg_group);
638
639 let penalty_l2 = model_l2.compute_task_regularization_penalty(&weights);
640 let penalty_l1 = model_l1.compute_task_regularization_penalty(&weights);
641 let penalty_nuclear = model_nuclear.compute_task_regularization_penalty(&weights);
642 let penalty_group = model_group.compute_task_regularization_penalty(&weights);
643
644 assert!(penalty_l2 >= 0.0);
646 assert!(penalty_l1 >= 0.0);
647 assert!(penalty_nuclear >= 0.0);
648 assert!(penalty_group >= 0.0);
649
650 assert!(penalty_l1 > penalty_l2);
652 }
653}