1use crate::{
8 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
9};
10use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
11use scirs2_core::ndarray::ndarray_linalg::SVD;
12use scirs2_core::ndarray::{Array1, Array2, Axis};
13use sklears_core::error::{Result, SklearsError};
14use sklears_core::prelude::{Estimator, Fit, Float, Predict};
15use std::marker::PhantomData;
16
17use super::core_types::*;
18
19#[derive(Debug, Clone)]
40pub struct MultiTaskKernelRidgeRegression<State = Untrained> {
41 pub approximation_method: ApproximationMethod,
42 pub alpha: Float,
43 pub task_regularization: TaskRegularization,
44 pub solver: Solver,
45 pub random_state: Option<u64>,
46
47 weights_: Option<Array2<Float>>, feature_transformer_: Option<FeatureTransformer>,
50 n_tasks_: Option<usize>,
51
52 _state: PhantomData<State>,
53}
54
55#[derive(Debug, Clone)]
57pub enum TaskRegularization {
58 None,
60 L2 { beta: Float },
62 L1 { beta: Float },
64 NuclearNorm { beta: Float },
66 GroupSparsity { beta: Float },
68 Custom {
70 beta: Float,
71 regularizer: fn(&Array2<Float>) -> Float,
72 },
73}
74
75impl Default for TaskRegularization {
76 fn default() -> Self {
77 Self::None
78 }
79}
80
81impl MultiTaskKernelRidgeRegression<Untrained> {
82 pub fn new(approximation_method: ApproximationMethod) -> Self {
84 Self {
85 approximation_method,
86 alpha: 1.0,
87 task_regularization: TaskRegularization::None,
88 solver: Solver::Direct,
89 random_state: None,
90 weights_: None,
91 feature_transformer_: None,
92 n_tasks_: None,
93 _state: PhantomData,
94 }
95 }
96
97 pub fn alpha(mut self, alpha: Float) -> Self {
99 self.alpha = alpha;
100 self
101 }
102
103 pub fn task_regularization(mut self, regularization: TaskRegularization) -> Self {
105 self.task_regularization = regularization;
106 self
107 }
108
109 pub fn solver(mut self, solver: Solver) -> Self {
111 self.solver = solver;
112 self
113 }
114
115 pub fn random_state(mut self, seed: u64) -> Self {
117 self.random_state = Some(seed);
118 self
119 }
120
121 fn compute_task_regularization_penalty(&self, weights: &Array2<Float>) -> Float {
123 match &self.task_regularization {
124 TaskRegularization::None => 0.0,
125 TaskRegularization::L2 { beta } => {
126 let mut penalty = 0.0;
128 let n_tasks = weights.ncols();
129 for i in 0..n_tasks {
130 for j in (i + 1)..n_tasks {
131 let diff = &weights.column(i) - &weights.column(j);
132 penalty += diff.mapv(|x| x * x).sum();
133 }
134 }
135 beta * penalty / (n_tasks * (n_tasks - 1) / 2) as Float
136 }
137 TaskRegularization::L1 { beta } => {
138 beta * weights.mapv(|x| x.abs()).sum()
140 }
141 TaskRegularization::NuclearNorm { beta } => {
142 beta * weights.mapv(|x| x * x).sum().sqrt()
145 }
146 TaskRegularization::GroupSparsity { beta } => {
147 let mut penalty = 0.0;
149 for row in weights.axis_iter(Axis(0)) {
150 penalty += row.mapv(|x| x * x).sum().sqrt();
151 }
152 beta * penalty
153 }
154 TaskRegularization::Custom { beta, regularizer } => beta * regularizer(weights),
155 }
156 }
157}
158
159impl Estimator for MultiTaskKernelRidgeRegression<Untrained> {
160 type Config = ();
161 type Error = SklearsError;
162 type Float = Float;
163
164 fn config(&self) -> &Self::Config {
165 &()
166 }
167}
168
169impl Fit<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Untrained> {
170 type Fitted = MultiTaskKernelRidgeRegression<Trained>;
171
172 fn fit(self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Self::Fitted> {
173 if x.nrows() != y.nrows() {
174 return Err(SklearsError::InvalidInput(
175 "Number of samples must match".to_string(),
176 ));
177 }
178
179 let _n_samples = x.nrows();
180 let n_tasks = y.ncols();
181
182 let feature_transformer = self.fit_feature_transformer(x)?;
184 let x_transformed = feature_transformer.transform(x)?;
185 let _n_features = x_transformed.ncols();
186
187 let weights = match self.solver {
189 Solver::Direct => self.solve_direct_multitask(&x_transformed, y)?,
190 Solver::SVD => self.solve_svd_multitask(&x_transformed, y)?,
191 Solver::ConjugateGradient { max_iter, tol } => {
192 self.solve_cg_multitask(&x_transformed, y, max_iter, tol)?
193 }
194 };
195
196 Ok(MultiTaskKernelRidgeRegression {
197 approximation_method: self.approximation_method,
198 alpha: self.alpha,
199 task_regularization: self.task_regularization,
200 solver: self.solver,
201 random_state: self.random_state,
202 weights_: Some(weights),
203 feature_transformer_: Some(feature_transformer),
204 n_tasks_: Some(n_tasks),
205 _state: PhantomData,
206 })
207 }
208}
209
210impl MultiTaskKernelRidgeRegression<Untrained> {
211 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
213 match &self.approximation_method {
214 ApproximationMethod::Nystroem {
215 kernel,
216 n_components,
217 sampling_strategy,
218 } => {
219 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
220 .sampling_strategy(sampling_strategy.clone());
221 if let Some(seed) = self.random_state {
222 nystroem = nystroem.random_state(seed);
223 }
224 let fitted = nystroem.fit(x, &())?;
225 Ok(FeatureTransformer::Nystroem(fitted))
226 }
227 ApproximationMethod::RandomFourierFeatures {
228 n_components,
229 gamma,
230 } => {
231 let mut rff = RBFSampler::new(*n_components).gamma(*gamma);
232 if let Some(seed) = self.random_state {
233 rff = rff.random_state(seed);
234 }
235 let fitted = rff.fit(x, &())?;
236 Ok(FeatureTransformer::RBFSampler(fitted))
237 }
238 ApproximationMethod::StructuredRandomFeatures {
239 n_components,
240 gamma,
241 } => {
242 let mut srf = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
243 if let Some(seed) = self.random_state {
244 srf = srf.random_state(seed);
245 }
246 let fitted = srf.fit(x, &())?;
247 Ok(FeatureTransformer::StructuredRFF(fitted))
248 }
249 ApproximationMethod::Fastfood {
250 n_components,
251 gamma,
252 } => {
253 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
254 if let Some(seed) = self.random_state {
255 fastfood = fastfood.random_state(seed);
256 }
257 let fitted = fastfood.fit(x, &())?;
258 Ok(FeatureTransformer::Fastfood(fitted))
259 }
260 }
261 }
262
263 fn solve_direct_multitask(
265 &self,
266 x: &Array2<Float>,
267 y: &Array2<Float>,
268 ) -> Result<Array2<Float>> {
269 let n_features = x.ncols();
270 let n_tasks = y.ncols();
271
272 let mut all_weights = Array2::zeros((n_features, n_tasks));
275
276 for task_idx in 0..n_tasks {
277 let y_task = y.column(task_idx);
278
279 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
281 let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
282
283 let xtx = x_f64.t().dot(&x_f64);
284 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
285
286 let xty = x_f64.t().dot(&y_task_f64);
287 let weights_task_f64 =
288 regularized_xtx
289 .solve(&xty)
290 .map_err(|e| SklearsError::InvalidParameter {
291 name: "regularization".to_string(),
292 reason: format!("Linear system solving failed: {:?}", e),
293 })?;
294
295 let weights_task =
297 Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
298 all_weights.column_mut(task_idx).assign(&weights_task);
299 }
300
301 match &self.task_regularization {
304 TaskRegularization::L2 { beta } => {
305 let mean_weight = all_weights.mean_axis(Axis(1)).unwrap();
307 for mut col in all_weights.axis_iter_mut(Axis(1)) {
308 let diff = &col.to_owned() - &mean_weight;
309 col.scaled_add(-beta, &diff);
310 }
311 }
312 _ => {} }
314
315 Ok(all_weights)
316 }
317
318 fn solve_svd_multitask(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
320 let n_features = x.ncols();
321 let n_tasks = y.ncols();
322 let mut all_weights = Array2::zeros((n_features, n_tasks));
323
324 for task_idx in 0..n_tasks {
325 let y_task = y.column(task_idx);
326
327 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
329 let y_task_f64 = Array1::from_vec(y_task.iter().copied().collect());
330
331 let xtx = x_f64.t().dot(&x_f64);
332 let regularized_xtx = xtx + Array2::<f64>::eye(n_features) * self.alpha;
333
334 let (u, s, vt) =
335 regularized_xtx
336 .svd(true, true)
337 .map_err(|e| SklearsError::InvalidParameter {
338 name: "svd".to_string(),
339 reason: format!("SVD decomposition failed: {:?}", e),
340 })?;
341 let u = u.unwrap();
342 let vt = vt.unwrap();
343
344 let xty = x_f64.t().dot(&y_task_f64);
346 let ut_b = u.t().dot(&xty);
347 let s_inv = s.mapv(|x| if x > 1e-10 { 1.0 / x } else { 0.0 });
348 let y_svd = ut_b * s_inv;
349 let weights_task_f64 = vt.t().dot(&y_svd);
350
351 let weights_task =
353 Array1::from_vec(weights_task_f64.iter().map(|&val| val as Float).collect());
354 all_weights.column_mut(task_idx).assign(&weights_task);
355 }
356
357 Ok(all_weights)
358 }
359
360 fn solve_cg_multitask(
362 &self,
363 x: &Array2<Float>,
364 y: &Array2<Float>,
365 max_iter: usize,
366 tol: Float,
367 ) -> Result<Array2<Float>> {
368 let n_features = x.ncols();
369 let n_tasks = y.ncols();
370 let mut all_weights = Array2::zeros((n_features, n_tasks));
371
372 for task_idx in 0..n_tasks {
373 let y_task = y.column(task_idx);
374 let xty = x.t().dot(&y_task);
375
376 let mut weights = Array1::zeros(n_features);
378 let mut r = xty.clone();
379 let mut p = r.clone();
380 let mut rsold = r.dot(&r);
381
382 for _iter in 0..max_iter {
383 let xtx_p = x.t().dot(&x.dot(&p)) + &p * self.alpha;
384 let alpha_cg = rsold / p.dot(&xtx_p);
385
386 weights = weights + &p * alpha_cg;
387 r = r - &xtx_p * alpha_cg;
388
389 let rsnew = r.dot(&r);
390
391 if rsnew.sqrt() < tol {
392 break;
393 }
394
395 let beta = rsnew / rsold;
396 p = &r + &p * beta;
397 rsold = rsnew;
398 }
399
400 all_weights.column_mut(task_idx).assign(&weights);
401 }
402
403 Ok(all_weights)
404 }
405}
406
407impl Predict<Array2<Float>, Array2<Float>> for MultiTaskKernelRidgeRegression<Trained> {
408 fn predict(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
409 let feature_transformer =
410 self.feature_transformer_
411 .as_ref()
412 .ok_or_else(|| SklearsError::NotFitted {
413 operation: "predict".to_string(),
414 })?;
415
416 let weights = self
417 .weights_
418 .as_ref()
419 .ok_or_else(|| SklearsError::NotFitted {
420 operation: "predict".to_string(),
421 })?;
422
423 let x_transformed = feature_transformer.transform(x)?;
424 let predictions = x_transformed.dot(weights);
425
426 Ok(predictions)
427 }
428}
429
430impl MultiTaskKernelRidgeRegression<Trained> {
431 pub fn n_tasks(&self) -> usize {
433 self.n_tasks_.unwrap_or(0)
434 }
435
436 pub fn weights(&self) -> Option<&Array2<Float>> {
438 self.weights_.as_ref()
439 }
440
441 pub fn task_weights(&self, task_idx: usize) -> Result<Array1<Float>> {
443 let weights = self
444 .weights_
445 .as_ref()
446 .ok_or_else(|| SklearsError::NotFitted {
447 operation: "predict".to_string(),
448 })?;
449
450 if task_idx >= weights.ncols() {
451 return Err(SklearsError::InvalidInput(format!(
452 "Task index {} out of range",
453 task_idx
454 )));
455 }
456
457 Ok(weights.column(task_idx).to_owned())
458 }
459
460 pub fn predict_task(&self, x: &Array2<Float>, task_idx: usize) -> Result<Array1<Float>> {
462 let predictions = self.predict(x)?;
463
464 if task_idx >= predictions.ncols() {
465 return Err(SklearsError::InvalidInput(format!(
466 "Task index {} out of range",
467 task_idx
468 )));
469 }
470
471 Ok(predictions.column(task_idx).to_owned())
472 }
473}
474
475#[allow(non_snake_case)]
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use scirs2_core::ndarray::array;
480
481 #[test]
482 fn test_multitask_kernel_ridge_regression() {
483 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
484 let y = array![[1.0, 2.0], [4.0, 5.0], [9.0, 10.0], [16.0, 17.0]]; let approximation = ApproximationMethod::RandomFourierFeatures {
487 n_components: 20,
488 gamma: 0.1,
489 };
490
491 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
492 let fitted = mtkrr.fit(&x, &y).unwrap();
493 let predictions = fitted.predict(&x).unwrap();
494
495 assert_eq!(predictions.shape(), &[4, 2]);
496 assert_eq!(fitted.n_tasks(), 2);
497
498 let task0_pred = fitted.predict_task(&x, 0).unwrap();
500 let task1_pred = fitted.predict_task(&x, 1).unwrap();
501
502 assert_eq!(task0_pred.len(), 4);
503 assert_eq!(task1_pred.len(), 4);
504
505 for pred in predictions.iter() {
507 assert!(pred.is_finite());
508 }
509 }
510
511 #[test]
512 fn test_multitask_with_regularization() {
513 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
514 let y = array![[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]; let approximation = ApproximationMethod::RandomFourierFeatures {
517 n_components: 15,
518 gamma: 1.0,
519 };
520
521 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation)
522 .alpha(0.1)
523 .task_regularization(TaskRegularization::L2 { beta: 0.1 });
524
525 let fitted = mtkrr.fit(&x, &y).unwrap();
526 let predictions = fitted.predict(&x).unwrap();
527
528 assert_eq!(predictions.shape(), &[3, 2]);
529
530 let weights = fitted.weights().unwrap();
532 let task0_weights = weights.column(0);
533 let task1_weights = weights.column(1);
534 let weight_diff = (&task0_weights - &task1_weights)
535 .mapv(|x| x.abs())
536 .mean()
537 .unwrap();
538
539 assert!(weight_diff < 1.0);
541 }
542
543 #[test]
544 fn test_multitask_different_solvers() {
545 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
546 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
547
548 let approximation = ApproximationMethod::RandomFourierFeatures {
549 n_components: 10,
550 gamma: 1.0,
551 };
552
553 let solvers = vec![
555 Solver::Direct,
556 Solver::SVD,
557 Solver::ConjugateGradient {
558 max_iter: 100,
559 tol: 1e-6,
560 },
561 ];
562
563 for solver in solvers {
564 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation.clone())
565 .solver(solver)
566 .alpha(0.1);
567
568 let fitted = mtkrr.fit(&x, &y).unwrap();
569 let predictions = fitted.predict(&x).unwrap();
570
571 assert_eq!(predictions.shape(), &[3, 2]);
572 }
573 }
574
575 #[test]
576 fn test_multitask_single_task() {
577 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
579 let y = array![[1.0], [2.0], [3.0]]; let approximation = ApproximationMethod::RandomFourierFeatures {
582 n_components: 10,
583 gamma: 1.0,
584 };
585
586 let mtkrr = MultiTaskKernelRidgeRegression::new(approximation).alpha(0.1);
587 let fitted = mtkrr.fit(&x, &y).unwrap();
588 let predictions = fitted.predict(&x).unwrap();
589
590 assert_eq!(predictions.shape(), &[3, 1]);
591 assert_eq!(fitted.n_tasks(), 1);
592 }
593
594 #[test]
595 fn test_multitask_reproducibility() {
596 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
597 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
598
599 let approximation = ApproximationMethod::RandomFourierFeatures {
600 n_components: 10,
601 gamma: 1.0,
602 };
603
604 let mtkrr1 = MultiTaskKernelRidgeRegression::new(approximation.clone())
605 .alpha(0.1)
606 .random_state(42);
607 let fitted1 = mtkrr1.fit(&x, &y).unwrap();
608 let pred1 = fitted1.predict(&x).unwrap();
609
610 let mtkrr2 = MultiTaskKernelRidgeRegression::new(approximation)
611 .alpha(0.1)
612 .random_state(42);
613 let fitted2 = mtkrr2.fit(&x, &y).unwrap();
614 let pred2 = fitted2.predict(&x).unwrap();
615
616 assert_eq!(pred1.shape(), pred2.shape());
617 for i in 0..pred1.nrows() {
618 for j in 0..pred1.ncols() {
619 assert!((pred1[[i, j]] - pred2[[i, j]]).abs() < 1e-10);
620 }
621 }
622 }
623
624 #[test]
625 fn test_task_regularization_penalties() {
626 let weights = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
627
628 let model =
629 MultiTaskKernelRidgeRegression::new(ApproximationMethod::RandomFourierFeatures {
630 n_components: 10,
631 gamma: 1.0,
632 });
633
634 let reg_l2 = TaskRegularization::L2 { beta: 0.1 };
636 let reg_l1 = TaskRegularization::L1 { beta: 0.1 };
637 let reg_nuclear = TaskRegularization::NuclearNorm { beta: 0.1 };
638 let reg_group = TaskRegularization::GroupSparsity { beta: 0.1 };
639
640 let model_l2 = model.clone().task_regularization(reg_l2);
641 let model_l1 = model.clone().task_regularization(reg_l1);
642 let model_nuclear = model.clone().task_regularization(reg_nuclear);
643 let model_group = model.clone().task_regularization(reg_group);
644
645 let penalty_l2 = model_l2.compute_task_regularization_penalty(&weights);
646 let penalty_l1 = model_l1.compute_task_regularization_penalty(&weights);
647 let penalty_nuclear = model_nuclear.compute_task_regularization_penalty(&weights);
648 let penalty_group = model_group.compute_task_regularization_penalty(&weights);
649
650 assert!(penalty_l2 >= 0.0);
652 assert!(penalty_l1 >= 0.0);
653 assert!(penalty_nuclear >= 0.0);
654 assert!(penalty_group >= 0.0);
655
656 assert!(penalty_l1 > penalty_l2);
658 }
659}