1use crate::lda::{LinearDiscriminantAnalysis, LinearDiscriminantAnalysisConfig};
23use crate::qda::{QuadraticDiscriminantAnalysis, QuadraticDiscriminantAnalysisConfig};
24use scirs2_core::ndarray::{s, Array1, Array2, Axis};
26use sklears_core::{
27 error::Result,
28 prelude::SklearsError,
29 traits::{Estimator, Fit, Predict, PredictProba, Trained},
30 types::Float,
31};
32
33#[derive(Debug, Clone)]
35pub struct MultiTaskDiscriminantLearningConfig {
36 pub n_shared_components: Option<usize>,
38 pub n_task_components: Option<usize>,
40 pub sharing_penalty: Float,
42 pub task_penalty: Float,
44 pub base_discriminant: String,
46 pub task_weighting: String,
48 pub normalize_weights: bool,
50 pub max_iter: usize,
52 pub tol: Float,
54 pub warm_start: bool,
56 pub random_state: Option<u64>,
58 pub lda_config: LinearDiscriminantAnalysisConfig,
60 pub qda_config: QuadraticDiscriminantAnalysisConfig,
62}
63
64impl Default for MultiTaskDiscriminantLearningConfig {
65 fn default() -> Self {
66 Self {
67 n_shared_components: None,
68 n_task_components: None,
69 sharing_penalty: 1.0,
70 task_penalty: 1.0,
71 base_discriminant: "lda".to_string(),
72 task_weighting: "uniform".to_string(),
73 normalize_weights: true,
74 max_iter: 100,
75 tol: 1e-6,
76 warm_start: false,
77 random_state: None,
78 lda_config: LinearDiscriminantAnalysisConfig::default(),
79 qda_config: QuadraticDiscriminantAnalysisConfig::default(),
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct Task {
87 pub task_id: usize,
89 pub x: Array2<Float>,
91 pub y: Array1<i32>,
93 pub weight: Float,
95 pub classes: Vec<i32>,
97}
98
99impl Task {
100 pub fn new(task_id: usize, x: Array2<Float>, y: Array1<i32>) -> Result<Self> {
102 if x.nrows() != y.len() {
103 return Err(SklearsError::InvalidInput(
104 "Number of samples in X and y must match".to_string(),
105 ));
106 }
107
108 let classes: Vec<i32> = {
109 let mut classes: Vec<i32> = y.iter().cloned().collect();
110 classes.sort_unstable();
111 classes.dedup();
112 classes
113 };
114
115 Ok(Self {
116 task_id,
117 x,
118 y,
119 weight: 1.0,
120 classes,
121 })
122 }
123
124 pub fn with_weight(mut self, weight: Float) -> Self {
126 self.weight = weight;
127 self
128 }
129
130 pub fn n_samples(&self) -> usize {
132 self.x.nrows()
133 }
134
135 pub fn n_features(&self) -> usize {
137 self.x.ncols()
138 }
139
140 pub fn n_classes(&self) -> usize {
142 self.classes.len()
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct MultiTaskDiscriminantLearning {
149 config: MultiTaskDiscriminantLearningConfig,
150}
151
152impl MultiTaskDiscriminantLearning {
153 pub fn new() -> Self {
155 Self {
156 config: MultiTaskDiscriminantLearningConfig::default(),
157 }
158 }
159
160 pub fn n_shared_components(mut self, n_components: Option<usize>) -> Self {
162 self.config.n_shared_components = n_components;
163 self
164 }
165
166 pub fn n_task_components(mut self, n_components: Option<usize>) -> Self {
168 self.config.n_task_components = n_components;
169 self
170 }
171
172 pub fn sharing_penalty(mut self, penalty: Float) -> Self {
174 self.config.sharing_penalty = penalty;
175 self
176 }
177
178 pub fn task_penalty(mut self, penalty: Float) -> Self {
180 self.config.task_penalty = penalty;
181 self
182 }
183
184 pub fn base_discriminant(mut self, discriminant_type: &str) -> Self {
186 self.config.base_discriminant = discriminant_type.to_string();
187 self
188 }
189
190 pub fn task_weighting(mut self, weighting: &str) -> Self {
192 self.config.task_weighting = weighting.to_string();
193 self
194 }
195
196 pub fn normalize_weights(mut self, normalize: bool) -> Self {
198 self.config.normalize_weights = normalize;
199 self
200 }
201
202 pub fn max_iter(mut self, max_iter: usize) -> Self {
204 self.config.max_iter = max_iter;
205 self
206 }
207
208 pub fn tol(mut self, tol: Float) -> Self {
210 self.config.tol = tol;
211 self
212 }
213
214 pub fn warm_start(mut self, warm_start: bool) -> Self {
216 self.config.warm_start = warm_start;
217 self
218 }
219
220 pub fn random_state(mut self, seed: u64) -> Self {
222 self.config.random_state = Some(seed);
223 self
224 }
225
226 fn compute_task_weights(&self, tasks: &[Task]) -> Vec<Float> {
228 let mut weights = match self.config.task_weighting.as_str() {
229 "uniform" => vec![1.0; tasks.len()],
230 "proportional" => tasks.iter().map(|t| t.n_samples() as Float).collect(),
231 "inverse" => tasks
232 .iter()
233 .map(|t| 1.0 / (t.n_samples() as Float))
234 .collect(),
235 _ => vec![1.0; tasks.len()],
236 };
237
238 for (i, task) in tasks.iter().enumerate() {
240 weights[i] *= task.weight;
241 }
242
243 if self.config.normalize_weights {
245 let sum: Float = weights.iter().sum();
246 if sum > 0.0 {
247 for weight in &mut weights {
248 *weight /= sum;
249 }
250 }
251 }
252
253 weights
254 }
255
256 fn compute_shared_subspace(&self, tasks: &[Task]) -> Result<Array2<Float>> {
258 let n_features = tasks[0].n_features();
259 let n_shared = self
260 .config
261 .n_shared_components
262 .unwrap_or((n_features / 2).max(1));
263
264 let mut all_x = Vec::new();
266 let mut all_y = Vec::new();
267 let mut task_indices = Vec::new();
268
269 for (task_idx, task) in tasks.iter().enumerate() {
270 for (i, row) in task.x.axis_iter(Axis(0)).enumerate() {
271 all_x.push(row.to_owned());
272 all_y.push(task.y[i]);
273 task_indices.push(task_idx);
274 }
275 }
276
277 if all_x.is_empty() {
278 return Err(SklearsError::InvalidInput(
279 "No training data provided".to_string(),
280 ));
281 }
282
283 let combined_x = Array2::from_shape_vec(
285 (all_x.len(), n_features),
286 all_x.into_iter().flatten().collect(),
287 )
288 .map_err(|_| SklearsError::InvalidInput("Failed to stack task data".to_string()))?;
289
290 let combined_y = Array1::from_vec(all_y);
291
292 let shared_components = match self.config.base_discriminant.as_str() {
294 "lda" => {
295 let lda = LinearDiscriminantAnalysis::new().n_components(Some(n_shared));
296 let fitted_lda = lda.fit(&combined_x, &combined_y)?;
297 fitted_lda.components().clone()
298 }
299 "qda" => {
300 let lda = LinearDiscriminantAnalysis::new().n_components(Some(n_shared));
302 let fitted_lda = lda.fit(&combined_x, &combined_y)?;
303 fitted_lda.components().clone()
304 }
305 _ => {
306 return Err(SklearsError::InvalidParameter {
307 name: "base_discriminant".to_string(),
308 reason: format!(
309 "Unknown base discriminant: {}",
310 self.config.base_discriminant
311 ),
312 })
313 }
314 };
315
316 Ok(shared_components)
317 }
318
319 fn compute_task_specific_components(
321 &self,
322 task: &Task,
323 shared_components: &Array2<Float>,
324 ) -> Result<Array2<Float>> {
325 let n_features = task.n_features();
326 let n_task = self
327 .config
328 .n_task_components
329 .unwrap_or((n_features / 4).max(1));
330
331 let task_x = &task.x;
333
334 let shared_proj = shared_components.t().dot(shared_components);
336 let mut ortho_proj = Array2::eye(n_features);
337 ortho_proj = ortho_proj - shared_proj;
338
339 let projected_x = task_x.dot(&ortho_proj);
341
342 let task_components = match self.config.base_discriminant.as_str() {
344 "lda" => {
345 let mut lda_config = self.config.lda_config.clone();
346 lda_config.n_components = Some(n_task);
347 let lda = LinearDiscriminantAnalysis::new();
348 let fitted_lda = lda.fit(&projected_x, &task.y)?;
349 fitted_lda.components().clone()
350 }
351 "qda" => {
352 let mut lda_config = self.config.lda_config.clone();
354 lda_config.n_components = Some(n_task);
355 let lda = LinearDiscriminantAnalysis::new();
356 let fitted_lda = lda.fit(&projected_x, &task.y)?;
357 fitted_lda.components().clone()
358 }
359 _ => {
360 return Err(SklearsError::InvalidParameter {
361 name: "base_discriminant".to_string(),
362 reason: format!(
363 "Unknown base discriminant: {}",
364 self.config.base_discriminant
365 ),
366 })
367 }
368 };
369
370 let final_components = task_components.dot(&ortho_proj);
372
373 Ok(final_components)
374 }
375}
376
377impl Default for MultiTaskDiscriminantLearning {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383impl Estimator for MultiTaskDiscriminantLearning {
384 type Config = MultiTaskDiscriminantLearningConfig;
385 type Error = SklearsError;
386 type Float = Float;
387
388 fn config(&self) -> &Self::Config {
389 &self.config
390 }
391}
392
393#[derive(Debug)]
395pub struct TrainedMultiTaskDiscriminantLearning {
396 shared_components: Array2<Float>,
398 task_components: Vec<Array2<Float>>,
400 task_classifiers: Vec<TaskClassifier>,
402 tasks: Vec<Task>,
404 task_weights: Vec<Float>,
406 global_classes: Vec<i32>,
408 config: MultiTaskDiscriminantLearningConfig,
410}
411
412#[derive(Debug)]
414pub enum TaskClassifier {
415 LDA(LinearDiscriminantAnalysis<Trained>),
417 QDA(QuadraticDiscriminantAnalysis<Trained>),
419}
420
421impl TaskClassifier {
422 pub fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
424 match self {
425 TaskClassifier::LDA(lda) => lda.predict(x),
426 TaskClassifier::QDA(qda) => qda.predict(x),
427 }
428 }
429
430 pub fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
432 match self {
433 TaskClassifier::LDA(lda) => lda.predict_proba(x),
434 TaskClassifier::QDA(qda) => qda.predict_proba(x),
435 }
436 }
437
438 pub fn classes(&self) -> &[i32] {
440 match self {
441 TaskClassifier::LDA(lda) => lda.classes().as_slice().unwrap(),
442 TaskClassifier::QDA(qda) => qda.classes().as_slice().unwrap(),
443 }
444 }
445}
446
447impl TrainedMultiTaskDiscriminantLearning {
448 pub fn shared_components(&self) -> &Array2<Float> {
450 &self.shared_components
451 }
452
453 pub fn task_components(&self, task_id: usize) -> Option<&Array2<Float>> {
455 self.task_components.get(task_id)
456 }
457
458 pub fn global_classes(&self) -> &[i32] {
460 &self.global_classes
461 }
462
463 pub fn tasks(&self) -> &[Task] {
465 &self.tasks
466 }
467
468 pub fn task_weights(&self) -> &[Float] {
470 &self.task_weights
471 }
472
473 pub fn predict_task(&self, x: &Array2<Float>, task_id: usize) -> Result<Array1<i32>> {
475 if task_id >= self.task_classifiers.len() {
476 return Err(SklearsError::InvalidParameter {
477 name: "task_id".to_string(),
478 reason: format!("Task {} not found", task_id),
479 });
480 }
481
482 let transformed_x = self.transform_task(x, task_id)?;
484
485 self.task_classifiers[task_id].predict(&transformed_x)
487 }
488
489 pub fn predict_proba_task(&self, x: &Array2<Float>, task_id: usize) -> Result<Array2<Float>> {
491 if task_id >= self.task_classifiers.len() {
492 return Err(SklearsError::InvalidParameter {
493 name: "task_id".to_string(),
494 reason: format!("Task {} not found", task_id),
495 });
496 }
497
498 let transformed_x = self.transform_task(x, task_id)?;
500
501 self.task_classifiers[task_id].predict_proba(&transformed_x)
503 }
504
505 pub fn transform_task(&self, x: &Array2<Float>, task_id: usize) -> Result<Array2<Float>> {
507 if task_id >= self.task_components.len() {
508 return Err(SklearsError::InvalidParameter {
509 name: "task_id".to_string(),
510 reason: format!("Task {} not found", task_id),
511 });
512 }
513
514 let shared_proj = x.dot(&self.shared_components.t());
516
517 let task_proj = x.dot(&self.task_components[task_id].t());
519
520 let mut combined = Array2::zeros((x.nrows(), shared_proj.ncols() + task_proj.ncols()));
522 combined
523 .slice_mut(s![.., ..shared_proj.ncols()])
524 .assign(&shared_proj);
525 combined
526 .slice_mut(s![.., shared_proj.ncols()..])
527 .assign(&task_proj);
528
529 Ok(combined)
530 }
531
532 pub fn transform_shared(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
534 Ok(x.dot(&self.shared_components.t()))
535 }
536
537 pub fn add_task(&mut self, task: Task) -> Result<usize> {
539 let task_id = self.tasks.len();
540
541 let task_components = self.compute_task_components(&task)?;
543
544 let task_classifier = self.train_task_classifier(&task, &task_components)?;
546
547 self.tasks.push(task);
549 self.task_components.push(task_components);
550 self.task_classifiers.push(task_classifier);
551 self.task_weights.push(1.0);
552
553 self.update_global_classes();
555
556 Ok(task_id)
557 }
558
559 fn compute_task_components(&self, task: &Task) -> Result<Array2<Float>> {
560 let n_features = task.n_features();
561 let n_task = self
562 .config
563 .n_task_components
564 .unwrap_or((n_features / 4).max(1));
565
566 let shared_proj = self.shared_components.t().dot(&self.shared_components);
568 let mut ortho_proj = Array2::eye(n_features);
569 ortho_proj = ortho_proj - shared_proj;
570
571 let projected_x = task.x.dot(&ortho_proj);
573
574 let task_components = match self.config.base_discriminant.as_str() {
576 "lda" => {
577 let mut lda_config = self.config.lda_config.clone();
578 lda_config.n_components = Some(n_task);
579 let lda = LinearDiscriminantAnalysis::new();
580 let fitted_lda = lda.fit(&projected_x, &task.y)?;
581 fitted_lda.components().clone()
582 }
583 "qda" => {
584 let mut lda_config = self.config.lda_config.clone();
585 lda_config.n_components = Some(n_task);
586 let lda = LinearDiscriminantAnalysis::new();
587 let fitted_lda = lda.fit(&projected_x, &task.y)?;
588 fitted_lda.components().clone()
589 }
590 _ => {
591 return Err(SklearsError::InvalidParameter {
592 name: "base_discriminant".to_string(),
593 reason: format!(
594 "Unknown base discriminant: {}",
595 self.config.base_discriminant
596 ),
597 })
598 }
599 };
600
601 Ok(task_components.dot(&ortho_proj))
603 }
604
605 fn train_task_classifier(
606 &self,
607 task: &Task,
608 task_components: &Array2<Float>,
609 ) -> Result<TaskClassifier> {
610 let shared_proj = task.x.dot(&self.shared_components.t());
612 let task_proj = task.x.dot(&task_components.t());
613
614 let mut combined = Array2::zeros((task.x.nrows(), shared_proj.ncols() + task_proj.ncols()));
615 combined
616 .slice_mut(s![.., ..shared_proj.ncols()])
617 .assign(&shared_proj);
618 combined
619 .slice_mut(s![.., shared_proj.ncols()..])
620 .assign(&task_proj);
621
622 match self.config.base_discriminant.as_str() {
623 "lda" => {
624 let lda = LinearDiscriminantAnalysis::new();
625 let fitted_lda = lda.fit(&combined, &task.y)?;
626 Ok(TaskClassifier::LDA(fitted_lda))
627 }
628 "qda" => {
629 let qda = QuadraticDiscriminantAnalysis::new();
630 let fitted_qda = qda.fit(&combined, &task.y)?;
631 Ok(TaskClassifier::QDA(fitted_qda))
632 }
633 _ => Err(SklearsError::InvalidParameter {
634 name: "base_discriminant".to_string(),
635 reason: format!(
636 "Unknown base discriminant: {}",
637 self.config.base_discriminant
638 ),
639 }),
640 }
641 }
642
643 fn update_global_classes(&mut self) {
644 let mut all_classes = Vec::new();
645 for task in &self.tasks {
646 all_classes.extend(&task.classes);
647 }
648 all_classes.sort_unstable();
649 all_classes.dedup();
650 self.global_classes = all_classes;
651 }
652}
653
654impl Fit<Vec<Task>, ()> for MultiTaskDiscriminantLearning {
655 type Fitted = TrainedMultiTaskDiscriminantLearning;
656
657 fn fit(self, tasks: &Vec<Task>, _y: &()) -> Result<Self::Fitted> {
658 if tasks.is_empty() {
659 return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
660 }
661
662 let n_features = tasks[0].n_features();
664 for task in tasks {
665 if task.n_features() != n_features {
666 return Err(SklearsError::InvalidInput(
667 "All tasks must have the same number of features".to_string(),
668 ));
669 }
670 }
671
672 let task_weights = self.compute_task_weights(tasks);
674
675 let shared_components = self.compute_shared_subspace(tasks)?;
677
678 let mut task_components = Vec::new();
680 let mut task_classifiers = Vec::new();
681
682 for task in tasks {
683 let task_comp = self.compute_task_specific_components(task, &shared_components)?;
684
685 let shared_proj = task.x.dot(&shared_components.t());
687 let task_proj = task.x.dot(&task_comp.t());
688
689 let mut combined =
690 Array2::zeros((task.x.nrows(), shared_proj.ncols() + task_proj.ncols()));
691 combined
692 .slice_mut(s![.., ..shared_proj.ncols()])
693 .assign(&shared_proj);
694 combined
695 .slice_mut(s![.., shared_proj.ncols()..])
696 .assign(&task_proj);
697
698 let classifier = match self.config.base_discriminant.as_str() {
700 "lda" => {
701 let lda = LinearDiscriminantAnalysis::new();
702 let fitted_lda = lda.fit(&combined, &task.y)?;
703 TaskClassifier::LDA(fitted_lda)
704 }
705 "qda" => {
706 let qda = QuadraticDiscriminantAnalysis::new();
707 let fitted_qda = qda.fit(&combined, &task.y)?;
708 TaskClassifier::QDA(fitted_qda)
709 }
710 _ => {
711 return Err(SklearsError::InvalidParameter {
712 name: "base_discriminant".to_string(),
713 reason: format!(
714 "Unknown base discriminant: {}",
715 self.config.base_discriminant
716 ),
717 })
718 }
719 };
720
721 task_components.push(task_comp);
722 task_classifiers.push(classifier);
723 }
724
725 let mut global_classes = Vec::new();
727 for task in tasks {
728 global_classes.extend(&task.classes);
729 }
730 global_classes.sort_unstable();
731 global_classes.dedup();
732
733 Ok(TrainedMultiTaskDiscriminantLearning {
734 shared_components,
735 task_components,
736 task_classifiers,
737 tasks: tasks.clone(),
738 task_weights,
739 global_classes,
740 config: self.config.clone(),
741 })
742 }
743}
744
745#[allow(non_snake_case)]
746#[cfg(test)]
747mod tests {
748 use super::*;
749 use approx::assert_abs_diff_eq;
750 use scirs2_core::ndarray::array;
751
752 #[test]
753 fn test_task_creation() {
754 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
755 let y = array![0, 0, 1, 1];
756
757 let task = Task::new(0, x.clone(), y.clone()).unwrap();
758
759 assert_eq!(task.task_id, 0);
760 assert_eq!(task.n_samples(), 4);
761 assert_eq!(task.n_features(), 2);
762 assert_eq!(task.n_classes(), 2);
763 assert_eq!(task.classes, vec![0, 1]);
764 }
765
766 #[test]
767 fn test_multi_task_discriminant_learning_basic() {
768 let task1_x = array![
770 [1.0, 2.0],
771 [1.1, 2.1],
772 [1.2, 2.2], [3.0, 4.0],
774 [3.1, 4.1],
775 [3.2, 4.2] ];
777 let task1_y = array![0, 0, 0, 1, 1, 1];
778
779 let task2_x = array![
780 [1.5, 2.5],
781 [1.6, 2.6],
782 [1.7, 2.7], [3.5, 4.5],
784 [3.6, 4.6],
785 [3.7, 4.7] ];
787 let task2_y = array![0, 0, 0, 1, 1, 1];
788
789 let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
790 let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
791 let tasks = vec![task1, task2];
792
793 let mtdl = MultiTaskDiscriminantLearning::new();
794 let fitted = mtdl.fit(&tasks, &()).unwrap();
795
796 let predictions = fitted.predict_task(&task1_x, 0).unwrap();
798 assert_eq!(predictions.len(), 6);
799
800 let predictions = fitted.predict_task(&task2_x, 1).unwrap();
802 assert_eq!(predictions.len(), 6);
803 }
804
805 #[test]
806 fn test_multi_task_predict_proba() {
807 let task1_x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
808 let task1_y = array![0, 0, 1, 1];
809
810 let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
811 let task2_y = array![0, 0, 1, 1];
812
813 let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
814 let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
815 let tasks = vec![task1, task2];
816
817 let mtdl = MultiTaskDiscriminantLearning::new();
818 let fitted = mtdl.fit(&tasks, &()).unwrap();
819
820 let probas = fitted.predict_proba_task(&task1_x, 0).unwrap();
821 assert_eq!(probas.dim(), (4, 2));
822
823 for row in probas.axis_iter(Axis(0)) {
825 let sum: Float = row.sum();
826 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
827 }
828 }
829
830 #[test]
831 fn test_multi_task_transform() {
832 let task1_x = array![
833 [1.0, 2.0, 0.5],
834 [2.0, 1.0, 1.5],
835 [3.0, 4.0, 2.0],
836 [4.0, 3.0, 3.5],
837 [5.0, 2.0, 4.0],
838 [6.0, 1.0, 4.5]
839 ];
840 let task1_y = array![0, 0, 1, 1, 2, 2];
841
842 let task2_x = array![
843 [1.5, 2.5, 0.8],
844 [2.5, 1.5, 1.8],
845 [3.5, 4.5, 2.5],
846 [4.5, 3.5, 3.8],
847 [5.5, 2.5, 4.3],
848 [6.5, 1.5, 4.8]
849 ];
850 let task2_y = array![0, 0, 1, 1, 2, 2];
851
852 let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
853 let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
854 let tasks = vec![task1, task2];
855
856 let mtdl = MultiTaskDiscriminantLearning::new()
857 .n_shared_components(Some(2))
858 .n_task_components(Some(1));
859 let fitted = mtdl.fit(&tasks, &()).unwrap();
860
861 let shared_transformed = fitted.transform_shared(&task1_x).unwrap();
863 assert!(shared_transformed.ncols() >= 1); let task_transformed = fitted.transform_task(&task1_x, 0).unwrap();
867 assert!(task_transformed.ncols() >= shared_transformed.ncols()); }
869
870 #[test]
871 fn test_multi_task_with_qda() {
872 let task1_x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
873 let task1_y = array![0, 0, 1, 1];
874
875 let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
876 let task2_y = array![0, 0, 1, 1];
877
878 let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
879 let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
880 let tasks = vec![task1, task2];
881
882 let mtdl = MultiTaskDiscriminantLearning::new().base_discriminant("qda");
883 let fitted = mtdl.fit(&tasks, &()).unwrap();
884
885 let predictions = fitted.predict_task(&task1_x, 0).unwrap();
886 assert_eq!(predictions.len(), 4);
887 }
888
889 #[test]
890 fn test_task_weighting_strategies() {
891 let task1_x = array![[1.0, 2.0], [2.0, 3.0]];
892 let task1_y = array![0, 1];
893
894 let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
895 let task2_y = array![0, 0, 1, 1];
896
897 let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
898 let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
899 let tasks = vec![task1, task2];
900
901 let strategies = ["uniform", "proportional", "inverse"];
902 for strategy in &strategies {
903 let mtdl = MultiTaskDiscriminantLearning::new().task_weighting(strategy);
904 let fitted = mtdl.fit(&tasks, &()).unwrap();
905
906 assert_eq!(fitted.task_weights().len(), 2);
907 assert!(fitted.task_weights().iter().all(|&w| w > 0.0));
908 }
909 }
910
911 #[test]
912 fn test_add_new_task() {
913 let task1_x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
914 let task1_y = array![0, 0, 1, 1];
915
916 let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
917 let task2_y = array![0, 0, 1, 1];
918
919 let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
920 let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
921 let tasks = vec![task1, task2];
922
923 let mtdl = MultiTaskDiscriminantLearning::new();
924 let mut fitted = mtdl.fit(&tasks, &()).unwrap();
925
926 let task3_x = array![[2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
928 let task3_y = array![0, 0, 1, 1];
929 let task3 = Task::new(2, task3_x.clone(), task3_y.clone()).unwrap();
930
931 let new_task_id = fitted.add_task(task3).unwrap();
932 assert_eq!(new_task_id, 2);
933
934 let predictions = fitted.predict_task(&task3_x, new_task_id).unwrap();
936 assert_eq!(predictions.len(), 4);
937 }
938}