1use scirs2_core::ndarray::{
8 par_azip, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2, Zip,
9};
10use scirs2_core::numeric::{Float, NumCast};
11use scirs2_core::parallel_ops::*;
12use scirs2_core::validation::check_not_empty;
13use std::collections::HashMap;
14
15use crate::error::{Result, TransformError};
16use statrs::statistics::Statistics;
17
18#[derive(Debug, Clone)]
20pub struct DataChunker {
21 _max_memorymb: usize,
23 preferred_chunk_size: usize,
25 min_chunk_size: usize,
27}
28
29impl DataChunker {
30 pub fn new(_max_memorymb: usize) -> Self {
32 DataChunker {
33 _max_memorymb,
34 preferred_chunk_size: 10000,
35 min_chunk_size: 100,
36 }
37 }
38
39 pub fn calculate_chunk_size(&self, n_samples: usize, nfeatures: usize) -> usize {
41 let bytes_per_sample = nfeatures * std::mem::size_of::<f64>() + 64; let max_samples_in_memory = (self._max_memorymb * 1024 * 1024) / bytes_per_sample;
44
45 max_samples_in_memory
46 .min(self.preferred_chunk_size)
47 .max(self.min_chunk_size)
48 .min(n_samples)
49 }
50
51 pub fn chunk_indices(&self, n_samples: usize, nfeatures: usize) -> ChunkIterator {
53 let chunk_size = self.calculate_chunk_size(n_samples, nfeatures);
54 ChunkIterator {
55 current: 0,
56 total: n_samples,
57 chunk_size,
58 }
59 }
60}
61
62#[derive(Debug)]
64pub struct ChunkIterator {
65 current: usize,
66 total: usize,
67 chunk_size: usize,
68}
69
70impl Iterator for ChunkIterator {
71 type Item = (usize, usize); fn next(&mut self) -> Option<Self::Item> {
74 if self.current >= self.total {
75 return None;
76 }
77
78 let start = self.current;
79 let end = (self.current + self.chunk_size).min(self.total);
80 self.current = end;
81
82 Some((start, end))
83 }
84}
85
86pub struct TypeConverter;
88
89impl TypeConverter {
90 pub fn to_f64<T, S>(array: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
92 where
93 T: Float + NumCast + Send + Sync,
94 S: Data<Elem = T>,
95 {
96 check_not_empty(array, "array")?;
97
98 let result = if array.is_standard_layout() {
99 if array.len() > 10000 {
101 let mut result = Array2::zeros(array.raw_dim());
102 Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
103 *out = NumCast::from(inp).unwrap_or(0.0);
104 });
105 result
106 } else {
107 array.mapv(|x| NumCast::from(x).unwrap_or(0.0))
108 }
109 } else {
110 let shape = array.shape();
112 let mut result = Array2::zeros((shape[0], shape[1]));
113
114 par_azip!((out in result.view_mut(), &inp in array) {
115 *out = NumCast::from(inp).unwrap_or(0.0);
116 });
117
118 result
119 };
120
121 for &val in result.iter() {
123 if !val.is_finite() {
124 return Err(crate::error::TransformError::DataValidationError(
125 "Array contains non-finite values after conversion".to_string(),
126 ));
127 }
128 }
129 Ok(result)
130 }
131
132 pub fn f32_to_f64_simd(array: &ArrayView2<f32>) -> Result<Array2<f64>> {
134 check_not_empty(array, "array")?;
135
136 let result = if array.len() > 10000 {
137 let mut result = Array2::zeros(array.raw_dim());
138 Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
139 *out = inp as f64;
140 });
141 result
142 } else {
143 array.mapv(|x| x as f64)
144 };
145
146 for &val in result.iter() {
147 if !val.is_finite() {
148 return Err(crate::error::TransformError::DataValidationError(
149 "Array contains non-finite values after conversion".to_string(),
150 ));
151 }
152 }
153 Ok(result)
154 }
155
156 pub fn f64_to_f32_safe(array: &ArrayView2<f64>) -> Result<Array2<f32>> {
158 check_not_empty(array, "array")?;
159
160 for &val in array.iter() {
162 if !val.is_finite() {
163 return Err(crate::error::TransformError::DataValidationError(
164 "Array contains non-finite values".to_string(),
165 ));
166 }
167 }
168
169 let mut result = Array2::zeros(array.raw_dim());
170 for (out, &inp) in result.iter_mut().zip(array.iter()) {
171 if inp.abs() > f32::MAX as f64 {
172 return Err(TransformError::DataValidationError(
173 "Value too large for f32 conversion".to_string(),
174 ));
175 }
176 *out = inp as f32;
177 }
178
179 Ok(result)
180 }
181}
182
183pub struct StatUtils;
185
186impl StatUtils {
187 pub fn robust_stats(data: &ArrayView1<f64>) -> Result<(f64, f64)> {
189 check_not_empty(data, "data")?;
190
191 for &val in data.iter() {
193 if !val.is_finite() {
194 return Err(crate::error::TransformError::DataValidationError(
195 "Data contains non-finite values".to_string(),
196 ));
197 }
198 }
199
200 let mut sorted_data = data.to_vec();
201 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
202
203 let n = sorted_data.len();
204 let median = if n.is_multiple_of(2) {
205 (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
206 } else {
207 sorted_data[n / 2]
208 };
209
210 let mut deviations: Vec<f64> = sorted_data.iter().map(|&x| (x - median).abs()).collect();
212 deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
213
214 let mad = if n.is_multiple_of(2) {
215 (deviations[n / 2 - 1] + deviations[n / 2]) / 2.0
216 } else {
217 deviations[n / 2]
218 };
219
220 Ok((median, mad))
221 }
222
223 pub fn robust_stats_columns(data: &ArrayView2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
225 check_not_empty(data, "data")?;
226
227 for &val in data.iter() {
229 if !val.is_finite() {
230 return Err(crate::error::TransformError::DataValidationError(
231 "Data contains non-finite values".to_string(),
232 ));
233 }
234 }
235
236 let nfeatures = data.ncols();
237 let mut medians = Array1::zeros(nfeatures);
238 let mut mads = Array1::zeros(nfeatures);
239
240 let stats: Result<Vec<_>> = (0..nfeatures)
242 .into_par_iter()
243 .map(|j| {
244 let col = data.column(j);
245 Self::robust_stats(&col)
246 })
247 .collect();
248
249 let stats = stats?;
250
251 for (j, (median, mad)) in stats.into_iter().enumerate() {
252 medians[j] = median;
253 mads[j] = mad;
254 }
255
256 Ok((medians, mads))
257 }
258
259 pub fn detect_outliers_iqr(data: &ArrayView1<f64>, factor: f64) -> Result<Vec<bool>> {
261 check_not_empty(data, "data")?;
262
263 for &val in data.iter() {
265 if !val.is_finite() {
266 return Err(crate::error::TransformError::DataValidationError(
267 "Data contains non-finite values".to_string(),
268 ));
269 }
270 }
271
272 if factor <= 0.0 {
273 return Err(TransformError::InvalidInput(
274 "Outlier factor must be positive".to_string(),
275 ));
276 }
277
278 let mut sorted_data = data.to_vec();
279 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
280
281 let n = sorted_data.len();
282 let q1_idx = n / 4;
283 let q3_idx = 3 * n / 4;
284
285 let q1 = sorted_data[q1_idx];
286 let q3 = sorted_data[q3_idx];
287 let iqr = q3 - q1;
288
289 let lower_bound = q1 - factor * iqr;
290 let upper_bound = q3 + factor * iqr;
291
292 let outliers = data
293 .iter()
294 .map(|&x| x < lower_bound || x > upper_bound)
295 .collect();
296
297 Ok(outliers)
298 }
299
300 pub fn data_quality_score(data: &ArrayView2<f64>) -> Result<f64> {
302 check_not_empty(data, "data")?;
303
304 let total_elements = data.len() as f64;
305
306 let finite_count = data.iter().filter(|&&x| x.is_finite()).count() as f64;
308 let finite_ratio = finite_count / total_elements;
309
310 let nfeatures = data.ncols();
312 let mut diversity_scores = Vec::with_capacity(nfeatures);
313
314 for j in 0..nfeatures {
315 let col = data.column(j);
316 let mut unique_values = std::collections::HashSet::new();
317 for &val in col.iter() {
318 if val.is_finite() {
319 let rounded = (val * 1e12).round() as i64;
321 unique_values.insert(rounded);
322 }
323 }
324
325 let diversity = if !col.is_empty() {
326 unique_values.len() as f64 / col.len() as f64
327 } else {
328 0.0
329 };
330 diversity_scores.push(diversity);
331 }
332
333 let avg_diversity = if diversity_scores.is_empty() {
334 0.0
335 } else {
336 diversity_scores.iter().sum::<f64>() / diversity_scores.len() as f64
337 };
338
339 let quality_score = 0.7 * finite_ratio + 0.3 * avg_diversity;
341
342 Ok(quality_score.clamp(0.0, 1.0))
343 }
344}
345
346pub struct ArrayMemoryPool<T> {
348 available_arrays: HashMap<(usize, usize), Vec<Array2<T>>>,
350 max_persize: usize,
352 memory_limit: usize,
354 current_memory: usize,
356}
357
358impl<T: Clone + Default> ArrayMemoryPool<T> {
359 pub fn new(_memory_limit_mb: usize, max_persize: usize) -> Self {
361 ArrayMemoryPool {
362 available_arrays: HashMap::new(),
363 max_persize,
364 memory_limit: _memory_limit_mb * 1024 * 1024,
365 current_memory: 0,
366 }
367 }
368
369 pub fn get_array(&mut self, rows: usize, cols: usize) -> Array2<T> {
371 let size_key = (rows, cols);
372
373 if let Some(arrays) = self.available_arrays.get_mut(&size_key) {
374 if let Some(array) = arrays.pop() {
375 let array_size = rows * cols * std::mem::size_of::<T>();
376 self.current_memory = self.current_memory.saturating_sub(array_size);
377 return array;
378 }
379 }
380
381 Array2::default((rows, cols))
383 }
384
385 pub fn return_array(&mut self, mut array: Array2<T>) {
387 let (rows, cols) = array.dim();
388 let size_key = (rows, cols);
389 let array_size = rows * cols * std::mem::size_of::<T>();
390
391 if self.current_memory + array_size > self.memory_limit {
393 return; }
395
396 array.fill(T::default());
398
399 let arrays = self.available_arrays.entry(size_key).or_default();
400 if arrays.len() < self.max_persize {
401 arrays.push(array);
402 self.current_memory += array_size;
403 }
404 }
405
406 pub fn clear(&mut self) {
408 self.available_arrays.clear();
409 self.current_memory = 0;
410 }
411
412 pub fn memory_usage_mb(&self) -> f64 {
414 self.current_memory as f64 / (1024.0 * 1024.0)
415 }
416}
417
418pub struct ValidationUtils;
420
421impl ValidationUtils {
422 pub fn validate_parameter_bounds(
424 value: f64,
425 min: f64,
426 max: f64,
427 param_name: &str,
428 ) -> Result<()> {
429 if !value.is_finite() {
430 return Err(TransformError::InvalidInput(format!(
431 "{param_name} must be finite"
432 )));
433 }
434
435 if value < min || value > max {
436 return Err(TransformError::InvalidInput(format!(
437 "{param_name} must be between {min} and {max}, got {value}"
438 )));
439 }
440
441 Ok(())
442 }
443
444 pub fn validate_dimensions_compatible(
446 shape1: &[usize],
447 shape2: &[usize],
448 operation: &str,
449 ) -> Result<()> {
450 if shape1.len() != shape2.len() {
451 return Err(TransformError::InvalidInput(format!(
452 "Incompatible dimensions for {operation}: {shape1:?} vs {shape2:?}"
453 )));
454 }
455
456 for (i, (&dim1, &dim2)) in shape1.iter().zip(shape2.iter()).enumerate() {
457 if dim1 != dim2 {
458 return Err(TransformError::InvalidInput(format!(
459 "Dimension {i} mismatch for {operation}: {dim1} vs {dim2}"
460 )));
461 }
462 }
463
464 Ok(())
465 }
466
467 pub fn validate_data_for_transformation(
469 data: &ArrayView2<f64>,
470 transformation: &str,
471 ) -> Result<()> {
472 check_not_empty(data, "data")?;
473
474 for &val in data.iter() {
476 if !val.is_finite() {
477 return Err(crate::error::TransformError::DataValidationError(
478 "Data contains non-finite values".to_string(),
479 ));
480 }
481 }
482
483 let (n_samples, nfeatures) = data.dim();
484
485 match transformation {
486 "pca" => {
487 if n_samples < 2 {
488 return Err(TransformError::InvalidInput(
489 "PCA requires at least 2 samples".to_string(),
490 ));
491 }
492 if nfeatures < 1 {
493 return Err(TransformError::InvalidInput(
494 "PCA requires at least 1 feature".to_string(),
495 ));
496 }
497 }
498 "standardization" => {
499 for j in 0..nfeatures {
501 let col = data.column(j);
502 let variance = col.variance();
503 if variance < 1e-15 {
504 return Err(TransformError::DataValidationError(format!(
505 "Feature {j} has zero variance and cannot be standardized"
506 )));
507 }
508 }
509 }
510 "normalization" => {
511 for i in 0..n_samples {
513 let row = data.row(i);
514 let norm = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
515 if norm < 1e-15 {
516 return Err(TransformError::DataValidationError(format!(
517 "Sample {i} has zero norm and cannot be normalized"
518 )));
519 }
520 }
521 }
522 _ => {
523 }
525 }
526
527 Ok(())
528 }
529}
530
531pub struct PerfUtils;
533
534impl PerfUtils {
535 pub fn estimate_memory_usage(
537 inputshape: &[usize],
538 outputshape: &[usize],
539 operation: &str,
540 ) -> usize {
541 let input_size = inputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
542 let output_size = outputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
543
544 let overhead = match operation {
545 "pca" => input_size * 2, "standardization" => input_size / 10, "polynomial" => output_size / 2, _ => input_size / 4, };
550
551 input_size + output_size + overhead
552 }
553
554 pub fn estimate_computation_time(
556 n_samples: usize,
557 nfeatures: usize,
558 operation: &str,
559 ) -> std::time::Duration {
560 use std::time::Duration;
561
562 let base_time_ns = match operation {
563 "pca" => (n_samples as u64) * (nfeatures as u64).pow(2) / 1000, "standardization" => (n_samples as u64) * (nfeatures as u64) / 100, "normalization" => (n_samples as u64) * (nfeatures as u64) / 50, "polynomial" => (n_samples as u64) * (nfeatures as u64).pow(3) / 10000, _ => (n_samples as u64) * (nfeatures as u64) / 100,
568 };
569
570 Duration::from_nanos(base_time_ns.max(1000)) }
572
573 pub fn choose_processing_strategy(
575 n_samples: usize,
576 nfeatures: usize,
577 available_memory_mb: usize,
578 ) -> ProcessingStrategy {
579 let estimated_memory_mb =
580 (n_samples * nfeatures * std::mem::size_of::<f64>()) / (1024 * 1024);
581
582 if estimated_memory_mb > available_memory_mb {
583 ProcessingStrategy::OutOfCore {
584 chunk_size: (available_memory_mb * 1024 * 1024)
585 / (nfeatures * std::mem::size_of::<f64>()),
586 }
587 } else if n_samples > 10000 && nfeatures > 100 {
588 ProcessingStrategy::Parallel
589 } else if nfeatures > 1000 {
590 ProcessingStrategy::Simd
591 } else {
592 ProcessingStrategy::Standard
593 }
594 }
595}
596
597#[derive(Debug, Clone)]
599#[cfg_attr(feature = "distributed", derive(serde::Serialize, serde::Deserialize))]
600pub enum ProcessingStrategy {
601 Standard,
603 Simd,
605 Parallel,
607 OutOfCore {
609 chunk_size: usize,
611 },
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use scirs2_core::ndarray::Array2;
618
619 #[test]
620 fn test_data_chunker() {
621 let chunker = DataChunker::new(100); let chunk_size = chunker.calculate_chunk_size(50000, 100);
623 assert!(chunk_size > 0);
624 assert!(chunk_size <= 50000);
625 }
626
627 #[test]
628 fn test_chunk_iterator() {
629 let chunker = DataChunker::new(1); let chunks: Vec<_> = chunker.chunk_indices(1000, 10).collect();
631 assert!(!chunks.is_empty());
632
633 let total_covered = chunks.iter().map(|(start, end)| end - start).sum::<usize>();
635 assert_eq!(total_covered, 1000);
636 }
637
638 #[test]
639 fn test_type_converter() {
640 let data = Array2::<f32>::ones((10, 5));
641 let result = TypeConverter::f32_to_f64_simd(&data.view()).unwrap();
642 assert_eq!(result.shape(), &[10, 5]);
643 assert!((result[(0, 0)] - 1.0).abs() < 1e-10);
644 }
645
646 #[test]
647 fn test_robust_stats() {
648 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]); let (median, mad) = StatUtils::robust_stats(&data.view()).unwrap();
650 assert!((median - 3.5).abs() < 1e-10);
651 assert!(mad > 0.0);
652 }
653
654 #[test]
655 fn test_outlier_detection() {
656 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]);
657 let outliers = StatUtils::detect_outliers_iqr(&data.view(), 1.5).unwrap();
658 assert_eq!(outliers.len(), 6);
659 assert!(outliers[5]); }
661
662 #[test]
663 fn test_data_quality_score() {
664 let good_data =
665 Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
666 let quality = StatUtils::data_quality_score(&good_data.view()).unwrap();
667 assert!(quality > 0.5); let bad_data = Array2::from_elem((10, 3), f64::NAN);
670 let quality = StatUtils::data_quality_score(&bad_data.view()).unwrap();
671 assert!(quality < 0.5); }
673
674 #[test]
675 fn test_memory_pool() {
676 let mut pool = ArrayMemoryPool::<f64>::new(10, 2);
677
678 let array1 = pool.get_array(10, 5);
680 assert_eq!(array1.shape(), &[10, 5]);
681
682 pool.return_array(array1);
683
684 let array2 = pool.get_array(10, 5);
685 assert_eq!(array2.shape(), &[10, 5]);
686 }
687
688 #[test]
689 fn test_validation_utils() {
690 assert!(ValidationUtils::validate_parameter_bounds(0.5, 0.0, 1.0, "test").is_ok());
692 assert!(ValidationUtils::validate_parameter_bounds(1.5, 0.0, 1.0, "test").is_err());
693
694 assert!(
696 ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 5], "test").is_ok()
697 );
698 assert!(
699 ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 6], "test").is_err()
700 );
701 }
702
703 #[test]
704 fn test_performance_utils() {
705 let memory = PerfUtils::estimate_memory_usage(&[1000, 100], &[1000, 50], "pca");
706 assert!(memory > 0);
707
708 let time = PerfUtils::estimate_computation_time(1000, 100, "pca");
709 assert!(time.as_nanos() > 0);
710
711 let strategy = PerfUtils::choose_processing_strategy(10000, 100, 100);
712 matches!(strategy, ProcessingStrategy::Parallel);
713 }
714}