1use anyhow::Result;
5use crossbeam_utils::CachePadded;
6use std::mem::size_of;
7use std::ptr;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10pub const CACHE_LINE_SIZE: usize = 64;
12
13pub trait CacheLineAligned {
15 fn ensure_cache_aligned(&self) -> bool;
16 fn prefetch_data(&self);
17}
18
19pub struct SIMDMemoryOps;
21
22impl SIMDMemoryOps {
23 #[inline(always)]
25 pub unsafe fn memcpy_simd_optimized(dst: *mut u8, src: *const u8, len: usize) {
26 match len {
27 0 => return,
28 1..=8 => Self::memcpy_small(dst, src, len),
29 9..=16 => Self::memcpy_sse(dst, src, len),
30 17..=32 => Self::memcpy_avx(dst, src, len),
31 33..=64 => Self::memcpy_avx2(dst, src, len),
32 _ => Self::memcpy_avx512_or_fallback(dst, src, len),
33 }
34 }
35
36 #[inline(always)]
38 unsafe fn memcpy_small(dst: *mut u8, src: *const u8, len: usize) {
39 match len {
40 1 => *dst = *src,
41 2 => *(dst as *mut u16) = *(src as *const u16),
42 3 => {
43 *(dst as *mut u16) = *(src as *const u16);
44 *dst.add(2) = *src.add(2);
45 }
46 4 => *(dst as *mut u32) = *(src as *const u32),
47 5..=8 => {
48 *(dst as *mut u64) = *(src as *const u64);
49 if len > 8 {
50 ptr::copy_nonoverlapping(src.add(8), dst.add(8), len - 8);
51 }
52 }
53 _ => unreachable!(),
54 }
55 }
56
57 #[inline(always)]
59 unsafe fn memcpy_sse(dst: *mut u8, src: *const u8, len: usize) {
60 #[cfg(target_arch = "x86_64")]
61 {
62 use std::arch::x86_64::{__m128i, _mm_loadu_si128, _mm_storeu_si128};
63
64 if len <= 16 {
65 let chunk = _mm_loadu_si128(src as *const __m128i);
66 _mm_storeu_si128(dst as *mut __m128i, chunk);
67 }
68 }
69
70 #[cfg(not(target_arch = "x86_64"))]
71 {
72 ptr::copy_nonoverlapping(src, dst, len);
73 }
74 }
75
76 #[inline(always)]
78 unsafe fn memcpy_avx(dst: *mut u8, src: *const u8, len: usize) {
79 #[cfg(target_arch = "x86_64")]
80 {
81 use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
82
83 if len <= 32 {
84 let chunk = _mm256_loadu_si256(src as *const __m256i);
85 _mm256_storeu_si256(dst as *mut __m256i, chunk);
86 }
87 }
88
89 #[cfg(not(target_arch = "x86_64"))]
90 {
91 ptr::copy_nonoverlapping(src, dst, len);
92 }
93 }
94
95 #[inline(always)]
97 unsafe fn memcpy_avx2(dst: *mut u8, src: *const u8, len: usize) {
98 #[cfg(target_arch = "x86_64")]
99 {
100 use std::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
101
102 let chunk1 = _mm256_loadu_si256(src as *const __m256i);
103 _mm256_storeu_si256(dst as *mut __m256i, chunk1);
104 if len > 32 {
105 let remaining = len - 32;
106 if remaining <= 32 {
107 let chunk2 = _mm256_loadu_si256(src.add(32) as *const __m256i);
108 _mm256_storeu_si256(dst.add(32) as *mut __m256i, chunk2);
109 }
110 }
111 }
112
113 #[cfg(not(target_arch = "x86_64"))]
114 {
115 ptr::copy_nonoverlapping(src, dst, len);
116 }
117 }
118
119 #[inline(always)]
121 unsafe fn memcpy_avx512_or_fallback(dst: *mut u8, src: *const u8, len: usize) {
122 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
123 {
124 use std::arch::x86_64::{__m512i, _mm512_loadu_si512, _mm512_storeu_si512};
125
126 let chunks = len / 64;
127 let mut offset = 0;
128
129 for _ in 0..chunks {
130 let chunk = _mm512_loadu_si512(src.add(offset) as *const __m512i);
131 _mm512_storeu_si512(dst.add(offset) as *mut __m512i, chunk);
132 offset += 64;
133 }
134
135 let remaining = len % 64;
136 if remaining > 0 {
137 Self::memcpy_avx2(dst.add(offset), src.add(offset), remaining);
138 }
139 }
140
141 #[cfg(not(all(target_arch = "x86_64", target_feature = "avx512f")))]
142 {
143 let chunks = len / 32;
144 let mut offset = 0;
145
146 for _ in 0..chunks {
147 Self::memcpy_avx2(dst.add(offset), src.add(offset), 32);
148 offset += 32;
149 }
150
151 let remaining = len % 32;
152 if remaining > 0 {
153 Self::memcpy_avx(dst.add(offset), src.add(offset), remaining);
154 }
155 }
156 }
157
158 #[inline(always)]
160 pub unsafe fn memcmp_simd_optimized(a: *const u8, b: *const u8, len: usize) -> bool {
161 match len {
162 0 => true,
163 1..=8 => Self::memcmp_small(a, b, len),
164 9..=16 => Self::memcmp_sse(a, b, len),
165 17..=32 => Self::memcmp_avx2(a, b, len),
166 _ => Self::memcmp_large(a, b, len),
167 }
168 }
169
170 #[inline(always)]
172 unsafe fn memcmp_small(a: *const u8, b: *const u8, len: usize) -> bool {
173 match len {
174 1 => *a == *b,
175 2 => ptr::read_unaligned(a as *const u16) == ptr::read_unaligned(b as *const u16),
176 3 => {
177 ptr::read_unaligned(a as *const u16) == ptr::read_unaligned(b as *const u16)
178 && *a.add(2) == *b.add(2)
179 }
180 4 => ptr::read_unaligned(a as *const u32) == ptr::read_unaligned(b as *const u32),
181 5..=7 => {
182 ptr::read_unaligned(a as *const u32) == ptr::read_unaligned(b as *const u32)
183 && (4..len).all(|i| *a.add(i) == *b.add(i))
184 }
185 8 => ptr::read_unaligned(a as *const u64) == ptr::read_unaligned(b as *const u64),
186 _ => unreachable!(),
187 }
188 }
189
190 #[inline(always)]
192 unsafe fn memcmp_sse(a: *const u8, b: *const u8, len: usize) -> bool {
193 #[cfg(target_arch = "x86_64")]
194 {
195 use std::arch::x86_64::{__m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8};
196
197 let chunk_a = _mm_loadu_si128(a as *const __m128i);
198 let chunk_b = _mm_loadu_si128(b as *const __m128i);
199 let cmp_result = _mm_cmpeq_epi8(chunk_a, chunk_b);
200 let mask = _mm_movemask_epi8(cmp_result) as u32;
201
202 let valid_mask = if len >= 16 { 0xFFFF } else { (1u32 << len) - 1 };
203 (mask & valid_mask) == valid_mask
204 }
205
206 #[cfg(not(target_arch = "x86_64"))]
207 {
208 (0..len).all(|i| *a.add(i) == *b.add(i))
209 }
210 }
211
212 #[inline(always)]
214 unsafe fn memcmp_avx2(a: *const u8, b: *const u8, len: usize) -> bool {
215 #[cfg(target_arch = "x86_64")]
216 {
217 use std::arch::x86_64::{
218 __m256i, _mm256_cmpeq_epi8, _mm256_loadu_si256, _mm256_movemask_epi8,
219 };
220
221 let chunk_a = _mm256_loadu_si256(a as *const __m256i);
222 let chunk_b = _mm256_loadu_si256(b as *const __m256i);
223 let cmp_result = _mm256_cmpeq_epi8(chunk_a, chunk_b);
224 let mask = _mm256_movemask_epi8(cmp_result) as u32;
225
226 let valid_mask = if len >= 32 { 0xFFFFFFFF } else { (1u32 << len) - 1 };
227 (mask & valid_mask) == valid_mask
228 }
229
230 #[cfg(not(target_arch = "x86_64"))]
231 {
232 (0..len).all(|i| *a.add(i) == *b.add(i))
233 }
234 }
235
236 #[inline(always)]
238 unsafe fn memcmp_large(a: *const u8, b: *const u8, len: usize) -> bool {
239 let chunks = len / 32;
240
241 for i in 0..chunks {
242 let offset = i * 32;
243 if !Self::memcmp_avx2(a.add(offset), b.add(offset), 32) {
244 return false;
245 }
246 }
247
248 let remaining = len % 32;
249 if remaining > 0 {
250 return Self::memcmp_avx2(a.add(chunks * 32), b.add(chunks * 32), remaining);
251 }
252
253 true
254 }
255
256 #[inline(always)]
258 pub unsafe fn memzero_simd_optimized(ptr: *mut u8, len: usize) {
259 #[cfg(target_arch = "x86_64")]
260 {
261 use std::arch::x86_64::{__m256i, _mm256_setzero_si256, _mm256_storeu_si256};
262
263 let zero = _mm256_setzero_si256();
264 let chunks = len / 32;
265 let mut offset = 0;
266
267 for _ in 0..chunks {
268 _mm256_storeu_si256(ptr.add(offset) as *mut __m256i, zero);
269 offset += 32;
270 }
271
272 let remaining = len % 32;
273 for i in 0..remaining {
274 *ptr.add(offset + i) = 0;
275 }
276 }
277
278 #[cfg(not(target_arch = "x86_64"))]
279 {
280 ptr::write_bytes(ptr, 0, len);
281 }
282 }
283}
284
285#[repr(align(64))]
287pub struct CacheAlignedCounter {
288 value: AtomicU64,
289 _padding: [u8; CACHE_LINE_SIZE - size_of::<AtomicU64>()],
290}
291
292impl CacheAlignedCounter {
293 pub fn new(initial: u64) -> Self {
295 Self {
296 value: AtomicU64::new(initial),
297 _padding: [0; CACHE_LINE_SIZE - size_of::<AtomicU64>()],
298 }
299 }
300
301 #[inline(always)]
302 pub fn increment(&self) -> u64 {
303 self.value.fetch_add(1, Ordering::Relaxed)
304 }
305
306 #[inline(always)]
307 pub fn load(&self) -> u64 {
308 self.value.load(Ordering::Relaxed)
309 }
310
311 #[inline(always)]
312 pub fn store(&self, val: u64) {
313 self.value.store(val, Ordering::Relaxed)
314 }
315}
316
317impl CacheLineAligned for CacheAlignedCounter {
318 fn ensure_cache_aligned(&self) -> bool {
319 (self as *const Self as usize) % CACHE_LINE_SIZE == 0
320 }
321
322 fn prefetch_data(&self) {
323 #[cfg(target_arch = "x86_64")]
324 unsafe {
325 use std::arch::x86_64::_mm_prefetch;
326 use std::arch::x86_64::_MM_HINT_T0;
327 _mm_prefetch(self as *const Self as *const i8, _MM_HINT_T0);
328 }
329 }
330}
331
332#[repr(align(64))]
334pub struct CacheOptimizedRingBuffer<T> {
335 buffer: Vec<T>,
336 producer_head: CachePadded<AtomicU64>,
337 consumer_tail: CachePadded<AtomicU64>,
338 capacity: usize,
339 mask: usize,
340}
341
342impl<T: Copy + Default> CacheOptimizedRingBuffer<T> {
343 pub fn new(capacity: usize) -> Result<Self> {
345 if !capacity.is_power_of_two() {
346 return Err(anyhow::anyhow!("Capacity must be a power of 2"));
347 }
348
349 let mut buffer = Vec::with_capacity(capacity);
350 buffer.resize_with(capacity, Default::default);
351
352 Ok(Self {
353 buffer,
354 producer_head: CachePadded::new(AtomicU64::new(0)),
355 consumer_tail: CachePadded::new(AtomicU64::new(0)),
356 capacity,
357 mask: capacity - 1,
358 })
359 }
360
361 #[inline(always)]
363 pub fn try_push(&self, item: T) -> bool {
364 let current_head = self.producer_head.load(Ordering::Relaxed);
365 let current_tail = self.consumer_tail.load(Ordering::Acquire);
366 if (current_head + 1) & self.mask as u64 == current_tail & self.mask as u64 {
367 return false;
368 }
369 unsafe {
370 let index = current_head & self.mask as u64;
371 let ptr = self.buffer.as_ptr().add(index as usize) as *mut T;
372 ptr.write(item);
373 }
374 self.producer_head.store(current_head + 1, Ordering::Release);
375 true
376 }
377
378 #[inline(always)]
380 pub fn try_pop(&self) -> Option<T> {
381 let current_tail = self.consumer_tail.load(Ordering::Relaxed);
382 let current_head = self.producer_head.load(Ordering::Acquire);
383 if current_tail == current_head {
384 return None;
385 }
386 let item = unsafe {
387 let index = current_tail & self.mask as u64;
388 let ptr = self.buffer.as_ptr().add(index as usize);
389 ptr.read()
390 };
391 self.consumer_tail.store(current_tail + 1, Ordering::Release);
392 Some(item)
393 }
394
395 #[inline(always)]
397 pub fn len(&self) -> usize {
398 let head = self.producer_head.load(Ordering::Relaxed);
399 let tail = self.consumer_tail.load(Ordering::Relaxed);
400 ((head + self.capacity as u64 - tail) & self.mask as u64) as usize
401 }
402
403 #[inline(always)]
405 pub fn is_empty(&self) -> bool {
406 self.producer_head.load(Ordering::Relaxed) == self.consumer_tail.load(Ordering::Relaxed)
407 }
408}
409
410impl<T> CacheLineAligned for CacheOptimizedRingBuffer<T> {
411 fn ensure_cache_aligned(&self) -> bool {
412 (self as *const Self as usize) % CACHE_LINE_SIZE == 0
413 }
414
415 fn prefetch_data(&self) {
416 #[cfg(target_arch = "x86_64")]
417 unsafe {
418 use std::arch::x86_64::_mm_prefetch;
419 use std::arch::x86_64::_MM_HINT_T0;
420 _mm_prefetch(self.producer_head.as_ptr() as *const i8, _MM_HINT_T0);
421 _mm_prefetch(self.consumer_tail.as_ptr() as *const i8, _MM_HINT_T0);
422 _mm_prefetch(self.buffer.as_ptr() as *const i8, _MM_HINT_T0);
423 }
424 }
425}
426
427pub struct BranchOptimizer;
429
430impl BranchOptimizer {
431 #[inline(always)]
433 pub fn likely(condition: bool) -> bool {
434 #[cold]
435 fn cold() {}
436
437 if !condition {
438 cold();
439 }
440 condition
441 }
442
443 #[inline(always)]
445 pub fn unlikely(condition: bool) -> bool {
446 #[cold]
447 fn cold() {}
448
449 if condition {
450 cold();
451 }
452 condition
453 }
454
455 #[inline(always)]
457 pub unsafe fn prefetch_read_data<T>(_ptr: *const T) {
458 #[cfg(target_arch = "x86_64")]
459 {
460 use std::arch::x86_64::_mm_prefetch;
461 use std::arch::x86_64::_MM_HINT_T0;
462 _mm_prefetch(_ptr as *const i8, _MM_HINT_T0);
463 }
464 }
465
466 #[inline(always)]
468 pub unsafe fn prefetch_write_data<T>(_ptr: *const T) {
469 #[cfg(target_arch = "x86_64")]
470 {
471 use std::arch::x86_64::_mm_prefetch;
472 use std::arch::x86_64::_MM_HINT_T1;
473 _mm_prefetch(_ptr as *const i8, _MM_HINT_T1);
474 }
475 }
476}
477
478pub struct MemoryBarriers;
480
481impl MemoryBarriers {
482 #[inline(always)]
484 pub fn compiler_barrier() {
485 std::sync::atomic::compiler_fence(Ordering::SeqCst);
486 }
487
488 #[inline(always)]
490 pub fn memory_barrier_light() {
491 std::sync::atomic::fence(Ordering::Acquire);
492 }
493
494 #[inline(always)]
496 pub fn memory_barrier_heavy() {
497 std::sync::atomic::fence(Ordering::SeqCst);
498 }
499
500 #[inline(always)]
502 pub fn store_barrier() {
503 std::sync::atomic::fence(Ordering::Release);
504 }
505
506 #[inline(always)]
508 pub fn load_barrier() {
509 std::sync::atomic::fence(Ordering::Acquire);
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_cache_aligned_counter() {
519 let counter = CacheAlignedCounter::new(0);
520 assert!(counter.ensure_cache_aligned());
521
522 assert_eq!(counter.load(), 0);
523 counter.increment();
524 assert_eq!(counter.load(), 1);
525 }
526
527 #[test]
528 fn test_simd_memcpy() {
529 let src = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
530 let mut dst = [0u8; 10];
531
532 unsafe {
533 SIMDMemoryOps::memcpy_simd_optimized(dst.as_mut_ptr(), src.as_ptr(), src.len());
534 }
535
536 assert_eq!(src, dst);
537 }
538
539 #[test]
540 fn test_cache_optimized_ring_buffer() {
541 let buffer: CacheOptimizedRingBuffer<u64> = CacheOptimizedRingBuffer::new(16).unwrap();
542
543 assert!(buffer.is_empty());
544
545 assert!(buffer.try_push(42));
547 assert_eq!(buffer.len(), 1);
548
549 assert_eq!(buffer.try_pop(), Some(42));
551 assert!(buffer.is_empty());
552 }
553
554 #[test]
555 fn test_simd_memcmp() {
556 let a = [1u8, 2, 3, 4, 5];
557 let b = [1u8, 2, 3, 4, 5];
558 let c = [1u8, 2, 3, 4, 6];
559
560 unsafe {
561 assert!(SIMDMemoryOps::memcmp_simd_optimized(a.as_ptr(), b.as_ptr(), a.len()));
562
563 assert!(!SIMDMemoryOps::memcmp_simd_optimized(a.as_ptr(), c.as_ptr(), a.len()));
564 }
565 }
566}