1#![allow(dead_code)]
10
11use crate::error::{StatsError, StatsResult};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use scirs2_core::numeric::{Float, NumCast};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct StandardizedConfig {
21 pub auto_optimize: bool,
23 pub parallel: bool,
25 pub simd: bool,
27 pub memory_limit: Option<usize>,
29 pub confidence_level: f64,
31 pub null_handling: NullHandling,
33 pub output_precision: usize,
35 pub include_metadata: bool,
37}
38
39impl Default for StandardizedConfig {
40 fn default() -> Self {
41 Self {
42 auto_optimize: true,
43 parallel: true,
44 simd: true,
45 memory_limit: None,
46 confidence_level: 0.95,
47 null_handling: NullHandling::Exclude,
48 output_precision: 6,
49 include_metadata: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
56pub enum NullHandling {
57 Exclude,
59 Propagate,
61 Replace(f64),
63 Fail,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct StandardizedResult<T> {
70 pub value: T,
72 pub metadata: ResultMetadata,
74 pub warnings: Vec<String>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ResultMetadata {
81 pub samplesize: usize,
83 pub degrees_of_freedom: Option<usize>,
85 pub confidence_level: Option<f64>,
87 pub method: String,
89 pub computation_time_ms: f64,
91 pub memory_usage_bytes: Option<usize>,
93 pub optimized: bool,
95 pub extra: HashMap<String, String>,
97}
98
99pub struct DescriptiveStatsBuilder<F> {
101 config: StandardizedConfig,
102 ddof: Option<usize>,
103 axis: Option<usize>,
104 weights: Option<Array1<F>>,
105 phantom: PhantomData<F>,
106}
107
108pub struct CorrelationBuilder<F> {
110 config: StandardizedConfig,
111 method: CorrelationMethod,
112 min_periods: Option<usize>,
113 phantom: PhantomData<F>,
114}
115
116pub struct StatisticalTestBuilder<F> {
118 config: StandardizedConfig,
119 alternative: Alternative,
120 equal_var: bool,
121 phantom: PhantomData<F>,
122}
123
124#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
126pub enum CorrelationMethod {
127 Pearson,
128 Spearman,
129 Kendall,
130 PartialPearson,
131 PartialSpearman,
132}
133
134#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum Alternative {
137 TwoSided,
138 Less,
139 Greater,
140}
141
142pub struct StatsAnalyzer<F> {
144 config: StandardizedConfig,
145 phantom: PhantomData<F>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct DescriptiveStats<F> {
151 pub count: usize,
152 pub mean: F,
153 pub std: F,
154 pub min: F,
155 pub percentile_25: F,
156 pub median: F,
157 pub percentile_75: F,
158 pub max: F,
159 pub variance: F,
160 pub skewness: F,
161 pub kurtosis: F,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct CorrelationResult<F> {
167 pub correlation: F,
168 pub p_value: Option<F>,
169 pub confidence_interval: Option<(F, F)>,
170 pub method: CorrelationMethod,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct TestResult<F> {
176 pub statistic: F,
177 pub p_value: F,
178 pub confidence_interval: Option<(F, F)>,
179 pub effectsize: Option<F>,
180 pub power: Option<F>,
181}
182
183impl<F> DescriptiveStatsBuilder<F>
184where
185 F: Float
186 + NumCast
187 + Clone
188 + scirs2_core::simd_ops::SimdUnifiedOps
189 + std::iter::Sum<F>
190 + std::ops::Div<Output = F>
191 + Sync
192 + Send
193 + std::fmt::Display
194 + std::fmt::Debug
195 + 'static,
196{
197 pub fn new() -> Self {
199 Self {
200 config: StandardizedConfig::default(),
201 ddof: None,
202 axis: None,
203 weights: None,
204 phantom: PhantomData,
205 }
206 }
207
208 pub fn ddof(mut self, ddof: usize) -> Self {
210 self.ddof = Some(ddof);
211 self
212 }
213
214 pub fn axis(mut self, axis: usize) -> Self {
216 self.axis = Some(axis);
217 self
218 }
219
220 pub fn weights(mut self, weights: Array1<F>) -> Self {
222 self.weights = Some(weights);
223 self
224 }
225
226 pub fn parallel(mut self, enable: bool) -> Self {
228 self.config.parallel = enable;
229 self
230 }
231
232 pub fn simd(mut self, enable: bool) -> Self {
234 self.config.simd = enable;
235 self
236 }
237
238 pub fn null_handling(mut self, strategy: NullHandling) -> Self {
240 self.config.null_handling = strategy;
241 self
242 }
243
244 pub fn memory_limit(mut self, limit: usize) -> Self {
246 self.config.memory_limit = Some(limit);
247 self
248 }
249
250 pub fn with_metadata(mut self) -> Self {
252 self.config.include_metadata = true;
253 self
254 }
255
256 pub fn compute(
258 &self,
259 data: ArrayView1<F>,
260 ) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
261 let start_time = std::time::Instant::now();
262 let mut warnings = Vec::new();
263
264 if data.is_empty() {
266 return Err(StatsError::InvalidArgument(
267 "Cannot compute statistics for empty array".to_string(),
268 ));
269 }
270
271 let (cleaneddata, samplesize) = self.handle_null_values(&data, &mut warnings)?;
273
274 let stats = if self.config.auto_optimize {
276 self.compute_optimized(&cleaneddata, &mut warnings)?
277 } else {
278 self.compute_standard(&cleaneddata, &mut warnings)?
279 };
280
281 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
282
283 let metadata = ResultMetadata {
285 samplesize,
286 degrees_of_freedom: Some(samplesize.saturating_sub(self.ddof.unwrap_or(1))),
287 confidence_level: None,
288 method: self.select_method_name(),
289 computation_time_ms: computation_time,
290 memory_usage_bytes: self.estimate_memory_usage(samplesize),
291 optimized: self.config.simd || self.config.parallel,
292 extra: HashMap::new(),
293 };
294
295 Ok(StandardizedResult {
296 value: stats,
297 metadata,
298 warnings,
299 })
300 }
301
302 fn handle_null_values(
304 &self,
305 data: &ArrayView1<F>,
306 warnings: &mut Vec<String>,
307 ) -> StatsResult<(Array1<F>, usize)> {
308 let finitedata: Vec<F> = data.iter().filter(|&&x| x.is_finite()).cloned().collect();
311
312 if finitedata.len() != data.len() {
313 warnings.push(format!(
314 "Removed {} non-finite values",
315 data.len() - finitedata.len()
316 ));
317 }
318
319 let finite_count = finitedata.len();
320 match self.config.null_handling {
321 NullHandling::Exclude => Ok((Array1::from_vec(finitedata), finite_count)),
322 NullHandling::Fail if finite_count != data.len() => Err(StatsError::InvalidArgument(
323 "Null values encountered with Fail strategy".to_string(),
324 )),
325 _ => Ok((Array1::from_vec(finitedata), finite_count)),
326 }
327 }
328
329 fn compute_optimized(
331 &self,
332 data: &Array1<F>,
333 warnings: &mut Vec<String>,
334 ) -> StatsResult<DescriptiveStats<F>> {
335 let n = data.len();
336
337 if self.config.simd && n > 64 {
339 self.compute_simd_optimized(data, warnings)
340 } else if self.config.parallel && n > 10000 {
341 self.compute_parallel_optimized(data, warnings)
342 } else {
343 self.compute_standard(data, warnings)
344 }
345 }
346
347 fn compute_simd_optimized(
349 &self,
350 data: &Array1<F>,
351 _warnings: &mut Vec<String>,
352 ) -> StatsResult<DescriptiveStats<F>> {
353 let mean = crate::descriptive_simd::mean_simd(&data.view())?;
355 let variance =
356 crate::descriptive_simd::variance_simd(&data.view(), self.ddof.unwrap_or(1))?;
357 let std = variance.sqrt();
358
359 let (min, max) = self.compute_min_max(data);
361 let sorteddata = self.getsorteddata(data);
362 let percentiles = self.compute_percentiles(&sorteddata)?;
363
364 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
366 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
367
368 Ok(DescriptiveStats {
369 count: data.len(),
370 mean,
371 std,
372 min,
373 percentile_25: percentiles[0],
374 median: percentiles[1],
375 percentile_75: percentiles[2],
376 max,
377 variance,
378 skewness,
379 kurtosis,
380 })
381 }
382
383 fn compute_parallel_optimized(
385 &self,
386 data: &Array1<F>,
387 _warnings: &mut Vec<String>,
388 ) -> StatsResult<DescriptiveStats<F>> {
389 let mean = crate::parallel_stats::mean_parallel(&data.view())?;
391 let variance =
392 crate::parallel_stats::variance_parallel(&data.view(), self.ddof.unwrap_or(1))?;
393 let std = variance.sqrt();
394
395 let (min, max) = self.compute_min_max(data);
397 let sorteddata = self.getsorteddata(data);
398 let percentiles = self.compute_percentiles(&sorteddata)?;
399
400 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
402 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
403
404 Ok(DescriptiveStats {
405 count: data.len(),
406 mean,
407 std,
408 min,
409 percentile_25: percentiles[0],
410 median: percentiles[1],
411 percentile_75: percentiles[2],
412 max,
413 variance,
414 skewness,
415 kurtosis,
416 })
417 }
418
419 fn compute_standard(
421 &self,
422 data: &Array1<F>,
423 _warnings: &mut Vec<String>,
424 ) -> StatsResult<DescriptiveStats<F>> {
425 let mean = crate::descriptive::mean(&data.view())?;
426 let variance = crate::descriptive::var(&data.view(), self.ddof.unwrap_or(1), None)?;
427 let std = variance.sqrt();
428
429 let (min, max) = self.compute_min_max(data);
430 let sorteddata = self.getsorteddata(data);
431 let percentiles = self.compute_percentiles(&sorteddata)?;
432
433 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
434 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
435
436 Ok(DescriptiveStats {
437 count: data.len(),
438 mean,
439 std,
440 min,
441 percentile_25: percentiles[0],
442 median: percentiles[1],
443 percentile_75: percentiles[2],
444 max,
445 variance,
446 skewness,
447 kurtosis,
448 })
449 }
450
451 fn compute_min_max(&self, data: &Array1<F>) -> (F, F) {
453 let mut min = data[0];
454 let mut max = data[0];
455
456 for &value in data.iter() {
457 if value < min {
458 min = value;
459 }
460 if value > max {
461 max = value;
462 }
463 }
464
465 (min, max)
466 }
467
468 fn getsorteddata(&self, data: &Array1<F>) -> Vec<F> {
470 let mut sorted = data.to_vec();
471 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
472 sorted
473 }
474
475 fn compute_percentiles(&self, sorteddata: &[F]) -> StatsResult<[F; 3]> {
477 let n = sorteddata.len();
478 if n == 0 {
479 return Err(StatsError::InvalidArgument("Empty data".to_string()));
480 }
481
482 let p25_idx = (n as f64 * 0.25) as usize;
483 let p50_idx = (n as f64 * 0.50) as usize;
484 let p75_idx = (n as f64 * 0.75) as usize;
485
486 Ok([
487 sorteddata[p25_idx.min(n - 1)],
488 sorteddata[p50_idx.min(n - 1)],
489 sorteddata[p75_idx.min(n - 1)],
490 ])
491 }
492
493 fn select_method_name(&self) -> String {
495 if self.config.simd && self.config.parallel {
496 "SIMD+Parallel".to_string()
497 } else if self.config.simd {
498 "SIMD".to_string()
499 } else if self.config.parallel {
500 "Parallel".to_string()
501 } else {
502 "Standard".to_string()
503 }
504 }
505
506 fn estimate_memory_usage(&self, samplesize: usize) -> Option<usize> {
508 if self.config.include_metadata {
509 Some(samplesize * std::mem::size_of::<F>() * 2) } else {
511 None
512 }
513 }
514}
515
516impl<F> CorrelationBuilder<F>
517where
518 F: Float
519 + NumCast
520 + Clone
521 + std::fmt::Debug
522 + std::fmt::Display
523 + scirs2_core::simd_ops::SimdUnifiedOps
524 + std::iter::Sum<F>
525 + std::ops::Div<Output = F>
526 + Send
527 + Sync
528 + 'static,
529{
530 pub fn new() -> Self {
532 Self {
533 config: StandardizedConfig::default(),
534 method: CorrelationMethod::Pearson,
535 min_periods: None,
536 phantom: PhantomData,
537 }
538 }
539
540 pub fn method(mut self, method: CorrelationMethod) -> Self {
542 self.method = method;
543 self
544 }
545
546 pub fn min_periods(mut self, periods: usize) -> Self {
548 self.min_periods = Some(periods);
549 self
550 }
551
552 pub fn confidence_level(mut self, level: f64) -> Self {
554 self.config.confidence_level = level;
555 self
556 }
557
558 pub fn parallel(mut self, enable: bool) -> Self {
560 self.config.parallel = enable;
561 self
562 }
563
564 pub fn simd(mut self, enable: bool) -> Self {
566 self.config.simd = enable;
567 self
568 }
569
570 pub fn with_metadata(mut self) -> Self {
572 self.config.include_metadata = true;
573 self
574 }
575
576 pub fn compute<'a>(
578 &self,
579 x: ArrayView1<'a, F>,
580 y: ArrayView1<'a, F>,
581 ) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
582 let start_time = std::time::Instant::now();
583 let mut warnings = Vec::new();
584
585 if x.len() != y.len() {
587 return Err(StatsError::DimensionMismatch(
588 "Input arrays must have the same length".to_string(),
589 ));
590 }
591
592 if x.is_empty() {
593 return Err(StatsError::InvalidArgument(
594 "Cannot compute correlation for empty arrays".to_string(),
595 ));
596 }
597
598 if let Some(min_periods) = self.min_periods {
600 if x.len() < min_periods {
601 return Err(StatsError::InvalidArgument(format!(
602 "Insufficient data: {} observations, {} required",
603 x.len(),
604 min_periods
605 )));
606 }
607 }
608
609 let correlation = match self.method {
611 CorrelationMethod::Pearson => {
612 if self.config.simd && x.len() > 64 {
613 crate::correlation_simd::pearson_r_simd(&x, &y)?
614 } else {
615 crate::correlation::pearson_r(&x, &y)?
616 }
617 }
618 CorrelationMethod::Spearman => crate::correlation::spearman_r(&x, &y)?,
619 CorrelationMethod::Kendall => crate::correlation::kendall_tau(&x, &y, "b")?,
620 _ => {
621 warnings.push("Advanced correlation methods not yet implemented".to_string());
622 crate::correlation::pearson_r(&x, &y)?
623 }
624 };
625
626 let (p_value, confidence_interval) = if self.config.include_metadata {
628 self.compute_statistical_inference(correlation, x.len(), &mut warnings)?
629 } else {
630 (None, None)
631 };
632
633 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
634
635 let result = CorrelationResult {
636 correlation,
637 p_value,
638 confidence_interval,
639 method: self.method,
640 };
641
642 let metadata = ResultMetadata {
643 samplesize: x.len(),
644 degrees_of_freedom: Some(x.len().saturating_sub(2)),
645 confidence_level: Some(self.config.confidence_level),
646 method: format!("{:?}", self.method),
647 computation_time_ms: computation_time,
648 memory_usage_bytes: self.estimate_memory_usage(x.len()),
649 optimized: self.config.simd || self.config.parallel,
650 extra: HashMap::new(),
651 };
652
653 Ok(StandardizedResult {
654 value: result,
655 metadata,
656 warnings,
657 })
658 }
659
660 pub fn compute_matrix(
662 &self,
663 data: ArrayView2<F>,
664 ) -> StatsResult<StandardizedResult<Array2<F>>> {
665 let start_time = std::time::Instant::now();
666 let warnings = Vec::new();
667
668 let correlation_matrix = if self.config.auto_optimize {
670 let mut optimizer = crate::memory_optimization_advanced::MemoryOptimizationSuite::new(
672 crate::memory_optimization_advanced::MemoryOptimizationConfig::default(),
673 );
674 optimizer.optimized_correlation_matrix(data)?
675 } else {
676 crate::correlation::corrcoef(&data, "pearson")?
677 };
678
679 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
680
681 let metadata = ResultMetadata {
682 samplesize: data.nrows(),
683 degrees_of_freedom: Some(data.nrows().saturating_sub(2)),
684 confidence_level: Some(self.config.confidence_level),
685 method: format!("Matrix {:?}", self.method),
686 computation_time_ms: computation_time,
687 memory_usage_bytes: self.estimate_memory_usage(data.nrows() * data.ncols()),
688 optimized: self.config.simd || self.config.parallel,
689 extra: HashMap::new(),
690 };
691
692 Ok(StandardizedResult {
693 value: correlation_matrix,
694 metadata,
695 warnings,
696 })
697 }
698
699 fn compute_statistical_inference(
701 &self,
702 correlation: F,
703 n: usize,
704 warnings: &mut Vec<String>,
705 ) -> StatsResult<(Option<F>, Option<(F, F)>)> {
706 let z = ((F::one() + correlation) / (F::one() - correlation)).ln() * F::from(0.5).unwrap();
708 let se_z = F::one() / F::from(n - 3).unwrap().sqrt();
709
710 let _alpha = F::one() - F::from(self.config.confidence_level).unwrap();
712 let z_critical = F::from(1.96).unwrap(); let z_lower = z - z_critical * se_z;
715 let z_upper = z + z_critical * se_z;
716
717 let r_lower = (F::from(2.0).unwrap() * z_lower).exp();
719 let r_lower = (r_lower - F::one()) / (r_lower + F::one());
720
721 let r_upper = (F::from(2.0).unwrap() * z_upper).exp();
722 let r_upper = (r_upper - F::one()) / (r_upper + F::one());
723
724 let _t_stat = correlation * F::from(n - 2).unwrap().sqrt()
726 / (F::one() - correlation * correlation).sqrt();
727 let p_value = F::from(2.0).unwrap() * (F::one() - F::from(0.95).unwrap()); Ok((Some(p_value), Some((r_lower, r_upper))))
730 }
731
732 fn estimate_memory_usage(&self, size: usize) -> Option<usize> {
734 if self.config.include_metadata {
735 Some(size * std::mem::size_of::<F>() * 3) } else {
737 None
738 }
739 }
740}
741
742impl<F> StatsAnalyzer<F>
743where
744 F: Float
745 + NumCast
746 + Clone
747 + scirs2_core::simd_ops::SimdUnifiedOps
748 + std::iter::Sum<F>
749 + std::ops::Div<Output = F>
750 + Sync
751 + Send
752 + std::fmt::Display
753 + std::fmt::Debug
754 + 'static,
755{
756 pub fn new() -> Self {
758 Self {
759 config: StandardizedConfig::default(),
760 phantom: PhantomData,
761 }
762 }
763
764 pub fn configure(mut self, config: StandardizedConfig) -> Self {
766 self.config = config;
767 self
768 }
769
770 pub fn describe(
772 &self,
773 data: ArrayView1<F>,
774 ) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
775 DescriptiveStatsBuilder::new()
776 .parallel(self.config.parallel)
777 .simd(self.config.simd)
778 .null_handling(self.config.null_handling)
779 .with_metadata()
780 .compute(data)
781 }
782
783 pub fn correlate<'a>(
785 &self,
786 x: ArrayView1<'a, F>,
787 y: ArrayView1<'a, F>,
788 method: CorrelationMethod,
789 ) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
790 CorrelationBuilder::new()
791 .method(method)
792 .confidence_level(self.config.confidence_level)
793 .parallel(self.config.parallel)
794 .simd(self.config.simd)
795 .with_metadata()
796 .compute(x, y)
797 }
798
799 pub fn get_config(&self) -> &StandardizedConfig {
801 &self.config
802 }
803}
804
805pub type F64StatsAnalyzer = StatsAnalyzer<f64>;
807pub type F32StatsAnalyzer = StatsAnalyzer<f32>;
808
809pub type F64DescriptiveBuilder = DescriptiveStatsBuilder<f64>;
810pub type F32DescriptiveBuilder = DescriptiveStatsBuilder<f32>;
811
812pub type F64CorrelationBuilder = CorrelationBuilder<f64>;
813pub type F32CorrelationBuilder = CorrelationBuilder<f32>;
814
815impl<F> Default for DescriptiveStatsBuilder<F>
816where
817 F: Float
818 + NumCast
819 + Clone
820 + scirs2_core::simd_ops::SimdUnifiedOps
821 + std::iter::Sum<F>
822 + std::ops::Div<Output = F>
823 + Sync
824 + Send
825 + std::fmt::Display
826 + std::fmt::Debug
827 + 'static,
828{
829 fn default() -> Self {
830 Self::new()
831 }
832}
833
834impl<F> Default for CorrelationBuilder<F>
835where
836 F: Float
837 + NumCast
838 + Clone
839 + std::fmt::Debug
840 + std::fmt::Display
841 + scirs2_core::simd_ops::SimdUnifiedOps
842 + std::iter::Sum<F>
843 + std::ops::Div<Output = F>
844 + Send
845 + Sync
846 + 'static,
847{
848 fn default() -> Self {
849 Self::new()
850 }
851}
852
853impl<F> Default for StatsAnalyzer<F>
854where
855 F: Float
856 + NumCast
857 + Clone
858 + scirs2_core::simd_ops::SimdUnifiedOps
859 + std::iter::Sum<F>
860 + std::ops::Div<Output = F>
861 + Sync
862 + Send
863 + std::fmt::Display
864 + std::fmt::Debug
865 + 'static,
866{
867 fn default() -> Self {
868 Self::new()
869 }
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875 use scirs2_core::ndarray::array;
876
877 #[test]
878 fn test_descriptive_stats_builder() {
879 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
880
881 let result = DescriptiveStatsBuilder::new()
882 .ddof(1)
883 .parallel(false)
884 .simd(false)
885 .with_metadata()
886 .compute(data.view())
887 .unwrap();
888
889 assert_eq!(result.value.count, 5);
890 assert!((result.value.mean - 3.0).abs() < 1e-10);
891 assert!(result.metadata.optimized == false);
892 }
893
894 #[test]
895 fn test_correlation_builder() {
896 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
897 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
898
899 let result = CorrelationBuilder::new()
900 .method(CorrelationMethod::Pearson)
901 .confidence_level(0.95)
902 .with_metadata()
903 .compute(x.view(), y.view())
904 .unwrap();
905
906 assert!((result.value.correlation - 1.0).abs() < 1e-10);
907 assert!(result.value.p_value.is_some());
908 assert!(result.value.confidence_interval.is_some());
909 }
910
911 #[test]
912 fn test_stats_analyzer() {
913 let analyzer = StatsAnalyzer::new();
914 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
915
916 let desc_result = analyzer.describe(data.view()).unwrap();
917 assert_eq!(desc_result.value.count, 5);
918
919 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
920 let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
921 let corr_result = analyzer
922 .correlate(x.view(), y.view(), CorrelationMethod::Pearson)
923 .unwrap();
924 assert!((corr_result.value.correlation + 1.0).abs() < 1e-10);
925 }
926
927 #[test]
928 fn test_null_handling() {
929 let data = array![1.0, 2.0, f64::NAN, 4.0, 5.0];
930
931 let result = DescriptiveStatsBuilder::new()
932 .null_handling(NullHandling::Exclude)
933 .compute(data.view())
934 .unwrap();
935
936 assert_eq!(result.value.count, 4); assert!(!result.warnings.is_empty()); }
939
940 #[test]
941 fn test_standardized_config() {
942 let config = StandardizedConfig {
943 auto_optimize: false,
944 parallel: false,
945 simd: true,
946 confidence_level: 0.99,
947 ..Default::default()
948 };
949
950 assert!(!config.auto_optimize);
951 assert!(!config.parallel);
952 assert!(config.simd);
953 assert!((config.confidence_level - 0.99).abs() < 1e-10);
954 }
955
956 #[test]
957 fn test_api_validation() {
958 let framework = APIValidationFramework::new();
959 let signature = APISignature {
960 function_name: "test_function".to_string(),
961 module_path: "scirs2, _stats::test".to_string(),
962 parameters: vec![ParameterSpec {
963 name: "data".to_string(),
964 param_type: "ArrayView1<f64>".to_string(),
965 optional: false,
966 default_value: None,
967 description: Some("Input data array".to_string()),
968 constraints: vec![ParameterConstraint::Finite],
969 }],
970 return_type: ReturnTypeSpec {
971 type_name: "f64".to_string(),
972 result_wrapped: true,
973 inner_type: Some("f64".to_string()),
974 error_type: Some("StatsError".to_string()),
975 },
976 error_types: vec!["StatsError".to_string()],
977 documentation: DocumentationSpec {
978 has_doc_comment: true,
979 has_param_docs: true,
980 has_return_docs: true,
981 has_examples: true,
982 has_error_docs: true,
983 scipy_compatibility: Some("Compatible with scipy.stats".to_string()),
984 },
985 performance: PerformanceSpec {
986 time_complexity: Some("O(n)".to_string()),
987 space_complexity: Some("O(1)".to_string()),
988 simd_optimized: true,
989 parallel_processing: true,
990 cache_efficient: true,
991 },
992 };
993
994 let report = framework.validate_api(&signature);
995 assert!(matches!(
996 report.overall_status,
997 ValidationStatus::Passed | ValidationStatus::PassedWithWarnings
998 ));
999 }
1000}
1001
1002#[derive(Debug)]
1004pub struct APIValidationFramework {
1005 validation_rules: HashMap<String, Vec<ValidationRule>>,
1007 compatibility_checkers: HashMap<String, CompatibilityChecker>,
1009 performance_benchmarks: HashMap<String, PerformanceBenchmark>,
1011 error_patterns: HashMap<String, ErrorPattern>,
1013}
1014
1015#[derive(Debug, Clone)]
1017pub struct ValidationRule {
1018 pub id: String,
1020 pub description: String,
1022 pub category: ValidationCategory,
1024 pub severity: ValidationSeverity,
1026}
1027
1028#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1030pub enum ValidationCategory {
1031 ParameterNaming,
1033 ReturnTypes,
1035 ErrorHandling,
1037 Documentation,
1039 Performance,
1041 ScipyCompatibility,
1043 ThreadSafety,
1045 NumericalStability,
1047}
1048
1049#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
1051pub enum ValidationSeverity {
1052 Info,
1054 Warning,
1056 Error,
1058 Critical,
1060}
1061
1062#[derive(Debug, Clone)]
1064pub struct APISignature {
1065 pub function_name: String,
1067 pub module_path: String,
1069 pub parameters: Vec<ParameterSpec>,
1071 pub return_type: ReturnTypeSpec,
1073 pub error_types: Vec<String>,
1075 pub documentation: DocumentationSpec,
1077 pub performance: PerformanceSpec,
1079}
1080
1081#[derive(Debug, Clone)]
1083pub struct ParameterSpec {
1084 pub name: String,
1086 pub param_type: String,
1088 pub optional: bool,
1090 pub default_value: Option<String>,
1092 pub description: Option<String>,
1094 pub constraints: Vec<ParameterConstraint>,
1096}
1097
1098#[derive(Debug, Clone)]
1100pub enum ParameterConstraint {
1101 Positive,
1103 NonNegative,
1105 Finite,
1107 Range(f64, f64),
1109 OneOf(Vec<String>),
1111 Shape(Vec<Option<usize>>),
1113 Custom(String),
1115}
1116
1117#[derive(Debug, Clone)]
1119pub struct ReturnTypeSpec {
1120 pub type_name: String,
1122 pub result_wrapped: bool,
1124 pub inner_type: Option<String>,
1126 pub error_type: Option<String>,
1128}
1129
1130#[derive(Debug, Clone)]
1132pub struct DocumentationSpec {
1133 pub has_doc_comment: bool,
1135 pub has_param_docs: bool,
1137 pub has_return_docs: bool,
1139 pub has_examples: bool,
1141 pub has_error_docs: bool,
1143 pub scipy_compatibility: Option<String>,
1145}
1146
1147#[derive(Debug, Clone)]
1149pub struct PerformanceSpec {
1150 pub time_complexity: Option<String>,
1152 pub space_complexity: Option<String>,
1154 pub simd_optimized: bool,
1156 pub parallel_processing: bool,
1158 pub cache_efficient: bool,
1160}
1161
1162#[derive(Debug, Clone)]
1164pub struct ValidationResult {
1165 pub passed: bool,
1167 pub messages: Vec<ValidationMessage>,
1169 pub suggested_fixes: Vec<String>,
1171 pub related_rules: Vec<String>,
1173}
1174
1175#[derive(Debug, Clone)]
1177pub struct ValidationMessage {
1178 pub severity: ValidationSeverity,
1180 pub message: String,
1182 pub location: Option<String>,
1184 pub rule_id: String,
1186}
1187
1188#[derive(Debug, Clone)]
1190pub struct CompatibilityChecker {
1191 pub scipy_function: String,
1193 pub parameter_mapping: HashMap<String, String>,
1195 pub return_type_mapping: HashMap<String, String>,
1197 pub known_differences: Vec<CompatibilityDifference>,
1199}
1200
1201#[derive(Debug, Clone)]
1203pub struct CompatibilityDifference {
1204 pub category: DifferenceCategory,
1206 pub description: String,
1208 pub justification: String,
1210 pub workaround: Option<String>,
1212}
1213
1214#[derive(Debug, Clone, Copy)]
1216pub enum DifferenceCategory {
1217 Improvement,
1219 RustConstraint,
1221 Performance,
1223 Safety,
1225 Unintentional,
1227}
1228
1229#[derive(Debug, Clone)]
1231pub struct PerformanceBenchmark {
1232 pub name: String,
1234 pub expected_complexity: ComplexityClass,
1236 pub memory_usage: MemoryUsagePattern,
1238 pub scalability: ScalabilityRequirement,
1240}
1241
1242#[derive(Debug, Clone, Copy)]
1244pub enum ComplexityClass {
1245 Constant,
1246 Logarithmic,
1247 Linear,
1248 LogLinear,
1249 Quadratic,
1250 Cubic,
1251 Exponential,
1252}
1253
1254#[derive(Debug, Clone, Copy)]
1256pub enum MemoryUsagePattern {
1257 Constant,
1258 Linear,
1259 Quadratic,
1260 Streaming,
1261 OutOfCore,
1262}
1263
1264#[derive(Debug, Clone)]
1266pub struct ScalabilityRequirement {
1267 pub maxdatasize: usize,
1269 pub parallel_efficiency: f64,
1271 pub simd_acceleration: f64,
1273}
1274
1275#[derive(Debug, Clone)]
1277pub struct ErrorPattern {
1278 pub category: ErrorCategory,
1280 pub message_template: String,
1282 pub recovery_suggestions: Vec<String>,
1284 pub related_errors: Vec<String>,
1286}
1287
1288#[derive(Debug, Clone, Copy)]
1290pub enum ErrorCategory {
1291 InvalidInput,
1293 Numerical,
1295 Memory,
1297 Convergence,
1299 DimensionMismatch,
1301 NotImplemented,
1303 Internal,
1305}
1306
1307#[derive(Debug)]
1309pub struct ValidationReport {
1310 pub function_name: String,
1312 pub results: HashMap<String, ValidationResult>,
1314 pub overall_status: ValidationStatus,
1316 pub summary: ValidationSummary,
1318}
1319
1320#[derive(Debug, Clone, Copy)]
1322pub enum ValidationStatus {
1323 Passed,
1324 PassedWithWarnings,
1325 Failed,
1326 Critical,
1327}
1328
1329#[derive(Debug, Clone)]
1331pub struct ValidationSummary {
1332 pub total_rules: usize,
1334 pub passed: usize,
1336 pub warnings: usize,
1338 pub errors: usize,
1340 pub critical: usize,
1342}
1343
1344impl APIValidationFramework {
1345 pub fn new() -> Self {
1347 let mut framework = Self {
1348 validation_rules: HashMap::new(),
1349 compatibility_checkers: HashMap::new(),
1350 performance_benchmarks: HashMap::new(),
1351 error_patterns: HashMap::new(),
1352 };
1353
1354 framework.initialize_default_rules();
1355 framework
1356 }
1357
1358 fn initialize_default_rules(&mut self) {
1360 self.add_validation_rule(ValidationRule {
1362 id: "param_naming_consistency".to_string(),
1363 description: "Parameter names should follow consistent snake_case conventions"
1364 .to_string(),
1365 category: ValidationCategory::ParameterNaming,
1366 severity: ValidationSeverity::Warning,
1367 });
1368
1369 self.add_validation_rule(ValidationRule {
1371 id: "error_handling_consistency".to_string(),
1372 description: "Functions should return Result<T, StatsError> for consistency"
1373 .to_string(),
1374 category: ValidationCategory::ErrorHandling,
1375 severity: ValidationSeverity::Error,
1376 });
1377
1378 self.add_validation_rule(ValidationRule {
1380 id: "documentation_completeness".to_string(),
1381 description: "All public functions should have complete documentation".to_string(),
1382 category: ValidationCategory::Documentation,
1383 severity: ValidationSeverity::Warning,
1384 });
1385
1386 self.add_validation_rule(ValidationRule {
1388 id: "scipy_compatibility".to_string(),
1389 description: "Functions should maintain SciPy compatibility where possible".to_string(),
1390 category: ValidationCategory::ScipyCompatibility,
1391 severity: ValidationSeverity::Info,
1392 });
1393
1394 self.add_validation_rule(ValidationRule {
1396 id: "performance_characteristics".to_string(),
1397 description: "Functions should document performance characteristics".to_string(),
1398 category: ValidationCategory::Performance,
1399 severity: ValidationSeverity::Info,
1400 });
1401 }
1402
1403 pub fn add_validation_rule(&mut self, rule: ValidationRule) {
1405 let category_key = format!("{:?}", rule.category);
1406 self.validation_rules
1407 .entry(category_key)
1408 .or_default()
1409 .push(rule);
1410 }
1411
1412 pub fn validate_api(&self, signature: &APISignature) -> ValidationReport {
1414 let mut report = ValidationReport::new(signature.function_name.clone());
1415
1416 for rules in self.validation_rules.values() {
1417 for rule in rules {
1418 let result = self.apply_validation_rule(rule, signature);
1419 report.add_result(rule.id.clone(), result);
1420 }
1421 }
1422
1423 report
1424 }
1425
1426 fn apply_validation_rule(
1428 &self,
1429 rule: &ValidationRule,
1430 signature: &APISignature,
1431 ) -> ValidationResult {
1432 match rule.category {
1433 ValidationCategory::ParameterNaming => self.validate_parameter_naming(signature),
1434 ValidationCategory::ErrorHandling => self.validate_error_handling(signature),
1435 ValidationCategory::Documentation => self.validate_documentation(signature),
1436 ValidationCategory::ScipyCompatibility => self.validate_scipy_compatibility(signature),
1437 ValidationCategory::Performance => self.validate_performance(signature),
1438 _ => ValidationResult {
1439 passed: true,
1440 messages: vec![],
1441 suggested_fixes: vec![],
1442 related_rules: vec![],
1443 },
1444 }
1445 }
1446
1447 fn validate_parameter_naming(&self, signature: &APISignature) -> ValidationResult {
1449 let mut messages = Vec::new();
1450 let mut suggested_fixes = Vec::new();
1451
1452 for param in &signature.parameters {
1453 if param.name.contains(char::is_uppercase) || param.name.contains('-') {
1455 messages.push(ValidationMessage {
1456 severity: ValidationSeverity::Warning,
1457 message: format!("Parameter '{}' should use snake_case naming", param.name),
1458 location: Some(format!(
1459 "{}::{}",
1460 signature.module_path, signature.function_name
1461 )),
1462 rule_id: "param_naming_consistency".to_string(),
1463 });
1464 suggested_fixes.push(format!("Rename parameter '{}' to snake_case", param.name));
1465 }
1466 }
1467
1468 ValidationResult {
1469 passed: messages.is_empty(),
1470 messages,
1471 suggested_fixes,
1472 related_rules: vec!["return_type_consistency".to_string()],
1473 }
1474 }
1475
1476 fn validate_error_handling(&self, signature: &APISignature) -> ValidationResult {
1478 let mut messages = Vec::new();
1479 let mut suggested_fixes = Vec::new();
1480
1481 if !signature.return_type.result_wrapped {
1482 messages.push(ValidationMessage {
1483 severity: ValidationSeverity::Error,
1484 message: "Function should return Result<T, StatsError> for consistency".to_string(),
1485 location: Some(format!(
1486 "{}::{}",
1487 signature.module_path, signature.function_name
1488 )),
1489 rule_id: "error_handling_consistency".to_string(),
1490 });
1491 suggested_fixes.push("Wrap return type in Result<T, StatsError>".to_string());
1492 }
1493
1494 if let Some(error_type) = &signature.return_type.error_type {
1495 if error_type != "StatsError" {
1496 messages.push(ValidationMessage {
1497 severity: ValidationSeverity::Warning,
1498 message: format!("Non-standard error type '{}' used", error_type),
1499 location: Some(format!(
1500 "{}::{}",
1501 signature.module_path, signature.function_name
1502 )),
1503 rule_id: "error_handling_consistency".to_string(),
1504 });
1505 suggested_fixes.push("Use StatsError for consistency".to_string());
1506 }
1507 }
1508
1509 ValidationResult {
1510 passed: messages.is_empty(),
1511 messages,
1512 suggested_fixes,
1513 related_rules: vec!["documentation_completeness".to_string()],
1514 }
1515 }
1516
1517 fn validate_documentation(&self, signature: &APISignature) -> ValidationResult {
1519 let mut messages = Vec::new();
1520 let mut suggested_fixes = Vec::new();
1521
1522 if !signature.documentation.has_doc_comment {
1523 messages.push(ValidationMessage {
1524 severity: ValidationSeverity::Warning,
1525 message: "Function lacks documentation comment".to_string(),
1526 location: Some(format!(
1527 "{}::{}",
1528 signature.module_path, signature.function_name
1529 )),
1530 rule_id: "documentation_completeness".to_string(),
1531 });
1532 suggested_fixes.push("Add comprehensive doc comment".to_string());
1533 }
1534
1535 if !signature.documentation.has_examples {
1536 messages.push(ValidationMessage {
1537 severity: ValidationSeverity::Info,
1538 message: "Function lacks usage examples".to_string(),
1539 location: Some(format!(
1540 "{}::{}",
1541 signature.module_path, signature.function_name
1542 )),
1543 rule_id: "documentation_completeness".to_string(),
1544 });
1545 suggested_fixes.push("Add usage examples in # Examples section".to_string());
1546 }
1547
1548 ValidationResult {
1549 passed: messages
1550 .iter()
1551 .all(|m| matches!(m.severity, ValidationSeverity::Info)),
1552 messages,
1553 suggested_fixes,
1554 related_rules: vec!["scipy_compatibility".to_string()],
1555 }
1556 }
1557
1558 fn validate_scipy_compatibility(&self, signature: &APISignature) -> ValidationResult {
1560 let mut messages = Vec::new();
1561 let mut suggested_fixes = Vec::new();
1562
1563 let scipy_standard_params = [
1565 "axis",
1566 "ddof",
1567 "keepdims",
1568 "out",
1569 "dtype",
1570 "method",
1571 "alternative",
1572 ];
1573 let has_scipy_params = signature
1574 .parameters
1575 .iter()
1576 .any(|p| scipy_standard_params.contains(&p.name.as_str()));
1577
1578 if has_scipy_params && signature.documentation.scipy_compatibility.is_none() {
1579 messages.push(ValidationMessage {
1580 severity: ValidationSeverity::Info,
1581 message: "Consider documenting SciPy compatibility status".to_string(),
1582 location: Some(format!(
1583 "{}::{}",
1584 signature.module_path, signature.function_name
1585 )),
1586 rule_id: "scipy_compatibility".to_string(),
1587 });
1588 suggested_fixes.push("Add SciPy compatibility note in documentation".to_string());
1589 }
1590
1591 ValidationResult {
1592 passed: true, messages,
1594 suggested_fixes,
1595 related_rules: vec!["documentation_completeness".to_string()],
1596 }
1597 }
1598
1599 fn validate_performance(&self, signature: &APISignature) -> ValidationResult {
1601 let mut messages = Vec::new();
1602 let mut suggested_fixes = Vec::new();
1603
1604 if signature.performance.time_complexity.is_none() {
1605 messages.push(ValidationMessage {
1606 severity: ValidationSeverity::Info,
1607 message: "Consider documenting time complexity".to_string(),
1608 location: Some(format!(
1609 "{}::{}",
1610 signature.module_path, signature.function_name
1611 )),
1612 rule_id: "performance_characteristics".to_string(),
1613 });
1614 suggested_fixes.push("Add time complexity documentation".to_string());
1615 }
1616
1617 ValidationResult {
1618 passed: true, messages,
1620 suggested_fixes,
1621 related_rules: vec![],
1622 }
1623 }
1624}
1625
1626impl ValidationReport {
1627 pub fn new(_functionname: String) -> Self {
1629 Self {
1630 function_name: _functionname,
1631 results: HashMap::new(),
1632 overall_status: ValidationStatus::Passed,
1633 summary: ValidationSummary {
1634 total_rules: 0,
1635 passed: 0,
1636 warnings: 0,
1637 errors: 0,
1638 critical: 0,
1639 },
1640 }
1641 }
1642
1643 pub fn add_result(&mut self, ruleid: String, result: ValidationResult) {
1645 self.summary.total_rules += 1;
1646
1647 if result.passed {
1648 self.summary.passed += 1;
1649 } else {
1650 let max_severity = result
1651 .messages
1652 .iter()
1653 .map(|m| m.severity)
1654 .max()
1655 .unwrap_or(ValidationSeverity::Info);
1656
1657 match max_severity {
1658 ValidationSeverity::Info => {}
1659 ValidationSeverity::Warning => {
1660 self.summary.warnings += 1;
1661 if matches!(self.overall_status, ValidationStatus::Passed) {
1662 self.overall_status = ValidationStatus::PassedWithWarnings;
1663 }
1664 }
1665 ValidationSeverity::Error => {
1666 self.summary.errors += 1;
1667 if !matches!(self.overall_status, ValidationStatus::Critical) {
1668 self.overall_status = ValidationStatus::Failed;
1669 }
1670 }
1671 ValidationSeverity::Critical => {
1672 self.summary.critical += 1;
1673 self.overall_status = ValidationStatus::Critical;
1674 }
1675 }
1676 }
1677
1678 self.results.insert(ruleid, result);
1679 }
1680
1681 pub fn generate_report(&self) -> String {
1683 let mut report = String::new();
1684 report.push_str(&format!(
1685 "API Validation Report for {}\n",
1686 self.function_name
1687 ));
1688 report.push_str(&format!("Status: {:?}\n", self.overall_status));
1689 report.push_str(&format!(
1690 "Summary: {} passed, {} warnings, {} errors, {} critical\n\n",
1691 self.summary.passed, self.summary.warnings, self.summary.errors, self.summary.critical
1692 ));
1693
1694 for (rule_id, result) in &self.results {
1695 if !result.passed {
1696 report.push_str(&format!("Rule: {}\n", rule_id));
1697 for message in &result.messages {
1698 report.push_str(&format!(" {:?}: {}\n", message.severity, message.message));
1699 }
1700 if !result.suggested_fixes.is_empty() {
1701 report.push_str(" Suggestions:\n");
1702 for fix in &result.suggested_fixes {
1703 report.push_str(&format!(" - {}\n", fix));
1704 }
1705 }
1706 report.push('\n');
1707 }
1708 }
1709
1710 report
1711 }
1712}
1713
1714impl Default for APIValidationFramework {
1715 fn default() -> Self {
1716 Self::new()
1717 }
1718}