1use crate::marshal::{register_typed_fn_1, register_typed_fn_2, register_typed_fn_3};
24use crate::module_exports::ModuleExports;
25use crate::typed_module_exports::{ConcreteReturn, ConcreteType, TypedReturn};
26use shape_value::aligned_vec::AlignedVec;
27use std::sync::Arc;
28use wide::f64x4;
29
30const SIMD_THRESHOLD: usize = 16;
31
32pub fn create_vector_intrinsics_module() -> ModuleExports {
36 let mut module = ModuleExports::new("std::core::intrinsics::vector");
37 module.description = "SIMD vector element-wise intrinsics".to_string();
38
39 register_typed_fn_1::<_, Arc<Vec<f64>>>(
40 &mut module,
41 "__intrinsic_vec_abs",
42 "Element-wise absolute value of a Vec<number>",
43 "input",
44 "Array<number>",
45 ConcreteType::ArrayNumber,
46 |input, _ctx| {
47 let result = unary_apply(input.as_slice(), |v| v.abs(), f64::abs);
48 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
49 },
50 );
51
52 register_typed_fn_1::<_, Arc<Vec<f64>>>(
53 &mut module,
54 "__intrinsic_vec_sqrt",
55 "Element-wise square root of a Vec<number>",
56 "input",
57 "Array<number>",
58 ConcreteType::ArrayNumber,
59 |input, _ctx| {
60 let result = unary_apply(input.as_slice(), |v| v.sqrt(), f64::sqrt);
61 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
62 },
63 );
64
65 register_typed_fn_1::<_, Arc<Vec<f64>>>(
67 &mut module,
68 "__intrinsic_vec_ln",
69 "Element-wise natural logarithm of a Vec<number>",
70 "input",
71 "Array<number>",
72 ConcreteType::ArrayNumber,
73 |input, _ctx| {
74 let result: Vec<f64> = input.as_slice().iter().map(|x| x.ln()).collect();
75 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
76 },
77 );
78
79 register_typed_fn_1::<_, Arc<Vec<f64>>>(
80 &mut module,
81 "__intrinsic_vec_exp",
82 "Element-wise exponential of a Vec<number>",
83 "input",
84 "Array<number>",
85 ConcreteType::ArrayNumber,
86 |input, _ctx| {
87 let result: Vec<f64> = input.as_slice().iter().map(|x| x.exp()).collect();
88 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
89 },
90 );
91
92 register_typed_fn_2::<_, Arc<Vec<f64>>, Arc<Vec<f64>>>(
93 &mut module,
94 "__intrinsic_vec_add",
95 "Element-wise addition of two Vec<number>",
96 [("a", "Array<number>"), ("b", "Array<number>")],
97 ConcreteType::ArrayNumber,
98 |a, b, _ctx| {
99 check_lens(a.len(), b.len(), "vec_add")?;
100 let result = binary_apply(a.as_slice(), b.as_slice(), |va, vb| va + vb, |x, y| x + y);
101 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
102 },
103 );
104
105 register_typed_fn_2::<_, Arc<Vec<f64>>, Arc<Vec<f64>>>(
106 &mut module,
107 "__intrinsic_vec_sub",
108 "Element-wise subtraction of two Vec<number>",
109 [("a", "Array<number>"), ("b", "Array<number>")],
110 ConcreteType::ArrayNumber,
111 |a, b, _ctx| {
112 check_lens(a.len(), b.len(), "vec_sub")?;
113 let result = binary_apply(a.as_slice(), b.as_slice(), |va, vb| va - vb, |x, y| x - y);
114 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
115 },
116 );
117
118 register_typed_fn_2::<_, Arc<Vec<f64>>, Arc<Vec<f64>>>(
119 &mut module,
120 "__intrinsic_vec_mul",
121 "Element-wise multiplication of two Vec<number>",
122 [("a", "Array<number>"), ("b", "Array<number>")],
123 ConcreteType::ArrayNumber,
124 |a, b, _ctx| {
125 check_lens(a.len(), b.len(), "vec_mul")?;
126 let result = binary_apply(a.as_slice(), b.as_slice(), |va, vb| va * vb, |x, y| x * y);
127 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
128 },
129 );
130
131 register_typed_fn_2::<_, Arc<Vec<f64>>, Arc<Vec<f64>>>(
132 &mut module,
133 "__intrinsic_vec_div",
134 "Element-wise division of two Vec<number>",
135 [("a", "Array<number>"), ("b", "Array<number>")],
136 ConcreteType::ArrayNumber,
137 |a, b, _ctx| {
138 check_lens(a.len(), b.len(), "vec_div")?;
139 let result = binary_apply(a.as_slice(), b.as_slice(), |va, vb| va / vb, |x, y| x / y);
140 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
141 },
142 );
143
144 register_typed_fn_2::<_, Arc<Vec<f64>>, Arc<Vec<f64>>>(
145 &mut module,
146 "__intrinsic_vec_max",
147 "Element-wise max of two Vec<number>",
148 [("a", "Array<number>"), ("b", "Array<number>")],
149 ConcreteType::ArrayNumber,
150 |a, b, _ctx| {
151 check_lens(a.len(), b.len(), "vec_max")?;
152 let result = binary_apply(a.as_slice(), b.as_slice(), |va, vb| va.max(vb), f64::max);
153 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
154 },
155 );
156
157 register_typed_fn_2::<_, Arc<Vec<f64>>, Arc<Vec<f64>>>(
158 &mut module,
159 "__intrinsic_vec_min",
160 "Element-wise min of two Vec<number>",
161 [("a", "Array<number>"), ("b", "Array<number>")],
162 ConcreteType::ArrayNumber,
163 |a, b, _ctx| {
164 check_lens(a.len(), b.len(), "vec_min")?;
165 let result = binary_apply(a.as_slice(), b.as_slice(), |va, vb| va.min(vb), f64::min);
166 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
167 },
168 );
169
170 register_typed_fn_3::<_, Arc<Vec<f64>>, Arc<Vec<f64>>, Arc<Vec<f64>>>(
171 &mut module,
172 "__intrinsic_vec_select",
173 "Element-wise select: cond[i] != 0 ? t[i] : f[i]",
174 [
175 ("cond", "Array<number>"),
176 ("t", "Array<number>"),
177 ("f", "Array<number>"),
178 ],
179 ConcreteType::ArrayNumber,
180 |cond, t, f, _ctx| {
181 let n = cond.len();
182 if t.len() != n || f.len() != n {
183 return Err(format!(
184 "vec_select: length mismatch cond={}, t={}, f={}",
185 n,
186 t.len(),
187 f.len()
188 ));
189 }
190 let cond_data = cond.as_slice();
191 let t_data = t.as_slice();
192 let f_data = f.as_slice();
193 let mut result = Vec::with_capacity(n);
194 for i in 0..n {
195 result.push(if cond_data[i] != 0.0 {
196 t_data[i]
197 } else {
198 f_data[i]
199 });
200 }
201 Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(result)))
202 },
203 );
204
205 register_typed_fn_2::<_, Arc<Vec<i64>>, Arc<Vec<i64>>>(
206 &mut module,
207 "__intrinsic_vec_add_i64",
208 "Element-wise addition of two Vec<int>, overflow-checked",
209 [("a", "Array<int>"), ("b", "Array<int>")],
210 ConcreteType::ArrayInt,
211 |a, b, _ctx| {
212 check_lens(a.len(), b.len(), "vec_add_i64")?;
213 simd_vec_add_i64(a.as_slice(), b.as_slice())
214 .map(|r| TypedReturn::Concrete(ConcreteReturn::ArrayI64(r)))
215 .map_err(|()| "Integer overflow in Vec<int> element-wise addition".to_string())
216 },
217 );
218
219 module
220}
221
222#[inline]
225fn check_lens(a: usize, b: usize, name: &str) -> Result<(), String> {
226 if a != b {
227 Err(format!("Vector length mismatch in {}: {} vs {}", name, a, b))
228 } else {
229 Ok(())
230 }
231}
232
233fn unary_apply(
235 data: &[f64],
236 simd_op: impl Fn(f64x4) -> f64x4,
237 scalar_op: impl Fn(f64) -> f64,
238) -> Vec<f64> {
239 let len = data.len();
240 let mut result = vec![0.0; len];
241 if len >= SIMD_THRESHOLD {
242 let chunks = len / 4;
243 for i in 0..chunks {
244 let idx = i * 4;
245 let v = f64x4::from(&data[idx..idx + 4]);
246 let res = simd_op(v);
247 result[idx..idx + 4].copy_from_slice(&res.to_array());
248 }
249 for i in (chunks * 4)..len {
250 result[i] = scalar_op(data[i]);
251 }
252 } else {
253 for i in 0..len {
254 result[i] = scalar_op(data[i]);
255 }
256 }
257 result
258}
259
260fn binary_apply(
263 a: &[f64],
264 b: &[f64],
265 simd_op: impl Fn(f64x4, f64x4) -> f64x4,
266 scalar_op: impl Fn(f64, f64) -> f64,
267) -> Vec<f64> {
268 let len = a.len();
269 let mut result = vec![0.0; len];
270 if len >= SIMD_THRESHOLD {
271 let chunks = len / 4;
272 for i in 0..chunks {
273 let idx = i * 4;
274 let va = f64x4::from(&a[idx..idx + 4]);
275 let vb = f64x4::from(&b[idx..idx + 4]);
276 let res = simd_op(va, vb);
277 result[idx..idx + 4].copy_from_slice(&res.to_array());
278 }
279 for i in (chunks * 4)..len {
280 result[i] = scalar_op(a[i], b[i]);
281 }
282 } else {
283 for i in 0..len {
284 result[i] = scalar_op(a[i], b[i]);
285 }
286 }
287 result
288}
289
290pub fn simd_vec_add_f64(a: &[f64], b: &[f64]) -> AlignedVec<f64> {
300 debug_assert_eq!(a.len(), b.len());
301 let len = a.len();
302 let mut result = AlignedVec::with_capacity(len);
303 if len >= SIMD_THRESHOLD {
304 let chunks = len / 4;
305 for i in 0..chunks {
306 let idx = i * 4;
307 let va = f64x4::from(&a[idx..idx + 4]);
308 let vb = f64x4::from(&b[idx..idx + 4]);
309 let res = va + vb;
310 for &v in res.to_array().iter() {
311 result.push(v);
312 }
313 }
314 for i in (chunks * 4)..len {
315 result.push(a[i] + b[i]);
316 }
317 } else {
318 for i in 0..len {
319 result.push(a[i] + b[i]);
320 }
321 }
322 result
323}
324
325pub fn simd_vec_sub_f64(a: &[f64], b: &[f64]) -> AlignedVec<f64> {
327 debug_assert_eq!(a.len(), b.len());
328 let len = a.len();
329 let mut result = AlignedVec::with_capacity(len);
330 if len >= SIMD_THRESHOLD {
331 let chunks = len / 4;
332 for i in 0..chunks {
333 let idx = i * 4;
334 let va = f64x4::from(&a[idx..idx + 4]);
335 let vb = f64x4::from(&b[idx..idx + 4]);
336 let res = va - vb;
337 for &v in res.to_array().iter() {
338 result.push(v);
339 }
340 }
341 for i in (chunks * 4)..len {
342 result.push(a[i] - b[i]);
343 }
344 } else {
345 for i in 0..len {
346 result.push(a[i] - b[i]);
347 }
348 }
349 result
350}
351
352pub fn simd_vec_mul_f64(a: &[f64], b: &[f64]) -> AlignedVec<f64> {
354 debug_assert_eq!(a.len(), b.len());
355 let len = a.len();
356 let mut result = AlignedVec::with_capacity(len);
357 if len >= SIMD_THRESHOLD {
358 let chunks = len / 4;
359 for i in 0..chunks {
360 let idx = i * 4;
361 let va = f64x4::from(&a[idx..idx + 4]);
362 let vb = f64x4::from(&b[idx..idx + 4]);
363 let res = va * vb;
364 for &v in res.to_array().iter() {
365 result.push(v);
366 }
367 }
368 for i in (chunks * 4)..len {
369 result.push(a[i] * b[i]);
370 }
371 } else {
372 for i in 0..len {
373 result.push(a[i] * b[i]);
374 }
375 }
376 result
377}
378
379pub fn simd_vec_div_f64(a: &[f64], b: &[f64]) -> AlignedVec<f64> {
381 debug_assert_eq!(a.len(), b.len());
382 let len = a.len();
383 let mut result = AlignedVec::with_capacity(len);
384 if len >= SIMD_THRESHOLD {
385 let chunks = len / 4;
386 for i in 0..chunks {
387 let idx = i * 4;
388 let va = f64x4::from(&a[idx..idx + 4]);
389 let vb = f64x4::from(&b[idx..idx + 4]);
390 let res = va / vb;
391 for &v in res.to_array().iter() {
392 result.push(v);
393 }
394 }
395 for i in (chunks * 4)..len {
396 result.push(a[i] / b[i]);
397 }
398 } else {
399 for i in 0..len {
400 result.push(a[i] / b[i]);
401 }
402 }
403 result
404}
405
406pub fn simd_vec_scale_f64(a: &[f64], scalar: f64) -> AlignedVec<f64> {
408 let len = a.len();
409 let mut result = AlignedVec::with_capacity(len);
410 if len >= SIMD_THRESHOLD {
411 let s_vec = f64x4::splat(scalar);
412 let chunks = len / 4;
413 for i in 0..chunks {
414 let idx = i * 4;
415 let va = f64x4::from(&a[idx..idx + 4]);
416 let res = va * s_vec;
417 for &v in res.to_array().iter() {
418 result.push(v);
419 }
420 }
421 for i in (chunks * 4)..len {
422 result.push(a[i] * scalar);
423 }
424 } else {
425 for i in 0..len {
426 result.push(a[i] * scalar);
427 }
428 }
429 result
430}
431
432pub fn simd_vec_add_i64(a: &[i64], b: &[i64]) -> std::result::Result<Vec<i64>, ()> {
435 debug_assert_eq!(a.len(), b.len());
436 let len = a.len();
437 let mut result = Vec::with_capacity(len);
438 for i in 0..len {
439 match a[i].checked_add(b[i]) {
440 Some(v) => result.push(v),
441 None => return Err(()),
442 }
443 }
444 Ok(result)
445}
446
447pub fn simd_vec_sub_i64(a: &[i64], b: &[i64]) -> std::result::Result<Vec<i64>, ()> {
449 debug_assert_eq!(a.len(), b.len());
450 let len = a.len();
451 let mut result = Vec::with_capacity(len);
452 for i in 0..len {
453 match a[i].checked_sub(b[i]) {
454 Some(v) => result.push(v),
455 None => return Err(()),
456 }
457 }
458 Ok(result)
459}
460
461pub fn simd_vec_mul_i64(a: &[i64], b: &[i64]) -> std::result::Result<Vec<i64>, ()> {
463 debug_assert_eq!(a.len(), b.len());
464 let len = a.len();
465 let mut result = Vec::with_capacity(len);
466 for i in 0..len {
467 match a[i].checked_mul(b[i]) {
468 Some(v) => result.push(v),
469 None => return Err(()),
470 }
471 }
472 Ok(result)
473}
474
475pub fn simd_vec_div_i64(a: &[i64], b: &[i64]) -> std::result::Result<Vec<i64>, ()> {
477 debug_assert_eq!(a.len(), b.len());
478 let len = a.len();
479 let mut result = Vec::with_capacity(len);
480 for i in 0..len {
481 if b[i] == 0 {
482 return Err(());
483 }
484 match a[i].checked_div(b[i]) {
485 Some(v) => result.push(v),
486 None => return Err(()),
487 }
488 }
489 Ok(result)
490}
491
492pub fn i64_slice_to_f64(data: &[i64]) -> AlignedVec<f64> {
494 let mut result = AlignedVec::with_capacity(data.len());
495 for &v in data {
496 result.push(v as f64);
497 }
498 result
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
509 fn test_simd_vec_add_f64_small() {
510 let a = [1.0, 2.0, 3.0];
511 let b = [4.0, 5.0, 6.0];
512 let result = simd_vec_add_f64(&a, &b);
513 assert_eq!(&*result, &[5.0, 7.0, 9.0]);
514 }
515
516 #[test]
517 fn test_simd_vec_add_f64_large() {
518 let a: Vec<f64> = (0..20).map(|i| i as f64).collect();
519 let b: Vec<f64> = (0..20).map(|i| (i * 2) as f64).collect();
520 let result = simd_vec_add_f64(&a, &b);
521 for i in 0..20 {
522 assert_eq!(result[i], (i * 3) as f64);
523 }
524 }
525
526 #[test]
527 fn test_simd_vec_sub_f64() {
528 let a = [10.0, 20.0, 30.0];
529 let b = [3.0, 5.0, 7.0];
530 let result = simd_vec_sub_f64(&a, &b);
531 assert_eq!(&*result, &[7.0, 15.0, 23.0]);
532 }
533
534 #[test]
535 fn test_simd_vec_mul_f64() {
536 let a = [2.0, 3.0, 4.0];
537 let b = [5.0, 6.0, 7.0];
538 let result = simd_vec_mul_f64(&a, &b);
539 assert_eq!(&*result, &[10.0, 18.0, 28.0]);
540 }
541
542 #[test]
543 fn test_simd_vec_div_f64() {
544 let a = [10.0, 20.0, 30.0];
545 let b = [2.0, 5.0, 6.0];
546 let result = simd_vec_div_f64(&a, &b);
547 assert_eq!(&*result, &[5.0, 4.0, 5.0]);
548 }
549
550 #[test]
551 fn test_simd_vec_scale_f64() {
552 let a = [1.0, 2.0, 3.0];
553 let result = simd_vec_scale_f64(&a, 10.0);
554 assert_eq!(&*result, &[10.0, 20.0, 30.0]);
555 }
556
557 #[test]
558 fn test_simd_vec_add_i64_ok() {
559 let a = [1i64, 2, 3];
560 let b = [4i64, 5, 6];
561 let result = simd_vec_add_i64(&a, &b).unwrap();
562 assert_eq!(result, vec![5, 7, 9]);
563 }
564
565 #[test]
566 fn test_simd_vec_add_i64_overflow() {
567 let a = [i64::MAX];
568 let b = [1i64];
569 assert!(simd_vec_add_i64(&a, &b).is_err());
570 }
571
572 #[test]
573 fn test_simd_vec_div_i64_zero() {
574 let a = [10i64];
575 let b = [0i64];
576 assert!(simd_vec_div_i64(&a, &b).is_err());
577 }
578
579 #[test]
580 fn test_i64_slice_to_f64() {
581 let data = [1i64, -2, 100];
582 let result = i64_slice_to_f64(&data);
583 assert_eq!(&*result, &[1.0, -2.0, 100.0]);
584 }
585
586 #[test]
589 fn test_unary_apply_abs_simd() {
590 let data: Vec<f64> = (0..20).map(|i| -(i as f64)).collect();
591 let result = unary_apply(&data, |v| v.abs(), f64::abs);
592 for i in 0..20 {
593 assert_eq!(result[i], i as f64);
594 }
595 }
596
597 #[test]
598 fn test_binary_apply_add_simd() {
599 let a: Vec<f64> = (0..20).map(|i| i as f64).collect();
600 let b: Vec<f64> = (0..20).map(|i| (i * 2) as f64).collect();
601 let result = binary_apply(&a, &b, |va, vb| va + vb, |x, y| x + y);
602 for i in 0..20 {
603 assert_eq!(result[i], (i * 3) as f64);
604 }
605 }
606}