1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8#![allow(unused_imports)]
9#![allow(unused_variables)]
10#![allow(unused_mut)]
11#![allow(unused_assignments)]
12#![allow(unused_doc_comments)]
13#![allow(unused_parens)]
14#![allow(unused_comparisons)]
15pub mod activation;
23pub mod adversarial;
24pub mod chains;
25pub mod classification;
26pub mod core;
27pub mod correlation;
28pub mod ensemble;
29pub mod hierarchical;
30pub mod label_analysis;
31pub mod loss;
32pub mod metrics;
33pub mod mlp;
34pub mod multi_label;
35pub mod multitask;
36pub mod neighbors;
37pub mod neural;
38pub mod optimization;
39pub mod performance;
40pub mod probabilistic;
41pub mod ranking;
42pub mod recurrent;
43pub mod regularization;
44pub mod sequence;
45pub mod sparse_storage;
46pub mod streaming;
47pub mod svm;
48pub mod transfer_learning;
49pub mod tree;
50pub mod utilities;
51pub mod utils;
52
53pub use core::{
57 MultiOutputClassifier, MultiOutputClassifierTrained, MultiOutputRegressor,
58 MultiOutputRegressorTrained,
59};
60
61pub use chains::{
63 BayesianClassifierChain, BayesianClassifierChainTrained, ChainMethod, ClassifierChain,
64 ClassifierChainTrained, EnsembleOfChains, EnsembleOfChainsTrained, RegressorChain,
65 RegressorChainTrained,
66};
67
68pub use ensemble::{GradientBoostingMultiOutput, GradientBoostingMultiOutputTrained, WeakLearner};
70
71pub use neural::{
73 ActivationFunction, AdversarialMultiTaskNetwork, AdversarialMultiTaskNetworkTrained,
74 AdversarialStrategy, CellType, GradientReversalConfig, LambdaSchedule, LossFunction,
75 MultiOutputMLP, MultiOutputMLPClassifier, MultiOutputMLPRegressor, MultiOutputMLPTrained,
76 MultiTaskNeuralNetwork, MultiTaskNeuralNetworkTrained, RecurrentNeuralNetwork,
77 RecurrentNeuralNetworkTrained, SequenceMode, TaskBalancing, TaskDiscriminator,
78};
79
80pub use adversarial::AdversarialConfig;
82
83pub use regularization::{
85 GroupLasso, GroupLassoTrained, MetaLearningMultiTask, MetaLearningMultiTaskTrained,
86 MultiTaskElasticNet, MultiTaskElasticNetTrained, NuclearNormRegression,
87 NuclearNormRegressionTrained, RegularizationStrategy, TaskClusteringRegressionTrained,
88 TaskClusteringRegularization, TaskRelationshipLearning, TaskRelationshipLearningTrained,
89 TaskSimilarityMethod,
90};
91
92pub use correlation::{
94 CITestMethod, CITestResult, CITestResults, ConditionalIndependenceTester, CorrelationAnalysis,
95 CorrelationType, DependencyGraph, DependencyGraphBuilder, DependencyMethod, GraphStatistics,
96 OutputCorrelationAnalyzer,
97};
98
99pub use transfer_learning::{
101 ContinualLearning, ContinualLearningTrained, CrossTaskTransferLearning,
102 CrossTaskTransferLearningTrained, DomainAdaptation, DomainAdaptationTrained,
103 KnowledgeDistillation, KnowledgeDistillationTrained, ProgressiveTransferLearning,
104 ProgressiveTransferLearningTrained,
105};
106
107pub use optimization::{
109 JointLossConfig, JointLossOptimizer, JointLossOptimizerTrained, LossCombination,
110 LossFunction as OptimizationLossFunction, MultiObjectiveConfig, MultiObjectiveOptimizer,
111 MultiObjectiveOptimizerTrained, NSGA2Algorithm, NSGA2Config, NSGA2Optimizer,
112 NSGA2OptimizerTrained, ParetoSolution, ScalarizationConfig, ScalarizationMethod,
113 ScalarizationOptimizer, ScalarizationOptimizerTrained,
114};
115
116pub use probabilistic::{
118 BayesianMultiOutputConfig, BayesianMultiOutputModel, BayesianMultiOutputModelTrained,
119 EnsembleBayesianConfig, EnsembleBayesianModel, EnsembleBayesianModelTrained, EnsembleStrategy,
120 GaussianProcessMultiOutput, GaussianProcessMultiOutputTrained, InferenceMethod, KernelFunction,
121 PosteriorDistribution, PredictionWithUncertainty, PriorDistribution,
122};
123
124pub use ranking::{
126 BinaryClassifierModel, IndependentLabelPrediction, IndependentLabelPredictionTrained,
127 ThresholdStrategy as RankingThresholdStrategy,
128};
129
130pub use sparse_storage::{
132 sparse_utils, CSRMatrix, MemoryUsage, SparseMultiOutput, SparseMultiOutputTrained,
133 SparsityAnalysis, StorageRecommendation,
134};
135
136pub use streaming::{
138 IncrementalMultiOutputRegression, IncrementalMultiOutputRegressionConfig,
139 IncrementalMultiOutputRegressionTrained, StreamingMultiOutput, StreamingMultiOutputConfig,
140 StreamingMultiOutputTrained,
141};
142
143pub use performance::{
145 EarlyStopping, EarlyStoppingConfig, PredictionCache, WarmStartRegressor,
146 WarmStartRegressorConfig, WarmStartRegressorTrained,
147};
148
149pub use multi_label::{
151 BinaryRelevance, BinaryRelevanceTrained, LabelPowerset, LabelPowersetTrained,
152 OneVsRestClassifier, OneVsRestClassifierTrained, PrunedLabelPowerset,
153 PrunedLabelPowersetTrained, PruningStrategy,
154};
155
156pub use tree::{
158 ClassificationCriterion, DAGInferenceMethod, MultiTargetDecisionTreeClassifier,
159 MultiTargetDecisionTreeClassifierTrained, MultiTargetRegressionTree,
160 MultiTargetRegressionTreeTrained, RandomForestMultiOutput, RandomForestMultiOutputTrained,
161 TreeStructuredPredictor, TreeStructuredPredictorTrained,
162};
163
164pub use neighbors::{IBLRTrained, WeightFunction, IBLR};
166
167pub use svm::{
169 MLTSVMTrained, MultiOutputSVM, MultiOutputSVMTrained, RankSVM, RankSVMTrained, RankingSVMModel,
170 SVMKernel, SVMModel, ThresholdStrategy as SVMThresholdStrategy, TwinSVMModel, MLTSVM,
171};
172
173pub use sequence::{
175 FeatureFunction, FeatureType, HiddenMarkovModel, HiddenMarkovModelTrained,
176 MaximumEntropyMarkovModel, MaximumEntropyMarkovModelTrained, StructuredPerceptron,
177 StructuredPerceptronTrained,
178};
179
180pub use hierarchical::{
182 AggregationFunction, ConsistencyEnforcement, CostSensitiveHierarchicalClassifier,
183 CostSensitiveHierarchicalClassifierTrained, CostStrategy, GraphNeuralNetwork,
184 GraphNeuralNetworkTrained, MessagePassingVariant, OntologyAwareClassifier,
185 OntologyAwareClassifierTrained,
186};
187
188pub use classification::{
190 CalibratedBinaryRelevance, CalibratedBinaryRelevanceTrained, CalibrationMethod, CostMatrix,
191 CostSensitiveBinaryRelevance, CostSensitiveBinaryRelevanceTrained, DistanceMetric, MLkNN,
192 MLkNNTrained, RandomLabelCombinations, SimpleBinaryModel,
193};
194
195pub use metrics::{
197 average_precision_score,
198 confidence_interval,
199 coverage_error,
200 f1_score,
201 hamming_loss,
203 jaccard_score,
204 label_ranking_average_precision,
205 mcnemar_test,
207 one_error,
208 paired_t_test,
209 per_label_metrics,
211 precision_score_micro,
212 ranking_loss,
213 recall_score_micro,
214
215 subset_accuracy,
216 wilcoxon_signed_rank_test,
217 ConfidenceInterval,
218 PerLabelMetrics,
219
220 StatisticalTestResult,
221};
222
223#[allow(non_snake_case)]
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::utilities::CLARE;
229 use scirs2_core::ndarray::{array, Array2};
230 use sklears_core::traits::{Fit, Predict};
231 use sklears_core::types::Float;
232
233 #[test]
234 fn test_multi_output_classifier() {
235 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
236 let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
237
238 let moc = MultiOutputClassifier::new();
239 let fitted = moc.fit(&X.view(), &y).unwrap();
240
241 assert_eq!(fitted.n_targets(), 2);
242 assert_eq!(fitted.classes().len(), 2);
243
244 let predictions = fitted.predict(&X.view()).unwrap();
245 assert_eq!(predictions.dim(), (4, 2));
246
247 for target_idx in 0..2 {
249 let target_classes = &fitted.classes()[target_idx];
250 for sample_idx in 0..4 {
251 let pred = predictions[[sample_idx, target_idx]];
252 assert!(target_classes.contains(&pred));
253 }
254 }
255 }
256
257 #[test]
258 fn test_multi_output_regressor() {
259 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
260 let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5], [1.0, 1.5]];
261
262 let mor = MultiOutputRegressor::new();
263 let fitted = mor.fit(&X.view(), &y).unwrap();
264
265 assert_eq!(fitted.n_targets(), 2);
266
267 let predictions = fitted.predict(&X.view()).unwrap();
268 assert_eq!(predictions.dim(), (4, 2));
269
270 for pred in predictions.iter() {
272 assert!(pred.is_finite());
273 }
274 }
275
276 #[test]
277 fn test_invalid_input() {
278 let X = array![[1.0, 2.0], [2.0, 3.0]];
279 let y = array![[0, 1], [1, 0], [0, 1]]; let moc = MultiOutputClassifier::new();
282 assert!(moc.fit(&X.view(), &y).is_err());
283 }
284
285 #[test]
286 fn test_empty_targets() {
287 let X = array![[1.0, 2.0], [2.0, 3.0]];
288 let y = Array2::<i32>::zeros((2, 0)); let moc = MultiOutputClassifier::new();
291 assert!(moc.fit(&X.view(), &y).is_err());
292 }
293
294 #[test]
295 fn test_prediction_shape_mismatch() {
296 let X = array![[1.0, 2.0], [2.0, 3.0]];
297 let y = array![[0, 1], [1, 0]];
298
299 let moc = MultiOutputClassifier::new();
300 let fitted = moc.fit(&X.view(), &y).unwrap();
301
302 let X_wrong = array![[1.0, 2.0, 3.0]]; assert!(fitted.predict(&X_wrong.view()).is_err());
305 }
306
307 #[test]
308 fn test_classifier_chain() {
309 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
310 let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
311
312 let cc = ClassifierChain::new();
313 let fitted = cc.fit_simple(&X.view(), &y).unwrap();
314
315 assert_eq!(fitted.n_targets(), 2);
316 assert_eq!(fitted.chain_order(), &[0, 1]); let predictions = fitted.predict_simple(&X.view()).unwrap();
319 assert_eq!(predictions.dim(), (4, 2));
320
321 for sample_idx in 0..4 {
323 for target_idx in 0..2 {
324 let pred = predictions[[sample_idx, target_idx]];
325 assert!(pred == 0 || pred == 1);
327 }
328 }
329 }
330
331 #[test]
332 fn test_classifier_chain_custom_order() {
333 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
334 let y = array![[0, 1], [1, 0], [1, 1]];
335
336 let cc = ClassifierChain::new().order(vec![1, 0]); let fitted = cc.fit_simple(&X.view(), &y).unwrap();
338
339 assert_eq!(fitted.chain_order(), &[1, 0]);
340
341 let predictions = fitted.predict_simple(&X.view()).unwrap();
342 assert_eq!(predictions.dim(), (3, 2));
343 }
344
345 #[test]
346 fn test_classifier_chain_invalid_order() {
347 let X = array![[1.0, 2.0], [2.0, 3.0]];
348 let y = array![[0, 1], [1, 0]];
349
350 let cc = ClassifierChain::new().order(vec![0, 1, 2]); assert!(cc.fit_simple(&X.view(), &y).is_err());
352 }
353
354 #[test]
355 fn test_classifier_chain_monte_carlo() {
356 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
357 let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
358
359 let cc = ClassifierChain::new();
360 let fitted = cc.fit_simple(&X.view(), &y).unwrap();
361
362 let mc_probs = fitted
364 .predict_monte_carlo(&X.view(), 100, Some(42))
365 .unwrap();
366 assert_eq!(mc_probs.dim(), (4, 2));
367
368 for prob in mc_probs.iter() {
370 assert!(*prob >= 0.0 && *prob <= 1.0);
371 }
372
373 let mc_labels = fitted
375 .predict_monte_carlo_labels(&X.view(), 100, Some(42))
376 .unwrap();
377 assert_eq!(mc_labels.dim(), (4, 2));
378
379 for pred in mc_labels.iter() {
381 assert!(*pred == 0 || *pred == 1);
382 }
383
384 let mc_probs2 = fitted
386 .predict_monte_carlo(&X.view(), 100, Some(42))
387 .unwrap();
388 for (i, (&prob1, &prob2)) in mc_probs.iter().zip(mc_probs2.iter()).enumerate() {
389 assert!(
390 (prob1 - prob2).abs() < 1e-10,
391 "Probabilities should be identical with same random state at index {}",
392 i
393 );
394 }
395 }
396
397 #[test]
398 fn test_classifier_chain_monte_carlo_invalid_input() {
399 let X = array![[1.0, 2.0], [2.0, 3.0]];
400 let y = array![[0, 1], [1, 0]];
401
402 let cc = ClassifierChain::new();
403 let fitted = cc.fit_simple(&X.view(), &y).unwrap();
404
405 assert!(fitted.predict_monte_carlo(&X.view(), 0, None).is_err());
407
408 let X_wrong = array![[1.0, 2.0, 3.0]]; assert!(fitted
411 .predict_monte_carlo(&X_wrong.view(), 10, None)
412 .is_err());
413 }
414
415 #[test]
416 fn test_regressor_chain() {
417 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
418 let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5], [1.0, 1.5]];
419
420 let rc = RegressorChain::new();
421 let fitted = rc.fit_simple(&X.view(), &y).unwrap();
422
423 assert_eq!(fitted.n_targets(), 2);
424 assert_eq!(fitted.chain_order(), &[0, 1]); let predictions = fitted.predict_simple(&X.view()).unwrap();
427 assert_eq!(predictions.dim(), (4, 2));
428
429 for pred in predictions.iter() {
431 assert!(pred.is_finite());
432 }
433 }
434
435 #[test]
436 fn test_regressor_chain_custom_order() {
437 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
438 let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5]];
439
440 let rc = RegressorChain::new().order(vec![1, 0]); let fitted = rc.fit_simple(&X.view(), &y).unwrap();
442
443 assert_eq!(fitted.chain_order(), &[1, 0]);
444
445 let predictions = fitted.predict_simple(&X.view()).unwrap();
446 assert_eq!(predictions.dim(), (3, 2));
447
448 for pred in predictions.iter() {
450 assert!(pred.is_finite());
451 }
452 }
453
454 #[test]
455 fn test_regressor_chain_invalid_input() {
456 let X = array![[1.0, 2.0], [2.0, 3.0]];
457 let y = array![[1.5, 2.5], [2.5, 3.5], [2.0, 1.5]]; let rc = RegressorChain::new();
460 assert!(rc.fit_simple(&X.view(), &y).is_err());
461 }
462
463 #[test]
464 fn test_regressor_chain_invalid_order() {
465 let X = array![[1.0, 2.0], [2.0, 3.0]];
466 let y = array![[1.5, 2.5], [2.5, 3.5]];
467
468 let rc = RegressorChain::new().order(vec![0, 1, 2]); assert!(rc.fit_simple(&X.view(), &y).is_err());
470 }
471
472 #[test]
473 fn test_binary_relevance() {
474 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
475 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let br = BinaryRelevance::new();
478 let fitted = br.fit(&X.view(), &y).unwrap();
479
480 assert_eq!(fitted.n_labels(), 2);
481 assert_eq!(fitted.classes().len(), 2);
482
483 let predictions = fitted.predict(&X.view()).unwrap();
484 assert_eq!(predictions.dim(), (4, 2));
485
486 for pred in predictions.iter() {
488 assert!(*pred == 0 || *pred == 1);
489 }
490
491 let probabilities = fitted.predict_proba(&X.view()).unwrap();
493 assert_eq!(probabilities.dim(), (4, 2));
494
495 for prob in probabilities.iter() {
497 assert!(*prob >= 0.0 && *prob <= 1.0);
498 }
499 }
500
501 #[test]
502 fn test_binary_relevance_single_label() {
503 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
504 let y = array![[1], [0], [1]]; let br = BinaryRelevance::new();
507 let fitted = br.fit(&X.view(), &y).unwrap();
508
509 assert_eq!(fitted.n_labels(), 1);
510
511 let predictions = fitted.predict(&X.view()).unwrap();
512 assert_eq!(predictions.dim(), (3, 1));
513
514 for pred in predictions.iter() {
516 assert!(*pred == 0 || *pred == 1);
517 }
518 }
519
520 #[test]
521 fn test_binary_relevance_invalid_input() {
522 let X = array![[1.0, 2.0], [2.0, 3.0]];
523 let y = array![[1, 0], [0, 1], [1, 1]]; let br = BinaryRelevance::new();
526 assert!(br.fit(&X.view(), &y).is_err());
527 }
528
529 #[test]
530 fn test_binary_relevance_non_binary_labels() {
531 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
532 let y = array![[0, 1], [1, 2], [2, 0]]; let br = BinaryRelevance::new();
535 assert!(br.fit(&X.view(), &y).is_err());
536 }
537
538 #[test]
539 fn test_binary_relevance_predict_shape_mismatch() {
540 let X = array![[1.0, 2.0], [2.0, 3.0]];
541 let y = array![[1, 0], [0, 1]];
542
543 let br = BinaryRelevance::new();
544 let fitted = br.fit(&X.view(), &y).unwrap();
545
546 let X_wrong = array![[1.0, 2.0, 3.0]]; assert!(fitted.predict(&X_wrong.view()).is_err());
549 assert!(fitted.predict_proba(&X_wrong.view()).is_err());
550 }
551
552 #[test]
553 fn test_label_powerset() {
554 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
555 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let lp = LabelPowerset::new();
558 let fitted = lp.fit(&X.view(), &y).unwrap();
559
560 assert_eq!(fitted.n_labels(), 2);
561 assert_eq!(fitted.n_classes(), 4); let predictions = fitted.predict(&X.view()).unwrap();
564 assert_eq!(predictions.dim(), (4, 2));
565
566 for pred in predictions.iter() {
568 assert!(*pred == 0 || *pred == 1);
569 }
570
571 let scores = fitted.decision_function(&X.view()).unwrap();
573 assert_eq!(scores.dim(), (4, 4)); for score in scores.iter() {
577 assert!(score.is_finite());
578 }
579 }
580
581 #[test]
582 fn test_label_powerset_simple_case() {
583 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
584 let y = array![[1, 0], [0, 1], [1, 0]]; let lp = LabelPowerset::new();
587 let fitted = lp.fit(&X.view(), &y).unwrap();
588
589 assert_eq!(fitted.n_labels(), 2);
590 assert_eq!(fitted.n_classes(), 2); let predictions = fitted.predict(&X.view()).unwrap();
593 assert_eq!(predictions.dim(), (3, 2));
594
595 for pred in predictions.iter() {
597 assert!(*pred == 0 || *pred == 1);
598 }
599 }
600
601 #[test]
602 fn test_label_powerset_single_label() {
603 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
604 let y = array![[1], [0], [1]]; let lp = LabelPowerset::new();
607 let fitted = lp.fit(&X.view(), &y).unwrap();
608
609 assert_eq!(fitted.n_labels(), 1);
610 assert_eq!(fitted.n_classes(), 2); let predictions = fitted.predict(&X.view()).unwrap();
613 assert_eq!(predictions.dim(), (3, 1));
614
615 for pred in predictions.iter() {
617 assert!(*pred == 0 || *pred == 1);
618 }
619 }
620
621 #[test]
622 fn test_label_powerset_invalid_input() {
623 let X = array![[1.0, 2.0], [2.0, 3.0]];
624 let y = array![[1, 0], [0, 1], [1, 1]]; let lp = LabelPowerset::new();
627 assert!(lp.fit(&X.view(), &y).is_err());
628 }
629
630 #[test]
631 fn test_label_powerset_non_binary_labels() {
632 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
633 let y = array![[0, 1], [1, 2], [2, 0]]; let lp = LabelPowerset::new();
636 assert!(lp.fit(&X.view(), &y).is_err());
637 }
638
639 #[test]
640 fn test_label_powerset_predict_shape_mismatch() {
641 let X = array![[1.0, 2.0], [2.0, 3.0]];
642 let y = array![[1, 0], [0, 1]];
643
644 let lp = LabelPowerset::new();
645 let fitted = lp.fit(&X.view(), &y).unwrap();
646
647 let X_wrong = array![[1.0, 2.0, 3.0]]; assert!(fitted.predict(&X_wrong.view()).is_err());
650 assert!(fitted.decision_function(&X_wrong.view()).is_err());
651 }
652
653 #[test]
654 fn test_label_powerset_all_same_combination() {
655 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
656 let y = array![[1, 0], [1, 0], [1, 0]]; let lp = LabelPowerset::new();
659 let fitted = lp.fit(&X.view(), &y).unwrap();
660
661 assert_eq!(fitted.n_labels(), 2);
662 assert_eq!(fitted.n_classes(), 1); let predictions = fitted.predict(&X.view()).unwrap();
665 assert_eq!(predictions.dim(), (3, 2));
666
667 for sample_idx in 0..3 {
669 assert_eq!(predictions[[sample_idx, 0]], 1);
670 assert_eq!(predictions[[sample_idx, 1]], 0);
671 }
672 }
673
674 #[test]
675 fn test_pruned_label_powerset_default_strategy() {
676 let X = array![
678 [1.0, 2.0],
679 [2.0, 3.0],
680 [3.0, 1.0],
681 [1.0, 1.0],
682 [2.0, 2.0],
683 [3.0, 3.0],
684 [1.5, 2.5],
685 [2.5, 1.5]
686 ];
687 let y = array![
688 [1, 0],
689 [0, 1],
690 [1, 1],
691 [0, 0], [1, 0],
693 [0, 1],
694 [1, 0],
695 [0, 1], ];
697
698 let plp = PrunedLabelPowerset::new()
699 .min_frequency(2)
700 .strategy(PruningStrategy::DefaultMapping(vec![0, 0]));
701
702 let fitted = plp.fit(&X.view(), &y).unwrap();
703
704 assert!(fitted.n_frequent_classes() <= 4); assert_eq!(fitted.min_frequency(), 2);
707
708 let predictions = fitted.predict(&X.view()).unwrap();
709 assert_eq!(predictions.dim(), (8, 2));
710
711 for pred in predictions.iter() {
713 assert!(*pred == 0 || *pred == 1);
714 }
715 }
716
717 #[test]
718 fn test_pruned_label_powerset_similarity_strategy() {
719 let X = array![
721 [1.0, 2.0],
722 [2.0, 3.0],
723 [3.0, 1.0],
724 [1.0, 1.0],
725 [2.0, 2.0],
726 [3.0, 3.0]
727 ];
728 let y = array![
729 [1, 0],
730 [1, 0],
731 [1, 0], [0, 1],
733 [0, 1], [1, 1] ];
736
737 let plp = PrunedLabelPowerset::new()
738 .min_frequency(2)
739 .strategy(PruningStrategy::SimilarityMapping);
740
741 let fitted = plp.fit(&X.view(), &y).unwrap();
742
743 assert_eq!(fitted.n_frequent_classes(), 2);
745
746 let mapping = fitted.combination_mapping();
748 let rare_combo = vec![1, 1];
749 assert!(mapping.contains_key(&rare_combo));
750
751 let mapped = mapping.get(&rare_combo).unwrap();
753 assert!(mapped == &vec![1, 0] || mapped == &vec![0, 1]);
754
755 let predictions = fitted.predict(&X.view()).unwrap();
756 assert_eq!(predictions.dim(), (6, 2));
757
758 for pred in predictions.iter() {
760 assert!(*pred == 0 || *pred == 1);
761 }
762 }
763
764 #[test]
765 fn test_pruned_label_powerset_invalid_input() {
766 let X = array![[1.0, 2.0], [2.0, 3.0]];
767 let y = array![[0, 1], [1, 0]];
768
769 let plp = PrunedLabelPowerset::new().min_frequency(5); assert!(plp.fit(&X.view(), &y).is_err());
772
773 let plp =
775 PrunedLabelPowerset::new().strategy(PruningStrategy::DefaultMapping(vec![0, 1, 0])); assert!(plp.fit(&X.view(), &y).is_err());
777
778 let y_bad = array![[2, 1], [1, 0]]; let plp = PrunedLabelPowerset::new();
781 assert!(plp.fit(&X.view(), &y_bad).is_err());
782 }
783
784 #[test]
785 fn test_pruned_label_powerset_edge_cases() {
786 let X = array![[1.0, 2.0], [2.0, 3.0]];
788 let y = array![[1, 0], [1, 0]]; let plp = PrunedLabelPowerset::new().min_frequency(2);
791 let fitted = plp.fit(&X.view(), &y).unwrap();
792
793 assert!(fitted.n_frequent_classes() >= 1);
795 assert!(fitted.frequent_combinations().len() >= 1);
796
797 assert!(fitted.frequent_combinations().contains(&vec![1, 0]));
799
800 let predictions = fitted.predict(&X.view()).unwrap();
801 assert_eq!(predictions.dim(), (2, 2));
802
803 for sample_idx in 0..2 {
805 assert_eq!(predictions[[sample_idx, 0]], 1);
806 assert_eq!(predictions[[sample_idx, 1]], 0);
807 }
808 }
809
810 #[test]
811 fn test_metrics_hamming_loss() {
812 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
813 let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]]; let loss = metrics::hamming_loss(&y_true.view(), &y_pred.view()).unwrap();
816 assert!((loss - 3.0 / 9.0).abs() < 1e-10);
817 }
818
819 #[test]
820 fn test_metrics_subset_accuracy() {
821 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
822 let y_pred = array![[1, 0, 1], [0, 1, 1], [1, 0, 1]]; let accuracy = metrics::subset_accuracy(&y_true.view(), &y_pred.view()).unwrap();
825 assert!((accuracy - 1.0 / 3.0).abs() < 1e-10);
826 }
827
828 #[test]
829 fn test_metrics_jaccard_score() {
830 let y_true = array![[1, 0, 1], [0, 1, 0]];
831 let y_pred = array![[1, 0, 0], [0, 1, 1]];
832
833 let score = metrics::jaccard_score(&y_true.view(), &y_pred.view()).unwrap();
834 assert!((score - 0.5).abs() < 1e-10);
838 }
839
840 #[test]
841 fn test_metrics_f1_score_micro() {
842 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
843 let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]];
844
845 let f1 = metrics::f1_score(&y_true.view(), &y_pred.view(), "micro").unwrap();
846 assert!((f1 - 0.7272727272727273).abs() < 1e-10);
850 }
851
852 #[test]
853 fn test_metrics_f1_score_macro() {
854 let y_true = array![[1, 0], [0, 1], [1, 1]];
855 let y_pred = array![[1, 0], [0, 1], [1, 0]]; let f1 = metrics::f1_score(&y_true.view(), &y_pred.view(), "macro").unwrap();
858 assert!((f1 - 0.8333333333333334).abs() < 1e-10);
862 }
863
864 #[test]
865 fn test_metrics_f1_score_samples() {
866 let y_true = array![[1, 0], [0, 1], [1, 1]];
867 let y_pred = array![[1, 0], [0, 1], [1, 0]];
868
869 let f1 = metrics::f1_score(&y_true.view(), &y_pred.view(), "samples").unwrap();
870 assert!((f1 - 0.8888888888888888).abs() < 1e-10);
875 }
876
877 #[test]
878 fn test_metrics_coverage_error() {
879 let y_true = array![[1, 0, 1], [0, 1, 0]];
880 let y_scores = array![[0.9, 0.1, 0.8], [0.2, 0.9, 0.3]];
881
882 let coverage = metrics::coverage_error(&y_true.view(), &y_scores.view()).unwrap();
883 assert!((coverage - 1.5).abs() < 1e-10);
889 }
890
891 #[test]
892 fn test_metrics_label_ranking_average_precision() {
893 let y_true = array![[1, 0, 1], [0, 1, 0]];
894 let y_scores = array![[0.9, 0.1, 0.8], [0.2, 0.9, 0.3]];
895
896 let lrap =
897 metrics::label_ranking_average_precision(&y_true.view(), &y_scores.view()).unwrap();
898 assert!((lrap - 1.0).abs() < 1e-10);
908 }
909
910 #[test]
911 fn test_metrics_invalid_shapes() {
912 let y_true = array![[1, 0], [0, 1]];
913 let y_pred = array![[1, 0, 1]]; assert!(metrics::hamming_loss(&y_true.view(), &y_pred.view()).is_err());
916 assert!(metrics::subset_accuracy(&y_true.view(), &y_pred.view()).is_err());
917 assert!(metrics::jaccard_score(&y_true.view(), &y_pred.view()).is_err());
918 assert!(metrics::f1_score(&y_true.view(), &y_pred.view(), "micro").is_err());
919 }
920
921 #[test]
922 fn test_metrics_invalid_f1_average() {
923 let y_true = array![[1, 0], [0, 1]];
924 let y_pred = array![[1, 0], [0, 1]];
925
926 assert!(metrics::f1_score(&y_true.view(), &y_pred.view(), "invalid").is_err());
927 }
928
929 #[test]
930 fn test_ensemble_of_chains() {
931 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
932 let y = array![[0, 1], [1, 0], [1, 1], [0, 0]];
933
934 let eoc = EnsembleOfChains::new().n_chains(3).random_state(42);
935 let fitted = eoc.fit_simple(&X.view(), &y).unwrap();
936
937 assert_eq!(fitted.n_chains(), 3);
938 assert_eq!(fitted.n_targets(), 2);
939
940 let predictions = fitted.predict_simple(&X.view()).unwrap();
941 assert_eq!(predictions.dim(), (4, 2));
942
943 for pred in predictions.iter() {
945 assert!(*pred == 0 || *pred == 1);
946 }
947
948 let probabilities = fitted.predict_proba_simple(&X.view()).unwrap();
950 assert_eq!(probabilities.dim(), (4, 2));
951
952 for prob in probabilities.iter() {
954 assert!(*prob >= 0.0 && *prob <= 1.0);
955 }
956 }
957
958 #[test]
959 fn test_ensemble_of_chains_single_chain() {
960 let X = array![[1.0, 2.0], [2.0, 3.0]];
961 let y = array![[1, 0], [0, 1]];
962
963 let eoc = EnsembleOfChains::new().n_chains(1);
964 let fitted = eoc.fit_simple(&X.view(), &y).unwrap();
965
966 assert_eq!(fitted.n_chains(), 1);
967
968 let predictions = fitted.predict_simple(&X.view()).unwrap();
969 assert_eq!(predictions.dim(), (2, 2));
970 }
971
972 #[test]
973 fn test_ensemble_of_chains_invalid_input() {
974 let X = array![[1.0, 2.0], [2.0, 3.0]];
975 let y = array![[1, 0], [0, 1], [1, 1]]; let eoc = EnsembleOfChains::new();
978 assert!(eoc.fit_simple(&X.view(), &y).is_err());
979 }
980
981 #[test]
982 fn test_one_vs_rest_classifier() {
983 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 1.0]];
984 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let ovr = OneVsRestClassifier::new();
987 let fitted = ovr.fit(&X.view(), &y).unwrap();
988
989 assert_eq!(fitted.n_labels(), 2);
990 assert_eq!(fitted.classes().len(), 2);
991
992 let predictions = fitted.predict(&X.view()).unwrap();
993 assert_eq!(predictions.dim(), (4, 2));
994
995 for pred in predictions.iter() {
997 assert!(*pred == 0 || *pred == 1);
998 }
999
1000 let probabilities = fitted.predict_proba(&X.view()).unwrap();
1002 assert_eq!(probabilities.dim(), (4, 2));
1003
1004 for prob in probabilities.iter() {
1006 assert!(*prob >= 0.0 && *prob <= 1.0);
1007 }
1008
1009 let scores = fitted.decision_function(&X.view()).unwrap();
1011 assert_eq!(scores.dim(), (4, 2));
1012
1013 for score in scores.iter() {
1015 assert!(score.is_finite());
1016 }
1017 }
1018
1019 #[test]
1020 fn test_one_vs_rest_classifier_invalid_input() {
1021 let X = array![[1.0, 2.0], [2.0, 3.0]];
1022 let y = array![[1, 0], [0, 1], [1, 1]]; let ovr = OneVsRestClassifier::new();
1025 assert!(ovr.fit(&X.view(), &y).is_err());
1026 }
1027
1028 #[test]
1029 fn test_metrics_one_error() {
1030 let y_true = array![[1, 0, 0], [0, 1, 0], [0, 0, 1]];
1031 let y_scores = array![[0.9, 0.1, 0.05], [0.1, 0.8, 0.1], [0.05, 0.1, 0.85]];
1032
1033 let one_err = metrics::one_error(&y_true.view(), &y_scores.view()).unwrap();
1034 assert!((one_err - 0.0).abs() < 1e-10);
1036 }
1037
1038 #[test]
1039 fn test_metrics_one_error_with_errors() {
1040 let y_true = array![[1, 0], [0, 1]];
1041 let y_scores = array![[0.3, 0.7], [0.6, 0.4]]; let one_err = metrics::one_error(&y_true.view(), &y_scores.view()).unwrap();
1044 assert!((one_err - 1.0).abs() < 1e-10);
1046 }
1047
1048 #[test]
1049 fn test_metrics_ranking_loss() {
1050 let y_true = array![[1, 0], [0, 1]];
1051 let y_scores = array![[0.8, 0.2], [0.3, 0.7]]; let ranking_loss = metrics::ranking_loss(&y_true.view(), &y_scores.view()).unwrap();
1054 assert!((ranking_loss - 0.0).abs() < 1e-10);
1056 }
1057
1058 #[test]
1059 fn test_metrics_ranking_loss_with_errors() {
1060 let y_true = array![[1, 0], [0, 1]];
1061 let y_scores = array![[0.2, 0.8], [0.7, 0.3]]; let ranking_loss = metrics::ranking_loss(&y_true.view(), &y_scores.view()).unwrap();
1064 assert!((ranking_loss - 1.0).abs() < 1e-10);
1066 }
1067
1068 #[test]
1069 fn test_metrics_average_precision_score() {
1070 let y_true = array![[1, 0, 1], [0, 1, 0]];
1071 let y_scores = array![[0.9, 0.1, 0.8], [0.2, 0.9, 0.3]];
1072
1073 let ap_score = metrics::average_precision_score(&y_true.view(), &y_scores.view()).unwrap();
1074 assert!((ap_score - 1.0).abs() < 1e-10);
1076 }
1077
1078 #[test]
1079 fn test_metrics_precision_recall_micro() {
1080 let y_true = array![[1, 0, 1], [0, 1, 0], [1, 1, 1]];
1081 let y_pred = array![[1, 0, 0], [0, 1, 1], [1, 0, 1]];
1082
1083 let precision = metrics::precision_score_micro(&y_true.view(), &y_pred.view()).unwrap();
1084 let recall = metrics::recall_score_micro(&y_true.view(), &y_pred.view()).unwrap();
1085
1086 assert!((precision - 0.8).abs() < 1e-10);
1089 assert!((recall - 0.6666666666666666).abs() < 1e-10);
1090 }
1091
1092 #[test]
1093 fn test_metrics_invalid_shapes_new_metrics() {
1094 let y_true = array![[1, 0], [0, 1]];
1095 let y_pred = array![[1, 0, 1]]; let y_scores = array![[0.8, 0.2, 0.1]]; assert!(metrics::one_error(&y_true.view(), &y_scores.view()).is_err());
1099 assert!(metrics::ranking_loss(&y_true.view(), &y_scores.view()).is_err());
1100 assert!(metrics::average_precision_score(&y_true.view(), &y_scores.view()).is_err());
1101 assert!(metrics::precision_score_micro(&y_true.view(), &y_pred.view()).is_err());
1102 assert!(metrics::recall_score_micro(&y_true.view(), &y_pred.view()).is_err());
1103 }
1104
1105 #[test]
1106 fn test_label_analysis_basic() {
1107 let y = array![
1108 [1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 0, 0], [1, 1, 1], ];
1114
1115 let results = label_analysis::analyze_combinations(&y.view()).unwrap();
1116
1117 assert_eq!(results.total_samples, 5);
1118 assert_eq!(results.combinations[0].combination.len(), 3); assert_eq!(results.unique_combinations, 4); assert_eq!(
1123 results.most_frequent.as_ref().unwrap().combination,
1124 vec![1, 0, 1]
1125 );
1126 assert_eq!(results.most_frequent.as_ref().unwrap().frequency, 2);
1127 assert!((results.most_frequent.as_ref().unwrap().relative_frequency - 0.4).abs() < 1e-10);
1128 assert_eq!(results.most_frequent.as_ref().unwrap().cardinality, 2);
1129
1130 assert!((results.average_cardinality - 1.6).abs() < 1e-10);
1132 }
1133
1134 #[test]
1135 fn test_label_analysis_utility_functions() {
1136 let y = array![
1137 [1, 0],
1138 [1, 0],
1139 [1, 0], [0, 1],
1141 [0, 1], [1, 1] ];
1144
1145 let results = label_analysis::analyze_combinations(&y.view()).unwrap();
1146
1147 let rare = label_analysis::get_rare_combinations(&results, 2);
1149 assert_eq!(rare.len(), 2); let freq_1_combo = rare.iter().find(|combo| combo.frequency == 1).unwrap();
1152 assert_eq!(freq_1_combo.combination, vec![1, 1]);
1153
1154 let cardinality_1 = label_analysis::get_combinations_by_cardinality(&results, 1);
1156 assert_eq!(cardinality_1.len(), 2); let cardinality_2 = label_analysis::get_combinations_by_cardinality(&results, 2);
1159 assert_eq!(cardinality_2.len(), 1); assert_eq!(cardinality_2[0].combination, vec![1, 1]);
1161 }
1162
1163 #[test]
1164 fn test_label_cooccurrence_matrix() {
1165 let y = array![
1166 [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1], ];
1171
1172 let cooccurrence = label_analysis::label_cooccurrence_matrix(&y.view()).unwrap();
1173 assert_eq!(cooccurrence.dim(), (3, 3));
1174
1175 assert_eq!(cooccurrence[[0, 0]], 3);
1177 assert_eq!(cooccurrence[[1, 1]], 3);
1179 assert_eq!(cooccurrence[[2, 2]], 3);
1181
1182 assert_eq!(cooccurrence[[0, 1]], 2);
1184 assert_eq!(cooccurrence[[1, 0]], 2);
1185
1186 assert_eq!(cooccurrence[[0, 2]], 2);
1188 assert_eq!(cooccurrence[[2, 0]], 2);
1189
1190 assert_eq!(cooccurrence[[1, 2]], 2);
1192 assert_eq!(cooccurrence[[2, 1]], 2);
1193 }
1194
1195 #[test]
1196 fn test_label_correlation_matrix() {
1197 let y = array![[1, 1, 0], [1, 0, 1], [0, 1, 1], [0, 0, 0],];
1198
1199 let correlation = label_analysis::label_correlation_matrix(&y.view()).unwrap();
1200 assert_eq!(correlation.dim(), (3, 3));
1201
1202 assert!((correlation[[0, 0]] - 1.0).abs() < 1e-10);
1204 assert!((correlation[[1, 1]] - 1.0).abs() < 1e-10);
1205 assert!((correlation[[2, 2]] - 1.0).abs() < 1e-10);
1206
1207 for i in 0..3 {
1209 for j in 0..3 {
1210 assert!((correlation[[i, j]] - correlation[[j, i]]).abs() < 1e-10);
1211 }
1212 }
1213
1214 for i in 0..3 {
1216 for j in 0..3 {
1217 assert!(correlation[[i, j]] >= -1.0 && correlation[[i, j]] <= 1.0);
1218 }
1219 }
1220 }
1221
1222 #[test]
1223 fn test_label_analysis_invalid_input() {
1224 let y_bad = array![[2, 1], [1, 0]]; assert!(label_analysis::analyze_combinations(&y_bad.view()).is_err());
1227
1228 let y_empty = Array2::<i32>::zeros((0, 2));
1230 assert!(label_analysis::analyze_combinations(&y_empty.view()).is_err());
1231
1232 let y_no_labels = Array2::<i32>::zeros((2, 0));
1233 assert!(label_analysis::analyze_combinations(&y_no_labels.view()).is_err());
1234
1235 assert!(label_analysis::label_cooccurrence_matrix(&y_empty.view()).is_err());
1237 assert!(label_analysis::label_correlation_matrix(&y_empty.view()).is_err());
1238 }
1239
1240 #[test]
1241 fn test_label_analysis_edge_cases() {
1242 let y_single = array![[1, 0, 1]];
1244 let results = label_analysis::analyze_combinations(&y_single.view()).unwrap();
1245
1246 assert_eq!(results.total_samples, 1);
1247 assert_eq!(results.unique_combinations, 1);
1248 assert_eq!(
1249 results.most_frequent.as_ref().unwrap().combination,
1250 vec![1, 0, 1]
1251 );
1252 assert_eq!(
1253 results.least_frequent.as_ref().unwrap().combination,
1254 vec![1, 0, 1]
1255 );
1256 assert_eq!(results.average_cardinality, 2.0);
1257
1258 let y_zeros = array![[0, 0], [0, 0]];
1260 let results = label_analysis::analyze_combinations(&y_zeros.view()).unwrap();
1261
1262 assert_eq!(results.average_cardinality, 0.0);
1263
1264 let y_ones = array![[1, 1], [1, 1]];
1266 let results = label_analysis::analyze_combinations(&y_ones.view()).unwrap();
1267
1268 assert_eq!(results.average_cardinality, 2.0);
1269 }
1270
1271 #[test]
1272 fn test_iblr_basic_functionality() {
1273 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1274 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let iblr = IBLR::new().k_neighbors(2);
1277 let trained_iblr = iblr.fit(&X.view(), &y.view()).unwrap();
1278 let predictions = trained_iblr.predict(&X.view()).unwrap();
1279
1280 assert_eq!(predictions.dim(), (4, 2));
1281
1282 for i in 0..4 {
1284 for j in 0..2 {
1285 assert!(predictions[[i, j]] == 0 || predictions[[i, j]] == 1);
1286 }
1287 }
1288 }
1289
1290 #[test]
1291 fn test_iblr_configuration() {
1292 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1293 let y = array![[1, 0], [0, 1], [1, 1]]; let iblr1 = IBLR::new().k_neighbors(1);
1297 let iblr2 = IBLR::new().k_neighbors(2); let trained1 = iblr1.fit(&X.view(), &y.view()).unwrap();
1300 let trained2 = iblr2.fit(&X.view(), &y.view()).unwrap();
1301
1302 let pred1 = trained1.predict(&X.view()).unwrap();
1303 let pred2 = trained2.predict(&X.view()).unwrap();
1304
1305 assert_eq!(pred1.dim(), (3, 2));
1306 assert_eq!(pred2.dim(), (3, 2));
1307
1308 let iblr_uniform = IBLR::new().k_neighbors(2).weights(WeightFunction::Uniform);
1310 let iblr_distance = IBLR::new().k_neighbors(2).weights(WeightFunction::Distance);
1311
1312 let trained_uniform = iblr_uniform.fit(&X.view(), &y.view()).unwrap();
1313 let trained_distance = iblr_distance.fit(&X.view(), &y.view()).unwrap();
1314
1315 let pred_uniform = trained_uniform.predict(&X.view()).unwrap();
1316 let pred_distance = trained_distance.predict(&X.view()).unwrap();
1317
1318 assert_eq!(pred_uniform.dim(), (3, 2));
1319 assert_eq!(pred_distance.dim(), (3, 2));
1320 }
1321
1322 #[test]
1323 fn test_iblr_error_handling() {
1324 let X = array![[1.0, 2.0], [2.0, 3.0]];
1325 let y = array![[1, 0], [0, 1], [1, 1]]; let iblr = IBLR::new();
1328 assert!(iblr.fit(&X.view(), &y.view()).is_err());
1329
1330 let y_valid = array![[1, 0], [0, 1]]; let iblr_zero_k = IBLR::new().k_neighbors(0);
1334 assert!(iblr_zero_k.fit(&X.view(), &y_valid.view()).is_err());
1335
1336 let iblr_large_k = IBLR::new().k_neighbors(5); assert!(iblr_large_k.fit(&X.view(), &y_valid.view()).is_err());
1338
1339 let X_train = array![[1.0, 2.0], [2.0, 3.0]];
1341 let y_train = array![[1, 0], [0, 1]];
1342 let iblr_for_predict = IBLR::new().k_neighbors(1); let trained = iblr_for_predict
1344 .fit(&X_train.view(), &y_train.view())
1345 .unwrap();
1346
1347 let X_wrong_features = array![[1.0, 2.0, 3.0]]; assert!(trained.predict(&X_wrong_features.view()).is_err());
1349
1350 let X_empty = Array2::<Float>::zeros((0, 2));
1352 let y_empty = Array2::<i32>::zeros((0, 2));
1353 assert!(IBLR::new().fit(&X_empty.view(), &y_empty.view()).is_err());
1354 }
1355
1356 #[test]
1357 fn test_iblr_weight_functions() {
1358 let X = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [2.0, 2.0]];
1359 let y = array![[1, 1], [0, 1], [1, 0], [0, 0]]; let iblr_uniform = IBLR::new().k_neighbors(3).weights(WeightFunction::Uniform);
1363 let trained_uniform = iblr_uniform.fit(&X.view(), &y.view()).unwrap();
1364 let pred_uniform = trained_uniform.predict(&X.view()).unwrap();
1365
1366 let iblr_distance = IBLR::new().k_neighbors(3).weights(WeightFunction::Distance);
1368 let trained_distance = iblr_distance.fit(&X.view(), &y.view()).unwrap();
1369 let pred_distance = trained_distance.predict(&X.view()).unwrap();
1370
1371 assert_eq!(pred_uniform.dim(), (4, 2));
1373 assert_eq!(pred_distance.dim(), (4, 2));
1374
1375 for i in 0..4 {
1377 for j in 0..2 {
1378 assert!(pred_uniform[[i, j]] == 0 || pred_uniform[[i, j]] == 1);
1379 assert!(pred_distance[[i, j]] == 0 || pred_distance[[i, j]] == 1);
1380 }
1381 }
1382 }
1383
1384 #[test]
1385 fn test_iblr_single_neighbor() {
1386 let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1387 let y = array![[1, 0], [0, 1], [1, 1]]; let iblr = IBLR::new().k_neighbors(1);
1390 let trained = iblr.fit(&X.view(), &y.view()).unwrap();
1391
1392 let predictions = trained.predict(&X.view()).unwrap();
1394
1395 for i in 0..3 {
1396 for j in 0..2 {
1397 assert_eq!(predictions[[i, j]], y[[i, j]]);
1398 }
1399 }
1400 }
1401
1402 #[test]
1403 fn test_iblr_interpolation() {
1404 let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
1405 let y = array![[0, 0], [1, 1], [0, 1]]; let iblr = IBLR::new().k_neighbors(2);
1408 let trained = iblr.fit(&X.view(), &y.view()).unwrap();
1409
1410 let X_test = array![[0.5, 0.5]];
1412 let prediction = trained.predict(&X_test.view()).unwrap();
1413
1414 assert!(prediction[[0, 0]] == 0 || prediction[[0, 0]] == 1);
1416 assert!(prediction[[0, 1]] == 0 || prediction[[0, 1]] == 1);
1417 }
1418
1419 #[test]
1420 fn test_clare_basic_functionality() {
1421 let X = array![[1.0, 1.0], [1.5, 1.5], [5.0, 5.0], [5.5, 5.5]];
1422 let y = array![[1, 0], [1, 0], [0, 1], [0, 1]]; let clare = CLARE::new().n_clusters(2).random_state(42);
1425 let trained_clare = clare.fit(&X.view(), &y).unwrap();
1426 let predictions = trained_clare.predict(&X.view()).unwrap();
1427
1428 assert_eq!(predictions.dim(), (4, 2));
1429
1430 assert_eq!(trained_clare.cluster_centers().dim(), (2, 2));
1432 assert_eq!(trained_clare.cluster_assignments().len(), 4);
1433 }
1434
1435 #[test]
1436 fn test_clare_configuration() {
1437 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1438 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1439
1440 let clare1 = CLARE::new().n_clusters(2).threshold(0.3);
1442 let clare2 = CLARE::new().n_clusters(3).max_iter(50);
1443 let clare3 = CLARE::new().random_state(123);
1444
1445 let trained1 = clare1.fit(&X.view(), &y).unwrap();
1446 let trained2 = clare2.fit(&X.view(), &y).unwrap();
1447 let trained3 = clare3.fit(&X.view(), &y).unwrap();
1448
1449 let pred1 = trained1.predict(&X.view()).unwrap();
1450 let pred2 = trained2.predict(&X.view()).unwrap();
1451 let pred3 = trained3.predict(&X.view()).unwrap();
1452
1453 assert_eq!(pred1.dim(), (4, 2));
1454 assert_eq!(pred2.dim(), (4, 2));
1455 assert_eq!(pred3.dim(), (4, 2));
1456
1457 assert_eq!(trained1.threshold(), 0.3);
1459 assert_eq!(trained1.cluster_centers().dim(), (2, 2));
1460 assert_eq!(trained2.cluster_centers().dim(), (3, 2));
1461 }
1462
1463 #[test]
1464 fn test_clare_error_handling() {
1465 let X = array![[1.0, 2.0], [2.0, 3.0]];
1466 let y = array![[1, 0], [0, 1], [1, 1]]; let clare = CLARE::new();
1469 assert!(clare.fit(&X.view(), &y).is_err());
1470
1471 let y_valid = array![[1, 0], [0, 1]];
1473
1474 let clare_zero_clusters = CLARE::new().n_clusters(0);
1475 assert!(clare_zero_clusters.fit(&X.view(), &y_valid).is_err());
1476
1477 let clare_too_many_clusters = CLARE::new().n_clusters(5); assert!(clare_too_many_clusters.fit(&X.view(), &y_valid).is_err());
1479
1480 let y_non_binary = array![[1, 2], [0, 1]]; assert!(CLARE::new().fit(&X.view(), &y_non_binary).is_err());
1483
1484 let X_train = array![[1.0, 2.0], [2.0, 3.0]];
1486 let y_train = array![[1, 0], [0, 1]];
1487 let clare_for_predict = CLARE::new().n_clusters(2);
1488 let trained = clare_for_predict.fit(&X_train.view(), &y_train).unwrap();
1489
1490 let X_wrong_features = array![[1.0, 2.0, 3.0]]; assert!(trained.predict(&X_wrong_features.view()).is_err());
1492
1493 let X_empty = Array2::<Float>::zeros((0, 2));
1495 let y_empty = Array2::<i32>::zeros((0, 2));
1496 assert!(CLARE::new().fit(&X_empty.view(), &y_empty).is_err());
1497 }
1498
1499 #[test]
1500 fn test_clare_threshold_prediction() {
1501 let X = array![[1.0, 1.0], [1.2, 1.2], [5.0, 5.0], [5.2, 5.2]];
1502 let y = array![[1, 0], [1, 0], [0, 1], [0, 1]];
1503
1504 let clare = CLARE::new().n_clusters(2).threshold(0.3).random_state(42);
1505 let trained_clare = clare.fit(&X.view(), &y).unwrap();
1506
1507 let predictions = trained_clare.predict(&X.view()).unwrap();
1509 assert_eq!(predictions.dim(), (4, 2));
1510
1511 for pred in predictions.iter() {
1513 assert!(*pred == 0 || *pred == 1);
1514 }
1515
1516 assert_eq!(trained_clare.threshold(), 0.3);
1518 }
1519
1520 #[test]
1521 fn test_clare_clustering_consistency() {
1522 let X = array![
1523 [0.0, 0.0],
1524 [0.1, 0.1], [5.0, 5.0],
1526 [5.1, 5.1] ];
1528 let y = array![
1529 [1, 0],
1530 [1, 0], [0, 1],
1532 [0, 1] ];
1534
1535 let clare = CLARE::new().n_clusters(2).threshold(0.5).random_state(42);
1536 let trained_clare = clare.fit(&X.view(), &y).unwrap();
1537 let predictions = trained_clare.predict(&X.view()).unwrap();
1538
1539 assert_eq!(predictions.dim(), (4, 2));
1541 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1542
1543 assert!((trained_clare.threshold() - 0.5).abs() < 1e-10);
1545 }
1546
1547 #[test]
1548 fn test_clare_single_cluster() {
1549 let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1550 let y = array![[1, 0], [0, 1], [1, 1]];
1551
1552 let clare = CLARE::new().n_clusters(1);
1554 let trained_clare = clare.fit(&X.view(), &y).unwrap();
1555 let predictions = trained_clare.predict(&X.view()).unwrap();
1556
1557 assert_eq!(predictions.dim(), (3, 2));
1558 assert_eq!(trained_clare.cluster_centers().dim(), (1, 2));
1559
1560 for i in 1..3 {
1563 for j in 0..2 {
1564 assert_eq!(predictions[[0, j]], predictions[[i, j]]);
1565 }
1566 }
1567 }
1568
1569 #[test]
1570 fn test_clare_reproducibility() {
1571 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1572 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1573
1574 let clare1 = CLARE::new().n_clusters(2).random_state(42);
1576 let trained1 = clare1.fit(&X.view(), &y).unwrap();
1577
1578 let clare2 = CLARE::new().n_clusters(2).random_state(42);
1579 let trained2 = clare2.fit(&X.view(), &y).unwrap();
1580
1581 let centers1 = trained1.cluster_centers();
1583 let centers2 = trained2.cluster_centers();
1584
1585 for i in 0..centers1.nrows() {
1586 for j in 0..centers1.ncols() {
1587 assert!((centers1[[i, j]] - centers2[[i, j]]).abs() < 1e-10);
1588 }
1589 }
1590
1591 let pred1 = trained1.predict(&X.view()).unwrap();
1593 let pred2 = trained2.predict(&X.view()).unwrap();
1594
1595 for i in 0..pred1.nrows() {
1596 for j in 0..pred1.ncols() {
1597 assert_eq!(pred1[[i, j]], pred2[[i, j]]);
1598 }
1599 }
1600 }
1601
1602 #[test]
1603 fn test_mltsvm_basic_functionality() {
1604 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1605 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let mltsvm = MLTSVM::new().c1(1.0).c2(1.0);
1608 let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1609 let predictions = trained_mltsvm.predict(&X.view()).unwrap();
1610
1611 assert_eq!(predictions.dim(), (4, 2));
1612 assert_eq!(trained_mltsvm.n_labels(), 2);
1613
1614 for &pred in predictions.iter() {
1616 assert!(pred == 0 || pred == 1);
1617 }
1618 }
1619
1620 #[test]
1621 fn test_mltsvm_configuration() {
1622 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1623 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1624
1625 let mltsvm1 = MLTSVM::new().c1(0.5).c2(1.5);
1627 let mltsvm2 = MLTSVM::new().epsilon(1e-8).max_iter(500);
1628
1629 let trained1 = mltsvm1.fit(&X.view(), &y).unwrap();
1630 let trained2 = mltsvm2.fit(&X.view(), &y).unwrap();
1631
1632 let pred1 = trained1.predict(&X.view()).unwrap();
1633 let pred2 = trained2.predict(&X.view()).unwrap();
1634
1635 assert_eq!(pred1.dim(), (4, 2));
1636 assert_eq!(pred2.dim(), (4, 2));
1637 }
1638
1639 #[test]
1640 fn test_mltsvm_error_handling() {
1641 let X = array![[1.0, 2.0], [2.0, 3.0]];
1642 let y = array![[1, 0], [0, 1], [1, 1]]; let mltsvm = MLTSVM::new();
1645 assert!(mltsvm.fit(&X.view(), &y).is_err());
1646
1647 let y_non_binary = array![[1, 2], [0, 1]]; assert!(MLTSVM::new().fit(&X.view(), &y_non_binary).is_err());
1650
1651 let X_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 2.0]];
1653 let y_train = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1654 let mltsvm_for_predict = MLTSVM::new();
1655 let trained = mltsvm_for_predict.fit(&X_train.view(), &y_train).unwrap();
1656
1657 let X_wrong_features = array![[1.0, 2.0, 3.0]]; assert!(trained.predict(&X_wrong_features.view()).is_err());
1659
1660 let X_empty = Array2::<Float>::zeros((0, 2));
1662 let y_empty = Array2::<i32>::zeros((0, 2));
1663 assert!(MLTSVM::new().fit(&X_empty.view(), &y_empty).is_err());
1664 }
1665
1666 #[test]
1667 fn test_mltsvm_decision_function() {
1668 let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]];
1669 let y = array![[1, 0], [1, 0], [0, 1], [0, 1]];
1670
1671 let mltsvm = MLTSVM::new();
1672 let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1673
1674 let decision_values = trained_mltsvm.decision_function(&X.view()).unwrap();
1676 assert_eq!(decision_values.dim(), (4, 2));
1677
1678 for &val in decision_values.iter() {
1681 assert!(val.is_finite());
1682 }
1683 }
1684
1685 #[test]
1686 fn test_mltsvm_separable_data() {
1687 let X = array![
1688 [0.0, 0.0],
1689 [0.5, 0.5], [3.0, 3.0],
1691 [3.5, 3.5] ];
1693 let y = array![
1694 [0, 1],
1695 [0, 1], [1, 0],
1697 [1, 0] ];
1699
1700 let mltsvm = MLTSVM::new().c1(1.0).c2(1.0);
1701 let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1702 let predictions = trained_mltsvm.predict(&X.view()).unwrap();
1703
1704 let mut correct_predictions = 0;
1706 let total_predictions = predictions.len();
1707
1708 for i in 0..predictions.nrows() {
1709 for j in 0..predictions.ncols() {
1710 if predictions[[i, j]] == y[[i, j]] {
1711 correct_predictions += 1;
1712 }
1713 }
1714 }
1715
1716 let accuracy = correct_predictions as Float / total_predictions as Float;
1717 assert!(accuracy >= 0.5); }
1720
1721 #[test]
1722 fn test_mltsvm_feature_scaling() {
1723 let X = array![
1725 [1000.0, 0.001],
1726 [2000.0, 0.002],
1727 [3000.0, 0.003],
1728 [4000.0, 0.004]
1729 ];
1730 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1731
1732 let mltsvm = MLTSVM::new();
1733 let trained_mltsvm = mltsvm.fit(&X.view(), &y).unwrap();
1734 let predictions = trained_mltsvm.predict(&X.view()).unwrap();
1735
1736 assert_eq!(predictions.dim(), (4, 2));
1738
1739 for &pred in predictions.iter() {
1741 assert!(pred == 0 || pred == 1);
1742 }
1743 }
1744
1745 #[test]
1746 fn test_mltsvm_consistency() {
1747 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1748 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1749
1750 let mltsvm1 = MLTSVM::new().c1(1.0).c2(1.0);
1752 let trained1 = mltsvm1.fit(&X.view(), &y).unwrap();
1753
1754 let mltsvm2 = MLTSVM::new().c1(1.0).c2(1.0);
1755 let trained2 = mltsvm2.fit(&X.view(), &y).unwrap();
1756
1757 let pred1 = trained1.predict(&X.view()).unwrap();
1758 let pred2 = trained2.predict(&X.view()).unwrap();
1759
1760 for i in 0..pred1.nrows() {
1762 for j in 0..pred1.ncols() {
1763 assert_eq!(pred1[[i, j]], pred2[[i, j]]);
1764 }
1765 }
1766 }
1767
1768 #[test]
1769 fn test_ranksvm_basic_functionality() {
1770 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1771 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]]; let ranksvm = RankSVM::new().c(1.0);
1774 let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1775 let predictions = trained_ranksvm.predict(&X.view()).unwrap();
1776
1777 assert_eq!(predictions.dim(), (4, 2));
1778 assert_eq!(trained_ranksvm.n_labels(), 2);
1779
1780 for &pred in predictions.iter() {
1782 assert!(pred == 0 || pred == 1);
1783 }
1784 }
1785
1786 #[test]
1787 fn test_ranksvm_threshold_strategies() {
1788 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1789 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1790
1791 let ranksvm1 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::Fixed(0.5));
1793 let ranksvm2 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::OptimizeF1);
1794 let ranksvm3 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::TopK(2));
1795 let ranksvm4 = RankSVM::new().threshold_strategy(SVMThresholdStrategy::OptimizeF1);
1796
1797 let trained1 = ranksvm1.fit(&X.view(), &y).unwrap();
1798 let trained2 = ranksvm2.fit(&X.view(), &y).unwrap();
1799 let trained3 = ranksvm3.fit(&X.view(), &y).unwrap();
1800 let trained4 = ranksvm4.fit(&X.view(), &y).unwrap();
1801
1802 let pred1 = trained1.predict(&X.view()).unwrap();
1803 let pred2 = trained2.predict(&X.view()).unwrap();
1804 let pred3 = trained3.predict(&X.view()).unwrap();
1805 let pred4 = trained4.predict(&X.view()).unwrap();
1806
1807 assert_eq!(pred1.dim(), (4, 2));
1808 assert_eq!(pred2.dim(), (4, 2));
1809 assert_eq!(pred3.dim(), (4, 2));
1810 assert_eq!(pred4.dim(), (4, 2));
1811
1812 assert_eq!(trained1.thresholds().len(), 2);
1814 assert_eq!(trained2.thresholds().len(), 2);
1815 assert_eq!(trained3.thresholds().len(), 2);
1816 }
1817
1818 #[test]
1819 fn test_ranksvm_decision_function() {
1820 let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]];
1821 let y = array![[1, 0], [1, 0], [0, 1], [0, 1]];
1822
1823 let ranksvm = RankSVM::new();
1824 let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1825
1826 let scores = trained_ranksvm.decision_function(&X.view()).unwrap();
1828 assert_eq!(scores.dim(), (4, 2));
1829
1830 for &score in scores.iter() {
1832 assert!(score.is_finite());
1833 }
1834 }
1835
1836 #[test]
1837 fn test_ranksvm_predict_ranking() {
1838 let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1839 let y = array![[1, 0, 0], [0, 1, 0], [0, 0, 1]]; let ranksvm = RankSVM::new();
1842 let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1843
1844 let rankings = trained_ranksvm.predict_ranking(&X.view()).unwrap();
1846
1847 assert_eq!(rankings.dim(), (3, 3)); for sample_idx in 0..3 {
1849 let mut ranking_vec = Vec::new();
1850 for label_idx in 0..3 {
1851 ranking_vec.push(rankings[[sample_idx, label_idx]]);
1852 }
1853 let mut sorted_ranking = ranking_vec.clone();
1855 sorted_ranking.sort();
1856 assert_eq!(sorted_ranking, vec![0, 1, 2]);
1857 }
1858 }
1859
1860 #[test]
1861 fn test_ranksvm_error_handling() {
1862 let X = array![[1.0, 2.0], [2.0, 3.0]];
1863 let y = array![[1, 0], [0, 1], [1, 1]]; let ranksvm = RankSVM::new();
1866 assert!(ranksvm.fit(&X.view(), &y).is_err());
1867
1868 let y_non_binary = array![[1, 2], [0, 1]]; assert!(RankSVM::new().fit(&X.view(), &y_non_binary).is_err());
1871
1872 let X_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 2.0]];
1874 let y_train = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1875 let ranksvm_for_predict = RankSVM::new();
1876 let trained = ranksvm_for_predict.fit(&X_train.view(), &y_train).unwrap();
1877
1878 let X_wrong_features = array![[1.0, 2.0, 3.0]]; assert!(trained.predict(&X_wrong_features.view()).is_err());
1880 assert!(trained.decision_function(&X_wrong_features.view()).is_err());
1881 assert!(trained.predict_ranking(&X_wrong_features.view()).is_err());
1882
1883 let X_empty = Array2::<Float>::zeros((0, 2));
1885 let y_empty = Array2::<i32>::zeros((0, 2));
1886 assert!(RankSVM::new().fit(&X_empty.view(), &y_empty).is_err());
1887 }
1888
1889 #[test]
1890 fn test_ranksvm_configuration() {
1891 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1892 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1893
1894 let ranksvm1 = RankSVM::new().c(0.5).epsilon(1e-8);
1896 let ranksvm2 = RankSVM::new().max_iter(500);
1897
1898 let trained1 = ranksvm1.fit(&X.view(), &y).unwrap();
1899 let trained2 = ranksvm2.fit(&X.view(), &y).unwrap();
1900
1901 let pred1 = trained1.predict(&X.view()).unwrap();
1902 let pred2 = trained2.predict(&X.view()).unwrap();
1903
1904 assert_eq!(pred1.dim(), (4, 2));
1905 assert_eq!(pred2.dim(), (4, 2));
1906 }
1907
1908 #[test]
1909 fn test_ranksvm_ranking_consistency() {
1910 let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
1911 let y = array![
1912 [0, 0, 1], [0, 1, 0], [1, 0, 0] ];
1916
1917 let ranksvm = RankSVM::new().threshold_strategy(SVMThresholdStrategy::OptimizeF1);
1918 let trained_ranksvm = ranksvm.fit(&X.view(), &y).unwrap();
1919
1920 let rankings = trained_ranksvm.predict_ranking(&X.view()).unwrap();
1921 let scores = trained_ranksvm.decision_function(&X.view()).unwrap();
1922
1923 for i in 0..3 {
1925 let top_label = rankings[[i, 0]];
1927 for j in 1..3 {
1928 let other_label = rankings[[i, j]];
1929 assert!(scores[[i, top_label]] >= scores[[i, other_label]]);
1930 }
1931 }
1932 }
1933
1934 #[test]
1935 fn test_ranksvm_single_class_handling() {
1936 let X = array![[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
1937
1938 let y = array![[1, 0], [1, 1], [1, 0]]; let ranksvm = RankSVM::new();
1942 let trained = ranksvm.fit(&X.view(), &y).unwrap();
1943 let predictions = trained.predict(&X.view()).unwrap();
1944 let scores = trained.decision_function(&X.view()).unwrap();
1945
1946 assert_eq!(predictions.dim(), (3, 2));
1947 assert_eq!(scores.dim(), (3, 2));
1948
1949 for &score in scores.iter() {
1951 assert!(score.is_finite());
1952 }
1953 }
1954
1955 #[test]
1956 fn test_ranksvm_reproducibility() {
1957 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
1958 let y = array![[1, 0], [0, 1], [1, 1], [0, 0]];
1959
1960 let ranksvm1 = RankSVM::new().c(1.0).epsilon(1e-6);
1962 let trained1 = ranksvm1.fit(&X.view(), &y).unwrap();
1963
1964 let ranksvm2 = RankSVM::new().c(1.0).epsilon(1e-6);
1965 let trained2 = ranksvm2.fit(&X.view(), &y).unwrap();
1966
1967 let pred1 = trained1.predict(&X.view()).unwrap();
1968 let pred2 = trained2.predict(&X.view()).unwrap();
1969 let scores1 = trained1.decision_function(&X.view()).unwrap();
1970 let scores2 = trained2.decision_function(&X.view()).unwrap();
1971
1972 for i in 0..pred1.nrows() {
1974 for j in 0..pred1.ncols() {
1975 assert_eq!(pred1[[i, j]], pred2[[i, j]]);
1976 assert!((scores1[[i, j]] - scores2[[i, j]]).abs() < 1e-10);
1977 }
1978 }
1979 }
1980}