1use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds};
7use scirs2_core::ndarray::Axis;
9
10pub struct ArrayStats;
12
13impl ArrayStats {
14 pub fn weighted_mean<T>(array: &Array1<T>, weights: &Array1<T>) -> Result<T>
16 where
17 T: FloatBounds,
18 {
19 if array.len() != weights.len() {
20 return Err(SklearsError::ShapeMismatch {
21 expected: format!("{}", array.len()),
22 actual: format!("{}", weights.len()),
23 });
24 }
25
26 let weight_sum = weights.sum();
27 if weight_sum == T::zero() {
28 return Err(SklearsError::InvalidInput(
29 "Weight sum cannot be zero".to_string(),
30 ));
31 }
32
33 let weighted_sum = array
34 .iter()
35 .zip(weights.iter())
36 .map(|(&x, &w)| x * w)
37 .fold(T::zero(), |acc, x| acc + x);
38
39 Ok(weighted_sum / weight_sum)
40 }
41
42 pub fn robust_covariance<T>(data: &Array2<T>, shrinkage: Option<T>) -> Result<Array2<T>>
44 where
45 T: FloatBounds + scirs2_core::ndarray::ScalarOperand,
46 {
47 let (n_samples, n_features) = data.dim();
48
49 if n_samples < 2 {
50 return Err(SklearsError::InvalidInput(
51 "Need at least 2 samples for covariance".to_string(),
52 ));
53 }
54
55 let means = data.mean_axis(Axis(0)).ok_or_else(|| {
57 SklearsError::NumericalError("mean_axis computation failed on empty axis".to_string())
58 })?;
59
60 let centered = data - &means.insert_axis(Axis(0));
62
63 let cov_empirical =
65 centered.t().dot(¢ered) / T::from_usize(n_samples - 1).unwrap_or_else(|| T::zero());
66
67 if let Some(shrink) = shrinkage {
69 let identity = Array2::<T>::eye(n_features);
70 let trace = (0..n_features)
71 .map(|i| cov_empirical[[i, i]])
72 .fold(T::zero(), |acc, x| acc + x);
73 let target =
74 identity * (trace / T::from_usize(n_features).unwrap_or_else(|| T::zero()));
75
76 Ok(&cov_empirical * (T::one() - shrink) + &target * shrink)
77 } else {
78 Ok(cov_empirical)
79 }
80 }
81
82 pub fn percentile<T>(array: &Array1<T>, q: T) -> Result<T>
84 where
85 T: FloatBounds + PartialOrd,
86 {
87 if array.is_empty() {
88 return Err(SklearsError::InvalidInput(
89 "Array cannot be empty".to_string(),
90 ));
91 }
92
93 if q < T::zero() || q > T::from_f64(100.0).unwrap_or_else(|| T::zero()) {
94 return Err(SklearsError::InvalidInput(
95 "Percentile must be between 0 and 100".to_string(),
96 ));
97 }
98
99 let mut sorted = array.to_vec();
100 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101
102 let n = sorted.len();
103 let index = q / T::from_f64(100.0).unwrap_or_else(|| T::zero())
104 * T::from_usize(n - 1).unwrap_or_else(|| T::zero());
105 let lower_idx = index.floor().to_usize().unwrap_or(0);
106 let upper_idx = index.ceil().to_usize().unwrap_or(0).min(n - 1);
107
108 if lower_idx == upper_idx {
109 Ok(sorted[lower_idx])
110 } else {
111 let lower_val = sorted[lower_idx];
112 let upper_val = sorted[upper_idx];
113 let weight = index - T::from_usize(lower_idx).unwrap_or_else(|| T::zero());
114 Ok(lower_val * (T::one() - weight) + upper_val * weight)
115 }
116 }
117}
118
119pub struct MatrixOps;
121
122impl MatrixOps {
123 pub fn condition_number<T>(matrix: &Array2<T>) -> Result<T>
125 where
126 T: FloatBounds,
127 {
128 let (rows, cols) = matrix.dim();
131 if rows != cols {
132 return Err(SklearsError::InvalidInput(
133 "Matrix must be square for condition number".to_string(),
134 ));
135 }
136
137 let mut min_diag = T::infinity();
139 let mut max_diag = T::neg_infinity();
140
141 for i in 0..rows {
142 let diag_val = matrix[[i, i]].abs();
143 if diag_val < min_diag {
144 min_diag = diag_val;
145 }
146 if diag_val > max_diag {
147 max_diag = diag_val;
148 }
149 }
150
151 if min_diag == T::zero() {
152 Ok(T::infinity())
153 } else {
154 Ok(max_diag / min_diag)
155 }
156 }
157
158 pub fn rank<T>(matrix: &Array2<T>, tolerance: Option<T>) -> usize
160 where
161 T: FloatBounds,
162 {
163 let (rows, cols) = matrix.dim();
164 let tol = tolerance.unwrap_or_else(|| {
165 T::from_f64(1e-12).unwrap_or_else(|| T::zero())
166 * T::from_usize(rows.max(cols)).unwrap_or_else(|| T::zero())
167 });
168
169 let min_dim = rows.min(cols);
172 let mut rank = 0;
173
174 for i in 0..min_dim {
175 if matrix[[i, i]].abs() > tol {
176 rank += 1;
177 }
178 }
179
180 rank
181 }
182
183 pub fn pinv<T>(matrix: &Array2<T>, _tolerance: Option<T>) -> Result<Array2<T>>
185 where
186 T: FloatBounds,
187 {
188 let (rows, cols) = matrix.dim();
189
190 if rows == cols {
192 if let Ok(inv) = Self::try_inverse(matrix) {
194 return Ok(inv);
195 }
196 }
197
198 let gram = if rows >= cols {
201 let at = matrix.t().to_owned();
203 let ata = at.dot(matrix);
204 let ata_inv = Self::try_inverse(&ata)?;
205 ata_inv.dot(&at)
206 } else {
207 let at = matrix.t().to_owned();
209 let aat = matrix.dot(&at);
210 let aat_inv = Self::try_inverse(&aat)?;
211 at.dot(&aat_inv)
212 };
213
214 Ok(gram)
215 }
216
217 fn try_inverse<T>(matrix: &Array2<T>) -> Result<Array2<T>>
219 where
220 T: FloatBounds,
221 {
222 let (rows, cols) = matrix.dim();
223 if rows != cols {
224 return Err(SklearsError::InvalidInput(
225 "Matrix must be square".to_string(),
226 ));
227 }
228
229 let mut inv = Array2::<T>::zeros((rows, cols));
232 for i in 0..rows {
233 let diag_val = matrix[[i, i]];
234 if diag_val.abs() < T::from_f64(1e-15).unwrap_or_else(|| T::zero()) {
235 return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
236 }
237 inv[[i, i]] = T::one() / diag_val;
238 }
239
240 Ok(inv)
241 }
242}
243
244pub struct MemoryOps;
246
247impl MemoryOps {
248 pub fn chunked_dot<T>(a: &Array1<T>, b: &Array1<T>, chunk_size: Option<usize>) -> Result<T>
250 where
251 T: FloatBounds,
252 {
253 if a.len() != b.len() {
254 return Err(SklearsError::ShapeMismatch {
255 expected: format!("{}", a.len()),
256 actual: format!("{}", b.len()),
257 });
258 }
259
260 let chunk_size = chunk_size.unwrap_or(1024);
261 let mut result = T::zero();
262
263 let a_chunks: Vec<_> = a.exact_chunks(chunk_size).into_iter().collect();
264 let b_chunks: Vec<_> = b.exact_chunks(chunk_size).into_iter().collect();
265 for (a_chunk, b_chunk) in a_chunks.iter().zip(b_chunks.iter()) {
266 result += a_chunk
267 .iter()
268 .zip(b_chunk.iter())
269 .map(|(&x, &y)| x * y)
270 .fold(T::zero(), |acc, x| acc + x);
271 }
272
273 let remainder_len = a.len() % chunk_size;
275 if remainder_len > 0 {
276 let start_idx = a.len() - remainder_len;
277 for i in 0..remainder_len {
278 result += a[start_idx + i] * b[start_idx + i];
279 }
280 }
281
282 Ok(result)
283 }
284
285 pub fn streaming_stats<T>(values: impl Iterator<Item = T>) -> (T, T, usize)
287 where
288 T: FloatBounds,
289 {
290 let mut count = 0;
291 let mut mean = T::zero();
292 let mut m2 = T::zero();
293
294 for value in values {
295 count += 1;
296 let delta = value - mean;
297 mean += delta / T::from_usize(count).unwrap_or_else(|| T::zero());
298 let delta2 = value - mean;
299 m2 += delta * delta2;
300 }
301
302 let variance = if count > 1 {
303 m2 / T::from_usize(count - 1).unwrap_or_else(|| T::zero())
304 } else {
305 T::zero()
306 };
307
308 (mean, variance, count)
309 }
310}
311
312#[allow(non_snake_case)]
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use approx::assert_abs_diff_eq;
318 use scirs2_core::ndarray::array;
319
320 #[test]
321 fn test_weighted_mean() {
322 let data = array![1.0, 2.0, 3.0, 4.0];
323 let weights = array![1.0, 2.0, 3.0, 4.0];
324
325 let result = ArrayStats::weighted_mean(&data, &weights).expect("expected valid value");
326 let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0 + 4.0 * 4.0) / (1.0 + 2.0 + 3.0 + 4.0);
327
328 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
329 }
330
331 #[test]
332 fn test_percentile() {
333 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
334
335 let median = ArrayStats::percentile(&data, 50.0).expect("expected valid value");
336 assert_abs_diff_eq!(median, 3.0, epsilon = 1e-10);
337
338 let q25 = ArrayStats::percentile(&data, 25.0).expect("expected valid value");
339 assert_abs_diff_eq!(q25, 2.0, epsilon = 1e-10);
340 }
341
342 #[test]
343 fn test_chunked_dot() {
344 let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
345 let b = array![2.0, 3.0, 4.0, 5.0, 6.0];
346
347 let result = MemoryOps::chunked_dot(&a, &b, Some(2)).expect("expected valid value");
348 let expected: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
349
350 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
351 }
352
353 #[test]
354 fn test_streaming_stats() {
355 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
356 let (mean, variance, count) = MemoryOps::streaming_stats(values.into_iter());
357
358 assert_eq!(count, 5);
359 assert_abs_diff_eq!(mean, 3.0, epsilon = 1e-10);
360 assert_abs_diff_eq!(variance, 2.5, epsilon = 1e-10);
361 }
362
363 #[test]
364 fn test_robust_covariance() {
365 use scirs2_core::ndarray::array;
367
368 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
369 let cov = ArrayStats::robust_covariance(&data, None).expect("expected valid value");
370
371 assert_eq!(cov.dim(), (2, 2));
372 assert!(cov[[0, 0]] > 0.0);
374 assert!(cov[[1, 1]] > 0.0);
375 assert_abs_diff_eq!(cov[[0, 1]], cov[[1, 0]], epsilon = 1e-10);
376 }
377}