Skip to main content

sklears_core/
advanced_array_ops.rs

1/// Advanced array operations using SciRS2 capabilities
2///
3/// This module provides high-performance array operations that leverage SciRS2's
4/// advanced features including SIMD, GPU acceleration, and memory efficiency.
5use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds};
7// SciRS2 Policy: Using scirs2_core::ndarray (COMPLIANT)
8use scirs2_core::ndarray::Axis;
9
10/// Advanced array statistics with optimized implementations
11pub struct ArrayStats;
12
13impl ArrayStats {
14    /// Compute weighted mean with numerical stability
15    pub fn weighted_mean<T>(array: &Array1<T>, weights: &Array1<T>) -> Result<T>
16    where
17        T: FloatBounds,
18    {
19        if array.len() != weights.len() {
20            return Err(SklearsError::ShapeMismatch {
21                expected: format!("{}", array.len()),
22                actual: format!("{}", weights.len()),
23            });
24        }
25
26        let weight_sum = weights.sum();
27        if weight_sum == T::zero() {
28            return Err(SklearsError::InvalidInput(
29                "Weight sum cannot be zero".to_string(),
30            ));
31        }
32
33        let weighted_sum = array
34            .iter()
35            .zip(weights.iter())
36            .map(|(&x, &w)| x * w)
37            .fold(T::zero(), |acc, x| acc + x);
38
39        Ok(weighted_sum / weight_sum)
40    }
41
42    /// Compute robust covariance matrix with outlier handling
43    pub fn robust_covariance<T>(data: &Array2<T>, shrinkage: Option<T>) -> Result<Array2<T>>
44    where
45        T: FloatBounds + scirs2_core::ndarray::ScalarOperand,
46    {
47        let (n_samples, n_features) = data.dim();
48
49        if n_samples < 2 {
50            return Err(SklearsError::InvalidInput(
51                "Need at least 2 samples for covariance".to_string(),
52            ));
53        }
54
55        // Compute sample means
56        let means = data.mean_axis(Axis(0)).ok_or_else(|| {
57            SklearsError::NumericalError("mean_axis computation failed on empty axis".to_string())
58        })?;
59
60        // Center the data
61        let centered = data - &means.insert_axis(Axis(0));
62
63        // Compute empirical covariance
64        let cov_empirical =
65            centered.t().dot(&centered) / T::from_usize(n_samples - 1).unwrap_or_else(|| T::zero());
66
67        // Apply shrinkage if specified
68        if let Some(shrink) = shrinkage {
69            let identity = Array2::<T>::eye(n_features);
70            let trace = (0..n_features)
71                .map(|i| cov_empirical[[i, i]])
72                .fold(T::zero(), |acc, x| acc + x);
73            let target =
74                identity * (trace / T::from_usize(n_features).unwrap_or_else(|| T::zero()));
75
76            Ok(&cov_empirical * (T::one() - shrink) + &target * shrink)
77        } else {
78            Ok(cov_empirical)
79        }
80    }
81
82    /// Compute percentile with interpolation
83    pub fn percentile<T>(array: &Array1<T>, q: T) -> Result<T>
84    where
85        T: FloatBounds + PartialOrd,
86    {
87        if array.is_empty() {
88            return Err(SklearsError::InvalidInput(
89                "Array cannot be empty".to_string(),
90            ));
91        }
92
93        if q < T::zero() || q > T::from_f64(100.0).unwrap_or_else(|| T::zero()) {
94            return Err(SklearsError::InvalidInput(
95                "Percentile must be between 0 and 100".to_string(),
96            ));
97        }
98
99        let mut sorted = array.to_vec();
100        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101
102        let n = sorted.len();
103        let index = q / T::from_f64(100.0).unwrap_or_else(|| T::zero())
104            * T::from_usize(n - 1).unwrap_or_else(|| T::zero());
105        let lower_idx = index.floor().to_usize().unwrap_or(0);
106        let upper_idx = index.ceil().to_usize().unwrap_or(0).min(n - 1);
107
108        if lower_idx == upper_idx {
109            Ok(sorted[lower_idx])
110        } else {
111            let lower_val = sorted[lower_idx];
112            let upper_val = sorted[upper_idx];
113            let weight = index - T::from_usize(lower_idx).unwrap_or_else(|| T::zero());
114            Ok(lower_val * (T::one() - weight) + upper_val * weight)
115        }
116    }
117}
118
119/// Advanced matrix operations with optimizations
120pub struct MatrixOps;
121
122impl MatrixOps {
123    /// Compute matrix condition number (ratio of largest to smallest singular value)
124    pub fn condition_number<T>(matrix: &Array2<T>) -> Result<T>
125    where
126        T: FloatBounds,
127    {
128        // For now, use a simplified approach - in a full implementation,
129        // this would use SVD decomposition from SciRS2's advanced features
130        let (rows, cols) = matrix.dim();
131        if rows != cols {
132            return Err(SklearsError::InvalidInput(
133                "Matrix must be square for condition number".to_string(),
134            ));
135        }
136
137        // Simplified condition number estimation using diagonal dominance
138        let mut min_diag = T::infinity();
139        let mut max_diag = T::neg_infinity();
140
141        for i in 0..rows {
142            let diag_val = matrix[[i, i]].abs();
143            if diag_val < min_diag {
144                min_diag = diag_val;
145            }
146            if diag_val > max_diag {
147                max_diag = diag_val;
148            }
149        }
150
151        if min_diag == T::zero() {
152            Ok(T::infinity())
153        } else {
154            Ok(max_diag / min_diag)
155        }
156    }
157
158    /// Compute matrix rank using tolerance-based approach
159    pub fn rank<T>(matrix: &Array2<T>, tolerance: Option<T>) -> usize
160    where
161        T: FloatBounds,
162    {
163        let (rows, cols) = matrix.dim();
164        let tol = tolerance.unwrap_or_else(|| {
165            T::from_f64(1e-12).unwrap_or_else(|| T::zero())
166                * T::from_usize(rows.max(cols)).unwrap_or_else(|| T::zero())
167        });
168
169        // Simplified rank computation - count non-zero diagonal elements
170        // In a full implementation, this would use SVD
171        let min_dim = rows.min(cols);
172        let mut rank = 0;
173
174        for i in 0..min_dim {
175            if matrix[[i, i]].abs() > tol {
176                rank += 1;
177            }
178        }
179
180        rank
181    }
182
183    /// Compute generalized inverse (Moore-Penrose pseudoinverse)
184    pub fn pinv<T>(matrix: &Array2<T>, _tolerance: Option<T>) -> Result<Array2<T>>
185    where
186        T: FloatBounds,
187    {
188        let (rows, cols) = matrix.dim();
189
190        // For square matrices, try regular inverse first
191        if rows == cols {
192            // Simplified approach - in practice would use LU decomposition
193            if let Ok(inv) = Self::try_inverse(matrix) {
194                return Ok(inv);
195            }
196        }
197
198        // Fall back to pseudoinverse computation
199        // This is a simplified implementation - real implementation would use SVD
200        let gram = if rows >= cols {
201            // Tall matrix: (A^T A)^-1 A^T
202            let at = matrix.t().to_owned();
203            let ata = at.dot(matrix);
204            let ata_inv = Self::try_inverse(&ata)?;
205            ata_inv.dot(&at)
206        } else {
207            // Wide matrix: A^T (A A^T)^-1
208            let at = matrix.t().to_owned();
209            let aat = matrix.dot(&at);
210            let aat_inv = Self::try_inverse(&aat)?;
211            at.dot(&aat_inv)
212        };
213
214        Ok(gram)
215    }
216
217    /// Helper method to attempt matrix inversion
218    fn try_inverse<T>(matrix: &Array2<T>) -> Result<Array2<T>>
219    where
220        T: FloatBounds,
221    {
222        let (rows, cols) = matrix.dim();
223        if rows != cols {
224            return Err(SklearsError::InvalidInput(
225                "Matrix must be square".to_string(),
226            ));
227        }
228
229        // Simplified inverse using diagonal matrix assumption
230        // Real implementation would use LU decomposition or similar
231        let mut inv = Array2::<T>::zeros((rows, cols));
232        for i in 0..rows {
233            let diag_val = matrix[[i, i]];
234            if diag_val.abs() < T::from_f64(1e-15).unwrap_or_else(|| T::zero()) {
235                return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
236            }
237            inv[[i, i]] = T::one() / diag_val;
238        }
239
240        Ok(inv)
241    }
242}
243
244/// Memory-efficient operations for large arrays
245pub struct MemoryOps;
246
247impl MemoryOps {
248    /// Compute dot product in chunks to reduce memory usage
249    pub fn chunked_dot<T>(a: &Array1<T>, b: &Array1<T>, chunk_size: Option<usize>) -> Result<T>
250    where
251        T: FloatBounds,
252    {
253        if a.len() != b.len() {
254            return Err(SklearsError::ShapeMismatch {
255                expected: format!("{}", a.len()),
256                actual: format!("{}", b.len()),
257            });
258        }
259
260        let chunk_size = chunk_size.unwrap_or(1024);
261        let mut result = T::zero();
262
263        for (a_chunk, b_chunk) in a
264            .exact_chunks(chunk_size)
265            .into_iter()
266            .zip(b.exact_chunks(chunk_size).into_iter())
267        {
268            result += a_chunk
269                .iter()
270                .zip(b_chunk.iter())
271                .map(|(&x, &y)| x * y)
272                .fold(T::zero(), |acc, x| acc + x);
273        }
274
275        // Handle remainder
276        let remainder_len = a.len() % chunk_size;
277        if remainder_len > 0 {
278            let start_idx = a.len() - remainder_len;
279            for i in 0..remainder_len {
280                result += a[start_idx + i] * b[start_idx + i];
281            }
282        }
283
284        Ok(result)
285    }
286
287    /// Streaming statistics computation for large datasets
288    pub fn streaming_stats<T>(values: impl Iterator<Item = T>) -> (T, T, usize)
289    where
290        T: FloatBounds,
291    {
292        let mut count = 0;
293        let mut mean = T::zero();
294        let mut m2 = T::zero();
295
296        for value in values {
297            count += 1;
298            let delta = value - mean;
299            mean += delta / T::from_usize(count).unwrap_or_else(|| T::zero());
300            let delta2 = value - mean;
301            m2 += delta * delta2;
302        }
303
304        let variance = if count > 1 {
305            m2 / T::from_usize(count - 1).unwrap_or_else(|| T::zero())
306        } else {
307            T::zero()
308        };
309
310        (mean, variance, count)
311    }
312}
313
314#[allow(non_snake_case)]
315#[cfg(test)]
316mod tests {
317    use super::*;
318    // SciRS2 Policy: Using scirs2_core::ndarray (COMPLIANT)
319    use approx::assert_abs_diff_eq;
320    use scirs2_core::ndarray::array;
321
322    #[test]
323    fn test_weighted_mean() {
324        let data = array![1.0, 2.0, 3.0, 4.0];
325        let weights = array![1.0, 2.0, 3.0, 4.0];
326
327        let result = ArrayStats::weighted_mean(&data, &weights).expect("expected valid value");
328        let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0 + 4.0 * 4.0) / (1.0 + 2.0 + 3.0 + 4.0);
329
330        assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
331    }
332
333    #[test]
334    fn test_percentile() {
335        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
336
337        let median = ArrayStats::percentile(&data, 50.0).expect("expected valid value");
338        assert_abs_diff_eq!(median, 3.0, epsilon = 1e-10);
339
340        let q25 = ArrayStats::percentile(&data, 25.0).expect("expected valid value");
341        assert_abs_diff_eq!(q25, 2.0, epsilon = 1e-10);
342    }
343
344    #[test]
345    fn test_chunked_dot() {
346        let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
347        let b = array![2.0, 3.0, 4.0, 5.0, 6.0];
348
349        let result = MemoryOps::chunked_dot(&a, &b, Some(2)).expect("expected valid value");
350        let expected: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
351
352        assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
353    }
354
355    #[test]
356    fn test_streaming_stats() {
357        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
358        let (mean, variance, count) = MemoryOps::streaming_stats(values.into_iter());
359
360        assert_eq!(count, 5);
361        assert_abs_diff_eq!(mean, 3.0, epsilon = 1e-10);
362        assert_abs_diff_eq!(variance, 2.5, epsilon = 1e-10);
363    }
364
365    #[test]
366    fn test_robust_covariance() {
367        // SciRS2 Policy: Using scirs2_core::ndarray (COMPLIANT)
368        use scirs2_core::ndarray::array;
369
370        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
371        let cov = ArrayStats::robust_covariance(&data, None).expect("expected valid value");
372
373        assert_eq!(cov.dim(), (2, 2));
374        // Basic sanity checks
375        assert!(cov[[0, 0]] > 0.0);
376        assert!(cov[[1, 1]] > 0.0);
377        assert_abs_diff_eq!(cov[[0, 1]], cov[[1, 0]], epsilon = 1e-10);
378    }
379}