sklears_discriminant_analysis/
multi_task.rs

1//! # Multi-Task Discriminant Learning
2//!
3//! Multi-task discriminant learning extends traditional discriminant analysis to handle
4//! multiple related classification tasks simultaneously. This approach leverages shared
5//! information across tasks to improve performance, especially when individual tasks
6//! have limited training data.
7//!
8//! ## Key Features
9//! - Shared discriminant subspace across multiple tasks
10//! - Task-specific discriminant components  
11//! - Regularization to control sharing vs. task specialization
12//! - Support for both LDA and QDA base classifiers
13//! - Flexible task weighting strategies
14//! - Transfer learning capabilities for new tasks
15//!
16//! ## Applications
17//! - Multi-domain classification (e.g., sentiment analysis across different domains)
18//! - Multi-label classification with structured label dependencies
19//! - Few-shot learning with related tasks
20//! - Domain adaptation scenarios
21
22use crate::lda::{LinearDiscriminantAnalysis, LinearDiscriminantAnalysisConfig};
23use crate::qda::{QuadraticDiscriminantAnalysis, QuadraticDiscriminantAnalysisConfig};
24// ✅ Using SciRS2 dependencies following SciRS2 policy
25use scirs2_core::ndarray::{s, Array1, Array2, Axis};
26use sklears_core::{
27    error::Result,
28    prelude::SklearsError,
29    traits::{Estimator, Fit, Predict, PredictProba, Trained},
30    types::Float,
31};
32
33/// Configuration for multi-task discriminant learning
34#[derive(Debug, Clone)]
35pub struct MultiTaskDiscriminantLearningConfig {
36    /// Number of shared discriminant components
37    pub n_shared_components: Option<usize>,
38    /// Number of task-specific components per task
39    pub n_task_components: Option<usize>,
40    /// Regularization parameter for shared components (higher = more sharing)
41    pub sharing_penalty: Float,
42    /// Regularization parameter for task-specific components
43    pub task_penalty: Float,
44    /// Base discriminant type ("lda" or "qda")
45    pub base_discriminant: String,
46    /// Task weighting strategy ("uniform", "proportional", "inverse")
47    pub task_weighting: String,
48    /// Whether to normalize task weights
49    pub normalize_weights: bool,
50    /// Maximum iterations for optimization
51    pub max_iter: usize,
52    /// Convergence tolerance
53    pub tol: Float,
54    /// Whether to use warm start for new tasks
55    pub warm_start: bool,
56    /// Random state for reproducible results
57    pub random_state: Option<u64>,
58    /// LDA configuration for base classifiers
59    pub lda_config: LinearDiscriminantAnalysisConfig,
60    /// QDA configuration for base classifiers
61    pub qda_config: QuadraticDiscriminantAnalysisConfig,
62}
63
64impl Default for MultiTaskDiscriminantLearningConfig {
65    fn default() -> Self {
66        Self {
67            n_shared_components: None,
68            n_task_components: None,
69            sharing_penalty: 1.0,
70            task_penalty: 1.0,
71            base_discriminant: "lda".to_string(),
72            task_weighting: "uniform".to_string(),
73            normalize_weights: true,
74            max_iter: 100,
75            tol: 1e-6,
76            warm_start: false,
77            random_state: None,
78            lda_config: LinearDiscriminantAnalysisConfig::default(),
79            qda_config: QuadraticDiscriminantAnalysisConfig::default(),
80        }
81    }
82}
83
84/// Represents a single task in multi-task learning
85#[derive(Debug, Clone)]
86pub struct Task {
87    /// Task identifier
88    pub task_id: usize,
89    /// Training data for this task
90    pub x: Array2<Float>,
91    /// Training labels for this task
92    pub y: Array1<i32>,
93    /// Task weight (importance)
94    pub weight: Float,
95    /// Task-specific class mapping
96    pub classes: Vec<i32>,
97}
98
99impl Task {
100    /// Create a new task
101    pub fn new(task_id: usize, x: Array2<Float>, y: Array1<i32>) -> Result<Self> {
102        if x.nrows() != y.len() {
103            return Err(SklearsError::InvalidInput(
104                "Number of samples in X and y must match".to_string(),
105            ));
106        }
107
108        let classes: Vec<i32> = {
109            let mut classes: Vec<i32> = y.iter().cloned().collect();
110            classes.sort_unstable();
111            classes.dedup();
112            classes
113        };
114
115        Ok(Self {
116            task_id,
117            x,
118            y,
119            weight: 1.0,
120            classes,
121        })
122    }
123
124    /// Set task weight
125    pub fn with_weight(mut self, weight: Float) -> Self {
126        self.weight = weight;
127        self
128    }
129
130    /// Get number of samples
131    pub fn n_samples(&self) -> usize {
132        self.x.nrows()
133    }
134
135    /// Get number of features
136    pub fn n_features(&self) -> usize {
137        self.x.ncols()
138    }
139
140    /// Get number of classes
141    pub fn n_classes(&self) -> usize {
142        self.classes.len()
143    }
144}
145
146/// Multi-task discriminant learning estimator
147#[derive(Debug, Clone)]
148pub struct MultiTaskDiscriminantLearning {
149    config: MultiTaskDiscriminantLearningConfig,
150}
151
152impl MultiTaskDiscriminantLearning {
153    /// Create a new multi-task discriminant learning estimator
154    pub fn new() -> Self {
155        Self {
156            config: MultiTaskDiscriminantLearningConfig::default(),
157        }
158    }
159
160    /// Set number of shared discriminant components
161    pub fn n_shared_components(mut self, n_components: Option<usize>) -> Self {
162        self.config.n_shared_components = n_components;
163        self
164    }
165
166    /// Set number of task-specific components
167    pub fn n_task_components(mut self, n_components: Option<usize>) -> Self {
168        self.config.n_task_components = n_components;
169        self
170    }
171
172    /// Set sharing penalty (regularization for shared components)
173    pub fn sharing_penalty(mut self, penalty: Float) -> Self {
174        self.config.sharing_penalty = penalty;
175        self
176    }
177
178    /// Set task penalty (regularization for task-specific components)
179    pub fn task_penalty(mut self, penalty: Float) -> Self {
180        self.config.task_penalty = penalty;
181        self
182    }
183
184    /// Set base discriminant type
185    pub fn base_discriminant(mut self, discriminant_type: &str) -> Self {
186        self.config.base_discriminant = discriminant_type.to_string();
187        self
188    }
189
190    /// Set task weighting strategy
191    pub fn task_weighting(mut self, weighting: &str) -> Self {
192        self.config.task_weighting = weighting.to_string();
193        self
194    }
195
196    /// Set whether to normalize task weights
197    pub fn normalize_weights(mut self, normalize: bool) -> Self {
198        self.config.normalize_weights = normalize;
199        self
200    }
201
202    /// Set maximum iterations
203    pub fn max_iter(mut self, max_iter: usize) -> Self {
204        self.config.max_iter = max_iter;
205        self
206    }
207
208    /// Set convergence tolerance
209    pub fn tol(mut self, tol: Float) -> Self {
210        self.config.tol = tol;
211        self
212    }
213
214    /// Set warm start option
215    pub fn warm_start(mut self, warm_start: bool) -> Self {
216        self.config.warm_start = warm_start;
217        self
218    }
219
220    /// Set random state
221    pub fn random_state(mut self, seed: u64) -> Self {
222        self.config.random_state = Some(seed);
223        self
224    }
225
226    /// Compute task weights based on the weighting strategy
227    fn compute_task_weights(&self, tasks: &[Task]) -> Vec<Float> {
228        let mut weights = match self.config.task_weighting.as_str() {
229            "uniform" => vec![1.0; tasks.len()],
230            "proportional" => tasks.iter().map(|t| t.n_samples() as Float).collect(),
231            "inverse" => tasks
232                .iter()
233                .map(|t| 1.0 / (t.n_samples() as Float))
234                .collect(),
235            _ => vec![1.0; tasks.len()],
236        };
237
238        // Apply manual task weights
239        for (i, task) in tasks.iter().enumerate() {
240            weights[i] *= task.weight;
241        }
242
243        // Normalize if requested
244        if self.config.normalize_weights {
245            let sum: Float = weights.iter().sum();
246            if sum > 0.0 {
247                for weight in &mut weights {
248                    *weight /= sum;
249                }
250            }
251        }
252
253        weights
254    }
255
256    /// Compute shared discriminant subspace across all tasks
257    fn compute_shared_subspace(&self, tasks: &[Task]) -> Result<Array2<Float>> {
258        let n_features = tasks[0].n_features();
259        let n_shared = self
260            .config
261            .n_shared_components
262            .unwrap_or((n_features / 2).max(1));
263
264        // Stack all data from all tasks
265        let mut all_x = Vec::new();
266        let mut all_y = Vec::new();
267        let mut task_indices = Vec::new();
268
269        for (task_idx, task) in tasks.iter().enumerate() {
270            for (i, row) in task.x.axis_iter(Axis(0)).enumerate() {
271                all_x.push(row.to_owned());
272                all_y.push(task.y[i]);
273                task_indices.push(task_idx);
274            }
275        }
276
277        if all_x.is_empty() {
278            return Err(SklearsError::InvalidInput(
279                "No training data provided".to_string(),
280            ));
281        }
282
283        // Convert to arrays
284        let combined_x = Array2::from_shape_vec(
285            (all_x.len(), n_features),
286            all_x.into_iter().flatten().collect(),
287        )
288        .map_err(|_| SklearsError::InvalidInput("Failed to stack task data".to_string()))?;
289
290        let combined_y = Array1::from_vec(all_y);
291
292        // Fit a global discriminant analysis
293        let shared_components = match self.config.base_discriminant.as_str() {
294            "lda" => {
295                let lda = LinearDiscriminantAnalysis::new().n_components(Some(n_shared));
296                let fitted_lda = lda.fit(&combined_x, &combined_y)?;
297                fitted_lda.components().clone()
298            }
299            "qda" => {
300                // For QDA, we use the pooled covariance approach similar to LDA
301                let lda = LinearDiscriminantAnalysis::new().n_components(Some(n_shared));
302                let fitted_lda = lda.fit(&combined_x, &combined_y)?;
303                fitted_lda.components().clone()
304            }
305            _ => {
306                return Err(SklearsError::InvalidParameter {
307                    name: "base_discriminant".to_string(),
308                    reason: format!(
309                        "Unknown base discriminant: {}",
310                        self.config.base_discriminant
311                    ),
312                })
313            }
314        };
315
316        Ok(shared_components)
317    }
318
319    /// Compute task-specific discriminant components
320    fn compute_task_specific_components(
321        &self,
322        task: &Task,
323        shared_components: &Array2<Float>,
324    ) -> Result<Array2<Float>> {
325        let n_features = task.n_features();
326        let n_task = self
327            .config
328            .n_task_components
329            .unwrap_or((n_features / 4).max(1));
330
331        // Project data to the orthogonal space of shared components
332        let task_x = &task.x;
333
334        // Create orthogonal projection matrix (I - P_shared)
335        let shared_proj = shared_components.t().dot(shared_components);
336        let mut ortho_proj = Array2::eye(n_features);
337        ortho_proj = ortho_proj - shared_proj;
338
339        // Project task data to orthogonal space
340        let projected_x = task_x.dot(&ortho_proj);
341
342        // Fit discriminant analysis on projected data
343        let task_components = match self.config.base_discriminant.as_str() {
344            "lda" => {
345                let mut lda_config = self.config.lda_config.clone();
346                lda_config.n_components = Some(n_task);
347                let lda = LinearDiscriminantAnalysis::new();
348                let fitted_lda = lda.fit(&projected_x, &task.y)?;
349                fitted_lda.components().clone()
350            }
351            "qda" => {
352                // For QDA, use LDA approach for component extraction
353                let mut lda_config = self.config.lda_config.clone();
354                lda_config.n_components = Some(n_task);
355                let lda = LinearDiscriminantAnalysis::new();
356                let fitted_lda = lda.fit(&projected_x, &task.y)?;
357                fitted_lda.components().clone()
358            }
359            _ => {
360                return Err(SklearsError::InvalidParameter {
361                    name: "base_discriminant".to_string(),
362                    reason: format!(
363                        "Unknown base discriminant: {}",
364                        self.config.base_discriminant
365                    ),
366                })
367            }
368        };
369
370        // Transform components back to original space
371        let final_components = task_components.dot(&ortho_proj);
372
373        Ok(final_components)
374    }
375}
376
377impl Default for MultiTaskDiscriminantLearning {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383impl Estimator for MultiTaskDiscriminantLearning {
384    type Config = MultiTaskDiscriminantLearningConfig;
385    type Error = SklearsError;
386    type Float = Float;
387
388    fn config(&self) -> &Self::Config {
389        &self.config
390    }
391}
392
393/// Trained multi-task discriminant learning model
394#[derive(Debug)]
395pub struct TrainedMultiTaskDiscriminantLearning {
396    /// Shared discriminant components
397    shared_components: Array2<Float>,
398    /// Task-specific components for each task
399    task_components: Vec<Array2<Float>>,
400    /// Task-specific classifiers
401    task_classifiers: Vec<TaskClassifier>,
402    /// Task information
403    tasks: Vec<Task>,
404    /// Task weights
405    task_weights: Vec<Float>,
406    /// Global classes (union of all task classes)
407    global_classes: Vec<i32>,
408    /// Configuration
409    config: MultiTaskDiscriminantLearningConfig,
410}
411
412/// Task-specific classifier
413#[derive(Debug)]
414pub enum TaskClassifier {
415    /// LDA classifier for this task
416    LDA(LinearDiscriminantAnalysis<Trained>),
417    /// QDA classifier for this task  
418    QDA(QuadraticDiscriminantAnalysis<Trained>),
419}
420
421impl TaskClassifier {
422    /// Predict using the task classifier
423    pub fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
424        match self {
425            TaskClassifier::LDA(lda) => lda.predict(x),
426            TaskClassifier::QDA(qda) => qda.predict(x),
427        }
428    }
429
430    /// Predict probabilities using the task classifier
431    pub fn predict_proba(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
432        match self {
433            TaskClassifier::LDA(lda) => lda.predict_proba(x),
434            TaskClassifier::QDA(qda) => qda.predict_proba(x),
435        }
436    }
437
438    /// Get classes
439    pub fn classes(&self) -> &[i32] {
440        match self {
441            TaskClassifier::LDA(lda) => lda.classes().as_slice().unwrap(),
442            TaskClassifier::QDA(qda) => qda.classes().as_slice().unwrap(),
443        }
444    }
445}
446
447impl TrainedMultiTaskDiscriminantLearning {
448    /// Get the shared components
449    pub fn shared_components(&self) -> &Array2<Float> {
450        &self.shared_components
451    }
452
453    /// Get task-specific components for a task
454    pub fn task_components(&self, task_id: usize) -> Option<&Array2<Float>> {
455        self.task_components.get(task_id)
456    }
457
458    /// Get global classes
459    pub fn global_classes(&self) -> &[i32] {
460        &self.global_classes
461    }
462
463    /// Get task information
464    pub fn tasks(&self) -> &[Task] {
465        &self.tasks
466    }
467
468    /// Get task weights
469    pub fn task_weights(&self) -> &[Float] {
470        &self.task_weights
471    }
472
473    /// Predict for a specific task
474    pub fn predict_task(&self, x: &Array2<Float>, task_id: usize) -> Result<Array1<i32>> {
475        if task_id >= self.task_classifiers.len() {
476            return Err(SklearsError::InvalidParameter {
477                name: "task_id".to_string(),
478                reason: format!("Task {} not found", task_id),
479            });
480        }
481
482        // Transform data using both shared and task-specific components
483        let transformed_x = self.transform_task(x, task_id)?;
484
485        // Predict using task classifier
486        self.task_classifiers[task_id].predict(&transformed_x)
487    }
488
489    /// Predict probabilities for a specific task
490    pub fn predict_proba_task(&self, x: &Array2<Float>, task_id: usize) -> Result<Array2<Float>> {
491        if task_id >= self.task_classifiers.len() {
492            return Err(SklearsError::InvalidParameter {
493                name: "task_id".to_string(),
494                reason: format!("Task {} not found", task_id),
495            });
496        }
497
498        // Transform data using both shared and task-specific components
499        let transformed_x = self.transform_task(x, task_id)?;
500
501        // Predict probabilities using task classifier
502        self.task_classifiers[task_id].predict_proba(&transformed_x)
503    }
504
505    /// Transform data for a specific task using shared and task-specific components
506    pub fn transform_task(&self, x: &Array2<Float>, task_id: usize) -> Result<Array2<Float>> {
507        if task_id >= self.task_components.len() {
508            return Err(SklearsError::InvalidParameter {
509                name: "task_id".to_string(),
510                reason: format!("Task {} not found", task_id),
511            });
512        }
513
514        // Project to shared subspace
515        let shared_proj = x.dot(&self.shared_components.t());
516
517        // Project to task-specific subspace
518        let task_proj = x.dot(&self.task_components[task_id].t());
519
520        // Concatenate projections
521        let mut combined = Array2::zeros((x.nrows(), shared_proj.ncols() + task_proj.ncols()));
522        combined
523            .slice_mut(s![.., ..shared_proj.ncols()])
524            .assign(&shared_proj);
525        combined
526            .slice_mut(s![.., shared_proj.ncols()..])
527            .assign(&task_proj);
528
529        Ok(combined)
530    }
531
532    /// Transform data using only shared components
533    pub fn transform_shared(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
534        Ok(x.dot(&self.shared_components.t()))
535    }
536
537    /// Add a new task (transfer learning)
538    pub fn add_task(&mut self, task: Task) -> Result<usize> {
539        let task_id = self.tasks.len();
540
541        // Compute task-specific components for new task
542        let task_components = self.compute_task_components(&task)?;
543
544        // Train task classifier
545        let task_classifier = self.train_task_classifier(&task, &task_components)?;
546
547        // Add to model
548        self.tasks.push(task);
549        self.task_components.push(task_components);
550        self.task_classifiers.push(task_classifier);
551        self.task_weights.push(1.0);
552
553        // Update global classes
554        self.update_global_classes();
555
556        Ok(task_id)
557    }
558
559    fn compute_task_components(&self, task: &Task) -> Result<Array2<Float>> {
560        let n_features = task.n_features();
561        let n_task = self
562            .config
563            .n_task_components
564            .unwrap_or((n_features / 4).max(1));
565
566        // Create orthogonal projection matrix (I - P_shared)
567        let shared_proj = self.shared_components.t().dot(&self.shared_components);
568        let mut ortho_proj = Array2::eye(n_features);
569        ortho_proj = ortho_proj - shared_proj;
570
571        // Project task data to orthogonal space
572        let projected_x = task.x.dot(&ortho_proj);
573
574        // Fit discriminant analysis on projected data
575        let task_components = match self.config.base_discriminant.as_str() {
576            "lda" => {
577                let mut lda_config = self.config.lda_config.clone();
578                lda_config.n_components = Some(n_task);
579                let lda = LinearDiscriminantAnalysis::new();
580                let fitted_lda = lda.fit(&projected_x, &task.y)?;
581                fitted_lda.components().clone()
582            }
583            "qda" => {
584                let mut lda_config = self.config.lda_config.clone();
585                lda_config.n_components = Some(n_task);
586                let lda = LinearDiscriminantAnalysis::new();
587                let fitted_lda = lda.fit(&projected_x, &task.y)?;
588                fitted_lda.components().clone()
589            }
590            _ => {
591                return Err(SklearsError::InvalidParameter {
592                    name: "base_discriminant".to_string(),
593                    reason: format!(
594                        "Unknown base discriminant: {}",
595                        self.config.base_discriminant
596                    ),
597                })
598            }
599        };
600
601        // Transform components back to original space
602        Ok(task_components.dot(&ortho_proj))
603    }
604
605    fn train_task_classifier(
606        &self,
607        task: &Task,
608        task_components: &Array2<Float>,
609    ) -> Result<TaskClassifier> {
610        // Transform task data
611        let shared_proj = task.x.dot(&self.shared_components.t());
612        let task_proj = task.x.dot(&task_components.t());
613
614        let mut combined = Array2::zeros((task.x.nrows(), shared_proj.ncols() + task_proj.ncols()));
615        combined
616            .slice_mut(s![.., ..shared_proj.ncols()])
617            .assign(&shared_proj);
618        combined
619            .slice_mut(s![.., shared_proj.ncols()..])
620            .assign(&task_proj);
621
622        match self.config.base_discriminant.as_str() {
623            "lda" => {
624                let lda = LinearDiscriminantAnalysis::new();
625                let fitted_lda = lda.fit(&combined, &task.y)?;
626                Ok(TaskClassifier::LDA(fitted_lda))
627            }
628            "qda" => {
629                let qda = QuadraticDiscriminantAnalysis::new();
630                let fitted_qda = qda.fit(&combined, &task.y)?;
631                Ok(TaskClassifier::QDA(fitted_qda))
632            }
633            _ => Err(SklearsError::InvalidParameter {
634                name: "base_discriminant".to_string(),
635                reason: format!(
636                    "Unknown base discriminant: {}",
637                    self.config.base_discriminant
638                ),
639            }),
640        }
641    }
642
643    fn update_global_classes(&mut self) {
644        let mut all_classes = Vec::new();
645        for task in &self.tasks {
646            all_classes.extend(&task.classes);
647        }
648        all_classes.sort_unstable();
649        all_classes.dedup();
650        self.global_classes = all_classes;
651    }
652}
653
654impl Fit<Vec<Task>, ()> for MultiTaskDiscriminantLearning {
655    type Fitted = TrainedMultiTaskDiscriminantLearning;
656
657    fn fit(self, tasks: &Vec<Task>, _y: &()) -> Result<Self::Fitted> {
658        if tasks.is_empty() {
659            return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
660        }
661
662        // Validate that all tasks have the same number of features
663        let n_features = tasks[0].n_features();
664        for task in tasks {
665            if task.n_features() != n_features {
666                return Err(SklearsError::InvalidInput(
667                    "All tasks must have the same number of features".to_string(),
668                ));
669            }
670        }
671
672        // Compute task weights
673        let task_weights = self.compute_task_weights(tasks);
674
675        // Compute shared discriminant subspace
676        let shared_components = self.compute_shared_subspace(tasks)?;
677
678        // Compute task-specific components and train classifiers
679        let mut task_components = Vec::new();
680        let mut task_classifiers = Vec::new();
681
682        for task in tasks {
683            let task_comp = self.compute_task_specific_components(task, &shared_components)?;
684
685            // Transform task data using both shared and task-specific components
686            let shared_proj = task.x.dot(&shared_components.t());
687            let task_proj = task.x.dot(&task_comp.t());
688
689            let mut combined =
690                Array2::zeros((task.x.nrows(), shared_proj.ncols() + task_proj.ncols()));
691            combined
692                .slice_mut(s![.., ..shared_proj.ncols()])
693                .assign(&shared_proj);
694            combined
695                .slice_mut(s![.., shared_proj.ncols()..])
696                .assign(&task_proj);
697
698            // Train task classifier
699            let classifier = match self.config.base_discriminant.as_str() {
700                "lda" => {
701                    let lda = LinearDiscriminantAnalysis::new();
702                    let fitted_lda = lda.fit(&combined, &task.y)?;
703                    TaskClassifier::LDA(fitted_lda)
704                }
705                "qda" => {
706                    let qda = QuadraticDiscriminantAnalysis::new();
707                    let fitted_qda = qda.fit(&combined, &task.y)?;
708                    TaskClassifier::QDA(fitted_qda)
709                }
710                _ => {
711                    return Err(SklearsError::InvalidParameter {
712                        name: "base_discriminant".to_string(),
713                        reason: format!(
714                            "Unknown base discriminant: {}",
715                            self.config.base_discriminant
716                        ),
717                    })
718                }
719            };
720
721            task_components.push(task_comp);
722            task_classifiers.push(classifier);
723        }
724
725        // Compute global classes
726        let mut global_classes = Vec::new();
727        for task in tasks {
728            global_classes.extend(&task.classes);
729        }
730        global_classes.sort_unstable();
731        global_classes.dedup();
732
733        Ok(TrainedMultiTaskDiscriminantLearning {
734            shared_components,
735            task_components,
736            task_classifiers,
737            tasks: tasks.clone(),
738            task_weights,
739            global_classes,
740            config: self.config.clone(),
741        })
742    }
743}
744
745#[allow(non_snake_case)]
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use approx::assert_abs_diff_eq;
750    use scirs2_core::ndarray::array;
751
752    #[test]
753    fn test_task_creation() {
754        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
755        let y = array![0, 0, 1, 1];
756
757        let task = Task::new(0, x.clone(), y.clone()).unwrap();
758
759        assert_eq!(task.task_id, 0);
760        assert_eq!(task.n_samples(), 4);
761        assert_eq!(task.n_features(), 2);
762        assert_eq!(task.n_classes(), 2);
763        assert_eq!(task.classes, vec![0, 1]);
764    }
765
766    #[test]
767    fn test_multi_task_discriminant_learning_basic() {
768        // Create two related tasks
769        let task1_x = array![
770            [1.0, 2.0],
771            [1.1, 2.1],
772            [1.2, 2.2], // Class 0
773            [3.0, 4.0],
774            [3.1, 4.1],
775            [3.2, 4.2] // Class 1
776        ];
777        let task1_y = array![0, 0, 0, 1, 1, 1];
778
779        let task2_x = array![
780            [1.5, 2.5],
781            [1.6, 2.6],
782            [1.7, 2.7], // Class 0
783            [3.5, 4.5],
784            [3.6, 4.6],
785            [3.7, 4.7] // Class 1
786        ];
787        let task2_y = array![0, 0, 0, 1, 1, 1];
788
789        let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
790        let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
791        let tasks = vec![task1, task2];
792
793        let mtdl = MultiTaskDiscriminantLearning::new();
794        let fitted = mtdl.fit(&tasks, &()).unwrap();
795
796        // Test prediction for task 1
797        let predictions = fitted.predict_task(&task1_x, 0).unwrap();
798        assert_eq!(predictions.len(), 6);
799
800        // Test prediction for task 2
801        let predictions = fitted.predict_task(&task2_x, 1).unwrap();
802        assert_eq!(predictions.len(), 6);
803    }
804
805    #[test]
806    fn test_multi_task_predict_proba() {
807        let task1_x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
808        let task1_y = array![0, 0, 1, 1];
809
810        let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
811        let task2_y = array![0, 0, 1, 1];
812
813        let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
814        let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
815        let tasks = vec![task1, task2];
816
817        let mtdl = MultiTaskDiscriminantLearning::new();
818        let fitted = mtdl.fit(&tasks, &()).unwrap();
819
820        let probas = fitted.predict_proba_task(&task1_x, 0).unwrap();
821        assert_eq!(probas.dim(), (4, 2));
822
823        // Check that probabilities sum to 1
824        for row in probas.axis_iter(Axis(0)) {
825            let sum: Float = row.sum();
826            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
827        }
828    }
829
830    #[test]
831    fn test_multi_task_transform() {
832        let task1_x = array![
833            [1.0, 2.0, 0.5],
834            [2.0, 1.0, 1.5],
835            [3.0, 4.0, 2.0],
836            [4.0, 3.0, 3.5],
837            [5.0, 2.0, 4.0],
838            [6.0, 1.0, 4.5]
839        ];
840        let task1_y = array![0, 0, 1, 1, 2, 2];
841
842        let task2_x = array![
843            [1.5, 2.5, 0.8],
844            [2.5, 1.5, 1.8],
845            [3.5, 4.5, 2.5],
846            [4.5, 3.5, 3.8],
847            [5.5, 2.5, 4.3],
848            [6.5, 1.5, 4.8]
849        ];
850        let task2_y = array![0, 0, 1, 1, 2, 2];
851
852        let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
853        let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
854        let tasks = vec![task1, task2];
855
856        let mtdl = MultiTaskDiscriminantLearning::new()
857            .n_shared_components(Some(2))
858            .n_task_components(Some(1));
859        let fitted = mtdl.fit(&tasks, &()).unwrap();
860
861        // Test shared transformation
862        let shared_transformed = fitted.transform_shared(&task1_x).unwrap();
863        assert!(shared_transformed.ncols() >= 1); // Ensure we get some components
864
865        // Test task-specific transformation
866        let task_transformed = fitted.transform_task(&task1_x, 0).unwrap();
867        assert!(task_transformed.ncols() >= shared_transformed.ncols()); // Should include shared + task components
868    }
869
870    #[test]
871    fn test_multi_task_with_qda() {
872        let task1_x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
873        let task1_y = array![0, 0, 1, 1];
874
875        let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
876        let task2_y = array![0, 0, 1, 1];
877
878        let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
879        let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
880        let tasks = vec![task1, task2];
881
882        let mtdl = MultiTaskDiscriminantLearning::new().base_discriminant("qda");
883        let fitted = mtdl.fit(&tasks, &()).unwrap();
884
885        let predictions = fitted.predict_task(&task1_x, 0).unwrap();
886        assert_eq!(predictions.len(), 4);
887    }
888
889    #[test]
890    fn test_task_weighting_strategies() {
891        let task1_x = array![[1.0, 2.0], [2.0, 3.0]];
892        let task1_y = array![0, 1];
893
894        let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
895        let task2_y = array![0, 0, 1, 1];
896
897        let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
898        let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
899        let tasks = vec![task1, task2];
900
901        let strategies = ["uniform", "proportional", "inverse"];
902        for strategy in &strategies {
903            let mtdl = MultiTaskDiscriminantLearning::new().task_weighting(strategy);
904            let fitted = mtdl.fit(&tasks, &()).unwrap();
905
906            assert_eq!(fitted.task_weights().len(), 2);
907            assert!(fitted.task_weights().iter().all(|&w| w > 0.0));
908        }
909    }
910
911    #[test]
912    fn test_add_new_task() {
913        let task1_x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
914        let task1_y = array![0, 0, 1, 1];
915
916        let task2_x = array![[1.5, 2.5], [2.5, 3.5], [3.5, 4.5], [4.5, 5.5]];
917        let task2_y = array![0, 0, 1, 1];
918
919        let task1 = Task::new(0, task1_x.clone(), task1_y.clone()).unwrap();
920        let task2 = Task::new(1, task2_x.clone(), task2_y.clone()).unwrap();
921        let tasks = vec![task1, task2];
922
923        let mtdl = MultiTaskDiscriminantLearning::new();
924        let mut fitted = mtdl.fit(&tasks, &()).unwrap();
925
926        // Add a new task
927        let task3_x = array![[2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
928        let task3_y = array![0, 0, 1, 1];
929        let task3 = Task::new(2, task3_x.clone(), task3_y.clone()).unwrap();
930
931        let new_task_id = fitted.add_task(task3).unwrap();
932        assert_eq!(new_task_id, 2);
933
934        // Test prediction for new task
935        let predictions = fitted.predict_task(&task3_x, new_task_id).unwrap();
936        assert_eq!(predictions.len(), 4);
937    }
938}