1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use std::thread;
17
18#[derive(Debug, Clone)]
36pub struct MultiOutputClassifier<S = Untrained> {
37 state: S,
38 n_jobs: Option<i32>,
39}
40
41impl MultiOutputClassifier<Untrained> {
42 pub fn new() -> Self {
44 Self {
45 state: Untrained,
46 n_jobs: None,
47 }
48 }
49
50 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
52 self.n_jobs = n_jobs;
53 self
54 }
55}
56
57impl Default for MultiOutputClassifier<Untrained> {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl Estimator for MultiOutputClassifier<Untrained> {
64 type Config = ();
65 type Error = SklearsError;
66 type Float = Float;
67
68 fn config(&self) -> &Self::Config {
69 &()
70 }
71}
72
73impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MultiOutputClassifier<Untrained> {
74 type Fitted = MultiOutputClassifier<MultiOutputClassifierTrained>;
75
76 #[allow(non_snake_case)]
77 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
78 let X = X.to_owned();
79 let (n_samples, n_features) = X.dim();
80
81 if n_samples != y.nrows() {
82 return Err(SklearsError::InvalidInput(
83 "X and y must have the same number of samples".to_string(),
84 ));
85 }
86
87 let n_targets = y.ncols();
88 if n_targets == 0 {
89 return Err(SklearsError::InvalidInput(
90 "y must have at least one target".to_string(),
91 ));
92 }
93
94 let mut classes_per_target = Vec::new();
95 let mut target_models = HashMap::new();
96
97 for target_idx in 0..n_targets {
99 let y_target = y.column(target_idx);
100
101 let mut target_classes: Vec<i32> = y_target
103 .iter()
104 .cloned()
105 .collect::<std::collections::HashSet<_>>()
106 .into_iter()
107 .collect();
108 target_classes.sort();
109
110 let mut class_centroids = HashMap::new();
112 for &class_label in &target_classes {
113 let mut centroid = Array1::<Float>::zeros(n_features);
114 let mut count = 0;
115
116 for (sample_idx, &sample_class) in y_target.iter().enumerate() {
117 if sample_class == class_label {
118 for feature_idx in 0..n_features {
119 centroid[feature_idx] += X[[sample_idx, feature_idx]];
120 }
121 count += 1;
122 }
123 }
124
125 if count > 0 {
126 centroid /= count as f64;
127 }
128 class_centroids.insert(class_label, centroid);
129 }
130
131 target_models.insert(target_idx, class_centroids);
132 classes_per_target.push(target_classes);
133 }
134
135 if let Some(n_jobs) = self.n_jobs {
137 if n_jobs > 1 && n_targets > 1 {
138 return self.fit_parallel(X, y, n_jobs as usize);
139 }
140 }
141
142 Ok(MultiOutputClassifier {
143 state: MultiOutputClassifierTrained {
144 classes_per_target,
145 target_models,
146 n_targets,
147 n_features,
148 },
149 n_jobs: self.n_jobs,
150 })
151 }
152}
153
154impl MultiOutputClassifier<Untrained> {
155 #[allow(non_snake_case)]
157 fn fit_parallel(
158 self,
159 X: Array2<Float>,
160 y: &Array2<i32>,
161 n_jobs: usize,
162 ) -> SklResult<MultiOutputClassifier<MultiOutputClassifierTrained>> {
163 let (n_samples, n_features) = X.dim();
164 let n_targets = y.ncols();
165
166 let X_arc = Arc::new(X);
168 let y_arc = Arc::new(y.clone());
169 let classes_per_target = Arc::new(Mutex::new(Vec::with_capacity(n_targets)));
170 let target_models = Arc::new(Mutex::new(HashMap::new()));
171
172 let chunk_size = (n_targets + n_jobs - 1) / n_jobs; let mut handles = vec![];
175
176 for worker_id in 0..n_jobs {
178 let start_target = worker_id * chunk_size;
179 let end_target = std::cmp::min(start_target + chunk_size, n_targets);
180
181 if start_target >= n_targets {
182 break; }
184
185 let X_thread = Arc::clone(&X_arc);
186 let y_thread = Arc::clone(&y_arc);
187 let classes_thread = Arc::clone(&classes_per_target);
188 let models_thread = Arc::clone(&target_models);
189
190 let handle = thread::spawn(move || -> SklResult<()> {
191 let mut local_classes = Vec::new();
192 let mut local_models = HashMap::new();
193
194 for target_idx in start_target..end_target {
195 let y_target = y_thread.column(target_idx);
196
197 let mut target_classes: Vec<i32> = y_target
199 .iter()
200 .cloned()
201 .collect::<std::collections::HashSet<_>>()
202 .into_iter()
203 .collect();
204 target_classes.sort();
205
206 let mut class_centroids = HashMap::new();
208 for &class_label in &target_classes {
209 let mut centroid = Array1::<Float>::zeros(n_features);
210 let mut count = 0;
211
212 for (sample_idx, &sample_class) in y_target.iter().enumerate() {
213 if sample_class == class_label {
214 for feature_idx in 0..n_features {
215 centroid[feature_idx] += X_thread[[sample_idx, feature_idx]];
216 }
217 count += 1;
218 }
219 }
220
221 if count > 0 {
222 centroid /= count as f64;
223 }
224 class_centroids.insert(class_label, centroid);
225 }
226
227 local_models.insert(target_idx, class_centroids);
228 local_classes.push((target_idx, target_classes));
229 }
230
231 {
233 let mut classes_guard = classes_thread.lock().unwrap();
234 let mut models_guard = models_thread.lock().unwrap();
235
236 local_classes.sort_by_key(|(idx, _)| *idx);
238 for (target_idx, target_classes) in local_classes {
239 while classes_guard.len() <= target_idx {
241 classes_guard.push(vec![]);
242 }
243 classes_guard[target_idx] = target_classes;
244 }
245
246 for (target_idx, class_centroids) in local_models {
247 models_guard.insert(target_idx, class_centroids);
248 }
249 }
250
251 Ok(())
252 });
253
254 handles.push(handle);
255 }
256
257 for handle in handles {
259 handle.join().map_err(|_| {
260 SklearsError::InvalidInput("Thread panicked during parallel training".to_string())
261 })??;
262 }
263
264 let final_classes = Arc::try_unwrap(classes_per_target)
266 .map_err(|_| SklearsError::InvalidInput("Failed to extract classes".to_string()))?
267 .into_inner()
268 .unwrap();
269
270 let final_models = Arc::try_unwrap(target_models)
271 .map_err(|_| SklearsError::InvalidInput("Failed to extract models".to_string()))?
272 .into_inner()
273 .unwrap();
274
275 Ok(MultiOutputClassifier {
276 state: MultiOutputClassifierTrained {
277 classes_per_target: final_classes,
278 target_models: final_models,
279 n_targets,
280 n_features,
281 },
282 n_jobs: Some(n_jobs as i32),
283 })
284 }
285}
286
287impl MultiOutputClassifier<MultiOutputClassifierTrained> {
288 pub fn classes(&self) -> &[Vec<i32>] {
290 &self.state.classes_per_target
291 }
292
293 pub fn n_targets(&self) -> usize {
295 self.state.n_targets
296 }
297}
298
299impl Predict<ArrayView2<'_, Float>, Array2<i32>>
300 for MultiOutputClassifier<MultiOutputClassifierTrained>
301{
302 #[allow(non_snake_case)]
303 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
304 let X = X.to_owned();
305 let (n_samples, n_features) = X.dim();
306
307 if n_features != self.state.n_features {
308 return Err(SklearsError::InvalidInput(
309 "Number of features doesn't match training data".to_string(),
310 ));
311 }
312
313 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_targets));
314
315 for target_idx in 0..self.state.n_targets {
317 if let Some(class_centroids) = self.state.target_models.get(&target_idx) {
318 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
319 let mut min_distance = f64::INFINITY;
320 let mut best_class = 0;
321
322 for (&class_label, centroid) in class_centroids {
324 let mut distance = 0.0;
325 for feature_idx in 0..n_features {
326 let diff = sample[feature_idx] - centroid[feature_idx];
327 distance += diff * diff;
328 }
329 distance = distance.sqrt();
330
331 if distance < min_distance {
332 min_distance = distance;
333 best_class = class_label;
334 }
335 }
336
337 predictions[[sample_idx, target_idx]] = best_class;
338 }
339 }
340 }
341
342 Ok(predictions)
343 }
344}
345
346#[derive(Debug, Clone)]
364pub struct MultiOutputRegressor<S = Untrained> {
365 state: S,
366 n_jobs: Option<i32>,
367}
368
369impl MultiOutputRegressor<Untrained> {
370 pub fn new() -> Self {
372 Self {
373 state: Untrained,
374 n_jobs: None,
375 }
376 }
377
378 pub fn n_jobs(mut self, n_jobs: Option<i32>) -> Self {
380 self.n_jobs = n_jobs;
381 self
382 }
383}
384
385impl Default for MultiOutputRegressor<Untrained> {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl Estimator for MultiOutputRegressor<Untrained> {
392 type Config = ();
393 type Error = SklearsError;
394 type Float = Float;
395
396 fn config(&self) -> &Self::Config {
397 &()
398 }
399}
400
401impl Fit<ArrayView2<'_, Float>, Array2<f64>> for MultiOutputRegressor<Untrained> {
402 type Fitted = MultiOutputRegressor<MultiOutputRegressorTrained>;
403
404 #[allow(non_snake_case)]
405 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<f64>) -> SklResult<Self::Fitted> {
406 let X = X.to_owned();
407 let (n_samples, n_features) = X.dim();
408
409 if n_samples != y.nrows() {
410 return Err(SklearsError::InvalidInput(
411 "X and y must have the same number of samples".to_string(),
412 ));
413 }
414
415 let n_targets = y.ncols();
416 if n_targets == 0 {
417 return Err(SklearsError::InvalidInput(
418 "y must have at least one target".to_string(),
419 ));
420 }
421
422 let mut target_models = HashMap::new();
423
424 for target_idx in 0..n_targets {
426 let y_target = y.column(target_idx);
427
428 let mut weights = Array1::<Float>::zeros(n_features);
431 let mut bias = 0.0;
432
433 let y_mean = y_target.mean().unwrap();
435 bias = y_mean;
436
437 for feature_idx in 0..n_features {
439 let mut correlation = 0.0;
440 let mut x_mean = 0.0;
441
442 for sample_idx in 0..n_samples {
444 x_mean += X[[sample_idx, feature_idx]];
445 }
446 x_mean /= n_samples as f64;
447
448 let mut numerator = 0.0;
450 let mut x_var = 0.0;
451 let mut y_var = 0.0;
452
453 for sample_idx in 0..n_samples {
454 let x_diff = X[[sample_idx, feature_idx]] - x_mean;
455 let y_diff = y_target[sample_idx] - y_mean;
456 numerator += x_diff * y_diff;
457 x_var += x_diff * x_diff;
458 y_var += y_diff * y_diff;
459 }
460
461 if x_var > 1e-10 && y_var > 1e-10 {
462 correlation = numerator / (x_var.sqrt() * y_var.sqrt());
463 }
464
465 weights[feature_idx] = correlation * 0.1; }
467
468 target_models.insert(target_idx, (weights, bias));
469 }
470
471 if let Some(n_jobs) = self.n_jobs {
473 if n_jobs > 1 && n_targets > 1 {
474 return self.fit_parallel(X, y, n_jobs as usize);
475 }
476 }
477
478 Ok(MultiOutputRegressor {
479 state: MultiOutputRegressorTrained {
480 target_models,
481 n_targets,
482 n_features,
483 },
484 n_jobs: self.n_jobs,
485 })
486 }
487}
488
489impl MultiOutputRegressor<Untrained> {
490 #[allow(non_snake_case)]
492 fn fit_parallel(
493 self,
494 X: Array2<Float>,
495 y: &Array2<f64>,
496 n_jobs: usize,
497 ) -> SklResult<MultiOutputRegressor<MultiOutputRegressorTrained>> {
498 let (n_samples, n_features) = X.dim();
499 let n_targets = y.ncols();
500
501 let X_arc = Arc::new(X);
503 let y_arc = Arc::new(y.clone());
504 let target_models = Arc::new(Mutex::new(HashMap::new()));
505
506 let chunk_size = (n_targets + n_jobs - 1) / n_jobs; let mut handles = vec![];
509
510 for worker_id in 0..n_jobs {
512 let start_target = worker_id * chunk_size;
513 let end_target = std::cmp::min(start_target + chunk_size, n_targets);
514
515 if start_target >= n_targets {
516 break; }
518
519 let X_thread = Arc::clone(&X_arc);
520 let y_thread = Arc::clone(&y_arc);
521 let models_thread = Arc::clone(&target_models);
522
523 let handle = thread::spawn(move || -> SklResult<()> {
524 let mut local_models = HashMap::new();
525
526 for target_idx in start_target..end_target {
527 let y_target = y_thread.column(target_idx);
528 let mut weights = Array1::<f64>::zeros(n_features);
529
530 let y_mean = y_target.mean().unwrap();
532 let bias: f64 = y_mean;
533
534 for feature_idx in 0..n_features {
536 let mut correlation = 0.0;
537 let mut x_mean = 0.0;
538
539 for sample_idx in 0..n_samples {
541 x_mean += X_thread[[sample_idx, feature_idx]];
542 }
543 x_mean /= n_samples as f64;
544
545 let mut numerator = 0.0;
547 let mut x_var = 0.0;
548 let mut y_var = 0.0;
549
550 for sample_idx in 0..n_samples {
551 let x_diff = X_thread[[sample_idx, feature_idx]] - x_mean;
552 let y_diff = y_target[sample_idx] - y_mean;
553 numerator += x_diff * y_diff;
554 x_var += x_diff * x_diff;
555 y_var += y_diff * y_diff;
556 }
557
558 if x_var > 1e-10 && y_var > 1e-10 {
559 correlation = numerator / (x_var.sqrt() * y_var.sqrt());
560 }
561
562 weights[feature_idx] = correlation * 0.1; }
564
565 local_models.insert(target_idx, (weights, bias));
566 }
567
568 {
570 let mut models_guard = models_thread.lock().unwrap();
571 for (target_idx, model) in local_models {
572 models_guard.insert(target_idx, model);
573 }
574 }
575
576 Ok(())
577 });
578
579 handles.push(handle);
580 }
581
582 for handle in handles {
584 handle.join().map_err(|_| {
585 SklearsError::InvalidInput("Thread panicked during parallel training".to_string())
586 })??;
587 }
588
589 let final_models = Arc::try_unwrap(target_models)
591 .map_err(|_| SklearsError::InvalidInput("Failed to extract models".to_string()))?
592 .into_inner()
593 .unwrap();
594
595 Ok(MultiOutputRegressor {
596 state: MultiOutputRegressorTrained {
597 target_models: final_models,
598 n_targets,
599 n_features,
600 },
601 n_jobs: Some(n_jobs as i32),
602 })
603 }
604}
605
606impl MultiOutputRegressor<MultiOutputRegressorTrained> {
607 pub fn n_targets(&self) -> usize {
609 self.state.n_targets
610 }
611}
612
613impl Predict<ArrayView2<'_, Float>, Array2<f64>>
614 for MultiOutputRegressor<MultiOutputRegressorTrained>
615{
616 #[allow(non_snake_case)]
617 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
618 let X = X.to_owned();
619 let (n_samples, n_features) = X.dim();
620
621 if n_features != self.state.n_features {
622 return Err(SklearsError::InvalidInput(
623 "Number of features doesn't match training data".to_string(),
624 ));
625 }
626
627 let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
628
629 for target_idx in 0..self.state.n_targets {
631 if let Some((weights, bias)) = self.state.target_models.get(&target_idx) {
632 for (sample_idx, sample) in X.axis_iter(Axis(0)).enumerate() {
633 let prediction: f64 = sample
635 .iter()
636 .zip(weights.iter())
637 .map(|(&x, &w)| x * w)
638 .sum::<f64>()
639 + bias;
640
641 predictions[[sample_idx, target_idx]] = prediction;
642 }
643 }
644 }
645
646 Ok(predictions)
647 }
648}
649
650#[derive(Debug, Clone)]
652pub struct MultiOutputClassifierTrained {
653 pub classes_per_target: Vec<Vec<i32>>,
655 pub target_models: HashMap<usize, HashMap<i32, Array1<f64>>>,
657 pub n_targets: usize,
659 pub n_features: usize,
661}
662
663#[derive(Debug, Clone)]
665pub struct MultiOutputRegressorTrained {
666 pub target_models: HashMap<usize, (Array1<f64>, f64)>,
668 pub n_targets: usize,
670 pub n_features: usize,
672}
673
674#[allow(non_snake_case)]
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use approx::assert_abs_diff_eq;
679 use scirs2_core::ndarray::array;
681 use std::time::Instant;
682
683 #[test]
684 #[allow(non_snake_case)]
685 fn test_parallel_multi_output_classifier() {
686 let X = array![
687 [1.0, 2.0, 3.0],
688 [2.0, 3.0, 4.0],
689 [3.0, 4.0, 5.0],
690 [4.0, 5.0, 6.0],
691 [5.0, 6.0, 7.0],
692 [6.0, 7.0, 8.0]
693 ];
694 let y = array![
695 [0, 1, 0],
696 [1, 0, 1],
697 [0, 1, 0],
698 [1, 0, 1],
699 [0, 1, 0],
700 [1, 0, 1]
701 ];
702
703 let classifier_parallel = MultiOutputClassifier::new().n_jobs(Some(2));
705 let trained_parallel = classifier_parallel.fit(&X.view(), &y).unwrap();
706
707 let classifier_sequential = MultiOutputClassifier::new().n_jobs(Some(1));
709 let trained_sequential = classifier_sequential.fit(&X.view(), &y).unwrap();
710
711 assert_eq!(trained_parallel.n_targets(), trained_sequential.n_targets());
713 assert_eq!(
714 trained_parallel.classes().len(),
715 trained_sequential.classes().len()
716 );
717
718 let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
720 let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
721
722 assert_eq!(pred_parallel.shape(), pred_sequential.shape());
723 assert_eq!(pred_parallel.shape(), &[6, 3]);
724 }
725
726 #[test]
727 #[allow(non_snake_case)]
728 fn test_parallel_multi_output_regressor() {
729 let X = array![
730 [1.0, 2.0, 3.0],
731 [2.0, 3.0, 4.0],
732 [3.0, 4.0, 5.0],
733 [4.0, 5.0, 6.0],
734 [5.0, 6.0, 7.0],
735 [6.0, 7.0, 8.0]
736 ];
737 let y = array![
738 [1.5, 2.5, 3.5],
739 [2.5, 3.5, 4.5],
740 [3.5, 4.5, 5.5],
741 [4.5, 5.5, 6.5],
742 [5.5, 6.5, 7.5],
743 [6.5, 7.5, 8.5]
744 ];
745
746 let regressor_parallel = MultiOutputRegressor::new().n_jobs(Some(2));
748 let trained_parallel = regressor_parallel.fit(&X.view(), &y).unwrap();
749
750 let regressor_sequential = MultiOutputRegressor::new().n_jobs(Some(1));
752 let trained_sequential = regressor_sequential.fit(&X.view(), &y).unwrap();
753
754 assert_eq!(trained_parallel.n_targets(), trained_sequential.n_targets());
756
757 let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
759 let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
760
761 assert_eq!(pred_parallel.shape(), pred_sequential.shape());
762 assert_eq!(pred_parallel.shape(), &[6, 3]);
763
764 for i in 0..pred_parallel.nrows() {
766 for j in 0..pred_parallel.ncols() {
767 assert_abs_diff_eq!(
768 pred_parallel[[i, j]],
769 pred_sequential[[i, j]],
770 epsilon = 1e-10
771 );
772 }
773 }
774 }
775
776 #[test]
777 fn test_parallel_training_performance_classifier() {
778 let n_samples = 1000;
780 let n_features = 50;
781 let n_targets = 20;
782
783 let mut X = Array2::<Float>::zeros((n_samples, n_features));
784 let mut y = Array2::<i32>::zeros((n_samples, n_targets));
785
786 for i in 0..n_samples {
788 for j in 0..n_features {
789 X[[i, j]] = (i * j) as Float * 0.01;
790 }
791 for j in 0..n_targets {
792 y[[i, j]] = ((i + j) % 2) as i32;
793 }
794 }
795
796 let start_sequential = Instant::now();
798 let classifier_sequential = MultiOutputClassifier::new().n_jobs(Some(1));
799 let trained_sequential = classifier_sequential.fit(&X.view(), &y).unwrap();
800 let sequential_time = start_sequential.elapsed();
801
802 let start_parallel = Instant::now();
804 let classifier_parallel = MultiOutputClassifier::new().n_jobs(Some(4));
805 let trained_parallel = classifier_parallel.fit(&X.view(), &y).unwrap();
806 let parallel_time = start_parallel.elapsed();
807
808 assert_eq!(trained_parallel.n_targets(), n_targets);
810 assert_eq!(trained_sequential.n_targets(), n_targets);
811
812 let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
814 let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
815 assert_eq!(pred_parallel.shape(), pred_sequential.shape());
816
817 println!(
818 "Sequential time: {:?}, Parallel time: {:?}",
819 sequential_time, parallel_time
820 );
821 }
822
823 #[test]
824 fn test_parallel_training_performance_regressor() {
825 let n_samples = 1000;
827 let n_features = 50;
828 let n_targets = 20;
829
830 let mut X = Array2::<Float>::zeros((n_samples, n_features));
831 let mut y = Array2::<f64>::zeros((n_samples, n_targets));
832
833 for i in 0..n_samples {
835 for j in 0..n_features {
836 X[[i, j]] = (i * j) as Float * 0.01;
837 }
838 for j in 0..n_targets {
839 y[[i, j]] = (i + j) as f64 * 0.1;
840 }
841 }
842
843 let start_sequential = Instant::now();
845 let regressor_sequential = MultiOutputRegressor::new().n_jobs(Some(1));
846 let trained_sequential = regressor_sequential.fit(&X.view(), &y).unwrap();
847 let sequential_time = start_sequential.elapsed();
848
849 let start_parallel = Instant::now();
851 let regressor_parallel = MultiOutputRegressor::new().n_jobs(Some(4));
852 let trained_parallel = regressor_parallel.fit(&X.view(), &y).unwrap();
853 let parallel_time = start_parallel.elapsed();
854
855 assert_eq!(trained_parallel.n_targets(), n_targets);
857 assert_eq!(trained_sequential.n_targets(), n_targets);
858
859 let pred_parallel = trained_parallel.predict(&X.view()).unwrap();
861 let pred_sequential = trained_sequential.predict(&X.view()).unwrap();
862 assert_eq!(pred_parallel.shape(), pred_sequential.shape());
863
864 println!(
865 "Sequential time: {:?}, Parallel time: {:?}",
866 sequential_time, parallel_time
867 );
868 }
869
870 #[test]
871 #[allow(non_snake_case)]
872 fn test_parallel_training_thread_safety() {
873 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
874 let y_class = array![[0, 1], [1, 0], [0, 1], [1, 0]];
875 let y_reg = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
876
877 for _ in 0..10 {
879 let classifier = MultiOutputClassifier::new().n_jobs(Some(2));
880 let trained = classifier.fit(&X.view(), &y_class).unwrap();
881 let predictions = trained.predict(&X.view()).unwrap();
882 assert_eq!(predictions.shape(), &[4, 2]);
883
884 let regressor = MultiOutputRegressor::new().n_jobs(Some(2));
885 let trained = regressor.fit(&X.view(), &y_reg).unwrap();
886 let predictions = trained.predict(&X.view()).unwrap();
887 assert_eq!(predictions.shape(), &[4, 2]);
888 }
889 }
890
891 #[test]
892 #[allow(non_snake_case)]
893 fn test_parallel_training_edge_cases() {
894 let X = array![[1.0, 2.0], [2.0, 3.0]];
895 let y_class = array![[0, 1], [1, 0]];
896 let y_reg = array![[1.0, 2.0], [2.0, 3.0]];
897
898 let classifier = MultiOutputClassifier::new().n_jobs(Some(10));
900 let trained = classifier.fit(&X.view(), &y_class).unwrap();
901 assert_eq!(trained.n_targets(), 2);
902
903 let regressor = MultiOutputRegressor::new().n_jobs(Some(10));
904 let trained = regressor.fit(&X.view(), &y_reg).unwrap();
905 assert_eq!(trained.n_targets(), 2);
906
907 let y_single = array![[0], [1]];
909 let classifier_single = MultiOutputClassifier::new().n_jobs(Some(4));
910 let trained_single = classifier_single.fit(&X.view(), &y_single).unwrap();
911 assert_eq!(trained_single.n_targets(), 1);
912 }
913
914 #[test]
915 #[allow(non_snake_case)]
916 fn test_parallel_training_error_handling() {
917 let X = array![[1.0, 2.0], [2.0, 3.0]];
918 let y_mismatch = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]]; let classifier = MultiOutputClassifier::new().n_jobs(Some(2));
922 let result = classifier.fit(&X.view(), &y_mismatch);
923 assert!(result.is_err());
924
925 let regressor = MultiOutputRegressor::new().n_jobs(Some(2));
926 let y_reg_mismatch = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
927 let result = regressor.fit(&X.view(), &y_reg_mismatch);
928 assert!(result.is_err());
929 }
930}