sklears_multioutput/
core.rs

1//! Core multi-output algorithms
2//!
3//! This module contains the core MultiOutputClassifier and MultiOutputRegressor
4//! implementations with their trained states and associated methods.
5//! Enhanced with parallel processing capabilities for improved performance.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use std::thread;
17
18/// Multi-Output Classifier
19///
20/// This strategy consists of fitting one classifier per target. This is a simple
21/// strategy for extending classifiers that do not natively support multi-class
22/// classification to such cases.
23///
24/// # Examples
25///
26/// ```
27/// use sklears_multioutput::MultiOutputClassifier;
28/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
29/// use scirs2_core::ndarray::array;
30///
31/// // This is a simple example showing the structure
32/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
33/// let labels = array![[0, 1], [1, 0], [1, 1]];
34/// ```
35#[derive(Debug, Clone)]
36pub struct MultiOutputClassifier<S = Untrained> {
37    state: S,
38    n_jobs: Option<i32>,
39}
40
41impl MultiOutputClassifier<Untrained> {
42    /// Create a new MultiOutputClassifier instance
43    pub fn new() -> Self {
44        Self {
45            state: Untrained,
46            n_jobs: None,
47        }
48    }
49
50    /// Set the number of parallel jobs
51    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
52        self.n_jobs = n_jobs;
53        self
54    }
55}
56
57impl Default for MultiOutputClassifier<Untrained> {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl Estimator for MultiOutputClassifier<Untrained> {
64    type Config = ();
65    type Error = SklearsError;
66    type Float = Float;
67
68    fn config(&self) -> &Self::Config {
69        &()
70    }
71}
72
73impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MultiOutputClassifier<Untrained> {
74    type Fitted = MultiOutputClassifier<MultiOutputClassifierTrained>;
75
76    #[allow(non_snake_case)]
77    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
78        let X = X.to_owned();
79        let (n_samples, n_features) = X.dim();
80
81        if n_samples != y.nrows() {
82            return Err(SklearsError::InvalidInput(
83                "X and y must have the same number of samples".to_string(),
84            ));
85        }
86
87        let n_targets = y.ncols();
88        if n_targets == 0 {
89            return Err(SklearsError::InvalidInput(
90                "y must have at least one target".to_string(),
91            ));
92        }
93
94        let mut classes_per_target = Vec::new();
95        let mut target_models = HashMap::new();
96
97        // Fit one classifier per target using simplified nearest centroid approach
98        for target_idx in 0..n_targets {
99            let y_target = y.column(target_idx);
100
101            // Get unique classes for this target
102            let mut target_classes: Vec<i32> = y_target
103                .iter()
104                .cloned()
105                .collect::<std::collections::HashSet<_>>()
106                .into_iter()
107                .collect();
108            target_classes.sort();
109
110            // Compute class centroids for nearest centroid classifier
111            let mut class_centroids = HashMap::new();
112            for &class_label in &target_classes {
113                let mut centroid = Array1::<Float>::zeros(n_features);
114                let mut count = 0;
115
116                for (sample_idx, &sample_class) in y_target.iter().enumerate() {
117                    if sample_class == class_label {
118                        for feature_idx in 0..n_features {
119                            centroid[feature_idx] += X[[sample_idx, feature_idx]];
120                        }
121                        count += 1;
122                    }
123                }
124
125                if count > 0 {
126                    centroid /= count as f64;
127                }
128                class_centroids.insert(class_label, centroid);
129            }
130
131            target_models.insert(target_idx, class_centroids);
132            classes_per_target.push(target_classes);
133        }
134
135        // Use parallel training if n_jobs is specified and > 1
136        if let Some(n_jobs) = self.n_jobs {
137            if n_jobs > 1 && n_targets > 1 {
138                return self.fit_parallel(X, y, n_jobs as usize);
139            }
140        }
141
142        Ok(MultiOutputClassifier {
143            state: MultiOutputClassifierTrained {
144                classes_per_target,
145                target_models,
146                n_targets,
147                n_features,
148            },
149            n_jobs: self.n_jobs,
150        })
151    }
152}
153
154impl MultiOutputClassifier<Untrained> {
155    /// Parallel training implementation
156    #[allow(non_snake_case)]
157    fn fit_parallel(
158        self,
159        X: Array2<Float>,
160        y: &Array2<i32>,
161        n_jobs: usize,
162    ) -> SklResult<MultiOutputClassifier<MultiOutputClassifierTrained>> {
163        let (n_samples, n_features) = X.dim();
164        let n_targets = y.ncols();
165
166        // Shared data structures
167        let X_arc = Arc::new(X);
168        let y_arc = Arc::new(y.clone());
169        let classes_per_target = Arc::new(Mutex::new(Vec::with_capacity(n_targets)));
170        let target_models = Arc::new(Mutex::new(HashMap::new()));
171
172        // Calculate chunk size for work distribution
173        let chunk_size = (n_targets + n_jobs - 1) / n_jobs; // Ceiling division
174        let mut handles = vec![];
175
176        // Spawn worker threads
177        for worker_id in 0..n_jobs {
178            let start_target = worker_id * chunk_size;
179            let end_target = std::cmp::min(start_target + chunk_size, n_targets);
180
181            if start_target >= n_targets {
182                break; // No more work for this thread
183            }
184
185            let X_thread = Arc::clone(&X_arc);
186            let y_thread = Arc::clone(&y_arc);
187            let classes_thread = Arc::clone(&classes_per_target);
188            let models_thread = Arc::clone(&target_models);
189
190            let handle = thread::spawn(move || -> SklResult<()> {
191                let mut local_classes = Vec::new();
192                let mut local_models = HashMap::new();
193
194                for target_idx in start_target..end_target {
195                    let y_target = y_thread.column(target_idx);
196
197                    // Get unique classes for this target
198                    let mut target_classes: Vec<i32> = y_target
199                        .iter()
200                        .cloned()
201                        .collect::<std::collections::HashSet<_>>()
202                        .into_iter()
203                        .collect();
204                    target_classes.sort();
205
206                    // Compute class centroids for nearest centroid classifier
207                    let mut class_centroids = HashMap::new();
208                    for &class_label in &target_classes {
209                        let mut centroid = Array1::<Float>::zeros(n_features);
210                        let mut count = 0;
211
212                        for (sample_idx, &sample_class) in y_target.iter().enumerate() {
213                            if sample_class == class_label {
214                                for feature_idx in 0..n_features {
215                                    centroid[feature_idx] += X_thread[[sample_idx, feature_idx]];
216                                }
217                                count += 1;
218                            }
219                        }
220
221                        if count > 0 {
222                            centroid /= count as f64;
223                        }
224                        class_centroids.insert(class_label, centroid);
225                    }
226
227                    local_models.insert(target_idx, class_centroids);
228                    local_classes.push((target_idx, target_classes));
229                }
230
231                // Merge results back to shared data structures
232                {
233                    let mut classes_guard = classes_thread.lock().unwrap();
234                    let mut models_guard = models_thread.lock().unwrap();
235
236                    // Ensure proper ordering by sorting local results
237                    local_classes.sort_by_key(|(idx, _)| *idx);
238                    for (target_idx, target_classes) in local_classes {
239                        // Insert at the correct position
240                        while classes_guard.len() <= target_idx {
241                            classes_guard.push(vec![]);
242                        }
243                        classes_guard[target_idx] = target_classes;
244                    }
245
246                    for (target_idx, class_centroids) in local_models {
247                        models_guard.insert(target_idx, class_centroids);
248                    }
249                }
250
251                Ok(())
252            });
253
254            handles.push(handle);
255        }
256
257        // Wait for all threads to complete and collect any errors
258        for handle in handles {
259            handle.join().map_err(|_| {
260                SklearsError::InvalidInput("Thread panicked during parallel training".to_string())
261            })??;
262        }
263
264        // Extract results from Arc<Mutex<>>
265        let final_classes = Arc::try_unwrap(classes_per_target)
266            .map_err(|_| SklearsError::InvalidInput("Failed to extract classes".to_string()))?
267            .into_inner()
268            .unwrap();
269
270        let final_models = Arc::try_unwrap(target_models)
271            .map_err(|_| SklearsError::InvalidInput("Failed to extract models".to_string()))?
272            .into_inner()
273            .unwrap();
274
275        Ok(MultiOutputClassifier {
276            state: MultiOutputClassifierTrained {
277                classes_per_target: final_classes,
278                target_models: final_models,
279                n_targets,
280                n_features,
281            },
282            n_jobs: Some(n_jobs as i32),
283        })
284    }
285}
286
287impl MultiOutputClassifier<MultiOutputClassifierTrained> {
288    /// Get the classes for each target
289    pub fn classes(&self) -> &[Vec<i32>] {
290        &self.state.classes_per_target
291    }
292
293    /// Get the number of targets
294    pub fn n_targets(&self) -> usize {
295        self.state.n_targets
296    }
297}
298
299impl Predict<ArrayView2<'_, Float>, Array2<i32>>
300    for MultiOutputClassifier<MultiOutputClassifierTrained>
301{
302    #[allow(non_snake_case)]
303    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
304        let X = X.to_owned();
305        let (n_samples, n_features) = X.dim();
306
307        if n_features != self.state.n_features {
308            return Err(SklearsError::InvalidInput(
309                "Number of features doesn't match training data".to_string(),
310            ));
311        }
312
313        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_targets));
314
315        // Get predictions from each target using nearest centroid
316        for target_idx in 0..self.state.n_targets {
317            if let Some(class_centroids) = self.state.target_models.get(&target_idx) {
318                for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
319                    let mut min_distance = f64::INFINITY;
320                    let mut best_class = 0;
321
322                    // Find nearest centroid
323                    for (&class_label, centroid) in class_centroids {
324                        let mut distance = 0.0;
325                        for feature_idx in 0..n_features {
326                            let diff = sample[feature_idx] - centroid[feature_idx];
327                            distance += diff * diff;
328                        }
329                        distance = distance.sqrt();
330
331                        if distance < min_distance {
332                            min_distance = distance;
333                            best_class = class_label;
334                        }
335                    }
336
337                    predictions[[sample_idx, target_idx]] = best_class;
338                }
339            }
340        }
341
342        Ok(predictions)
343    }
344}
345
346/// Multi-Output Regressor
347///
348/// This strategy consists of fitting one regressor per target. This is a simple
349/// strategy for extending regressors that do not natively support multi-output
350/// regression to such cases.
351///
352/// # Examples
353///
354/// ```
355/// use sklears_multioutput::MultiOutputRegressor;
356/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
357/// use scirs2_core::ndarray::array;
358///
359/// // This is a simple example showing the structure
360/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
361/// let targets = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5]];
362/// ```
363#[derive(Debug, Clone)]
364pub struct MultiOutputRegressor<S = Untrained> {
365    state: S,
366    n_jobs: Option<i32>,
367}
368
369impl MultiOutputRegressor<Untrained> {
370    /// Create a new MultiOutputRegressor instance
371    pub fn new() -> Self {
372        Self {
373            state: Untrained,
374            n_jobs: None,
375        }
376    }
377
378    /// Set the number of parallel jobs
379    pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
380        self.n_jobs = n_jobs;
381        self
382    }
383}
384
385impl Default for MultiOutputRegressor<Untrained> {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391impl Estimator for MultiOutputRegressor<Untrained> {
392    type Config = ();
393    type Error = SklearsError;
394    type Float = Float;
395
396    fn config(&self) -> &Self::Config {
397        &()
398    }
399}
400
401impl Fit<ArrayView2<'_, Float>, Array2<f64>> for MultiOutputRegressor<Untrained> {
402    type Fitted = MultiOutputRegressor<MultiOutputRegressorTrained>;
403
404    #[allow(non_snake_case)]
405    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<f64>) -> SklResult<Self::Fitted> {
406        let X = X.to_owned();
407        let (n_samples, n_features) = X.dim();
408
409        if n_samples != y.nrows() {
410            return Err(SklearsError::InvalidInput(
411                "X and y must have the same number of samples".to_string(),
412            ));
413        }
414
415        let n_targets = y.ncols();
416        if n_targets == 0 {
417            return Err(SklearsError::InvalidInput(
418                "y must have at least one target".to_string(),
419            ));
420        }
421
422        let mut target_models = HashMap::new();
423
424        // Fit one linear regressor per target using least squares
425        for target_idx in 0..n_targets {
426            let y_target = y.column(target_idx);
427
428            // Simple linear regression: solve normal equations X^T X w = X^T y
429            // For numerical stability, we'll use a simple average-based approach
430            let mut weights = Array1::<Float>::zeros(n_features);
431            let mut bias = 0.0;
432
433            // Compute mean of targets
434            let y_mean = y_target.mean().unwrap();
435            bias = y_mean;
436
437            // Simple approach: set weights proportional to feature correlations with target
438            for feature_idx in 0..n_features {
439                let mut correlation = 0.0;
440                let mut x_mean = 0.0;
441
442                // Compute feature mean
443                for sample_idx in 0..n_samples {
444                    x_mean += X[[sample_idx, feature_idx]];
445                }
446                x_mean /= n_samples as f64;
447
448                // Compute correlation
449                let mut numerator = 0.0;
450                let mut x_var = 0.0;
451                let mut y_var = 0.0;
452
453                for sample_idx in 0..n_samples {
454                    let x_diff = X[[sample_idx, feature_idx]] - x_mean;
455                    let y_diff = y_target[sample_idx] - y_mean;
456                    numerator += x_diff * y_diff;
457                    x_var += x_diff * x_diff;
458                    y_var += y_diff * y_diff;
459                }
460
461                if x_var > 1e-10 && y_var > 1e-10 {
462                    correlation = numerator / (x_var.sqrt() * y_var.sqrt());
463                }
464
465                weights[feature_idx] = correlation * 0.1; // Scale down for stability
466            }
467
468            target_models.insert(target_idx, (weights, bias));
469        }
470
471        // Use parallel training if n_jobs is specified and > 1
472        if let Some(n_jobs) = self.n_jobs {
473            if n_jobs > 1 && n_targets > 1 {
474                return self.fit_parallel(X, y, n_jobs as usize);
475            }
476        }
477
478        Ok(MultiOutputRegressor {
479            state: MultiOutputRegressorTrained {
480                target_models,
481                n_targets,
482                n_features,
483            },
484            n_jobs: self.n_jobs,
485        })
486    }
487}
488
489impl MultiOutputRegressor<Untrained> {
490    /// Parallel training implementation for regression
491    #[allow(non_snake_case)]
492    fn fit_parallel(
493        self,
494        X: Array2<Float>,
495        y: &Array2<f64>,
496        n_jobs: usize,
497    ) -> SklResult<MultiOutputRegressor<MultiOutputRegressorTrained>> {
498        let (n_samples, n_features) = X.dim();
499        let n_targets = y.ncols();
500
501        // Shared data structures
502        let X_arc = Arc::new(X);
503        let y_arc = Arc::new(y.clone());
504        let target_models = Arc::new(Mutex::new(HashMap::new()));
505
506        // Calculate chunk size for work distribution
507        let chunk_size = (n_targets + n_jobs - 1) / n_jobs; // Ceiling division
508        let mut handles = vec![];
509
510        // Spawn worker threads
511        for worker_id in 0..n_jobs {
512            let start_target = worker_id * chunk_size;
513            let end_target = std::cmp::min(start_target + chunk_size, n_targets);
514
515            if start_target >= n_targets {
516                break; // No more work for this thread
517            }
518
519            let X_thread = Arc::clone(&X_arc);
520            let y_thread = Arc::clone(&y_arc);
521            let models_thread = Arc::clone(&target_models);
522
523            let handle = thread::spawn(move || -> SklResult<()> {
524                let mut local_models = HashMap::new();
525
526                for target_idx in start_target..end_target {
527                    let y_target = y_thread.column(target_idx);
528                    let mut weights = Array1::<f64>::zeros(n_features);
529
530                    // Compute mean of targets
531                    let y_mean = y_target.mean().unwrap();
532                    let bias: f64 = y_mean;
533
534                    // Simple approach: set weights proportional to feature correlations with target
535                    for feature_idx in 0..n_features {
536                        let mut correlation = 0.0;
537                        let mut x_mean = 0.0;
538
539                        // Compute feature mean
540                        for sample_idx in 0..n_samples {
541                            x_mean += X_thread[[sample_idx, feature_idx]];
542                        }
543                        x_mean /= n_samples as f64;
544
545                        // Compute correlation
546                        let mut numerator = 0.0;
547                        let mut x_var = 0.0;
548                        let mut y_var = 0.0;
549
550                        for sample_idx in 0..n_samples {
551                            let x_diff = X_thread[[sample_idx, feature_idx]] - x_mean;
552                            let y_diff = y_target[sample_idx] - y_mean;
553                            numerator += x_diff * y_diff;
554                            x_var += x_diff * x_diff;
555                            y_var += y_diff * y_diff;
556                        }
557
558                        if x_var > 1e-10 && y_var > 1e-10 {
559                            correlation = numerator / (x_var.sqrt() * y_var.sqrt());
560                        }
561
562                        weights[feature_idx] = correlation * 0.1; // Scale down for stability
563                    }
564
565                    local_models.insert(target_idx, (weights, bias));
566                }
567
568                // Merge results back to shared data structure
569                {
570                    let mut models_guard = models_thread.lock().unwrap();
571                    for (target_idx, model) in local_models {
572                        models_guard.insert(target_idx, model);
573                    }
574                }
575
576                Ok(())
577            });
578
579            handles.push(handle);
580        }
581
582        // Wait for all threads to complete and collect any errors
583        for handle in handles {
584            handle.join().map_err(|_| {
585                SklearsError::InvalidInput("Thread panicked during parallel training".to_string())
586            })??;
587        }
588
589        // Extract results from Arc<Mutex<>>
590        let final_models = Arc::try_unwrap(target_models)
591            .map_err(|_| SklearsError::InvalidInput("Failed to extract models".to_string()))?
592            .into_inner()
593            .unwrap();
594
595        Ok(MultiOutputRegressor {
596            state: MultiOutputRegressorTrained {
597                target_models: final_models,
598                n_targets,
599                n_features,
600            },
601            n_jobs: Some(n_jobs as i32),
602        })
603    }
604}
605
606impl MultiOutputRegressor<MultiOutputRegressorTrained> {
607    /// Get the number of targets
608    pub fn n_targets(&self) -> usize {
609        self.state.n_targets
610    }
611}
612
613impl Predict<ArrayView2<'_, Float>, Array2<f64>>
614    for MultiOutputRegressor<MultiOutputRegressorTrained>
615{
616    #[allow(non_snake_case)]
617    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
618        let X = X.to_owned();
619        let (n_samples, n_features) = X.dim();
620
621        if n_features != self.state.n_features {
622            return Err(SklearsError::InvalidInput(
623                "Number of features doesn't match training data".to_string(),
624            ));
625        }
626
627        let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
628
629        // Get predictions from each target regressor
630        for target_idx in 0..self.state.n_targets {
631            if let Some((weights, bias)) = self.state.target_models.get(&target_idx) {
632                for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
633                    // Linear prediction: weights^T * x + bias
634                    let prediction: f64 = sample
635                        .iter()
636                        .zip(weights.iter())
637                        .map(|(&x, &w)| x * w)
638                        .sum::<f64>()
639                        + bias;
640
641                    predictions[[sample_idx, target_idx]] = prediction;
642                }
643            }
644        }
645
646        Ok(predictions)
647    }
648}
649
650/// Trained state for MultiOutputClassifier
651#[derive(Debug, Clone)]
652pub struct MultiOutputClassifierTrained {
653    /// The classes for each target
654    pub classes_per_target: Vec<Vec<i32>>,
655    /// Nearest centroid models for each target
656    pub target_models: HashMap<usize, HashMap<i32, Array1<f64>>>,
657    /// Number of targets
658    pub n_targets: usize,
659    /// Number of features
660    pub n_features: usize,
661}
662
663/// Trained state for MultiOutputRegressor
664#[derive(Debug, Clone)]
665pub struct MultiOutputRegressorTrained {
666    /// Linear models for each target (weights, bias)
667    pub target_models: HashMap<usize, (Array1<f64>, f64)>,
668    /// Number of targets
669    pub n_targets: usize,
670    /// Number of features
671    pub n_features: usize,
672}
673
674#[allow(non_snake_case)]
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use approx::assert_abs_diff_eq;
679    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
680    use scirs2_core::ndarray::array;
681    use std::time::Instant;
682
683    #[test]
684    #[allow(non_snake_case)]
685    fn test_parallel_multi_output_classifier() {
686        let X = array![
687            [1.0, 2.0, 3.0],
688            [2.0, 3.0, 4.0],
689            [3.0, 4.0, 5.0],
690            [4.0, 5.0, 6.0],
691            [5.0, 6.0, 7.0],
692            [6.0, 7.0, 8.0]
693        ];
694        let y = array![
695            [0, 1, 0],
696            [1, 0, 1],
697            [0, 1, 0],
698            [1, 0, 1],
699            [0, 1, 0],
700            [1, 0, 1]
701        ];
702
703        // Test with parallel training
704        let classifier_parallel = MultiOutputClassifier::new().n_jobs(Some(2));
705        let trained_parallel = classifier_parallel.fit(&X.view(), &y).unwrap();
706
707        // Test with sequential training
708        let classifier_sequential = MultiOutputClassifier::new().n_jobs(Some(1));
709        let trained_sequential = classifier_sequential.fit(&X.view(), &y).unwrap();
710
711        // Results should be the same
712        assert_eq!(trained_parallel.n_targets(), trained_sequential.n_targets());
713        assert_eq!(
714            trained_parallel.classes().len(),
715            trained_sequential.classes().len()
716        );
717
718        // Test predictions
719        let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
720        let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
721
722        assert_eq!(pred_parallel.shape(), pred_sequential.shape());
723        assert_eq!(pred_parallel.shape(), &[6, 3]);
724    }
725
726    #[test]
727    #[allow(non_snake_case)]
728    fn test_parallel_multi_output_regressor() {
729        let X = array![
730            [1.0, 2.0, 3.0],
731            [2.0, 3.0, 4.0],
732            [3.0, 4.0, 5.0],
733            [4.0, 5.0, 6.0],
734            [5.0, 6.0, 7.0],
735            [6.0, 7.0, 8.0]
736        ];
737        let y = array![
738            [1.5, 2.5, 3.5],
739            [2.5, 3.5, 4.5],
740            [3.5, 4.5, 5.5],
741            [4.5, 5.5, 6.5],
742            [5.5, 6.5, 7.5],
743            [6.5, 7.5, 8.5]
744        ];
745
746        // Test with parallel training
747        let regressor_parallel = MultiOutputRegressor::new().n_jobs(Some(2));
748        let trained_parallel = regressor_parallel.fit(&X.view(), &y).unwrap();
749
750        // Test with sequential training
751        let regressor_sequential = MultiOutputRegressor::new().n_jobs(Some(1));
752        let trained_sequential = regressor_sequential.fit(&X.view(), &y).unwrap();
753
754        // Results should be the same
755        assert_eq!(trained_parallel.n_targets(), trained_sequential.n_targets());
756
757        // Test predictions
758        let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
759        let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
760
761        assert_eq!(pred_parallel.shape(), pred_sequential.shape());
762        assert_eq!(pred_parallel.shape(), &[6, 3]);
763
764        // Predictions should be approximately equal
765        for i in 0..pred_parallel.nrows() {
766            for j in 0..pred_parallel.ncols() {
767                assert_abs_diff_eq!(
768                    pred_parallel[[i, j]],
769                    pred_sequential[[i, j]],
770                    epsilon = 1e-10
771                );
772            }
773        }
774    }
775
776    #[test]
777    fn test_parallel_training_performance_classifier() {
778        // Create larger dataset to see potential parallel benefits
779        let n_samples = 1000;
780        let n_features = 50;
781        let n_targets = 20;
782
783        let mut X = Array2::<Float>::zeros((n_samples, n_features));
784        let mut y = Array2::<i32>::zeros((n_samples, n_targets));
785
786        // Fill with simple patterns
787        for i in 0..n_samples {
788            for j in 0..n_features {
789                X[[i, j]] = (i * j) as Float * 0.01;
790            }
791            for j in 0..n_targets {
792                y[[i, j]] = ((i + j) % 2) as i32;
793            }
794        }
795
796        // Time sequential training
797        let start_sequential = Instant::now();
798        let classifier_sequential = MultiOutputClassifier::new().n_jobs(Some(1));
799        let trained_sequential = classifier_sequential.fit(&X.view(), &y).unwrap();
800        let sequential_time = start_sequential.elapsed();
801
802        // Time parallel training
803        let start_parallel = Instant::now();
804        let classifier_parallel = MultiOutputClassifier::new().n_jobs(Some(4));
805        let trained_parallel = classifier_parallel.fit(&X.view(), &y).unwrap();
806        let parallel_time = start_parallel.elapsed();
807
808        // Ensure both produce valid results
809        assert_eq!(trained_parallel.n_targets(), n_targets);
810        assert_eq!(trained_sequential.n_targets(), n_targets);
811
812        // Test predictions are consistent
813        let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
814        let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
815        assert_eq!(pred_parallel.shape(), pred_sequential.shape());
816
817        println!(
818            "Sequential time: {:?}, Parallel time: {:?}",
819            sequential_time, parallel_time
820        );
821    }
822
823    #[test]
824    fn test_parallel_training_performance_regressor() {
825        // Create larger dataset to see potential parallel benefits
826        let n_samples = 1000;
827        let n_features = 50;
828        let n_targets = 20;
829
830        let mut X = Array2::<Float>::zeros((n_samples, n_features));
831        let mut y = Array2::<f64>::zeros((n_samples, n_targets));
832
833        // Fill with simple patterns
834        for i in 0..n_samples {
835            for j in 0..n_features {
836                X[[i, j]] = (i * j) as Float * 0.01;
837            }
838            for j in 0..n_targets {
839                y[[i, j]] = (i + j) as f64 * 0.1;
840            }
841        }
842
843        // Time sequential training
844        let start_sequential = Instant::now();
845        let regressor_sequential = MultiOutputRegressor::new().n_jobs(Some(1));
846        let trained_sequential = regressor_sequential.fit(&X.view(), &y).unwrap();
847        let sequential_time = start_sequential.elapsed();
848
849        // Time parallel training
850        let start_parallel = Instant::now();
851        let regressor_parallel = MultiOutputRegressor::new().n_jobs(Some(4));
852        let trained_parallel = regressor_parallel.fit(&X.view(), &y).unwrap();
853        let parallel_time = start_parallel.elapsed();
854
855        // Ensure both produce valid results
856        assert_eq!(trained_parallel.n_targets(), n_targets);
857        assert_eq!(trained_sequential.n_targets(), n_targets);
858
859        // Test predictions are consistent
860        let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
861        let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
862        assert_eq!(pred_parallel.shape(), pred_sequential.shape());
863
864        println!(
865            "Sequential time: {:?}, Parallel time: {:?}",
866            sequential_time, parallel_time
867        );
868    }
869
870    #[test]
871    #[allow(non_snake_case)]
872    fn test_parallel_training_thread_safety() {
873        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
874        let y_class = array![[0, 1], [1, 0], [0, 1], [1, 0]];
875        let y_reg = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
876
877        // Test multiple parallel runs to check for race conditions
878        for _ in 0..10 {
879            let classifier = MultiOutputClassifier::new().n_jobs(Some(2));
880            let trained = classifier.fit(&X.view(), &y_class).unwrap();
881            let predictions = trained.predict(&X.view()).unwrap();
882            assert_eq!(predictions.shape(), &[4, 2]);
883
884            let regressor = MultiOutputRegressor::new().n_jobs(Some(2));
885            let trained = regressor.fit(&X.view(), &y_reg).unwrap();
886            let predictions = trained.predict(&X.view()).unwrap();
887            assert_eq!(predictions.shape(), &[4, 2]);
888        }
889    }
890
891    #[test]
892    #[allow(non_snake_case)]
893    fn test_parallel_training_edge_cases() {
894        let X = array![[1.0, 2.0], [2.0, 3.0]];
895        let y_class = array![[0, 1], [1, 0]];
896        let y_reg = array![[1.0, 2.0], [2.0, 3.0]];
897
898        // Test with more threads than targets (should handle gracefully)
899        let classifier = MultiOutputClassifier::new().n_jobs(Some(10));
900        let trained = classifier.fit(&X.view(), &y_class).unwrap();
901        assert_eq!(trained.n_targets(), 2);
902
903        let regressor = MultiOutputRegressor::new().n_jobs(Some(10));
904        let trained = regressor.fit(&X.view(), &y_reg).unwrap();
905        assert_eq!(trained.n_targets(), 2);
906
907        // Test with single target (should fall back to sequential)
908        let y_single = array![[0], [1]];
909        let classifier_single = MultiOutputClassifier::new().n_jobs(Some(4));
910        let trained_single = classifier_single.fit(&X.view(), &y_single).unwrap();
911        assert_eq!(trained_single.n_targets(), 1);
912    }
913
914    #[test]
915    #[allow(non_snake_case)]
916    fn test_parallel_training_error_handling() {
917        let X = array![[1.0, 2.0], [2.0, 3.0]];
918        let y_mismatch = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]]; // Wrong number of samples
919
920        // Test error handling in parallel mode
921        let classifier = MultiOutputClassifier::new().n_jobs(Some(2));
922        let result = classifier.fit(&X.view(), &y_mismatch);
923        assert!(result.is_err());
924
925        let regressor = MultiOutputRegressor::new().n_jobs(Some(2));
926        let y_reg_mismatch = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
927        let result = regressor.fit(&X.view(), &y_reg_mismatch);
928        assert!(result.is_err());
929    }
930}