1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8#![allow(unused_imports)]
9#![allow(unused_variables)]
10#![allow(unused_mut)]
11#![allow(unused_assignments)]
12#![allow(unused_doc_comments)]
13#![allow(unused_parens)]
14#![allow(unused_comparisons)]
15#![allow(clippy::type_complexity)]
16#![allow(clippy::assign_op_pattern)]
17#![allow(clippy::upper_case_acronyms)]
18#![allow(clippy::new_without_default)]
19#![allow(clippy::ptr_arg)]
20#![allow(clippy::useless_asref)]
21#![allow(clippy::needless_range_loop)]
22#![allow(clippy::empty_line_after_doc_comments)]
23#![allow(clippy::let_and_return)]
24#![allow(clippy::needless_borrow)]
25#![allow(clippy::cast_lossless)]
26#![allow(clippy::unnecessary_cast)]
27#![allow(clippy::len_zero)]
28#![allow(clippy::useless_vec)]
29#![allow(clippy::derivable_impls)]
30#![allow(clippy::only_used_in_recursion)]
31#![allow(clippy::op_ref)]
32#![allow(clippy::manual_clamp)]
33#![allow(clippy::manual_ok_err)]
34#![allow(clippy::cast_abs_to_unsigned)]
35#![allow(clippy::unnecessary_sort_by)]
36#![allow(clippy::implicit_saturating_sub)]
37#![allow(clippy::manual_abs_diff)]
38mod adaptive_lasso;
46mod alternating_least_squares;
47mod alternating_projections;
48mod bayesian_covariance;
49mod bigquic;
50mod chen_stein;
51mod clime;
52mod composable_regularization;
53mod coordinate_descent;
54mod differential_privacy;
55mod elastic_net;
56mod elliptic_envelope;
57mod em_missing_data;
58mod empirical;
59mod factor_model;
60mod financial_applications;
61mod fluent_api;
62mod frank_wolfe;
63mod genomics_bioinformatics;
64mod graphical_lasso;
65mod group_lasso;
66mod huber;
67mod hyperparameter_tuning;
68mod ica_covariance;
69mod information_theory;
70mod iterative_proportional_fitting;
71mod ledoit_wolf;
72mod low_rank_sparse;
73mod meta_learning;
74mod min_cov_det;
75mod model_selection;
76mod neighborhood_selection;
77mod nmf_covariance;
78mod nonlinear_shrinkage;
79mod nonparametric_covariance;
80mod nuclear_norm;
81mod oas;
82mod optimization;
83mod pca_integration;
84mod performance_optimizations;
85mod polars_integration;
86mod presets;
87mod rao_blackwell_lw;
88mod ridge;
89mod robust_pca;
90mod rotation_equivariant;
91mod rust_improvements;
92mod shrunk;
93mod signal_processing;
94mod space;
95mod sparse_factor_models;
96mod testing_quality;
97mod tiger;
98mod time_varying_covariance;
99mod utils;
100
101mod diagnostics;
103
104mod adversarial_robustness;
106mod federated_learning;
107mod plugin_architecture;
108mod quantum_methods;
109
110pub use utils::{
116 adaptive_shrinkage, frobenius_norm, is_diagonally_dominant, matrix_determinant, matrix_inverse,
117 nuclear_norm_approximation, rank_estimate, spectral_radius_estimate,
118 validate_covariance_matrix, BenchmarkResult, CovarianceBenchmark, CovarianceCV,
119 CovarianceProperties, ScoringMethod,
120};
121
122pub use diagnostics::{
124 compare_covariance_matrices, CorrelationDiagnostics, CovarianceDiagnostics, DiagonalStats,
125 OffDiagonalStats, QualityAssessment,
126};
127
128pub use polars_integration::utils as polars_utils;
130pub use polars_integration::{
131 ColumnStatistics, ConvergenceInfo, CovarianceDataFrame, CovarianceResult, DataFrameDescription,
132 DataFrameEstimator, DataFrameMetadata, EstimatorInfo, PerformanceMetrics,
133};
134
135pub use hyperparameter_tuning::presets as tuning_presets;
137pub use hyperparameter_tuning::{
138 AcquisitionFunction, CVResult, CovarianceEstimatorFitted, CovarianceEstimatorTunable,
139 CovarianceHyperparameterTuner, CrossValidationConfig, EarlyStoppingConfig, ExplorationMetrics,
140 OptimizationHistory, ParameterSpec, ParameterType, ParameterValue, ScoringMetric,
141 SearchStrategy, TuningConfig, TuningResult,
142};
143
144pub use model_selection::presets as model_selection_presets;
146pub use model_selection::{
147 AutoCovarianceSelector, BestEstimator, CandidateResult, ComputationalComplexity,
148 ComputationalConstraints, CorrelationStructure, DataCharacteristics, DataCharacterizationRules,
149 DistributionCharacteristics, HeuristicRule, InformationCriterion, MissingDataInfo,
150 ModelSelectionCV, ModelSelectionResult, ModelSelectionScoring, PerformanceComparison,
151 RuleCondition, SelectionRule, SelectionStrategy, StratificationStrategy,
152};
153
154pub use presets::{
156 CovariancePresets, DomainPresets, Financial, Genomics, PresetRecommendations, SignalProcessing,
157};
158
159pub use empirical::{EmpiricalCovariance, EmpiricalCovarianceTrained};
161
162pub use shrunk::{ShrunkCovariance, ShrunkCovarianceTrained};
164
165pub use min_cov_det::{MinCovDet, MinCovDetTrained};
167
168pub use graphical_lasso::{GraphicalLasso, GraphicalLassoTrained};
170
171pub use ledoit_wolf::{LedoitWolf, LedoitWolfTrained};
173
174pub use elliptic_envelope::{EllipticEnvelope, EllipticEnvelopeTrained};
176
177pub use oas::{OASTrained, OAS};
179
180pub use huber::{HuberCovariance, HuberCovarianceTrained};
182
183pub use ridge::{RidgeCovariance, RidgeCovarianceTrained};
185
186pub use elastic_net::{ElasticNetCovariance, ElasticNetCovarianceTrained};
188
189pub use chen_stein::{ChenSteinCovariance, ChenSteinCovarianceTrained};
191
192pub use adaptive_lasso::{AdaptiveLassoCovariance, AdaptiveLassoCovarianceTrained};
194
195pub use group_lasso::{GroupLassoCovariance, GroupLassoCovarianceTrained};
197
198pub use rao_blackwell_lw::{RaoBlackwellLedoitWolf, RaoBlackwellLedoitWolfTrained};
200
201pub use nonlinear_shrinkage::NonlinearShrinkage;
203pub type NonlinearShrinkageTrained =
204 NonlinearShrinkage<nonlinear_shrinkage::NonlinearShrinkageTrained>;
205
206pub use clime::CLIME;
208pub type CLIMETrained = CLIME<clime::CLIMETrained>;
209
210pub use nuclear_norm::NuclearNormMinimization;
212pub type NuclearNormMinimizationTrained =
213 NuclearNormMinimization<nuclear_norm::NuclearNormMinimizationTrained>;
214
215pub use rotation_equivariant::{RotationEquivariant, RotationEquivariantTrained};
217
218pub use neighborhood_selection::{NeighborhoodSelection, NeighborhoodSelectionTrained};
220
221pub use space::{SPACETrained, SPACE};
223
224pub use tiger::{TIGERTrained, TIGER};
226
227pub use robust_pca::{RobustPCA, RobustPCATrained};
229
230pub use bigquic::{BigQUIC, BigQUICTrained};
232
233pub use factor_model::{FactorModelCovariance, FactorModelCovarianceTrained};
235
236pub use low_rank_sparse::{LowRankSparseCovariance, LowRankSparseCovarianceTrained};
238
239pub use alternating_least_squares::{ALSCovariance, ALSCovarianceTrained};
241
242pub use pca_integration::{KernelFunction, PCACovariance, PCACovarianceTrained, PCAMethod};
244
245pub use ica_covariance::{
247 ContrastFunction, ICAAlgorithm, ICACovariance, ICACovarianceTrained, WhiteningMethod,
248};
249
250pub use em_missing_data::{
252 EMCovarianceMissingData, EMCovarianceMissingDataTrained, MissingDataMethod,
253};
254
255pub use iterative_proportional_fitting::{
257 ConstraintType, IPFCovariance, IPFCovarianceTrained, MarginalConstraint, RankMethod,
258};
259
260pub use coordinate_descent::{
262 CoordinateDescentCovariance, CoordinateDescentCovarianceTrained, OptimizationTarget,
263 RegularizationMethod,
264};
265
266pub use nmf_covariance::{
268 NMFAlgorithm, NMFCovariance, NMFCovarianceTrained, NMFInitialization, UpdateRule,
269};
270
271pub use sparse_factor_models::{
273 SparseFactorModel, SparseFactorModelTrained, SparseInitialization, SparseRegularization,
274};
275
276pub use alternating_projections::{
278 APAlgorithm, AlternatingProjections, AlternatingProjectionsTrained, ConvergenceCriterion,
279 ProjectionConstraint,
280};
281
282pub use frank_wolfe::{
284 FrankWolfeAlgorithm, FrankWolfeConstraint, FrankWolfeCovariance, FrankWolfeCovarianceTrained,
285 LineSearchMethod as FrankWolfeLineSearchMethod,
286 ObjectiveFunction as FrankWolfeObjectiveFunction,
287};
288
289pub use bayesian_covariance::{
291 BayesianCovariance, BayesianCovarianceFitted, BayesianMethod, BayesianPrior, McmcConfig,
292 VariationalConfig, VariationalParameters,
293};
294
295pub use time_varying_covariance::{
297 DccConfig, ExponentialWeightedConfig, GarchConfig, GarchType, RegimeSwitchingConfig,
298 RollingWindowConfig, TimeVaryingCovariance, TimeVaryingCovarianceFitted, TimeVaryingMethod,
299};
300
301pub use nonparametric_covariance::{
303 CopulaConfig, CopulaType, DistributionFreeConfig, KdeConfig, KernelType,
304 NonparametricCovariance, NonparametricCovarianceFitted, NonparametricMethod, RankBasedConfig,
305 RankCorrelationType, RobustCorrelationConfig, RobustCorrelationType,
306};
307
308pub use financial_applications::{
310 FactorModelMethod, OptimizationMethod, PortfolioOptimizer, RiskDecomposition, RiskFactorModel,
311 RiskFactorModelTrained, RiskFactorModelUntrained, StressScenario, StressTestResult,
312 StressTesting, VolatilityModel, VolatilityModelType,
313};
314
315pub use performance_optimizations::{
317 ComputationStats, DistributedCovariance, DistributionStrategy, MemoryEfficientCovariance,
318 MemoryEstimate, ParallelCovariance, ParallelCovarianceTrained, ParallelCovarianceUntrained,
319 SIMDCovariance, StreamingCovariance, StreamingMethod,
320};
321
322pub use testing_quality::{
324 AccuracyTestResult, BenchmarkConfig, BenchmarkResult as TestingBenchmarkResult, BenchmarkSuite,
325 ComparisonResult, DifficultyLevel, GroundTruthTestCase, NumericalAccuracyTester,
326 PropertyFailure, PropertyTestResult, PropertyTester, SingleAccuracyResult, SingleComparison,
327};
328
329pub use rust_improvements::{
331 CovarianceAccumulator, CovarianceError, CovarianceEstimator, CovarianceIterator,
332 CovarianceMetadata, Diagonal, General, GenericEmpiricalCovariance, MatrixStructure,
333 NumericallyStableCovariance, PivotingStrategy, PositiveDefinite, RobustCovarianceEstimator,
334 SharedCovariance, SparseCovarianceEstimator, Symmetric, ThreadSafeCovarianceView, TypedMatrix,
335 ZeroCostCovariance,
336};
337
338pub use genomics_bioinformatics::{
340 BranchLengthMethod, ClusteringMethod, ComplexDetectionMethod, CorrectionMethod,
341 EnrichmentMethod, EvolutionaryModel, GeneExpressionNetwork, GeneExpressionNetworkTrained,
342 GeneExpressionNetworkUntrained, IntegrationMethod, ModelFitStatistics, MultiOmicsCovariance,
343 MultiOmicsCovarianceTrained, MultiOmicsCovarianceUntrained, NetworkStatistics,
344 NormalizationMethod, PathwayAnalysis, PathwayAnalysisTrained, PathwayAnalysisUntrained,
345 PhylogeneticCovariance, PhylogeneticCovarianceTrained, PhylogeneticCovarianceUntrained,
346 ProteinInteractionNetwork, ProteinInteractionNetworkTrained,
347 ProteinInteractionNetworkUntrained, RateVariationModel, TopologyMetrics,
348};
349
350pub use signal_processing::{
352 AdaptiveAlgorithm, AdaptiveFilteringCovariance, AdaptiveFilteringCovarianceTrained,
353 AdaptiveFilteringCovarianceUntrained, ArrayGeometry, ArraySignalProcessing,
354 ArraySignalProcessingTrained, ArraySignalProcessingUntrained, BeamformingAlgorithm,
355 BeamformingCovariance, BeamformingCovarianceTrained, BeamformingCovarianceUntrained,
356 ClutterStatistics, ClutterSuppression, ConvergenceAnalysis, ConvergenceParams,
357 CorrelationHandling, DOAMethod, DetectionMethod, DopplerProcessing, FilterType,
358 NoiseCharacteristics, NoiseType, PerformanceMetric, RadarSonarCovariance,
359 RadarSonarCovarianceTrained, RadarSonarCovarianceUntrained, RangeProcessing,
360 SpatialCovarianceEstimator, SpatialCovarianceEstimatorTrained,
361 SpatialCovarianceEstimatorUntrained, SpatialEstimationMethod, SpatialSmoothing, SystemType,
362};
363
364pub use composable_regularization::{
366 CombinationMethod, CompositeRegularization, GroupLassoRegularization, L1Regularization,
367 L2Regularization, NuclearNormRegularization, RegularizationFactory, RegularizationStrategy,
368};
369
370pub use fluent_api::{
372 ConditioningMethod, ConditioningStep, CovariancePipeline,
373 CrossValidationConfig as FluentCrossValidationConfig, EstimatorConfig, EstimatorType, Fitted,
374 OutlierMethod, OutlierRemovalStep, PipelineMetadata, PostprocessingStep, PreprocessingStep,
375 ScoringMetric as FluentScoringMetric, StandardizationStep, StepResult, StepType, Unfit,
376};
377
378pub use differential_privacy::{
380 BudgetAllocation, CompositionMethod, DifferentialPrivacyCovariance,
381 DifferentialPrivacyCovarianceTrained, DifferentialPrivacyCovarianceUntrained, NoiseCalibration,
382 PrivacyAccountant, PrivacyMechanism, PrivacyOperation, UtilityMetrics,
383};
384
385pub use information_theory::{
387 DivergenceMeasure, EntropyEstimator, InformationMethod, InformationMetrics,
388 InformationRegularization, InformationTheoryCovariance, InformationTheoryCovarianceTrained,
389 InformationTheoryCovarianceUntrained,
390};
391
392pub use meta_learning::{
394 CovarianceMethod, MetaFeatures, MetaLearningCovariance, MetaLearningCovarianceTrained,
395 MetaLearningCovarianceUntrained, MetaLearningStrategy,
396 OptimizationMethod as MetaOptimizationMethod, PerformanceHistory,
397 PerformanceMetrics as MetaPerformanceMetrics,
398};
399
400pub use optimization::{
402 AdamOptimizer, CoordinateDescentOptimizer, LineSearchMethod, NelderMeadOptimizer,
403 ObjectiveFunction, OptimizationAlgorithm, OptimizationConfig, OptimizationConfigBuilder,
404 OptimizationHistory as FrameworkOptimizationHistory, OptimizationResult, OptimizerRegistry,
405 OptimizerType, ProximalGradientOptimizer, SGDOptimizer,
406};
407
408pub use quantum_methods::{
410 AlgorithmComplexity, QuantumAdvantageAnalysis, QuantumAlgorithmType, QuantumInspiredCovariance,
411 QuantumInspiredCovarianceTrained, QuantumInspiredCovarianceUntrained,
412};
413
414pub use adversarial_robustness::{
416 AdversarialRobustCovariance, AdversarialRobustCovarianceTrained,
417 AdversarialRobustCovarianceUntrained, RobustnessDiagnostics, RobustnessMethod,
418};
419
420pub use federated_learning::{
422 create_federated_parties, split_data_for_federation, AggregationMethod, CommunicationCost,
423 FederatedCovariance, FederatedCovarianceTrained, FederatedCovarianceUntrained, FederatedParty,
424 PrivacyMechanism as FederatedPrivacyMechanism,
425};
426
427pub use plugin_architecture::{
429 CovariancePluginRegistry, CustomCovarianceEstimator, EstimatorFactory, EstimatorMetadata,
430 EstimatorState, Hook, HookContext, HookType, Middleware, ParameterSpec as PluginParameterSpec,
431 RegularizationFunction, GLOBAL_REGISTRY,
432};
433
434#[allow(non_snake_case)]
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use scirs2_core::ndarray::array;
461 use sklears_core::traits::Fit;
462
463 #[test]
464 fn test_empirical_covariance_basic() {
465 let x = array![[1.0, 0.1], [2.0, 1.9], [3.0, 2.8], [4.0, 4.1], [5.0, 4.9]];
466
467 let estimator = EmpiricalCovariance::new();
468 let fitted = estimator.fit(&x.view(), &()).unwrap();
469
470 assert_eq!(fitted.get_covariance().dim(), (2, 2));
471 assert!(fitted.get_precision().is_some());
472 }
473
474 #[test]
475 fn test_ledoit_wolf_basic() {
476 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [2.0, 3.0]];
477
478 let estimator = LedoitWolf::new();
479 let fitted = estimator.fit(&x.view(), &()).unwrap();
480
481 assert_eq!(fitted.get_covariance().dim(), (2, 2));
482 assert!(fitted.get_precision().is_some());
483 let shrinkage = fitted.get_shrinkage();
484 assert!((0.0..=1.0).contains(&shrinkage));
485 }
486
487 #[test]
488 fn test_shrunk_covariance_basic() {
489 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
490
491 let estimator = ShrunkCovariance::new().shrinkage(0.1);
492 let fitted = estimator.fit(&x.view(), &()).unwrap();
493
494 assert_eq!(fitted.get_covariance().dim(), (2, 2));
495 assert_eq!(fitted.get_shrinkage(), 0.1);
496 }
497
498 #[test]
499 fn test_min_cov_det_basic() {
500 let x = array![
501 [1.0, 1.0],
502 [2.0, 2.0],
503 [3.0, 3.0],
504 [4.0, 4.0],
505 [5.0, 5.0],
506 [6.0, 6.0]
507 ];
508
509 let estimator = MinCovDet::new().support_fraction(0.8);
510 match estimator.fit(&x.view(), &()) {
511 Ok(fitted) => {
512 assert_eq!(fitted.get_covariance().dim(), (2, 2));
513 assert_eq!(fitted.get_support().len(), 6);
514 }
515 Err(_) => {
516 }
518 }
519 }
520
521 #[test]
522 fn test_oas_basic() {
523 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [2.0, 3.0]];
524
525 let estimator = OAS::new();
526 let fitted = estimator.fit(&x.view(), &()).unwrap();
527
528 assert_eq!(fitted.get_covariance().dim(), (2, 2));
529 assert!(fitted.get_precision().is_some());
530 let shrinkage = fitted.get_shrinkage();
531 assert!((0.0..=1.0).contains(&shrinkage));
532 }
533
534 #[test]
535 fn test_elliptic_envelope_basic() {
536 let x = array![
537 [1.0, 1.0],
538 [2.0, 2.0],
539 [3.0, 3.0],
540 [4.0, 4.0],
541 [5.0, 5.0],
542 [6.0, 6.0],
543 [7.0, 7.0],
544 [8.0, 8.0],
545 [100.0, 100.0]
546 ];
547
548 let estimator = EllipticEnvelope::new().contamination(0.25);
549 match estimator.fit(&x.view(), &()) {
550 Ok(fitted) => {
551 assert_eq!(fitted.get_covariance().dim(), (2, 2));
552 let predictions = fitted.predict(&x.view()).unwrap();
553 assert_eq!(predictions.len(), 9);
554 }
555 Err(_) => {
556 }
558 }
559 }
560
561 #[test]
562 fn test_huber_covariance_basic() {
563 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
564
565 let estimator = HuberCovariance::new().delta(1.5);
566 match estimator.fit(&x.view(), &()) {
567 Ok(fitted) => {
568 assert_eq!(fitted.get_covariance().dim(), (2, 2));
569 assert!(fitted.get_precision().is_some());
570 assert_eq!(fitted.get_weights().len(), 3);
571 }
572 Err(_) => {
573 }
575 }
576 }
577
578 #[test]
579 fn test_ridge_covariance_basic() {
580 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
581
582 let estimator = RidgeCovariance::new().alpha(0.1);
583 let fitted = estimator.fit(&x.view(), &()).unwrap();
584
585 assert_eq!(fitted.get_covariance().dim(), (2, 2));
586 assert!(fitted.get_precision().is_some());
587 assert_eq!(fitted.get_alpha(), 0.1);
588 }
589
590 #[test]
591 fn test_elastic_net_covariance_basic() {
592 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
593
594 let estimator = ElasticNetCovariance::new().alpha(0.1).l1_ratio(0.5);
595 let fitted = estimator.fit(&x.view(), &()).unwrap();
596
597 assert_eq!(fitted.get_covariance().dim(), (2, 2));
598 assert!(fitted.get_precision().is_some());
599 assert_eq!(fitted.get_alpha(), 0.1);
600 assert_eq!(fitted.get_l1_ratio(), 0.5);
601 }
602
603 #[test]
604 fn test_chen_stein_covariance_basic() {
605 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [2.0, 3.0]];
606
607 let estimator = ChenSteinCovariance::new();
608 let fitted = estimator.fit(&x.view(), &()).unwrap();
609
610 assert_eq!(fitted.get_covariance().dim(), (2, 2));
611 assert!(fitted.get_precision().is_some());
612 let shrinkage = fitted.get_shrinkage();
613 assert!((0.0..=1.0).contains(&shrinkage));
614 }
615
616 #[test]
617 fn test_adaptive_lasso_covariance_basic() {
618 let x = array![
619 [1.0, 0.5],
620 [2.0, 1.5],
621 [3.0, 2.8],
622 [4.0, 3.9],
623 [5.0, 4.1],
624 [1.5, 0.8],
625 [2.5, 1.9],
626 [3.5, 3.1]
627 ];
628
629 let estimator = AdaptiveLassoCovariance::new()
630 .alpha(0.01)
631 .gamma(0.5)
632 .max_iter(20);
633 match estimator.fit(&x.view(), &()) {
634 Ok(fitted) => {
635 assert_eq!(fitted.get_covariance().dim(), (2, 2));
636 assert!(fitted.get_precision().is_some());
637 assert_eq!(fitted.get_alpha(), 0.01);
638 assert_eq!(fitted.get_gamma(), 0.5);
639 assert!(fitted.get_n_iter() > 0);
640 }
641 Err(_) => {
642 }
644 }
645 }
646
647 #[test]
648 fn test_group_lasso_covariance_basic() {
649 let x = array![
650 [1.0, 0.5, 0.1],
651 [2.0, 1.5, 0.2],
652 [3.0, 2.8, 0.3],
653 [4.0, 3.9, 0.4],
654 [5.0, 4.1, 0.5],
655 [1.5, 0.8, 0.15]
656 ];
657
658 let groups = vec![0, 0, 1]; let estimator = GroupLassoCovariance::new()
660 .alpha(0.05)
661 .groups(groups.clone())
662 .max_iter(30);
663
664 match estimator.fit(&x.view(), &()) {
665 Ok(fitted) => {
666 assert_eq!(fitted.get_covariance().dim(), (3, 3));
667 assert!(fitted.get_precision().is_some());
668 assert_eq!(fitted.get_groups(), &groups);
669 assert_eq!(fitted.get_alpha(), 0.05);
670 assert!(fitted.get_n_iter() > 0);
671 }
672 Err(_) => {
673 }
675 }
676 }
677
678 #[test]
679 fn test_rao_blackwell_ledoit_wolf_basic() {
680 let x = array![
681 [1.0, 0.5],
682 [2.0, 1.5],
683 [3.0, 2.8],
684 [4.0, 3.9],
685 [5.0, 4.1],
686 [1.5, 0.8],
687 [2.5, 1.9],
688 [3.5, 3.1]
689 ];
690
691 let estimator = RaoBlackwellLedoitWolf::new();
692 let fitted = estimator.fit(&x.view(), &()).unwrap();
693
694 assert_eq!(fitted.get_covariance().dim(), (2, 2));
695 assert!(fitted.get_precision().is_some());
696
697 let shrinkage = fitted.get_shrinkage();
698 let effective_shrinkage = fitted.get_effective_shrinkage();
699
700 assert!((0.0..=1.0).contains(&shrinkage));
701 assert!((0.0..=1.0).contains(&effective_shrinkage));
702 }
703
704 #[test]
705 fn test_nonlinear_shrinkage_basic() {
706 let x = array![
707 [1.0, 0.5],
708 [2.0, 1.5],
709 [3.0, 2.8],
710 [4.0, 3.9],
711 [5.0, 4.1],
712 [1.5, 0.8],
713 [2.5, 1.9],
714 [3.5, 3.1]
715 ];
716
717 let estimator = NonlinearShrinkage::new();
718 let fitted = estimator.fit(&x.view(), &()).unwrap();
719
720 assert_eq!(fitted.get_covariance().dim(), (2, 2));
721 assert!(fitted.get_precision().is_some());
722 assert_eq!(fitted.get_eigenvalues().len(), 2);
723 }
724
725 #[test]
726 fn test_clime_basic() {
727 let x = array![
728 [1.0, 0.5, 0.1],
729 [2.0, 1.5, 0.2],
730 [3.0, 2.8, 0.3],
731 [4.0, 3.9, 0.4],
732 [5.0, 4.1, 0.5],
733 [1.5, 0.8, 0.15]
734 ];
735
736 let estimator = CLIME::new().lambda(0.1);
737 let fitted = estimator.fit(&x.view(), &()).unwrap();
738
739 assert_eq!(fitted.get_covariance().dim(), (3, 3));
740 assert_eq!(fitted.get_precision().dim(), (3, 3));
741 assert_eq!(fitted.get_lambda(), 0.1);
742 }
743
744 #[test]
745 fn test_nuclear_norm_basic() {
746 let x = array![
747 [1.0, 0.8, 0.6],
748 [2.0, 1.6, 1.2],
749 [3.0, 2.4, 1.8],
750 [4.0, 3.2, 2.4],
751 [5.0, 4.0, 3.0]
752 ];
753
754 let estimator = NuclearNormMinimization::new().lambda(0.1);
755 let fitted = estimator.fit(&x.view(), &()).unwrap();
756
757 assert_eq!(fitted.get_covariance().dim(), (3, 3));
758 assert!(fitted.get_precision().is_some());
759 assert_eq!(fitted.get_lambda(), 0.1);
760 assert!(fitted.get_rank() <= 3);
763 }
764
765 #[test]
766 fn test_bigquic_basic() {
767 let x = array![
768 [1.0, 0.5, 0.1],
769 [2.0, 1.5, 0.2],
770 [3.0, 2.8, 0.3],
771 [4.0, 3.9, 0.4],
772 [5.0, 4.1, 0.5],
773 [1.5, 0.8, 0.15],
774 [2.5, 1.9, 0.25],
775 [3.5, 3.1, 0.35]
776 ];
777
778 let estimator = BigQUIC::new().lambda(0.1).max_iter(50).block_size(100);
779
780 match estimator.fit(&x.view(), &()) {
781 Ok(fitted) => {
782 assert_eq!(fitted.get_precision().dim(), (3, 3));
783 assert_eq!(fitted.get_covariance().dim(), (3, 3));
784 assert_eq!(fitted.get_lambda(), 0.1);
785 assert_eq!(fitted.get_block_size(), 100);
786 assert!(fitted.get_n_iter() > 0);
787 assert!(fitted.get_nnz() > 0);
788 assert!(fitted.get_sparsity_ratio() >= 0.0 && fitted.get_sparsity_ratio() <= 1.0);
789 }
790 Err(_) => {
791 }
793 }
794 }
795
796 #[test]
797 fn test_factor_model_basic() {
798 let x = array![
799 [1.0, 0.8, 0.6, 0.4],
800 [2.0, 1.6, 1.2, 0.8],
801 [3.0, 2.4, 1.8, 1.2],
802 [4.0, 3.2, 2.4, 1.6],
803 [5.0, 4.0, 3.0, 2.0],
804 [1.5, 1.2, 0.9, 0.6],
805 [2.5, 2.0, 1.5, 1.0],
806 [3.5, 2.8, 2.1, 1.4]
807 ];
808
809 let estimator = FactorModelCovariance::new().n_factors(2).max_iter(50);
810
811 match estimator.fit(&x.view(), &()) {
812 Ok(fitted) => {
813 assert_eq!(fitted.get_loadings().dim(), (4, 2));
814 assert_eq!(fitted.get_specific_variances().len(), 4);
815 assert_eq!(fitted.get_covariance().dim(), (4, 4));
816 assert_eq!(fitted.get_n_factors(), 2);
817 assert!(fitted.get_n_iter() > 0);
818 assert!(fitted.get_goodness_of_fit() >= 0.0 && fitted.get_goodness_of_fit() <= 1.0);
819 }
820 Err(_) => {
821 }
823 }
824 }
825
826 #[test]
827 fn test_low_rank_sparse_basic() {
828 let x = array![
829 [1.0, 0.8, 0.6, 0.4],
830 [2.0, 1.6, 1.2, 0.8],
831 [3.0, 2.4, 1.8, 1.2],
832 [4.0, 3.2, 2.4, 1.6],
833 [5.0, 4.0, 3.0, 2.0],
834 [1.5, 1.2, 0.9, 0.6],
835 [2.5, 2.0, 1.5, 1.0],
836 [3.5, 2.8, 2.1, 1.4]
837 ];
838
839 let estimator = LowRankSparseCovariance::new()
840 .lambda_nuclear(0.1)
841 .lambda_l1(0.1)
842 .max_iter(50);
843
844 match estimator.fit(&x.view(), &()) {
845 Ok(fitted) => {
846 assert_eq!(fitted.get_low_rank_component().dim(), (4, 4));
847 assert_eq!(fitted.get_sparse_component().dim(), (4, 4));
848 assert_eq!(fitted.get_covariance().dim(), (4, 4));
849 assert_eq!(fitted.get_lambda_nuclear(), 0.1);
850 assert_eq!(fitted.get_lambda_l1(), 0.1);
851 assert!(fitted.get_n_iter() > 0);
852 assert!(fitted.get_sparsity_ratio() >= 0.0 && fitted.get_sparsity_ratio() <= 1.0);
853 assert!(fitted.get_low_rank_ratio() >= 0.0 && fitted.get_low_rank_ratio() <= 1.0);
854 }
855 Err(_) => {
856 }
858 }
859 }
860
861 #[test]
862 fn test_als_covariance_basic() {
863 let x = array![
864 [1.0, 0.8, 0.6, 0.4],
865 [2.0, 1.6, 1.2, 0.8],
866 [3.0, 2.4, 1.8, 1.2],
867 [4.0, 3.2, 2.4, 1.6],
868 [5.0, 4.0, 3.0, 2.0],
869 [1.5, 1.2, 0.9, 0.6],
870 [2.5, 2.0, 1.5, 1.0],
871 [3.5, 2.8, 2.1, 1.4]
872 ];
873
874 let estimator = ALSCovariance::new()
875 .n_factors(2)
876 .max_iter(50)
877 .reg_param(0.01);
878
879 match estimator.fit(&x.view(), &()) {
880 Ok(fitted) => {
881 assert_eq!(fitted.get_left_factors().dim(), (4, 2));
882 assert_eq!(fitted.get_right_factors().dim(), (2, 4));
883 assert_eq!(fitted.get_covariance().dim(), (4, 4));
884 assert_eq!(fitted.get_n_factors(), 2);
885 assert!(fitted.get_n_iter() > 0);
886 assert!(fitted.get_reconstruction_error() >= 0.0);
887 assert!(fitted.get_explained_variance_ratio() >= 0.0);
888 }
889 Err(_) => {
890 }
892 }
893 }
894}