1pub mod avx2;
12pub mod quantize;
13pub mod scalar;
14pub mod winograd;
15
16#[cfg(target_arch = "aarch64")]
17pub mod neon;
18
19#[cfg(target_arch = "wasm32")]
20pub mod wasm;
21
22pub use avx2::*;
24pub use scalar::*;
25pub use winograd::{conv_3x3_winograd, transform_filter, transform_input, transform_output, WinogradFilterCache};
26pub use quantize::{
27 QuantParams, QuantizedTensor, QuantizationType, PerChannelQuantParams,
28 quantize_simd, dequantize_simd, quantize_batch, dequantize_batch,
29 pi_constants,
30};
31
32#[inline(always)]
34pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
35 #[cfg(target_arch = "x86_64")]
36 {
37 if is_x86_feature_detected!("avx512f") {
38 unsafe { avx2::dot_product_avx512(a, b) }
39 } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
40 unsafe { avx2::dot_product_avx2_fma(a, b) }
41 } else if is_x86_feature_detected!("avx2") {
42 unsafe { avx2::dot_product_avx2(a, b) }
43 } else {
44 scalar::dot_product_scalar(a, b)
45 }
46 }
47
48 #[cfg(target_arch = "aarch64")]
49 {
50 unsafe { neon::dot_product_neon(a, b) }
51 }
52
53 #[cfg(target_arch = "wasm32")]
54 {
55 wasm::dot_product_wasm(a, b)
56 }
57
58 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
59 {
60 scalar::dot_product_scalar(a, b)
61 }
62}
63
64#[inline(always)]
66pub fn relu_simd(input: &[f32], output: &mut [f32]) {
67 #[cfg(target_arch = "x86_64")]
68 {
69 if is_x86_feature_detected!("avx2") {
70 unsafe { avx2::relu_avx2(input, output) }
71 } else {
72 scalar::relu_scalar(input, output)
73 }
74 }
75
76 #[cfg(target_arch = "aarch64")]
77 {
78 unsafe { neon::relu_neon(input, output) }
79 }
80
81 #[cfg(target_arch = "wasm32")]
82 {
83 wasm::relu_wasm(input, output)
84 }
85
86 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
87 {
88 scalar::relu_scalar(input, output)
89 }
90}
91
92#[inline(always)]
94pub fn relu6_simd(input: &[f32], output: &mut [f32]) {
95 #[cfg(target_arch = "x86_64")]
96 {
97 if is_x86_feature_detected!("avx2") {
98 unsafe { avx2::relu6_avx2(input, output) }
99 } else {
100 scalar::relu6_scalar(input, output)
101 }
102 }
103
104 #[cfg(target_arch = "aarch64")]
105 {
106 unsafe { neon::relu6_neon(input, output) }
107 }
108
109 #[cfg(target_arch = "wasm32")]
110 {
111 wasm::relu6_wasm(input, output)
112 }
113
114 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
115 {
116 scalar::relu6_scalar(input, output)
117 }
118}
119
120#[inline(always)]
122pub fn batch_norm_simd(
123 input: &[f32],
124 output: &mut [f32],
125 gamma: &[f32],
126 beta: &[f32],
127 mean: &[f32],
128 var: &[f32],
129 epsilon: f32,
130 channels: usize,
131) {
132 #[cfg(target_arch = "x86_64")]
133 {
134 if is_x86_feature_detected!("avx2") {
135 unsafe { avx2::batch_norm_avx2(input, output, gamma, beta, mean, var, epsilon, channels) }
136 } else {
137 scalar::batch_norm_scalar(input, output, gamma, beta, mean, var, epsilon, channels)
138 }
139 }
140
141 #[cfg(target_arch = "aarch64")]
142 {
143 unsafe { neon::batch_norm_neon(input, output, gamma, beta, mean, var, epsilon, channels) }
144 }
145
146 #[cfg(target_arch = "wasm32")]
147 {
148 wasm::batch_norm_wasm(input, output, gamma, beta, mean, var, epsilon, channels)
149 }
150
151 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
152 {
153 scalar::batch_norm_scalar(input, output, gamma, beta, mean, var, epsilon, channels)
154 }
155}
156
157#[inline(always)]
159pub fn conv_3x3_simd(
160 input: &[f32],
161 kernel: &[f32],
162 output: &mut [f32],
163 in_h: usize,
164 in_w: usize,
165 in_c: usize,
166 out_c: usize,
167 stride: usize,
168 padding: usize,
169) {
170 #[cfg(target_arch = "x86_64")]
171 {
172 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
173 unsafe {
174 avx2::conv_3x3_avx2_fma(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
175 }
176 } else if is_x86_feature_detected!("avx2") {
177 unsafe {
178 avx2::conv_3x3_avx2(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
179 }
180 } else {
181 scalar::conv_3x3_scalar(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
182 }
183 }
184
185 #[cfg(target_arch = "aarch64")]
186 {
187 unsafe {
188 neon::conv_3x3_neon(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
189 }
190 }
191
192 #[cfg(target_arch = "wasm32")]
193 {
194 wasm::conv_3x3_wasm(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
195 }
196
197 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
198 {
199 scalar::conv_3x3_scalar(input, kernel, output, in_h, in_w, in_c, out_c, stride, padding)
200 }
201}
202
203#[inline(always)]
205pub fn depthwise_conv_3x3_simd(
206 input: &[f32],
207 kernel: &[f32],
208 output: &mut [f32],
209 h: usize,
210 w: usize,
211 c: usize,
212 stride: usize,
213 padding: usize,
214) {
215 #[cfg(target_arch = "x86_64")]
216 {
217 if is_x86_feature_detected!("avx2") {
218 unsafe { avx2::depthwise_conv_3x3_avx2(input, kernel, output, h, w, c, stride, padding) }
219 } else {
220 scalar::depthwise_conv_3x3_scalar(input, kernel, output, h, w, c, stride, padding)
221 }
222 }
223
224 #[cfg(target_arch = "aarch64")]
225 {
226 unsafe { neon::depthwise_conv_3x3_neon(input, kernel, output, h, w, c, stride, padding) }
227 }
228
229 #[cfg(target_arch = "wasm32")]
230 {
231 wasm::depthwise_conv_3x3_wasm(input, kernel, output, h, w, c, stride, padding)
232 }
233
234 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
235 {
236 scalar::depthwise_conv_3x3_scalar(input, kernel, output, h, w, c, stride, padding)
237 }
238}
239
240#[inline(always)]
242pub fn global_avg_pool_simd(input: &[f32], output: &mut [f32], h: usize, w: usize, c: usize) {
243 #[cfg(target_arch = "x86_64")]
244 {
245 if is_x86_feature_detected!("avx2") {
246 unsafe { avx2::global_avg_pool_avx2(input, output, h, w, c) }
247 } else {
248 scalar::global_avg_pool_scalar(input, output, h, w, c)
249 }
250 }
251
252 #[cfg(target_arch = "aarch64")]
253 {
254 unsafe { neon::global_avg_pool_neon(input, output, h, w, c) }
255 }
256
257 #[cfg(target_arch = "wasm32")]
258 {
259 wasm::global_avg_pool_wasm(input, output, h, w, c)
260 }
261
262 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
263 {
264 scalar::global_avg_pool_scalar(input, output, h, w, c)
265 }
266}
267
268#[inline(always)]
270pub fn max_pool_2x2_simd(
271 input: &[f32],
272 output: &mut [f32],
273 h: usize,
274 w: usize,
275 c: usize,
276 stride: usize,
277) {
278 #[cfg(target_arch = "x86_64")]
279 {
280 if is_x86_feature_detected!("avx2") {
281 unsafe { avx2::max_pool_2x2_avx2(input, output, h, w, c, stride) }
282 } else {
283 scalar::max_pool_2x2_scalar(input, output, h, w, c, stride)
284 }
285 }
286
287 #[cfg(target_arch = "aarch64")]
288 {
289 unsafe { neon::max_pool_2x2_neon(input, output, h, w, c, stride) }
290 }
291
292 #[cfg(target_arch = "wasm32")]
293 {
294 wasm::max_pool_2x2_wasm(input, output, h, w, c, stride)
295 }
296
297 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "wasm32")))]
298 {
299 scalar::max_pool_2x2_scalar(input, output, h, w, c, stride)
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_dot_product_simd() {
309 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
310 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
311
312 let result = dot_product_simd(&a, &b);
313 let expected = scalar::dot_product_scalar(&a, &b);
314
315 assert!((result - expected).abs() < 0.001);
316 }
317
318 #[test]
319 fn test_relu_simd() {
320 let input = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
321 let mut output = vec![0.0; 8];
322
323 relu_simd(&input, &mut output);
324
325 assert_eq!(output, vec![0.0, 2.0, 0.0, 4.0, 0.0, 6.0, 0.0, 8.0]);
326 }
327
328 #[test]
329 fn test_relu6_simd() {
330 let input = vec![-1.0, 2.0, 7.0, 4.0, -5.0, 10.0, 3.0, 8.0];
331 let mut output = vec![0.0; 8];
332
333 relu6_simd(&input, &mut output);
334
335 assert_eq!(output, vec![0.0, 2.0, 6.0, 4.0, 0.0, 6.0, 3.0, 6.0]);
336 }
337}