1use std::alloc::{alloc, dealloc, Layout};
8use std::ptr::{self, NonNull};
9use std::slice;
10
11pub const SIMD_ALIGNMENT: usize = 32;
13
14pub struct AlignedVec<T> {
16 ptr: NonNull<T>,
17 len: usize,
18 capacity: usize,
19}
20
21impl<T> AlignedVec<T> {
22 pub fn with_capacity(capacity: usize) -> Result<Self, Box<dyn std::error::Error>> {
24 if capacity == 0 {
25 return Ok(Self {
26 ptr: NonNull::dangling(),
27 len: 0,
28 capacity: 0,
29 });
30 }
31
32 if std::mem::size_of::<T>() == 0 {
35 return Ok(Self {
36 ptr: NonNull::dangling(),
37 len: 0,
38 capacity,
39 });
40 }
41
42 let layout = Layout::from_size_align(capacity * std::mem::size_of::<T>(), SIMD_ALIGNMENT)
43 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
44
45 let ptr = unsafe { alloc(layout) };
46 if ptr.is_null() {
47 return Err("Memory allocation failed".into());
48 }
49
50 Ok(Self {
51 ptr: unsafe { NonNull::new_unchecked(ptr as *mut T) },
52 len: 0,
53 capacity,
54 })
55 }
56
57 pub fn from_vec(vec: Vec<T>) -> Result<Self, Box<dyn std::error::Error>>
59 where
60 T: Copy,
61 {
62 let mut aligned = Self::with_capacity(vec.len())?;
63 for item in vec {
64 aligned.push(item);
65 }
66 Ok(aligned)
67 }
68
69 pub fn push(&mut self, value: T) {
71 if self.len >= self.capacity {
72 panic!("AlignedVec capacity exceeded");
73 }
74
75 unsafe {
76 ptr::write(self.ptr.as_ptr().add(self.len), value);
77 }
78 self.len += 1;
79 }
80
81 pub fn len(&self) -> usize {
83 self.len
84 }
85
86 pub fn is_empty(&self) -> bool {
88 self.len == 0
89 }
90
91 pub fn capacity(&self) -> usize {
93 self.capacity
94 }
95
96 pub fn as_slice(&self) -> &[T] {
98 if self.len == 0 {
99 &[]
100 } else {
101 unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
102 }
103 }
104
105 pub fn as_mut_slice(&mut self) -> &mut [T] {
107 if self.len == 0 {
108 &mut []
109 } else {
110 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
111 }
112 }
113
114 pub fn set(&mut self, index: usize, value: T) {
119 assert!(
120 index < self.len,
121 "Index {} out of bounds for length {}",
122 index,
123 self.len
124 );
125 unsafe {
126 ptr::write(self.ptr.as_ptr().add(index), value);
127 }
128 }
129
130 pub fn get(&self, index: usize) -> &T {
135 assert!(
136 index < self.len,
137 "Index {} out of bounds for length {}",
138 index,
139 self.len
140 );
141 unsafe { &*self.ptr.as_ptr().add(index) }
142 }
143
144 pub unsafe fn with_capacity_uninit(
149 capacity: usize,
150 ) -> Result<Self, Box<dyn std::error::Error>> {
151 let mut vec = Self::with_capacity(capacity)?;
152 vec.len = capacity; Ok(vec)
154 }
155
156 pub fn fill(&mut self, value: T)
158 where
159 T: Copy,
160 {
161 for i in 0..self.len {
162 unsafe {
163 ptr::write(self.ptr.as_ptr().add(i), value);
164 }
165 }
166 }
167
168 pub fn clear(&mut self) {
170 for i in 0..self.len {
171 unsafe {
172 ptr::drop_in_place(self.ptr.as_ptr().add(i));
173 }
174 }
175 self.len = 0;
176 }
177
178 pub fn to_vec(&self) -> Vec<T>
180 where
181 T: Clone,
182 {
183 self.as_slice().to_vec()
184 }
185
186 pub unsafe fn set_len(&mut self, new_len: usize) {
193 debug_assert!(new_len <= self.capacity);
194 self.len = new_len;
195 }
196
197 pub fn as_mut_ptr(&mut self) -> *mut T {
199 self.ptr.as_ptr()
200 }
201
202 pub fn as_ptr(&self) -> *const T {
204 self.ptr.as_ptr()
205 }
206}
207
208impl<T> Drop for AlignedVec<T> {
209 fn drop(&mut self) {
210 if self.capacity != 0 {
211 unsafe {
212 for i in 0..self.len {
214 ptr::drop_in_place(self.ptr.as_ptr().add(i));
215 }
216
217 let layout = Layout::from_size_align_unchecked(
219 self.capacity * std::mem::size_of::<T>(),
220 SIMD_ALIGNMENT,
221 );
222 dealloc(self.ptr.as_ptr() as *mut u8, layout);
223 }
224 }
225 }
226}
227
228unsafe impl<T: Send> Send for AlignedVec<T> {}
229unsafe impl<T: Sync> Sync for AlignedVec<T> {}
230
231pub fn simd_add_aligned_f32(a: &[f32], b: &[f32]) -> Result<AlignedVec<f32>, &'static str> {
233 if a.len() != b.len() {
234 return Err("Arrays must have the same length");
235 }
236
237 let len = a.len();
238 let mut result: AlignedVec<f32> =
239 AlignedVec::with_capacity(len).map_err(|_| "Failed to allocate aligned memory")?;
240
241 #[cfg(target_arch = "x86_64")]
242 {
243 use std::arch::x86_64::*;
244
245 if is_x86_feature_detected!("avx2") {
246 unsafe {
247 let mut i = 0;
248
249 while i + 8 <= len {
251 let a_ptr = a.as_ptr().add(i);
252 let b_ptr = b.as_ptr().add(i);
253 let result_ptr = result.ptr.as_ptr().add(i);
254
255 let a_vec = if (a_ptr as usize) % 32 == 0 {
257 _mm256_load_ps(a_ptr)
258 } else {
259 _mm256_loadu_ps(a_ptr)
260 };
261
262 let b_vec = if (b_ptr as usize) % 32 == 0 {
263 _mm256_load_ps(b_ptr)
264 } else {
265 _mm256_loadu_ps(b_ptr)
266 };
267
268 let result_vec = _mm256_add_ps(a_vec, b_vec);
269
270 _mm256_store_ps(result_ptr, result_vec);
272
273 i += 8;
274 }
275
276 result.len = i;
278
279 for j in i..len {
281 result.push(a[j] + b[j]);
282 }
283 }
284 } else if is_x86_feature_detected!("sse") {
285 unsafe {
286 let mut i = 0;
287
288 while i + 4 <= len {
290 let a_ptr = a.as_ptr().add(i);
291 let b_ptr = b.as_ptr().add(i);
292 let result_ptr = result.ptr.as_ptr().add(i);
293
294 let a_vec = if (a_ptr as usize) % 16 == 0 {
295 _mm_load_ps(a_ptr)
296 } else {
297 _mm_loadu_ps(a_ptr)
298 };
299
300 let b_vec = if (b_ptr as usize) % 16 == 0 {
301 _mm_load_ps(b_ptr)
302 } else {
303 _mm_loadu_ps(b_ptr)
304 };
305
306 let result_vec = _mm_add_ps(a_vec, b_vec);
307 _mm_store_ps(result_ptr, result_vec);
308
309 i += 4;
310 }
311
312 result.len = i;
313
314 for j in i..len {
316 result.push(a[j] + b[j]);
317 }
318 }
319 } else {
320 for i in 0..len {
322 result.push(a[i] + b[i]);
323 }
324 }
325 }
326
327 #[cfg(target_arch = "aarch64")]
328 {
329 use std::arch::aarch64::*;
330
331 if std::arch::is_aarch64_feature_detected!("neon") {
332 unsafe {
333 let mut i = 0;
334
335 while i + 4 <= len {
337 let a_ptr = a.as_ptr().add(i);
338 let b_ptr = b.as_ptr().add(i);
339 let result_ptr = result.ptr.as_ptr().add(i);
340
341 let a_vec = vld1q_f32(a_ptr);
342 let b_vec = vld1q_f32(b_ptr);
343 let result_vec = vaddq_f32(a_vec, b_vec);
344 vst1q_f32(result_ptr, result_vec);
345
346 i += 4;
347 }
348
349 result.len = i;
350
351 for j in i..len {
353 result.push(a[j] + b[j]);
354 }
355 }
356 } else {
357 for i in 0..len {
359 result.push(a[i] + b[i]);
360 }
361 }
362 }
363
364 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
365 {
366 for i in 0..len {
368 result.push(a[i] + b[i]);
369 }
370 }
371
372 Ok(result)
373}
374
375pub fn simd_mul_aligned_f32(a: &[f32], b: &[f32]) -> Result<AlignedVec<f32>, &'static str> {
377 if a.len() != b.len() {
378 return Err("Arrays must have the same length");
379 }
380
381 let len = a.len();
382 let mut result: AlignedVec<f32> =
383 AlignedVec::with_capacity(len).map_err(|_| "Failed to allocate aligned memory")?;
384
385 #[cfg(target_arch = "x86_64")]
386 {
387 use std::arch::x86_64::*;
388
389 if is_x86_feature_detected!("avx2") {
390 unsafe {
391 let mut i = 0;
392
393 while i + 8 <= len {
394 let a_ptr = a.as_ptr().add(i);
395 let b_ptr = b.as_ptr().add(i);
396 let result_ptr = result.ptr.as_ptr().add(i);
397
398 let a_vec = if (a_ptr as usize) % 32 == 0 {
399 _mm256_load_ps(a_ptr)
400 } else {
401 _mm256_loadu_ps(a_ptr)
402 };
403
404 let b_vec = if (b_ptr as usize) % 32 == 0 {
405 _mm256_load_ps(b_ptr)
406 } else {
407 _mm256_loadu_ps(b_ptr)
408 };
409
410 let result_vec = _mm256_mul_ps(a_vec, b_vec);
411 _mm256_store_ps(result_ptr, result_vec);
412
413 i += 8;
414 }
415
416 result.len = i;
417
418 for j in i..len {
419 result.push(a[j] * b[j]);
420 }
421 }
422 } else {
423 for i in 0..len {
425 result.push(a[i] * b[i]);
426 }
427 }
428 }
429
430 #[cfg(not(target_arch = "x86_64"))]
431 {
432 for i in 0..len {
434 result.push(a[i] * b[i]);
435 }
436 }
437
438 Ok(result)
439}
440
441pub fn simd_dot_aligned_f32(a: &[f32], b: &[f32]) -> Result<f32, &'static str> {
443 if a.len() != b.len() {
444 return Err("Arrays must have the same length");
445 }
446
447 let len = a.len();
448
449 #[cfg(target_arch = "x86_64")]
450 {
451 use std::arch::x86_64::*;
452
453 if is_x86_feature_detected!("avx2") {
454 unsafe {
455 let mut sums = _mm256_setzero_ps();
456 let mut i = 0;
457
458 while i + 8 <= len {
459 let a_ptr = a.as_ptr().add(i);
460 let b_ptr = b.as_ptr().add(i);
461
462 let a_vec = if (a_ptr as usize) % 32 == 0 {
463 _mm256_load_ps(a_ptr)
464 } else {
465 _mm256_loadu_ps(a_ptr)
466 };
467
468 let b_vec = if (b_ptr as usize) % 32 == 0 {
469 _mm256_load_ps(b_ptr)
470 } else {
471 _mm256_loadu_ps(b_ptr)
472 };
473
474 let product = _mm256_mul_ps(a_vec, b_vec);
475 sums = _mm256_add_ps(sums, product);
476
477 i += 8;
478 }
479
480 let high = _mm256_extractf128_ps(sums, 1);
482 let low = _mm256_castps256_ps128(sums);
483 let sum128 = _mm_add_ps(low, high);
484
485 let shuf = _mm_shuffle_ps(sum128, sum128, 0b1110);
486 let sum_partial = _mm_add_ps(sum128, shuf);
487 let shuf2 = _mm_shuffle_ps(sum_partial, sum_partial, 0b0001);
488 let final_sum = _mm_add_ps(sum_partial, shuf2);
489
490 let mut result = _mm_cvtss_f32(final_sum);
491
492 for j in i..len {
494 result += a[j] * b[j];
495 }
496
497 return Ok(result);
498 }
499 }
500 }
501
502 let mut sum = 0.0f32;
504 for i in 0..len {
505 sum += a[i] * b[i];
506 }
507 Ok(sum)
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_aligned_vec_creation() {
516 let mut vec = AlignedVec::<f32>::with_capacity(16).expect("Operation failed");
517 assert_eq!(vec.len(), 0);
518 assert_eq!(vec.capacity(), 16);
519
520 vec.push(1.0);
521 vec.push(2.0);
522 assert_eq!(vec.len(), 2);
523 assert_eq!(vec.as_slice(), &[1.0, 2.0]);
524 }
525
526 #[test]
527 fn test_simd_add_aligned() {
528 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
529 let b = vec![10.0f32, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
530
531 let result = simd_add_aligned_f32(&a, &b).expect("Operation failed");
532 let expected = vec![11.0f32; 10];
533
534 assert_eq!(result.as_slice(), &expected);
535 }
536
537 #[test]
538 fn test_simd_dot_aligned() {
539 let a = vec![1.0f32, 2.0, 3.0, 4.0];
540 let b = vec![5.0f32, 6.0, 7.0, 8.0];
541
542 let result = simd_dot_aligned_f32(&a, &b).expect("Operation failed");
543 let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0; assert!((result - expected).abs() < 1e-6);
546 }
547
548 #[test]
549 fn test_alignment() {
550 let mut vec = AlignedVec::<f32>::with_capacity(32).expect("Operation failed");
551 vec.push(1.0);
553 vec.push(2.0);
554 vec.push(3.0);
555 vec.push(4.0);
556
557 let ptr = vec.as_slice().as_ptr() as usize;
558 assert_eq!(ptr % SIMD_ALIGNMENT, 0, "Vector should be properly aligned");
559 }
560}