1#![allow(dead_code)]
10
11use crate::error::{StatsError, StatsResult};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use scirs2_core::numeric::{Float, NumCast};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct StandardizedConfig {
21 pub auto_optimize: bool,
23 pub parallel: bool,
25 pub simd: bool,
27 pub memory_limit: Option<usize>,
29 pub confidence_level: f64,
31 pub null_handling: NullHandling,
33 pub output_precision: usize,
35 pub include_metadata: bool,
37}
38
39impl Default for StandardizedConfig {
40 fn default() -> Self {
41 Self {
42 auto_optimize: true,
43 parallel: true,
44 simd: true,
45 memory_limit: None,
46 confidence_level: 0.95,
47 null_handling: NullHandling::Exclude,
48 output_precision: 6,
49 include_metadata: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
56pub enum NullHandling {
57 Exclude,
59 Propagate,
61 Replace(f64),
63 Fail,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct StandardizedResult<T> {
70 pub value: T,
72 pub metadata: ResultMetadata,
74 pub warnings: Vec<String>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ResultMetadata {
81 pub samplesize: usize,
83 pub degrees_of_freedom: Option<usize>,
85 pub confidence_level: Option<f64>,
87 pub method: String,
89 pub computation_time_ms: f64,
91 pub memory_usage_bytes: Option<usize>,
93 pub optimized: bool,
95 pub extra: HashMap<String, String>,
97}
98
99pub struct DescriptiveStatsBuilder<F> {
101 config: StandardizedConfig,
102 ddof: Option<usize>,
103 axis: Option<usize>,
104 weights: Option<Array1<F>>,
105 phantom: PhantomData<F>,
106}
107
108pub struct CorrelationBuilder<F> {
110 config: StandardizedConfig,
111 method: CorrelationMethod,
112 min_periods: Option<usize>,
113 phantom: PhantomData<F>,
114}
115
116pub struct StatisticalTestBuilder<F> {
118 config: StandardizedConfig,
119 alternative: Alternative,
120 equal_var: bool,
121 phantom: PhantomData<F>,
122}
123
124#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
126pub enum CorrelationMethod {
127 Pearson,
128 Spearman,
129 Kendall,
130 PartialPearson,
131 PartialSpearman,
132}
133
134#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
136pub enum Alternative {
137 TwoSided,
138 Less,
139 Greater,
140}
141
142pub struct StatsAnalyzer<F> {
144 config: StandardizedConfig,
145 phantom: PhantomData<F>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct DescriptiveStats<F> {
151 pub count: usize,
152 pub mean: F,
153 pub std: F,
154 pub min: F,
155 pub percentile_25: F,
156 pub median: F,
157 pub percentile_75: F,
158 pub max: F,
159 pub variance: F,
160 pub skewness: F,
161 pub kurtosis: F,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct CorrelationResult<F> {
167 pub correlation: F,
168 pub p_value: Option<F>,
169 pub confidence_interval: Option<(F, F)>,
170 pub method: CorrelationMethod,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct TestResult<F> {
176 pub statistic: F,
177 pub p_value: F,
178 pub confidence_interval: Option<(F, F)>,
179 pub effectsize: Option<F>,
180 pub power: Option<F>,
181}
182
183impl<F> DescriptiveStatsBuilder<F>
184where
185 F: Float
186 + NumCast
187 + Clone
188 + scirs2_core::simd_ops::SimdUnifiedOps
189 + std::iter::Sum<F>
190 + std::ops::Div<Output = F>
191 + Sync
192 + Send
193 + std::fmt::Display
194 + std::fmt::Debug
195 + 'static,
196{
197 pub fn new() -> Self {
199 Self {
200 config: StandardizedConfig::default(),
201 ddof: None,
202 axis: None,
203 weights: None,
204 phantom: PhantomData,
205 }
206 }
207
208 pub fn ddof(mut self, ddof: usize) -> Self {
210 self.ddof = Some(ddof);
211 self
212 }
213
214 pub fn axis(mut self, axis: usize) -> Self {
216 self.axis = Some(axis);
217 self
218 }
219
220 pub fn weights(mut self, weights: Array1<F>) -> Self {
222 self.weights = Some(weights);
223 self
224 }
225
226 pub fn parallel(mut self, enable: bool) -> Self {
228 self.config.parallel = enable;
229 self
230 }
231
232 pub fn simd(mut self, enable: bool) -> Self {
234 self.config.simd = enable;
235 self
236 }
237
238 pub fn null_handling(mut self, strategy: NullHandling) -> Self {
240 self.config.null_handling = strategy;
241 self
242 }
243
244 pub fn memory_limit(mut self, limit: usize) -> Self {
246 self.config.memory_limit = Some(limit);
247 self
248 }
249
250 pub fn with_metadata(mut self) -> Self {
252 self.config.include_metadata = true;
253 self
254 }
255
256 pub fn compute(
258 &self,
259 data: ArrayView1<F>,
260 ) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
261 let start_time = std::time::Instant::now();
262 let mut warnings = Vec::new();
263
264 if data.is_empty() {
266 return Err(StatsError::InvalidArgument(
267 "Cannot compute statistics for empty array".to_string(),
268 ));
269 }
270
271 let (cleaneddata, samplesize) = self.handle_null_values(&data, &mut warnings)?;
273
274 let stats = if self.config.auto_optimize {
276 self.compute_optimized(&cleaneddata, &mut warnings)?
277 } else {
278 self.compute_standard(&cleaneddata, &mut warnings)?
279 };
280
281 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
282
283 let metadata = ResultMetadata {
285 samplesize,
286 degrees_of_freedom: Some(samplesize.saturating_sub(self.ddof.unwrap_or(1))),
287 confidence_level: None,
288 method: self.select_method_name(),
289 computation_time_ms: computation_time,
290 memory_usage_bytes: self.estimate_memory_usage(samplesize),
291 optimized: self.config.simd || self.config.parallel,
292 extra: HashMap::new(),
293 };
294
295 Ok(StandardizedResult {
296 value: stats,
297 metadata,
298 warnings,
299 })
300 }
301
302 fn handle_null_values(
304 &self,
305 data: &ArrayView1<F>,
306 warnings: &mut Vec<String>,
307 ) -> StatsResult<(Array1<F>, usize)> {
308 let finitedata: Vec<F> = data.iter().filter(|&&x| x.is_finite()).cloned().collect();
311
312 if finitedata.len() != data.len() {
313 warnings.push(format!(
314 "Removed {} non-finite values",
315 data.len() - finitedata.len()
316 ));
317 }
318
319 let finite_count = finitedata.len();
320 match self.config.null_handling {
321 NullHandling::Exclude => Ok((Array1::from_vec(finitedata), finite_count)),
322 NullHandling::Fail if finite_count != data.len() => Err(StatsError::InvalidArgument(
323 "Null values encountered with Fail strategy".to_string(),
324 )),
325 _ => Ok((Array1::from_vec(finitedata), finite_count)),
326 }
327 }
328
329 fn compute_optimized(
331 &self,
332 data: &Array1<F>,
333 warnings: &mut Vec<String>,
334 ) -> StatsResult<DescriptiveStats<F>> {
335 let n = data.len();
336
337 if self.config.simd && n > 64 {
339 self.compute_simd_optimized(data, warnings)
340 } else if self.config.parallel && n > 10000 {
341 self.compute_parallel_optimized(data, warnings)
342 } else {
343 self.compute_standard(data, warnings)
344 }
345 }
346
347 fn compute_simd_optimized(
349 &self,
350 data: &Array1<F>,
351 _warnings: &mut Vec<String>,
352 ) -> StatsResult<DescriptiveStats<F>> {
353 let mean = crate::descriptive_simd::mean_simd(&data.view())?;
355 let variance =
356 crate::descriptive_simd::variance_simd(&data.view(), self.ddof.unwrap_or(1))?;
357 let std = variance.sqrt();
358
359 let (min, max) = self.compute_min_max(data);
361 let sorteddata = self.getsorteddata(data);
362 let percentiles = self.compute_percentiles(&sorteddata)?;
363
364 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
366 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
367
368 Ok(DescriptiveStats {
369 count: data.len(),
370 mean,
371 std,
372 min,
373 percentile_25: percentiles[0],
374 median: percentiles[1],
375 percentile_75: percentiles[2],
376 max,
377 variance,
378 skewness,
379 kurtosis,
380 })
381 }
382
383 fn compute_parallel_optimized(
385 &self,
386 data: &Array1<F>,
387 _warnings: &mut Vec<String>,
388 ) -> StatsResult<DescriptiveStats<F>> {
389 let mean = crate::parallel_stats::mean_parallel(&data.view())?;
391 let variance =
392 crate::parallel_stats::variance_parallel(&data.view(), self.ddof.unwrap_or(1))?;
393 let std = variance.sqrt();
394
395 let (min, max) = self.compute_min_max(data);
397 let sorteddata = self.getsorteddata(data);
398 let percentiles = self.compute_percentiles(&sorteddata)?;
399
400 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
402 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
403
404 Ok(DescriptiveStats {
405 count: data.len(),
406 mean,
407 std,
408 min,
409 percentile_25: percentiles[0],
410 median: percentiles[1],
411 percentile_75: percentiles[2],
412 max,
413 variance,
414 skewness,
415 kurtosis,
416 })
417 }
418
419 fn compute_standard(
421 &self,
422 data: &Array1<F>,
423 _warnings: &mut Vec<String>,
424 ) -> StatsResult<DescriptiveStats<F>> {
425 let mean = crate::descriptive::mean(&data.view())?;
426 let variance = crate::descriptive::var(&data.view(), self.ddof.unwrap_or(1), None)?;
427 let std = variance.sqrt();
428
429 let (min, max) = self.compute_min_max(data);
430 let sorteddata = self.getsorteddata(data);
431 let percentiles = self.compute_percentiles(&sorteddata)?;
432
433 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
434 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
435
436 Ok(DescriptiveStats {
437 count: data.len(),
438 mean,
439 std,
440 min,
441 percentile_25: percentiles[0],
442 median: percentiles[1],
443 percentile_75: percentiles[2],
444 max,
445 variance,
446 skewness,
447 kurtosis,
448 })
449 }
450
451 fn compute_min_max(&self, data: &Array1<F>) -> (F, F) {
453 let mut min = data[0];
454 let mut max = data[0];
455
456 for &value in data.iter() {
457 if value < min {
458 min = value;
459 }
460 if value > max {
461 max = value;
462 }
463 }
464
465 (min, max)
466 }
467
468 fn getsorteddata(&self, data: &Array1<F>) -> Vec<F> {
470 let mut sorted = data.to_vec();
471 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
472 sorted
473 }
474
475 fn compute_percentiles(&self, sorteddata: &[F]) -> StatsResult<[F; 3]> {
477 let n = sorteddata.len();
478 if n == 0 {
479 return Err(StatsError::InvalidArgument("Empty data".to_string()));
480 }
481
482 let p25_idx = (n as f64 * 0.25) as usize;
483 let p50_idx = (n as f64 * 0.50) as usize;
484 let p75_idx = (n as f64 * 0.75) as usize;
485
486 Ok([
487 sorteddata[p25_idx.min(n - 1)],
488 sorteddata[p50_idx.min(n - 1)],
489 sorteddata[p75_idx.min(n - 1)],
490 ])
491 }
492
493 fn select_method_name(&self) -> String {
495 if self.config.simd && self.config.parallel {
496 "SIMD+Parallel".to_string()
497 } else if self.config.simd {
498 "SIMD".to_string()
499 } else if self.config.parallel {
500 "Parallel".to_string()
501 } else {
502 "Standard".to_string()
503 }
504 }
505
506 fn estimate_memory_usage(&self, samplesize: usize) -> Option<usize> {
508 if self.config.include_metadata {
509 Some(samplesize * std::mem::size_of::<F>() * 2) } else {
511 None
512 }
513 }
514}
515
516impl<F> CorrelationBuilder<F>
517where
518 F: Float
519 + NumCast
520 + Clone
521 + std::fmt::Debug
522 + std::fmt::Display
523 + scirs2_core::simd_ops::SimdUnifiedOps
524 + std::iter::Sum<F>
525 + std::ops::Div<Output = F>
526 + Send
527 + Sync
528 + 'static,
529{
530 pub fn new() -> Self {
532 Self {
533 config: StandardizedConfig::default(),
534 method: CorrelationMethod::Pearson,
535 min_periods: None,
536 phantom: PhantomData,
537 }
538 }
539
540 pub fn method(mut self, method: CorrelationMethod) -> Self {
542 self.method = method;
543 self
544 }
545
546 pub fn min_periods(mut self, periods: usize) -> Self {
548 self.min_periods = Some(periods);
549 self
550 }
551
552 pub fn confidence_level(mut self, level: f64) -> Self {
554 self.config.confidence_level = level;
555 self
556 }
557
558 pub fn parallel(mut self, enable: bool) -> Self {
560 self.config.parallel = enable;
561 self
562 }
563
564 pub fn simd(mut self, enable: bool) -> Self {
566 self.config.simd = enable;
567 self
568 }
569
570 pub fn with_metadata(mut self) -> Self {
572 self.config.include_metadata = true;
573 self
574 }
575
576 pub fn compute<'a>(
578 &self,
579 x: ArrayView1<'a, F>,
580 y: ArrayView1<'a, F>,
581 ) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
582 let start_time = std::time::Instant::now();
583 let mut warnings = Vec::new();
584
585 if x.len() != y.len() {
587 return Err(StatsError::DimensionMismatch(
588 "Input arrays must have the same length".to_string(),
589 ));
590 }
591
592 if x.is_empty() {
593 return Err(StatsError::InvalidArgument(
594 "Cannot compute correlation for empty arrays".to_string(),
595 ));
596 }
597
598 if let Some(min_periods) = self.min_periods {
600 if x.len() < min_periods {
601 return Err(StatsError::InvalidArgument(format!(
602 "Insufficient data: {} observations, {} required",
603 x.len(),
604 min_periods
605 )));
606 }
607 }
608
609 let correlation = match self.method {
611 CorrelationMethod::Pearson => {
612 if self.config.simd && x.len() > 64 {
613 crate::correlation_simd::pearson_r_simd(&x, &y)?
614 } else {
615 crate::correlation::pearson_r(&x, &y)?
616 }
617 }
618 CorrelationMethod::Spearman => crate::correlation::spearman_r(&x, &y)?,
619 CorrelationMethod::Kendall => crate::correlation::kendall_tau(&x, &y, "b")?,
620 _ => {
621 warnings.push("Advanced correlation methods not yet implemented".to_string());
622 crate::correlation::pearson_r(&x, &y)?
623 }
624 };
625
626 let (p_value, confidence_interval) = if self.config.include_metadata {
628 self.compute_statistical_inference(correlation, x.len(), &mut warnings)?
629 } else {
630 (None, None)
631 };
632
633 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
634
635 let result = CorrelationResult {
636 correlation,
637 p_value,
638 confidence_interval,
639 method: self.method,
640 };
641
642 let metadata = ResultMetadata {
643 samplesize: x.len(),
644 degrees_of_freedom: Some(x.len().saturating_sub(2)),
645 confidence_level: Some(self.config.confidence_level),
646 method: format!("{:?}", self.method),
647 computation_time_ms: computation_time,
648 memory_usage_bytes: self.estimate_memory_usage(x.len()),
649 optimized: self.config.simd || self.config.parallel,
650 extra: HashMap::new(),
651 };
652
653 Ok(StandardizedResult {
654 value: result,
655 metadata,
656 warnings,
657 })
658 }
659
660 pub fn compute_matrix(
662 &self,
663 data: ArrayView2<F>,
664 ) -> StatsResult<StandardizedResult<Array2<F>>> {
665 let start_time = std::time::Instant::now();
666 let warnings = Vec::new();
667
668 let correlation_matrix = if self.config.auto_optimize {
670 let mut optimizer = crate::memory_optimization_advanced::MemoryOptimizationSuite::new(
672 crate::memory_optimization_advanced::MemoryOptimizationConfig::default(),
673 );
674 optimizer.optimized_correlation_matrix(data)?
675 } else {
676 crate::correlation::corrcoef(&data, "pearson")?
677 };
678
679 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
680
681 let metadata = ResultMetadata {
682 samplesize: data.nrows(),
683 degrees_of_freedom: Some(data.nrows().saturating_sub(2)),
684 confidence_level: Some(self.config.confidence_level),
685 method: format!("Matrix {:?}", self.method),
686 computation_time_ms: computation_time,
687 memory_usage_bytes: self.estimate_memory_usage(data.nrows() * data.ncols()),
688 optimized: self.config.simd || self.config.parallel,
689 extra: HashMap::new(),
690 };
691
692 Ok(StandardizedResult {
693 value: correlation_matrix,
694 metadata,
695 warnings,
696 })
697 }
698
699 fn compute_statistical_inference(
701 &self,
702 correlation: F,
703 n: usize,
704 warnings: &mut Vec<String>,
705 ) -> StatsResult<(Option<F>, Option<(F, F)>)> {
706 let z = ((F::one() + correlation) / (F::one() - correlation)).ln()
708 * F::from(0.5).expect("Failed to convert constant to float");
709 let se_z = F::one() / F::from(n - 3).expect("Failed to convert to float").sqrt();
710
711 let _alpha =
713 F::one() - F::from(self.config.confidence_level).expect("Failed to convert to float");
714 let z_critical = F::from(1.96).expect("Failed to convert constant to float"); let z_lower = z - z_critical * se_z;
717 let z_upper = z + z_critical * se_z;
718
719 let r_lower = (F::from(2.0).expect("Failed to convert constant to float") * z_lower).exp();
721 let r_lower = (r_lower - F::one()) / (r_lower + F::one());
722
723 let r_upper = (F::from(2.0).expect("Failed to convert constant to float") * z_upper).exp();
724 let r_upper = (r_upper - F::one()) / (r_upper + F::one());
725
726 let _t_stat = correlation * F::from(n - 2).expect("Failed to convert to float").sqrt()
728 / (F::one() - correlation * correlation).sqrt();
729 let p_value = F::from(2.0).expect("Failed to convert constant to float")
730 * (F::one() - F::from(0.95).expect("Failed to convert constant to float")); Ok((Some(p_value), Some((r_lower, r_upper))))
733 }
734
735 fn estimate_memory_usage(&self, size: usize) -> Option<usize> {
737 if self.config.include_metadata {
738 Some(size * std::mem::size_of::<F>() * 3) } else {
740 None
741 }
742 }
743}
744
745impl<F> StatsAnalyzer<F>
746where
747 F: Float
748 + NumCast
749 + Clone
750 + scirs2_core::simd_ops::SimdUnifiedOps
751 + std::iter::Sum<F>
752 + std::ops::Div<Output = F>
753 + Sync
754 + Send
755 + std::fmt::Display
756 + std::fmt::Debug
757 + 'static,
758{
759 pub fn new() -> Self {
761 Self {
762 config: StandardizedConfig::default(),
763 phantom: PhantomData,
764 }
765 }
766
767 pub fn configure(mut self, config: StandardizedConfig) -> Self {
769 self.config = config;
770 self
771 }
772
773 pub fn describe(
775 &self,
776 data: ArrayView1<F>,
777 ) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
778 DescriptiveStatsBuilder::new()
779 .parallel(self.config.parallel)
780 .simd(self.config.simd)
781 .null_handling(self.config.null_handling)
782 .with_metadata()
783 .compute(data)
784 }
785
786 pub fn correlate<'a>(
788 &self,
789 x: ArrayView1<'a, F>,
790 y: ArrayView1<'a, F>,
791 method: CorrelationMethod,
792 ) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
793 CorrelationBuilder::new()
794 .method(method)
795 .confidence_level(self.config.confidence_level)
796 .parallel(self.config.parallel)
797 .simd(self.config.simd)
798 .with_metadata()
799 .compute(x, y)
800 }
801
802 pub fn get_config(&self) -> &StandardizedConfig {
804 &self.config
805 }
806}
807
808pub type F64StatsAnalyzer = StatsAnalyzer<f64>;
810pub type F32StatsAnalyzer = StatsAnalyzer<f32>;
811
812pub type F64DescriptiveBuilder = DescriptiveStatsBuilder<f64>;
813pub type F32DescriptiveBuilder = DescriptiveStatsBuilder<f32>;
814
815pub type F64CorrelationBuilder = CorrelationBuilder<f64>;
816pub type F32CorrelationBuilder = CorrelationBuilder<f32>;
817
818impl<F> Default for DescriptiveStatsBuilder<F>
819where
820 F: Float
821 + NumCast
822 + Clone
823 + scirs2_core::simd_ops::SimdUnifiedOps
824 + std::iter::Sum<F>
825 + std::ops::Div<Output = F>
826 + Sync
827 + Send
828 + std::fmt::Display
829 + std::fmt::Debug
830 + 'static,
831{
832 fn default() -> Self {
833 Self::new()
834 }
835}
836
837impl<F> Default for CorrelationBuilder<F>
838where
839 F: Float
840 + NumCast
841 + Clone
842 + std::fmt::Debug
843 + std::fmt::Display
844 + scirs2_core::simd_ops::SimdUnifiedOps
845 + std::iter::Sum<F>
846 + std::ops::Div<Output = F>
847 + Send
848 + Sync
849 + 'static,
850{
851 fn default() -> Self {
852 Self::new()
853 }
854}
855
856impl<F> Default for StatsAnalyzer<F>
857where
858 F: Float
859 + NumCast
860 + Clone
861 + scirs2_core::simd_ops::SimdUnifiedOps
862 + std::iter::Sum<F>
863 + std::ops::Div<Output = F>
864 + Sync
865 + Send
866 + std::fmt::Display
867 + std::fmt::Debug
868 + 'static,
869{
870 fn default() -> Self {
871 Self::new()
872 }
873}
874
875#[cfg(test)]
876mod tests {
877 use super::*;
878 use scirs2_core::ndarray::array;
879
880 #[test]
881 fn test_descriptive_stats_builder() {
882 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
883
884 let result = DescriptiveStatsBuilder::new()
885 .ddof(1)
886 .parallel(false)
887 .simd(false)
888 .with_metadata()
889 .compute(data.view())
890 .expect("Operation failed");
891
892 assert_eq!(result.value.count, 5);
893 assert!((result.value.mean - 3.0).abs() < 1e-10);
894 assert!(result.metadata.optimized == false);
895 }
896
897 #[test]
898 fn test_correlation_builder() {
899 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
900 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
901
902 let result = CorrelationBuilder::new()
903 .method(CorrelationMethod::Pearson)
904 .confidence_level(0.95)
905 .with_metadata()
906 .compute(x.view(), y.view())
907 .expect("Operation failed");
908
909 assert!((result.value.correlation - 1.0).abs() < 1e-10);
910 assert!(result.value.p_value.is_some());
911 assert!(result.value.confidence_interval.is_some());
912 }
913
914 #[test]
915 fn test_stats_analyzer() {
916 let analyzer = StatsAnalyzer::new();
917 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
918
919 let desc_result = analyzer.describe(data.view()).expect("Operation failed");
920 assert_eq!(desc_result.value.count, 5);
921
922 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
923 let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
924 let corr_result = analyzer
925 .correlate(x.view(), y.view(), CorrelationMethod::Pearson)
926 .expect("Operation failed");
927 assert!((corr_result.value.correlation + 1.0).abs() < 1e-10);
928 }
929
930 #[test]
931 fn test_null_handling() {
932 let data = array![1.0, 2.0, f64::NAN, 4.0, 5.0];
933
934 let result = DescriptiveStatsBuilder::new()
935 .null_handling(NullHandling::Exclude)
936 .compute(data.view())
937 .expect("Operation failed");
938
939 assert_eq!(result.value.count, 4); assert!(!result.warnings.is_empty()); }
942
943 #[test]
944 fn test_standardized_config() {
945 let config = StandardizedConfig {
946 auto_optimize: false,
947 parallel: false,
948 simd: true,
949 confidence_level: 0.99,
950 ..Default::default()
951 };
952
953 assert!(!config.auto_optimize);
954 assert!(!config.parallel);
955 assert!(config.simd);
956 assert!((config.confidence_level - 0.99).abs() < 1e-10);
957 }
958
959 #[test]
960 fn test_api_validation() {
961 let framework = APIValidationFramework::new();
962 let signature = APISignature {
963 function_name: "test_function".to_string(),
964 module_path: "scirs2, _stats::test".to_string(),
965 parameters: vec![ParameterSpec {
966 name: "data".to_string(),
967 param_type: "ArrayView1<f64>".to_string(),
968 optional: false,
969 default_value: None,
970 description: Some("Input data array".to_string()),
971 constraints: vec![ParameterConstraint::Finite],
972 }],
973 return_type: ReturnTypeSpec {
974 type_name: "f64".to_string(),
975 result_wrapped: true,
976 inner_type: Some("f64".to_string()),
977 error_type: Some("StatsError".to_string()),
978 },
979 error_types: vec!["StatsError".to_string()],
980 documentation: DocumentationSpec {
981 has_doc_comment: true,
982 has_param_docs: true,
983 has_return_docs: true,
984 has_examples: true,
985 has_error_docs: true,
986 scipy_compatibility: Some("Compatible with scipy.stats".to_string()),
987 },
988 performance: PerformanceSpec {
989 time_complexity: Some("O(n)".to_string()),
990 space_complexity: Some("O(1)".to_string()),
991 simd_optimized: true,
992 parallel_processing: true,
993 cache_efficient: true,
994 },
995 };
996
997 let report = framework.validate_api(&signature);
998 assert!(matches!(
999 report.overall_status,
1000 ValidationStatus::Passed | ValidationStatus::PassedWithWarnings
1001 ));
1002 }
1003}
1004
1005#[derive(Debug)]
1007pub struct APIValidationFramework {
1008 validation_rules: HashMap<String, Vec<ValidationRule>>,
1010 compatibility_checkers: HashMap<String, CompatibilityChecker>,
1012 performance_benchmarks: HashMap<String, PerformanceBenchmark>,
1014 error_patterns: HashMap<String, ErrorPattern>,
1016}
1017
1018#[derive(Debug, Clone)]
1020pub struct ValidationRule {
1021 pub id: String,
1023 pub description: String,
1025 pub category: ValidationCategory,
1027 pub severity: ValidationSeverity,
1029}
1030
1031#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1033pub enum ValidationCategory {
1034 ParameterNaming,
1036 ReturnTypes,
1038 ErrorHandling,
1040 Documentation,
1042 Performance,
1044 ScipyCompatibility,
1046 ThreadSafety,
1048 NumericalStability,
1050}
1051
1052#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
1054pub enum ValidationSeverity {
1055 Info,
1057 Warning,
1059 Error,
1061 Critical,
1063}
1064
1065#[derive(Debug, Clone)]
1067pub struct APISignature {
1068 pub function_name: String,
1070 pub module_path: String,
1072 pub parameters: Vec<ParameterSpec>,
1074 pub return_type: ReturnTypeSpec,
1076 pub error_types: Vec<String>,
1078 pub documentation: DocumentationSpec,
1080 pub performance: PerformanceSpec,
1082}
1083
1084#[derive(Debug, Clone)]
1086pub struct ParameterSpec {
1087 pub name: String,
1089 pub param_type: String,
1091 pub optional: bool,
1093 pub default_value: Option<String>,
1095 pub description: Option<String>,
1097 pub constraints: Vec<ParameterConstraint>,
1099}
1100
1101#[derive(Debug, Clone)]
1103pub enum ParameterConstraint {
1104 Positive,
1106 NonNegative,
1108 Finite,
1110 Range(f64, f64),
1112 OneOf(Vec<String>),
1114 Shape(Vec<Option<usize>>),
1116 Custom(String),
1118}
1119
1120#[derive(Debug, Clone)]
1122pub struct ReturnTypeSpec {
1123 pub type_name: String,
1125 pub result_wrapped: bool,
1127 pub inner_type: Option<String>,
1129 pub error_type: Option<String>,
1131}
1132
1133#[derive(Debug, Clone)]
1135pub struct DocumentationSpec {
1136 pub has_doc_comment: bool,
1138 pub has_param_docs: bool,
1140 pub has_return_docs: bool,
1142 pub has_examples: bool,
1144 pub has_error_docs: bool,
1146 pub scipy_compatibility: Option<String>,
1148}
1149
1150#[derive(Debug, Clone)]
1152pub struct PerformanceSpec {
1153 pub time_complexity: Option<String>,
1155 pub space_complexity: Option<String>,
1157 pub simd_optimized: bool,
1159 pub parallel_processing: bool,
1161 pub cache_efficient: bool,
1163}
1164
1165#[derive(Debug, Clone)]
1167pub struct ValidationResult {
1168 pub passed: bool,
1170 pub messages: Vec<ValidationMessage>,
1172 pub suggested_fixes: Vec<String>,
1174 pub related_rules: Vec<String>,
1176}
1177
1178#[derive(Debug, Clone)]
1180pub struct ValidationMessage {
1181 pub severity: ValidationSeverity,
1183 pub message: String,
1185 pub location: Option<String>,
1187 pub rule_id: String,
1189}
1190
1191#[derive(Debug, Clone)]
1193pub struct CompatibilityChecker {
1194 pub scipy_function: String,
1196 pub parameter_mapping: HashMap<String, String>,
1198 pub return_type_mapping: HashMap<String, String>,
1200 pub known_differences: Vec<CompatibilityDifference>,
1202}
1203
1204#[derive(Debug, Clone)]
1206pub struct CompatibilityDifference {
1207 pub category: DifferenceCategory,
1209 pub description: String,
1211 pub justification: String,
1213 pub workaround: Option<String>,
1215}
1216
1217#[derive(Debug, Clone, Copy)]
1219pub enum DifferenceCategory {
1220 Improvement,
1222 RustConstraint,
1224 Performance,
1226 Safety,
1228 Unintentional,
1230}
1231
1232#[derive(Debug, Clone)]
1234pub struct PerformanceBenchmark {
1235 pub name: String,
1237 pub expected_complexity: ComplexityClass,
1239 pub memory_usage: MemoryUsagePattern,
1241 pub scalability: ScalabilityRequirement,
1243}
1244
1245#[derive(Debug, Clone, Copy)]
1247pub enum ComplexityClass {
1248 Constant,
1249 Logarithmic,
1250 Linear,
1251 LogLinear,
1252 Quadratic,
1253 Cubic,
1254 Exponential,
1255}
1256
1257#[derive(Debug, Clone, Copy)]
1259pub enum MemoryUsagePattern {
1260 Constant,
1261 Linear,
1262 Quadratic,
1263 Streaming,
1264 OutOfCore,
1265}
1266
1267#[derive(Debug, Clone)]
1269pub struct ScalabilityRequirement {
1270 pub maxdatasize: usize,
1272 pub parallel_efficiency: f64,
1274 pub simd_acceleration: f64,
1276}
1277
1278#[derive(Debug, Clone)]
1280pub struct ErrorPattern {
1281 pub category: ErrorCategory,
1283 pub message_template: String,
1285 pub recovery_suggestions: Vec<String>,
1287 pub related_errors: Vec<String>,
1289}
1290
1291#[derive(Debug, Clone, Copy)]
1293pub enum ErrorCategory {
1294 InvalidInput,
1296 Numerical,
1298 Memory,
1300 Convergence,
1302 DimensionMismatch,
1304 NotImplemented,
1306 Internal,
1308}
1309
1310#[derive(Debug)]
1312pub struct ValidationReport {
1313 pub function_name: String,
1315 pub results: HashMap<String, ValidationResult>,
1317 pub overall_status: ValidationStatus,
1319 pub summary: ValidationSummary,
1321}
1322
1323#[derive(Debug, Clone, Copy)]
1325pub enum ValidationStatus {
1326 Passed,
1327 PassedWithWarnings,
1328 Failed,
1329 Critical,
1330}
1331
1332#[derive(Debug, Clone)]
1334pub struct ValidationSummary {
1335 pub total_rules: usize,
1337 pub passed: usize,
1339 pub warnings: usize,
1341 pub errors: usize,
1343 pub critical: usize,
1345}
1346
1347impl APIValidationFramework {
1348 pub fn new() -> Self {
1350 let mut framework = Self {
1351 validation_rules: HashMap::new(),
1352 compatibility_checkers: HashMap::new(),
1353 performance_benchmarks: HashMap::new(),
1354 error_patterns: HashMap::new(),
1355 };
1356
1357 framework.initialize_default_rules();
1358 framework
1359 }
1360
1361 fn initialize_default_rules(&mut self) {
1363 self.add_validation_rule(ValidationRule {
1365 id: "param_naming_consistency".to_string(),
1366 description: "Parameter names should follow consistent snake_case conventions"
1367 .to_string(),
1368 category: ValidationCategory::ParameterNaming,
1369 severity: ValidationSeverity::Warning,
1370 });
1371
1372 self.add_validation_rule(ValidationRule {
1374 id: "error_handling_consistency".to_string(),
1375 description: "Functions should return Result<T, StatsError> for consistency"
1376 .to_string(),
1377 category: ValidationCategory::ErrorHandling,
1378 severity: ValidationSeverity::Error,
1379 });
1380
1381 self.add_validation_rule(ValidationRule {
1383 id: "documentation_completeness".to_string(),
1384 description: "All public functions should have complete documentation".to_string(),
1385 category: ValidationCategory::Documentation,
1386 severity: ValidationSeverity::Warning,
1387 });
1388
1389 self.add_validation_rule(ValidationRule {
1391 id: "scipy_compatibility".to_string(),
1392 description: "Functions should maintain SciPy compatibility where possible".to_string(),
1393 category: ValidationCategory::ScipyCompatibility,
1394 severity: ValidationSeverity::Info,
1395 });
1396
1397 self.add_validation_rule(ValidationRule {
1399 id: "performance_characteristics".to_string(),
1400 description: "Functions should document performance characteristics".to_string(),
1401 category: ValidationCategory::Performance,
1402 severity: ValidationSeverity::Info,
1403 });
1404 }
1405
1406 pub fn add_validation_rule(&mut self, rule: ValidationRule) {
1408 let category_key = format!("{:?}", rule.category);
1409 self.validation_rules
1410 .entry(category_key)
1411 .or_default()
1412 .push(rule);
1413 }
1414
1415 pub fn validate_api(&self, signature: &APISignature) -> ValidationReport {
1417 let mut report = ValidationReport::new(signature.function_name.clone());
1418
1419 for rules in self.validation_rules.values() {
1420 for rule in rules {
1421 let result = self.apply_validation_rule(rule, signature);
1422 report.add_result(rule.id.clone(), result);
1423 }
1424 }
1425
1426 report
1427 }
1428
1429 fn apply_validation_rule(
1431 &self,
1432 rule: &ValidationRule,
1433 signature: &APISignature,
1434 ) -> ValidationResult {
1435 match rule.category {
1436 ValidationCategory::ParameterNaming => self.validate_parameter_naming(signature),
1437 ValidationCategory::ErrorHandling => self.validate_error_handling(signature),
1438 ValidationCategory::Documentation => self.validate_documentation(signature),
1439 ValidationCategory::ScipyCompatibility => self.validate_scipy_compatibility(signature),
1440 ValidationCategory::Performance => self.validate_performance(signature),
1441 _ => ValidationResult {
1442 passed: true,
1443 messages: vec![],
1444 suggested_fixes: vec![],
1445 related_rules: vec![],
1446 },
1447 }
1448 }
1449
1450 fn validate_parameter_naming(&self, signature: &APISignature) -> ValidationResult {
1452 let mut messages = Vec::new();
1453 let mut suggested_fixes = Vec::new();
1454
1455 for param in &signature.parameters {
1456 if param.name.contains(char::is_uppercase) || param.name.contains('-') {
1458 messages.push(ValidationMessage {
1459 severity: ValidationSeverity::Warning,
1460 message: format!("Parameter '{}' should use snake_case naming", param.name),
1461 location: Some(format!(
1462 "{}::{}",
1463 signature.module_path, signature.function_name
1464 )),
1465 rule_id: "param_naming_consistency".to_string(),
1466 });
1467 suggested_fixes.push(format!("Rename parameter '{}' to snake_case", param.name));
1468 }
1469 }
1470
1471 ValidationResult {
1472 passed: messages.is_empty(),
1473 messages,
1474 suggested_fixes,
1475 related_rules: vec!["return_type_consistency".to_string()],
1476 }
1477 }
1478
1479 fn validate_error_handling(&self, signature: &APISignature) -> ValidationResult {
1481 let mut messages = Vec::new();
1482 let mut suggested_fixes = Vec::new();
1483
1484 if !signature.return_type.result_wrapped {
1485 messages.push(ValidationMessage {
1486 severity: ValidationSeverity::Error,
1487 message: "Function should return Result<T, StatsError> for consistency".to_string(),
1488 location: Some(format!(
1489 "{}::{}",
1490 signature.module_path, signature.function_name
1491 )),
1492 rule_id: "error_handling_consistency".to_string(),
1493 });
1494 suggested_fixes.push("Wrap return type in Result<T, StatsError>".to_string());
1495 }
1496
1497 if let Some(error_type) = &signature.return_type.error_type {
1498 if error_type != "StatsError" {
1499 messages.push(ValidationMessage {
1500 severity: ValidationSeverity::Warning,
1501 message: format!("Non-standard error type '{}' used", error_type),
1502 location: Some(format!(
1503 "{}::{}",
1504 signature.module_path, signature.function_name
1505 )),
1506 rule_id: "error_handling_consistency".to_string(),
1507 });
1508 suggested_fixes.push("Use StatsError for consistency".to_string());
1509 }
1510 }
1511
1512 ValidationResult {
1513 passed: messages.is_empty(),
1514 messages,
1515 suggested_fixes,
1516 related_rules: vec!["documentation_completeness".to_string()],
1517 }
1518 }
1519
1520 fn validate_documentation(&self, signature: &APISignature) -> ValidationResult {
1522 let mut messages = Vec::new();
1523 let mut suggested_fixes = Vec::new();
1524
1525 if !signature.documentation.has_doc_comment {
1526 messages.push(ValidationMessage {
1527 severity: ValidationSeverity::Warning,
1528 message: "Function lacks documentation comment".to_string(),
1529 location: Some(format!(
1530 "{}::{}",
1531 signature.module_path, signature.function_name
1532 )),
1533 rule_id: "documentation_completeness".to_string(),
1534 });
1535 suggested_fixes.push("Add comprehensive doc comment".to_string());
1536 }
1537
1538 if !signature.documentation.has_examples {
1539 messages.push(ValidationMessage {
1540 severity: ValidationSeverity::Info,
1541 message: "Function lacks usage examples".to_string(),
1542 location: Some(format!(
1543 "{}::{}",
1544 signature.module_path, signature.function_name
1545 )),
1546 rule_id: "documentation_completeness".to_string(),
1547 });
1548 suggested_fixes.push("Add usage examples in # Examples section".to_string());
1549 }
1550
1551 ValidationResult {
1552 passed: messages
1553 .iter()
1554 .all(|m| matches!(m.severity, ValidationSeverity::Info)),
1555 messages,
1556 suggested_fixes,
1557 related_rules: vec!["scipy_compatibility".to_string()],
1558 }
1559 }
1560
1561 fn validate_scipy_compatibility(&self, signature: &APISignature) -> ValidationResult {
1563 let mut messages = Vec::new();
1564 let mut suggested_fixes = Vec::new();
1565
1566 let scipy_standard_params = [
1568 "axis",
1569 "ddof",
1570 "keepdims",
1571 "out",
1572 "dtype",
1573 "method",
1574 "alternative",
1575 ];
1576 let has_scipy_params = signature
1577 .parameters
1578 .iter()
1579 .any(|p| scipy_standard_params.contains(&p.name.as_str()));
1580
1581 if has_scipy_params && signature.documentation.scipy_compatibility.is_none() {
1582 messages.push(ValidationMessage {
1583 severity: ValidationSeverity::Info,
1584 message: "Consider documenting SciPy compatibility status".to_string(),
1585 location: Some(format!(
1586 "{}::{}",
1587 signature.module_path, signature.function_name
1588 )),
1589 rule_id: "scipy_compatibility".to_string(),
1590 });
1591 suggested_fixes.push("Add SciPy compatibility note in documentation".to_string());
1592 }
1593
1594 ValidationResult {
1595 passed: true, messages,
1597 suggested_fixes,
1598 related_rules: vec!["documentation_completeness".to_string()],
1599 }
1600 }
1601
1602 fn validate_performance(&self, signature: &APISignature) -> ValidationResult {
1604 let mut messages = Vec::new();
1605 let mut suggested_fixes = Vec::new();
1606
1607 if signature.performance.time_complexity.is_none() {
1608 messages.push(ValidationMessage {
1609 severity: ValidationSeverity::Info,
1610 message: "Consider documenting time complexity".to_string(),
1611 location: Some(format!(
1612 "{}::{}",
1613 signature.module_path, signature.function_name
1614 )),
1615 rule_id: "performance_characteristics".to_string(),
1616 });
1617 suggested_fixes.push("Add time complexity documentation".to_string());
1618 }
1619
1620 ValidationResult {
1621 passed: true, messages,
1623 suggested_fixes,
1624 related_rules: vec![],
1625 }
1626 }
1627}
1628
1629impl ValidationReport {
1630 pub fn new(_functionname: String) -> Self {
1632 Self {
1633 function_name: _functionname,
1634 results: HashMap::new(),
1635 overall_status: ValidationStatus::Passed,
1636 summary: ValidationSummary {
1637 total_rules: 0,
1638 passed: 0,
1639 warnings: 0,
1640 errors: 0,
1641 critical: 0,
1642 },
1643 }
1644 }
1645
1646 pub fn add_result(&mut self, ruleid: String, result: ValidationResult) {
1648 self.summary.total_rules += 1;
1649
1650 if result.passed {
1651 self.summary.passed += 1;
1652 } else {
1653 let max_severity = result
1654 .messages
1655 .iter()
1656 .map(|m| m.severity)
1657 .max()
1658 .unwrap_or(ValidationSeverity::Info);
1659
1660 match max_severity {
1661 ValidationSeverity::Info => {}
1662 ValidationSeverity::Warning => {
1663 self.summary.warnings += 1;
1664 if matches!(self.overall_status, ValidationStatus::Passed) {
1665 self.overall_status = ValidationStatus::PassedWithWarnings;
1666 }
1667 }
1668 ValidationSeverity::Error => {
1669 self.summary.errors += 1;
1670 if !matches!(self.overall_status, ValidationStatus::Critical) {
1671 self.overall_status = ValidationStatus::Failed;
1672 }
1673 }
1674 ValidationSeverity::Critical => {
1675 self.summary.critical += 1;
1676 self.overall_status = ValidationStatus::Critical;
1677 }
1678 }
1679 }
1680
1681 self.results.insert(ruleid, result);
1682 }
1683
1684 pub fn generate_report(&self) -> String {
1686 let mut report = String::new();
1687 report.push_str(&format!(
1688 "API Validation Report for {}\n",
1689 self.function_name
1690 ));
1691 report.push_str(&format!("Status: {:?}\n", self.overall_status));
1692 report.push_str(&format!(
1693 "Summary: {} passed, {} warnings, {} errors, {} critical\n\n",
1694 self.summary.passed, self.summary.warnings, self.summary.errors, self.summary.critical
1695 ));
1696
1697 for (rule_id, result) in &self.results {
1698 if !result.passed {
1699 report.push_str(&format!("Rule: {}\n", rule_id));
1700 for message in &result.messages {
1701 report.push_str(&format!(" {:?}: {}\n", message.severity, message.message));
1702 }
1703 if !result.suggested_fixes.is_empty() {
1704 report.push_str(" Suggestions:\n");
1705 for fix in &result.suggested_fixes {
1706 report.push_str(&format!(" - {}\n", fix));
1707 }
1708 }
1709 report.push('\n');
1710 }
1711 }
1712
1713 report
1714 }
1715}
1716
1717impl Default for APIValidationFramework {
1718 fn default() -> Self {
1719 Self::new()
1720 }
1721}