Skip to main content

scirs2/stats/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use pyo3::exceptions::PyRuntimeError;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyDict};
8use scirs2_numpy::{PyArray1, PyArrayMethods};
9use scirs2_stats::tests::ttest::{ttest_1samp, ttest_ind, Alternative};
10use scirs2_stats::{covariance_simd, pearsonr};
11
12/// Compute descriptive statistics - uses optimized implementations
13/// Returns dict with mean, std, var, min, max, median, count
14#[pyfunction]
15pub fn describe_py(py: Python, data: &Bound<'_, PyArray1<f64>>) -> PyResult<Py<PyAny>> {
16    let binding = data.readonly();
17    let arr = binding.as_array();
18    let n = arr.len();
19    if n == 0 {
20        return Err(PyRuntimeError::new_err("Empty array provided"));
21    }
22    let slice = arr.as_slice().expect("Operation failed");
23    let mut sum0 = 0.0f64;
24    let mut sum1 = 0.0f64;
25    let mut sum2 = 0.0f64;
26    let mut sum3 = 0.0f64;
27    let chunks = slice.chunks_exact(8);
28    let remainder = chunks.remainder();
29    for chunk in chunks {
30        sum0 += chunk[0] + chunk[4];
31        sum1 += chunk[1] + chunk[5];
32        sum2 += chunk[2] + chunk[6];
33        sum3 += chunk[3] + chunk[7];
34    }
35    let mut sum = sum0 + sum1 + sum2 + sum3;
36    for &val in remainder {
37        sum += val;
38    }
39    let mean_val = sum / n as f64;
40    let mut sq0 = 0.0f64;
41    let mut sq1 = 0.0f64;
42    let mut sq2 = 0.0f64;
43    let mut sq3 = 0.0f64;
44    let mut min_val = f64::INFINITY;
45    let mut max_val = f64::NEG_INFINITY;
46    let chunks = slice.chunks_exact(8);
47    let remainder = chunks.remainder();
48    for chunk in chunks {
49        let d0 = chunk[0] - mean_val;
50        let d1 = chunk[1] - mean_val;
51        let d2 = chunk[2] - mean_val;
52        let d3 = chunk[3] - mean_val;
53        let d4 = chunk[4] - mean_val;
54        let d5 = chunk[5] - mean_val;
55        let d6 = chunk[6] - mean_val;
56        let d7 = chunk[7] - mean_val;
57        sq0 += d0 * d0 + d4 * d4;
58        sq1 += d1 * d1 + d5 * d5;
59        sq2 += d2 * d2 + d6 * d6;
60        sq3 += d3 * d3 + d7 * d7;
61        min_val = min_val
62            .min(chunk[0])
63            .min(chunk[1])
64            .min(chunk[2])
65            .min(chunk[3])
66            .min(chunk[4])
67            .min(chunk[5])
68            .min(chunk[6])
69            .min(chunk[7]);
70        max_val = max_val
71            .max(chunk[0])
72            .max(chunk[1])
73            .max(chunk[2])
74            .max(chunk[3])
75            .max(chunk[4])
76            .max(chunk[5])
77            .max(chunk[6])
78            .max(chunk[7]);
79    }
80    for &val in remainder {
81        let d = val - mean_val;
82        sq0 += d * d;
83        min_val = min_val.min(val);
84        max_val = max_val.max(val);
85    }
86    let sq_sum = sq0 + sq1 + sq2 + sq3;
87    let var_val = if n > 1 { sq_sum / (n - 1) as f64 } else { 0.0 };
88    let std_val = var_val.sqrt();
89    let mut vec: Vec<f64> = arr.iter().cloned().collect();
90    let median_val = if n % 2 == 1 {
91        let mid = n / 2;
92        let (_, val, _) = vec.select_nth_unstable_by(mid, |a, b| {
93            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
94        });
95        *val
96    } else {
97        let mid = n / 2;
98        let (lower, val_at_mid, _) = vec.select_nth_unstable_by(mid, |a, b| {
99            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
100        });
101        let val_mid = *val_at_mid;
102        let val_mid_minus_1 = lower
103            .iter()
104            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
105            .copied()
106            .unwrap_or(val_mid);
107        (val_mid_minus_1 + val_mid) / 2.0
108    };
109    let dict = PyDict::new(py);
110    dict.set_item("mean", mean_val)?;
111    dict.set_item("std", std_val)?;
112    dict.set_item("var", var_val)?;
113    dict.set_item("min", min_val)?;
114    dict.set_item("max", max_val)?;
115    dict.set_item("median", median_val)?;
116    dict.set_item("count", n)?;
117    Ok(dict.into())
118}
119/// Calculate mean - optimized with 8-way unrolling and multiple accumulators
120#[pyfunction]
121pub fn mean_py(data: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
122    let binding = data.readonly();
123    let arr = binding.as_array();
124    let n = arr.len();
125    if n == 0 {
126        return Err(PyRuntimeError::new_err("Empty array provided"));
127    }
128    let slice = arr.as_slice().expect("Operation failed");
129    let mut sum0 = 0.0f64;
130    let mut sum1 = 0.0f64;
131    let mut sum2 = 0.0f64;
132    let mut sum3 = 0.0f64;
133    let chunks = slice.chunks_exact(8);
134    let remainder = chunks.remainder();
135    for chunk in chunks {
136        sum0 += chunk[0] + chunk[4];
137        sum1 += chunk[1] + chunk[5];
138        sum2 += chunk[2] + chunk[6];
139        sum3 += chunk[3] + chunk[7];
140    }
141    let mut sum = sum0 + sum1 + sum2 + sum3;
142    for &val in remainder {
143        sum += val;
144    }
145    Ok(sum / n as f64)
146}
147/// Calculate standard deviation - optimized two-pass with multi-accumulator
148#[pyfunction]
149#[pyo3(signature = (data, ddof = 0))]
150pub fn std_py(data: &Bound<'_, PyArray1<f64>>, ddof: usize) -> PyResult<f64> {
151    let binding = data.readonly();
152    let arr = binding.as_array();
153    let n = arr.len();
154    if n == 0 {
155        return Err(PyRuntimeError::new_err("Empty array provided"));
156    }
157    if n <= ddof {
158        return Err(PyRuntimeError::new_err(
159            "Not enough data points for given ddof",
160        ));
161    }
162    let slice = arr.as_slice().expect("Operation failed");
163    let mut sum0 = 0.0f64;
164    let mut sum1 = 0.0f64;
165    let mut sum2 = 0.0f64;
166    let mut sum3 = 0.0f64;
167    let chunks = slice.chunks_exact(8);
168    let remainder = chunks.remainder();
169    for chunk in chunks {
170        sum0 += chunk[0] + chunk[4];
171        sum1 += chunk[1] + chunk[5];
172        sum2 += chunk[2] + chunk[6];
173        sum3 += chunk[3] + chunk[7];
174    }
175    let mut sum = sum0 + sum1 + sum2 + sum3;
176    for &val in remainder {
177        sum += val;
178    }
179    let mean = sum / n as f64;
180    let mut sq0 = 0.0f64;
181    let mut sq1 = 0.0f64;
182    let mut sq2 = 0.0f64;
183    let mut sq3 = 0.0f64;
184    let chunks = slice.chunks_exact(8);
185    let remainder = chunks.remainder();
186    for chunk in chunks {
187        let d0 = chunk[0] - mean;
188        let d1 = chunk[1] - mean;
189        let d2 = chunk[2] - mean;
190        let d3 = chunk[3] - mean;
191        let d4 = chunk[4] - mean;
192        let d5 = chunk[5] - mean;
193        let d6 = chunk[6] - mean;
194        let d7 = chunk[7] - mean;
195        sq0 += d0 * d0 + d4 * d4;
196        sq1 += d1 * d1 + d5 * d5;
197        sq2 += d2 * d2 + d6 * d6;
198        sq3 += d3 * d3 + d7 * d7;
199    }
200    let mut sq_sum = sq0 + sq1 + sq2 + sq3;
201    for &val in remainder {
202        let d = val - mean;
203        sq_sum += d * d;
204    }
205    let variance = sq_sum / (n - ddof) as f64;
206    Ok(variance.sqrt())
207}
208/// Calculate variance - optimized two-pass with multi-accumulator
209#[pyfunction]
210#[pyo3(signature = (data, ddof = 0))]
211pub fn var_py(data: &Bound<'_, PyArray1<f64>>, ddof: usize) -> PyResult<f64> {
212    let binding = data.readonly();
213    let arr = binding.as_array();
214    let n = arr.len();
215    if n == 0 {
216        return Err(PyRuntimeError::new_err("Empty array provided"));
217    }
218    if n <= ddof {
219        return Err(PyRuntimeError::new_err(
220            "Not enough data points for given ddof",
221        ));
222    }
223    let slice = arr.as_slice().expect("Operation failed");
224    let mut sum0 = 0.0f64;
225    let mut sum1 = 0.0f64;
226    let mut sum2 = 0.0f64;
227    let mut sum3 = 0.0f64;
228    let chunks = slice.chunks_exact(8);
229    let remainder = chunks.remainder();
230    for chunk in chunks {
231        sum0 += chunk[0] + chunk[4];
232        sum1 += chunk[1] + chunk[5];
233        sum2 += chunk[2] + chunk[6];
234        sum3 += chunk[3] + chunk[7];
235    }
236    let mut sum = sum0 + sum1 + sum2 + sum3;
237    for &val in remainder {
238        sum += val;
239    }
240    let mean = sum / n as f64;
241    let mut sq0 = 0.0f64;
242    let mut sq1 = 0.0f64;
243    let mut sq2 = 0.0f64;
244    let mut sq3 = 0.0f64;
245    let chunks = slice.chunks_exact(8);
246    let remainder = chunks.remainder();
247    for chunk in chunks {
248        let d0 = chunk[0] - mean;
249        let d1 = chunk[1] - mean;
250        let d2 = chunk[2] - mean;
251        let d3 = chunk[3] - mean;
252        let d4 = chunk[4] - mean;
253        let d5 = chunk[5] - mean;
254        let d6 = chunk[6] - mean;
255        let d7 = chunk[7] - mean;
256        sq0 += d0 * d0 + d4 * d4;
257        sq1 += d1 * d1 + d5 * d5;
258        sq2 += d2 * d2 + d6 * d6;
259        sq3 += d3 * d3 + d7 * d7;
260    }
261    let mut sq_sum = sq0 + sq1 + sq2 + sq3;
262    for &val in remainder {
263        let d = val - mean;
264        sq_sum += d * d;
265    }
266    Ok(sq_sum / (n - ddof) as f64)
267}
268/// Calculate percentile using optimized partial sort
269/// q: percentile value (0-100)
270#[pyfunction]
271pub fn percentile_py(data: &Bound<'_, PyArray1<f64>>, q: f64) -> PyResult<f64> {
272    let binding = data.readonly();
273    let arr = binding.as_array();
274    let n = arr.len();
275    if n == 0 {
276        return Err(PyRuntimeError::new_err("Empty array provided"));
277    }
278    if !(0.0..=100.0).contains(&q) {
279        return Err(PyRuntimeError::new_err(
280            "Percentile must be between 0 and 100",
281        ));
282    }
283    let mut vec: Vec<f64> = arr.iter().cloned().collect();
284    let q_norm = q / 100.0;
285    let virtual_index = q_norm * (n - 1) as f64;
286    let i = virtual_index.floor() as usize;
287    let fraction = virtual_index - i as f64;
288    if fraction == 0.0 || i >= n - 1 {
289        let idx = i.min(n - 1);
290        let (_, val, _) = vec.select_nth_unstable_by(idx, |a, b| {
291            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
292        });
293        Ok(*val)
294    } else {
295        let (lower, val_i1, _) = vec.select_nth_unstable_by(i + 1, |a, b| {
296            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
297        });
298        let val_upper = *val_i1;
299        let val_lower = lower
300            .iter()
301            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
302            .copied()
303            .unwrap_or(val_upper);
304        Ok(val_lower + fraction * (val_upper - val_lower))
305    }
306}
307/// Calculate correlation coefficient
308#[pyfunction]
309pub fn correlation_py(x: &Bound<'_, PyArray1<f64>>, y: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
310    let x_binding = x.readonly();
311    let x_arr = x_binding.as_array();
312    let y_binding = y.readonly();
313    let y_arr = y_binding.as_array();
314    let (r, _p) = pearsonr(&x_arr, &y_arr, "two-sided")
315        .map_err(|e| PyRuntimeError::new_err(format!("Correlation failed: {}", e)))?;
316    Ok(r)
317}
318/// Calculate covariance
319#[pyfunction]
320#[pyo3(signature = (x, y, ddof = 1))]
321pub fn covariance_py(
322    x: &Bound<'_, PyArray1<f64>>,
323    y: &Bound<'_, PyArray1<f64>>,
324    ddof: usize,
325) -> PyResult<f64> {
326    let x_binding = x.readonly();
327    let x_arr = x_binding.as_array();
328    let y_binding = y.readonly();
329    let y_arr = y_binding.as_array();
330    covariance_simd(&x_arr, &y_arr, ddof)
331        .map_err(|e| PyRuntimeError::new_err(format!("Covariance failed: {}", e)))
332}
333/// Calculate median using optimized partial sort (O(n) instead of O(n log n))
334#[pyfunction]
335pub fn median_py(data: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
336    let binding = data.readonly();
337    let arr = binding.as_array();
338    let n = arr.len();
339    if n == 0 {
340        return Err(PyRuntimeError::new_err("Empty array provided"));
341    }
342    let mut vec: Vec<f64> = arr.iter().cloned().collect();
343    if n % 2 == 1 {
344        let mid = n / 2;
345        let (_, median_val, _) = vec.select_nth_unstable_by(mid, |a, b| {
346            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
347        });
348        Ok(*median_val)
349    } else {
350        let mid = n / 2;
351        let (lower, val_at_mid, _) = vec.select_nth_unstable_by(mid, |a, b| {
352            a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
353        });
354        let val_mid = *val_at_mid;
355        let val_mid_minus_1 = lower
356            .iter()
357            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
358            .copied()
359            .unwrap_or(val_mid);
360        Ok((val_mid_minus_1 + val_mid) / 2.0)
361    }
362}
363/// Calculate interquartile range (IQR) using optimized partial sort
364#[pyfunction]
365pub fn iqr_py(data: &Bound<'_, PyArray1<f64>>) -> PyResult<f64> {
366    let binding = data.readonly();
367    let arr = binding.as_array();
368    let n = arr.len();
369    if n == 0 {
370        return Err(PyRuntimeError::new_err("Empty array provided"));
371    }
372    let mut vec: Vec<f64> = arr.iter().cloned().collect();
373    let get_percentile = |vec: &mut [f64], q: f64| -> f64 {
374        let virtual_index = q * (n - 1) as f64;
375        let i = virtual_index.floor() as usize;
376        let fraction = virtual_index - i as f64;
377        if fraction == 0.0 || i >= n - 1 {
378            let idx = i.min(n - 1);
379            let (_, val, _) = vec.select_nth_unstable_by(idx, |a, b| {
380                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
381            });
382            *val
383        } else {
384            let (lower, val_i1, _) = vec.select_nth_unstable_by(i + 1, |a, b| {
385                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
386            });
387            let val_upper = *val_i1;
388            let val_lower = lower
389                .iter()
390                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
391                .copied()
392                .unwrap_or(val_upper);
393            val_lower + fraction * (val_upper - val_lower)
394        }
395    };
396    let q75 = get_percentile(&mut vec, 0.75);
397    let mut vec2: Vec<f64> = arr.iter().cloned().collect();
398    let q25 = get_percentile(&mut vec2, 0.25);
399    Ok(q75 - q25)
400}
401/// One-sample t-test
402///
403/// Test whether the mean of a sample is different from a given value.
404///
405/// Parameters:
406/// - data: Input data
407/// - popmean: Population mean for null hypothesis
408/// - alternative: "two-sided", "less", or "greater"
409#[pyfunction]
410#[pyo3(signature = (data, popmean, alternative = "two-sided"))]
411pub fn ttest_1samp_py(
412    py: Python,
413    data: &Bound<'_, PyArray1<f64>>,
414    popmean: f64,
415    alternative: &str,
416) -> PyResult<Py<PyAny>> {
417    let binding = data.readonly();
418    let arr = binding.as_array();
419    let alt = match alternative.to_lowercase().as_str() {
420        "two-sided" | "two_sided" => Alternative::TwoSided,
421        "less" => Alternative::Less,
422        "greater" => Alternative::Greater,
423        _ => {
424            return Err(PyRuntimeError::new_err(format!(
425                "Invalid alternative: {}. Use 'two-sided', 'less', or 'greater'",
426                alternative
427            )));
428        }
429    };
430    let result = ttest_1samp(&arr.view(), popmean, alt, "omit")
431        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {}", e)))?;
432    let dict = PyDict::new(py);
433    dict.set_item("statistic", result.statistic)?;
434    dict.set_item("pvalue", result.pvalue)?;
435    dict.set_item("df", result.df)?;
436    Ok(dict.into())
437}
438/// Two-sample independent t-test
439///
440/// Test whether two independent samples have different means.
441///
442/// Parameters:
443/// - a: First sample
444/// - b: Second sample
445/// - equal_var: If true, perform standard t-test assuming equal variance
446/// - alternative: "two-sided", "less", or "greater"
447#[pyfunction]
448#[pyo3(signature = (a, b, equal_var = true, alternative = "two-sided"))]
449pub fn ttest_ind_py(
450    py: Python,
451    a: &Bound<'_, PyArray1<f64>>,
452    b: &Bound<'_, PyArray1<f64>>,
453    equal_var: bool,
454    alternative: &str,
455) -> PyResult<Py<PyAny>> {
456    let a_binding = a.readonly();
457    let a_arr = a_binding.as_array();
458    let b_binding = b.readonly();
459    let b_arr = b_binding.as_array();
460    let alt = match alternative.to_lowercase().as_str() {
461        "two-sided" | "two_sided" => Alternative::TwoSided,
462        "less" => Alternative::Less,
463        "greater" => Alternative::Greater,
464        _ => {
465            return Err(PyRuntimeError::new_err(format!(
466                "Invalid alternative: {}. Use 'two-sided', 'less', or 'greater'",
467                alternative
468            )));
469        }
470    };
471    let result = ttest_ind(&a_arr.view(), &b_arr.view(), equal_var, alt, "omit")
472        .map_err(|e| PyRuntimeError::new_err(format!("t-test failed: {}", e)))?;
473    let dict = PyDict::new(py);
474    dict.set_item("statistic", result.statistic)?;
475    dict.set_item("pvalue", result.pvalue)?;
476    dict.set_item("df", result.df)?;
477    Ok(dict.into())
478}