simd_kernels/kernels/scientific/distributions/univariate/common/
mod.rs

1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Common Distribution Utilities** - *Shared Testing and Helper Infrastructure*
5//!
6//! Common testing utilities, helper functions, and macros shared across all univariate
7//! distribution implementations to ensure consistency and reduce code duplication.
8//!
9//! ## Testing Infrastructure
10//! This module provides standardised testing patterns that validate:
11//! - **Numerical accuracy**: Comparison against reference implementations
12//! - **Null handling**: Proper propagation of missing values through calculations
13//! - **Edge cases**: Boundary conditions and special value handling
14//! - **Performance**: Bulk vs scalar operation consistency
15//!
16//! ## Helper Functions
17//! - **Array extraction**: Safe unwrapping of dense arrays without null masks
18//! - **Scalar testing**: Single-value operation testing utilities
19//! - **Mask creation**: Null mask generation for testing scenarios
20//! - **Tolerance checking**: Numerical comparison with configurable precision
21//!
22//! ## Test Macros
23//! The `common_tests!` macro generates standard test suites for distribution functions,
24//! ensuring consistent validation across all statistical implementations.
25
26#[cfg(feature = "simd")]
27pub mod simd;
28/// Scalar implementations of common distribution utilities.
29pub mod std;
30
31use minarrow::{Bitmask, Buffer, FloatArray};
32
33// Common test helpers
34
35/// Test Helper: unwrap `FloatArray`, assert *no* null mask, return data.
36pub fn dense_data(arr: FloatArray<f64>) -> Buffer<f64> {
37    assert!(arr.null_mask.is_none(), "unexpected mask on dense path");
38    arr.data
39}
40
41/// Build a 1-lane slice (`&[T]`) on the fly, call `kernel`,
42/// and return the single f64 result for *scalar* comparison.
43pub fn scalar_call<F>(kernel: F, x: f64) -> f64
44where
45    F: Fn(&[f64]) -> FloatArray<f64>,
46{
47    dense_data(kernel(&[x])).into_iter().next().unwrap()
48}
49
50/// Create a mask of given length with exactly the lane `idx` null.
51pub fn single_null_mask(len: usize, idx: usize) -> Bitmask {
52    let mut m = Bitmask::new_set_all(len, true);
53    unsafe { m.set_unchecked(idx, false) };
54    m
55}
56
57/// Assert absolute difference ≤ `tol`.
58pub fn assert_close(a: f64, b: f64, tol: f64) {
59    assert!(
60        (a - b).abs() < tol,
61        "assert_close failed: {} vs {} (tol={})",
62        a,
63        b,
64        tol
65    );
66}
67
68/// Generate the three most-common tests (empty-input, mask propagation,
69/// bulk-vs-scalar) for a `fn kernel(&[f64]) -> FloatArray<f64>`.
70///
71/// Usage:
72/// ```ignore
73/// common_tests!(normal_pdf, |x| normal_pdf(x, 0.0, 1.0, None, None).unwrap());
74/// ```
75#[macro_export]
76macro_rules! common_tests {
77    // $name    – a unique test-group prefix
78    // $call:expr – *closure* that gets &[f64] and returns FloatArray<f64>
79    ($name:ident, $call:expr) => {
80        mod $name {
81            use super::*;
82            use crate::tests::common::*;
83
84            #[test]
85            fn empty_input() {
86                let arr = ($call)(&[]);
87                assert!(arr.data.is_empty());
88                assert!(arr.null_mask.is_none());
89            }
90
91            #[test]
92            fn bulk_vs_scalar_consistency() {
93                let xs = vec64![-3.0, -1.0, 0.0, 1.0, 2.0];
94                let bulk = dense(($call)(&xs));
95                for (i, &x) in xs.iter().enumerate() {
96                    let scalar = scalar_call($call, x);
97                    assert_close(bulk[i], scalar, 1e-14);
98                }
99            }
100
101            #[test]
102            fn mask_propagation() {
103                let xs = vec64![1.0, 2.0, 3.0];
104                let mask = single_null_mask(3, 1); // middle lane null
105                let arr = ($call)(&xs);
106                // lane 1 -> NaN + null
107                assert!(!arr.null_mask.as_ref().unwrap().get(1));
108                assert!(arr.data[1].is_nan());
109            }
110        }
111    };
112}