torsh_core/
simd_arm.rs

1//! ARM NEON SIMD optimizations for ToRSh core operations
2//!
3//! This module provides optimized implementations of common operations using
4//! ARM NEON SIMD instructions for improved performance on ARM64 platforms.
5
6#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8
9use crate::error::{Result, TorshError};
10
11/// ARM NEON optimized operations
12pub struct ArmSimdOps;
13
14#[cfg(target_arch = "aarch64")]
15impl ArmSimdOps {
16    /// Check if NEON is available at runtime
17    pub fn is_neon_available() -> bool {
18        std::arch::is_aarch64_feature_detected!("neon")
19    }
20
21    /// Check if Advanced SIMD is available
22    pub fn is_asimd_available() -> bool {
23        std::arch::is_aarch64_feature_detected!("asimd")
24    }
25
26    /// Check if FP16 arithmetic is supported
27    pub fn is_fp16_available() -> bool {
28        std::arch::is_aarch64_feature_detected!("fp16")
29    }
30
31    /// Check if dot product instructions are supported
32    pub fn is_dotprod_available() -> bool {
33        std::arch::is_aarch64_feature_detected!("dotprod")
34    }
35
36    /// Vectorized addition of f32 arrays using NEON
37    #[target_feature(enable = "neon")]
38    pub unsafe fn add_f32_neon(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
39        if a.len() != b.len() || a.len() != result.len() {
40            return Err(TorshError::dimension_error_with_context(
41                "Array lengths must match",
42                "add_f32_neon",
43            ));
44        }
45
46        let len = a.len();
47        let simd_len = len & !3; // Process 4 elements at a time
48
49        let a_ptr = a.as_ptr();
50        let b_ptr = b.as_ptr();
51        let result_ptr = result.as_mut_ptr();
52
53        // Process 4 f32 elements at a time using NEON
54        for i in (0..simd_len).step_by(4) {
55            let va = vld1q_f32(a_ptr.add(i));
56            let vb = vld1q_f32(b_ptr.add(i));
57            let vresult = vaddq_f32(va, vb);
58            vst1q_f32(result_ptr.add(i), vresult);
59        }
60
61        // Handle remaining elements
62        for i in simd_len..len {
63            result[i] = a[i] + b[i];
64        }
65
66        Ok(())
67    }
68
69    /// Vectorized subtraction of f32 arrays using NEON
70    #[target_feature(enable = "neon")]
71    pub unsafe fn sub_f32_neon(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
72        if a.len() != b.len() || a.len() != result.len() {
73            return Err(TorshError::dimension_error_with_context(
74                "Array lengths must match",
75                "simd_operation",
76            ));
77        }
78
79        let len = a.len();
80        let simd_len = len & !3;
81
82        let a_ptr = a.as_ptr();
83        let b_ptr = b.as_ptr();
84        let result_ptr = result.as_mut_ptr();
85
86        for i in (0..simd_len).step_by(4) {
87            let va = vld1q_f32(a_ptr.add(i));
88            let vb = vld1q_f32(b_ptr.add(i));
89            let vresult = vsubq_f32(va, vb);
90            vst1q_f32(result_ptr.add(i), vresult);
91        }
92
93        for i in simd_len..len {
94            result[i] = a[i] - b[i];
95        }
96
97        Ok(())
98    }
99
100    /// Vectorized multiplication of f32 arrays using NEON
101    #[target_feature(enable = "neon")]
102    pub unsafe fn mul_f32_neon(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
103        if a.len() != b.len() || a.len() != result.len() {
104            return Err(TorshError::dimension_error_with_context(
105                "Array lengths must match",
106                "simd_operation",
107            ));
108        }
109
110        let len = a.len();
111        let simd_len = len & !3;
112
113        let a_ptr = a.as_ptr();
114        let b_ptr = b.as_ptr();
115        let result_ptr = result.as_mut_ptr();
116
117        for i in (0..simd_len).step_by(4) {
118            let va = vld1q_f32(a_ptr.add(i));
119            let vb = vld1q_f32(b_ptr.add(i));
120            let vresult = vmulq_f32(va, vb);
121            vst1q_f32(result_ptr.add(i), vresult);
122        }
123
124        for i in simd_len..len {
125            result[i] = a[i] * b[i];
126        }
127
128        Ok(())
129    }
130
131    /// Vectorized fused multiply-add operation using NEON
132    #[target_feature(enable = "neon")]
133    pub unsafe fn fma_f32_neon(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) -> Result<()> {
134        if a.len() != b.len() || a.len() != c.len() || a.len() != result.len() {
135            return Err(TorshError::dimension_error_with_context(
136                "Array lengths must match",
137                "simd_operation",
138            ));
139        }
140
141        let len = a.len();
142        let simd_len = len & !3;
143
144        let a_ptr = a.as_ptr();
145        let b_ptr = b.as_ptr();
146        let c_ptr = c.as_ptr();
147        let result_ptr = result.as_mut_ptr();
148
149        for i in (0..simd_len).step_by(4) {
150            let va = vld1q_f32(a_ptr.add(i));
151            let vb = vld1q_f32(b_ptr.add(i));
152            let vc = vld1q_f32(c_ptr.add(i));
153            let vresult = vfmaq_f32(vc, va, vb); // c + a * b
154            vst1q_f32(result_ptr.add(i), vresult);
155        }
156
157        for i in simd_len..len {
158            result[i] = a[i] * b[i] + c[i];
159        }
160
161        Ok(())
162    }
163
164    /// Vectorized dot product using NEON
165    #[target_feature(enable = "neon")]
166    pub unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32]) -> Result<f32> {
167        if a.len() != b.len() {
168            return Err(TorshError::dimension_error_with_context(
169                "Array lengths must match",
170                "simd_operation",
171            ));
172        }
173
174        let len = a.len();
175        let simd_len = len & !3;
176
177        let a_ptr = a.as_ptr();
178        let b_ptr = b.as_ptr();
179
180        let mut sum_vec = vdupq_n_f32(0.0);
181
182        // SIMD accumulation
183        for i in (0..simd_len).step_by(4) {
184            let va = vld1q_f32(a_ptr.add(i));
185            let vb = vld1q_f32(b_ptr.add(i));
186            let vmul = vmulq_f32(va, vb);
187            sum_vec = vaddq_f32(sum_vec, vmul);
188        }
189
190        // Horizontal sum of the vector
191        let sum_pair = vadd_f32(vget_low_f32(sum_vec), vget_high_f32(sum_vec));
192        let sum_scalar = vpadd_f32(sum_pair, sum_pair);
193        let mut result = vget_lane_f32(sum_scalar, 0);
194
195        // Handle remaining elements
196        for i in simd_len..len {
197            result += a[i] * b[i];
198        }
199
200        Ok(result)
201    }
202
203    /// Optimized dot product for small vectors
204    /// Note: dotprod instructions are currently unstable, using manual implementation
205    #[target_feature(enable = "neon")]
206    pub unsafe fn dot_product_i8_dotprod(a: &[i8], b: &[i8]) -> Result<i32> {
207        if a.len() != b.len() {
208            return Err(TorshError::dimension_error_with_context(
209                "Array lengths must match",
210                "simd_operation",
211            ));
212        }
213
214        let len = a.len();
215        let simd_len = len & !15; // Process 16 elements at a time
216
217        let a_ptr = a.as_ptr();
218        let b_ptr = b.as_ptr();
219
220        let mut sum_vec = vdupq_n_s32(0);
221
222        // SIMD accumulation using manual multiply-accumulate
223        // TODO: Re-enable vdotq_s32 when it becomes stable
224        for i in (0..simd_len).step_by(16) {
225            // Load 16 i8 values and process in groups of 4
226            for j in 0..4 {
227                let offset = i + j * 4;
228                if offset < len {
229                    // Load 4 i8 values and convert to i32
230                    let a_vals = [
231                        *a_ptr.add(offset) as i32,
232                        *a_ptr.add(offset + 1) as i32,
233                        *a_ptr.add(offset + 2) as i32,
234                        *a_ptr.add(offset + 3) as i32,
235                    ];
236                    let b_vals = [
237                        *b_ptr.add(offset) as i32,
238                        *b_ptr.add(offset + 1) as i32,
239                        *b_ptr.add(offset + 2) as i32,
240                        *b_ptr.add(offset + 3) as i32,
241                    ];
242
243                    let va = vld1q_s32(a_vals.as_ptr());
244                    let vb = vld1q_s32(b_vals.as_ptr());
245                    sum_vec = vmlaq_s32(sum_vec, va, vb);
246                }
247            }
248        }
249
250        // Horizontal sum
251        let sum_pair = vadd_s32(vget_low_s32(sum_vec), vget_high_s32(sum_vec));
252        let sum_scalar = vpadd_s32(sum_pair, sum_pair);
253        let mut result = vget_lane_s32(sum_scalar, 0);
254
255        // Handle remaining elements
256        for i in simd_len..len {
257            result += a[i] as i32 * b[i] as i32;
258        }
259
260        Ok(result)
261    }
262
263    /// Vectorized sum reduction using NEON
264    #[target_feature(enable = "neon")]
265    pub unsafe fn sum_f32_neon(data: &[f32]) -> f32 {
266        let len = data.len();
267        let simd_len = len & !3;
268        let data_ptr = data.as_ptr();
269
270        let mut sum_vec = vdupq_n_f32(0.0);
271
272        // SIMD accumulation
273        for i in (0..simd_len).step_by(4) {
274            let vdata = vld1q_f32(data_ptr.add(i));
275            sum_vec = vaddq_f32(sum_vec, vdata);
276        }
277
278        // Horizontal sum
279        let sum_pair = vadd_f32(vget_low_f32(sum_vec), vget_high_f32(sum_vec));
280        let sum_scalar = vpadd_f32(sum_pair, sum_pair);
281        let mut result = vget_lane_f32(sum_scalar, 0);
282
283        // Handle remaining elements
284        #[allow(clippy::needless_range_loop)] // Indexing is clearer for accumulation
285        for i in simd_len..len {
286            result += data[i];
287        }
288
289        result
290    }
291
292    /// Vectorized ReLU activation using NEON
293    #[target_feature(enable = "neon")]
294    pub unsafe fn relu_f32_neon(data: &[f32], result: &mut [f32]) -> Result<()> {
295        if data.len() != result.len() {
296            return Err(TorshError::dimension_error_with_context(
297                "Array lengths must match",
298                "simd_operation",
299            ));
300        }
301
302        let len = data.len();
303        let simd_len = len & !3;
304
305        let data_ptr = data.as_ptr();
306        let result_ptr = result.as_mut_ptr();
307        let zero_vec = vdupq_n_f32(0.0);
308
309        for i in (0..simd_len).step_by(4) {
310            let vdata = vld1q_f32(data_ptr.add(i));
311            let vresult = vmaxq_f32(vdata, zero_vec);
312            vst1q_f32(result_ptr.add(i), vresult);
313        }
314
315        for i in simd_len..len {
316            result[i] = data[i].max(0.0);
317        }
318
319        Ok(())
320    }
321
322    /// Vectorized matrix multiplication for small matrices using NEON
323    #[target_feature(enable = "neon")]
324    pub unsafe fn matmul_f32_4x4_neon(
325        a: &[f32; 16],
326        b: &[f32; 16],
327        result: &mut [f32; 16],
328    ) -> Result<()> {
329        // Load matrix A rows
330        let a_row0 = vld1q_f32(a.as_ptr());
331        let a_row1 = vld1q_f32(a.as_ptr().add(4));
332        let a_row2 = vld1q_f32(a.as_ptr().add(8));
333        let a_row3 = vld1q_f32(a.as_ptr().add(12));
334
335        // Load matrix B columns (transposed for efficient access)
336        let b_col0_arr = [b[0], b[4], b[8], b[12]];
337        let b_col1_arr = [b[1], b[5], b[9], b[13]];
338        let b_col2_arr = [b[2], b[6], b[10], b[14]];
339        let b_col3_arr = [b[3], b[7], b[11], b[15]];
340
341        let b_col0 = vld1q_f32(b_col0_arr.as_ptr());
342        let b_col1 = vld1q_f32(b_col1_arr.as_ptr());
343        let b_col2 = vld1q_f32(b_col2_arr.as_ptr());
344        let b_col3 = vld1q_f32(b_col3_arr.as_ptr());
345
346        // Compute result matrix
347        let a_rows = [a_row0, a_row1, a_row2, a_row3];
348        let b_cols = [b_col0, b_col1, b_col2, b_col3];
349
350        for i in 0..4 {
351            for j in 0..4 {
352                let dot = vmulq_f32(a_rows[i], b_cols[j]);
353                let sum_pair = vadd_f32(vget_low_f32(dot), vget_high_f32(dot));
354                let sum_scalar = vpadd_f32(sum_pair, sum_pair);
355                let final_sum = vget_lane_f32(sum_scalar, 0);
356                result[i * 4 + j] = final_sum;
357            }
358        }
359
360        Ok(())
361    }
362
363    // Note: f16 NEON operations require unstable Rust features
364    // Commented out until f16 support is stabilized
365    // #[cfg(feature = "fp16")]
366    // #[target_feature(enable = "neon", enable = "fp16")]
367    // pub unsafe fn add_f16_neon(a: &[f16], b: &[f16], result: &mut [f16]) -> Result<()> {
368    //     if a.len() != b.len() || a.len() != result.len() {
369    //         return Err(TorshError::dimension_error_with_context(
370    //             "Array lengths must match",
371    //             "simd_operation",
372    //         ));
373    //     }
374    //
375    //     let len = a.len();
376    //     let simd_len = len & !7; // Process 8 f16 elements at a time
377    //
378    //     let a_ptr = a.as_ptr() as *const __fp16;
379    //     let b_ptr = b.as_ptr() as *const __fp16;
380    //     let result_ptr = result.as_mut_ptr() as *mut __fp16;
381    //
382    //     for i in (0..simd_len).step_by(8) {
383    //         let va = vld1q_f16(a_ptr.add(i));
384    //         let vb = vld1q_f16(b_ptr.add(i));
385    //         let vresult = vaddq_f16(va, vb);
386    //         vst1q_f16(result_ptr.add(i), vresult);
387    //     }
388    //
389    //     // Handle remaining elements (fallback to scalar)
390    //     for i in simd_len..len {
391    //         result[i] = a[i] + b[i];
392    //     }
393    //
394    //     Ok(())
395    // }
396
397    /// Optimized memcpy using NEON for large data transfers
398    #[target_feature(enable = "neon")]
399    pub unsafe fn memcpy_neon(src: &[u8], dest: &mut [u8]) -> Result<()> {
400        if src.len() != dest.len() {
401            return Err(TorshError::dimension_error_with_context(
402                "Source and destination lengths must match",
403                "memcpy_neon",
404            ));
405        }
406
407        let len = src.len();
408        let simd_len = len & !31; // Process 32 bytes at a time
409
410        let src_ptr = src.as_ptr();
411        let dest_ptr = dest.as_mut_ptr();
412
413        // NEON optimized copy for large blocks
414        for i in (0..simd_len).step_by(32) {
415            let v0 = vld1q_u8(src_ptr.add(i));
416            let v1 = vld1q_u8(src_ptr.add(i + 16));
417            vst1q_u8(dest_ptr.add(i), v0);
418            vst1q_u8(dest_ptr.add(i + 16), v1);
419        }
420
421        // Handle remaining bytes
422        dest[simd_len..len].copy_from_slice(&src[simd_len..len]);
423
424        Ok(())
425    }
426}
427
428/// Safe wrapper functions for ARM SIMD operations
429impl ArmSimdOps {
430    /// Safe wrapper for NEON f32 addition
431    pub fn add_f32_safe(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
432        #[cfg(target_arch = "aarch64")]
433        {
434            if Self::is_neon_available() {
435                unsafe { Self::add_f32_neon(a, b, result) }
436            } else {
437                Self::add_f32_scalar(a, b, result)
438            }
439        }
440        #[cfg(not(target_arch = "aarch64"))]
441        {
442            Self::add_f32_scalar(a, b, result)
443        }
444    }
445
446    /// Safe wrapper for NEON f32 multiplication
447    pub fn mul_f32_safe(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
448        #[cfg(target_arch = "aarch64")]
449        {
450            if Self::is_neon_available() {
451                unsafe { Self::mul_f32_neon(a, b, result) }
452            } else {
453                Self::mul_f32_scalar(a, b, result)
454            }
455        }
456        #[cfg(not(target_arch = "aarch64"))]
457        {
458            Self::mul_f32_scalar(a, b, result)
459        }
460    }
461
462    /// Safe wrapper for NEON dot product
463    pub fn dot_product_f32_safe(a: &[f32], b: &[f32]) -> Result<f32> {
464        #[cfg(target_arch = "aarch64")]
465        {
466            if Self::is_neon_available() {
467                unsafe { Self::dot_product_f32_neon(a, b) }
468            } else {
469                Self::dot_product_f32_scalar(a, b)
470            }
471        }
472        #[cfg(not(target_arch = "aarch64"))]
473        {
474            Self::dot_product_f32_scalar(a, b)
475        }
476    }
477
478    /// Scalar fallback for f32 addition
479    fn add_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
480        if a.len() != b.len() || a.len() != result.len() {
481            return Err(TorshError::dimension_error_with_context(
482                "Array lengths must match",
483                "simd_operation",
484            ));
485        }
486
487        for i in 0..a.len() {
488            result[i] = a[i] + b[i];
489        }
490
491        Ok(())
492    }
493
494    /// Scalar fallback for f32 multiplication
495    fn mul_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
496        if a.len() != b.len() || a.len() != result.len() {
497            return Err(TorshError::dimension_error_with_context(
498                "Array lengths must match",
499                "simd_operation",
500            ));
501        }
502
503        for i in 0..a.len() {
504            result[i] = a[i] * b[i];
505        }
506
507        Ok(())
508    }
509
510    /// Scalar fallback for f32 dot product
511    fn dot_product_f32_scalar(a: &[f32], b: &[f32]) -> Result<f32> {
512        if a.len() != b.len() {
513            return Err(TorshError::dimension_error_with_context(
514                "Array lengths must match",
515                "simd_operation",
516            ));
517        }
518
519        let mut result = 0.0;
520        for i in 0..a.len() {
521            result += a[i] * b[i];
522        }
523
524        Ok(result)
525    }
526}
527
528#[cfg(not(target_arch = "aarch64"))]
529impl ArmSimdOps {
530    /// Stub implementation for non-ARM platforms
531    pub fn is_neon_available() -> bool {
532        false
533    }
534    pub fn is_asimd_available() -> bool {
535        false
536    }
537    pub fn is_fp16_available() -> bool {
538        false
539    }
540    pub fn is_dotprod_available() -> bool {
541        false
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548
549    #[test]
550    fn test_neon_availability() {
551        #[cfg(target_arch = "aarch64")]
552        {
553            // Test availability checks - these should not panic
554            let _ = ArmSimdOps::is_neon_available();
555            let _ = ArmSimdOps::is_asimd_available();
556            let _ = ArmSimdOps::is_fp16_available();
557            let _ = ArmSimdOps::is_dotprod_available();
558        }
559    }
560
561    #[test]
562    fn test_safe_operations() {
563        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
564        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
565        let mut result = vec![0.0; 8];
566
567        // Test addition
568        ArmSimdOps::add_f32_safe(&a, &b, &mut result).unwrap();
569        let expected_add = vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0];
570        assert_eq!(result, expected_add);
571
572        // Test multiplication
573        ArmSimdOps::mul_f32_safe(&a, &b, &mut result).unwrap();
574        let expected_mul = vec![2.0, 6.0, 12.0, 20.0, 30.0, 42.0, 56.0, 72.0];
575        assert_eq!(result, expected_mul);
576
577        // Test dot product
578        let dot_result = ArmSimdOps::dot_product_f32_safe(&a, &b).unwrap();
579        let expected_dot = 240.0; // 1*2 + 2*3 + 3*4 + 4*5 + 5*6 + 6*7 + 7*8 + 8*9
580        assert_eq!(dot_result, expected_dot);
581    }
582
583    #[test]
584    fn test_error_handling() {
585        let a = vec![1.0, 2.0, 3.0];
586        let b = vec![1.0, 2.0];
587        let mut result = vec![0.0; 3];
588
589        // Test mismatched lengths
590        assert!(ArmSimdOps::add_f32_safe(&a, &b, &mut result).is_err());
591        assert!(ArmSimdOps::dot_product_f32_safe(&a, &b).is_err());
592    }
593}