simd_kernels/kernels/scientific/distributions/univariate/common/mod.rs
1// Copyright Peter Bower 2025. All Rights Reserved.
2// Licensed under Mozilla Public License (MPL) 2.0.
3
4//! # **Common Distribution Utilities** - *Shared Testing and Helper Infrastructure*
5//!
6//! Common testing utilities, helper functions, and macros shared across all univariate
7//! distribution implementations to ensure consistency and reduce code duplication.
8//!
9//! ## Testing Infrastructure
10//! This module provides standardised testing patterns that validate:
11//! - **Numerical accuracy**: Comparison against reference implementations
12//! - **Null handling**: Proper propagation of missing values through calculations
13//! - **Edge cases**: Boundary conditions and special value handling
14//! - **Performance**: Bulk vs scalar operation consistency
15//!
16//! ## Helper Functions
17//! - **Array extraction**: Safe unwrapping of dense arrays without null masks
18//! - **Scalar testing**: Single-value operation testing utilities
19//! - **Mask creation**: Null mask generation for testing scenarios
20//! - **Tolerance checking**: Numerical comparison with configurable precision
21//!
22//! ## Test Macros
23//! The `common_tests!` macro generates standard test suites for distribution functions,
24//! ensuring consistent validation across all statistical implementations.
25
26#[cfg(feature = "simd")]
27pub mod simd;
28/// Scalar implementations of common distribution utilities.
29pub mod std;
30
31use minarrow::{Bitmask, Buffer, FloatArray};
32
33// Common test helpers
34
35/// Test Helper: unwrap `FloatArray`, assert *no* null mask, return data.
36pub fn dense_data(arr: FloatArray<f64>) -> Buffer<f64> {
37 assert!(arr.null_mask.is_none(), "unexpected mask on dense path");
38 arr.data
39}
40
41/// Build a 1-lane slice (`&[T]`) on the fly, call `kernel`,
42/// and return the single f64 result for *scalar* comparison.
43pub fn scalar_call<F>(kernel: F, x: f64) -> f64
44where
45 F: Fn(&[f64]) -> FloatArray<f64>,
46{
47 dense_data(kernel(&[x])).into_iter().next().unwrap()
48}
49
50/// Create a mask of given length with exactly the lane `idx` null.
51pub fn single_null_mask(len: usize, idx: usize) -> Bitmask {
52 let mut m = Bitmask::new_set_all(len, true);
53 unsafe { m.set_unchecked(idx, false) };
54 m
55}
56
57/// Assert absolute difference ≤ `tol`.
58pub fn assert_close(a: f64, b: f64, tol: f64) {
59 assert!(
60 (a - b).abs() < tol,
61 "assert_close failed: {} vs {} (tol={})",
62 a,
63 b,
64 tol
65 );
66}
67
68/// Generate the three most-common tests (empty-input, mask propagation,
69/// bulk-vs-scalar) for a `fn kernel(&[f64]) -> FloatArray<f64>`.
70///
71/// Usage:
72/// ```ignore
73/// common_tests!(normal_pdf, |x| normal_pdf(x, 0.0, 1.0, None, None).unwrap());
74/// ```
75#[macro_export]
76macro_rules! common_tests {
77 // $name – a unique test-group prefix
78 // $call:expr – *closure* that gets &[f64] and returns FloatArray<f64>
79 ($name:ident, $call:expr) => {
80 mod $name {
81 use super::*;
82 use crate::tests::common::*;
83
84 #[test]
85 fn empty_input() {
86 let arr = ($call)(&[]);
87 assert!(arr.data.is_empty());
88 assert!(arr.null_mask.is_none());
89 }
90
91 #[test]
92 fn bulk_vs_scalar_consistency() {
93 let xs = vec64![-3.0, -1.0, 0.0, 1.0, 2.0];
94 let bulk = dense(($call)(&xs));
95 for (i, &x) in xs.iter().enumerate() {
96 let scalar = scalar_call($call, x);
97 assert_close(bulk[i], scalar, 1e-14);
98 }
99 }
100
101 #[test]
102 fn mask_propagation() {
103 let xs = vec64![1.0, 2.0, 3.0];
104 let mask = single_null_mask(3, 1); // middle lane null
105 let arr = ($call)(&xs);
106 // lane 1 -> NaN + null
107 assert!(!arr.null_mask.as_ref().unwrap().get(1));
108 assert!(arr.data[1].is_nan());
109 }
110 }
111 };
112}