1use pyo3::exceptions::PyRuntimeError;
7use pyo3::prelude::*;
8use rayon::prelude::*;
9use scirs2_core::ndarray::Array1;
10use scirs2_stats::distributions::beta::Beta as RustBeta;
11use scirs2_stats::distributions::exponential::Exponential as RustExponential;
12use scirs2_stats::distributions::gamma::Gamma as RustGamma;
13use scirs2_stats::distributions::normal::Normal as RustNormal;
14use scirs2_stats::distributions::uniform::Uniform as RustUniform;
15use scirs2_stats::pearsonr;
16
17fn slice_mean(data: &[f64]) -> Option<f64> {
23 if data.is_empty() {
24 return None;
25 }
26 let mut sum0 = 0.0f64;
27 let mut sum1 = 0.0f64;
28 let mut sum2 = 0.0f64;
29 let mut sum3 = 0.0f64;
30 let chunks = data.chunks_exact(8);
31 let remainder = chunks.remainder();
32 for chunk in chunks {
33 sum0 += chunk[0] + chunk[4];
34 sum1 += chunk[1] + chunk[5];
35 sum2 += chunk[2] + chunk[6];
36 sum3 += chunk[3] + chunk[7];
37 }
38 let mut sum = sum0 + sum1 + sum2 + sum3;
39 for &v in remainder {
40 sum += v;
41 }
42 Some(sum / data.len() as f64)
43}
44
45fn slice_mean_var_std(data: &[f64]) -> Option<(f64, f64, f64)> {
47 if data.is_empty() {
48 return None;
49 }
50 let n = data.len();
51 let mean = slice_mean(data)?;
52 let mut sq0 = 0.0f64;
53 let mut sq1 = 0.0f64;
54 let mut sq2 = 0.0f64;
55 let mut sq3 = 0.0f64;
56 let chunks = data.chunks_exact(8);
57 let remainder = chunks.remainder();
58 for chunk in chunks {
59 let d0 = chunk[0] - mean;
60 let d1 = chunk[1] - mean;
61 let d2 = chunk[2] - mean;
62 let d3 = chunk[3] - mean;
63 let d4 = chunk[4] - mean;
64 let d5 = chunk[5] - mean;
65 let d6 = chunk[6] - mean;
66 let d7 = chunk[7] - mean;
67 sq0 += d0 * d0 + d4 * d4;
68 sq1 += d1 * d1 + d5 * d5;
69 sq2 += d2 * d2 + d6 * d6;
70 sq3 += d3 * d3 + d7 * d7;
71 }
72 let mut sq_sum = sq0 + sq1 + sq2 + sq3;
73 for &v in remainder {
74 let d = v - mean;
75 sq_sum += d * d;
76 }
77 let denom = if n > 1 { (n - 1) as f64 } else { 1.0 };
78 let var = sq_sum / denom;
79 let std = var.sqrt();
80 Some((mean, std, var))
81}
82
83fn sorted_percentile(sorted: &[f64], p: f64) -> f64 {
85 let n = sorted.len();
86 if n == 1 {
87 return sorted[0];
88 }
89 let virtual_index = p * (n - 1) as f64;
90 let i = virtual_index.floor() as usize;
91 let frac = virtual_index - i as f64;
92 if frac == 0.0 || i >= n - 1 {
93 sorted[i.min(n - 1)]
94 } else {
95 sorted[i] + frac * (sorted[i + 1] - sorted[i])
96 }
97}
98
99fn descriptive_stats_for_slice(
101 data: &[f64],
102) -> Result<std::collections::HashMap<String, f64>, String> {
103 let n = data.len();
104 if n == 0 {
105 return Err("Empty array".to_string());
106 }
107 let (mean, std, var) =
108 slice_mean_var_std(data).ok_or_else(|| "Failed to compute mean/std".to_string())?;
109 let min = data.iter().cloned().fold(f64::INFINITY, f64::min);
110 let max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
111
112 let mut sorted = data.to_vec();
113 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
114 let median = sorted_percentile(&sorted, 0.5);
115 let q25 = sorted_percentile(&sorted, 0.25);
116 let q75 = sorted_percentile(&sorted, 0.75);
117
118 let mut map = std::collections::HashMap::new();
119 map.insert("n".to_string(), n as f64);
120 map.insert("mean".to_string(), mean);
121 map.insert("std".to_string(), std);
122 map.insert("var".to_string(), var);
123 map.insert("min".to_string(), min);
124 map.insert("max".to_string(), max);
125 map.insert("median".to_string(), median);
126 map.insert("q25".to_string(), q25);
127 map.insert("q75".to_string(), q75);
128 Ok(map)
129}
130
131#[pyfunction]
145pub fn stats_summary(data: Vec<f64>) -> PyResult<(f64, f64, f64)> {
146 if data.is_empty() {
147 return Err(PyRuntimeError::new_err("Empty array provided"));
148 }
149 let (mean, std, var) = slice_mean_var_std(&data)
150 .ok_or_else(|| PyRuntimeError::new_err("Failed to compute stats"))?;
151 Ok((mean, std, var))
152}
153
154#[pyfunction]
164pub fn batch_descriptive_stats(
165 arrays: Vec<Vec<f64>>,
166) -> PyResult<Vec<std::collections::HashMap<String, f64>>> {
167 if arrays.is_empty() {
168 return Ok(vec![]);
169 }
170 let results: Vec<Result<std::collections::HashMap<String, f64>, String>> = arrays
171 .par_iter()
172 .map(|arr| descriptive_stats_for_slice(arr))
173 .collect();
174
175 results
176 .into_iter()
177 .map(|r| r.map_err(|e| PyRuntimeError::new_err(format!("Descriptive stats failed: {}", e))))
178 .collect()
179}
180
181#[pyfunction]
192pub fn batch_correlation(arrays: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
193 let k = arrays.len();
194 if k == 0 {
195 return Ok(vec![]);
196 }
197 let n = arrays[0].len();
198 for (i, arr) in arrays.iter().enumerate() {
199 if arr.len() != n {
200 return Err(PyRuntimeError::new_err(format!(
201 "Array {} has length {} but expected {}",
202 i,
203 arr.len(),
204 n
205 )));
206 }
207 if arr.is_empty() {
208 return Err(PyRuntimeError::new_err(format!("Array {} is empty", i)));
209 }
210 }
211
212 let pairs: Vec<(usize, usize)> = (0..k).flat_map(|i| (i..k).map(move |j| (i, j))).collect();
214
215 let corr_values: Vec<((usize, usize), f64)> = pairs
217 .par_iter()
218 .map(|&(i, j)| {
219 if i == j {
220 return Ok(((i, j), 1.0_f64));
221 }
222 let x_arr = Array1::from_vec(arrays[i].clone());
223 let y_arr = Array1::from_vec(arrays[j].clone());
224 pearsonr(&x_arr.view(), &y_arr.view(), "two-sided")
225 .map(|(r, _p)| ((i, j), r))
226 .map_err(|e| format!("Pearson correlation ({},{}) failed: {}", i, j, e))
227 })
228 .collect::<Vec<Result<((usize, usize), f64), String>>>()
229 .into_iter()
230 .collect::<Result<Vec<((usize, usize), f64)>, String>>()
231 .map_err(PyRuntimeError::new_err)?;
232
233 let mut matrix = vec![vec![0.0f64; k]; k];
235 for ((i, j), val) in corr_values {
236 matrix[i][j] = val;
237 matrix[j][i] = val;
238 }
239 Ok(matrix)
240}
241
242#[pyfunction]
259pub fn batch_pdf_eval(data: Vec<f64>, distribution: &str, params: Vec<f64>) -> PyResult<Vec<f64>> {
260 if data.is_empty() {
261 return Ok(vec![]);
262 }
263
264 match distribution.to_lowercase().as_str() {
265 "normal" => {
266 if params.len() < 2 {
267 return Err(PyRuntimeError::new_err(
268 "Normal distribution requires [mu, sigma] params",
269 ));
270 }
271 let mu = params[0];
272 let sigma = params[1];
273 if sigma <= 0.0 {
274 return Err(PyRuntimeError::new_err("sigma must be positive"));
275 }
276 let dist = RustNormal::new(mu, sigma).map_err(|e| {
278 PyRuntimeError::new_err(format!("Normal distribution failed: {}", e))
279 })?;
280 let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
281 Ok(result)
282 }
283 "exponential" => {
284 if params.is_empty() {
285 return Err(PyRuntimeError::new_err(
286 "Exponential distribution requires [lambda] params",
287 ));
288 }
289 let lambda = params[0];
290 if lambda <= 0.0 {
291 return Err(PyRuntimeError::new_err("lambda must be positive"));
292 }
293 let dist = RustExponential::new(lambda, 0.0).map_err(|e| {
295 PyRuntimeError::new_err(format!("Exponential distribution failed: {}", e))
296 })?;
297 let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
298 Ok(result)
299 }
300 "uniform" => {
301 if params.len() < 2 {
302 return Err(PyRuntimeError::new_err(
303 "Uniform distribution requires [low, high] params",
304 ));
305 }
306 let low = params[0];
307 let high = params[1];
308 if high <= low {
309 return Err(PyRuntimeError::new_err("high must be greater than low"));
310 }
311 let dist = RustUniform::new(low, high).map_err(|e| {
312 PyRuntimeError::new_err(format!("Uniform distribution failed: {}", e))
313 })?;
314 let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
315 Ok(result)
316 }
317 "gamma" => {
318 if params.len() < 2 {
319 return Err(PyRuntimeError::new_err(
320 "Gamma distribution requires [shape, scale] params",
321 ));
322 }
323 let shape = params[0];
324 let scale = params[1];
325 if shape <= 0.0 || scale <= 0.0 {
326 return Err(PyRuntimeError::new_err("shape and scale must be positive"));
327 }
328 let dist = RustGamma::new(shape, scale, 0.0).map_err(|e| {
329 PyRuntimeError::new_err(format!("Gamma distribution failed: {}", e))
330 })?;
331 let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
332 Ok(result)
333 }
334 "beta" => {
335 if params.len() < 2 {
336 return Err(PyRuntimeError::new_err(
337 "Beta distribution requires [alpha, beta] params",
338 ));
339 }
340 let alpha = params[0];
341 let beta_param = params[1];
342 if alpha <= 0.0 || beta_param <= 0.0 {
343 return Err(PyRuntimeError::new_err("alpha and beta must be positive"));
344 }
345 let dist = RustBeta::new(alpha, beta_param, 0.0, 1.0)
346 .map_err(|e| PyRuntimeError::new_err(format!("Beta distribution failed: {}", e)))?;
347 let result: Vec<f64> = data.par_iter().map(|&x| dist.pdf(x)).collect();
348 Ok(result)
349 }
350 other => Err(PyRuntimeError::new_err(format!(
351 "Unknown distribution: '{}'. Supported: normal, exponential, uniform, gamma, beta",
352 other
353 ))),
354 }
355}
356
357pub fn register_batch_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
359 m.add_function(wrap_pyfunction!(stats_summary, m)?)?;
360 m.add_function(wrap_pyfunction!(batch_descriptive_stats, m)?)?;
361 m.add_function(wrap_pyfunction!(batch_correlation, m)?)?;
362 m.add_function(wrap_pyfunction!(batch_pdf_eval, m)?)?;
363 Ok(())
364}