1use crate::{
8 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_linalg::compat::ArrayLinalgExt;
11use scirs2_core::ndarray::{Array1, Array2, Axis};
14use sklears_core::error::{Result, SklearsError};
15use sklears_core::prelude::{Estimator, Fit, Float, Predict};
16use std::marker::PhantomData;
17
18use super::core_types::*;
19
20#[derive(Debug, Clone)]
41pub struct MultiTaskKernelRidgeRegression<State = Untrained> {
42 pub approximation_method: ApproximationMethod,
43 pub alpha: Float,
44 pub task_regularization: TaskRegularization,
45 pub solver: Solver,
46 pub random_state: Option<u64>,
47
48 weights_: Option<Array2<Float>>, feature_transformer_: Option<FeatureTransformer>,
51 n_tasks_: Option<usize>,
52
53 _state: PhantomData<State>,
54}
55
56#[derive(Debug, Clone)]
58pub enum TaskRegularization {
59 None,
61 L2 { beta: Float },
63 L1 { beta: Float },
65 NuclearNorm { beta: Float },
67 GroupSparsity { beta: Float },
69 Custom {
71 beta: Float,
72 regularizer: fn(&Array2<Float>) -> Float,
73 },
74}
75
76impl Default for TaskRegularization {
77 fn default() -> Self {
78 Self::None
79 }
80}
81
82impl MultiTaskKernelRidgeRegression<Untrained> {
83 pub fn new(approximation_method: ApproximationMethod) -> Self {
85 Self {
86 approximation_method,
87 alpha: 1.0,
88 task_regularization: TaskRegularization::None,
89 solver: Solver::Direct,
90 random_state: None,
91 weights_: None,
92 feature_transformer_: None,
93 n_tasks_: None,
94 _state: PhantomData,
95 }
96 }
97
98 pub fn alpha(mut self, alpha: Float) -> Self {
100 self.alpha = alpha;
101 self
102 }
103
104 pub fn task_regularization(mut self, regularization: TaskRegularization) -> Self {
106 self.task_regularization = regularization;
107 self
108 }
109
110 pub fn solver(mut self, solver: Solver) -> Self {
112 self.solver = solver;
113 self
114 }
115
116 pub fn random_state(mut self, seed: u64) -> Self {
118 self.random_state = Some(seed);
119 self
120 }
121
122 fn compute_task_regularization_penalty(&self, weights: &Array2<Float>) -> Float {
124 match &self.task_regularization {
125 TaskRegularization::None => 0.0,
126 TaskRegularization::L2 { beta } => {
127 let mut penalty = 0.0;
129 let n_tasks = weights.ncols();
130 for i in 0..n_tasks {
131 for j in (i + 1)..n_tasks {
132 let diff = &weights.column(i) - &weights.column(j);
133 penalty += diff.mapv(|x| x * x).sum();
134 }
135 }
136 beta * penalty / (n_tasks * (n_tasks - 1) / 2) as Float
137 }
138 TaskRegularization::L1 { beta } => {
139 beta * weights.mapv(|x| x.abs()).sum()
141 }
142 TaskRegularization::NuclearNorm { beta } => {
143 beta * weights.mapv(|x| x * x).sum().sqrt()
146 }
147 TaskRegularization::GroupSparsity { beta } => {
148 let mut penalty = 0.0;
150 for row in weights.axis_iter(Axis(0)) {
151 penalty += row.mapv(|x| x * x).sum().sqrt();
152 }
153 beta * penalty
154 }
155 TaskRegularization::Custom { beta, regularizer } => beta * regularizer(weights),
156 }
157 }
158}
159
160impl Estimator for MultiTaskKernelRidgeRegression<Untrained> {
161 type Config = ();
162 type Error = SklearsError;
163 type Float = Float;
164
165 fn config(&self) -> &Self::Config {
166 &()
167 }
168}
169
170impl Fit<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Untrained> {
171 type Fitted = MultiTaskKernelRidgeRegression<Trained>;
172
173 fn fit(self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Self::Fitted> {
174 if x.nrows() != y.nrows() {
175 return Err(SklearsError::InvalidInput(
176 "Number of samples must match".to_string(),
177 ));
178 }
179
180 let _n_samples = x.nrows();
181 let n_tasks = y.ncols();
182
183 let feature_transformer = self.fit_feature_transformer(x)?;
185 let x_transformed = feature_transformer.transform(x)?;
186 let _n_features = x_transformed.ncols();
187
188 let weights = match self.solver {
190 Solver::Direct => self.solve_direct_multitask(&x_transformed, y)?,
191 Solver::SVD => self.solve_svd_multitask(&x_transformed, y)?,
192 Solver::ConjugateGradient { max_iter, tol } => {
193 self.solve_cg_multitask(&x_transformed, y, max_iter, tol)?
194 }
195 };
196
197 Ok(MultiTaskKernelRidgeRegression {
198 approximation_method: self.approximation_method,
199 alpha: self.alpha,
200 task_regularization: self.task_regularization,
201 solver: self.solver,
202 random_state: self.random_state,
203 weights_: Some(weights),
204 feature_transformer_: Some(feature_transformer),
205 n_tasks_: Some(n_tasks),
206 _state: PhantomData,
207 })
208 }
209}
210
211impl MultiTaskKernelRidgeRegression<Untrained> {
212 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
214 match &self.approximation_method {
215 ApproximationMethod::Nystroem {
216 kernel,
217 n_components,
218 sampling_strategy,
219 } => {
220 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
221 .sampling_strategy(sampling_strategy.clone());
222 if let Some(seed) = self.random_state {
223 nystroem = nystroem.random_state(seed);
224 }
225 let fitted = nystroem.fit(x, &())?;
226 Ok(FeatureTransformer::Nystroem(fitted))
227 }
228 ApproximationMethod::RandomFourierFeatures {
229 n_components,
230 gamma,
231 } => {
232 let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
233 if let Some(seed) = self.random_state {
234 rff = rff.random_state(seed);
235 }
236 let fitted = rff.fit(x, &())?;
237 Ok(FeatureTransformer::RBFSampler(fitted))
238 }
239 ApproximationMethod::StructuredRandomFeatures {
240 n_components,
241 gamma,
242 } => {
243 let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
244 if let Some(seed) = self.random_state {
245 srf = srf.random_state(seed);
246 }
247 let fitted = srf.fit(x, &())?;
248 Ok(FeatureTransformer::StructuredRFF(fitted))
249 }
250 ApproximationMethod::Fastfood {
251 n_components,
252 gamma,
253 } => {
254 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
255 if let Some(seed) = self.random_state {
256 fastfood = fastfood.random_state(seed);
257 }
258 let fitted = fastfood.fit(x, &())?;
259 Ok(FeatureTransformer::Fastfood(fitted))
260 }
261 }
262 }
263
264 fn solve_direct_multitask(
266 &self,
267 x: &Array2<Float>,
268 y: &Array2<Float>,
269 ) -> Result<Array2<Float>> {
270 let n_features = x.ncols();
271 let n_tasks = y.ncols();
272
273 let mut all_weights = Array2::zeros((n_features, n_tasks));
276
277 for task_idx in 0..n_tasks {
278 let y_task = y.column(task_idx);
279
280 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
282 let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
283
284 let xtx = x_f64.t().dot(&x_f64);
285 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
286
287 let xty = x_f64.t().dot(&y_task_f64);
288 let weights_task_f64 =
289 regularized_xtx
290 .solve(&xty)
291 .map_err(|e| SklearsError::InvalidParameter {
292 name: "regularization".to_string(),
293 reason: format!("Linear system solving failed: {:?}", e),
294 })?;
295
296 let weights_task =
298 Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
299 all_weights.column_mut(task_idx).assign(&weights_task);
300 }
301
302 match &self.task_regularization {
305 TaskRegularization::L2 { beta } => {
306 let mean_weight = all_weights.mean_axis(Axis(1)).unwrap();
308 for mut col in all_weights.axis_iter_mut(Axis(1)) {
309 let diff = &col.to_owned() - &mean_weight;
310 col.scaled_add(-beta, &diff);
311 }
312 }
313 _ => {} }
315
316 Ok(all_weights)
317 }
318
319 fn solve_svd_multitask(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
321 let n_features = x.ncols();
322 let n_tasks = y.ncols();
323 let mut all_weights = Array2::zeros((n_features, n_tasks));
324
325 for task_idx in 0..n_tasks {
326 let y_task = y.column(task_idx);
327
328 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
330 let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
331
332 let xtx = x_f64.t().dot(&x_f64);
333 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
334
335 let (u, s, vt) =
336 regularized_xtx
337 .svd(true)
338 .map_err(|e| SklearsError::InvalidParameter {
339 name: "svd".to_string(),
340 reason: format!("SVD decomposition failed: {:?}", e),
341 })?;
342
343 let xty = x_f64.t().dot(&y_task_f64);
345 let ut_b = u.t().dot(&xty);
346 let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
347 let y_svd = ut_b * s_inv;
348 let weights_task_f64 = vt.t().dot(&y_svd);
349
350 let weights_task =
352 Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
353 all_weights.column_mut(task_idx).assign(&weights_task);
354 }
355
356 Ok(all_weights)
357 }
358
359 fn solve_cg_multitask(
361 &self,
362 x: &Array2<Float>,
363 y: &Array2<Float>,
364 max_iter: usize,
365 tol: Float,
366 ) -> Result<Array2<Float>> {
367 let n_features = x.ncols();
368 let n_tasks = y.ncols();
369 let mut all_weights = Array2::zeros((n_features, n_tasks));
370
371 for task_idx in 0..n_tasks {
372 let y_task = y.column(task_idx);
373 let xty = x.t().dot(&y_task);
374
375 let mut weights = Array1::zeros(n_features);
377 let mut r = xty.clone();
378 let mut p = r.clone();
379 let mut rsold = r.dot(&r);
380
381 for _iter in 0..max_iter {
382 let xtx_p = x.t().dot(&x.dot(&p)) + &p * self.alpha;
383 let alpha_cg = rsold / p.dot(&xtx_p);
384
385 weights = weights + &p * alpha_cg;
386 r = r - &xtx_p * alpha_cg;
387
388 let rsnew = r.dot(&r);
389
390 if rsnew.sqrt() < tol {
391 break;
392 }
393
394 let beta = rsnew / rsold;
395 p = &r + &p * beta;
396 rsold = rsnew;
397 }
398
399 all_weights.column_mut(task_idx).assign(&weights);
400 }
401
402 Ok(all_weights)
403 }
404}
405
406impl Predict<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Trained> {
407 fn predict(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
408 let feature_transformer =
409 self.feature_transformer_
410 .as_ref()
411 .ok_or_else(|| SklearsError::NotFitted {
412 operation: "predict".to_string(),
413 })?;
414
415 let weights = self
416 .weights_
417 .as_ref()
418 .ok_or_else(|| SklearsError::NotFitted {
419 operation: "predict".to_string(),
420 })?;
421
422 let x_transformed = feature_transformer.transform(x)?;
423 let predictions = x_transformed.dot(weights);
424
425 Ok(predictions)
426 }
427}
428
429impl MultiTaskKernelRidgeRegression<Trained> {
430 pub fn n_tasks(&self) -> usize {
432 self.n_tasks_.unwrap_or(0)
433 }
434
435 pub fn weights(&self) -> Option<&Array2<Float>> {
437 self.weights_.as_ref()
438 }
439
440 pub fn task_weights(&self, task_idx: usize) -> Result<Array1<Float>> {
442 let weights = self
443 .weights_
444 .as_ref()
445 .ok_or_else(|| SklearsError::NotFitted {
446 operation: "predict".to_string(),
447 })?;
448
449 if task_idx >= weights.ncols() {
450 return Err(SklearsError::InvalidInput(format!(
451 "Task index {} out of range",
452 task_idx
453 )));
454 }
455
456 Ok(weights.column(task_idx).to_owned())
457 }
458
459 pub fn predict_task(&self, x: &Array2<Float>, task_idx: usize) -> Result<Array1<Float>> {
461 let predictions = self.predict(x)?;
462
463 if task_idx >= predictions.ncols() {
464 return Err(SklearsError::InvalidInput(format!(
465 "Task index {} out of range",
466 task_idx
467 )));
468 }
469
470 Ok(predictions.column(task_idx).to_owned())
471 }
472}
473
474#[allow(non_snake_case)]
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use scirs2_core::ndarray::array;
479
480 #[test]
481 fn test_multitask_kernel_ridge_regression() {
482 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
483 let y = array![[1.0, 2.0], [4.0, 5.0], [9.0, 10.0], [16.0, 17.0]]; let approximation = ApproximationMethod::RandomFourierFeatures {
486 n_components: 20,
487 gamma: 0.1,
488 };
489
490 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
491 let fitted = mtkrr.fit(&x, &y).unwrap();
492 let predictions = fitted.predict(&x).unwrap();
493
494 assert_eq!(predictions.shape(), &[4, 2]);
495 assert_eq!(fitted.n_tasks(), 2);
496
497 let task0_pred = fitted.predict_task(&x, 0).unwrap();
499 let task1_pred = fitted.predict_task(&x, 1).unwrap();
500
501 assert_eq!(task0_pred.len(), 4);
502 assert_eq!(task1_pred.len(), 4);
503
504 for pred in predictions.iter() {
506 assert!(pred.is_finite());
507 }
508 }
509
510 #[test]
511 fn test_multitask_with_regularization() {
512 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
513 let y = array![[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]; let approximation = ApproximationMethod::RandomFourierFeatures {
516 n_components: 15,
517 gamma: 1.0,
518 };
519
520 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation)
521 .alpha(0.1)
522 .task_regularization(TaskRegularization::L2 { beta: 0.1 });
523
524 let fitted = mtkrr.fit(&x, &y).unwrap();
525 let predictions = fitted.predict(&x).unwrap();
526
527 assert_eq!(predictions.shape(), &[3, 2]);
528
529 let weights = fitted.weights().unwrap();
531 let task0_weights = weights.column(0);
532 let task1_weights = weights.column(1);
533 let weight_diff = (&task0_weights - &task1_weights)
534 .mapv(|x| x.abs())
535 .mean()
536 .unwrap();
537
538 assert!(weight_diff < 1.0);
540 }
541
542 #[test]
543 fn test_multitask_different_solvers() {
544 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
545 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
546
547 let approximation = ApproximationMethod::RandomFourierFeatures {
548 n_components: 10,
549 gamma: 1.0,
550 };
551
552 let solvers = vec![
554 Solver::Direct,
555 Solver::SVD,
556 Solver::ConjugateGradient {
557 max_iter: 100,
558 tol: 1e-6,
559 },
560 ];
561
562 for solver in solvers {
563 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation.clone())
564 .solver(solver)
565 .alpha(0.1);
566
567 let fitted = mtkrr.fit(&x, &y).unwrap();
568 let predictions = fitted.predict(&x).unwrap();
569
570 assert_eq!(predictions.shape(), &[3, 2]);
571 }
572 }
573
574 #[test]
575 fn test_multitask_single_task() {
576 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
578 let y = array![[1.0], [2.0], [3.0]]; let approximation = ApproximationMethod::RandomFourierFeatures {
581 n_components: 10,
582 gamma: 1.0,
583 };
584
585 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
586 let fitted = mtkrr.fit(&x, &y).unwrap();
587 let predictions = fitted.predict(&x).unwrap();
588
589 assert_eq!(predictions.shape(), &[3, 1]);
590 assert_eq!(fitted.n_tasks(), 1);
591 }
592
593 #[test]
594 fn test_multitask_reproducibility() {
595 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
596 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
597
598 let approximation = ApproximationMethod::RandomFourierFeatures {
599 n_components: 10,
600 gamma: 1.0,
601 };
602
603 let mtkrr1 = MultiTaskKernelRidgeRegression::new(approximation.clone())
604 .alpha(0.1)
605 .random_state(42);
606 let fitted1 = mtkrr1.fit(&x, &y).unwrap();
607 let pred1 = fitted1.predict(&x).unwrap();
608
609 let mtkrr2 = MultiTaskKernelRidgeRegression::new(approximation)
610 .alpha(0.1)
611 .random_state(42);
612 let fitted2 = mtkrr2.fit(&x, &y).unwrap();
613 let pred2 = fitted2.predict(&x).unwrap();
614
615 assert_eq!(pred1.shape(), pred2.shape());
616 for i in 0..pred1.nrows() {
617 for j in 0..pred1.ncols() {
618 assert!((pred1[[i, j]] - pred2[[i, j]]).abs() < 1e-10);
619 }
620 }
621 }
622
623 #[test]
624 fn test_task_regularization_penalties() {
625 let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
626
627 let model =
628 MultiTaskKernelRidgeRegression::new(ApproximationMethod::RandomFourierFeatures {
629 n_components: 10,
630 gamma: 1.0,
631 });
632
633 let reg_l2 = TaskRegularization::L2 { beta: 0.1 };
635 let reg_l1 = TaskRegularization::L1 { beta: 0.1 };
636 let reg_nuclear = TaskRegularization::NuclearNorm { beta: 0.1 };
637 let reg_group = TaskRegularization::GroupSparsity { beta: 0.1 };
638
639 let model_l2 = model.clone().task_regularization(reg_l2);
640 let model_l1 = model.clone().task_regularization(reg_l1);
641 let model_nuclear = model.clone().task_regularization(reg_nuclear);
642 let model_group = model.clone().task_regularization(reg_group);
643
644 let penalty_l2 = model_l2.compute_task_regularization_penalty(&weights);
645 let penalty_l1 = model_l1.compute_task_regularization_penalty(&weights);
646 let penalty_nuclear = model_nuclear.compute_task_regularization_penalty(&weights);
647 let penalty_group = model_group.compute_task_regularization_penalty(&weights);
648
649 assert!(penalty_l2 >= 0.0);
651 assert!(penalty_l1 >= 0.0);
652 assert!(penalty_nuclear >= 0.0);
653 assert!(penalty_group >= 0.0);
654
655 assert!(penalty_l1 > penalty_l2);
657 }
658}