Skip to main content

vforce/
lib.rs

1//! Safe no_std Rust bindings for the [VForce](https://developer.apple.com/documentation/accelerate/vforce-library?language=objc) family of hardware-accelerated transcendental vectorized math functions in the [Accelerate](https://developer.apple.com/documentation/accelerate?language=objc) framework on MacOS.
2//! 
3//! Provides a safe API for VForce functions generic over single and double precision floats, with automatic chunking for very large arrays. 
4//!
5//! ```rust
6//! use vforce::arithmetic::{pow_array, pow_array_in_place};
7//!
8//! let mut bases: Vec<f64> = vec![1.0, 6.0, 2.5];
9//! let exponents: Vec<f64> = vec![2.0, 4.0, 1.3];
10//! let mut out = vec![0.0f64; 3];
11//!
12//! // results can be written out to another array:
13//! pow_array(&mut out, &bases, &exponents).unwrap();
14//! assert_eq!(out, vec![1.0f64, 1296.0f64, 2.5f64.powf(1.3)]);
15//!
16//! // or overwrite one of the original arrays:
17//! pow_array_in_place(&mut bases, &exponents).unwrap();
18//! assert_eq!(bases, vec![1.0f64, 1296.0f64, 2.5f64.powf(1.3)]);
19//! ```
20//!
21//! Either `f64` or `f32` may be used, but the type must be consistent for all arrays used for a given function call.
22//!
23//! The VForce functions are hand-tuned implementations of transcendental vectorized array functions built with NEON and optimized for Apple hardware. VForce is part of the Apple Accelerate framework, which ships on all MacOS versions since 10.3 (October 2003) and many more Apple devices since then.
24//!
25//! This will not compile for any OS other than MacOS, and linking will be invalid on platforms without Accelerate. Take care to put code that uses these functions behind conditional compilation flags.
26//!
27//! The original VForce functions are indexed by `i32`, causing them to fail when processing arrays longer than `i32::MAX` = 2,147,483,647 elements long. This implementation checks for excessive array length and will instead process arrays in `i32::MAX`-size chunks sequentially should they be input.
28//!
29//! Almost all functions provide an out-of-place variant and in-place variant, in order to allow safe overwriting without breaking alias XOR mutability.
30#![no_std]
31#![cfg(target_os = "macos")]
32
33mod accelerate;
34
35use core::fmt::Display;
36use accelerate::fns::*;
37
38pub use accelerate::AccelerateComplex;
39
40#[derive(Debug, Clone, Copy)]
41pub enum AccelerateError {
42    /// Inputs and outputs are not all the same length
43    LengthMismatch { expected: usize, got: usize },
44}
45
46impl Display for AccelerateError {
47    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48        match self {
49            Self::LengthMismatch { expected, got } => {
50                write!(f, "AccelerateError::LengthMismatch - vforce received arrays of different lengths: expected {} elements, got {} elements", expected, got)
51            }
52        }
53    }
54}
55
56mod sealed {
57    pub trait Sealed {}
58    impl Sealed for f32 {}
59    impl Sealed for f64 {}
60}
61
62/// Ensures that all inputs to an accelerate function must be the same numeric type: either f64 or
63/// f32
64pub trait AccelerateFloat: sealed::Sealed + Copy {
65    // Binary operations (out, a, b, count)
66    /// # Safety
67    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
68    /// type, either f64 or f32, and all arrays must be of length 'count'
69    unsafe fn accelerate_pow(out: *mut Self, base: *const Self, exp: *const Self, count: *const i32);
70    /// # Safety
71    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
72    /// type, either f64 or f32, and all arrays must be of length 'count'
73    unsafe fn accelerate_div(out: *mut Self, numerator: *const Self, denominator: *const Self, count: *const i32);
74    /// # Safety
75    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
76    /// type, either f64 or f32, and all arrays must be of length 'count'
77    unsafe fn accelerate_copysign(out: *mut Self, magnitude: *const Self, sign: *const Self, count: *const i32);
78    /// # Safety
79    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
80    /// type, either f64 or f32, and all arrays must be of length 'count'
81    unsafe fn accelerate_fmod(out: *mut Self, numerator: *const Self, denominator: *const Self, count: *const i32);
82    /// # Safety
83    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
84    /// type, either f64 or f32, and all arrays must be of length 'count'
85    unsafe fn accelerate_remainder(out: *mut Self, numerator: *const Self, denominator: *const Self, count: *const i32);
86    /// # Safety
87    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
88    /// type, either f64 or f32, and all arrays must be of length 'count'
89    unsafe fn accelerate_nextafter(out: *mut Self, input: *const Self, direction: *const Self, count: *const i32);
90    /// # Safety
91    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
92    /// type, either f64 or f32, and all arrays must be of length 'count'
93    unsafe fn accelerate_atan2(out: *mut Self, y: *const Self, x: *const Self, count: *const i32);
94
95    // Unary operations (out, input, count)
96
97    /// # Safety
98    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
99    /// type, either f64 or f32, and all arrays must be of length 'count'
100    unsafe fn accelerate_ceil(out: *mut Self, input: *const Self, count: *const i32);
101    /// # Safety
102    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
103    /// type, either f64 or f32, and all arrays must be of length 'count'
104    unsafe fn accelerate_floor(out: *mut Self, input: *const Self, count: *const i32);
105    /// # Safety
106    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
107    /// type, either f64 or f32, and all arrays must be of length 'count'
108    unsafe fn accelerate_fabs(out: *mut Self, input: *const Self, count: *const i32);
109    /// # Safety
110    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
111    /// type, either f64 or f32, and all arrays must be of length 'count'
112    unsafe fn accelerate_int(out: *mut Self, input: *const Self, count: *const i32);
113    /// # Safety
114    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
115    /// type, either f64 or f32, and all arrays must be of length 'count'
116    unsafe fn accelerate_nint(out: *mut Self, input: *const Self, count: *const i32);
117    /// # Safety
118    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
119    /// type, either f64 or f32, and all arrays must be of length 'count'
120    unsafe fn accelerate_rsqrt(out: *mut Self, input: *const Self, count: *const i32);
121    /// # Safety
122    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
123    /// type, either f64 or f32, and all arrays must be of length 'count'
124    unsafe fn accelerate_sqrt(out: *mut Self, input: *const Self, count: *const i32);
125    /// # Safety
126    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
127    /// type, either f64 or f32, and all arrays must be of length 'count'
128    unsafe fn accelerate_rec(out: *mut Self, input: *const Self, count: *const i32);
129    /// # Safety
130    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
131    /// type, either f64 or f32, and all arrays must be of length 'count'
132    unsafe fn accelerate_exp(out: *mut Self, input: *const Self, count: *const i32);
133    /// # Safety
134    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
135    /// type, either f64 or f32, and all arrays must be of length 'count'
136    unsafe fn accelerate_exp2(out: *mut Self, input: *const Self, count: *const i32);
137    /// # Safety
138    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
139    /// type, either f64 or f32, and all arrays must be of length 'count'
140    unsafe fn accelerate_expm1(out: *mut Self, input: *const Self, count: *const i32);
141    /// # Safety
142    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
143    /// type, either f64 or f32, and all arrays must be of length 'count'
144    unsafe fn accelerate_log(out: *mut Self, input: *const Self, count: *const i32);
145    /// # Safety
146    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
147    /// type, either f64 or f32, and all arrays must be of length 'count'
148    unsafe fn accelerate_log1p(out: *mut Self, input: *const Self, count: *const i32);
149    /// # Safety
150    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
151    /// type, either f64 or f32, and all arrays must be of length 'count'
152    unsafe fn accelerate_log2(out: *mut Self, input: *const Self, count: *const i32);
153    /// # Safety
154    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
155    /// type, either f64 or f32, and all arrays must be of length 'count'
156    unsafe fn accelerate_log10(out: *mut Self, input: *const Self, count: *const i32);
157    /// # Safety
158    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
159    /// type, either f64 or f32, and all arrays must be of length 'count'
160    unsafe fn accelerate_logb(out: *mut Self, input: *const Self, count: *const i32);
161    /// # Safety
162    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
163    /// type, either f64 or f32, and all arrays must be of length 'count'
164    unsafe fn accelerate_sin(out: *mut Self, input: *const Self, count: *const i32);
165    /// # Safety
166    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
167    /// type, either f64 or f32, and all arrays must be of length 'count'
168    unsafe fn accelerate_sinpi(out: *mut Self, input: *const Self, count: *const i32);
169    /// # Safety
170    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
171    /// type, either f64 or f32, and all arrays must be of length 'count'
172    unsafe fn accelerate_cos(out: *mut Self, input: *const Self, count: *const i32);
173    /// # Safety
174    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
175    /// type, either f64 or f32, and all arrays must be of length 'count'
176    unsafe fn accelerate_cospi(out: *mut Self, input: *const Self, count: *const i32);
177    /// # Safety
178    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
179    /// type, either f64 or f32, and all arrays must be of length 'count'
180    unsafe fn accelerate_tan(out: *mut Self, input: *const Self, count: *const i32);
181    /// # Safety
182    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
183    /// type, either f64 or f32, and all arrays must be of length 'count'
184    unsafe fn accelerate_tanpi(out: *mut Self, input: *const Self, count: *const i32);
185    /// # Safety
186    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
187    /// type, either f64 or f32, and all arrays must be of length 'count'
188    unsafe fn accelerate_asin(out: *mut Self, input: *const Self, count: *const i32);
189    /// # Safety
190    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
191    /// type, either f64 or f32, and all arrays must be of length 'count'
192    unsafe fn accelerate_acos(out: *mut Self, input: *const Self, count: *const i32);
193    /// # Safety
194    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
195    /// type, either f64 or f32, and all arrays must be of length 'count'
196    unsafe fn accelerate_atan(out: *mut Self, input: *const Self, count: *const i32);
197    /// # Safety
198    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
199    /// type, either f64 or f32, and all arrays must be of length 'count'
200    unsafe fn accelerate_sinh(out: *mut Self, input: *const Self, count: *const i32);
201    /// # Safety
202    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
203    /// type, either f64 or f32, and all arrays must be of length 'count'
204    unsafe fn accelerate_cosh(out: *mut Self, input: *const Self, count: *const i32);
205    /// # Safety
206    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
207    /// type, either f64 or f32, and all arrays must be of length 'count'
208    unsafe fn accelerate_tanh(out: *mut Self, input: *const Self, count: *const i32);
209    /// # Safety
210    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
211    /// type, either f64 or f32, and all arrays must be of length 'count'
212    unsafe fn accelerate_asinh(out: *mut Self, input: *const Self, count: *const i32);
213    /// # Safety
214    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
215    /// type, either f64 or f32, and all arrays must be of length 'count'
216    unsafe fn accelerate_acosh(out: *mut Self, input: *const Self, count: *const i32);
217    /// # Safety
218    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
219    /// type, either f64 or f32, and all arrays must be of length 'count'
220    unsafe fn accelerate_atanh(out: *mut Self, input: *const Self, count: *const i32);
221
222    // Special: sincos (sin_out, cos_out, input, count)
223    /// # Safety
224    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
225    /// type, either f64 or f32, and all arrays must be of length 'count'
226    unsafe fn accelerate_sincos(sin_out: *mut Self, cos_out: *mut Self, input: *const Self, count: *const i32);
227
228    // Special: cosisin
229    /// # Safety
230    /// All inputs must point to valid arrays of floating-point numbers. All must be of the same
231    /// type, either f64 or f32, and all arrays must be of length 'count'
232    unsafe fn accelerate_cosisin(out: *mut AccelerateComplex<Self>, input: *const Self, count: *const i32);
233}
234
235macro_rules! impl_accelerate_float {
236    ($ty:ty, $pow:ident, $div:ident, $copysign:ident, $fmod:ident, $remainder:ident,
237     $nextafter:ident, $atan2:ident, $ceil:ident, $floor:ident, $fabs:ident,
238     $int:ident, $nint:ident, $rsqrt:ident, $sqrt:ident, $rec:ident,
239     $exp:ident, $exp2:ident, $expm1:ident, $log:ident, $log1p:ident,
240     $log2:ident, $log10:ident, $logb:ident, $sin:ident, $sinpi:ident,
241     $cos:ident, $cospi:ident, $tan:ident, $tanpi:ident, $asin:ident,
242     $acos:ident, $atan:ident, $sinh:ident, $cosh:ident, $tanh:ident,
243     $asinh:ident, $acosh:ident, $atanh:ident, $sincos:ident, $cosisin:ident) => {
244        impl AccelerateFloat for $ty {
245            unsafe fn accelerate_pow(out: *mut Self, base: *const Self, exp: *const Self, count: *const i32)
246            { unsafe { $pow(out, exp, base, count) } }
247            unsafe fn accelerate_div(out: *mut Self, n: *const Self, d: *const Self, count: *const i32)
248            { unsafe { $div(out, n, d, count) } }
249            unsafe fn accelerate_copysign(out: *mut Self, m: *const Self, s: *const Self, count: *const i32)
250            { unsafe { $copysign(out, m, s, count) } }
251            unsafe fn accelerate_fmod(out: *mut Self, n: *const Self, d: *const Self, count: *const i32)
252            { unsafe { $fmod(out, n, d, count) } }
253            unsafe fn accelerate_remainder(out: *mut Self, n: *const Self, d: *const Self, count: *const i32)
254            { unsafe { $remainder(out, n, d, count) } }
255            unsafe fn accelerate_nextafter(out: *mut Self, i: *const Self, d: *const Self, count: *const i32)
256            { unsafe { $nextafter(out, i, d, count) } }
257            unsafe fn accelerate_atan2(out: *mut Self, y: *const Self, x: *const Self, count: *const i32)
258            { unsafe { $atan2(out, y, x, count) } }
259
260           unsafe fn accelerate_ceil(out: *mut Self, i: *const Self, count: *const i32)
261            { unsafe { $ceil(out, i, count) } }
262            unsafe fn accelerate_floor(out: *mut Self, i: *const Self, count: *const i32)
263            { unsafe { $floor(out, i, count) } }
264            unsafe fn accelerate_fabs(out: *mut Self, i: *const Self, count: *const i32)
265            { unsafe { $fabs(out, i, count) } }
266            unsafe fn accelerate_int(out: *mut Self, i: *const Self, count: *const i32)
267            { unsafe { $int(out, i, count) } }
268            unsafe fn accelerate_nint(out: *mut Self, i: *const Self, count: *const i32)
269            { unsafe { $nint(out, i, count) } }
270            unsafe fn accelerate_rsqrt(out: *mut Self, i: *const Self, count: *const i32)
271            { unsafe { $rsqrt(out, i, count) } }
272            unsafe fn accelerate_sqrt(out: *mut Self, i: *const Self, count: *const i32)
273            { unsafe { $sqrt(out, i, count) } }
274            unsafe fn accelerate_rec(out: *mut Self, i: *const Self, count: *const i32)
275            { unsafe { $rec(out, i, count) } }
276            unsafe fn accelerate_exp(out: *mut Self, i: *const Self, count: *const i32)
277            { unsafe { $exp(out, i, count) } }
278            unsafe fn accelerate_exp2(out: *mut Self, i: *const Self, count: *const i32)
279            { unsafe { $exp2(out, i, count) } }
280            unsafe fn accelerate_expm1(out: *mut Self, i: *const Self, count: *const i32)
281            { unsafe { $expm1(out, i, count) } }
282            unsafe fn accelerate_log(out: *mut Self, i: *const Self, count: *const i32)
283            { unsafe { $log(out, i, count) } }
284            unsafe fn accelerate_log1p(out: *mut Self, i: *const Self, count: *const i32)
285            { unsafe { $log1p(out, i, count) } }
286            unsafe fn accelerate_log2(out: *mut Self, i: *const Self, count: *const i32)
287            { unsafe { $log2(out, i, count) } }
288            unsafe fn accelerate_log10(out: *mut Self, i: *const Self, count: *const i32)
289            { unsafe { $log10(out, i, count) } }
290            unsafe fn accelerate_logb(out: *mut Self, i: *const Self, count: *const i32)
291            { unsafe { $logb(out, i, count) } }
292            unsafe fn accelerate_sin(out: *mut Self, i: *const Self, count: *const i32)
293            { unsafe { $sin(out, i, count) } }
294            unsafe fn accelerate_sinpi(out: *mut Self, i: *const Self, count: *const i32)
295            { unsafe { $sinpi(out, i, count) } }
296            unsafe fn accelerate_cos(out: *mut Self, i: *const Self, count: *const i32)
297            { unsafe { $cos(out, i, count) } }
298            unsafe fn accelerate_cospi(out: *mut Self, i: *const Self, count: *const i32)
299            { unsafe { $cospi(out, i, count) } }
300            unsafe fn accelerate_tan(out: *mut Self, i: *const Self, count: *const i32)
301            { unsafe { $tan(out, i, count) } }
302            unsafe fn accelerate_tanpi(out: *mut Self, i: *const Self, count: *const i32)
303            { unsafe { $tanpi(out, i, count) } }
304            unsafe fn accelerate_asin(out: *mut Self, i: *const Self, count: *const i32)
305            { unsafe { $asin(out, i, count) } }
306            unsafe fn accelerate_acos(out: *mut Self, i: *const Self, count: *const i32)
307            { unsafe { $acos(out, i, count) } }
308            unsafe fn accelerate_atan(out: *mut Self, i: *const Self, count: *const i32)
309            { unsafe { $atan(out, i, count) } }
310            unsafe fn accelerate_sinh(out: *mut Self, i: *const Self, count: *const i32)
311            { unsafe { $sinh(out, i, count) } }
312            unsafe fn accelerate_cosh(out: *mut Self, i: *const Self, count: *const i32)
313            { unsafe { $cosh(out, i, count) } }
314            unsafe fn accelerate_tanh(out: *mut Self, i: *const Self, count: *const i32)
315            { unsafe { $tanh(out, i, count) } }
316            unsafe fn accelerate_asinh(out: *mut Self, i: *const Self, count: *const i32)
317            { unsafe { $asinh(out, i, count) } }
318            unsafe fn accelerate_acosh(out: *mut Self, i: *const Self, count: *const i32)
319            { unsafe { $acosh(out, i, count) } }
320            unsafe fn accelerate_atanh(out: *mut Self, i: *const Self, count: *const i32)
321            { unsafe { $atanh(out, i, count) } }
322            unsafe fn accelerate_sincos(s: *mut Self, c: *mut Self, i: *const Self, count: *const i32)
323            { unsafe { $sincos(s, c, i, count) } }
324            unsafe fn accelerate_cosisin(out: *mut AccelerateComplex<Self>, i: *const Self, count: *const i32)
325            { unsafe { $cosisin(out, i, count) } }
326        }
327    };
328}
329
330impl_accelerate_float!(f64,
331    vvpow, vvdiv, vvcopysign, vvfmod, vvremainder, vvnextafter, vvatan2,
332    vvceil, vvfloor, vvfabs, vvint, vvnint, vvrsqrt, vvsqrt, vvrec,
333    vvexp, vvexp2, vvexpm1, vvlog, vvlog1p, vvlog2, vvlog10, vvlogb,
334    vvsin, vvsinpi, vvcos, vvcospi, vvtan, vvtanpi, vvasin, vvacos, vvatan,
335    vvsinh, vvcosh, vvtanh, vvasinh, vvacosh, vvatanh, vvsincos, vvcosisin
336);
337
338impl_accelerate_float!(f32,
339    vvpowf, vvdivf, vvcopysignf, vvfmodf, vvremainderf, vvnextafterf, vvatan2f,
340    vvceilf, vvfloorf, vvfabsf, vvintf, vvnintf, vvrsqrtf, vvsqrtf, vvrecf,
341    vvexpf, vvexp2f, vvexpm1f, vvlogf, vvlog1pf, vvlog2f, vvlog10f, vvlogbf,
342    vvsinf, vvsinpif, vvcosf, vvcospif, vvtanf, vvtanpif, vvasinf, vvacosf, vvatanf,
343    vvsinhf, vvcoshf, vvtanhf, vvasinhf, vvacoshf, vvatanhf, vvsincosf, vvcosisinf
344);
345
346macro_rules! binary_vforce_op {
347    (
348    $(#[$out_attr:meta])*
349    $name:ident,
350    $(#[$in_place_attr:meta])*
351    $name_in_place:ident,
352    $method:ident,
353    $a_name:ident,
354    $b_name:ident
355    ) => {
356        $(#[$out_attr])*
357        pub fn $name<AF: AccelerateFloat>(
358            out: &mut [AF], $a_name: &[AF], $b_name: &[AF]
359        ) -> Result<(), AccelerateError> {
360            check_lengths_2($a_name.len(), $b_name.len(), out.len())?;
361            for (out_chunk, (a_chunk, b_chunk)) in out.chunks_mut(CHUNK)
362                .zip($a_name.chunks(CHUNK).zip($b_name.chunks(CHUNK)))
363            {
364                let count = out_chunk.len() as i32;
365                unsafe { AF::$method(out_chunk.as_mut_ptr(), a_chunk.as_ptr(), b_chunk.as_ptr(), &count); }
366            }
367            Ok(())
368        }
369        $(#[$in_place_attr])*
370        pub fn $name_in_place<AF: AccelerateFloat>(
371            $a_name: &mut [AF], $b_name: &[AF]
372        ) -> Result<(), AccelerateError> {
373            check_lengths_1($a_name.len(), $b_name.len())?;
374            for (a_chunk, b_chunk) in $a_name.chunks_mut(CHUNK).zip($b_name.chunks(CHUNK)) {
375                let count = a_chunk.len() as i32;
376                unsafe { AF::$method(a_chunk.as_mut_ptr(), a_chunk.as_ptr(), b_chunk.as_ptr(), &count); }
377            }
378            Ok(())
379        }
380    };
381}
382
383macro_rules! unary_vforce_op {
384    (
385    $(#[$out_attr:meta])*
386    $name:ident,
387    $(#[$in_place_attr:meta])*
388    $name_in_place:ident,
389    $method:ident,
390    $input_name:ident
391    ) => {
392        $(#[$out_attr])*
393        pub fn $name<AF: AccelerateFloat>(
394            out: &mut [AF], $input_name: &[AF]
395        ) -> Result<(), AccelerateError> {
396            check_lengths_1($input_name.len(), out.len())?;
397            for (out_chunk, in_chunk) in out.chunks_mut(CHUNK).zip($input_name.chunks(CHUNK)) {
398                let count = out_chunk.len() as i32;
399                unsafe { AF::$method(out_chunk.as_mut_ptr(), in_chunk.as_ptr(), &count); }
400            }
401            Ok(())
402        }
403        $(#[$in_place_attr])*
404        pub fn $name_in_place<AF: AccelerateFloat>(
405            $input_name: &mut [AF]
406        ) {
407            for chunk in $input_name.chunks_mut(CHUNK) {
408                let count = chunk.len() as i32;
409                unsafe { AF::$method(chunk.as_mut_ptr(), chunk.as_ptr(), &count); }
410            }
411        }
412    };
413}
414
415pub(crate) const CHUNK: usize = i32::MAX as usize;
416
417pub(crate) fn check_lengths_1(a: usize, b: usize) -> Result<(), AccelerateError> {
418    if a != b {
419        return Err(AccelerateError::LengthMismatch { expected: a, got: b });
420    }
421    Ok(())
422}
423pub(crate) fn check_lengths_2(a: usize, b: usize, c: usize) -> Result<(), AccelerateError> {
424    if a != b {
425        return Err(AccelerateError::LengthMismatch { expected: a, got: b });
426    }
427    if a != c {
428        return Err(AccelerateError::LengthMismatch { expected: a, got: c });
429    }
430    Ok(())
431}
432
433pub mod arithmetic;
434pub mod exponential;
435pub mod trig;
436pub mod hyperbolic;
437
438#[cfg(test)]
439extern crate alloc;
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use super::arithmetic::*;
445    use super::exponential::*;
446    use super::trig::*;
447    use super::hyperbolic::*;
448    use alloc::vec;
449    use alloc::vec::Vec;
450
451    const INPUTS: [f64; 4] = [0.5, 1.0, 2.0, 3.5];
452    const POSITIVE: [f64; 4] = [0.25, 0.5, 1.0, 4.0];
453    const UNIT: [f64; 3] = [0.0, 0.25, -0.5];
454    const SMALL: [f64; 3] = [0.1, -0.3, 0.9];
455
456    fn assert_approx(actual: &[f64], expected: &[f64], tol: f64, name: &str) {
457        assert_eq!(actual.len(), expected.len(), "{name}: length mismatch");
458        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
459            assert!(
460                (a - e).abs() < tol,
461                "{name}[{i}]: got {a}, expected {e}, diff {}",
462                (a - e).abs()
463            );
464        }
465    }
466
467    fn assert_approx_f32(actual: &[f32], expected: &[f32], tol: f32, name: &str) {
468        assert_eq!(actual.len(), expected.len(), "{name}: length mismatch");
469        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
470            assert!(
471                (a - e).abs() < tol,
472                "{name}[{i}]: got {a}, expected {e}, diff {}",
473                (a - e).abs()
474            );
475        }
476    }
477
478    // Helper to test a unary out-of-place function against a scalar reference
479    fn check_unary(
480        vforce_fn: fn(&mut [f64], &[f64]) -> Result<(), AccelerateError>,
481        scalar_fn: fn(f64) -> f64,
482        inputs: &[f64],
483        name: &str,
484    ) {
485        let mut out = vec![0.0f64; inputs.len()];
486        vforce_fn(&mut out, inputs).unwrap();
487        let expected: Vec<f64> = inputs.iter().map(|&x| scalar_fn(x)).collect();
488        assert_approx(&out, &expected, 1e-10, name);
489    }
490
491    // Helper to test a unary in-place function against a scalar reference
492    fn check_unary_in_place(
493        vforce_fn: fn(&mut [f64]),
494        scalar_fn: fn(f64) -> f64,
495        inputs: &[f64],
496        name: &str,
497    ) {
498        let mut buf: Vec<f64> = inputs.to_vec();
499        vforce_fn(&mut buf);
500        let expected: Vec<f64> = inputs.iter().map(|&x| scalar_fn(x)).collect();
501        assert_approx(&buf, &expected, 1e-10, name);
502    }
503
504    // Helper to test a binary out-of-place function against a scalar reference
505    #[allow(clippy::type_complexity)]
506    fn check_binary(
507        vforce_fn: fn(&mut [f64], &[f64], &[f64]) -> Result<(), AccelerateError>,
508        scalar_fn: fn(f64, f64) -> f64,
509        a: &[f64],
510        b: &[f64],
511        name: &str,
512    ) {
513        let mut out = vec![0.0f64; a.len()];
514        vforce_fn(&mut out, a, b).unwrap();
515        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| scalar_fn(x, y)).collect();
516        assert_approx(&out, &expected, 1e-10, name);
517    }
518
519    // Helper to test a binary in-place function against a scalar reference
520    fn check_binary_in_place(
521        vforce_fn: fn(&mut [f64], &[f64]) -> Result<(), AccelerateError>,
522        scalar_fn: fn(f64, f64) -> f64,
523        a: &[f64],
524        b: &[f64],
525        name: &str,
526    ) {
527        let mut buf: Vec<f64> = a.to_vec();
528        vforce_fn(&mut buf, b).unwrap();
529        let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| scalar_fn(x, y)).collect();
530        assert_approx(&buf, &expected, 1e-10, name);
531    }
532
533    // ── Power functions ──
534
535    #[test]
536    fn test_pow_array() {
537        let bases = [2.0, 3.0, 4.0, 5.0];
538        let exponents = [3.0, 2.0, 0.5, 1.0];
539        check_binary(pow_array, f64::powf, &bases, &exponents, "pow_array");
540        check_binary_in_place(pow_array_in_place, f64::powf, &bases, &exponents, "pow_array_in_place");
541    }
542
543    // ── Arithmetic and Auxiliary functions ──
544
545    #[test]
546    fn test_div_array() {
547        let num = [10.0, 9.0, 8.0, 7.0];
548        let den = [2.0, 3.0, 4.0, 0.5];
549        check_binary(div_array, |a, b| a / b, &num, &den, "div_array");
550        check_binary_in_place(div_array_in_place, |a, b| a / b, &num, &den, "div_array_in_place");
551    }
552
553    #[test]
554    fn test_copysign_array() {
555        let mag = [1.0, -2.0, 3.0, -4.0];
556        let sign = [-1.0, 1.0, -1.0, 1.0];
557        check_binary(copysign_array, f64::copysign, &mag, &sign, "copysign_array");
558        check_binary_in_place(copysign_array_in_place, f64::copysign, &mag, &sign, "copysign_array_in_place");
559    }
560
561    #[test]
562    fn test_fmod_array() {
563        let num = [5.5, 7.0, -3.5, 10.0];
564        let den = [2.0, 3.0, 1.5, 3.0];
565        check_binary(fmod_array, |a, b| a % b, &num, &den, "fmod_array");
566        check_binary_in_place(fmod_array_in_place, |a, b| a % b, &num, &den, "fmod_array_in_place");
567    }
568
569    #[test]
570    #[ignore = "stub: no core equivalent for IEEE remainder"]
571    fn test_remainder_array() {
572        todo!("IEEE remainder differs from fmod; needs manual verification");
573    }
574
575    #[test]
576    #[ignore = "stub: no stable core equivalent for nextafter"]
577    fn test_nextafter_array() {
578        todo!("f64::next_up/next_down are not the same as nextafter(x, direction)");
579    }
580
581    #[test]
582    fn test_ceil_array() {
583        let inputs = [-1.5, 0.0, 0.3, 2.7];
584        check_unary(ceil_array, f64::ceil, &inputs, "ceil_array");
585        check_unary_in_place(ceil_array_in_place, f64::ceil, &inputs, "ceil_array_in_place");
586    }
587
588    #[test]
589    fn test_floor_array() {
590        let inputs = [-1.5, 0.0, 0.3, 2.7];
591        check_unary(floor_array, f64::floor, &inputs, "floor_array");
592        check_unary_in_place(floor_array_in_place, f64::floor, &inputs, "floor_array_in_place");
593    }
594
595    #[test]
596    fn test_fabs_array() {
597        let inputs = [-3.5, 0.0, 2.5, -0.1];
598        check_unary(fabs_array, f64::abs, &inputs, "fabs_array");
599        check_unary_in_place(fabs_array_in_place, f64::abs, &inputs, "fabs_array_in_place");
600    }
601
602    #[test]
603    fn test_int_array() {
604        let inputs = [-1.7, 0.0, 0.9, 2.3];
605        check_unary(int_array, f64::trunc, &inputs, "int_array");
606        check_unary_in_place(int_array_in_place, f64::trunc, &inputs, "int_array_in_place");
607    }
608
609    #[test]
610    #[ignore = "stub: vvnint rounding mode for ties may differ from f64::round"]
611    fn test_nint_array() {
612        todo!("vvnint may round ties to even vs f64::round which rounds ties away from zero");
613    }
614
615    #[test]
616    fn test_rsqrt_array() {
617        check_unary(rsqrt_array, |x| 1.0 / x.sqrt(), &POSITIVE, "rsqrt_array");
618        check_unary_in_place(rsqrt_array_in_place, |x| 1.0 / x.sqrt(), &POSITIVE, "rsqrt_array_in_place");
619    }
620
621    #[test]
622    fn test_sqrt_array() {
623        check_unary(sqrt_array, f64::sqrt, &POSITIVE, "sqrt_array");
624        check_unary_in_place(sqrt_array_in_place, f64::sqrt, &POSITIVE, "sqrt_array_in_place");
625    }
626
627    #[test]
628    fn test_rec_array() {
629        check_unary(rec_array, |x| 1.0 / x, &INPUTS, "rec_array");
630        check_unary_in_place(rec_array_in_place, |x| 1.0 / x, &INPUTS, "rec_array_in_place");
631    }
632
633    // ── Exponential and Logarithmic functions ──
634
635    #[test]
636    fn test_exp_array() {
637        check_unary(exp_array, f64::exp, &INPUTS, "exp_array");
638        check_unary_in_place(exp_array_in_place, f64::exp, &INPUTS, "exp_array_in_place");
639    }
640
641    #[test]
642    fn test_exp2_array() {
643        check_unary(exp2_array, f64::exp2, &INPUTS, "exp2_array");
644        check_unary_in_place(exp2_array_in_place, f64::exp2, &INPUTS, "exp2_array_in_place");
645    }
646
647    #[test]
648    fn test_expm1_array() {
649        check_unary(expm1_array, |x| x.exp_m1(), &SMALL, "expm1_array");
650        check_unary_in_place(expm1_array_in_place, |x| x.exp_m1(), &SMALL, "expm1_array_in_place");
651    }
652
653    #[test]
654    fn test_log_array() {
655        check_unary(log_array, f64::ln, &POSITIVE, "log_array");
656        check_unary_in_place(log_array_in_place, f64::ln, &POSITIVE, "log_array_in_place");
657    }
658
659    #[test]
660    fn test_log1p_array() {
661        check_unary(log1p_array, |x| x.ln_1p(), &SMALL, "log1p_array");
662        check_unary_in_place(log1p_array_in_place, |x| x.ln_1p(), &SMALL, "log1p_array_in_place");
663    }
664
665    #[test]
666    fn test_log2_array() {
667        check_unary(log2_array, f64::log2, &POSITIVE, "log2_array");
668        check_unary_in_place(log2_array_in_place, f64::log2, &POSITIVE, "log2_array_in_place");
669    }
670
671    #[test]
672    fn test_log10_array() {
673        check_unary(log10_array, f64::log10, &POSITIVE, "log10_array");
674        check_unary_in_place(log10_array_in_place, f64::log10, &POSITIVE, "log10_array_in_place");
675    }
676
677    #[test]
678    #[ignore = "stub: no core equivalent for logb (exponent extraction)"]
679    fn test_logb_array() {
680        todo!("logb extracts the exponent as a float; no direct core equivalent");
681    }
682
683    // ── Trigonometric functions ──
684
685    #[test]
686    fn test_sin_array() {
687        check_unary(sin_array, f64::sin, &INPUTS, "sin_array");
688        check_unary_in_place(sin_array_in_place, f64::sin, &INPUTS, "sin_array_in_place");
689    }
690
691    #[test]
692    #[ignore = "stub: no core equivalent for sinpi"]
693    fn test_sinpi_array() {
694        todo!("sinpi computes sin(x * pi); no direct core equivalent");
695    }
696
697    #[test]
698    fn test_cos_array() {
699        check_unary(cos_array, f64::cos, &INPUTS, "cos_array");
700        check_unary_in_place(cos_array_in_place, f64::cos, &INPUTS, "cos_array_in_place");
701    }
702
703    #[test]
704    #[ignore = "stub: no core equivalent for cospi"]
705    fn test_cospi_array() {
706        todo!("cospi computes cos(x * pi); no direct core equivalent");
707    }
708
709    #[test]
710    fn test_tan_array() {
711        check_unary(tan_array, f64::tan, &INPUTS, "tan_array");
712        check_unary_in_place(tan_array_in_place, f64::tan, &INPUTS, "tan_array_in_place");
713    }
714
715    #[test]
716    #[ignore = "stub: no core equivalent for tanpi"]
717    fn test_tanpi_array() {
718        todo!("tanpi computes tan(x * pi); no direct core equivalent");
719    }
720
721    #[test]
722    fn test_asin_array() {
723        check_unary(asin_array, f64::asin, &UNIT, "asin_array");
724        check_unary_in_place(asin_array_in_place, f64::asin, &UNIT, "asin_array_in_place");
725    }
726
727    #[test]
728    fn test_acos_array() {
729        check_unary(acos_array, f64::acos, &UNIT, "acos_array");
730        check_unary_in_place(acos_array_in_place, f64::acos, &UNIT, "acos_array_in_place");
731    }
732
733    #[test]
734    fn test_atan_array() {
735        check_unary(atan_array, f64::atan, &INPUTS, "atan_array");
736        check_unary_in_place(atan_array_in_place, f64::atan, &INPUTS, "atan_array_in_place");
737    }
738
739    #[test]
740    fn test_atan2_array() {
741        let y = [1.0, -1.0, 3.0, 0.0];
742        let x = [1.0, 1.0, -2.0, 5.0];
743        check_binary(atan2_array, f64::atan2, &y, &x, "atan2_array");
744        check_binary_in_place(atan2_array_in_place, f64::atan2, &y, &x, "atan2_array_in_place");
745    }
746
747    // ── Hyperbolic functions ──
748
749    #[test]
750    fn test_sinh_array() {
751        check_unary(sinh_array, f64::sinh, &INPUTS, "sinh_array");
752        check_unary_in_place(sinh_array_in_place, f64::sinh, &INPUTS, "sinh_array_in_place");
753    }
754
755    #[test]
756    fn test_cosh_array() {
757        let inputs = [0.0, 0.5, 1.0, 2.0];
758        check_unary(cosh_array, f64::cosh, &inputs, "cosh_array");
759        check_unary_in_place(cosh_array_in_place, f64::cosh, &inputs, "cosh_array_in_place");
760    }
761
762    #[test]
763    fn test_tanh_array() {
764        check_unary(tanh_array, f64::tanh, &INPUTS, "tanh_array");
765        check_unary_in_place(tanh_array_in_place, f64::tanh, &INPUTS, "tanh_array_in_place");
766    }
767
768    #[test]
769    fn test_asinh_array() {
770        check_unary(asinh_array, f64::asinh, &INPUTS, "asinh_array");
771        check_unary_in_place(asinh_array_in_place, f64::asinh, &INPUTS, "asinh_array_in_place");
772    }
773
774    #[test]
775    fn test_acosh_array() {
776        let inputs = [1.0, 1.5, 2.0, 4.0];
777        check_unary(acosh_array, f64::acosh, &inputs, "acosh_array");
778        check_unary_in_place(acosh_array_in_place, f64::acosh, &inputs, "acosh_array_in_place");
779    }
780
781    #[test]
782    fn test_atanh_array() {
783        check_unary(atanh_array, f64::atanh, &UNIT, "atanh_array");
784        check_unary_in_place(atanh_array_in_place, f64::atanh, &UNIT, "atanh_array_in_place");
785    }
786
787    // ── Special: sincos ──
788
789    #[test]
790    fn test_sincos_array() {
791        let mut sin_out = [0.0f64; 4];
792        let mut cos_out = [0.0f64; 4];
793        sincos_array(&mut sin_out, &mut cos_out, &INPUTS).unwrap();
794        let expected_sin: Vec<f64> = INPUTS.iter().map(|&x| x.sin()).collect();
795        let expected_cos: Vec<f64> = INPUTS.iter().map(|&x| x.cos()).collect();
796        assert_approx(&sin_out, &expected_sin, 1e-10, "sincos_array (sin)");
797        assert_approx(&cos_out, &expected_cos, 1e-10, "sincos_array (cos)");
798    }
799
800    // ── f32 spot check ──
801
802    #[test]
803    fn test_f32_sin_array() {
804        let inputs: [f32; 4] = [0.5, 1.0, 2.0, 3.5];
805        let mut out = [0.0f32; 4];
806        sin_array(&mut out, &inputs).unwrap();
807        let expected: Vec<f32> = inputs.iter().map(|&x| x.sin()).collect();
808        assert_approx_f32(&out, &expected, 1e-6, "sin_array (f32)");
809    }
810
811    #[test]
812    fn test_f32_pow_array() {
813        let bases: [f32; 4] = [2.0, 3.0, 4.0, 5.0];
814        let exponents: [f32; 4] = [3.0, 2.0, 0.5, 1.0];
815        let mut out = [0.0f32; 4];
816        pow_array(&mut out, &bases, &exponents).unwrap();
817        let expected: Vec<f32> = bases.iter().zip(exponents.iter()).map(|(&b, &e)| b.powf(e)).collect();
818        assert_approx_f32(&out, &expected, 1e-5, "pow_array (f32)");
819    }
820
821    // ── Error handling ──
822
823    #[test]
824    fn test_length_mismatch() {
825        let mut out = [0.0f64; 3];
826        let input = [1.0f64; 4];
827        let result = sin_array(&mut out, &input);
828        assert!(matches!(result, Err(AccelerateError::LengthMismatch { .. })));
829    }
830
831    #[test]
832    fn test_binary_length_mismatch() {
833        let mut out = [0.0f64; 4];
834        let a = [1.0f64; 4];
835        let b = [2.0f64; 3];
836        let result = pow_array(&mut out, &a, &b);
837        assert!(matches!(result, Err(AccelerateError::LengthMismatch { .. })));
838    }
839}