sklears_multioutput/
lib.rs

1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8#![allow(unused_imports)]
9#![allow(unused_variables)]
10#![allow(unused_mut)]
11#![allow(unused_assignments)]
12#![allow(unused_doc_comments)]
13#![allow(unused_parens)]
14#![allow(unused_comparisons)]
15//! Multi-output regression and classification
16//!
17//! This module provides meta-estimators for multi-target prediction problems.
18//! It includes strategies for independent multi-output prediction.
19
20// #![warn(missing_docs)]
21
22pub mod activation;
23pub mod adversarial;
24pub mod chains;
25pub mod classification;
26pub mod core;
27pub mod correlation;
28pub mod ensemble;
29pub mod hierarchical;
30pub mod label_analysis;
31pub mod loss;
32pub mod metrics;
33pub mod mlp;
34pub mod multi_label;
35pub mod multitask;
36pub mod neighbors;
37pub mod neural;
38pub mod optimization;
39pub mod performance;
40pub mod probabilistic;
41pub mod ranking;
42pub mod recurrent;
43pub mod regularization;
44pub mod sequence;
45pub mod sparse_storage;
46pub mod streaming;
47pub mod svm;
48pub mod transfer_learning;
49pub mod tree;
50pub mod utilities;
51pub mod utils;
52
53// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
54
55// Re-export core multi-output algorithms
56pub use core::{
57    MultiOutputClassifier, MultiOutputClassifierTrained, MultiOutputRegressor,
58    MultiOutputRegressorTrained,
59};
60
61// Re-export chain-based algorithms
62pub use chains::{
63    BayesianClassifierChain, BayesianClassifierChainTrained, ChainMethod, ClassifierChain,
64    ClassifierChainTrained, EnsembleOfChains, EnsembleOfChainsTrained, RegressorChain,
65    RegressorChainTrained,
66};
67
68// Re-export ensemble algorithms
69pub use ensemble::{GradientBoostingMultiOutput, GradientBoostingMultiOutputTrained, WeakLearner};
70
71// Re-export neural network algorithms
72pub use neural::{
73    ActivationFunction, AdversarialMultiTaskNetwork, AdversarialMultiTaskNetworkTrained,
74    AdversarialStrategy, CellType, GradientReversalConfig, LambdaSchedule, LossFunction,
75    MultiOutputMLP, MultiOutputMLPClassifier, MultiOutputMLPRegressor, MultiOutputMLPTrained,
76    MultiTaskNeuralNetwork, MultiTaskNeuralNetworkTrained, RecurrentNeuralNetwork,
77    RecurrentNeuralNetworkTrained, SequenceMode, TaskBalancing, TaskDiscriminator,
78};
79
80// Re-export adversarial learning types that are not in neural
81pub use adversarial::AdversarialConfig;
82
83// Re-export regularization algorithms
84pub use regularization::{
85    GroupLasso, GroupLassoTrained, MetaLearningMultiTask, MetaLearningMultiTaskTrained,
86    MultiTaskElasticNet, MultiTaskElasticNetTrained, NuclearNormRegression,
87    NuclearNormRegressionTrained, RegularizationStrategy, TaskClusteringRegressionTrained,
88    TaskClusteringRegularization, TaskRelationshipLearning, TaskRelationshipLearningTrained,
89    TaskSimilarityMethod,
90};
91
92// Re-export correlation and dependency analysis
93pub use correlation::{
94    CITestMethod, CITestResult, CITestResults, ConditionalIndependenceTester, CorrelationAnalysis,
95    CorrelationType, DependencyGraph, DependencyGraphBuilder, DependencyMethod, GraphStatistics,
96    OutputCorrelationAnalyzer,
97};
98
99// Re-export transfer learning algorithms
100pub use transfer_learning::{
101    ContinualLearning, ContinualLearningTrained, CrossTaskTransferLearning,
102    CrossTaskTransferLearningTrained, DomainAdaptation, DomainAdaptationTrained,
103    KnowledgeDistillation, KnowledgeDistillationTrained, ProgressiveTransferLearning,
104    ProgressiveTransferLearningTrained,
105};
106
107// Re-export optimization algorithms
108pub use optimization::{
109    JointLossConfig, JointLossOptimizer, JointLossOptimizerTrained, LossCombination,
110    LossFunction as OptimizationLossFunction, MultiObjectiveConfig, MultiObjectiveOptimizer,
111    MultiObjectiveOptimizerTrained, NSGA2Algorithm, NSGA2Config, NSGA2Optimizer,
112    NSGA2OptimizerTrained, ParetoSolution, ScalarizationConfig, ScalarizationMethod,
113    ScalarizationOptimizer, ScalarizationOptimizerTrained,
114};
115
116// Re-export probabilistic algorithms
117pub use probabilistic::{
118    BayesianMultiOutputConfig, BayesianMultiOutputModel, BayesianMultiOutputModelTrained,
119    EnsembleBayesianConfig, EnsembleBayesianModel, EnsembleBayesianModelTrained, EnsembleStrategy,
120    GaussianProcessMultiOutput, GaussianProcessMultiOutputTrained, InferenceMethod, KernelFunction,
121    PosteriorDistribution, PredictionWithUncertainty, PriorDistribution,
122};
123
124// Re-export ranking algorithms
125pub use ranking::{
126    BinaryClassifierModel, IndependentLabelPrediction, IndependentLabelPredictionTrained,
127    ThresholdStrategy as RankingThresholdStrategy,
128};
129
130// Re-export sparse storage algorithms
131pub use sparse_storage::{
132    sparse_utils, CSRMatrix, MemoryUsage, SparseMultiOutput, SparseMultiOutputTrained,
133    SparsityAnalysis, StorageRecommendation,
134};
135
136// Re-export streaming and incremental learning algorithms
137pub use streaming::{
138    IncrementalMultiOutputRegression, IncrementalMultiOutputRegressionConfig,
139    IncrementalMultiOutputRegressionTrained, StreamingMultiOutput, StreamingMultiOutputConfig,
140    StreamingMultiOutputTrained,
141};
142
143// Re-export performance optimization algorithms
144pub use performance::{
145    EarlyStopping, EarlyStoppingConfig, PredictionCache, WarmStartRegressor,
146    WarmStartRegressorConfig, WarmStartRegressorTrained,
147};
148
149// Re-export multi-label algorithms
150pub use multi_label::{
151    BinaryRelevance, BinaryRelevanceTrained, LabelPowerset, LabelPowersetTrained,
152    OneVsRestClassifier, OneVsRestClassifierTrained, PrunedLabelPowerset,
153    PrunedLabelPowersetTrained, PruningStrategy,
154};
155
156// Re-export tree-based algorithms
157pub use tree::{
158    ClassificationCriterion, DAGInferenceMethod, MultiTargetDecisionTreeClassifier,
159    MultiTargetDecisionTreeClassifierTrained, MultiTargetRegressionTree,
160    MultiTargetRegressionTreeTrained, RandomForestMultiOutput, RandomForestMultiOutputTrained,
161    TreeStructuredPredictor, TreeStructuredPredictorTrained,
162};
163
164// Re-export instance-based learning algorithms
165pub use neighbors::{IBLRTrained, WeightFunction, IBLR};
166
167// Re-export SVM algorithms
168pub use svm::{
169    MLTSVMTrained, MultiOutputSVM, MultiOutputSVMTrained, RankSVM, RankSVMTrained, RankingSVMModel,
170    SVMKernel, SVMModel, ThresholdStrategy as SVMThresholdStrategy, TwinSVMModel, MLTSVM,
171};
172
173// Re-export sequence/structured prediction algorithms
174pub use sequence::{
175    FeatureFunction, FeatureType, HiddenMarkovModel, HiddenMarkovModelTrained,
176    MaximumEntropyMarkovModel, MaximumEntropyMarkovModelTrained, StructuredPerceptron,
177    StructuredPerceptronTrained,
178};
179
180// Re-export hierarchical classification and graph neural network algorithms
181pub use hierarchical::{
182    AggregationFunction, ConsistencyEnforcement, CostSensitiveHierarchicalClassifier,
183    CostSensitiveHierarchicalClassifierTrained, CostStrategy, GraphNeuralNetwork,
184    GraphNeuralNetworkTrained, MessagePassingVariant, OntologyAwareClassifier,
185    OntologyAwareClassifierTrained,
186};
187
188// Re-export multi-label classification algorithms
189pub use classification::{
190    CalibratedBinaryRelevance, CalibratedBinaryRelevanceTrained, CalibrationMethod, CostMatrix,
191    CostSensitiveBinaryRelevance, CostSensitiveBinaryRelevanceTrained, DistanceMetric, MLkNN,
192    MLkNNTrained, RandomLabelCombinations, SimpleBinaryModel,
193};
194
195// Re-export comprehensive metrics and statistical testing functionality
196pub use metrics::{
197    average_precision_score,
198    confidence_interval,
199    coverage_error,
200    f1_score,
201    // Basic multi-label metrics
202    hamming_loss,
203    jaccard_score,
204    label_ranking_average_precision,
205    // Statistical significance testing
206    mcnemar_test,
207    one_error,
208    paired_t_test,
209    // Per-label performance metrics
210    per_label_metrics,
211    precision_score_micro,
212    ranking_loss,
213    recall_score_micro,
214
215    subset_accuracy,
216    wilcoxon_signed_rank_test,
217    ConfidenceInterval,
218    PerLabelMetrics,
219
220    StatisticalTestResult,
221};
222
223#[allow(non_snake_case)]
224#[cfg(test)]
225mod tests {
226    use super::*;
227    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
228    use crate::utilities::CLARE;
229    use scirs2_core::ndarray::{array, Array2};
230    use sklears_core::traits::{Fit, Predict};
231    use sklears_core::types::Float;
232
233    #[test]
234    fn test_multi_output_classifier() {
235        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
236        let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
237
238        let moc = MultiOutputClassifier::new();
239        let fitted = moc.fit(&X.view(), &y).unwrap();
240
241        assert_eq!(fitted.n_targets(), 2);
242        assert_eq!(fitted.classes().len(), 2);
243
244        let predictions = fitted.predict(&X.view()).unwrap();
245        assert_eq!(predictions.dim(), (4, 2));
246
247        // Check that predictions are valid (within the classes for each target)
248        for target_idx in 0..2 {
249            let target_classes = &fitted.classes()[target_idx];
250            for sample_idx in 0..4 {
251                let pred = predictions[[sample_idx, target_idx]];
252                assert!(target_classes.contains(&pred));
253            }
254        }
255    }
256
257    #[test]
258    fn test_multi_output_regressor() {
259        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
260        let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5], [1.0, 1.5]];
261
262        let mor = MultiOutputRegressor::new();
263        let fitted = mor.fit(&X.view(), &y).unwrap();
264
265        assert_eq!(fitted.n_targets(), 2);
266
267        let predictions = fitted.predict(&X.view()).unwrap();
268        assert_eq!(predictions.dim(), (4, 2));
269
270        // Predictions should be finite numbers
271        for pred in predictions.iter() {
272            assert!(pred.is_finite());
273        }
274    }
275
276    #[test]
277    fn test_invalid_input() {
278        let X = array![[1.0, 2.0], [2.0, 3.0]];
279        let y = array![[0, 1], [1, 0], [0, 1]]; // Wrong number of rows
280
281        let moc = MultiOutputClassifier::new();
282        assert!(moc.fit(&X.view(), &y).is_err());
283    }
284
285    #[test]
286    fn test_empty_targets() {
287        let X = array![[1.0, 2.0], [2.0, 3.0]];
288        let y = Array2::<i32>::zeros((2, 0)); // No targets
289
290        let moc = MultiOutputClassifier::new();
291        assert!(moc.fit(&X.view(), &y).is_err());
292    }
293
294    #[test]
295    fn test_prediction_shape_mismatch() {
296        let X = array![[1.0, 2.0], [2.0, 3.0]];
297        let y = array![[0, 1], [1, 0]];
298
299        let moc = MultiOutputClassifier::new();
300        let fitted = moc.fit(&X.view(), &y).unwrap();
301
302        // Test with wrong number of features
303        let X_wrong = array![[1.0, 2.0, 3.0]]; // 3 features instead of 2
304        assert!(fitted.predict(&X_wrong.view()).is_err());
305    }
306
307    #[test]
308    fn test_classifier_chain() {
309        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
310        let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
311
312        let cc = ClassifierChain::new();
313        let fitted = cc.fit_simple(&X.view(), &y).unwrap();
314
315        assert_eq!(fitted.n_targets(), 2);
316        assert_eq!(fitted.chain_order(), &[0, 1]); // Default order
317
318        let predictions = fitted.predict_simple(&X.view()).unwrap();
319        assert_eq!(predictions.dim(), (4, 2));
320
321        // Check that predictions are valid
322        for sample_idx in 0..4 {
323            for target_idx in 0..2 {
324                let pred = predictions[[sample_idx, target_idx]];
325                // Predictions should be either 0 or 1 for binary classification
326                assert!(pred == 0 || pred == 1);
327            }
328        }
329    }
330
331    #[test]
332    fn test_classifier_chain_custom_order() {
333        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
334        let y = array![[0, 1], [1, 0], [1, 1]];
335
336        let cc = ClassifierChain::new().order(vec![1, 0]); // Reverse order
337        let fitted = cc.fit_simple(&X.view(), &y).unwrap();
338
339        assert_eq!(fitted.chain_order(), &[1, 0]);
340
341        let predictions = fitted.predict_simple(&X.view()).unwrap();
342        assert_eq!(predictions.dim(), (3, 2));
343    }
344
345    #[test]
346    fn test_classifier_chain_invalid_order() {
347        let X = array![[1.0, 2.0], [2.0, 3.0]];
348        let y = array![[0, 1], [1, 0]];
349
350        let cc = ClassifierChain::new().order(vec![0, 1, 2]); // Too many indices
351        assert!(cc.fit_simple(&X.view(), &y).is_err());
352    }
353
354    #[test]
355    fn test_classifier_chain_monte_carlo() {
356        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
357        let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
358
359        let cc = ClassifierChain::new();
360        let fitted = cc.fit_simple(&X.view(), &y).unwrap();
361
362        // Test Monte Carlo predictions with probabilities
363        let mc_probs = fitted
364            .predict_monte_carlo(&X.view(), 100, Some(42))
365            .unwrap();
366        assert_eq!(mc_probs.dim(), (4, 2));
367
368        // All probabilities should be between 0 and 1
369        for prob in mc_probs.iter() {
370            assert!(*prob >= 0.0 && *prob <= 1.0);
371        }
372
373        // Test Monte Carlo predictions with labels
374        let mc_labels = fitted
375            .predict_monte_carlo_labels(&X.view(), 100, Some(42))
376            .unwrap();
377        assert_eq!(mc_labels.dim(), (4, 2));
378
379        // All predictions should be binary (0 or 1)
380        for pred in mc_labels.iter() {
381            assert!(*pred == 0 || *pred == 1);
382        }
383
384        // Test reproducibility with same random state
385        let mc_probs2 = fitted
386            .predict_monte_carlo(&X.view(), 100, Some(42))
387            .unwrap();
388        for (i, (&prob1, &prob2)) in mc_probs.iter().zip(mc_probs2.iter()).enumerate() {
389            assert!(
390                (prob1 - prob2).abs() < 1e-10,
391                "Probabilities should be identical with same random state at index {}",
392                i
393            );
394        }
395    }
396
397    #[test]
398    fn test_classifier_chain_monte_carlo_invalid_input() {
399        let X = array![[1.0, 2.0], [2.0, 3.0]];
400        let y = array![[0, 1], [1, 0]];
401
402        let cc = ClassifierChain::new();
403        let fitted = cc.fit_simple(&X.view(), &y).unwrap();
404
405        // Test with zero samples
406        assert!(fitted.predict_monte_carlo(&X.view(), 0, None).is_err());
407
408        // Test with wrong number of features
409        let X_wrong = array![[1.0, 2.0, 3.0]]; // 3 features instead of 2
410        assert!(fitted
411            .predict_monte_carlo(&X_wrong.view(), 10, None)
412            .is_err());
413    }
414
415    #[test]
416    fn test_regressor_chain() {
417        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
418        let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5], [1.0, 1.5]];
419
420        let rc = RegressorChain::new();
421        let fitted = rc.fit_simple(&X.view(), &y).unwrap();
422
423        assert_eq!(fitted.n_targets(), 2);
424        assert_eq!(fitted.chain_order(), &[0, 1]); // Default order
425
426        let predictions = fitted.predict_simple(&X.view()).unwrap();
427        assert_eq!(predictions.dim(), (4, 2));
428
429        // Predictions should be finite numbers
430        for pred in predictions.iter() {
431            assert!(pred.is_finite());
432        }
433    }
434
435    #[test]
436    fn test_regressor_chain_custom_order() {
437        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
438        let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5]];
439
440        let rc = RegressorChain::new().order(vec![1, 0]); // Reverse order
441        let fitted = rc.fit_simple(&X.view(), &y).unwrap();
442
443        assert_eq!(fitted.chain_order(), &[1, 0]);
444
445        let predictions = fitted.predict_simple(&X.view()).unwrap();
446        assert_eq!(predictions.dim(), (3, 2));
447
448        // Predictions should be finite numbers
449        for pred in predictions.iter() {
450            assert!(pred.is_finite());
451        }
452    }
453
454    #[test]
455    fn test_regressor_chain_invalid_input() {
456        let X = array![[1.0, 2.0], [2.0, 3.0]];
457        let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5]]; // Wrong number of rows
458
459        let rc = RegressorChain::new();
460        assert!(rc.fit_simple(&X.view(), &y).is_err());
461    }
462
463    #[test]
464    fn test_regressor_chain_invalid_order() {
465        let X = array![[1.0, 2.0], [2.0, 3.0]];
466        let y = array![[1.5, 2.5], [2.5, 3.5]];
467
468        let rc = RegressorChain::new().order(vec![0, 1, 2]); // Too many indices
469        assert!(rc.fit_simple(&X.view(), &y).is_err());
470    }
471
472    #[test]
473    fn test_binary_relevance() {
474        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
475        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label binary
476
477        let br = BinaryRelevance::new();
478        let fitted = br.fit(&X.view(), &y).unwrap();
479
480        assert_eq!(fitted.n_labels(), 2);
481        assert_eq!(fitted.classes().len(), 2);
482
483        let predictions = fitted.predict(&X.view()).unwrap();
484        assert_eq!(predictions.dim(), (4, 2));
485
486        // Check that predictions are binary (0 or 1)
487        for pred in predictions.iter() {
488            assert!(*pred == 0 || *pred == 1);
489        }
490
491        // Test probability predictions
492        let probabilities = fitted.predict_proba(&X.view()).unwrap();
493        assert_eq!(probabilities.dim(), (4, 2));
494
495        // Check that probabilities are in [0, 1]
496        for prob in probabilities.iter() {
497            assert!(*prob >= 0.0 && *prob <= 1.0);
498        }
499    }
500
501    #[test]
502    fn test_binary_relevance_single_label() {
503        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
504        let y = array![[1], [0], [1]]; // Single binary label
505
506        let br = BinaryRelevance::new();
507        let fitted = br.fit(&X.view(), &y).unwrap();
508
509        assert_eq!(fitted.n_labels(), 1);
510
511        let predictions = fitted.predict(&X.view()).unwrap();
512        assert_eq!(predictions.dim(), (3, 1));
513
514        // Check that predictions are binary
515        for pred in predictions.iter() {
516            assert!(*pred == 0 || *pred == 1);
517        }
518    }
519
520    #[test]
521    fn test_binary_relevance_invalid_input() {
522        let X = array![[1.0, 2.0], [2.0, 3.0]];
523        let y = array![[1, 0], [0, 1], [1, 1]]; // Wrong number of rows
524
525        let br = BinaryRelevance::new();
526        assert!(br.fit(&X.view(), &y).is_err());
527    }
528
529    #[test]
530    fn test_binary_relevance_non_binary_labels() {
531        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
532        let y = array![[0, 1], [1, 2], [2, 0]]; // Non-binary labels
533
534        let br = BinaryRelevance::new();
535        assert!(br.fit(&X.view(), &y).is_err());
536    }
537
538    #[test]
539    fn test_binary_relevance_predict_shape_mismatch() {
540        let X = array![[1.0, 2.0], [2.0, 3.0]];
541        let y = array![[1, 0], [0, 1]];
542
543        let br = BinaryRelevance::new();
544        let fitted = br.fit(&X.view(), &y).unwrap();
545
546        // Test with wrong number of features
547        let X_wrong = array![[1.0, 2.0, 3.0]]; // 3 features instead of 2
548        assert!(fitted.predict(&X_wrong.view()).is_err());
549        assert!(fitted.predict_proba(&X_wrong.view()).is_err());
550    }
551
552    #[test]
553    fn test_label_powerset() {
554        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
555        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label binary combinations
556
557        let lp = LabelPowerset::new();
558        let fitted = lp.fit(&X.view(), &y).unwrap();
559
560        assert_eq!(fitted.n_labels(), 2);
561        assert_eq!(fitted.n_classes(), 4); // 4 unique combinations: [1,0], [0,1], [1,1], [0,0]
562
563        let predictions = fitted.predict(&X.view()).unwrap();
564        assert_eq!(predictions.dim(), (4, 2));
565
566        // Check that predictions are binary (0 or 1)
567        for pred in predictions.iter() {
568            assert!(*pred == 0 || *pred == 1);
569        }
570
571        // Test decision function
572        let scores = fitted.decision_function(&X.view()).unwrap();
573        assert_eq!(scores.dim(), (4, 4)); // 4 samples, 4 classes
574
575        // Scores should be finite
576        for score in scores.iter() {
577            assert!(score.is_finite());
578        }
579    }
580
581    #[test]
582    fn test_label_powerset_simple_case() {
583        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
584        let y = array![[1, 0], [0, 1], [1, 0]]; // Only 2 unique combinations
585
586        let lp = LabelPowerset::new();
587        let fitted = lp.fit(&X.view(), &y).unwrap();
588
589        assert_eq!(fitted.n_labels(), 2);
590        assert_eq!(fitted.n_classes(), 2); // Only 2 unique combinations: [1,0], [0,1]
591
592        let predictions = fitted.predict(&X.view()).unwrap();
593        assert_eq!(predictions.dim(), (3, 2));
594
595        // Check that predictions are binary
596        for pred in predictions.iter() {
597            assert!(*pred == 0 || *pred == 1);
598        }
599    }
600
601    #[test]
602    fn test_label_powerset_single_label() {
603        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
604        let y = array![[1], [0], [1]]; // Single binary label
605
606        let lp = LabelPowerset::new();
607        let fitted = lp.fit(&X.view(), &y).unwrap();
608
609        assert_eq!(fitted.n_labels(), 1);
610        assert_eq!(fitted.n_classes(), 2); // 2 unique combinations: [1], [0]
611
612        let predictions = fitted.predict(&X.view()).unwrap();
613        assert_eq!(predictions.dim(), (3, 1));
614
615        // Check that predictions are binary
616        for pred in predictions.iter() {
617            assert!(*pred == 0 || *pred == 1);
618        }
619    }
620
621    #[test]
622    fn test_label_powerset_invalid_input() {
623        let X = array![[1.0, 2.0], [2.0, 3.0]];
624        let y = array![[1, 0], [0, 1], [1, 1]]; // Wrong number of rows
625
626        let lp = LabelPowerset::new();
627        assert!(lp.fit(&X.view(), &y).is_err());
628    }
629
630    #[test]
631    fn test_label_powerset_non_binary_labels() {
632        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
633        let y = array![[0, 1], [1, 2], [2, 0]]; // Non-binary labels
634
635        let lp = LabelPowerset::new();
636        assert!(lp.fit(&X.view(), &y).is_err());
637    }
638
639    #[test]
640    fn test_label_powerset_predict_shape_mismatch() {
641        let X = array![[1.0, 2.0], [2.0, 3.0]];
642        let y = array![[1, 0], [0, 1]];
643
644        let lp = LabelPowerset::new();
645        let fitted = lp.fit(&X.view(), &y).unwrap();
646
647        // Test with wrong number of features
648        let X_wrong = array![[1.0, 2.0, 3.0]]; // 3 features instead of 2
649        assert!(fitted.predict(&X_wrong.view()).is_err());
650        assert!(fitted.decision_function(&X_wrong.view()).is_err());
651    }
652
653    #[test]
654    fn test_label_powerset_all_same_combination() {
655        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
656        let y = array![[1, 0], [1, 0], [1, 0]]; // All samples have same label combination
657
658        let lp = LabelPowerset::new();
659        let fitted = lp.fit(&X.view(), &y).unwrap();
660
661        assert_eq!(fitted.n_labels(), 2);
662        assert_eq!(fitted.n_classes(), 1); // Only 1 unique combination
663
664        let predictions = fitted.predict(&X.view()).unwrap();
665        assert_eq!(predictions.dim(), (3, 2));
666
667        // All predictions should be [1, 0]
668        for sample_idx in 0..3 {
669            assert_eq!(predictions[[sample_idx, 0]], 1);
670            assert_eq!(predictions[[sample_idx, 1]], 0);
671        }
672    }
673
674    #[test]
675    fn test_pruned_label_powerset_default_strategy() {
676        // Test data with some rare combinations
677        let X = array![
678            [1.0, 2.0],
679            [2.0, 3.0],
680            [3.0, 1.0],
681            [1.0, 1.0],
682            [2.0, 2.0],
683            [3.0, 3.0],
684            [1.5, 2.5],
685            [2.5, 1.5]
686        ];
687        let y = array![
688            [1, 0],
689            [0, 1],
690            [1, 1],
691            [0, 0], // Frequent combinations
692            [1, 0],
693            [0, 1],
694            [1, 0],
695            [0, 1], // More frequent ones
696        ];
697
698        let plp = PrunedLabelPowerset::new()
699            .min_frequency(2)
700            .strategy(PruningStrategy::DefaultMapping(vec![0, 0]));
701
702        let fitted = plp.fit(&X.view(), &y).unwrap();
703
704        // Should have pruned to only frequent combinations
705        assert!(fitted.n_frequent_classes() <= 4); // At most [1,0], [0,1], [1,1], [0,0]
706        assert_eq!(fitted.min_frequency(), 2);
707
708        let predictions = fitted.predict(&X.view()).unwrap();
709        assert_eq!(predictions.dim(), (8, 2));
710
711        // All predictions should be binary
712        for pred in predictions.iter() {
713            assert!(*pred == 0 || *pred == 1);
714        }
715    }
716
717    #[test]
718    fn test_pruned_label_powerset_similarity_strategy() {
719        // Test with similarity mapping strategy
720        let X = array![
721            [1.0, 2.0],
722            [2.0, 3.0],
723            [3.0, 1.0],
724            [1.0, 1.0],
725            [2.0, 2.0],
726            [3.0, 3.0]
727        ];
728        let y = array![
729            [1, 0],
730            [1, 0],
731            [1, 0], // Frequent: [1, 0] appears 3 times
732            [0, 1],
733            [0, 1], // Frequent: [0, 1] appears 2 times
734            [1, 1]  // Rare: [1, 1] appears 1 time
735        ];
736
737        let plp = PrunedLabelPowerset::new()
738            .min_frequency(2)
739            .strategy(PruningStrategy::SimilarityMapping);
740
741        let fitted = plp.fit(&X.view(), &y).unwrap();
742
743        // Should have only 2 frequent combinations: [1,0] and [0,1]
744        assert_eq!(fitted.n_frequent_classes(), 2);
745
746        // The rare combination [1,1] should be mapped to one of the frequent ones
747        let mapping = fitted.combination_mapping();
748        let rare_combo = vec![1, 1];
749        assert!(mapping.contains_key(&rare_combo));
750
751        // The mapped combination should be one of the frequent ones
752        let mapped = mapping.get(&rare_combo).unwrap();
753        assert!(mapped == &vec![1, 0] || mapped == &vec![0, 1]);
754
755        let predictions = fitted.predict(&X.view()).unwrap();
756        assert_eq!(predictions.dim(), (6, 2));
757
758        // All predictions should be binary
759        for pred in predictions.iter() {
760            assert!(*pred == 0 || *pred == 1);
761        }
762    }
763
764    #[test]
765    fn test_pruned_label_powerset_invalid_input() {
766        let X = array![[1.0, 2.0], [2.0, 3.0]];
767        let y = array![[0, 1], [1, 0]];
768
769        // Test with minimum frequency that results in no frequent combinations
770        let plp = PrunedLabelPowerset::new().min_frequency(5); // Too high
771        assert!(plp.fit(&X.view(), &y).is_err());
772
773        // Test with invalid default combination length
774        let plp =
775            PrunedLabelPowerset::new().strategy(PruningStrategy::DefaultMapping(vec![0, 1, 0])); // 3 elements for 2 labels
776        assert!(plp.fit(&X.view(), &y).is_err());
777
778        // Test with non-binary labels
779        let y_bad = array![[2, 1], [1, 0]]; // Contains non-binary value
780        let plp = PrunedLabelPowerset::new();
781        assert!(plp.fit(&X.view(), &y_bad).is_err());
782    }
783
784    #[test]
785    fn test_pruned_label_powerset_edge_cases() {
786        // Test with minimal data that meets frequency requirement
787        let X = array![[1.0, 2.0], [2.0, 3.0]];
788        let y = array![[1, 0], [1, 0]]; // Same combination twice
789
790        let plp = PrunedLabelPowerset::new().min_frequency(2);
791        let fitted = plp.fit(&X.view(), &y).unwrap();
792
793        // Should have at least 1 combination, possibly 2 if default is added
794        assert!(fitted.n_frequent_classes() >= 1);
795        assert!(fitted.frequent_combinations().len() >= 1);
796
797        // The frequent combinations should include [1, 0]
798        assert!(fitted.frequent_combinations().contains(&vec![1, 0]));
799
800        let predictions = fitted.predict(&X.view()).unwrap();
801        assert_eq!(predictions.dim(), (2, 2));
802
803        // All predictions should be [1, 0]
804        for sample_idx in 0..2 {
805            assert_eq!(predictions[[sample_idx, 0]], 1);
806            assert_eq!(predictions[[sample_idx, 1]], 0);
807        }
808    }
809
810    #[test]
811    fn test_metrics_hamming_loss() {
812        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
813        let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]]; // 3 errors out of 9
814
815        let loss = metrics::hamming_loss(&y_true.view(), &y_pred.view()).unwrap();
816        assert!((loss - 3.0 / 9.0).abs() < 1e-10);
817    }
818
819    #[test]
820    fn test_metrics_subset_accuracy() {
821        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
822        let y_pred = array![[1, 0, 1], [0, 1, 1], [1, 0, 1]]; // Only first subset matches
823
824        let accuracy = metrics::subset_accuracy(&y_true.view(), &y_pred.view()).unwrap();
825        assert!((accuracy - 1.0 / 3.0).abs() < 1e-10);
826    }
827
828    #[test]
829    fn test_metrics_jaccard_score() {
830        let y_true = array![[1, 0, 1], [0, 1, 0]];
831        let y_pred = array![[1, 0, 0], [0, 1, 1]];
832
833        let score = metrics::jaccard_score(&y_true.view(), &y_pred.view()).unwrap();
834        // Sample 1: intersection=1, union=2, jaccard=0.5
835        // Sample 2: intersection=1, union=2, jaccard=0.5
836        // Average: 0.5
837        assert!((score - 0.5).abs() < 1e-10);
838    }
839
840    #[test]
841    fn test_metrics_f1_score_micro() {
842        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
843        let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]];
844
845        let f1 = metrics::f1_score(&y_true.view(), &y_pred.view(), "micro").unwrap();
846        // TP=4, FP=1, FN=2
847        // Precision = 4/5 = 0.8, Recall = 4/6 = 0.6667
848        // F1 = 2 * 0.8 * 0.6667 / (0.8 + 0.6667) = 0.727
849        assert!((f1 - 0.7272727272727273).abs() < 1e-10);
850    }
851
852    #[test]
853    fn test_metrics_f1_score_macro() {
854        let y_true = array![[1, 0], [0, 1], [1, 1]];
855        let y_pred = array![[1, 0], [0, 1], [1, 0]]; // Perfect for label 0, imperfect for label 1
856
857        let f1 = metrics::f1_score(&y_true.view(), &y_pred.view(), "macro").unwrap();
858        // Label 0: TP=2, FP=0, FN=0 -> F1=1.0
859        // Label 1: TP=1, FP=0, FN=1 -> Precision=1.0, Recall=0.5, F1=0.667
860        // Macro average: (1.0 + 0.667) / 2 = 0.833
861        assert!((f1 - 0.8333333333333334).abs() < 1e-10);
862    }
863
864    #[test]
865    fn test_metrics_f1_score_samples() {
866        let y_true = array![[1, 0], [0, 1], [1, 1]];
867        let y_pred = array![[1, 0], [0, 1], [1, 0]];
868
869        let f1 = metrics::f1_score(&y_true.view(), &y_pred.view(), "samples").unwrap();
870        // Sample 0: TP=1, FP=0, FN=0 -> F1=1.0
871        // Sample 1: TP=1, FP=0, FN=0 -> F1=1.0
872        // Sample 2: TP=1, FP=0, FN=1 -> Precision=1.0, Recall=0.5, F1=0.667
873        // Average: (1.0 + 1.0 + 0.667) / 3 = 0.889
874        assert!((f1 - 0.8888888888888888).abs() < 1e-10);
875    }
876
877    #[test]
878    fn test_metrics_coverage_error() {
879        let y_true = array![[1, 0, 1], [0, 1, 0]];
880        let y_scores = array![[0.9, 0.1, 0.8], [0.2, 0.9, 0.3]];
881
882        let coverage = metrics::coverage_error(&y_true.view(), &y_scores.view()).unwrap();
883        // Sample 0: sorted scores [0.9, 0.8, 0.1] -> labels [0, 2, 1]
884        //          true labels are at positions 1 and 2, so coverage = 2
885        // Sample 1: sorted scores [0.9, 0.3, 0.2] -> labels [1, 2, 0]
886        //          true label is at position 1, so coverage = 1
887        // Average: (2 + 1) / 2 = 1.5
888        assert!((coverage - 1.5).abs() < 1e-10);
889    }
890
891    #[test]
892    fn test_metrics_label_ranking_average_precision() {
893        let y_true = array![[1, 0, 1], [0, 1, 0]];
894        let y_scores = array![[0.9, 0.1, 0.8], [0.2, 0.9, 0.3]];
895
896        let lrap =
897            metrics::label_ranking_average_precision(&y_true.view(), &y_scores.view()).unwrap();
898        // Sample 0: sorted scores [0.9, 0.8, 0.1] -> labels [0, 2, 1]
899        //          true labels: 0 (pos 1), 2 (pos 2)
900        //          precision at pos 1: 1/1=1.0, precision at pos 2: 2/2=1.0
901        //          LRAP = (1.0 + 1.0) / 2 = 1.0
902        // Sample 1: sorted scores [0.9, 0.3, 0.2] -> labels [1, 2, 0]
903        //          true label: 1 (pos 1)
904        //          precision at pos 1: 1/1=1.0
905        //          LRAP = 1.0 / 1 = 1.0
906        // Average: (1.0 + 1.0) / 2 = 1.0
907        assert!((lrap - 1.0).abs() < 1e-10);
908    }
909
910    #[test]
911    fn test_metrics_invalid_shapes() {
912        let y_true = array![[1, 0], [0, 1]];
913        let y_pred = array![[1, 0, 1]]; // Wrong shape
914
915        assert!(metrics::hamming_loss(&y_true.view(), &y_pred.view()).is_err());
916        assert!(metrics::subset_accuracy(&y_true.view(), &y_pred.view()).is_err());
917        assert!(metrics::jaccard_score(&y_true.view(), &y_pred.view()).is_err());
918        assert!(metrics::f1_score(&y_true.view(), &y_pred.view(), "micro").is_err());
919    }
920
921    #[test]
922    fn test_metrics_invalid_f1_average() {
923        let y_true = array![[1, 0], [0, 1]];
924        let y_pred = array![[1, 0], [0, 1]];
925
926        assert!(metrics::f1_score(&y_true.view(), &y_pred.view(), "invalid").is_err());
927    }
928
929    #[test]
930    fn test_ensemble_of_chains() {
931        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
932        let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
933
934        let eoc = EnsembleOfChains::new().n_chains(3).random_state(42);
935        let fitted = eoc.fit_simple(&X.view(), &y).unwrap();
936
937        assert_eq!(fitted.n_chains(), 3);
938        assert_eq!(fitted.n_targets(), 2);
939
940        let predictions = fitted.predict_simple(&X.view()).unwrap();
941        assert_eq!(predictions.dim(), (4, 2));
942
943        // Check that predictions are binary (0 or 1)
944        for pred in predictions.iter() {
945            assert!(*pred == 0 || *pred == 1);
946        }
947
948        // Test probability predictions
949        let probabilities = fitted.predict_proba_simple(&X.view()).unwrap();
950        assert_eq!(probabilities.dim(), (4, 2));
951
952        // Check that probabilities are in [0, 1]
953        for prob in probabilities.iter() {
954            assert!(*prob >= 0.0 && *prob <= 1.0);
955        }
956    }
957
958    #[test]
959    fn test_ensemble_of_chains_single_chain() {
960        let X = array![[1.0, 2.0], [2.0, 3.0]];
961        let y = array![[1, 0], [0, 1]];
962
963        let eoc = EnsembleOfChains::new().n_chains(1);
964        let fitted = eoc.fit_simple(&X.view(), &y).unwrap();
965
966        assert_eq!(fitted.n_chains(), 1);
967
968        let predictions = fitted.predict_simple(&X.view()).unwrap();
969        assert_eq!(predictions.dim(), (2, 2));
970    }
971
972    #[test]
973    fn test_ensemble_of_chains_invalid_input() {
974        let X = array![[1.0, 2.0], [2.0, 3.0]];
975        let y = array![[1, 0], [0, 1], [1, 1]]; // Wrong number of rows
976
977        let eoc = EnsembleOfChains::new();
978        assert!(eoc.fit_simple(&X.view(), &y).is_err());
979    }
980
981    #[test]
982    fn test_one_vs_rest_classifier() {
983        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
984        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label binary
985
986        let ovr = OneVsRestClassifier::new();
987        let fitted = ovr.fit(&X.view(), &y).unwrap();
988
989        assert_eq!(fitted.n_labels(), 2);
990        assert_eq!(fitted.classes().len(), 2);
991
992        let predictions = fitted.predict(&X.view()).unwrap();
993        assert_eq!(predictions.dim(), (4, 2));
994
995        // Check that predictions are binary (0 or 1)
996        for pred in predictions.iter() {
997            assert!(*pred == 0 || *pred == 1);
998        }
999
1000        // Test probability predictions
1001        let probabilities = fitted.predict_proba(&X.view()).unwrap();
1002        assert_eq!(probabilities.dim(), (4, 2));
1003
1004        // Check that probabilities are in [0, 1]
1005        for prob in probabilities.iter() {
1006            assert!(*prob >= 0.0 && *prob <= 1.0);
1007        }
1008
1009        // Test decision function
1010        let scores = fitted.decision_function(&X.view()).unwrap();
1011        assert_eq!(scores.dim(), (4, 2));
1012
1013        // Scores should be finite
1014        for score in scores.iter() {
1015            assert!(score.is_finite());
1016        }
1017    }
1018
1019    #[test]
1020    fn test_one_vs_rest_classifier_invalid_input() {
1021        let X = array![[1.0, 2.0], [2.0, 3.0]];
1022        let y = array![[1, 0], [0, 1], [1, 1]]; // Wrong number of rows
1023
1024        let ovr = OneVsRestClassifier::new();
1025        assert!(ovr.fit(&X.view(), &y).is_err());
1026    }
1027
1028    #[test]
1029    fn test_metrics_one_error() {
1030        let y_true = array![[1, 0, 0], [0, 1, 0], [0, 0, 1]];
1031        let y_scores = array![[0.9, 0.1, 0.05], [0.1, 0.8, 0.1], [0.05, 0.1, 0.85]];
1032
1033        let one_err = metrics::one_error(&y_true.view(), &y_scores.view()).unwrap();
1034        // All top-ranked labels are correct, so one-error should be 0
1035        assert!((one_err - 0.0).abs() < 1e-10);
1036    }
1037
1038    #[test]
1039    fn test_metrics_one_error_with_errors() {
1040        let y_true = array![[1, 0], [0, 1]];
1041        let y_scores = array![[0.3, 0.7], [0.6, 0.4]]; // Top predictions are wrong
1042
1043        let one_err = metrics::one_error(&y_true.view(), &y_scores.view()).unwrap();
1044        // Both samples have incorrect top predictions, so one-error should be 1.0
1045        assert!((one_err - 1.0).abs() < 1e-10);
1046    }
1047
1048    #[test]
1049    fn test_metrics_ranking_loss() {
1050        let y_true = array![[1, 0], [0, 1]];
1051        let y_scores = array![[0.8, 0.2], [0.3, 0.7]]; // Correct ordering
1052
1053        let ranking_loss = metrics::ranking_loss(&y_true.view(), &y_scores.view()).unwrap();
1054        // Perfect ranking, so loss should be 0
1055        assert!((ranking_loss - 0.0).abs() < 1e-10);
1056    }
1057
1058    #[test]
1059    fn test_metrics_ranking_loss_with_errors() {
1060        let y_true = array![[1, 0], [0, 1]];
1061        let y_scores = array![[0.2, 0.8], [0.7, 0.3]]; // Incorrect ordering
1062
1063        let ranking_loss = metrics::ranking_loss(&y_true.view(), &y_scores.view()).unwrap();
1064        // All pairs are incorrectly ordered, so loss should be 1.0
1065        assert!((ranking_loss - 1.0).abs() < 1e-10);
1066    }
1067
1068    #[test]
1069    fn test_metrics_average_precision_score() {
1070        let y_true = array![[1, 0, 1], [0, 1, 0]];
1071        let y_scores = array![[0.9, 0.1, 0.8], [0.2, 0.9, 0.3]];
1072
1073        let ap_score = metrics::average_precision_score(&y_true.view(), &y_scores.view()).unwrap();
1074        // With perfect ranking for both samples, AP should be 1.0
1075        assert!((ap_score - 1.0).abs() < 1e-10);
1076    }
1077
1078    #[test]
1079    fn test_metrics_precision_recall_micro() {
1080        let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
1081        let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]];
1082
1083        let precision = metrics::precision_score_micro(&y_true.view(), &y_pred.view()).unwrap();
1084        let recall = metrics::recall_score_micro(&y_true.view(), &y_pred.view()).unwrap();
1085
1086        // TP=4, FP=1, FN=2
1087        // Precision = 4/5 = 0.8, Recall = 4/6 = 0.6667
1088        assert!((precision - 0.8).abs() < 1e-10);
1089        assert!((recall - 0.6666666666666666).abs() < 1e-10);
1090    }
1091
1092    #[test]
1093    fn test_metrics_invalid_shapes_new_metrics() {
1094        let y_true = array![[1, 0], [0, 1]];
1095        let y_pred = array![[1, 0, 1]]; // Wrong shape
1096        let y_scores = array![[0.8, 0.2, 0.1]]; // Wrong shape
1097
1098        assert!(metrics::one_error(&y_true.view(), &y_scores.view()).is_err());
1099        assert!(metrics::ranking_loss(&y_true.view(), &y_scores.view()).is_err());
1100        assert!(metrics::average_precision_score(&y_true.view(), &y_scores.view()).is_err());
1101        assert!(metrics::precision_score_micro(&y_true.view(), &y_pred.view()).is_err());
1102        assert!(metrics::recall_score_micro(&y_true.view(), &y_pred.view()).is_err());
1103    }
1104
1105    #[test]
1106    fn test_label_analysis_basic() {
1107        let y = array![
1108            [1, 0, 1], // Cardinality 2
1109            [0, 1, 0], // Cardinality 1
1110            [1, 0, 1], // Cardinality 2 (duplicate)
1111            [0, 0, 0], // Cardinality 0
1112            [1, 1, 1], // Cardinality 3
1113        ];
1114
1115        let results = label_analysis::analyze_combinations(&y.view()).unwrap();
1116
1117        assert_eq!(results.total_samples, 5);
1118        assert_eq!(results.combinations[0].combination.len(), 3); // Number of labels = 3
1119        assert_eq!(results.unique_combinations, 4); // [1,0,1], [0,1,0], [0,0,0], [1,1,1]
1120
1121        // Check that [1,0,1] is most frequent (appears twice)
1122        assert_eq!(
1123            results.most_frequent.as_ref().unwrap().combination,
1124            vec![1, 0, 1]
1125        );
1126        assert_eq!(results.most_frequent.as_ref().unwrap().frequency, 2);
1127        assert!((results.most_frequent.as_ref().unwrap().relative_frequency - 0.4).abs() < 1e-10);
1128        assert_eq!(results.most_frequent.as_ref().unwrap().cardinality, 2);
1129
1130        // Average cardinality should be (2+1+2+0+3)/5 = 1.6
1131        assert!((results.average_cardinality - 1.6).abs() < 1e-10);
1132    }
1133
1134    #[test]
1135    fn test_label_analysis_utility_functions() {
1136        let y = array![
1137            [1, 0],
1138            [1, 0],
1139            [1, 0], // Frequent: [1, 0] appears 3 times
1140            [0, 1],
1141            [0, 1], // Frequent: [0, 1] appears 2 times
1142            [1, 1]  // Rare: [1, 1] appears 1 time
1143        ];
1144
1145        let results = label_analysis::analyze_combinations(&y.view()).unwrap();
1146
1147        // Test get_rare_combinations
1148        let rare = label_analysis::get_rare_combinations(&results, 2);
1149        assert_eq!(rare.len(), 2); // [0,1] freq=2 and [1,1] freq=1 are both <= threshold
1150                                   // Find the combination with frequency 1
1151        let freq_1_combo = rare.iter().find(|combo| combo.frequency == 1).unwrap();
1152        assert_eq!(freq_1_combo.combination, vec![1, 1]);
1153
1154        // Test get_combinations_by_cardinality
1155        let cardinality_1 = label_analysis::get_combinations_by_cardinality(&results, 1);
1156        assert_eq!(cardinality_1.len(), 2); // [1, 0] and [0, 1]
1157
1158        let cardinality_2 = label_analysis::get_combinations_by_cardinality(&results, 2);
1159        assert_eq!(cardinality_2.len(), 1); // [1, 1]
1160        assert_eq!(cardinality_2[0].combination, vec![1, 1]);
1161    }
1162
1163    #[test]
1164    fn test_label_cooccurrence_matrix() {
1165        let y = array![
1166            [1, 1, 0], // Labels 0 and 1 co-occur
1167            [1, 0, 1], // Labels 0 and 2 co-occur
1168            [0, 1, 1], // Labels 1 and 2 co-occur
1169            [1, 1, 1], // All labels co-occur
1170        ];
1171
1172        let cooccurrence = label_analysis::label_cooccurrence_matrix(&y.view()).unwrap();
1173        assert_eq!(cooccurrence.dim(), (3, 3));
1174
1175        // Label 0 appears with itself in samples 0, 1, 3 = 3 times
1176        assert_eq!(cooccurrence[[0, 0]], 3);
1177        // Label 1 appears with itself in samples 0, 2, 3 = 3 times
1178        assert_eq!(cooccurrence[[1, 1]], 3);
1179        // Label 2 appears with itself in samples 1, 2, 3 = 3 times
1180        assert_eq!(cooccurrence[[2, 2]], 3);
1181
1182        // Labels 0 and 1 co-occur in samples 0, 3 = 2 times
1183        assert_eq!(cooccurrence[[0, 1]], 2);
1184        assert_eq!(cooccurrence[[1, 0]], 2);
1185
1186        // Labels 0 and 2 co-occur in samples 1, 3 = 2 times
1187        assert_eq!(cooccurrence[[0, 2]], 2);
1188        assert_eq!(cooccurrence[[2, 0]], 2);
1189
1190        // Labels 1 and 2 co-occur in samples 2, 3 = 2 times
1191        assert_eq!(cooccurrence[[1, 2]], 2);
1192        assert_eq!(cooccurrence[[2, 1]], 2);
1193    }
1194
1195    #[test]
1196    fn test_label_correlation_matrix() {
1197        let y = array![[1, 1, 0], [1, 0, 1], [0, 1, 1], [0, 0, 0],];
1198
1199        let correlation = label_analysis::label_correlation_matrix(&y.view()).unwrap();
1200        assert_eq!(correlation.dim(), (3, 3));
1201
1202        // Diagonal should be 1.0 (perfect self-correlation)
1203        assert!((correlation[[0, 0]] - 1.0).abs() < 1e-10);
1204        assert!((correlation[[1, 1]] - 1.0).abs() < 1e-10);
1205        assert!((correlation[[2, 2]] - 1.0).abs() < 1e-10);
1206
1207        // Matrix should be symmetric
1208        for i in 0..3 {
1209            for j in 0..3 {
1210                assert!((correlation[[i, j]] - correlation[[j, i]]).abs() < 1e-10);
1211            }
1212        }
1213
1214        // All correlations should be between -1 and 1
1215        for i in 0..3 {
1216            for j in 0..3 {
1217                assert!(correlation[[i, j]] >= -1.0 && correlation[[i, j]] <= 1.0);
1218            }
1219        }
1220    }
1221
1222    #[test]
1223    fn test_label_analysis_invalid_input() {
1224        // Test with non-binary labels
1225        let y_bad = array![[2, 1], [1, 0]]; // Contains non-binary value
1226        assert!(label_analysis::analyze_combinations(&y_bad.view()).is_err());
1227
1228        // Test with empty array
1229        let y_empty = Array2::<i32>::zeros((0, 2));
1230        assert!(label_analysis::analyze_combinations(&y_empty.view()).is_err());
1231
1232        let y_no_labels = Array2::<i32>::zeros((2, 0));
1233        assert!(label_analysis::analyze_combinations(&y_no_labels.view()).is_err());
1234
1235        // Test cooccurrence matrix with empty data
1236        assert!(label_analysis::label_cooccurrence_matrix(&y_empty.view()).is_err());
1237        assert!(label_analysis::label_correlation_matrix(&y_empty.view()).is_err());
1238    }
1239
1240    #[test]
1241    fn test_label_analysis_edge_cases() {
1242        // Test with single sample
1243        let y_single = array![[1, 0, 1]];
1244        let results = label_analysis::analyze_combinations(&y_single.view()).unwrap();
1245
1246        assert_eq!(results.total_samples, 1);
1247        assert_eq!(results.unique_combinations, 1);
1248        assert_eq!(
1249            results.most_frequent.as_ref().unwrap().combination,
1250            vec![1, 0, 1]
1251        );
1252        assert_eq!(
1253            results.least_frequent.as_ref().unwrap().combination,
1254            vec![1, 0, 1]
1255        );
1256        assert_eq!(results.average_cardinality, 2.0);
1257
1258        // Test with all zeros
1259        let y_zeros = array![[0, 0], [0, 0]];
1260        let results = label_analysis::analyze_combinations(&y_zeros.view()).unwrap();
1261
1262        assert_eq!(results.average_cardinality, 0.0);
1263
1264        // Test with all ones
1265        let y_ones = array![[1, 1], [1, 1]];
1266        let results = label_analysis::analyze_combinations(&y_ones.view()).unwrap();
1267
1268        assert_eq!(results.average_cardinality, 2.0);
1269    }
1270
1271    #[test]
1272    fn test_iblr_basic_functionality() {
1273        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1274        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label classification targets
1275
1276        let iblr = IBLR::new().k_neighbors(2);
1277        let trained_iblr = iblr.fit(&X.view(), &y.view()).unwrap();
1278        let predictions = trained_iblr.predict(&X.view()).unwrap();
1279
1280        assert_eq!(predictions.dim(), (4, 2));
1281
1282        // Check that predictions are binary (0 or 1)
1283        for i in 0..4 {
1284            for j in 0..2 {
1285                assert!(predictions[[i, j]] == 0 || predictions[[i, j]] == 1);
1286            }
1287        }
1288    }
1289
1290    #[test]
1291    fn test_iblr_configuration() {
1292        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1293        let y = array![[1, 0], [0, 1], [1, 1]]; // Multi-label classification targets
1294
1295        // Test different k values
1296        let iblr1 = IBLR::new().k_neighbors(1);
1297        let iblr2 = IBLR::new().k_neighbors(2); // Must be < n_samples (3)
1298
1299        let trained1 = iblr1.fit(&X.view(), &y.view()).unwrap();
1300        let trained2 = iblr2.fit(&X.view(), &y.view()).unwrap();
1301
1302        let pred1 = trained1.predict(&X.view()).unwrap();
1303        let pred2 = trained2.predict(&X.view()).unwrap();
1304
1305        assert_eq!(pred1.dim(), (3, 2));
1306        assert_eq!(pred2.dim(), (3, 2));
1307
1308        // Test weight functions
1309        let iblr_uniform = IBLR::new().k_neighbors(2).weights(WeightFunction::Uniform);
1310        let iblr_distance = IBLR::new().k_neighbors(2).weights(WeightFunction::Distance);
1311
1312        let trained_uniform = iblr_uniform.fit(&X.view(), &y.view()).unwrap();
1313        let trained_distance = iblr_distance.fit(&X.view(), &y.view()).unwrap();
1314
1315        let pred_uniform = trained_uniform.predict(&X.view()).unwrap();
1316        let pred_distance = trained_distance.predict(&X.view()).unwrap();
1317
1318        assert_eq!(pred_uniform.dim(), (3, 2));
1319        assert_eq!(pred_distance.dim(), (3, 2));
1320    }
1321
1322    #[test]
1323    fn test_iblr_error_handling() {
1324        let X = array![[1.0, 2.0], [2.0, 3.0]];
1325        let y = array![[1, 0], [0, 1], [1, 1]]; // Mismatched samples (3 vs 2)
1326
1327        let iblr = IBLR::new();
1328        assert!(iblr.fit(&X.view(), &y.view()).is_err());
1329
1330        // Test k_neighbors validation
1331        let y_valid = array![[1, 0], [0, 1]]; // Matching samples
1332
1333        let iblr_zero_k = IBLR::new().k_neighbors(0);
1334        assert!(iblr_zero_k.fit(&X.view(), &y_valid.view()).is_err());
1335
1336        let iblr_large_k = IBLR::new().k_neighbors(5); // More than samples
1337        assert!(iblr_large_k.fit(&X.view(), &y_valid.view()).is_err());
1338
1339        // Test prediction with wrong feature dimensions
1340        let X_train = array![[1.0, 2.0], [2.0, 3.0]];
1341        let y_train = array![[1, 0], [0, 1]];
1342        let iblr_for_predict = IBLR::new().k_neighbors(1); // Must be < n_samples (2)
1343        let trained = iblr_for_predict
1344            .fit(&X_train.view(), &y_train.view())
1345            .unwrap();
1346
1347        let X_wrong_features = array![[1.0, 2.0, 3.0]]; // Extra feature
1348        assert!(trained.predict(&X_wrong_features.view()).is_err());
1349
1350        // Test empty data
1351        let X_empty = Array2::<Float>::zeros((0, 2));
1352        let y_empty = Array2::<i32>::zeros((0, 2));
1353        assert!(IBLR::new().fit(&X_empty.view(), &y_empty.view()).is_err());
1354    }
1355
1356    #[test]
1357    fn test_iblr_weight_functions() {
1358        let X = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [2.0, 2.0]];
1359        let y = array![[1, 1], [0, 1], [1, 0], [0, 0]]; // Binary classification labels
1360
1361        // Test uniform weighting
1362        let iblr_uniform = IBLR::new().k_neighbors(3).weights(WeightFunction::Uniform);
1363        let trained_uniform = iblr_uniform.fit(&X.view(), &y.view()).unwrap();
1364        let pred_uniform = trained_uniform.predict(&X.view()).unwrap();
1365
1366        // Test distance weighting
1367        let iblr_distance = IBLR::new().k_neighbors(3).weights(WeightFunction::Distance);
1368        let trained_distance = iblr_distance.fit(&X.view(), &y.view()).unwrap();
1369        let pred_distance = trained_distance.predict(&X.view()).unwrap();
1370
1371        // Predictions should be reasonable for both
1372        assert_eq!(pred_uniform.dim(), (4, 2));
1373        assert_eq!(pred_distance.dim(), (4, 2));
1374
1375        // Check that all predictions are binary (0 or 1)
1376        for i in 0..4 {
1377            for j in 0..2 {
1378                assert!(pred_uniform[[i, j]] == 0 || pred_uniform[[i, j]] == 1);
1379                assert!(pred_distance[[i, j]] == 0 || pred_distance[[i, j]] == 1);
1380            }
1381        }
1382    }
1383
1384    #[test]
1385    fn test_iblr_single_neighbor() {
1386        let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1387        let y = array![[1, 0], [0, 1], [1, 1]]; // Binary classification labels
1388
1389        let iblr = IBLR::new().k_neighbors(1);
1390        let trained = iblr.fit(&X.view(), &y.view()).unwrap();
1391
1392        // Test prediction on training data (should be exact for k=1)
1393        let predictions = trained.predict(&X.view()).unwrap();
1394
1395        for i in 0..3 {
1396            for j in 0..2 {
1397                assert_eq!(predictions[[i, j]], y[[i, j]]);
1398            }
1399        }
1400    }
1401
1402    #[test]
1403    fn test_iblr_interpolation() {
1404        let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
1405        let y = array![[0, 0], [1, 1], [0, 1]]; // Binary classification labels
1406
1407        let iblr = IBLR::new().k_neighbors(2);
1408        let trained = iblr.fit(&X.view(), &y.view()).unwrap();
1409
1410        // Test prediction at midpoint
1411        let X_test = array![[0.5, 0.5]];
1412        let prediction = trained.predict(&X_test.view()).unwrap();
1413
1414        // Should predict binary values (0 or 1)
1415        assert!(prediction[[0, 0]] == 0 || prediction[[0, 0]] == 1);
1416        assert!(prediction[[0, 1]] == 0 || prediction[[0, 1]] == 1);
1417    }
1418
1419    #[test]
1420    fn test_clare_basic_functionality() {
1421        let X = array![[1.0, 1.0], [1.5, 1.5], [5.0, 5.0], [5.5, 5.5]];
1422        let y = array![[1, 0], [1, 0], [0, 1], [0, 1]]; // Two clear clusters with different label patterns
1423
1424        let clare = CLARE::new().n_clusters(2).random_state(42);
1425        let trained_clare = clare.fit(&X.view(), &y).unwrap();
1426        let predictions = trained_clare.predict(&X.view()).unwrap();
1427
1428        assert_eq!(predictions.dim(), (4, 2));
1429
1430        // Verify cluster centers and assignments were learned
1431        assert_eq!(trained_clare.cluster_centers().dim(), (2, 2));
1432        assert_eq!(trained_clare.cluster_assignments().len(), 4);
1433    }
1434
1435    #[test]
1436    fn test_clare_configuration() {
1437        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1438        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1439
1440        // Test different configurations
1441        let clare1 = CLARE::new().n_clusters(2).threshold(0.3);
1442        let clare2 = CLARE::new().n_clusters(3).max_iter(50);
1443        let clare3 = CLARE::new().random_state(123);
1444
1445        let trained1 = clare1.fit(&X.view(), &y).unwrap();
1446        let trained2 = clare2.fit(&X.view(), &y).unwrap();
1447        let trained3 = clare3.fit(&X.view(), &y).unwrap();
1448
1449        let pred1 = trained1.predict(&X.view()).unwrap();
1450        let pred2 = trained2.predict(&X.view()).unwrap();
1451        let pred3 = trained3.predict(&X.view()).unwrap();
1452
1453        assert_eq!(pred1.dim(), (4, 2));
1454        assert_eq!(pred2.dim(), (4, 2));
1455        assert_eq!(pred3.dim(), (4, 2));
1456
1457        // Test accessors
1458        assert_eq!(trained1.threshold(), 0.3);
1459        assert_eq!(trained1.cluster_centers().dim(), (2, 2));
1460        assert_eq!(trained2.cluster_centers().dim(), (3, 2));
1461    }
1462
1463    #[test]
1464    fn test_clare_error_handling() {
1465        let X = array![[1.0, 2.0], [2.0, 3.0]];
1466        let y = array![[1, 0], [0, 1], [1, 1]]; // Mismatched samples
1467
1468        let clare = CLARE::new();
1469        assert!(clare.fit(&X.view(), &y).is_err());
1470
1471        // Test n_clusters validation
1472        let y_valid = array![[1, 0], [0, 1]];
1473
1474        let clare_zero_clusters = CLARE::new().n_clusters(0);
1475        assert!(clare_zero_clusters.fit(&X.view(), &y_valid).is_err());
1476
1477        let clare_too_many_clusters = CLARE::new().n_clusters(5); // More than samples
1478        assert!(clare_too_many_clusters.fit(&X.view(), &y_valid).is_err());
1479
1480        // Test non-binary labels
1481        let y_non_binary = array![[1, 2], [0, 1]]; // Contains 2
1482        assert!(CLARE::new().fit(&X.view(), &y_non_binary).is_err());
1483
1484        // Test prediction with wrong feature dimensions
1485        let X_train = array![[1.0, 2.0], [2.0, 3.0]];
1486        let y_train = array![[1, 0], [0, 1]];
1487        let clare_for_predict = CLARE::new().n_clusters(2);
1488        let trained = clare_for_predict.fit(&X_train.view(), &y_train).unwrap();
1489
1490        let X_wrong_features = array![[1.0, 2.0, 3.0]]; // Extra feature
1491        assert!(trained.predict(&X_wrong_features.view()).is_err());
1492
1493        // Test empty data
1494        let X_empty = Array2::<Float>::zeros((0, 2));
1495        let y_empty = Array2::<i32>::zeros((0, 2));
1496        assert!(CLARE::new().fit(&X_empty.view(), &y_empty).is_err());
1497    }
1498
1499    #[test]
1500    fn test_clare_threshold_prediction() {
1501        let X = array![[1.0, 1.0], [1.2, 1.2], [5.0, 5.0], [5.2, 5.2]];
1502        let y = array![[1, 0], [1, 0], [0, 1], [0, 1]];
1503
1504        let clare = CLARE::new().n_clusters(2).threshold(0.3).random_state(42);
1505        let trained_clare = clare.fit(&X.view(), &y).unwrap();
1506
1507        // Test predictions are binary
1508        let predictions = trained_clare.predict(&X.view()).unwrap();
1509        assert_eq!(predictions.dim(), (4, 2));
1510
1511        // All predictions should be 0 or 1
1512        for pred in predictions.iter() {
1513            assert!(*pred == 0 || *pred == 1);
1514        }
1515
1516        // Verify threshold was set correctly
1517        assert_eq!(trained_clare.threshold(), 0.3);
1518    }
1519
1520    #[test]
1521    fn test_clare_clustering_consistency() {
1522        let X = array![
1523            [0.0, 0.0],
1524            [0.1, 0.1], // First cluster
1525            [5.0, 5.0],
1526            [5.1, 5.1] // Second cluster
1527        ];
1528        let y = array![
1529            [1, 0],
1530            [1, 0], // First cluster: always label 0 active
1531            [0, 1],
1532            [0, 1] // Second cluster: always label 1 active
1533        ];
1534
1535        let clare = CLARE::new().n_clusters(2).threshold(0.5).random_state(42);
1536        let trained_clare = clare.fit(&X.view(), &y).unwrap();
1537        let predictions = trained_clare.predict(&X.view()).unwrap();
1538
1539        // With clear clustering, predictions should match patterns
1540        assert_eq!(predictions.dim(), (4, 2));
1541        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1542
1543        // Test threshold accessor
1544        assert!((trained_clare.threshold() - 0.5).abs() < 1e-10);
1545    }
1546
1547    #[test]
1548    fn test_clare_single_cluster() {
1549        let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1550        let y = array![[1, 0], [0, 1], [1, 1]];
1551
1552        // Use only 1 cluster
1553        let clare = CLARE::new().n_clusters(1);
1554        let trained_clare = clare.fit(&X.view(), &y).unwrap();
1555        let predictions = trained_clare.predict(&X.view()).unwrap();
1556
1557        assert_eq!(predictions.dim(), (3, 2));
1558        assert_eq!(trained_clare.cluster_centers().dim(), (1, 2));
1559
1560        // With 1 cluster, all samples should get same prediction
1561        // (based on average label frequency)
1562        for i in 1..3 {
1563            for j in 0..2 {
1564                assert_eq!(predictions[[0, j]], predictions[[i, j]]);
1565            }
1566        }
1567    }
1568
1569    #[test]
1570    fn test_clare_reproducibility() {
1571        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1572        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1573
1574        // Train two models with same random state
1575        let clare1 = CLARE::new().n_clusters(2).random_state(42);
1576        let trained1 = clare1.fit(&X.view(), &y).unwrap();
1577
1578        let clare2 = CLARE::new().n_clusters(2).random_state(42);
1579        let trained2 = clare2.fit(&X.view(), &y).unwrap();
1580
1581        // Should produce same cluster centers
1582        let centers1 = trained1.cluster_centers();
1583        let centers2 = trained2.cluster_centers();
1584
1585        for i in 0..centers1.nrows() {
1586            for j in 0..centers1.ncols() {
1587                assert!((centers1[[i, j]] - centers2[[i, j]]).abs() < 1e-10);
1588            }
1589        }
1590
1591        // Should produce same predictions
1592        let pred1 = trained1.predict(&X.view()).unwrap();
1593        let pred2 = trained2.predict(&X.view()).unwrap();
1594
1595        for i in 0..pred1.nrows() {
1596            for j in 0..pred1.ncols() {
1597                assert_eq!(pred1[[i, j]], pred2[[i, j]]);
1598            }
1599        }
1600    }
1601
1602    #[test]
1603    fn test_mltsvm_basic_functionality() {
1604        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1605        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label binary
1606
1607        let mltsvm = MLTSVM::new().c1(1.0).c2(1.0);
1608        let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1609        let predictions = trained_mltsvm.predict(&X.view()).unwrap();
1610
1611        assert_eq!(predictions.dim(), (4, 2));
1612        assert_eq!(trained_mltsvm.n_labels(), 2);
1613
1614        // All predictions should be binary (0 or 1)
1615        for &pred in predictions.iter() {
1616            assert!(pred == 0 || pred == 1);
1617        }
1618    }
1619
1620    #[test]
1621    fn test_mltsvm_configuration() {
1622        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1623        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1624
1625        // Test different configurations
1626        let mltsvm1 = MLTSVM::new().c1(0.5).c2(1.5);
1627        let mltsvm2 = MLTSVM::new().epsilon(1e-8).max_iter(500);
1628
1629        let trained1 = mltsvm1.fit(&X.view(), &y).unwrap();
1630        let trained2 = mltsvm2.fit(&X.view(), &y).unwrap();
1631
1632        let pred1 = trained1.predict(&X.view()).unwrap();
1633        let pred2 = trained2.predict(&X.view()).unwrap();
1634
1635        assert_eq!(pred1.dim(), (4, 2));
1636        assert_eq!(pred2.dim(), (4, 2));
1637    }
1638
1639    #[test]
1640    fn test_mltsvm_error_handling() {
1641        let X = array![[1.0, 2.0], [2.0, 3.0]];
1642        let y = array![[1, 0], [0, 1], [1, 1]]; // Mismatched samples
1643
1644        let mltsvm = MLTSVM::new();
1645        assert!(mltsvm.fit(&X.view(), &y).is_err());
1646
1647        // Test non-binary labels
1648        let y_non_binary = array![[1, 2], [0, 1]]; // Contains 2
1649        assert!(MLTSVM::new().fit(&X.view(), &y_non_binary).is_err());
1650
1651        // Test prediction with wrong feature dimensions
1652        let X_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 2.0]];
1653        let y_train = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1654        let mltsvm_for_predict = MLTSVM::new();
1655        let trained = mltsvm_for_predict.fit(&X_train.view(), &y_train).unwrap();
1656
1657        let X_wrong_features = array![[1.0, 2.0, 3.0]]; // Extra feature
1658        assert!(trained.predict(&X_wrong_features.view()).is_err());
1659
1660        // Test empty data
1661        let X_empty = Array2::<Float>::zeros((0, 2));
1662        let y_empty = Array2::<i32>::zeros((0, 2));
1663        assert!(MLTSVM::new().fit(&X_empty.view(), &y_empty).is_err());
1664    }
1665
1666    #[test]
1667    fn test_mltsvm_decision_function() {
1668        let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]];
1669        let y = array![[1, 0], [1, 0], [0, 1], [0, 1]];
1670
1671        let mltsvm = MLTSVM::new();
1672        let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1673
1674        // Test decision function
1675        let decision_values = trained_mltsvm.decision_function(&X.view()).unwrap();
1676        assert_eq!(decision_values.dim(), (4, 2));
1677
1678        // Decision values should be real numbers (no constraints on range)
1679        // Just check that we get reasonable outputs
1680        for &val in decision_values.iter() {
1681            assert!(val.is_finite());
1682        }
1683    }
1684
1685    #[test]
1686    fn test_mltsvm_separable_data() {
1687        let X = array![
1688            [0.0, 0.0],
1689            [0.5, 0.5], // Negative class cluster
1690            [3.0, 3.0],
1691            [3.5, 3.5] // Positive class cluster
1692        ];
1693        let y = array![
1694            [0, 1],
1695            [0, 1], // First label: negative, Second label: positive
1696            [1, 0],
1697            [1, 0] // First label: positive, Second label: negative
1698        ];
1699
1700        let mltsvm = MLTSVM::new().c1(1.0).c2(1.0);
1701        let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1702        let predictions = trained_mltsvm.predict(&X.view()).unwrap();
1703
1704        // With linearly separable data, MLTSVM should perform well
1705        let mut correct_predictions = 0;
1706        let total_predictions = predictions.len();
1707
1708        for i in 0..predictions.nrows() {
1709            for j in 0..predictions.ncols() {
1710                if predictions[[i, j]] == y[[i, j]] {
1711                    correct_predictions += 1;
1712                }
1713            }
1714        }
1715
1716        let accuracy = correct_predictions as Float / total_predictions as Float;
1717        // Should get reasonably good accuracy on separable data
1718        assert!(accuracy >= 0.5); // At least better than random
1719    }
1720
1721    #[test]
1722    fn test_mltsvm_feature_scaling() {
1723        // Test with features of very different scales
1724        let X = array![
1725            [1000.0, 0.001],
1726            [2000.0, 0.002],
1727            [3000.0, 0.003],
1728            [4000.0, 0.004]
1729        ];
1730        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1731
1732        let mltsvm = MLTSVM::new();
1733        let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1734        let predictions = trained_mltsvm.predict(&X.view()).unwrap();
1735
1736        // Should handle feature scaling internally
1737        assert_eq!(predictions.dim(), (4, 2));
1738
1739        // All predictions should be binary
1740        for &pred in predictions.iter() {
1741            assert!(pred == 0 || pred == 1);
1742        }
1743    }
1744
1745    #[test]
1746    fn test_mltsvm_consistency() {
1747        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1748        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1749
1750        // Train the same model multiple times (deterministic should give same results)
1751        let mltsvm1 = MLTSVM::new().c1(1.0).c2(1.0);
1752        let trained1 = mltsvm1.fit(&X.view(), &y).unwrap();
1753
1754        let mltsvm2 = MLTSVM::new().c1(1.0).c2(1.0);
1755        let trained2 = mltsvm2.fit(&X.view(), &y).unwrap();
1756
1757        let pred1 = trained1.predict(&X.view()).unwrap();
1758        let pred2 = trained2.predict(&X.view()).unwrap();
1759
1760        // Should be deterministic (same predictions)
1761        for i in 0..pred1.nrows() {
1762            for j in 0..pred1.ncols() {
1763                assert_eq!(pred1[[i, j]], pred2[[i, j]]);
1764            }
1765        }
1766    }
1767
1768    #[test]
1769    fn test_ranksvm_basic_functionality() {
1770        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1771        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; // Multi-label binary
1772
1773        let ranksvm = RankSVM::new().c(1.0);
1774        let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1775        let predictions = trained_ranksvm.predict(&X.view()).unwrap();
1776
1777        assert_eq!(predictions.dim(), (4, 2));
1778        assert_eq!(trained_ranksvm.n_labels(), 2);
1779
1780        // All predictions should be binary (0 or 1)
1781        for &pred in predictions.iter() {
1782            assert!(pred == 0 || pred == 1);
1783        }
1784    }
1785
1786    #[test]
1787    fn test_ranksvm_threshold_strategies() {
1788        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1789        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1790
1791        // Test different threshold strategies
1792        let ranksvm1 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::Fixed(0.5));
1793        let ranksvm2 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::OptimizeF1);
1794        let ranksvm3 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::TopK(2));
1795        let ranksvm4 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::OptimizeF1);
1796
1797        let trained1 = ranksvm1.fit(&X.view(), &y).unwrap();
1798        let trained2 = ranksvm2.fit(&X.view(), &y).unwrap();
1799        let trained3 = ranksvm3.fit(&X.view(), &y).unwrap();
1800        let trained4 = ranksvm4.fit(&X.view(), &y).unwrap();
1801
1802        let pred1 = trained1.predict(&X.view()).unwrap();
1803        let pred2 = trained2.predict(&X.view()).unwrap();
1804        let pred3 = trained3.predict(&X.view()).unwrap();
1805        let pred4 = trained4.predict(&X.view()).unwrap();
1806
1807        assert_eq!(pred1.dim(), (4, 2));
1808        assert_eq!(pred2.dim(), (4, 2));
1809        assert_eq!(pred3.dim(), (4, 2));
1810        assert_eq!(pred4.dim(), (4, 2));
1811
1812        // Test threshold accessors
1813        assert_eq!(trained1.thresholds().len(), 2);
1814        assert_eq!(trained2.thresholds().len(), 2);
1815        assert_eq!(trained3.thresholds().len(), 2);
1816    }
1817
1818    #[test]
1819    fn test_ranksvm_decision_function() {
1820        let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]];
1821        let y = array![[1, 0], [1, 0], [0, 1], [0, 1]];
1822
1823        let ranksvm = RankSVM::new();
1824        let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1825
1826        // Test decision function
1827        let scores = trained_ranksvm.decision_function(&X.view()).unwrap();
1828        assert_eq!(scores.dim(), (4, 2));
1829
1830        // Scores should be real numbers
1831        for &score in scores.iter() {
1832            assert!(score.is_finite());
1833        }
1834    }
1835
1836    #[test]
1837    fn test_ranksvm_predict_ranking() {
1838        let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1839        let y = array![[1, 0, 0], [0, 1, 0], [0, 0, 1]]; // Three labels, one active per sample
1840
1841        let ranksvm = RankSVM::new();
1842        let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1843
1844        // Test ranking prediction
1845        let rankings = trained_ranksvm.predict_ranking(&X.view()).unwrap();
1846
1847        assert_eq!(rankings.dim(), (3, 3)); // 3 samples, 3 labels
1848        for sample_idx in 0..3 {
1849            let mut ranking_vec = Vec::new();
1850            for label_idx in 0..3 {
1851                ranking_vec.push(rankings[[sample_idx, label_idx]]);
1852            }
1853            // Should contain all label indices
1854            let mut sorted_ranking = ranking_vec.clone();
1855            sorted_ranking.sort();
1856            assert_eq!(sorted_ranking, vec![0, 1, 2]);
1857        }
1858    }
1859
1860    #[test]
1861    fn test_ranksvm_error_handling() {
1862        let X = array![[1.0, 2.0], [2.0, 3.0]];
1863        let y = array![[1, 0], [0, 1], [1, 1]]; // Mismatched samples
1864
1865        let ranksvm = RankSVM::new();
1866        assert!(ranksvm.fit(&X.view(), &y).is_err());
1867
1868        // Test non-binary labels
1869        let y_non_binary = array![[1, 2], [0, 1]]; // Contains 2
1870        assert!(RankSVM::new().fit(&X.view(), &y_non_binary).is_err());
1871
1872        // Test prediction with wrong feature dimensions
1873        let X_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 2.0]];
1874        let y_train = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1875        let ranksvm_for_predict = RankSVM::new();
1876        let trained = ranksvm_for_predict.fit(&X_train.view(), &y_train).unwrap();
1877
1878        let X_wrong_features = array![[1.0, 2.0, 3.0]]; // Extra feature
1879        assert!(trained.predict(&X_wrong_features.view()).is_err());
1880        assert!(trained.decision_function(&X_wrong_features.view()).is_err());
1881        assert!(trained.predict_ranking(&X_wrong_features.view()).is_err());
1882
1883        // Test empty data
1884        let X_empty = Array2::<Float>::zeros((0, 2));
1885        let y_empty = Array2::<i32>::zeros((0, 2));
1886        assert!(RankSVM::new().fit(&X_empty.view(), &y_empty).is_err());
1887    }
1888
1889    #[test]
1890    fn test_ranksvm_configuration() {
1891        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1892        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1893
1894        // Test different configurations
1895        let ranksvm1 = RankSVM::new().c(0.5).epsilon(1e-8);
1896        let ranksvm2 = RankSVM::new().max_iter(500);
1897
1898        let trained1 = ranksvm1.fit(&X.view(), &y).unwrap();
1899        let trained2 = ranksvm2.fit(&X.view(), &y).unwrap();
1900
1901        let pred1 = trained1.predict(&X.view()).unwrap();
1902        let pred2 = trained2.predict(&X.view()).unwrap();
1903
1904        assert_eq!(pred1.dim(), (4, 2));
1905        assert_eq!(pred2.dim(), (4, 2));
1906    }
1907
1908    #[test]
1909    fn test_ranksvm_ranking_consistency() {
1910        let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
1911        let y = array![
1912            [0, 0, 1], // Last label should rank highest
1913            [0, 1, 0], // Middle label should rank highest
1914            [1, 0, 0]  // First label should rank highest
1915        ];
1916
1917        let ranksvm = RankSVM::new().threshold_strategy(SVMThresholdStrategy::OptimizeF1);
1918        let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1919
1920        let rankings = trained_ranksvm.predict_ranking(&X.view()).unwrap();
1921        let scores = trained_ranksvm.decision_function(&X.view()).unwrap();
1922
1923        // Check that rankings are consistent with scores
1924        for i in 0..3 {
1925            // First ranked label should have highest score
1926            let top_label = rankings[[i, 0]];
1927            for j in 1..3 {
1928                let other_label = rankings[[i, j]];
1929                assert!(scores[[i, top_label]] >= scores[[i, other_label]]);
1930            }
1931        }
1932    }
1933
1934    #[test]
1935    fn test_ranksvm_single_class_handling() {
1936        let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1937
1938        // Test with all positive for one label, mixed for other
1939        let y = array![[1, 0], [1, 1], [1, 0]]; // First label: all positive, second label: mixed
1940
1941        let ranksvm = RankSVM::new();
1942        let trained = ranksvm.fit(&X.view(), &y).unwrap();
1943        let predictions = trained.predict(&X.view()).unwrap();
1944        let scores = trained.decision_function(&X.view()).unwrap();
1945
1946        assert_eq!(predictions.dim(), (3, 2));
1947        assert_eq!(scores.dim(), (3, 2));
1948
1949        // All scores should be finite
1950        for &score in scores.iter() {
1951            assert!(score.is_finite());
1952        }
1953    }
1954
1955    #[test]
1956    fn test_ranksvm_reproducibility() {
1957        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1958        let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1959
1960        // Train two models with same configuration
1961        let ranksvm1 = RankSVM::new().c(1.0).epsilon(1e-6);
1962        let trained1 = ranksvm1.fit(&X.view(), &y).unwrap();
1963
1964        let ranksvm2 = RankSVM::new().c(1.0).epsilon(1e-6);
1965        let trained2 = ranksvm2.fit(&X.view(), &y).unwrap();
1966
1967        let pred1 = trained1.predict(&X.view()).unwrap();
1968        let pred2 = trained2.predict(&X.view()).unwrap();
1969        let scores1 = trained1.decision_function(&X.view()).unwrap();
1970        let scores2 = trained2.decision_function(&X.view()).unwrap();
1971
1972        // Should be deterministic (same predictions and scores)
1973        for i in 0..pred1.nrows() {
1974            for j in 0..pred1.ncols() {
1975                assert_eq!(pred1[[i, j]], pred2[[i, j]]);
1976                assert!((scores1[[i, j]] - scores2[[i, j]]).abs() < 1e-10);
1977            }
1978        }
1979    }
1980}