1use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds};
7use scirs2_core::ndarray::Axis;
9
10pub struct ArrayStats;
12
13impl ArrayStats {
14 pub fn weighted_mean<T>(array: &Array1<T>, weights: &Array1<T>) -> Result<T>
16 where
17 T: FloatBounds,
18 {
19 if array.len() != weights.len() {
20 return Err(SklearsError::ShapeMismatch {
21 expected: format!("{}", array.len()),
22 actual: format!("{}", weights.len()),
23 });
24 }
25
26 let weight_sum = weights.sum();
27 if weight_sum == T::zero() {
28 return Err(SklearsError::InvalidInput(
29 "Weight sum cannot be zero".to_string(),
30 ));
31 }
32
33 let weighted_sum = array
34 .iter()
35 .zip(weights.iter())
36 .map(|(&x, &w)| x * w)
37 .fold(T::zero(), |acc, x| acc + x);
38
39 Ok(weighted_sum / weight_sum)
40 }
41
42 pub fn robust_covariance<T>(data: &Array2<T>, shrinkage: Option<T>) -> Result<Array2<T>>
44 where
45 T: FloatBounds + scirs2_core::ndarray::ScalarOperand,
46 {
47 let (n_samples, n_features) = data.dim();
48
49 if n_samples < 2 {
50 return Err(SklearsError::InvalidInput(
51 "Need at least 2 samples for covariance".to_string(),
52 ));
53 }
54
55 let means = data.mean_axis(Axis(0)).unwrap();
57
58 let centered = data - &means.insert_axis(Axis(0));
60
61 let cov_empirical = centered.t().dot(¢ered) / T::from_usize(n_samples - 1).unwrap();
63
64 if let Some(shrink) = shrinkage {
66 let identity = Array2::<T>::eye(n_features);
67 let trace = (0..n_features)
68 .map(|i| cov_empirical[[i, i]])
69 .fold(T::zero(), |acc, x| acc + x);
70 let target = identity * (trace / T::from_usize(n_features).unwrap());
71
72 Ok(&cov_empirical * (T::one() - shrink) + &target * shrink)
73 } else {
74 Ok(cov_empirical)
75 }
76 }
77
78 pub fn percentile<T>(array: &Array1<T>, q: T) -> Result<T>
80 where
81 T: FloatBounds + PartialOrd,
82 {
83 if array.is_empty() {
84 return Err(SklearsError::InvalidInput(
85 "Array cannot be empty".to_string(),
86 ));
87 }
88
89 if q < T::zero() || q > T::from_f64(100.0).unwrap() {
90 return Err(SklearsError::InvalidInput(
91 "Percentile must be between 0 and 100".to_string(),
92 ));
93 }
94
95 let mut sorted = array.to_vec();
96 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
97
98 let n = sorted.len();
99 let index = q / T::from_f64(100.0).unwrap() * T::from_usize(n - 1).unwrap();
100 let lower_idx = index.floor().to_usize().unwrap();
101 let upper_idx = index.ceil().to_usize().unwrap().min(n - 1);
102
103 if lower_idx == upper_idx {
104 Ok(sorted[lower_idx])
105 } else {
106 let lower_val = sorted[lower_idx];
107 let upper_val = sorted[upper_idx];
108 let weight = index - T::from_usize(lower_idx).unwrap();
109 Ok(lower_val * (T::one() - weight) + upper_val * weight)
110 }
111 }
112}
113
114pub struct MatrixOps;
116
117impl MatrixOps {
118 pub fn condition_number<T>(matrix: &Array2<T>) -> Result<T>
120 where
121 T: FloatBounds,
122 {
123 let (rows, cols) = matrix.dim();
126 if rows != cols {
127 return Err(SklearsError::InvalidInput(
128 "Matrix must be square for condition number".to_string(),
129 ));
130 }
131
132 let mut min_diag = T::infinity();
134 let mut max_diag = T::neg_infinity();
135
136 for i in 0..rows {
137 let diag_val = matrix[[i, i]].abs();
138 if diag_val < min_diag {
139 min_diag = diag_val;
140 }
141 if diag_val > max_diag {
142 max_diag = diag_val;
143 }
144 }
145
146 if min_diag == T::zero() {
147 Ok(T::infinity())
148 } else {
149 Ok(max_diag / min_diag)
150 }
151 }
152
153 pub fn rank<T>(matrix: &Array2<T>, tolerance: Option<T>) -> usize
155 where
156 T: FloatBounds,
157 {
158 let (rows, cols) = matrix.dim();
159 let tol = tolerance.unwrap_or_else(|| {
160 T::from_f64(1e-12).unwrap() * T::from_usize(rows.max(cols)).unwrap()
161 });
162
163 let min_dim = rows.min(cols);
166 let mut rank = 0;
167
168 for i in 0..min_dim {
169 if matrix[[i, i]].abs() > tol {
170 rank += 1;
171 }
172 }
173
174 rank
175 }
176
177 pub fn pinv<T>(matrix: &Array2<T>, _tolerance: Option<T>) -> Result<Array2<T>>
179 where
180 T: FloatBounds,
181 {
182 let (rows, cols) = matrix.dim();
183
184 if rows == cols {
186 if let Ok(inv) = Self::try_inverse(matrix) {
188 return Ok(inv);
189 }
190 }
191
192 let gram = if rows >= cols {
195 let at = matrix.t().to_owned();
197 let ata = at.dot(matrix);
198 let ata_inv = Self::try_inverse(&ata)?;
199 ata_inv.dot(&at)
200 } else {
201 let at = matrix.t().to_owned();
203 let aat = matrix.dot(&at);
204 let aat_inv = Self::try_inverse(&aat)?;
205 at.dot(&aat_inv)
206 };
207
208 Ok(gram)
209 }
210
211 fn try_inverse<T>(matrix: &Array2<T>) -> Result<Array2<T>>
213 where
214 T: FloatBounds,
215 {
216 let (rows, cols) = matrix.dim();
217 if rows != cols {
218 return Err(SklearsError::InvalidInput(
219 "Matrix must be square".to_string(),
220 ));
221 }
222
223 let mut inv = Array2::<T>::zeros((rows, cols));
226 for i in 0..rows {
227 let diag_val = matrix[[i, i]];
228 if diag_val.abs() < T::from_f64(1e-15).unwrap() {
229 return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
230 }
231 inv[[i, i]] = T::one() / diag_val;
232 }
233
234 Ok(inv)
235 }
236}
237
238pub struct MemoryOps;
240
241impl MemoryOps {
242 pub fn chunked_dot<T>(a: &Array1<T>, b: &Array1<T>, chunk_size: Option<usize>) -> Result<T>
244 where
245 T: FloatBounds,
246 {
247 if a.len() != b.len() {
248 return Err(SklearsError::ShapeMismatch {
249 expected: format!("{}", a.len()),
250 actual: format!("{}", b.len()),
251 });
252 }
253
254 let chunk_size = chunk_size.unwrap_or(1024);
255 let mut result = T::zero();
256
257 for (a_chunk, b_chunk) in a
258 .exact_chunks(chunk_size)
259 .into_iter()
260 .zip(b.exact_chunks(chunk_size).into_iter())
261 {
262 result += a_chunk
263 .iter()
264 .zip(b_chunk.iter())
265 .map(|(&x, &y)| x * y)
266 .fold(T::zero(), |acc, x| acc + x);
267 }
268
269 let remainder_len = a.len() % chunk_size;
271 if remainder_len > 0 {
272 let start_idx = a.len() - remainder_len;
273 for i in 0..remainder_len {
274 result += a[start_idx + i] * b[start_idx + i];
275 }
276 }
277
278 Ok(result)
279 }
280
281 pub fn streaming_stats<T>(values: impl Iterator<Item = T>) -> (T, T, usize)
283 where
284 T: FloatBounds,
285 {
286 let mut count = 0;
287 let mut mean = T::zero();
288 let mut m2 = T::zero();
289
290 for value in values {
291 count += 1;
292 let delta = value - mean;
293 mean += delta / T::from_usize(count).unwrap();
294 let delta2 = value - mean;
295 m2 += delta * delta2;
296 }
297
298 let variance = if count > 1 {
299 m2 / T::from_usize(count - 1).unwrap()
300 } else {
301 T::zero()
302 };
303
304 (mean, variance, count)
305 }
306}
307
308#[allow(non_snake_case)]
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use approx::assert_abs_diff_eq;
314 use scirs2_core::ndarray::array;
315
316 #[test]
317 fn test_weighted_mean() {
318 let data = array![1.0, 2.0, 3.0, 4.0];
319 let weights = array![1.0, 2.0, 3.0, 4.0];
320
321 let result = ArrayStats::weighted_mean(&data, &weights).unwrap();
322 let expected = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0 + 4.0 * 4.0) / (1.0 + 2.0 + 3.0 + 4.0);
323
324 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
325 }
326
327 #[test]
328 fn test_percentile() {
329 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
330
331 let median = ArrayStats::percentile(&data, 50.0).unwrap();
332 assert_abs_diff_eq!(median, 3.0, epsilon = 1e-10);
333
334 let q25 = ArrayStats::percentile(&data, 25.0).unwrap();
335 assert_abs_diff_eq!(q25, 2.0, epsilon = 1e-10);
336 }
337
338 #[test]
339 fn test_chunked_dot() {
340 let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
341 let b = array![2.0, 3.0, 4.0, 5.0, 6.0];
342
343 let result = MemoryOps::chunked_dot(&a, &b, Some(2)).unwrap();
344 let expected: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
345
346 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
347 }
348
349 #[test]
350 fn test_streaming_stats() {
351 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
352 let (mean, variance, count) = MemoryOps::streaming_stats(values.into_iter());
353
354 assert_eq!(count, 5);
355 assert_abs_diff_eq!(mean, 3.0, epsilon = 1e-10);
356 assert_abs_diff_eq!(variance, 2.5, epsilon = 1e-10);
357 }
358
359 #[test]
360 fn test_robust_covariance() {
361 use scirs2_core::ndarray::array;
363
364 let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
365 let cov = ArrayStats::robust_covariance(&data, None).unwrap();
366
367 assert_eq!(cov.dim(), (2, 2));
368 assert!(cov[[0, 0]] > 0.0);
370 assert!(cov[[1, 1]] > 0.0);
371 assert_abs_diff_eq!(cov[[0, 1]], cov[[1, 0]], epsilon = 1e-10);
372 }
373}