1use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds};
7use scirs2_core::ndarray::Axis;
9
10pub struct ArrayStats;
12
13impl ArrayStats {
14 pub fn weighted_mean<T>(array: &Array1<T>, weights: &Array1<T>) -> Result<T>
16 where
17 T: FloatBounds,
18 {
19 if array.len() != weights.len() {
20 return Err(SklearsError::ShapeMismatch {
21 expected: format!("{}", array.len()),
22 actual: format!("{}", weights.len()),
23 });
24 }
25
26 let weight_sum = weights.sum();
27 if weight_sum == T::zero() {
28 return Err(SklearsError::InvalidInput(
29 "Weight sum cannot be zero".to_string(),
30 ));
31 }
32
33 let weighted_sum = array
34 .iter()
35 .zip(weights.iter())
36 .map(|(&x, &w)| x * w)
37 .fold(T::zero(), |acc, x| acc + x);
38
39 Ok(weighted_sum / weight_sum)
40 }
41
42 pub fn robust_covariance<T>(data: &Array2<T>, shrinkage: Option<T>) -> Result<Array2<T>>
44 where
45 T: FloatBounds + scirs2_core::ndarray::ScalarOperand,
46 {
47 let (n_samples, n_features) = data.dim();
48
49 if n_samples < 2 {
50 return Err(SklearsError::InvalidInput(
51 "Need at least 2 samples for covariance".to_string(),
52 ));
53 }
54
55 let means = data.mean_axis(Axis(0)).ok_or_else(|| {
57 SklearsError::NumericalError("mean_axis computation failed on empty axis".to_string())
58 })?;
59
60 let centered = data - &means.insert_axis(Axis(0));
62
63 let cov_empirical =
65 centered.t().dot(¢ered) / T::from_usize(n_samples - 1).unwrap_or_else(|| T::zero());
66
67 if let Some(shrink) = shrinkage {
69 let identity = Array2::<T>::eye(n_features);
70 let trace = (0..n_features)
71 .map(|i| cov_empirical[[i, i]])
72 .fold(T::zero(), |acc, x| acc + x);
73 let target =
74 identity * (trace / T::from_usize(n_features).unwrap_or_else(|| T::zero()));
75
76 Ok(&cov_empirical * (T::one() - shrink) + &target * shrink)
77 } else {
78 Ok(cov_empirical)
79 }
80 }
81
82 pub fn percentile<T>(array: &Array1<T>, q: T) -> Result<T>
84 where
85 T: FloatBounds + PartialOrd,
86 {
87 if array.is_empty() {
88 return Err(SklearsError::InvalidInput(
89 "Array cannot be empty".to_string(),
90 ));
91 }
92
93 if q < T::zero() || q > T::from_f64(100.0).unwrap_or_else(|| T::zero()) {
94 return Err(SklearsError::InvalidInput(
95 "Percentile must be between 0 and 100".to_string(),
96 ));
97 }
98
99 let mut sorted = array.to_vec();
100 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101
102 let n = sorted.len();
103 let index = q / T::from_f64(100.0).unwrap_or_else(|| T::zero())
104 * T::from_usize(n - 1).unwrap_or_else(|| T::zero());
105 let lower_idx = index.floor().to_usize().unwrap_or(0);
106 let upper_idx = index.ceil().to_usize().unwrap_or(0).min(n - 1);
107
108 if lower_idx == upper_idx {
109 Ok(sorted[lower_idx])
110 } else {
111 let lower_val = sorted[lower_idx];
112 let upper_val = sorted[upper_idx];
113 let weight = index - T::from_usize(lower_idx).unwrap_or_else(|| T::zero());
114 Ok(lower_val * (T::one() - weight) + upper_val * weight)
115 }
116 }
117}
118
119pub struct MatrixOps;
121
122impl MatrixOps {
123 pub fn condition_number<T>(matrix: &Array2<T>) -> Result<T>
125 where
126 T: FloatBounds,
127 {
128 let (rows, cols) = matrix.dim();
131 if rows != cols {
132 return Err(SklearsError::InvalidInput(
133 "Matrix must be square for condition number".to_string(),
134 ));
135 }
136
137 let mut min_diag = T::infinity();
139 let mut max_diag = T::neg_infinity();
140
141 for i in 0..rows {
142 let diag_val = matrix[[i, i]].abs();
143 if diag_val < min_diag {
144 min_diag = diag_val;
145 }
146 if diag_val > max_diag {
147 max_diag = diag_val;
148 }
149 }
150
151 if min_diag == T::zero() {
152 Ok(T::infinity())
153 } else {
154 Ok(max_diag / min_diag)
155 }
156 }
157
158 pub fn rank<T>(matrix: &Array2<T>, tolerance: Option<T>) -> usize
160 where
161 T: FloatBounds,
162 {
163 let (rows, cols) = matrix.dim();
164 let tol = tolerance.unwrap_or_else(|| {
165 T::from_f64(1e-12).unwrap_or_else(|| T::zero())
166 * T::from_usize(rows.max(cols)).unwrap_or_else(|| T::zero())
167 });
168
169 let min_dim = rows.min(cols);
172 let mut rank = 0;
173
174 for i in 0..min_dim {
175 if matrix[[i, i]].abs() > tol {
176 rank += 1;
177 }
178 }
179
180 rank
181 }
182
183 pub fn pinv<T>(matrix: &Array2<T>, _tolerance: Option<T>) -> Result<Array2<T>>
185 where
186 T: FloatBounds,
187 {
188 let (rows, cols) = matrix.dim();
189
190 if rows == cols {
192 if let Ok(inv) = Self::try_inverse(matrix) {
194 return Ok(inv);
195 }
196 }
197
198 let gram = if rows >= cols {
201 let at = matrix.t().to_owned();
203 let ata = at.dot(matrix);
204 let ata_inv = Self::try_inverse(&ata)?;
205 ata_inv.dot(&at)
206 } else {
207 let at = matrix.t().to_owned();
209 let aat = matrix.dot(&at);
210 let aat_inv = Self::try_inverse(&aat)?;
211 at.dot(&aat_inv)
212 };
213
214 Ok(gram)
215 }
216
217 fn try_inverse<T>(matrix: &Array2<T>) -> Result<Array2<T>>
219 where
220 T: FloatBounds,
221 {
222 let (rows, cols) = matrix.dim();
223 if rows != cols {
224 return Err(SklearsError::InvalidInput(
225 "Matrix must be square".to_string(),
226 ));
227 }
228
229 let mut inv = Array2::<T>::zeros((rows, cols));
232 for i in 0..rows {
233 let diag_val = matrix[[i, i]];
234 if diag_val.abs() < T::from_f64(1e-15).unwrap_or_else(|| T::zero()) {
235 return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
236 }
237 inv[[i, i]] = T::one() / diag_val;
238 }
239
240 Ok(inv)
241 }
242}
243
244pub struct MemoryOps;
246
247impl MemoryOps {
248 pub fn chunked_dot<T>(a: &Array1<T>, b: &Array1<T>, chunk_size: Option<usize>) -> Result<T>
250 where
251 T: FloatBounds,
252 {
253 if a.len() != b.len() {
254 return Err(SklearsError::ShapeMismatch {
255 expected: format!("{}", a.len()),
256 actual: format!("{}", b.len()),
257 });
258 }
259
260 let chunk_size = chunk_size.unwrap_or(1024);
261 let mut result = T::zero();
262
263 for (a_chunk, b_chunk) in a
264 .exact_chunks(chunk_size)
265 .into_iter()
266 .zip(b.exact_chunks(chunk_size).into_iter())
267 {
268 result += a_chunk
269 .iter()
270 .zip(b_chunk.iter())
271 .map(|(&x, &y)| x * y)
272 .fold(T::zero(), |acc, x| acc + x);
273 }
274
275 let remainder_len = a.len() % chunk_size;
277 if remainder_len > 0 {
278 let start_idx = a.len() - remainder_len;
279 for i in 0..remainder_len {
280 result += a[start_idx + i] * b[start_idx + i];
281 }
282 }
283
284 Ok(result)
285 }
286
287 pub fn streaming_stats<T>(values: impl Iterator<Item = T>) -> (T, T, usize)
289 where
290 T: FloatBounds,
291 {
292 let mut count = 0;
293 let mut mean = T::zero();
294 let mut m2 = T::zero();
295
296 for value in values {
297 count += 1;
298 let delta = value - mean;
299 mean += delta / T::from_usize(count).unwrap_or_else(|| T::zero());
300 let delta2 = value - mean;
301 m2 += delta * delta2;
302 }
303
304 let variance = if count > 1 {
305 m2 / T::from_usize(count - 1).unwrap_or_else(|| T::zero())
306 } else {
307 T::zero()
308 };
309
310 (mean, variance, count)
311 }
312}
313
314#[allow(non_snake_case)]
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use approx::assert_abs_diff_eq;
320 use scirs2_core::ndarray::array;
321
322 #[test]
323 fn test_weighted_mean() {
324 let data = array![1.0, 2.0, 3.0, 4.0];
325 let weights = array![1.0, 2.0, 3.0, 4.0];
326
327 let result = ArrayStats::weighted_mean(&data, &weights).expect("expected valid value");
328 let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0 + 4.0 * 4.0) / (1.0 + 2.0 + 3.0 + 4.0);
329
330 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
331 }
332
333 #[test]
334 fn test_percentile() {
335 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
336
337 let median = ArrayStats::percentile(&data, 50.0).expect("expected valid value");
338 assert_abs_diff_eq!(median, 3.0, epsilon = 1e-10);
339
340 let q25 = ArrayStats::percentile(&data, 25.0).expect("expected valid value");
341 assert_abs_diff_eq!(q25, 2.0, epsilon = 1e-10);
342 }
343
344 #[test]
345 fn test_chunked_dot() {
346 let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
347 let b = array![2.0, 3.0, 4.0, 5.0, 6.0];
348
349 let result = MemoryOps::chunked_dot(&a, &b, Some(2)).expect("expected valid value");
350 let expected: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
351
352 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
353 }
354
355 #[test]
356 fn test_streaming_stats() {
357 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
358 let (mean, variance, count) = MemoryOps::streaming_stats(values.into_iter());
359
360 assert_eq!(count, 5);
361 assert_abs_diff_eq!(mean, 3.0, epsilon = 1e-10);
362 assert_abs_diff_eq!(variance, 2.5, epsilon = 1e-10);
363 }
364
365 #[test]
366 fn test_robust_covariance() {
367 use scirs2_core::ndarray::array;
369
370 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
371 let cov = ArrayStats::robust_covariance(&data, None).expect("expected valid value");
372
373 assert_eq!(cov.dim(), (2, 2));
374 assert!(cov[[0, 0]] > 0.0);
376 assert!(cov[[1, 1]] > 0.0);
377 assert_abs_diff_eq!(cov[[0, 1]], cov[[1, 0]], epsilon = 1e-10);
378 }
379}