Skip to main content

simd_kernels/kernels/scientific/distributions/univariate/common/
mod.rs

1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Common Distribution Utilities** - *Shared Testing and Helper Infrastructure*
6//!
7//! Common testing utilities, helper functions, and macros shared across all univariate
8//! distribution implementations to ensure consistency and reduce code duplication.
9//!
10//! ## Testing Infrastructure
11//! This module provides standardised testing patterns that validate:
12//! - **Numerical accuracy**: Comparison against reference implementations
13//! - **Null handling**: Proper propagation of missing values through calculations
14//! - **Edge cases**: Boundary conditions and special value handling
15//! - **Performance**: Bulk vs scalar operation consistency
16//!
17//! ## Helper Functions
18//! - **Array extraction**: Safe unwrapping of dense arrays without null masks
19//! - **Scalar testing**: Single-value operation testing utilities
20//! - **Mask creation**: Null mask generation for testing scenarios
21//! - **Tolerance checking**: Numerical comparison with configurable precision
22//!
23//! ## Test Macros
24//! The `common_tests!` macro generates standard test suites for distribution functions,
25//! ensuring consistent validation across all statistical implementations.
26
27#[cfg(feature = "simd")]
28pub mod simd;
29/// Scalar implementations of common distribution utilities.
30pub mod std;
31
32use minarrow::{Bitmask, Buffer, FloatArray};
33
34// Common test helpers
35
36/// Test Helper: unwrap `FloatArray`, assert *no* null mask, return data.
37pub fn dense_data(arr: FloatArray<f64>) -> Buffer<f64> {
38    assert!(arr.null_mask.is_none(), "unexpected mask on dense path");
39    arr.data
40}
41
42/// Build a 1-lane slice (`&[T]`) on the fly, call `kernel`,
43/// and return the single f64 result for *scalar* comparison.
44pub fn scalar_call<F>(kernel: F, x: f64) -> f64
45where
46    F: Fn(&[f64]) -> FloatArray<f64>,
47{
48    dense_data(kernel(&[x])).into_iter().next().unwrap()
49}
50
51/// Create a mask of given length with exactly the lane `idx` null.
52pub fn single_null_mask(len: usize, idx: usize) -> Bitmask {
53    let mut m = Bitmask::new_set_all(len, true);
54    unsafe { m.set_unchecked(idx, false) };
55    m
56}
57
58/// Assert absolute difference ≤ `tol`.
59pub fn assert_close(a: f64, b: f64, tol: f64) {
60    assert!(
61        (a - b).abs() < tol,
62        "assert_close failed: {} vs {} (tol={})",
63        a,
64        b,
65        tol
66    );
67}
68
69/// Generate the three most-common tests (empty-input, mask propagation,
70/// bulk-vs-scalar) for a `fn kernel(&[f64]) -> FloatArray<f64>`.
71///
72/// Usage:
73/// ```ignore
74/// common_tests!(normal_pdf, |x| normal_pdf(x, 0.0, 1.0, None, None).unwrap());
75/// ```
76#[macro_export]
77macro_rules! common_tests {
78    // $name    – a unique test-group prefix
79    // $call:expr – *closure* that gets &[f64] and returns FloatArray<f64>
80    ($name:ident, $call:expr) => {
81        mod $name {
82            use super::*;
83            use crate::tests::common::*;
84
85            #[test]
86            fn empty_input() {
87                let arr = ($call)(&[]);
88                assert!(arr.data.is_empty());
89                assert!(arr.null_mask.is_none());
90            }
91
92            #[test]
93            fn bulk_vs_scalar_consistency() {
94                let xs = vec64![-3.0, -1.0, 0.0, 1.0, 2.0];
95                let bulk = dense(($call)(&xs));
96                for (i, &x) in xs.iter().enumerate() {
97                    let scalar = scalar_call($call, x);
98                    assert_close(bulk[i], scalar, 1e-14);
99                }
100            }
101
102            #[test]
103            fn mask_propagation() {
104                let xs = vec64![1.0, 2.0, 3.0];
105                let mask = single_null_mask(3, 1); // middle lane null
106                let arr = ($call)(&xs);
107                // lane 1 -> NaN + null
108                assert!(!arr.null_mask.as_ref().unwrap().get(1));
109                assert!(arr.data[1].is_nan());
110            }
111        }
112    };
113}