1#[cfg(target_arch = "aarch64")]
7use std::arch::aarch64::*;
8
9use crate::error::{Result, TorshError};
10
11pub struct ArmSimdOps;
13
14#[cfg(target_arch = "aarch64")]
15impl ArmSimdOps {
16 pub fn is_neon_available() -> bool {
18 std::arch::is_aarch64_feature_detected!("neon")
19 }
20
21 pub fn is_asimd_available() -> bool {
23 std::arch::is_aarch64_feature_detected!("asimd")
24 }
25
26 pub fn is_fp16_available() -> bool {
28 std::arch::is_aarch64_feature_detected!("fp16")
29 }
30
31 pub fn is_dotprod_available() -> bool {
33 std::arch::is_aarch64_feature_detected!("dotprod")
34 }
35
36 #[target_feature(enable = "neon")]
38 pub unsafe fn add_f32_neon(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
39 if a.len() != b.len() || a.len() != result.len() {
40 return Err(TorshError::dimension_error_with_context(
41 "Array lengths must match",
42 "add_f32_neon",
43 ));
44 }
45
46 let len = a.len();
47 let simd_len = len & !3; let a_ptr = a.as_ptr();
50 let b_ptr = b.as_ptr();
51 let result_ptr = result.as_mut_ptr();
52
53 for i in (0..simd_len).step_by(4) {
55 let va = vld1q_f32(a_ptr.add(i));
56 let vb = vld1q_f32(b_ptr.add(i));
57 let vresult = vaddq_f32(va, vb);
58 vst1q_f32(result_ptr.add(i), vresult);
59 }
60
61 for i in simd_len..len {
63 result[i] = a[i] + b[i];
64 }
65
66 Ok(())
67 }
68
69 #[target_feature(enable = "neon")]
71 pub unsafe fn sub_f32_neon(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
72 if a.len() != b.len() || a.len() != result.len() {
73 return Err(TorshError::dimension_error_with_context(
74 "Array lengths must match",
75 "simd_operation",
76 ));
77 }
78
79 let len = a.len();
80 let simd_len = len & !3;
81
82 let a_ptr = a.as_ptr();
83 let b_ptr = b.as_ptr();
84 let result_ptr = result.as_mut_ptr();
85
86 for i in (0..simd_len).step_by(4) {
87 let va = vld1q_f32(a_ptr.add(i));
88 let vb = vld1q_f32(b_ptr.add(i));
89 let vresult = vsubq_f32(va, vb);
90 vst1q_f32(result_ptr.add(i), vresult);
91 }
92
93 for i in simd_len..len {
94 result[i] = a[i] - b[i];
95 }
96
97 Ok(())
98 }
99
100 #[target_feature(enable = "neon")]
102 pub unsafe fn mul_f32_neon(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
103 if a.len() != b.len() || a.len() != result.len() {
104 return Err(TorshError::dimension_error_with_context(
105 "Array lengths must match",
106 "simd_operation",
107 ));
108 }
109
110 let len = a.len();
111 let simd_len = len & !3;
112
113 let a_ptr = a.as_ptr();
114 let b_ptr = b.as_ptr();
115 let result_ptr = result.as_mut_ptr();
116
117 for i in (0..simd_len).step_by(4) {
118 let va = vld1q_f32(a_ptr.add(i));
119 let vb = vld1q_f32(b_ptr.add(i));
120 let vresult = vmulq_f32(va, vb);
121 vst1q_f32(result_ptr.add(i), vresult);
122 }
123
124 for i in simd_len..len {
125 result[i] = a[i] * b[i];
126 }
127
128 Ok(())
129 }
130
131 #[target_feature(enable = "neon")]
133 pub unsafe fn fma_f32_neon(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) -> Result<()> {
134 if a.len() != b.len() || a.len() != c.len() || a.len() != result.len() {
135 return Err(TorshError::dimension_error_with_context(
136 "Array lengths must match",
137 "simd_operation",
138 ));
139 }
140
141 let len = a.len();
142 let simd_len = len & !3;
143
144 let a_ptr = a.as_ptr();
145 let b_ptr = b.as_ptr();
146 let c_ptr = c.as_ptr();
147 let result_ptr = result.as_mut_ptr();
148
149 for i in (0..simd_len).step_by(4) {
150 let va = vld1q_f32(a_ptr.add(i));
151 let vb = vld1q_f32(b_ptr.add(i));
152 let vc = vld1q_f32(c_ptr.add(i));
153 let vresult = vfmaq_f32(vc, va, vb); vst1q_f32(result_ptr.add(i), vresult);
155 }
156
157 for i in simd_len..len {
158 result[i] = a[i] * b[i] + c[i];
159 }
160
161 Ok(())
162 }
163
164 #[target_feature(enable = "neon")]
166 pub unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32]) -> Result<f32> {
167 if a.len() != b.len() {
168 return Err(TorshError::dimension_error_with_context(
169 "Array lengths must match",
170 "simd_operation",
171 ));
172 }
173
174 let len = a.len();
175 let simd_len = len & !3;
176
177 let a_ptr = a.as_ptr();
178 let b_ptr = b.as_ptr();
179
180 let mut sum_vec = vdupq_n_f32(0.0);
181
182 for i in (0..simd_len).step_by(4) {
184 let va = vld1q_f32(a_ptr.add(i));
185 let vb = vld1q_f32(b_ptr.add(i));
186 let vmul = vmulq_f32(va, vb);
187 sum_vec = vaddq_f32(sum_vec, vmul);
188 }
189
190 let sum_pair = vadd_f32(vget_low_f32(sum_vec), vget_high_f32(sum_vec));
192 let sum_scalar = vpadd_f32(sum_pair, sum_pair);
193 let mut result = vget_lane_f32(sum_scalar, 0);
194
195 for i in simd_len..len {
197 result += a[i] * b[i];
198 }
199
200 Ok(result)
201 }
202
203 #[target_feature(enable = "neon")]
206 pub unsafe fn dot_product_i8_dotprod(a: &[i8], b: &[i8]) -> Result<i32> {
207 if a.len() != b.len() {
208 return Err(TorshError::dimension_error_with_context(
209 "Array lengths must match",
210 "simd_operation",
211 ));
212 }
213
214 let len = a.len();
215 let simd_len = len & !15; let a_ptr = a.as_ptr();
218 let b_ptr = b.as_ptr();
219
220 let mut sum_vec = vdupq_n_s32(0);
221
222 for i in (0..simd_len).step_by(16) {
225 for j in 0..4 {
227 let offset = i + j * 4;
228 if offset < len {
229 let a_vals = [
231 *a_ptr.add(offset) as i32,
232 *a_ptr.add(offset + 1) as i32,
233 *a_ptr.add(offset + 2) as i32,
234 *a_ptr.add(offset + 3) as i32,
235 ];
236 let b_vals = [
237 *b_ptr.add(offset) as i32,
238 *b_ptr.add(offset + 1) as i32,
239 *b_ptr.add(offset + 2) as i32,
240 *b_ptr.add(offset + 3) as i32,
241 ];
242
243 let va = vld1q_s32(a_vals.as_ptr());
244 let vb = vld1q_s32(b_vals.as_ptr());
245 sum_vec = vmlaq_s32(sum_vec, va, vb);
246 }
247 }
248 }
249
250 let sum_pair = vadd_s32(vget_low_s32(sum_vec), vget_high_s32(sum_vec));
252 let sum_scalar = vpadd_s32(sum_pair, sum_pair);
253 let mut result = vget_lane_s32(sum_scalar, 0);
254
255 for i in simd_len..len {
257 result += a[i] as i32 * b[i] as i32;
258 }
259
260 Ok(result)
261 }
262
263 #[target_feature(enable = "neon")]
265 pub unsafe fn sum_f32_neon(data: &[f32]) -> f32 {
266 let len = data.len();
267 let simd_len = len & !3;
268 let data_ptr = data.as_ptr();
269
270 let mut sum_vec = vdupq_n_f32(0.0);
271
272 for i in (0..simd_len).step_by(4) {
274 let vdata = vld1q_f32(data_ptr.add(i));
275 sum_vec = vaddq_f32(sum_vec, vdata);
276 }
277
278 let sum_pair = vadd_f32(vget_low_f32(sum_vec), vget_high_f32(sum_vec));
280 let sum_scalar = vpadd_f32(sum_pair, sum_pair);
281 let mut result = vget_lane_f32(sum_scalar, 0);
282
283 #[allow(clippy::needless_range_loop)] for i in simd_len..len {
286 result += data[i];
287 }
288
289 result
290 }
291
292 #[target_feature(enable = "neon")]
294 pub unsafe fn relu_f32_neon(data: &[f32], result: &mut [f32]) -> Result<()> {
295 if data.len() != result.len() {
296 return Err(TorshError::dimension_error_with_context(
297 "Array lengths must match",
298 "simd_operation",
299 ));
300 }
301
302 let len = data.len();
303 let simd_len = len & !3;
304
305 let data_ptr = data.as_ptr();
306 let result_ptr = result.as_mut_ptr();
307 let zero_vec = vdupq_n_f32(0.0);
308
309 for i in (0..simd_len).step_by(4) {
310 let vdata = vld1q_f32(data_ptr.add(i));
311 let vresult = vmaxq_f32(vdata, zero_vec);
312 vst1q_f32(result_ptr.add(i), vresult);
313 }
314
315 for i in simd_len..len {
316 result[i] = data[i].max(0.0);
317 }
318
319 Ok(())
320 }
321
322 #[target_feature(enable = "neon")]
324 pub unsafe fn matmul_f32_4x4_neon(
325 a: &[f32; 16],
326 b: &[f32; 16],
327 result: &mut [f32; 16],
328 ) -> Result<()> {
329 let a_row0 = vld1q_f32(a.as_ptr());
331 let a_row1 = vld1q_f32(a.as_ptr().add(4));
332 let a_row2 = vld1q_f32(a.as_ptr().add(8));
333 let a_row3 = vld1q_f32(a.as_ptr().add(12));
334
335 let b_col0_arr = [b[0], b[4], b[8], b[12]];
337 let b_col1_arr = [b[1], b[5], b[9], b[13]];
338 let b_col2_arr = [b[2], b[6], b[10], b[14]];
339 let b_col3_arr = [b[3], b[7], b[11], b[15]];
340
341 let b_col0 = vld1q_f32(b_col0_arr.as_ptr());
342 let b_col1 = vld1q_f32(b_col1_arr.as_ptr());
343 let b_col2 = vld1q_f32(b_col2_arr.as_ptr());
344 let b_col3 = vld1q_f32(b_col3_arr.as_ptr());
345
346 let a_rows = [a_row0, a_row1, a_row2, a_row3];
348 let b_cols = [b_col0, b_col1, b_col2, b_col3];
349
350 for i in 0..4 {
351 for j in 0..4 {
352 let dot = vmulq_f32(a_rows[i], b_cols[j]);
353 let sum_pair = vadd_f32(vget_low_f32(dot), vget_high_f32(dot));
354 let sum_scalar = vpadd_f32(sum_pair, sum_pair);
355 let final_sum = vget_lane_f32(sum_scalar, 0);
356 result[i * 4 + j] = final_sum;
357 }
358 }
359
360 Ok(())
361 }
362
363 #[target_feature(enable = "neon")]
399 pub unsafe fn memcpy_neon(src: &[u8], dest: &mut [u8]) -> Result<()> {
400 if src.len() != dest.len() {
401 return Err(TorshError::dimension_error_with_context(
402 "Source and destination lengths must match",
403 "memcpy_neon",
404 ));
405 }
406
407 let len = src.len();
408 let simd_len = len & !31; let src_ptr = src.as_ptr();
411 let dest_ptr = dest.as_mut_ptr();
412
413 for i in (0..simd_len).step_by(32) {
415 let v0 = vld1q_u8(src_ptr.add(i));
416 let v1 = vld1q_u8(src_ptr.add(i + 16));
417 vst1q_u8(dest_ptr.add(i), v0);
418 vst1q_u8(dest_ptr.add(i + 16), v1);
419 }
420
421 dest[simd_len..len].copy_from_slice(&src[simd_len..len]);
423
424 Ok(())
425 }
426}
427
428impl ArmSimdOps {
430 pub fn add_f32_safe(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
432 #[cfg(target_arch = "aarch64")]
433 {
434 if Self::is_neon_available() {
435 unsafe { Self::add_f32_neon(a, b, result) }
436 } else {
437 Self::add_f32_scalar(a, b, result)
438 }
439 }
440 #[cfg(not(target_arch = "aarch64"))]
441 {
442 Self::add_f32_scalar(a, b, result)
443 }
444 }
445
446 pub fn mul_f32_safe(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
448 #[cfg(target_arch = "aarch64")]
449 {
450 if Self::is_neon_available() {
451 unsafe { Self::mul_f32_neon(a, b, result) }
452 } else {
453 Self::mul_f32_scalar(a, b, result)
454 }
455 }
456 #[cfg(not(target_arch = "aarch64"))]
457 {
458 Self::mul_f32_scalar(a, b, result)
459 }
460 }
461
462 pub fn dot_product_f32_safe(a: &[f32], b: &[f32]) -> Result<f32> {
464 #[cfg(target_arch = "aarch64")]
465 {
466 if Self::is_neon_available() {
467 unsafe { Self::dot_product_f32_neon(a, b) }
468 } else {
469 Self::dot_product_f32_scalar(a, b)
470 }
471 }
472 #[cfg(not(target_arch = "aarch64"))]
473 {
474 Self::dot_product_f32_scalar(a, b)
475 }
476 }
477
478 fn add_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
480 if a.len() != b.len() || a.len() != result.len() {
481 return Err(TorshError::dimension_error_with_context(
482 "Array lengths must match",
483 "simd_operation",
484 ));
485 }
486
487 for i in 0..a.len() {
488 result[i] = a[i] + b[i];
489 }
490
491 Ok(())
492 }
493
494 fn mul_f32_scalar(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
496 if a.len() != b.len() || a.len() != result.len() {
497 return Err(TorshError::dimension_error_with_context(
498 "Array lengths must match",
499 "simd_operation",
500 ));
501 }
502
503 for i in 0..a.len() {
504 result[i] = a[i] * b[i];
505 }
506
507 Ok(())
508 }
509
510 fn dot_product_f32_scalar(a: &[f32], b: &[f32]) -> Result<f32> {
512 if a.len() != b.len() {
513 return Err(TorshError::dimension_error_with_context(
514 "Array lengths must match",
515 "simd_operation",
516 ));
517 }
518
519 let mut result = 0.0;
520 for i in 0..a.len() {
521 result += a[i] * b[i];
522 }
523
524 Ok(result)
525 }
526}
527
528#[cfg(not(target_arch = "aarch64"))]
529impl ArmSimdOps {
530 pub fn is_neon_available() -> bool {
532 false
533 }
534 pub fn is_asimd_available() -> bool {
535 false
536 }
537 pub fn is_fp16_available() -> bool {
538 false
539 }
540 pub fn is_dotprod_available() -> bool {
541 false
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn test_neon_availability() {
551 #[cfg(target_arch = "aarch64")]
552 {
553 let _ = ArmSimdOps::is_neon_available();
555 let _ = ArmSimdOps::is_asimd_available();
556 let _ = ArmSimdOps::is_fp16_available();
557 let _ = ArmSimdOps::is_dotprod_available();
558 }
559 }
560
561 #[test]
562 fn test_safe_operations() {
563 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
564 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
565 let mut result = vec![0.0; 8];
566
567 ArmSimdOps::add_f32_safe(&a, &b, &mut result).unwrap();
569 let expected_add = vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0];
570 assert_eq!(result, expected_add);
571
572 ArmSimdOps::mul_f32_safe(&a, &b, &mut result).unwrap();
574 let expected_mul = vec![2.0, 6.0, 12.0, 20.0, 30.0, 42.0, 56.0, 72.0];
575 assert_eq!(result, expected_mul);
576
577 let dot_result = ArmSimdOps::dot_product_f32_safe(&a, &b).unwrap();
579 let expected_dot = 240.0; assert_eq!(dot_result, expected_dot);
581 }
582
583 #[test]
584 fn test_error_handling() {
585 let a = vec![1.0, 2.0, 3.0];
586 let b = vec![1.0, 2.0];
587 let mut result = vec![0.0; 3];
588
589 assert!(ArmSimdOps::add_f32_safe(&a, &b, &mut result).is_err());
591 assert!(ArmSimdOps::dot_product_f32_safe(&a, &b).is_err());
592 }
593}