1use ndarray::{par_azip, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2, Zip};
8use num_traits::{Float, NumCast};
9use scirs2_core::parallel_ops::*;
10use scirs2_core::validation::check_not_empty;
11use std::collections::HashMap;
12
13use crate::error::{Result, TransformError};
14use statrs::statistics::Statistics;
15
16#[derive(Debug, Clone)]
18pub struct DataChunker {
19 _max_memorymb: usize,
21 preferred_chunk_size: usize,
23 min_chunk_size: usize,
25}
26
27impl DataChunker {
28 pub fn new(_max_memorymb: usize) -> Self {
30 DataChunker {
31 _max_memorymb,
32 preferred_chunk_size: 10000,
33 min_chunk_size: 100,
34 }
35 }
36
37 pub fn calculate_chunk_size(&self, n_samples: usize, nfeatures: usize) -> usize {
39 let bytes_per_sample = nfeatures * std::mem::size_of::<f64>() + 64; let max_samples_in_memory = (self._max_memorymb * 1024 * 1024) / bytes_per_sample;
42
43 max_samples_in_memory
44 .min(self.preferred_chunk_size)
45 .max(self.min_chunk_size)
46 .min(n_samples)
47 }
48
49 pub fn chunk_indices(&self, n_samples: usize, nfeatures: usize) -> ChunkIterator {
51 let chunk_size = self.calculate_chunk_size(n_samples, nfeatures);
52 ChunkIterator {
53 current: 0,
54 total: n_samples,
55 chunk_size,
56 }
57 }
58}
59
60#[derive(Debug)]
62pub struct ChunkIterator {
63 current: usize,
64 total: usize,
65 chunk_size: usize,
66}
67
68impl Iterator for ChunkIterator {
69 type Item = (usize, usize); fn next(&mut self) -> Option<Self::Item> {
72 if self.current >= self.total {
73 return None;
74 }
75
76 let start = self.current;
77 let end = (self.current + self.chunk_size).min(self.total);
78 self.current = end;
79
80 Some((start, end))
81 }
82}
83
84pub struct TypeConverter;
86
87impl TypeConverter {
88 pub fn to_f64<T, S>(array: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
90 where
91 T: Float + NumCast + Send + Sync,
92 S: Data<Elem = T>,
93 {
94 check_not_empty(array, "array")?;
95
96 let result = if array.is_standard_layout() {
97 if array.len() > 10000 {
99 let mut result = Array2::zeros(array.raw_dim());
100 Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
101 *out = num_traits::cast::<T, f64>(inp).unwrap_or(0.0);
102 });
103 result
104 } else {
105 array.mapv(|x| num_traits::cast::<T, f64>(x).unwrap_or(0.0))
106 }
107 } else {
108 let shape = array.shape();
110 let mut result = Array2::zeros((shape[0], shape[1]));
111
112 par_azip!((out in result.view_mut(), &inp in array) {
113 *out = num_traits::cast::<T, f64>(inp).unwrap_or(0.0);
114 });
115
116 result
117 };
118
119 for &val in result.iter() {
121 if !val.is_finite() {
122 return Err(crate::error::TransformError::DataValidationError(
123 "Array contains non-finite values after conversion".to_string(),
124 ));
125 }
126 }
127 Ok(result)
128 }
129
130 pub fn f32_to_f64_simd(array: &ArrayView2<f32>) -> Result<Array2<f64>> {
132 check_not_empty(array, "array")?;
133
134 let result = if array.len() > 10000 {
135 let mut result = Array2::zeros(array.raw_dim());
136 Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
137 *out = inp as f64;
138 });
139 result
140 } else {
141 array.mapv(|x| x as f64)
142 };
143
144 for &val in result.iter() {
145 if !val.is_finite() {
146 return Err(crate::error::TransformError::DataValidationError(
147 "Array contains non-finite values after conversion".to_string(),
148 ));
149 }
150 }
151 Ok(result)
152 }
153
154 pub fn f64_to_f32_safe(array: &ArrayView2<f64>) -> Result<Array2<f32>> {
156 check_not_empty(array, "array")?;
157
158 for &val in array.iter() {
160 if !val.is_finite() {
161 return Err(crate::error::TransformError::DataValidationError(
162 "Array contains non-finite values".to_string(),
163 ));
164 }
165 }
166
167 let mut result = Array2::zeros(array.raw_dim());
168 for (out, &inp) in result.iter_mut().zip(array.iter()) {
169 if inp.abs() > f32::MAX as f64 {
170 return Err(TransformError::DataValidationError(
171 "Value too large for f32 conversion".to_string(),
172 ));
173 }
174 *out = inp as f32;
175 }
176
177 Ok(result)
178 }
179}
180
181pub struct StatUtils;
183
184impl StatUtils {
185 pub fn robust_stats(data: &ArrayView1<f64>) -> Result<(f64, f64)> {
187 check_not_empty(data, "data")?;
188
189 for &val in data.iter() {
191 if !val.is_finite() {
192 return Err(crate::error::TransformError::DataValidationError(
193 "Data contains non-finite values".to_string(),
194 ));
195 }
196 }
197
198 let mut sorted_data = data.to_vec();
199 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
200
201 let n = sorted_data.len();
202 let median = if n % 2 == 0 {
203 (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
204 } else {
205 sorted_data[n / 2]
206 };
207
208 let mut deviations: Vec<f64> = sorted_data.iter().map(|&x| (x - median).abs()).collect();
210 deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
211
212 let mad = if n % 2 == 0 {
213 (deviations[n / 2 - 1] + deviations[n / 2]) / 2.0
214 } else {
215 deviations[n / 2]
216 };
217
218 Ok((median, mad))
219 }
220
221 pub fn robust_stats_columns(data: &ArrayView2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
223 check_not_empty(data, "data")?;
224
225 for &val in data.iter() {
227 if !val.is_finite() {
228 return Err(crate::error::TransformError::DataValidationError(
229 "Data contains non-finite values".to_string(),
230 ));
231 }
232 }
233
234 let nfeatures = data.ncols();
235 let mut medians = Array1::zeros(nfeatures);
236 let mut mads = Array1::zeros(nfeatures);
237
238 let stats: Result<Vec<_>> = (0..nfeatures)
240 .into_par_iter()
241 .map(|j| {
242 let col = data.column(j);
243 Self::robust_stats(&col)
244 })
245 .collect();
246
247 let stats = stats?;
248
249 for (j, (median, mad)) in stats.into_iter().enumerate() {
250 medians[j] = median;
251 mads[j] = mad;
252 }
253
254 Ok((medians, mads))
255 }
256
257 pub fn detect_outliers_iqr(data: &ArrayView1<f64>, factor: f64) -> Result<Vec<bool>> {
259 check_not_empty(data, "data")?;
260
261 for &val in data.iter() {
263 if !val.is_finite() {
264 return Err(crate::error::TransformError::DataValidationError(
265 "Data contains non-finite values".to_string(),
266 ));
267 }
268 }
269
270 if factor <= 0.0 {
271 return Err(TransformError::InvalidInput(
272 "Outlier factor must be positive".to_string(),
273 ));
274 }
275
276 let mut sorted_data = data.to_vec();
277 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
278
279 let n = sorted_data.len();
280 let q1_idx = n / 4;
281 let q3_idx = 3 * n / 4;
282
283 let q1 = sorted_data[q1_idx];
284 let q3 = sorted_data[q3_idx];
285 let iqr = q3 - q1;
286
287 let lower_bound = q1 - factor * iqr;
288 let upper_bound = q3 + factor * iqr;
289
290 let outliers = data
291 .iter()
292 .map(|&x| x < lower_bound || x > upper_bound)
293 .collect();
294
295 Ok(outliers)
296 }
297
298 pub fn data_quality_score(data: &ArrayView2<f64>) -> Result<f64> {
300 check_not_empty(data, "data")?;
301
302 let total_elements = data.len() as f64;
303
304 let finite_count = data.iter().filter(|&&x| x.is_finite()).count() as f64;
306 let finite_ratio = finite_count / total_elements;
307
308 let nfeatures = data.ncols();
310 let mut diversity_scores = Vec::with_capacity(nfeatures);
311
312 for j in 0..nfeatures {
313 let col = data.column(j);
314 let mut unique_values = std::collections::HashSet::new();
315 for &val in col.iter() {
316 if val.is_finite() {
317 let rounded = (val * 1e12).round() as i64;
319 unique_values.insert(rounded);
320 }
321 }
322
323 let diversity = if !col.is_empty() {
324 unique_values.len() as f64 / col.len() as f64
325 } else {
326 0.0
327 };
328 diversity_scores.push(diversity);
329 }
330
331 let avg_diversity = if diversity_scores.is_empty() {
332 0.0
333 } else {
334 diversity_scores.iter().sum::<f64>() / diversity_scores.len() as f64
335 };
336
337 let quality_score = 0.7 * finite_ratio + 0.3 * avg_diversity;
339
340 Ok(quality_score.clamp(0.0, 1.0))
341 }
342}
343
344pub struct ArrayMemoryPool<T> {
346 available_arrays: HashMap<(usize, usize), Vec<Array2<T>>>,
348 max_persize: usize,
350 memory_limit: usize,
352 current_memory: usize,
354}
355
356impl<T: Clone + Default> ArrayMemoryPool<T> {
357 pub fn new(_memory_limit_mb: usize, max_persize: usize) -> Self {
359 ArrayMemoryPool {
360 available_arrays: HashMap::new(),
361 max_persize,
362 memory_limit: _memory_limit_mb * 1024 * 1024,
363 current_memory: 0,
364 }
365 }
366
367 pub fn get_array(&mut self, rows: usize, cols: usize) -> Array2<T> {
369 let size_key = (rows, cols);
370
371 if let Some(arrays) = self.available_arrays.get_mut(&size_key) {
372 if let Some(array) = arrays.pop() {
373 let array_size = rows * cols * std::mem::size_of::<T>();
374 self.current_memory = self.current_memory.saturating_sub(array_size);
375 return array;
376 }
377 }
378
379 Array2::default((rows, cols))
381 }
382
383 pub fn return_array(&mut self, mut array: Array2<T>) {
385 let (rows, cols) = array.dim();
386 let size_key = (rows, cols);
387 let array_size = rows * cols * std::mem::size_of::<T>();
388
389 if self.current_memory + array_size > self.memory_limit {
391 return; }
393
394 array.fill(T::default());
396
397 let arrays = self.available_arrays.entry(size_key).or_default();
398 if arrays.len() < self.max_persize {
399 arrays.push(array);
400 self.current_memory += array_size;
401 }
402 }
403
404 pub fn clear(&mut self) {
406 self.available_arrays.clear();
407 self.current_memory = 0;
408 }
409
410 pub fn memory_usage_mb(&self) -> f64 {
412 self.current_memory as f64 / (1024.0 * 1024.0)
413 }
414}
415
416pub struct ValidationUtils;
418
419impl ValidationUtils {
420 pub fn validate_parameter_bounds(
422 value: f64,
423 min: f64,
424 max: f64,
425 param_name: &str,
426 ) -> Result<()> {
427 if !value.is_finite() {
428 return Err(TransformError::InvalidInput(format!(
429 "{param_name} must be finite"
430 )));
431 }
432
433 if value < min || value > max {
434 return Err(TransformError::InvalidInput(format!(
435 "{param_name} must be between {min} and {max}, got {value}"
436 )));
437 }
438
439 Ok(())
440 }
441
442 pub fn validate_dimensions_compatible(
444 shape1: &[usize],
445 shape2: &[usize],
446 operation: &str,
447 ) -> Result<()> {
448 if shape1.len() != shape2.len() {
449 return Err(TransformError::InvalidInput(format!(
450 "Incompatible dimensions for {operation}: {shape1:?} vs {shape2:?}"
451 )));
452 }
453
454 for (i, (&dim1, &dim2)) in shape1.iter().zip(shape2.iter()).enumerate() {
455 if dim1 != dim2 {
456 return Err(TransformError::InvalidInput(format!(
457 "Dimension {i} mismatch for {operation}: {dim1} vs {dim2}"
458 )));
459 }
460 }
461
462 Ok(())
463 }
464
465 pub fn validate_data_for_transformation(
467 data: &ArrayView2<f64>,
468 transformation: &str,
469 ) -> Result<()> {
470 check_not_empty(data, "data")?;
471
472 for &val in data.iter() {
474 if !val.is_finite() {
475 return Err(crate::error::TransformError::DataValidationError(
476 "Data contains non-finite values".to_string(),
477 ));
478 }
479 }
480
481 let (n_samples, nfeatures) = data.dim();
482
483 match transformation {
484 "pca" => {
485 if n_samples < 2 {
486 return Err(TransformError::InvalidInput(
487 "PCA requires at least 2 samples".to_string(),
488 ));
489 }
490 if nfeatures < 1 {
491 return Err(TransformError::InvalidInput(
492 "PCA requires at least 1 feature".to_string(),
493 ));
494 }
495 }
496 "standardization" => {
497 for j in 0..nfeatures {
499 let col = data.column(j);
500 let variance = col.variance();
501 if variance < 1e-15 {
502 return Err(TransformError::DataValidationError(format!(
503 "Feature {j} has zero variance and cannot be standardized"
504 )));
505 }
506 }
507 }
508 "normalization" => {
509 for i in 0..n_samples {
511 let row = data.row(i);
512 let norm = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
513 if norm < 1e-15 {
514 return Err(TransformError::DataValidationError(format!(
515 "Sample {i} has zero norm and cannot be normalized"
516 )));
517 }
518 }
519 }
520 _ => {
521 }
523 }
524
525 Ok(())
526 }
527}
528
529pub struct PerfUtils;
531
532impl PerfUtils {
533 pub fn estimate_memory_usage(
535 inputshape: &[usize],
536 outputshape: &[usize],
537 operation: &str,
538 ) -> usize {
539 let input_size = inputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
540 let output_size = outputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
541
542 let overhead = match operation {
543 "pca" => input_size * 2, "standardization" => input_size / 10, "polynomial" => output_size / 2, _ => input_size / 4, };
548
549 input_size + output_size + overhead
550 }
551
552 pub fn estimate_computation_time(
554 n_samples: usize,
555 nfeatures: usize,
556 operation: &str,
557 ) -> std::time::Duration {
558 use std::time::Duration;
559
560 let base_time_ns = match operation {
561 "pca" => (n_samples as u64) * (nfeatures as u64).pow(2) / 1000, "standardization" => (n_samples as u64) * (nfeatures as u64) / 100, "normalization" => (n_samples as u64) * (nfeatures as u64) / 50, "polynomial" => (n_samples as u64) * (nfeatures as u64).pow(3) / 10000, _ => (n_samples as u64) * (nfeatures as u64) / 100,
566 };
567
568 Duration::from_nanos(base_time_ns.max(1000)) }
570
571 pub fn choose_processing_strategy(
573 n_samples: usize,
574 nfeatures: usize,
575 available_memory_mb: usize,
576 ) -> ProcessingStrategy {
577 let estimated_memory_mb =
578 (n_samples * nfeatures * std::mem::size_of::<f64>()) / (1024 * 1024);
579
580 if estimated_memory_mb > available_memory_mb {
581 ProcessingStrategy::OutOfCore {
582 chunk_size: (available_memory_mb * 1024 * 1024)
583 / (nfeatures * std::mem::size_of::<f64>()),
584 }
585 } else if n_samples > 10000 && nfeatures > 100 {
586 ProcessingStrategy::Parallel
587 } else if nfeatures > 1000 {
588 ProcessingStrategy::Simd
589 } else {
590 ProcessingStrategy::Standard
591 }
592 }
593}
594
595#[derive(Debug, Clone)]
597#[cfg_attr(feature = "distributed", derive(serde::Serialize, serde::Deserialize))]
598pub enum ProcessingStrategy {
599 Standard,
601 Simd,
603 Parallel,
605 OutOfCore {
607 chunk_size: usize,
609 },
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use ndarray::Array2;
616
617 #[test]
618 fn test_data_chunker() {
619 let chunker = DataChunker::new(100); let chunk_size = chunker.calculate_chunk_size(50000, 100);
621 assert!(chunk_size > 0);
622 assert!(chunk_size <= 50000);
623 }
624
625 #[test]
626 fn test_chunk_iterator() {
627 let chunker = DataChunker::new(1); let chunks: Vec<_> = chunker.chunk_indices(1000, 10).collect();
629 assert!(!chunks.is_empty());
630
631 let total_covered = chunks.iter().map(|(start, end)| end - start).sum::<usize>();
633 assert_eq!(total_covered, 1000);
634 }
635
636 #[test]
637 fn test_type_converter() {
638 let data = Array2::<f32>::ones((10, 5));
639 let result = TypeConverter::f32_to_f64_simd(&data.view()).unwrap();
640 assert_eq!(result.shape(), &[10, 5]);
641 assert!((result[(0, 0)] - 1.0).abs() < 1e-10);
642 }
643
644 #[test]
645 fn test_robust_stats() {
646 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]); let (median, mad) = StatUtils::robust_stats(&data.view()).unwrap();
648 assert!((median - 3.5).abs() < 1e-10);
649 assert!(mad > 0.0);
650 }
651
652 #[test]
653 fn test_outlier_detection() {
654 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]);
655 let outliers = StatUtils::detect_outliers_iqr(&data.view(), 1.5).unwrap();
656 assert_eq!(outliers.len(), 6);
657 assert!(outliers[5]); }
659
660 #[test]
661 fn test_data_quality_score() {
662 let good_data =
663 Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
664 let quality = StatUtils::data_quality_score(&good_data.view()).unwrap();
665 assert!(quality > 0.5); let bad_data = Array2::from_elem((10, 3), f64::NAN);
668 let quality = StatUtils::data_quality_score(&bad_data.view()).unwrap();
669 assert!(quality < 0.5); }
671
672 #[test]
673 fn test_memory_pool() {
674 let mut pool = ArrayMemoryPool::<f64>::new(10, 2);
675
676 let array1 = pool.get_array(10, 5);
678 assert_eq!(array1.shape(), &[10, 5]);
679
680 pool.return_array(array1);
681
682 let array2 = pool.get_array(10, 5);
683 assert_eq!(array2.shape(), &[10, 5]);
684 }
685
686 #[test]
687 fn test_validation_utils() {
688 assert!(ValidationUtils::validate_parameter_bounds(0.5, 0.0, 1.0, "test").is_ok());
690 assert!(ValidationUtils::validate_parameter_bounds(1.5, 0.0, 1.0, "test").is_err());
691
692 assert!(
694 ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 5], "test").is_ok()
695 );
696 assert!(
697 ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 6], "test").is_err()
698 );
699 }
700
701 #[test]
702 fn test_performance_utils() {
703 let memory = PerfUtils::estimate_memory_usage(&[1000, 100], &[1000, 50], "pca");
704 assert!(memory > 0);
705
706 let time = PerfUtils::estimate_computation_time(1000, 100, "pca");
707 assert!(time.as_nanos() > 0);
708
709 let strategy = PerfUtils::choose_processing_strategy(10000, 100, 100);
710 matches!(strategy, ProcessingStrategy::Parallel);
711 }
712}