simd_kernels/kernels/scientific/distributions/univariate/common/mod.rs
1// Copyright (c) 2025 SpaceCell Enterprises Ltd
2// SPDX-License-Identifier: AGPL-3.0-or-later
3// Commercial licensing available. See LICENSE and LICENSING.md.
4
5//! # **Common Distribution Utilities** - *Shared Testing and Helper Infrastructure*
6//!
7//! Common testing utilities, helper functions, and macros shared across all univariate
8//! distribution implementations to ensure consistency and reduce code duplication.
9//!
10//! ## Testing Infrastructure
11//! This module provides standardised testing patterns that validate:
12//! - **Numerical accuracy**: Comparison against reference implementations
13//! - **Null handling**: Proper propagation of missing values through calculations
14//! - **Edge cases**: Boundary conditions and special value handling
15//! - **Performance**: Bulk vs scalar operation consistency
16//!
17//! ## Helper Functions
18//! - **Array extraction**: Safe unwrapping of dense arrays without null masks
19//! - **Scalar testing**: Single-value operation testing utilities
20//! - **Mask creation**: Null mask generation for testing scenarios
21//! - **Tolerance checking**: Numerical comparison with configurable precision
22//!
23//! ## Test Macros
24//! The `common_tests!` macro generates standard test suites for distribution functions,
25//! ensuring consistent validation across all statistical implementations.
26
27#[cfg(feature = "simd")]
28pub mod simd;
29/// Scalar implementations of common distribution utilities.
30pub mod std;
31
32use minarrow::{Bitmask, Buffer, FloatArray};
33
34// Common test helpers
35
36/// Test Helper: unwrap `FloatArray`, assert *no* null mask, return data.
37pub fn dense_data(arr: FloatArray<f64>) -> Buffer<f64> {
38 assert!(arr.null_mask.is_none(), "unexpected mask on dense path");
39 arr.data
40}
41
42/// Build a 1-lane slice (`&[T]`) on the fly, call `kernel`,
43/// and return the single f64 result for *scalar* comparison.
44pub fn scalar_call<F>(kernel: F, x: f64) -> f64
45where
46 F: Fn(&[f64]) -> FloatArray<f64>,
47{
48 dense_data(kernel(&[x])).into_iter().next().unwrap()
49}
50
51/// Create a mask of given length with exactly the lane `idx` null.
52pub fn single_null_mask(len: usize, idx: usize) -> Bitmask {
53 let mut m = Bitmask::new_set_all(len, true);
54 unsafe { m.set_unchecked(idx, false) };
55 m
56}
57
58/// Assert absolute difference ≤ `tol`.
59pub fn assert_close(a: f64, b: f64, tol: f64) {
60 assert!(
61 (a - b).abs() < tol,
62 "assert_close failed: {} vs {} (tol={})",
63 a,
64 b,
65 tol
66 );
67}
68
69/// Generate the three most-common tests (empty-input, mask propagation,
70/// bulk-vs-scalar) for a `fn kernel(&[f64]) -> FloatArray<f64>`.
71///
72/// Usage:
73/// ```ignore
74/// common_tests!(normal_pdf, |x| normal_pdf(x, 0.0, 1.0, None, None).unwrap());
75/// ```
76#[macro_export]
77macro_rules! common_tests {
78 // $name – a unique test-group prefix
79 // $call:expr – *closure* that gets &[f64] and returns FloatArray<f64>
80 ($name:ident, $call:expr) => {
81 mod $name {
82 use super::*;
83 use crate::tests::common::*;
84
85 #[test]
86 fn empty_input() {
87 let arr = ($call)(&[]);
88 assert!(arr.data.is_empty());
89 assert!(arr.null_mask.is_none());
90 }
91
92 #[test]
93 fn bulk_vs_scalar_consistency() {
94 let xs = vec64![-3.0, -1.0, 0.0, 1.0, 2.0];
95 let bulk = dense(($call)(&xs));
96 for (i, &x) in xs.iter().enumerate() {
97 let scalar = scalar_call($call, x);
98 assert_close(bulk[i], scalar, 1e-14);
99 }
100 }
101
102 #[test]
103 fn mask_propagation() {
104 let xs = vec64![1.0, 2.0, 3.0];
105 let mask = single_null_mask(3, 1); // middle lane null
106 let arr = ($call)(&xs);
107 // lane 1 -> NaN + null
108 assert!(!arr.null_mask.as_ref().unwrap().get(1));
109 assert!(arr.data[1].is_nan());
110 }
111 }
112 };
113}