Skip to main content

scirs2/stats/
batch.rs

1//! Batch/vectorized Python APIs for statistics
2//!
3//! Provides batch operations that reduce FFI overhead when calling many small
4//! statistical computations from Python.
5
6use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use rayon::prelude::*;
9use scirs2_core::ndarray::Array1;
10use scirs2_stats::distributions::beta::Beta as RustBeta;
11use scirs2_stats::distributions::exponential::Exponential as RustExponential;
12use scirs2_stats::distributions::gamma::Gamma as RustGamma;
13use scirs2_stats::distributions::normal::Normal as RustNormal;
14use scirs2_stats::distributions::uniform::Uniform as RustUniform;
15use scirs2_stats::pearsonr;
16
17// ============================================================
18// Internal helpers (no PyO3 dependency)
19// ============================================================
20
21/// Compute mean of a slice in a single pass.
22fn slice_mean(data: &[f64]) -> Option<f64> {
23    if data.is_empty() {
24        return None;
25    }
26    let mut sum0 = 0.0f64;
27    let mut sum1 = 0.0f64;
28    let mut sum2 = 0.0f64;
29    let mut sum3 = 0.0f64;
30    let chunks = data.chunks_exact(8);
31    let remainder = chunks.remainder();
32    for chunk in chunks {
33        sum0 += chunk[0] + chunk[4];
34        sum1 += chunk[1] + chunk[5];
35        sum2 += chunk[2] + chunk[6];
36        sum3 += chunk[3] + chunk[7];
37    }
38    let mut sum = sum0 + sum1 + sum2 + sum3;
39    for &v in remainder {
40        sum += v;
41    }
42    Some(sum / data.len() as f64)
43}
44
45/// Compute (mean, variance_sample, std_sample) in two passes.
46fn slice_mean_var_std(data: &[f64]) -> Option<(f64, f64, f64)> {
47    if data.is_empty() {
48        return None;
49    }
50    let n = data.len();
51    let mean = slice_mean(data)?;
52    let mut sq0 = 0.0f64;
53    let mut sq1 = 0.0f64;
54    let mut sq2 = 0.0f64;
55    let mut sq3 = 0.0f64;
56    let chunks = data.chunks_exact(8);
57    let remainder = chunks.remainder();
58    for chunk in chunks {
59        let d0 = chunk[0] - mean;
60        let d1 = chunk[1] - mean;
61        let d2 = chunk[2] - mean;
62        let d3 = chunk[3] - mean;
63        let d4 = chunk[4] - mean;
64        let d5 = chunk[5] - mean;
65        let d6 = chunk[6] - mean;
66        let d7 = chunk[7] - mean;
67        sq0 += d0 * d0 + d4 * d4;
68        sq1 += d1 * d1 + d5 * d5;
69        sq2 += d2 * d2 + d6 * d6;
70        sq3 += d3 * d3 + d7 * d7;
71    }
72    let mut sq_sum = sq0 + sq1 + sq2 + sq3;
73    for &v in remainder {
74        let d = v - mean;
75        sq_sum += d * d;
76    }
77    let denom = if n > 1 { (n - 1) as f64 } else { 1.0 };
78    let var = sq_sum / denom;
79    let std = var.sqrt();
80    Some((mean, std, var))
81}
82
83/// Compute percentile from a sorted slice via linear interpolation.
84fn sorted_percentile(sorted: &[f64], p: f64) -> f64 {
85    let n = sorted.len();
86    if n == 1 {
87        return sorted[0];
88    }
89    let virtual_index = p * (n - 1) as f64;
90    let i = virtual_index.floor() as usize;
91    let frac = virtual_index - i as f64;
92    if frac == 0.0 || i >= n - 1 {
93        sorted[i.min(n - 1)]
94    } else {
95        sorted[i] + frac * (sorted[i + 1] - sorted[i])
96    }
97}
98
99/// Full descriptive stats dict for a single slice.
100fn descriptive_stats_for_slice(
101    data: &[f64],
102) -> Result<std::collections::HashMap<String, f64>, String> {
103    let n = data.len();
104    if n == 0 {
105        return Err("Empty array".to_string());
106    }
107    let (mean, std, var) =
108        slice_mean_var_std(data).ok_or_else(|| "Failed to compute mean/std".to_string())?;
109    let min = data.iter().cloned().fold(f64::INFINITY, f64::min);
110    let max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
111
112    let mut sorted = data.to_vec();
113    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114    let median = sorted_percentile(&sorted, 0.5);
115    let q25 = sorted_percentile(&sorted, 0.25);
116    let q75 = sorted_percentile(&sorted, 0.75);
117
118    let mut map = std::collections::HashMap::new();
119    map.insert("n".to_string(), n as f64);
120    map.insert("mean".to_string(), mean);
121    map.insert("std".to_string(), std);
122    map.insert("var".to_string(), var);
123    map.insert("min".to_string(), min);
124    map.insert("max".to_string(), max);
125    map.insert("median".to_string(), median);
126    map.insert("q25".to_string(), q25);
127    map.insert("q75".to_string(), q75);
128    Ok(map)
129}
130
131// ============================================================
132// Public batch #[pyfunction]s
133// ============================================================
134
135/// Compute mean, std, and variance in a single batch call.
136///
137/// Avoids 3 separate FFI calls by computing all three in a single pass.
138///
139/// Parameters:
140///     data: Input data array
141///
142/// Returns:
143///     Tuple (mean, std, variance) — sample variance (ddof=1)
144#[pyfunction]
145pub fn stats_summary(data: Vec<f64>) -> PyResult<(f64, f64, f64)> {
146    if data.is_empty() {
147        return Err(PyRuntimeError::new_err("Empty array provided"));
148    }
149    let (mean, std, var) = slice_mean_var_std(&data)
150        .ok_or_else(|| PyRuntimeError::new_err("Failed to compute stats"))?;
151    Ok((mean, std, var))
152}
153
154/// Batch descriptive stats for multiple arrays.
155///
156/// For each array computes: n, mean, std, var, min, max, median, q25, q75.
157///
158/// Parameters:
159///     arrays: List of data arrays
160///
161/// Returns:
162///     List of dicts, one per input array
163#[pyfunction]
164pub fn batch_descriptive_stats(
165    arrays: Vec<Vec<f64>>,
166) -> PyResult<Vec<std::collections::HashMap<String, f64>>> {
167    if arrays.is_empty() {
168        return Ok(vec![]);
169    }
170    let results: Vec<Result<std::collections::HashMap<String, f64>, String>> = arrays
171        .par_iter()
172        .map(|arr| descriptive_stats_for_slice(arr))
173        .collect();
174
175    results
176        .into_iter()
177        .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("Descriptive stats failed: {}", e))))
178        .collect()
179}
180
181/// Batch Pearson correlation matrix for a list of arrays.
182///
183/// Computes the full correlation matrix for the provided arrays.
184/// Entry [i][j] is the Pearson correlation between arrays[i] and arrays[j].
185///
186/// Parameters:
187///     arrays: List of arrays (all must have the same length)
188///
189/// Returns:
190///     Correlation matrix as Vec<Vec<f64>>
191#[pyfunction]
192pub fn batch_correlation(arrays: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
193    let k = arrays.len();
194    if k == 0 {
195        return Ok(vec![]);
196    }
197    let n = arrays[0].len();
198    for (i, arr) in arrays.iter().enumerate() {
199        if arr.len() != n {
200            return Err(PyRuntimeError::new_err(format!(
201                "Array {} has length {} but expected {}",
202                i,
203                arr.len(),
204                n
205            )));
206        }
207        if arr.is_empty() {
208            return Err(PyRuntimeError::new_err(format!("Array {} is empty", i)));
209        }
210    }
211
212    // Build index pairs for upper triangle (including diagonal)
213    let pairs: Vec<(usize, usize)> = (0..k).flat_map(|i| (i..k).map(move |j| (i, j))).collect();
214
215    // Compute correlations in parallel for upper triangle
216    let corr_values: Vec<((usize, usize), f64)> = pairs
217        .par_iter()
218        .map(|&(i, j)| {
219            if i == j {
220                return Ok(((i, j), 1.0_f64));
221            }
222            let x_arr = Array1::from_vec(arrays[i].clone());
223            let y_arr = Array1::from_vec(arrays[j].clone());
224            pearsonr(&x_arr.view(), &y_arr.view(), "two-sided")
225                .map(|(r, _p)| ((i, j), r))
226                .map_err(|e| format!("Pearson correlation ({},{}) failed: {}", i, j, e))
227        })
228        .collect::<Vec<Result<((usize, usize), f64), String>>>()
229        .into_iter()
230        .collect::<Result<Vec<((usize, usize), f64)>, String>>()
231        .map_err(PyRuntimeError::new_err)?;
232
233    // Fill symmetric matrix
234    let mut matrix = vec![vec![0.0f64; k]; k];
235    for ((i, j), val) in corr_values {
236        matrix[i][j] = val;
237        matrix[j][i] = val;
238    }
239    Ok(matrix)
240}
241
242/// Evaluate the PDF of a named distribution at each data point.
243///
244/// Supported distributions: "normal", "exponential", "uniform", "gamma", "beta"
245///
246/// Parameters:
247///     data: Points at which to evaluate the PDF
248///     distribution: Distribution name (e.g., "normal")
249///     params: Distribution parameters
250///         - normal: [mu, sigma]
251///         - exponential: [lambda] (rate = 1/scale)
252///         - uniform: [low, high]
253///         - gamma: [shape, scale]
254///         - beta: [alpha, beta_param]
255///
256/// Returns:
257///     Vec<f64> of PDF values
258#[pyfunction]
259pub fn batch_pdf_eval(data: Vec<f64>, distribution: &str, params: Vec<f64>) -> PyResult<Vec<f64>> {
260    if data.is_empty() {
261        return Ok(vec![]);
262    }
263
264    match distribution.to_lowercase().as_str() {
265        "normal" => {
266            if params.len() < 2 {
267                return Err(PyRuntimeError::new_err(
268                    "Normal distribution requires [mu, sigma] params",
269                ));
270            }
271            let mu = params[0];
272            let sigma = params[1];
273            if sigma <= 0.0 {
274                return Err(PyRuntimeError::new_err("sigma must be positive"));
275            }
276            // Normal::new(loc, scale) where loc=mu, scale=sigma
277            let dist = RustNormal::new(mu, sigma).map_err(|e| {
278                PyRuntimeError::new_err(format!("Normal distribution failed: {}", e))
279            })?;
280            let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
281            Ok(result)
282        }
283        "exponential" => {
284            if params.is_empty() {
285                return Err(PyRuntimeError::new_err(
286                    "Exponential distribution requires [lambda] params",
287                ));
288            }
289            let lambda = params[0];
290            if lambda <= 0.0 {
291                return Err(PyRuntimeError::new_err("lambda must be positive"));
292            }
293            // Exponential::new(rate, loc) where rate=lambda, loc=0
294            let dist = RustExponential::new(lambda, 0.0).map_err(|e| {
295                PyRuntimeError::new_err(format!("Exponential distribution failed: {}", e))
296            })?;
297            let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
298            Ok(result)
299        }
300        "uniform" => {
301            if params.len() < 2 {
302                return Err(PyRuntimeError::new_err(
303                    "Uniform distribution requires [low, high] params",
304                ));
305            }
306            let low = params[0];
307            let high = params[1];
308            if high <= low {
309                return Err(PyRuntimeError::new_err("high must be greater than low"));
310            }
311            let dist = RustUniform::new(low, high).map_err(|e| {
312                PyRuntimeError::new_err(format!("Uniform distribution failed: {}", e))
313            })?;
314            let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
315            Ok(result)
316        }
317        "gamma" => {
318            if params.len() < 2 {
319                return Err(PyRuntimeError::new_err(
320                    "Gamma distribution requires [shape, scale] params",
321                ));
322            }
323            let shape = params[0];
324            let scale = params[1];
325            if shape <= 0.0 || scale <= 0.0 {
326                return Err(PyRuntimeError::new_err("shape and scale must be positive"));
327            }
328            let dist = RustGamma::new(shape, scale, 0.0).map_err(|e| {
329                PyRuntimeError::new_err(format!("Gamma distribution failed: {}", e))
330            })?;
331            let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
332            Ok(result)
333        }
334        "beta" => {
335            if params.len() < 2 {
336                return Err(PyRuntimeError::new_err(
337                    "Beta distribution requires [alpha, beta] params",
338                ));
339            }
340            let alpha = params[0];
341            let beta_param = params[1];
342            if alpha <= 0.0 || beta_param <= 0.0 {
343                return Err(PyRuntimeError::new_err("alpha and beta must be positive"));
344            }
345            let dist = RustBeta::new(alpha, beta_param, 0.0, 1.0)
346                .map_err(|e| PyRuntimeError::new_err(format!("Beta distribution failed: {}", e)))?;
347            let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
348            Ok(result)
349        }
350        other => Err(PyRuntimeError::new_err(format!(
351            "Unknown distribution: '{}'. Supported: normal, exponential, uniform, gamma, beta",
352            other
353        ))),
354    }
355}
356
357/// Register batch stats functions into the Python module.
358pub fn register_batch_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
359    m.add_function(wrap_pyfunction!(stats_summary, m)?)?;
360    m.add_function(wrap_pyfunction!(batch_descriptive_stats, m)?)?;
361    m.add_function(wrap_pyfunction!(batch_correlation, m)?)?;
362    m.add_function(wrap_pyfunction!(batch_pdf_eval, m)?)?;
363    Ok(())
364}