scirs2_transform/
utils.rs

1//! Utility functions and helpers for data transformation
2//!
3//! This module provides common utility functions that are frequently needed
4//! for data transformation tasks, including data validation, memory optimization,
5//! and performance helpers.
6
7use scirs2_core::ndarray::{
8    par_azip, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Data, Ix2, Zip,
9};
10use scirs2_core::numeric::{Float, NumCast};
11use scirs2_core::parallel_ops::*;
12use scirs2_core::validation::check_not_empty;
13use std::collections::HashMap;
14
15use crate::error::{Result, TransformError};
16use statrs::statistics::Statistics;
17
18/// Memory-efficient data chunking for large-scale transformations
19#[derive(Debug, Clone)]
20pub struct DataChunker {
21    /// Maximum memory usage in MB
22    _max_memorymb: usize,
23    /// Preferred chunk size in number of samples
24    preferred_chunk_size: usize,
25    /// Minimum chunk size to maintain efficiency
26    min_chunk_size: usize,
27}
28
29impl DataChunker {
30    /// Create a new data chunker with memory constraints
31    pub fn new(_max_memorymb: usize) -> Self {
32        DataChunker {
33            _max_memorymb,
34            preferred_chunk_size: 10000,
35            min_chunk_size: 100,
36        }
37    }
38
39    /// Calculate optimal chunk size for given data dimensions
40    pub fn calculate_chunk_size(&self, n_samples: usize, nfeatures: usize) -> usize {
41        // Estimate memory per sample (8 bytes per f64 element + overhead)
42        let bytes_per_sample = nfeatures * std::mem::size_of::<f64>() + 64; // 64 bytes overhead
43        let max_samples_in_memory = (self._max_memorymb * 1024 * 1024) / bytes_per_sample;
44
45        max_samples_in_memory
46            .min(self.preferred_chunk_size)
47            .max(self.min_chunk_size)
48            .min(n_samples)
49    }
50
51    /// Iterator over data chunks
52    pub fn chunk_indices(&self, n_samples: usize, nfeatures: usize) -> ChunkIterator {
53        let chunk_size = self.calculate_chunk_size(n_samples, nfeatures);
54        ChunkIterator {
55            current: 0,
56            total: n_samples,
57            chunk_size,
58        }
59    }
60}
61
62/// Iterator for data chunk indices
63#[derive(Debug)]
64pub struct ChunkIterator {
65    current: usize,
66    total: usize,
67    chunk_size: usize,
68}
69
70impl Iterator for ChunkIterator {
71    type Item = (usize, usize); // (start_idx, end_idx)
72
73    fn next(&mut self) -> Option<Self::Item> {
74        if self.current >= self.total {
75            return None;
76        }
77
78        let start = self.current;
79        let end = (self.current + self.chunk_size).min(self.total);
80        self.current = end;
81
82        Some((start, end))
83    }
84}
85
86/// Fast data type conversion utilities
87pub struct TypeConverter;
88
89impl TypeConverter {
90    /// Convert array to f64 with optimized SIMD operations where possible
91    pub fn to_f64<T, S>(array: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
92    where
93        T: Float + NumCast + Send + Sync,
94        S: Data<Elem = T>,
95    {
96        check_not_empty(array, "array")?;
97
98        let result = if array.is_standard_layout() {
99            // Use parallel processing for large arrays
100            if array.len() > 10000 {
101                let mut result = Array2::zeros(array.raw_dim());
102                Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
103                    *out = NumCast::from(inp).unwrap_or(0.0);
104                });
105                result
106            } else {
107                array.mapv(|x| NumCast::from(x).unwrap_or(0.0))
108            }
109        } else {
110            // Handle non-standard layout
111            let shape = array.shape();
112            let mut result = Array2::zeros((shape[0], shape[1]));
113
114            par_azip!((out in result.view_mut(), &inp in array) {
115                *out = NumCast::from(inp).unwrap_or(0.0);
116            });
117
118            result
119        };
120
121        // Validate result for non-finite values
122        for &val in result.iter() {
123            if !val.is_finite() {
124                return Err(crate::error::TransformError::DataValidationError(
125                    "Array contains non-finite values after conversion".to_string(),
126                ));
127            }
128        }
129        Ok(result)
130    }
131
132    /// Convert f32 array to f64 with SIMD optimization
133    pub fn f32_to_f64_simd(array: &ArrayView2<f32>) -> Result<Array2<f64>> {
134        check_not_empty(array, "array")?;
135
136        let result = if array.len() > 10000 {
137            let mut result = Array2::zeros(array.raw_dim());
138            Zip::from(&mut result).and(array).par_for_each(|out, &inp| {
139                *out = inp as f64;
140            });
141            result
142        } else {
143            array.mapv(|x| x as f64)
144        };
145
146        for &val in result.iter() {
147            if !val.is_finite() {
148                return Err(crate::error::TransformError::DataValidationError(
149                    "Array contains non-finite values after conversion".to_string(),
150                ));
151            }
152        }
153        Ok(result)
154    }
155
156    /// Convert f64 array to f32 with overflow checking
157    pub fn f64_to_f32_safe(array: &ArrayView2<f64>) -> Result<Array2<f32>> {
158        check_not_empty(array, "array")?;
159
160        // Check finite values
161        for &val in array.iter() {
162            if !val.is_finite() {
163                return Err(crate::error::TransformError::DataValidationError(
164                    "Array contains non-finite values".to_string(),
165                ));
166            }
167        }
168
169        let mut result = Array2::zeros(array.raw_dim());
170        for (out, &inp) in result.iter_mut().zip(array.iter()) {
171            if inp.abs() > f32::MAX as f64 {
172                return Err(TransformError::DataValidationError(
173                    "Value too large for f32 conversion".to_string(),
174                ));
175            }
176            *out = inp as f32;
177        }
178
179        Ok(result)
180    }
181}
182
183/// Statistical utilities for transformation validation
184pub struct StatUtils;
185
186impl StatUtils {
187    /// Calculate robust statistics (median, MAD) efficiently
188    pub fn robust_stats(data: &ArrayView1<f64>) -> Result<(f64, f64)> {
189        check_not_empty(data, "data")?;
190
191        // Check finite values
192        for &val in data.iter() {
193            if !val.is_finite() {
194                return Err(crate::error::TransformError::DataValidationError(
195                    "Data contains non-finite values".to_string(),
196                ));
197            }
198        }
199
200        let mut sorted_data = data.to_vec();
201        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
202
203        let n = sorted_data.len();
204        let median = if n.is_multiple_of(2) {
205            (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
206        } else {
207            sorted_data[n / 2]
208        };
209
210        // Calculate MAD (Median Absolute Deviation)
211        let mut deviations: Vec<f64> = sorted_data.iter().map(|&x| (x - median).abs()).collect();
212        deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
213
214        let mad = if n.is_multiple_of(2) {
215            (deviations[n / 2 - 1] + deviations[n / 2]) / 2.0
216        } else {
217            deviations[n / 2]
218        };
219
220        Ok((median, mad))
221    }
222
223    /// Calculate column-wise robust statistics in parallel
224    pub fn robust_stats_columns(data: &ArrayView2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
225        check_not_empty(data, "data")?;
226
227        // Check finite values
228        for &val in data.iter() {
229            if !val.is_finite() {
230                return Err(crate::error::TransformError::DataValidationError(
231                    "Data contains non-finite values".to_string(),
232                ));
233            }
234        }
235
236        let nfeatures = data.ncols();
237        let mut medians = Array1::zeros(nfeatures);
238        let mut mads = Array1::zeros(nfeatures);
239
240        // Use parallel processing for multiple columns
241        let stats: Result<Vec<_>> = (0..nfeatures)
242            .into_par_iter()
243            .map(|j| {
244                let col = data.column(j);
245                Self::robust_stats(&col)
246            })
247            .collect();
248
249        let stats = stats?;
250
251        for (j, (median, mad)) in stats.into_iter().enumerate() {
252            medians[j] = median;
253            mads[j] = mad;
254        }
255
256        Ok((medians, mads))
257    }
258
259    /// Detect outliers using IQR method
260    pub fn detect_outliers_iqr(data: &ArrayView1<f64>, factor: f64) -> Result<Vec<bool>> {
261        check_not_empty(data, "data")?;
262
263        // Check finite values
264        for &val in data.iter() {
265            if !val.is_finite() {
266                return Err(crate::error::TransformError::DataValidationError(
267                    "Data contains non-finite values".to_string(),
268                ));
269            }
270        }
271
272        if factor <= 0.0 {
273            return Err(TransformError::InvalidInput(
274                "Outlier factor must be positive".to_string(),
275            ));
276        }
277
278        let mut sorted_data = data.to_vec();
279        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
280
281        let n = sorted_data.len();
282        let q1_idx = n / 4;
283        let q3_idx = 3 * n / 4;
284
285        let q1 = sorted_data[q1_idx];
286        let q3 = sorted_data[q3_idx];
287        let iqr = q3 - q1;
288
289        let lower_bound = q1 - factor * iqr;
290        let upper_bound = q3 + factor * iqr;
291
292        let outliers = data
293            .iter()
294            .map(|&x| x < lower_bound || x > upper_bound)
295            .collect();
296
297        Ok(outliers)
298    }
299
300    /// Calculate data quality score
301    pub fn data_quality_score(data: &ArrayView2<f64>) -> Result<f64> {
302        check_not_empty(data, "data")?;
303
304        let total_elements = data.len() as f64;
305
306        // Count finite values
307        let finite_count = data.iter().filter(|&&x| x.is_finite()).count() as f64;
308        let finite_ratio = finite_count / total_elements;
309
310        // Count unique values per column (diversity score)
311        let nfeatures = data.ncols();
312        let mut diversity_scores = Vec::with_capacity(nfeatures);
313
314        for j in 0..nfeatures {
315            let col = data.column(j);
316            let mut unique_values = std::collections::HashSet::new();
317            for &val in col.iter() {
318                if val.is_finite() {
319                    // Round to avoid floating point precision issues
320                    let rounded = (val * 1e12).round() as i64;
321                    unique_values.insert(rounded);
322                }
323            }
324
325            let diversity = if !col.is_empty() {
326                unique_values.len() as f64 / col.len() as f64
327            } else {
328                0.0
329            };
330            diversity_scores.push(diversity);
331        }
332
333        let avg_diversity = if diversity_scores.is_empty() {
334            0.0
335        } else {
336            diversity_scores.iter().sum::<f64>() / diversity_scores.len() as f64
337        };
338
339        // Combine scores with weights
340        let quality_score = 0.7 * finite_ratio + 0.3 * avg_diversity;
341
342        Ok(quality_score.clamp(0.0, 1.0))
343    }
344}
345
346/// Memory pool for efficient array allocation and reuse
347pub struct ArrayMemoryPool<T> {
348    /// Available arrays by size
349    available_arrays: HashMap<(usize, usize), Vec<Array2<T>>>,
350    /// Maximum number of arrays to keep per size
351    max_persize: usize,
352    /// Total memory limit in bytes
353    memory_limit: usize,
354    /// Current memory usage
355    current_memory: usize,
356}
357
358impl<T: Clone + Default> ArrayMemoryPool<T> {
359    /// Create a new array memory pool
360    pub fn new(_memory_limit_mb: usize, max_persize: usize) -> Self {
361        ArrayMemoryPool {
362            available_arrays: HashMap::new(),
363            max_persize,
364            memory_limit: _memory_limit_mb * 1024 * 1024,
365            current_memory: 0,
366        }
367    }
368
369    /// Get an array from the pool or create a new one
370    pub fn get_array(&mut self, rows: usize, cols: usize) -> Array2<T> {
371        let size_key = (rows, cols);
372
373        if let Some(arrays) = self.available_arrays.get_mut(&size_key) {
374            if let Some(array) = arrays.pop() {
375                let array_size = rows * cols * std::mem::size_of::<T>();
376                self.current_memory = self.current_memory.saturating_sub(array_size);
377                return array;
378            }
379        }
380
381        // Create new array if none available
382        Array2::default((rows, cols))
383    }
384
385    /// Return an array to the pool for reuse
386    pub fn return_array(&mut self, mut array: Array2<T>) {
387        let (rows, cols) = array.dim();
388        let size_key = (rows, cols);
389        let array_size = rows * cols * std::mem::size_of::<T>();
390
391        // Check memory limits
392        if self.current_memory + array_size > self.memory_limit {
393            return; // Drop the array
394        }
395
396        // Zero out the array for reuse
397        array.fill(T::default());
398
399        let arrays = self.available_arrays.entry(size_key).or_default();
400        if arrays.len() < self.max_persize {
401            arrays.push(array);
402            self.current_memory += array_size;
403        }
404    }
405
406    /// Clear the pool and free memory
407    pub fn clear(&mut self) {
408        self.available_arrays.clear();
409        self.current_memory = 0;
410    }
411
412    /// Get current memory usage in MB
413    pub fn memory_usage_mb(&self) -> f64 {
414        self.current_memory as f64 / (1024.0 * 1024.0)
415    }
416}
417
418/// Validation utilities for transformation parameters
419pub struct ValidationUtils;
420
421impl ValidationUtils {
422    /// Validate that a parameter is within reasonable bounds
423    pub fn validate_parameter_bounds(
424        value: f64,
425        min: f64,
426        max: f64,
427        param_name: &str,
428    ) -> Result<()> {
429        if !value.is_finite() {
430            return Err(TransformError::InvalidInput(format!(
431                "{param_name} must be finite"
432            )));
433        }
434
435        if value < min || value > max {
436            return Err(TransformError::InvalidInput(format!(
437                "{param_name} must be between {min} and {max}, got {value}"
438            )));
439        }
440
441        Ok(())
442    }
443
444    /// Validate array dimensions for compatibility
445    pub fn validate_dimensions_compatible(
446        shape1: &[usize],
447        shape2: &[usize],
448        operation: &str,
449    ) -> Result<()> {
450        if shape1.len() != shape2.len() {
451            return Err(TransformError::InvalidInput(format!(
452                "Incompatible dimensions for {operation}: {shape1:?} vs {shape2:?}"
453            )));
454        }
455
456        for (i, (&dim1, &dim2)) in shape1.iter().zip(shape2.iter()).enumerate() {
457            if dim1 != dim2 {
458                return Err(TransformError::InvalidInput(format!(
459                    "Dimension {i} mismatch for {operation}: {dim1} vs {dim2}"
460                )));
461            }
462        }
463
464        Ok(())
465    }
466
467    /// Validate that data is suitable for a specific transformation
468    pub fn validate_data_for_transformation(
469        data: &ArrayView2<f64>,
470        transformation: &str,
471    ) -> Result<()> {
472        check_not_empty(data, "data")?;
473
474        // Check finite values
475        for &val in data.iter() {
476            if !val.is_finite() {
477                return Err(crate::error::TransformError::DataValidationError(
478                    "Data contains non-finite values".to_string(),
479                ));
480            }
481        }
482
483        let (n_samples, nfeatures) = data.dim();
484
485        match transformation {
486            "pca" => {
487                if n_samples < 2 {
488                    return Err(TransformError::InvalidInput(
489                        "PCA requires at least 2 samples".to_string(),
490                    ));
491                }
492                if nfeatures < 1 {
493                    return Err(TransformError::InvalidInput(
494                        "PCA requires at least 1 feature".to_string(),
495                    ));
496                }
497            }
498            "standardization" => {
499                // Check for constant features
500                for j in 0..nfeatures {
501                    let col = data.column(j);
502                    let variance = col.variance();
503                    if variance < 1e-15 {
504                        return Err(TransformError::DataValidationError(format!(
505                            "Feature {j} has zero variance and cannot be standardized"
506                        )));
507                    }
508                }
509            }
510            "normalization" => {
511                // Check for zero-norm rows
512                for i in 0..n_samples {
513                    let row = data.row(i);
514                    let norm = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
515                    if norm < 1e-15 {
516                        return Err(TransformError::DataValidationError(format!(
517                            "Sample {i} has zero norm and cannot be normalized"
518                        )));
519                    }
520                }
521            }
522            _ => {
523                // Generic validation
524            }
525        }
526
527        Ok(())
528    }
529}
530
531/// Performance monitoring utilities
532pub struct PerfUtils;
533
534impl PerfUtils {
535    /// Estimate memory usage for an operation
536    pub fn estimate_memory_usage(
537        inputshape: &[usize],
538        outputshape: &[usize],
539        operation: &str,
540    ) -> usize {
541        let input_size = inputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
542        let output_size = outputshape.iter().product::<usize>() * std::mem::size_of::<f64>();
543
544        let overhead = match operation {
545            "pca" => input_size * 2,              // Covariance matrix + temporaries
546            "standardization" => input_size / 10, // Just statistics
547            "polynomial" => output_size / 2,      // Temporary computations
548            _ => input_size / 4,                  // Default overhead
549        };
550
551        input_size + output_size + overhead
552    }
553
554    /// Estimate computation time based on data size and operation
555    pub fn estimate_computation_time(
556        n_samples: usize,
557        nfeatures: usize,
558        operation: &str,
559    ) -> std::time::Duration {
560        use std::time::Duration;
561
562        let base_time_ns = match operation {
563            "pca" => (n_samples as u64) * (nfeatures as u64).pow(2) / 1000, // O(n*m^2)
564            "standardization" => (n_samples as u64) * (nfeatures as u64) / 100, // O(n*m)
565            "normalization" => (n_samples as u64) * (nfeatures as u64) / 50, // O(n*m)
566            "polynomial" => (n_samples as u64) * (nfeatures as u64).pow(3) / 10000, // O(n*m^3)
567            _ => (n_samples as u64) * (nfeatures as u64) / 100,
568        };
569
570        Duration::from_nanos(base_time_ns.max(1000)) // At least 1 microsecond
571    }
572
573    /// Choose optimal processing strategy based on data characteristics
574    pub fn choose_processing_strategy(
575        n_samples: usize,
576        nfeatures: usize,
577        available_memory_mb: usize,
578    ) -> ProcessingStrategy {
579        let estimated_memory_mb =
580            (n_samples * nfeatures * std::mem::size_of::<f64>()) / (1024 * 1024);
581
582        if estimated_memory_mb > available_memory_mb {
583            ProcessingStrategy::OutOfCore {
584                chunk_size: (available_memory_mb * 1024 * 1024)
585                    / (nfeatures * std::mem::size_of::<f64>()),
586            }
587        } else if n_samples > 10000 && nfeatures > 100 {
588            ProcessingStrategy::Parallel
589        } else if nfeatures > 1000 {
590            ProcessingStrategy::Simd
591        } else {
592            ProcessingStrategy::Standard
593        }
594    }
595}
596
597/// Processing strategy recommendation
598#[derive(Debug, Clone)]
599#[cfg_attr(feature = "distributed", derive(serde::Serialize, serde::Deserialize))]
600pub enum ProcessingStrategy {
601    /// Standard sequential processing
602    Standard,
603    /// SIMD-accelerated processing
604    Simd,
605    /// Parallel processing across multiple cores
606    Parallel,
607    /// Out-of-core processing for large datasets
608    OutOfCore {
609        /// Size of data chunks for processing
610        chunk_size: usize,
611    },
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use scirs2_core::ndarray::Array2;
618
619    #[test]
620    fn test_data_chunker() {
621        let chunker = DataChunker::new(100); // 100MB
622        let chunk_size = chunker.calculate_chunk_size(50000, 100);
623        assert!(chunk_size > 0);
624        assert!(chunk_size <= 50000);
625    }
626
627    #[test]
628    fn test_chunk_iterator() {
629        let chunker = DataChunker::new(1); // 1MB - small for testing
630        let chunks: Vec<_> = chunker.chunk_indices(1000, 10).collect();
631        assert!(!chunks.is_empty());
632
633        // Verify complete coverage
634        let total_covered = chunks.iter().map(|(start, end)| end - start).sum::<usize>();
635        assert_eq!(total_covered, 1000);
636    }
637
638    #[test]
639    fn test_type_converter() {
640        let data = Array2::<f32>::ones((10, 5));
641        let result = TypeConverter::f32_to_f64_simd(&data.view()).unwrap();
642        assert_eq!(result.shape(), &[10, 5]);
643        assert!((result[(0, 0)] - 1.0).abs() < 1e-10);
644    }
645
646    #[test]
647    fn test_robust_stats() {
648        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]); // With outlier
649        let (median, mad) = StatUtils::robust_stats(&data.view()).unwrap();
650        assert!((median - 3.5).abs() < 1e-10);
651        assert!(mad > 0.0);
652    }
653
654    #[test]
655    fn test_outlier_detection() {
656        let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]);
657        let outliers = StatUtils::detect_outliers_iqr(&data.view(), 1.5).unwrap();
658        assert_eq!(outliers.len(), 6);
659        assert!(outliers[5]); // 100.0 should be detected as outlier
660    }
661
662    #[test]
663    fn test_data_quality_score() {
664        let good_data =
665            Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
666        let quality = StatUtils::data_quality_score(&good_data.view()).unwrap();
667        assert!(quality > 0.5); // Should have reasonable quality
668
669        let bad_data = Array2::from_elem((10, 3), f64::NAN);
670        let quality = StatUtils::data_quality_score(&bad_data.view()).unwrap();
671        assert!(quality < 0.5); // Should have poor quality due to NaN values
672    }
673
674    #[test]
675    fn test_memory_pool() {
676        let mut pool = ArrayMemoryPool::<f64>::new(10, 2);
677
678        // Get and return arrays
679        let array1 = pool.get_array(10, 5);
680        assert_eq!(array1.shape(), &[10, 5]);
681
682        pool.return_array(array1);
683
684        let array2 = pool.get_array(10, 5);
685        assert_eq!(array2.shape(), &[10, 5]);
686    }
687
688    #[test]
689    fn test_validation_utils() {
690        // Test parameter bounds validation
691        assert!(ValidationUtils::validate_parameter_bounds(0.5, 0.0, 1.0, "test").is_ok());
692        assert!(ValidationUtils::validate_parameter_bounds(1.5, 0.0, 1.0, "test").is_err());
693
694        // Test dimension compatibility
695        assert!(
696            ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 5], "test").is_ok()
697        );
698        assert!(
699            ValidationUtils::validate_dimensions_compatible(&[10, 5], &[10, 6], "test").is_err()
700        );
701    }
702
703    #[test]
704    fn test_performance_utils() {
705        let memory = PerfUtils::estimate_memory_usage(&[1000, 100], &[1000, 50], "pca");
706        assert!(memory > 0);
707
708        let time = PerfUtils::estimate_computation_time(1000, 100, "pca");
709        assert!(time.as_nanos() > 0);
710
711        let strategy = PerfUtils::choose_processing_strategy(10000, 100, 100);
712        matches!(strategy, ProcessingStrategy::Parallel);
713    }
714}