scirs2_transform/
utils.rs

1//! Utility functions and helpers for data transformation
2//!
3//! This module provides common utility functions that are frequently needed
4//! for data transformation tasks, including data validation, memory optimization,
5//! and performance helpers.
6
7use ndarray::{par_azip, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2, Zip};
8use num_traits::{Float, NumCast};
9use scirs2_core::parallel_ops::*;
10use scirs2_core::validation::check_not_empty;
11use std::collections::HashMap;
12
13use crate::error::{Result, TransformError};
14use statrs::statistics::Statistics;
15
16/// Memory-efficient data chunking for large-scale transformations
17#[derive(Debug, Clone)]
18pub struct DataChunker {
19    /// Maximum memory usage in MB
20    _max_memorymb: usize,
21    /// Preferred chunk size in number of samples
22    preferred_chunk_size: usize,
23    /// Minimum chunk size to maintain efficiency
24    min_chunk_size: usize,
25}
26
27impl DataChunker {
28    /// Create a new data chunker with memory constraints
29    pub fn new(_max_memorymb: usize) -> Self {
30        DataChunker {
31            _max_memorymb,
32            preferred_chunk_size: 10000,
33            min_chunk_size: 100,
34        }
35    }
36
37    /// Calculate optimal chunk size for given data dimensions
38    pub fn calculate_chunk_size(&self, n_samples: usize, nfeatures: usize) -> usize {
39        // Estimate memory per sample (8 bytes per f64 element + overhead)
40        let bytes_per_sample = nfeatures * std::mem::size_of::<f64>() + 64; // 64 bytes overhead
41        let max_samples_in_memory = (self._max_memorymb * 1024 * 1024) / bytes_per_sample;
42
43        max_samples_in_memory
44            .min(self.preferred_chunk_size)
45            .max(self.min_chunk_size)
46            .min(n_samples)
47    }
48
49    /// Iterator over data chunks
50    pub fn chunk_indices(&self, n_samples: usize, nfeatures: usize) -> ChunkIterator {
51        let chunk_size = self.calculate_chunk_size(n_samples, nfeatures);
52        ChunkIterator {
53            current: 0,
54            total: n_samples,
55            chunk_size,
56        }
57    }
58}
59
60/// Iterator for data chunk indices
61#[derive(Debug)]
62pub struct ChunkIterator {
63    current: usize,
64    total: usize,
65    chunk_size: usize,
66}
67
68impl Iterator for ChunkIterator {
69    type Item = (usize, usize); // (start_idx, end_idx)
70
71    fn next(&mut self) -> Option<Self::Item> {
72        if self.current >= self.total {
73            return None;
74        }
75
76        let start = self.current;
77        let end = (self.current + self.chunk_size).min(self.total);
78        self.current = end;
79
80        Some((start, end))
81    }
82}
83
84/// Fast data type conversion utilities
85pub struct TypeConverter;
86
87impl TypeConverter {
88    /// Convert array to f64 with optimized SIMD operations where possible
89    pub fn to_f64<T, S>(array: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
90    where
91        T: Float + NumCast + Send + Sync,
92        S: Data<Elem = T>,
93    {
94        check_not_empty(array, "array")?;
95
96        let result = if array.is_standard_layout() {
97            // Use parallel processing for large arrays
98            if array.len() > 10000 {
99                let mut result = Array2::zeros(array.raw_dim());
100                Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
101                    *out = num_traits::cast::<T, f64>(inp).unwrap_or(0.0);
102                });
103                result
104            } else {
105                array.mapv(|x| num_traits::cast::<T, f64>(x).unwrap_or(0.0))
106            }
107        } else {
108            // Handle non-standard layout
109            let shape = array.shape();
110            let mut result = Array2::zeros((shape[0], shape[1]));
111
112            par_azip!((out in result.view_mut(), &inp in array) {
113                *out = num_traits::cast::<T, f64>(inp).unwrap_or(0.0);
114            });
115
116            result
117        };
118
119        // Validate result for non-finite values
120        for &val in result.iter() {
121            if !val.is_finite() {
122                return Err(crate::error::TransformError::DataValidationError(
123                    "Array contains non-finite values after conversion".to_string(),
124                ));
125            }
126        }
127        Ok(result)
128    }
129
130    /// Convert f32 array to f64 with SIMD optimization
131    pub fn f32_to_f64_simd(array: &ArrayView2<f32>) -> Result<Array2<f64>> {
132        check_not_empty(array, "array")?;
133
134        let result = if array.len() > 10000 {
135            let mut result = Array2::zeros(array.raw_dim());
136            Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
137                *out = inp as f64;
138            });
139            result
140        } else {
141            array.mapv(|x| x as f64)
142        };
143
144        for &val in result.iter() {
145            if !val.is_finite() {
146                return Err(crate::error::TransformError::DataValidationError(
147                    "Array contains non-finite values after conversion".to_string(),
148                ));
149            }
150        }
151        Ok(result)
152    }
153
154    /// Convert f64 array to f32 with overflow checking
155    pub fn f64_to_f32_safe(array: &ArrayView2<f64>) -> Result<Array2<f32>> {
156        check_not_empty(array, "array")?;
157
158        // Check finite values
159        for &val in array.iter() {
160            if !val.is_finite() {
161                return Err(crate::error::TransformError::DataValidationError(
162                    "Array contains non-finite values".to_string(),
163                ));
164            }
165        }
166
167        let mut result = Array2::zeros(array.raw_dim());
168        for (out, &inp) in result.iter_mut().zip(array.iter()) {
169            if inp.abs() > f32::MAX as f64 {
170                return Err(TransformError::DataValidationError(
171                    "Value too large for f32 conversion".to_string(),
172                ));
173            }
174            *out = inp as f32;
175        }
176
177        Ok(result)
178    }
179}
180
181/// Statistical utilities for transformation validation
182pub struct StatUtils;
183
184impl StatUtils {
185    /// Calculate robust statistics (median, MAD) efficiently
186    pub fn robust_stats(data: &ArrayView1<f64>) -> Result<(f64, f64)> {
187        check_not_empty(data, "data")?;
188
189        // Check finite values
190        for &val in data.iter() {
191            if !val.is_finite() {
192                return Err(crate::error::TransformError::DataValidationError(
193                    "Data contains non-finite values".to_string(),
194                ));
195            }
196        }
197
198        let mut sorted_data = data.to_vec();
199        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
200
201        let n = sorted_data.len();
202        let median = if n % 2 == 0 {
203            (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
204        } else {
205            sorted_data[n / 2]
206        };
207
208        // Calculate MAD (Median Absolute Deviation)
209        let mut deviations: Vec<f64> = sorted_data.iter().map(|&x| (x - median).abs()).collect();
210        deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
211
212        let mad = if n % 2 == 0 {
213            (deviations[n / 2 - 1] + deviations[n / 2]) / 2.0
214        } else {
215            deviations[n / 2]
216        };
217
218        Ok((median, mad))
219    }
220
221    /// Calculate column-wise robust statistics in parallel
222    pub fn robust_stats_columns(data: &ArrayView2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
223        check_not_empty(data, "data")?;
224
225        // Check finite values
226        for &val in data.iter() {
227            if !val.is_finite() {
228                return Err(crate::error::TransformError::DataValidationError(
229                    "Data contains non-finite values".to_string(),
230                ));
231            }
232        }
233
234        let nfeatures = data.ncols();
235        let mut medians = Array1::zeros(nfeatures);
236        let mut mads = Array1::zeros(nfeatures);
237
238        // Use parallel processing for multiple columns
239        let stats: Result<Vec<_>> = (0..nfeatures)
240            .into_par_iter()
241            .map(|j| {
242                let col = data.column(j);
243                Self::robust_stats(&col)
244            })
245            .collect();
246
247        let stats = stats?;
248
249        for (j, (median, mad)) in stats.into_iter().enumerate() {
250            medians[j] = median;
251            mads[j] = mad;
252        }
253
254        Ok((medians, mads))
255    }
256
257    /// Detect outliers using IQR method
258    pub fn detect_outliers_iqr(data: &ArrayView1<f64>, factor: f64) -> Result<Vec<bool>> {
259        check_not_empty(data, "data")?;
260
261        // Check finite values
262        for &val in data.iter() {
263            if !val.is_finite() {
264                return Err(crate::error::TransformError::DataValidationError(
265                    "Data contains non-finite values".to_string(),
266                ));
267            }
268        }
269
270        if factor <= 0.0 {
271            return Err(TransformError::InvalidInput(
272                "Outlier factor must be positive".to_string(),
273            ));
274        }
275
276        let mut sorted_data = data.to_vec();
277        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
278
279        let n = sorted_data.len();
280        let q1_idx = n / 4;
281        let q3_idx = 3 * n / 4;
282
283        let q1 = sorted_data[q1_idx];
284        let q3 = sorted_data[q3_idx];
285        let iqr = q3 - q1;
286
287        let lower_bound = q1 - factor * iqr;
288        let upper_bound = q3 + factor * iqr;
289
290        let outliers = data
291            .iter()
292            .map(|&x| x < lower_bound || x > upper_bound)
293            .collect();
294
295        Ok(outliers)
296    }
297
298    /// Calculate data quality score
299    pub fn data_quality_score(data: &ArrayView2<f64>) -> Result<f64> {
300        check_not_empty(data, "data")?;
301
302        let total_elements = data.len() as f64;
303
304        // Count finite values
305        let finite_count = data.iter().filter(|&&x| x.is_finite()).count() as f64;
306        let finite_ratio = finite_count / total_elements;
307
308        // Count unique values per column (diversity score)
309        let nfeatures = data.ncols();
310        let mut diversity_scores = Vec::with_capacity(nfeatures);
311
312        for j in 0..nfeatures {
313            let col = data.column(j);
314            let mut unique_values = std::collections::HashSet::new();
315            for &val in col.iter() {
316                if val.is_finite() {
317                    // Round to avoid floating point precision issues
318                    let rounded = (val * 1e12).round() as i64;
319                    unique_values.insert(rounded);
320                }
321            }
322
323            let diversity = if !col.is_empty() {
324                unique_values.len() as f64 / col.len() as f64
325            } else {
326                0.0
327            };
328            diversity_scores.push(diversity);
329        }
330
331        let avg_diversity = if diversity_scores.is_empty() {
332            0.0
333        } else {
334            diversity_scores.iter().sum::<f64>() / diversity_scores.len() as f64
335        };
336
337        // Combine scores with weights
338        let quality_score = 0.7 * finite_ratio + 0.3 * avg_diversity;
339
340        Ok(quality_score.clamp(0.0, 1.0))
341    }
342}
343
344/// Memory pool for efficient array allocation and reuse
345pub struct ArrayMemoryPool<T> {
346    /// Available arrays by size
347    available_arrays: HashMap<(usize, usize), Vec<Array2<T>>>,
348    /// Maximum number of arrays to keep per size
349    max_persize: usize,
350    /// Total memory limit in bytes
351    memory_limit: usize,
352    /// Current memory usage
353    current_memory: usize,
354}
355
356impl<T: Clone + Default> ArrayMemoryPool<T> {
357    /// Create a new array memory pool
358    pub fn new(_memory_limit_mb: usize, max_persize: usize) -> Self {
359        ArrayMemoryPool {
360            available_arrays: HashMap::new(),
361            max_persize,
362            memory_limit: _memory_limit_mb * 1024 * 1024,
363            current_memory: 0,
364        }
365    }
366
367    /// Get an array from the pool or create a new one
368    pub fn get_array(&mut self, rows: usize, cols: usize) -> Array2<T> {
369        let size_key = (rows, cols);
370
371        if let Some(arrays) = self.available_arrays.get_mut(&size_key) {
372            if let Some(array) = arrays.pop() {
373                let array_size = rows * cols * std::mem::size_of::<T>();
374                self.current_memory = self.current_memory.saturating_sub(array_size);
375                return array;
376            }
377        }
378
379        // Create new array if none available
380        Array2::default((rows, cols))
381    }
382
383    /// Return an array to the pool for reuse
384    pub fn return_array(&mut self, mut array: Array2<T>) {
385        let (rows, cols) = array.dim();
386        let size_key = (rows, cols);
387        let array_size = rows * cols * std::mem::size_of::<T>();
388
389        // Check memory limits
390        if self.current_memory + array_size > self.memory_limit {
391            return; // Drop the array
392        }
393
394        // Zero out the array for reuse
395        array.fill(T::default());
396
397        let arrays = self.available_arrays.entry(size_key).or_default();
398        if arrays.len() < self.max_persize {
399            arrays.push(array);
400            self.current_memory += array_size;
401        }
402    }
403
404    /// Clear the pool and free memory
405    pub fn clear(&mut self) {
406        self.available_arrays.clear();
407        self.current_memory = 0;
408    }
409
410    /// Get current memory usage in MB
411    pub fn memory_usage_mb(&self) -> f64 {
412        self.current_memory as f64 / (1024.0 * 1024.0)
413    }
414}
415
416/// Validation utilities for transformation parameters
417pub struct ValidationUtils;
418
419impl ValidationUtils {
420    /// Validate that a parameter is within reasonable bounds
421    pub fn validate_parameter_bounds(
422        value: f64,
423        min: f64,
424        max: f64,
425        param_name: &str,
426    ) -> Result<()> {
427        if !value.is_finite() {
428            return Err(TransformError::InvalidInput(format!(
429                "{param_name} must be finite"
430            )));
431        }
432
433        if value < min || value > max {
434            return Err(TransformError::InvalidInput(format!(
435                "{param_name} must be between {min} and {max}, got {value}"
436            )));
437        }
438
439        Ok(())
440    }
441
442    /// Validate array dimensions for compatibility
443    pub fn validate_dimensions_compatible(
444        shape1: &[usize],
445        shape2: &[usize],
446        operation: &str,
447    ) -> Result<()> {
448        if shape1.len() != shape2.len() {
449            return Err(TransformError::InvalidInput(format!(
450                "Incompatible dimensions for {operation}: {shape1:?} vs {shape2:?}"
451            )));
452        }
453
454        for (i, (&dim1, &dim2)) in shape1.iter().zip(shape2.iter()).enumerate() {
455            if dim1 != dim2 {
456                return Err(TransformError::InvalidInput(format!(
457                    "Dimension {i} mismatch for {operation}: {dim1} vs {dim2}"
458                )));
459            }
460        }
461
462        Ok(())
463    }
464
465    /// Validate that data is suitable for a specific transformation
466    pub fn validate_data_for_transformation(
467        data: &ArrayView2<f64>,
468        transformation: &str,
469    ) -> Result<()> {
470        check_not_empty(data, "data")?;
471
472        // Check finite values
473        for &val in data.iter() {
474            if !val.is_finite() {
475                return Err(crate::error::TransformError::DataValidationError(
476                    "Data contains non-finite values".to_string(),
477                ));
478            }
479        }
480
481        let (n_samples, nfeatures) = data.dim();
482
483        match transformation {
484            "pca" => {
485                if n_samples < 2 {
486                    return Err(TransformError::InvalidInput(
487                        "PCA requires at least 2 samples".to_string(),
488                    ));
489                }
490                if nfeatures < 1 {
491                    return Err(TransformError::InvalidInput(
492                        "PCA requires at least 1 feature".to_string(),
493                    ));
494                }
495            }
496            "standardization" => {
497                // Check for constant features
498                for j in 0..nfeatures {
499                    let col = data.column(j);
500                    let variance = col.variance();
501                    if variance < 1e-15 {
502                        return Err(TransformError::DataValidationError(format!(
503                            "Feature {j} has zero variance and cannot be standardized"
504                        )));
505                    }
506                }
507            }
508            "normalization" => {
509                // Check for zero-norm rows
510                for i in 0..n_samples {
511                    let row = data.row(i);
512                    let norm = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
513                    if norm < 1e-15 {
514                        return Err(TransformError::DataValidationError(format!(
515                            "Sample {i} has zero norm and cannot be normalized"
516                        )));
517                    }
518                }
519            }
520            _ => {
521                // Generic validation
522            }
523        }
524
525        Ok(())
526    }
527}
528
529/// Performance monitoring utilities
530pub struct PerfUtils;
531
532impl PerfUtils {
533    /// Estimate memory usage for an operation
534    pub fn estimate_memory_usage(
535        inputshape: &[usize],
536        outputshape: &[usize],
537        operation: &str,
538    ) -> usize {
539        let input_size = inputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
540        let output_size = outputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
541
542        let overhead = match operation {
543            "pca" => input_size * 2,              // Covariance matrix + temporaries
544            "standardization" => input_size / 10, // Just statistics
545            "polynomial" => output_size / 2,      // Temporary computations
546            _ => input_size / 4,                  // Default overhead
547        };
548
549        input_size + output_size + overhead
550    }
551
552    /// Estimate computation time based on data size and operation
553    pub fn estimate_computation_time(
554        n_samples: usize,
555        nfeatures: usize,
556        operation: &str,
557    ) -> std::time::Duration {
558        use std::time::Duration;
559
560        let base_time_ns = match operation {
561            "pca" => (n_samples as u64) * (nfeatures as u64).pow(2) / 1000, // O(n*m^2)
562            "standardization" => (n_samples as u64) * (nfeatures as u64) / 100, // O(n*m)
563            "normalization" => (n_samples as u64) * (nfeatures as u64) / 50, // O(n*m)
564            "polynomial" => (n_samples as u64) * (nfeatures as u64).pow(3) / 10000, // O(n*m^3)
565            _ => (n_samples as u64) * (nfeatures as u64) / 100,
566        };
567
568        Duration::from_nanos(base_time_ns.max(1000)) // At least 1 microsecond
569    }
570
571    /// Choose optimal processing strategy based on data characteristics
572    pub fn choose_processing_strategy(
573        n_samples: usize,
574        nfeatures: usize,
575        available_memory_mb: usize,
576    ) -> ProcessingStrategy {
577        let estimated_memory_mb =
578            (n_samples * nfeatures * std::mem::size_of::<f64>()) / (1024 * 1024);
579
580        if estimated_memory_mb > available_memory_mb {
581            ProcessingStrategy::OutOfCore {
582                chunk_size: (available_memory_mb * 1024 * 1024)
583                    / (nfeatures * std::mem::size_of::<f64>()),
584            }
585        } else if n_samples > 10000 && nfeatures > 100 {
586            ProcessingStrategy::Parallel
587        } else if nfeatures > 1000 {
588            ProcessingStrategy::Simd
589        } else {
590            ProcessingStrategy::Standard
591        }
592    }
593}
594
595/// Processing strategy recommendation
596#[derive(Debug, Clone)]
597#[cfg_attr(feature = "distributed", derive(serde::Serialize, serde::Deserialize))]
598pub enum ProcessingStrategy {
599    /// Standard sequential processing
600    Standard,
601    /// SIMD-accelerated processing
602    Simd,
603    /// Parallel processing across multiple cores
604    Parallel,
605    /// Out-of-core processing for large datasets
606    OutOfCore {
607        /// Size of data chunks for processing
608        chunk_size: usize,
609    },
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use ndarray::Array2;
616
617    #[test]
618    fn test_data_chunker() {
619        let chunker = DataChunker::new(100); // 100MB
620        let chunk_size = chunker.calculate_chunk_size(50000, 100);
621        assert!(chunk_size > 0);
622        assert!(chunk_size <= 50000);
623    }
624
625    #[test]
626    fn test_chunk_iterator() {
627        let chunker = DataChunker::new(1); // 1MB - small for testing
628        let chunks: Vec<_> = chunker.chunk_indices(1000, 10).collect();
629        assert!(!chunks.is_empty());
630
631        // Verify complete coverage
632        let total_covered = chunks.iter().map(|(start, end)| end - start).sum::<usize>();
633        assert_eq!(total_covered, 1000);
634    }
635
636    #[test]
637    fn test_type_converter() {
638        let data = Array2::<f32>::ones((10, 5));
639        let result = TypeConverter::f32_to_f64_simd(&data.view()).unwrap();
640        assert_eq!(result.shape(), &[10, 5]);
641        assert!((result[(0, 0)] - 1.0).abs() < 1e-10);
642    }
643
644    #[test]
645    fn test_robust_stats() {
646        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]); // With outlier
647        let (median, mad) = StatUtils::robust_stats(&data.view()).unwrap();
648        assert!((median - 3.5).abs() < 1e-10);
649        assert!(mad > 0.0);
650    }
651
652    #[test]
653    fn test_outlier_detection() {
654        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]);
655        let outliers = StatUtils::detect_outliers_iqr(&data.view(), 1.5).unwrap();
656        assert_eq!(outliers.len(), 6);
657        assert!(outliers[5]); // 100.0 should be detected as outlier
658    }
659
660    #[test]
661    fn test_data_quality_score() {
662        let good_data =
663            Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
664        let quality = StatUtils::data_quality_score(&good_data.view()).unwrap();
665        assert!(quality > 0.5); // Should have reasonable quality
666
667        let bad_data = Array2::from_elem((10, 3), f64::NAN);
668        let quality = StatUtils::data_quality_score(&bad_data.view()).unwrap();
669        assert!(quality < 0.5); // Should have poor quality due to NaN values
670    }
671
672    #[test]
673    fn test_memory_pool() {
674        let mut pool = ArrayMemoryPool::<f64>::new(10, 2);
675
676        // Get and return arrays
677        let array1 = pool.get_array(10, 5);
678        assert_eq!(array1.shape(), &[10, 5]);
679
680        pool.return_array(array1);
681
682        let array2 = pool.get_array(10, 5);
683        assert_eq!(array2.shape(), &[10, 5]);
684    }
685
686    #[test]
687    fn test_validation_utils() {
688        // Test parameter bounds validation
689        assert!(ValidationUtils::validate_parameter_bounds(0.5, 0.0, 1.0, "test").is_ok());
690        assert!(ValidationUtils::validate_parameter_bounds(1.5, 0.0, 1.0, "test").is_err());
691
692        // Test dimension compatibility
693        assert!(
694            ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 5], "test").is_ok()
695        );
696        assert!(
697            ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 6], "test").is_err()
698        );
699    }
700
701    #[test]
702    fn test_performance_utils() {
703        let memory = PerfUtils::estimate_memory_usage(&[1000, 100], &[1000, 50], "pca");
704        assert!(memory > 0);
705
706        let time = PerfUtils::estimate_computation_time(1000, 100, "pca");
707        assert!(time.as_nanos() > 0);
708
709        let strategy = PerfUtils::choose_processing_strategy(10000, 100, 100);
710        matches!(strategy, ProcessingStrategy::Parallel);
711    }
712}