Skip to main content

scirs2/stats/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use pyo3::exceptions::PyRuntimeError;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyDict};
8use scirs2_numpy::{PyArray1, PyArrayMethods};
9use scirs2_stats::tests::ttest::{ttest_1samp, ttest_ind, Alternative};
10use scirs2_stats::{covariance_simd, pearsonr};
11
12/// Compute descriptive statistics - uses optimized implementations
13/// Returns dict with mean, std, var, min, max, median, count
14#[pyfunction]
15pub fn describe_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyAny>> {
16    let binding = data.readonly();
17    let arr = binding.as_array();
18    let n = arr.len();
19    if n == 0 {
20        return Err(PyRuntimeError::new_err("Empty array provided"));
21    }
22    let slice = arr
23        .as_slice()
24        .ok_or_else(|| PyRuntimeError::new_err("Array is not contiguous in memory"))?;
25    let mut sum0 = 0.0f64;
26    let mut sum1 = 0.0f64;
27    let mut sum2 = 0.0f64;
28    let mut sum3 = 0.0f64;
29    let chunks = slice.chunks_exact(8);
30    let remainder = chunks.remainder();
31    for chunk in chunks {
32        sum0 += chunk[0] + chunk[4];
33        sum1 += chunk[1] + chunk[5];
34        sum2 += chunk[2] + chunk[6];
35        sum3 += chunk[3] + chunk[7];
36    }
37    let mut sum = sum0 + sum1 + sum2 + sum3;
38    for &val in remainder {
39        sum += val;
40    }
41    let mean_val = sum / n as f64;
42    let mut sq0 = 0.0f64;
43    let mut sq1 = 0.0f64;
44    let mut sq2 = 0.0f64;
45    let mut sq3 = 0.0f64;
46    let mut min_val = f64::INFINITY;
47    let mut max_val = f64::NEG_INFINITY;
48    let chunks = slice.chunks_exact(8);
49    let remainder = chunks.remainder();
50    for chunk in chunks {
51        let d0 = chunk[0] - mean_val;
52        let d1 = chunk[1] - mean_val;
53        let d2 = chunk[2] - mean_val;
54        let d3 = chunk[3] - mean_val;
55        let d4 = chunk[4] - mean_val;
56        let d5 = chunk[5] - mean_val;
57        let d6 = chunk[6] - mean_val;
58        let d7 = chunk[7] - mean_val;
59        sq0 += d0 * d0 + d4 * d4;
60        sq1 += d1 * d1 + d5 * d5;
61        sq2 += d2 * d2 + d6 * d6;
62        sq3 += d3 * d3 + d7 * d7;
63        min_val = min_val
64            .min(chunk[0])
65            .min(chunk[1])
66            .min(chunk[2])
67            .min(chunk[3])
68            .min(chunk[4])
69            .min(chunk[5])
70            .min(chunk[6])
71            .min(chunk[7]);
72        max_val = max_val
73            .max(chunk[0])
74            .max(chunk[1])
75            .max(chunk[2])
76            .max(chunk[3])
77            .max(chunk[4])
78            .max(chunk[5])
79            .max(chunk[6])
80            .max(chunk[7]);
81    }
82    for &val in remainder {
83        let d = val - mean_val;
84        sq0 += d * d;
85        min_val = min_val.min(val);
86        max_val = max_val.max(val);
87    }
88    let sq_sum = sq0 + sq1 + sq2 + sq3;
89    let var_val = if n > 1 { sq_sum / (n - 1) as f64 } else { 0.0 };
90    let std_val = var_val.sqrt();
91    let mut vec: Vec<f64> = arr.iter().cloned().collect();
92    let median_val = if n % 2 == 1 {
93        let mid = n / 2;
94        let (_, val, _) = vec.select_nth_unstable_by(mid, |a, b| {
95            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
96        });
97        *val
98    } else {
99        let mid = n / 2;
100        let (lower, val_at_mid, _) = vec.select_nth_unstable_by(mid, |a, b| {
101            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
102        });
103        let val_mid = *val_at_mid;
104        let val_mid_minus_1 = lower
105            .iter()
106            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
107            .copied()
108            .unwrap_or(val_mid);
109        (val_mid_minus_1 + val_mid) / 2.0
110    };
111    let dict = PyDict::new(py);
112    dict.set_item("mean", mean_val)?;
113    dict.set_item("std", std_val)?;
114    dict.set_item("var", var_val)?;
115    dict.set_item("min", min_val)?;
116    dict.set_item("max", max_val)?;
117    dict.set_item("median", median_val)?;
118    dict.set_item("count", n)?;
119    Ok(dict.into())
120}
121/// Calculate mean - optimized with 8-way unrolling and multiple accumulators
122#[pyfunction]
123pub fn mean_py(data: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
124    let binding = data.readonly();
125    let arr = binding.as_array();
126    let n = arr.len();
127    if n == 0 {
128        return Err(PyRuntimeError::new_err("Empty array provided"));
129    }
130    let slice = arr
131        .as_slice()
132        .ok_or_else(|| PyRuntimeError::new_err("Array is not contiguous in memory"))?;
133    let mut sum0 = 0.0f64;
134    let mut sum1 = 0.0f64;
135    let mut sum2 = 0.0f64;
136    let mut sum3 = 0.0f64;
137    let chunks = slice.chunks_exact(8);
138    let remainder = chunks.remainder();
139    for chunk in chunks {
140        sum0 += chunk[0] + chunk[4];
141        sum1 += chunk[1] + chunk[5];
142        sum2 += chunk[2] + chunk[6];
143        sum3 += chunk[3] + chunk[7];
144    }
145    let mut sum = sum0 + sum1 + sum2 + sum3;
146    for &val in remainder {
147        sum += val;
148    }
149    Ok(sum / n as f64)
150}
151/// Calculate standard deviation - optimized two-pass with multi-accumulator
152#[pyfunction]
153#[pyo3(signature = (data, ddof = 0))]
154pub fn std_py(data: &Bound<'_, PyArray1<f64>>, ddof: usize) -> PyResult<f64> {
155    let binding = data.readonly();
156    let arr = binding.as_array();
157    let n = arr.len();
158    if n == 0 {
159        return Err(PyRuntimeError::new_err("Empty array provided"));
160    }
161    if n <= ddof {
162        return Err(PyRuntimeError::new_err(
163            "Not enough data points for given ddof",
164        ));
165    }
166    let slice = arr
167        .as_slice()
168        .ok_or_else(|| PyRuntimeError::new_err("Array is not contiguous in memory"))?;
169    let mut sum0 = 0.0f64;
170    let mut sum1 = 0.0f64;
171    let mut sum2 = 0.0f64;
172    let mut sum3 = 0.0f64;
173    let chunks = slice.chunks_exact(8);
174    let remainder = chunks.remainder();
175    for chunk in chunks {
176        sum0 += chunk[0] + chunk[4];
177        sum1 += chunk[1] + chunk[5];
178        sum2 += chunk[2] + chunk[6];
179        sum3 += chunk[3] + chunk[7];
180    }
181    let mut sum = sum0 + sum1 + sum2 + sum3;
182    for &val in remainder {
183        sum += val;
184    }
185    let mean = sum / n as f64;
186    let mut sq0 = 0.0f64;
187    let mut sq1 = 0.0f64;
188    let mut sq2 = 0.0f64;
189    let mut sq3 = 0.0f64;
190    let chunks = slice.chunks_exact(8);
191    let remainder = chunks.remainder();
192    for chunk in chunks {
193        let d0 = chunk[0] - mean;
194        let d1 = chunk[1] - mean;
195        let d2 = chunk[2] - mean;
196        let d3 = chunk[3] - mean;
197        let d4 = chunk[4] - mean;
198        let d5 = chunk[5] - mean;
199        let d6 = chunk[6] - mean;
200        let d7 = chunk[7] - mean;
201        sq0 += d0 * d0 + d4 * d4;
202        sq1 += d1 * d1 + d5 * d5;
203        sq2 += d2 * d2 + d6 * d6;
204        sq3 += d3 * d3 + d7 * d7;
205    }
206    let mut sq_sum = sq0 + sq1 + sq2 + sq3;
207    for &val in remainder {
208        let d = val - mean;
209        sq_sum += d * d;
210    }
211    let variance = sq_sum / (n - ddof) as f64;
212    Ok(variance.sqrt())
213}
214/// Calculate variance - optimized two-pass with multi-accumulator
215#[pyfunction]
216#[pyo3(signature = (data, ddof = 0))]
217pub fn var_py(data: &Bound<'_, PyArray1<f64>>, ddof: usize) -> PyResult<f64> {
218    let binding = data.readonly();
219    let arr = binding.as_array();
220    let n = arr.len();
221    if n == 0 {
222        return Err(PyRuntimeError::new_err("Empty array provided"));
223    }
224    if n <= ddof {
225        return Err(PyRuntimeError::new_err(
226            "Not enough data points for given ddof",
227        ));
228    }
229    let slice = arr
230        .as_slice()
231        .ok_or_else(|| PyRuntimeError::new_err("Array is not contiguous in memory"))?;
232    let mut sum0 = 0.0f64;
233    let mut sum1 = 0.0f64;
234    let mut sum2 = 0.0f64;
235    let mut sum3 = 0.0f64;
236    let chunks = slice.chunks_exact(8);
237    let remainder = chunks.remainder();
238    for chunk in chunks {
239        sum0 += chunk[0] + chunk[4];
240        sum1 += chunk[1] + chunk[5];
241        sum2 += chunk[2] + chunk[6];
242        sum3 += chunk[3] + chunk[7];
243    }
244    let mut sum = sum0 + sum1 + sum2 + sum3;
245    for &val in remainder {
246        sum += val;
247    }
248    let mean = sum / n as f64;
249    let mut sq0 = 0.0f64;
250    let mut sq1 = 0.0f64;
251    let mut sq2 = 0.0f64;
252    let mut sq3 = 0.0f64;
253    let chunks = slice.chunks_exact(8);
254    let remainder = chunks.remainder();
255    for chunk in chunks {
256        let d0 = chunk[0] - mean;
257        let d1 = chunk[1] - mean;
258        let d2 = chunk[2] - mean;
259        let d3 = chunk[3] - mean;
260        let d4 = chunk[4] - mean;
261        let d5 = chunk[5] - mean;
262        let d6 = chunk[6] - mean;
263        let d7 = chunk[7] - mean;
264        sq0 += d0 * d0 + d4 * d4;
265        sq1 += d1 * d1 + d5 * d5;
266        sq2 += d2 * d2 + d6 * d6;
267        sq3 += d3 * d3 + d7 * d7;
268    }
269    let mut sq_sum = sq0 + sq1 + sq2 + sq3;
270    for &val in remainder {
271        let d = val - mean;
272        sq_sum += d * d;
273    }
274    Ok(sq_sum / (n - ddof) as f64)
275}
276/// Calculate percentile using optimized partial sort
277/// q: percentile value (0-100)
278#[pyfunction]
279pub fn percentile_py(data: &Bound<'_, PyArray1<f64>>, q: f64) -> PyResult<f64> {
280    let binding = data.readonly();
281    let arr = binding.as_array();
282    let n = arr.len();
283    if n == 0 {
284        return Err(PyRuntimeError::new_err("Empty array provided"));
285    }
286    if !(0.0..=100.0).contains(&q) {
287        return Err(PyRuntimeError::new_err(
288            "Percentile must be between 0 and 100",
289        ));
290    }
291    let mut vec: Vec<f64> = arr.iter().cloned().collect();
292    let q_norm = q / 100.0;
293    let virtual_index = q_norm * (n - 1) as f64;
294    let i = virtual_index.floor() as usize;
295    let fraction = virtual_index - i as f64;
296    if fraction == 0.0 || i >= n - 1 {
297        let idx = i.min(n - 1);
298        let (_, val, _) = vec.select_nth_unstable_by(idx, |a, b| {
299            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
300        });
301        Ok(*val)
302    } else {
303        let (lower, val_i1, _) = vec.select_nth_unstable_by(i + 1, |a, b| {
304            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
305        });
306        let val_upper = *val_i1;
307        let val_lower = lower
308            .iter()
309            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
310            .copied()
311            .unwrap_or(val_upper);
312        Ok(val_lower + fraction * (val_upper - val_lower))
313    }
314}
315/// Calculate correlation coefficient
316#[pyfunction]
317pub fn correlation_py(x: &Bound<'_, PyArray1<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
318    let x_binding = x.readonly();
319    let x_arr = x_binding.as_array();
320    let y_binding = y.readonly();
321    let y_arr = y_binding.as_array();
322    let (r, _p) = pearsonr(&x_arr, &y_arr, "two-sided")
323        .map_err(|e| PyRuntimeError::new_err(format!("Correlation failed: {}", e)))?;
324    Ok(r)
325}
326/// Calculate covariance
327#[pyfunction]
328#[pyo3(signature = (x, y, ddof = 1))]
329pub fn covariance_py(
330    x: &Bound<'_, PyArray1<f64>>,
331    y: &Bound<'_, PyArray1<f64>>,
332    ddof: usize,
333) -> PyResult<f64> {
334    let x_binding = x.readonly();
335    let x_arr = x_binding.as_array();
336    let y_binding = y.readonly();
337    let y_arr = y_binding.as_array();
338    covariance_simd(&x_arr, &y_arr, ddof)
339        .map_err(|e| PyRuntimeError::new_err(format!("Covariance failed: {}", e)))
340}
341/// Calculate median using optimized partial sort (O(n) instead of O(n log n))
342#[pyfunction]
343pub fn median_py(data: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
344    let binding = data.readonly();
345    let arr = binding.as_array();
346    let n = arr.len();
347    if n == 0 {
348        return Err(PyRuntimeError::new_err("Empty array provided"));
349    }
350    let mut vec: Vec<f64> = arr.iter().cloned().collect();
351    if n % 2 == 1 {
352        let mid = n / 2;
353        let (_, median_val, _) = vec.select_nth_unstable_by(mid, |a, b| {
354            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
355        });
356        Ok(*median_val)
357    } else {
358        let mid = n / 2;
359        let (lower, val_at_mid, _) = vec.select_nth_unstable_by(mid, |a, b| {
360            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
361        });
362        let val_mid = *val_at_mid;
363        let val_mid_minus_1 = lower
364            .iter()
365            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
366            .copied()
367            .unwrap_or(val_mid);
368        Ok((val_mid_minus_1 + val_mid) / 2.0)
369    }
370}
371/// Calculate interquartile range (IQR) using optimized partial sort
372#[pyfunction]
373pub fn iqr_py(data: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
374    let binding = data.readonly();
375    let arr = binding.as_array();
376    let n = arr.len();
377    if n == 0 {
378        return Err(PyRuntimeError::new_err("Empty array provided"));
379    }
380    let mut vec: Vec<f64> = arr.iter().cloned().collect();
381    let get_percentile = |vec: &mut [f64], q: f64| -> f64 {
382        let virtual_index = q * (n - 1) as f64;
383        let i = virtual_index.floor() as usize;
384        let fraction = virtual_index - i as f64;
385        if fraction == 0.0 || i >= n - 1 {
386            let idx = i.min(n - 1);
387            let (_, val, _) = vec.select_nth_unstable_by(idx, |a, b| {
388                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
389            });
390            *val
391        } else {
392            let (lower, val_i1, _) = vec.select_nth_unstable_by(i + 1, |a, b| {
393                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
394            });
395            let val_upper = *val_i1;
396            let val_lower = lower
397                .iter()
398                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
399                .copied()
400                .unwrap_or(val_upper);
401            val_lower + fraction * (val_upper - val_lower)
402        }
403    };
404    let q75 = get_percentile(&mut vec, 0.75);
405    let mut vec2: Vec<f64> = arr.iter().cloned().collect();
406    let q25 = get_percentile(&mut vec2, 0.25);
407    Ok(q75 - q25)
408}
409/// One-sample t-test
410///
411/// Test whether the mean of a sample is different from a given value.
412///
413/// Parameters:
414/// - data: Input data
415/// - popmean: Population mean for null hypothesis
416/// - alternative: "two-sided", "less", or "greater"
417#[pyfunction]
418#[pyo3(signature = (data, popmean, alternative = "two-sided"))]
419pub fn ttest_1samp_py(
420    py: Python,
421    data: &Bound<'_, PyArray1<f64>>,
422    popmean: f64,
423    alternative: &str,
424) -> PyResult<Py<PyAny>> {
425    let binding = data.readonly();
426    let arr = binding.as_array();
427    let alt = match alternative.to_lowercase().as_str() {
428        "two-sided" | "two_sided" => Alternative::TwoSided,
429        "less" => Alternative::Less,
430        "greater" => Alternative::Greater,
431        _ => {
432            return Err(PyRuntimeError::new_err(format!(
433                "Invalid alternative: {}. Use 'two-sided', 'less', or 'greater'",
434                alternative
435            )));
436        }
437    };
438    let result = ttest_1samp(&arr.view(), popmean, alt, "omit")
439        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {}", e)))?;
440    let dict = PyDict::new(py);
441    dict.set_item("statistic", result.statistic)?;
442    dict.set_item("pvalue", result.pvalue)?;
443    dict.set_item("df", result.df)?;
444    Ok(dict.into())
445}
446/// Two-sample independent t-test
447///
448/// Test whether two independent samples have different means.
449///
450/// Parameters:
451/// - a: First sample
452/// - b: Second sample
453/// - equal_var: If true, perform standard t-test assuming equal variance
454/// - alternative: "two-sided", "less", or "greater"
455#[pyfunction]
456#[pyo3(signature = (a, b, equal_var = true, alternative = "two-sided"))]
457pub fn ttest_ind_py(
458    py: Python,
459    a: &Bound<'_, PyArray1<f64>>,
460    b: &Bound<'_, PyArray1<f64>>,
461    equal_var: bool,
462    alternative: &str,
463) -> PyResult<Py<PyAny>> {
464    let a_binding = a.readonly();
465    let a_arr = a_binding.as_array();
466    let b_binding = b.readonly();
467    let b_arr = b_binding.as_array();
468    let alt = match alternative.to_lowercase().as_str() {
469        "two-sided" | "two_sided" => Alternative::TwoSided,
470        "less" => Alternative::Less,
471        "greater" => Alternative::Greater,
472        _ => {
473            return Err(PyRuntimeError::new_err(format!(
474                "Invalid alternative: {}. Use 'two-sided', 'less', or 'greater'",
475                alternative
476            )));
477        }
478    };
479    let result = ttest_ind(&a_arr.view(), &b_arr.view(), equal_var, alt, "omit")
480        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {}", e)))?;
481    let dict = PyDict::new(py);
482    dict.set_item("statistic", result.statistic)?;
483    dict.set_item("pvalue", result.pvalue)?;
484    dict.set_item("df", result.df)?;
485    Ok(dict.into())
486}