Skip to main content

ruvector_cnn/simd/
mod.rs

1//! SIMD Backend Dispatch Module
2//!
3//! Provides architecture-specific SIMD implementations with automatic dispatch:
4//! - AVX-512 for modern Intel/AMD (16 floats per iteration)
5//! - AVX2 with FMA for Intel Haswell+ / AMD Zen+ (8 floats per iteration)
6//! - NEON for ARM64/Apple Silicon (4 floats per iteration)
7//! - WASM SIMD for WebAssembly (4 floats per iteration)
8//! - Winograd F(2,3) for 2.25x faster 3x3 convolutions
9//! - Scalar fallback for all other platforms
10
11pub mod avx2;
12pub mod quantize;
13pub mod scalar;
14pub mod winograd;
15
16#[cfg(target_arch = "aarch64")]
17pub mod neon;
18
19#[cfg(target_arch = "wasm32")]
20pub mod wasm;
21
22// Re-export the dispatch functions
23pub use avx2::*;
24pub use scalar::*;
25pub use winograd::{conv_3x3_winograd, transform_filter, transform_input, transform_output, WinogradFilterCache};
26pub use quantize::{
27    QuantParams, QuantizedTensor, QuantizationType, PerChannelQuantParams,
28    quantize_simd, dequantize_simd, quantize_batch, dequantize_batch,
29    pi_constants,
30};
31
32/// SIMD-accelerated dot product with automatic architecture dispatch
33#[inline(always)]
34pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
35    #[cfg(target_arch = "x86_64")]
36    {
37        if is_x86_feature_detected!("avx512f") {
38            unsafe { avx2::dot_product_avx512(a, b) }
39        } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
40            unsafe { avx2::dot_product_avx2_fma(a, b) }
41        } else if is_x86_feature_detected!("avx2") {
42            unsafe { avx2::dot_product_avx2(a, b) }
43        } else {
44            scalar::dot_product_scalar(a, b)
45        }
46    }
47
48    #[cfg(target_arch = "aarch64")]
49    {
50        unsafe { neon::dot_product_neon(a, b) }
51    }
52
53    #[cfg(target_arch = "wasm32")]
54    {
55        wasm::dot_product_wasm(a, b)
56    }
57
58    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
59    {
60        scalar::dot_product_scalar(a, b)
61    }
62}
63
64/// SIMD-accelerated ReLU activation with automatic architecture dispatch
65#[inline(always)]
66pub fn relu_simd(input: &[f32], output: &mut [f32]) {
67    #[cfg(target_arch = "x86_64")]
68    {
69        if is_x86_feature_detected!("avx2") {
70            unsafe { avx2::relu_avx2(input, output) }
71        } else {
72            scalar::relu_scalar(input, output)
73        }
74    }
75
76    #[cfg(target_arch = "aarch64")]
77    {
78        unsafe { neon::relu_neon(input, output) }
79    }
80
81    #[cfg(target_arch = "wasm32")]
82    {
83        wasm::relu_wasm(input, output)
84    }
85
86    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
87    {
88        scalar::relu_scalar(input, output)
89    }
90}
91
92/// SIMD-accelerated ReLU6 activation with automatic architecture dispatch
93#[inline(always)]
94pub fn relu6_simd(input: &[f32], output: &mut [f32]) {
95    #[cfg(target_arch = "x86_64")]
96    {
97        if is_x86_feature_detected!("avx2") {
98            unsafe { avx2::relu6_avx2(input, output) }
99        } else {
100            scalar::relu6_scalar(input, output)
101        }
102    }
103
104    #[cfg(target_arch = "aarch64")]
105    {
106        unsafe { neon::relu6_neon(input, output) }
107    }
108
109    #[cfg(target_arch = "wasm32")]
110    {
111        wasm::relu6_wasm(input, output)
112    }
113
114    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
115    {
116        scalar::relu6_scalar(input, output)
117    }
118}
119
120/// SIMD-accelerated batch normalization with automatic architecture dispatch
121#[inline(always)]
122pub fn batch_norm_simd(
123    input: &[f32],
124    output: &mut [f32],
125    gamma: &[f32],
126    beta: &[f32],
127    mean: &[f32],
128    var: &[f32],
129    epsilon: f32,
130    channels: usize,
131) {
132    #[cfg(target_arch = "x86_64")]
133    {
134        if is_x86_feature_detected!("avx2") {
135            unsafe { avx2::batch_norm_avx2(input, output, gamma, beta, mean, var, epsilon, channels) }
136        } else {
137            scalar::batch_norm_scalar(input, output, gamma, beta, mean, var, epsilon, channels)
138        }
139    }
140
141    #[cfg(target_arch = "aarch64")]
142    {
143        unsafe { neon::batch_norm_neon(input, output, gamma, beta, mean, var, epsilon, channels) }
144    }
145
146    #[cfg(target_arch = "wasm32")]
147    {
148        wasm::batch_norm_wasm(input, output, gamma, beta, mean, var, epsilon, channels)
149    }
150
151    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
152    {
153        scalar::batch_norm_scalar(input, output, gamma, beta, mean, var, epsilon, channels)
154    }
155}
156
157/// SIMD-accelerated 3x3 convolution with automatic architecture dispatch
158#[inline(always)]
159pub fn conv_3x3_simd(
160    input: &[f32],
161    kernel: &[f32],
162    output: &mut [f32],
163    in_h: usize,
164    in_w: usize,
165    in_c: usize,
166    out_c: usize,
167    stride: usize,
168    padding: usize,
169) {
170    #[cfg(target_arch = "x86_64")]
171    {
172        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
173            unsafe {
174                avx2::conv_3x3_avx2_fma(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
175            }
176        } else if is_x86_feature_detected!("avx2") {
177            unsafe {
178                avx2::conv_3x3_avx2(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
179            }
180        } else {
181            scalar::conv_3x3_scalar(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
182        }
183    }
184
185    #[cfg(target_arch = "aarch64")]
186    {
187        unsafe {
188            neon::conv_3x3_neon(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
189        }
190    }
191
192    #[cfg(target_arch = "wasm32")]
193    {
194        wasm::conv_3x3_wasm(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
195    }
196
197    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
198    {
199        scalar::conv_3x3_scalar(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
200    }
201}
202
203/// SIMD-accelerated depthwise 3x3 convolution
204#[inline(always)]
205pub fn depthwise_conv_3x3_simd(
206    input: &[f32],
207    kernel: &[f32],
208    output: &mut [f32],
209    h: usize,
210    w: usize,
211    c: usize,
212    stride: usize,
213    padding: usize,
214) {
215    #[cfg(target_arch = "x86_64")]
216    {
217        if is_x86_feature_detected!("avx2") {
218            unsafe { avx2::depthwise_conv_3x3_avx2(input, kernel, output, h, w, c, stride, padding) }
219        } else {
220            scalar::depthwise_conv_3x3_scalar(input, kernel, output, h, w, c, stride, padding)
221        }
222    }
223
224    #[cfg(target_arch = "aarch64")]
225    {
226        unsafe { neon::depthwise_conv_3x3_neon(input, kernel, output, h, w, c, stride, padding) }
227    }
228
229    #[cfg(target_arch = "wasm32")]
230    {
231        wasm::depthwise_conv_3x3_wasm(input, kernel, output, h, w, c, stride, padding)
232    }
233
234    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
235    {
236        scalar::depthwise_conv_3x3_scalar(input, kernel, output, h, w, c, stride, padding)
237    }
238}
239
240/// SIMD-accelerated global average pooling
241#[inline(always)]
242pub fn global_avg_pool_simd(input: &[f32], output: &mut [f32], h: usize, w: usize, c: usize) {
243    #[cfg(target_arch = "x86_64")]
244    {
245        if is_x86_feature_detected!("avx2") {
246            unsafe { avx2::global_avg_pool_avx2(input, output, h, w, c) }
247        } else {
248            scalar::global_avg_pool_scalar(input, output, h, w, c)
249        }
250    }
251
252    #[cfg(target_arch = "aarch64")]
253    {
254        unsafe { neon::global_avg_pool_neon(input, output, h, w, c) }
255    }
256
257    #[cfg(target_arch = "wasm32")]
258    {
259        wasm::global_avg_pool_wasm(input, output, h, w, c)
260    }
261
262    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
263    {
264        scalar::global_avg_pool_scalar(input, output, h, w, c)
265    }
266}
267
268/// SIMD-accelerated max pooling 2x2
269#[inline(always)]
270pub fn max_pool_2x2_simd(
271    input: &[f32],
272    output: &mut [f32],
273    h: usize,
274    w: usize,
275    c: usize,
276    stride: usize,
277) {
278    #[cfg(target_arch = "x86_64")]
279    {
280        if is_x86_feature_detected!("avx2") {
281            unsafe { avx2::max_pool_2x2_avx2(input, output, h, w, c, stride) }
282        } else {
283            scalar::max_pool_2x2_scalar(input, output, h, w, c, stride)
284        }
285    }
286
287    #[cfg(target_arch = "aarch64")]
288    {
289        unsafe { neon::max_pool_2x2_neon(input, output, h, w, c, stride) }
290    }
291
292    #[cfg(target_arch = "wasm32")]
293    {
294        wasm::max_pool_2x2_wasm(input, output, h, w, c, stride)
295    }
296
297    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
298    {
299        scalar::max_pool_2x2_scalar(input, output, h, w, c, stride)
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_dot_product_simd() {
309        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
310        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
311
312        let result = dot_product_simd(&a, &b);
313        let expected = scalar::dot_product_scalar(&a, &b);
314
315        assert!((result - expected).abs() < 0.001);
316    }
317
318    #[test]
319    fn test_relu_simd() {
320        let input = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
321        let mut output = vec![0.0; 8];
322
323        relu_simd(&input, &mut output);
324
325        assert_eq!(output, vec![0.0, 2.0, 0.0, 4.0, 0.0, 6.0, 0.0, 8.0]);
326    }
327
328    #[test]
329    fn test_relu6_simd() {
330        let input = vec![-1.0, 2.0, 7.0, 4.0, -5.0, 10.0, 3.0, 8.0];
331        let mut output = vec![0.0; 8];
332
333        relu6_simd(&input, &mut output);
334
335        assert_eq!(output, vec![0.0, 2.0, 6.0, 4.0, 0.0, 6.0, 3.0, 6.0]);
336    }
337}