Skip to main content

sol_trade_sdk/perf/
simd.rs

1//! 🚀 SIMD 优化模块
2//!
3//! 使用 SIMD 指令加速数据处理:
4//! - 内存拷贝加速
5//! - 批量哈希计算
6//! - 向量化数学运算
7//! - 并行数据处理
8
9#[cfg(target_arch = "x86_64")]
10use std::arch::x86_64::*;
11
12/// SIMD 内存操作
13pub struct SIMDMemory;
14
15impl SIMDMemory {
16    /// 使用 SIMD 加速内存拷贝(256位 AVX2)
17    #[cfg(target_arch = "x86_64")]
18    #[inline(always)]
19    pub unsafe fn copy_avx2(dst: *mut u8, src: *const u8, len: usize) {
20        let mut offset = 0;
21
22        // 32字节对齐的批量拷贝(AVX2)
23        while offset + 32 <= len {
24            let data = _mm256_loadu_si256(src.add(offset) as *const __m256i);
25            _mm256_storeu_si256(dst.add(offset) as *mut __m256i, data);
26            offset += 32;
27        }
28
29        // 处理剩余字节
30        while offset < len {
31            *dst.add(offset) = *src.add(offset);
32            offset += 1;
33        }
34    }
35
36    /// 使用通用方法拷贝内存(非x86_64架构)
37    #[cfg(not(target_arch = "x86_64"))]
38    #[inline(always)]
39    pub unsafe fn copy_avx2(dst: *mut u8, src: *const u8, len: usize) {
40        std::ptr::copy_nonoverlapping(src, dst, len);
41    }
42
43    /// 使用 SIMD 加速内存比较
44    #[cfg(target_arch = "x86_64")]
45    #[inline(always)]
46    pub unsafe fn compare_avx2(a: *const u8, b: *const u8, len: usize) -> bool {
47        let mut offset = 0;
48
49        // 32字节对齐的批量比较
50        while offset + 32 <= len {
51            let va = _mm256_loadu_si256(a.add(offset) as *const __m256i);
52            let vb = _mm256_loadu_si256(b.add(offset) as *const __m256i);
53            let cmp = _mm256_cmpeq_epi8(va, vb);
54            let mask = _mm256_movemask_epi8(cmp);
55
56            if mask != -1 {
57                return false;
58            }
59            offset += 32;
60        }
61
62        // 处理剩余字节
63        while offset < len {
64            if *a.add(offset) != *b.add(offset) {
65                return false;
66            }
67            offset += 1;
68        }
69
70        true
71    }
72
73    /// 使用通用方法比较内存(非x86_64架构)
74    #[cfg(not(target_arch = "x86_64"))]
75    #[inline(always)]
76    pub unsafe fn compare_avx2(a: *const u8, b: *const u8, len: usize) -> bool {
77        std::slice::from_raw_parts(a, len) == std::slice::from_raw_parts(b, len)
78    }
79
80    /// 使用 SIMD 清零内存
81    #[cfg(target_arch = "x86_64")]
82    #[inline(always)]
83    pub unsafe fn zero_avx2(ptr: *mut u8, len: usize) {
84        let zero = _mm256_setzero_si256();
85        let mut offset = 0;
86
87        // 32字节对齐的批量清零
88        while offset + 32 <= len {
89            _mm256_storeu_si256(ptr.add(offset) as *mut __m256i, zero);
90            offset += 32;
91        }
92
93        // 处理剩余字节
94        while offset < len {
95            *ptr.add(offset) = 0;
96            offset += 1;
97        }
98    }
99
100    /// 使用通用方法清零内存(非x86_64架构)
101    #[cfg(not(target_arch = "x86_64"))]
102    #[inline(always)]
103    pub unsafe fn zero_avx2(ptr: *mut u8, len: usize) {
104        std::ptr::write_bytes(ptr, 0, len);
105    }
106}
107
108/// SIMD 数学运算
109pub struct SIMDMath;
110
111impl SIMDMath {
112    /// 批量 u64 加法 - x86_64 版本
113    #[cfg(target_arch = "x86_64")]
114    #[inline(always)]
115    pub unsafe fn add_u64_batch(a: &[u64], b: &[u64], result: &mut [u64]) {
116        assert_eq!(a.len(), b.len());
117        assert_eq!(a.len(), result.len());
118
119        let len = a.len();
120        let mut i = 0;
121
122        // 4个 u64 一组处理(256位)
123        while i + 4 <= len {
124            let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
125            let vb = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
126            let vsum = _mm256_add_epi64(va, vb);
127            _mm256_storeu_si256(result.as_mut_ptr().add(i) as *mut __m256i, vsum);
128            i += 4;
129        }
130
131        // 处理剩余元素
132        while i < len {
133            result[i] = a[i].wrapping_add(b[i]);
134            i += 1;
135        }
136    }
137
138    /// 批量 u64 加法 - 通用版本(非x86_64架构)
139    #[cfg(not(target_arch = "x86_64"))]
140    #[inline(always)]
141    pub fn add_u64_batch(a: &[u64], b: &[u64], result: &mut [u64]) {
142        assert_eq!(a.len(), b.len());
143        assert_eq!(a.len(), result.len());
144
145        for i in 0..a.len() {
146            result[i] = a[i].wrapping_add(b[i]);
147        }
148    }
149
150    /// 批量查找最大值
151    #[inline(always)]
152    pub fn max_u64_batch(data: &[u64]) -> u64 {
153        if data.is_empty() {
154            return 0;
155        }
156
157        let mut max = data[0];
158        for &val in &data[1..] {
159            if val > max {
160                max = val;
161            }
162        }
163        max
164    }
165
166    /// 批量查找最小值
167    #[inline(always)]
168    pub fn min_u64_batch(data: &[u64]) -> u64 {
169        if data.is_empty() {
170            return 0;
171        }
172
173        let mut min = data[0];
174        for &val in &data[1..] {
175            if val < min {
176                min = val;
177            }
178        }
179        min
180    }
181}
182
183/// SIMD 序列化优化
184pub struct SIMDSerializer;
185
186impl SIMDSerializer {
187    /// 批量序列化 u64 数组
188    #[inline(always)]
189    pub fn serialize_u64_batch(data: &[u64]) -> Vec<u8> {
190        let mut result = Vec::with_capacity(data.len() * 8);
191
192        for &value in data {
193            result.extend_from_slice(&value.to_le_bytes());
194        }
195
196        result
197    }
198
199    /// 批量反序列化 u64 数组
200    #[inline(always)]
201    pub fn deserialize_u64_batch(data: &[u8]) -> Vec<u64> {
202        let count = data.len() / 8;
203        let mut result = Vec::with_capacity(count);
204
205        for i in 0..count {
206            let offset = i * 8;
207            let bytes = [
208                data[offset],
209                data[offset + 1],
210                data[offset + 2],
211                data[offset + 3],
212                data[offset + 4],
213                data[offset + 5],
214                data[offset + 6],
215                data[offset + 7],
216            ];
217            result.push(u64::from_le_bytes(bytes));
218        }
219
220        result
221    }
222
223    /// 使用 SIMD 加速 Base64 编码(简化版)
224    #[inline(always)]
225    pub fn encode_base64_simd(data: &[u8]) -> String {
226        use base64::Engine;
227        base64::engine::general_purpose::STANDARD.encode(data)
228    }
229}
230
231/// SIMD 哈希计算
232pub struct SIMDHash;
233
234impl SIMDHash {
235    /// 批量计算 SHA256 哈希
236    #[inline(always)]
237    pub fn hash_batch_sha256(data: &[&[u8]]) -> Vec<[u8; 32]> {
238        use sha2::{Digest, Sha256};
239
240        data.iter()
241            .map(|item| {
242                let mut hasher = Sha256::new();
243                hasher.update(item);
244                hasher.finalize().into()
245            })
246            .collect()
247    }
248
249    /// 快速哈希(非加密)
250    #[inline(always)]
251    pub fn fast_hash_u64(data: &[u8]) -> u64 {
252        let mut hash: u64 = 0xcbf29ce484222325; // FNV-1a offset
253
254        for &byte in data {
255            hash ^= byte as u64;
256            hash = hash.wrapping_mul(0x100000001b3); // FNV-1a prime
257        }
258
259        hash
260    }
261}
262
263/// SIMD 向量化迭代器
264pub struct SIMDIterator;
265
266impl SIMDIterator {
267    /// 并行处理切片
268    #[inline(always)]
269    pub fn parallel_map<T, F>(data: &[T], f: F) -> Vec<T>
270    where
271        T: Copy + Send + Sync,
272        F: Fn(T) -> T + Send + Sync,
273    {
274        data.iter().map(|&x| f(x)).collect()
275    }
276
277    /// 并行过滤
278    #[inline(always)]
279    pub fn parallel_filter<T, F>(data: &[T], predicate: F) -> Vec<T>
280    where
281        T: Copy + Send + Sync,
282        F: Fn(&T) -> bool + Send + Sync,
283    {
284        data.iter().filter(|x| predicate(x)).copied().collect()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_simd_memory_copy() {
294        let src = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
295        let mut dst = vec![0u8; 10];
296
297        unsafe {
298            SIMDMemory::copy_avx2(dst.as_mut_ptr(), src.as_ptr(), src.len());
299        }
300
301        assert_eq!(src, dst);
302    }
303
304    #[test]
305    fn test_simd_math() {
306        let a = vec![1u64, 2, 3, 4];
307        let b = vec![5u64, 6, 7, 8];
308        let mut result = vec![0u64; 4];
309
310        #[cfg(target_arch = "x86_64")]
311        unsafe {
312            SIMDMath::add_u64_batch(&a, &b, &mut result);
313        }
314
315        #[cfg(not(target_arch = "x86_64"))]
316        SIMDMath::add_u64_batch(&a, &b, &mut result);
317
318        assert_eq!(result, vec![6, 8, 10, 12]);
319    }
320
321    #[test]
322    fn test_fast_hash() {
323        let data = b"hello world";
324        let hash1 = SIMDHash::fast_hash_u64(data);
325        let hash2 = SIMDHash::fast_hash_u64(data);
326
327        assert_eq!(hash1, hash2);
328    }
329}