Skip to main content

xore_process/
simd.rs

1//! SIMD 优化模块 - 高性能数值计算
2//!
3//! 提供优化的数值计算函数,用于提升数据处理性能。
4//! 注意:由于 std::simd 仍不稳定,这里使用手动优化和循环展开技术。
5
6/// 优化的 f64 数组求和
7///
8/// 使用循环展开技术提升性能
9pub fn sum_f64_simd(data: &[f64]) -> f64 {
10    if data.is_empty() {
11        return 0.0;
12    }
13
14    // 使用 4 路循环展开
15    let mut sum0 = 0.0;
16    let mut sum1 = 0.0;
17    let mut sum2 = 0.0;
18    let mut sum3 = 0.0;
19
20    let chunks = data.chunks_exact(4);
21    let remainder = chunks.remainder();
22
23    for chunk in chunks {
24        sum0 += chunk[0];
25        sum1 += chunk[1];
26        sum2 += chunk[2];
27        sum3 += chunk[3];
28    }
29
30    let mut sum = sum0 + sum1 + sum2 + sum3;
31
32    // 处理剩余元素
33    for &val in remainder {
34        sum += val;
35    }
36
37    sum
38}
39
40/// 优化的 f64 数组均值计算
41pub fn mean_f64_simd(data: &[f64]) -> Option<f64> {
42    if data.is_empty() {
43        return None;
44    }
45
46    let sum = sum_f64_simd(data);
47    Some(sum / data.len() as f64)
48}
49
50/// 优化的 f64 数组方差计算
51///
52/// 使用两次遍历算法:
53/// 1. 计算均值
54/// 2. 计算平方差之和
55pub fn variance_f64_simd(data: &[f64]) -> Option<f64> {
56    if data.len() < 2 {
57        return None;
58    }
59
60    let mean = mean_f64_simd(data)?;
61
62    // 使用 4 路循环展开计算平方差
63    let mut sum0 = 0.0;
64    let mut sum1 = 0.0;
65    let mut sum2 = 0.0;
66    let mut sum3 = 0.0;
67
68    let chunks = data.chunks_exact(4);
69    let remainder = chunks.remainder();
70
71    for chunk in chunks {
72        let diff0 = chunk[0] - mean;
73        let diff1 = chunk[1] - mean;
74        let diff2 = chunk[2] - mean;
75        let diff3 = chunk[3] - mean;
76
77        sum0 += diff0 * diff0;
78        sum1 += diff1 * diff1;
79        sum2 += diff2 * diff2;
80        sum3 += diff3 * diff3;
81    }
82
83    let mut variance = sum0 + sum1 + sum2 + sum3;
84
85    // 处理剩余元素
86    for &val in remainder {
87        let diff = val - mean;
88        variance += diff * diff;
89    }
90
91    Some(variance / (data.len() - 1) as f64)
92}
93
94/// 优化的 f64 数组标准差计算
95pub fn std_dev_f64_simd(data: &[f64]) -> Option<f64> {
96    variance_f64_simd(data).map(|v| v.sqrt())
97}
98
99/// 优化的 f64 数组最小值查找
100pub fn min_f64_simd(data: &[f64]) -> Option<f64> {
101    if data.is_empty() {
102        return None;
103    }
104
105    // 使用 4 路并行查找
106    let mut min0 = f64::INFINITY;
107    let mut min1 = f64::INFINITY;
108    let mut min2 = f64::INFINITY;
109    let mut min3 = f64::INFINITY;
110
111    let chunks = data.chunks_exact(4);
112    let remainder = chunks.remainder();
113
114    for chunk in chunks {
115        if chunk[0] < min0 {
116            min0 = chunk[0];
117        }
118        if chunk[1] < min1 {
119            min1 = chunk[1];
120        }
121        if chunk[2] < min2 {
122            min2 = chunk[2];
123        }
124        if chunk[3] < min3 {
125            min3 = chunk[3];
126        }
127    }
128
129    let mut min_val = min0.min(min1).min(min2).min(min3);
130
131    for &val in remainder {
132        if val < min_val {
133            min_val = val;
134        }
135    }
136
137    Some(min_val)
138}
139
140/// 优化的 f64 数组最大值查找
141pub fn max_f64_simd(data: &[f64]) -> Option<f64> {
142    if data.is_empty() {
143        return None;
144    }
145
146    // 使用 4 路并行查找
147    let mut max0 = f64::NEG_INFINITY;
148    let mut max1 = f64::NEG_INFINITY;
149    let mut max2 = f64::NEG_INFINITY;
150    let mut max3 = f64::NEG_INFINITY;
151
152    let chunks = data.chunks_exact(4);
153    let remainder = chunks.remainder();
154
155    for chunk in chunks {
156        if chunk[0] > max0 {
157            max0 = chunk[0];
158        }
159        if chunk[1] > max1 {
160            max1 = chunk[1];
161        }
162        if chunk[2] > max2 {
163            max2 = chunk[2];
164        }
165        if chunk[3] > max3 {
166            max3 = chunk[3];
167        }
168    }
169
170    let mut max_val = max0.max(max1).max(max2).max(max3);
171
172    for &val in remainder {
173        if val > max_val {
174            max_val = val;
175        }
176    }
177
178    Some(max_val)
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_sum_f64_simd() {
187        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
188        let sum = sum_f64_simd(&data);
189        assert_eq!(sum, 15.0);
190    }
191
192    #[test]
193    fn test_sum_f64_simd_empty() {
194        let data: Vec<f64> = vec![];
195        let sum = sum_f64_simd(&data);
196        assert_eq!(sum, 0.0);
197    }
198
199    #[test]
200    fn test_sum_f64_simd_large() {
201        // 测试大于 4 的数据
202        let data: Vec<f64> = (1..=100).map(|x| x as f64).collect();
203        let sum = sum_f64_simd(&data);
204        let expected: f64 = (1..=100).map(|x| x as f64).sum();
205        assert!((sum - expected).abs() < 1e-10);
206    }
207
208    #[test]
209    fn test_mean_f64_simd() {
210        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
211        let mean = mean_f64_simd(&data).unwrap();
212        assert_eq!(mean, 3.0);
213    }
214
215    #[test]
216    fn test_mean_f64_simd_empty() {
217        let data: Vec<f64> = vec![];
218        assert!(mean_f64_simd(&data).is_none());
219    }
220
221    #[test]
222    fn test_variance_f64_simd() {
223        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
224        let variance = variance_f64_simd(&data).unwrap();
225        // 样本方差 = 4.571428...
226        assert!((variance - 4.571428571428571).abs() < 1e-10);
227    }
228
229    #[test]
230    fn test_variance_f64_simd_small() {
231        let data = vec![1.0];
232        assert!(variance_f64_simd(&data).is_none());
233    }
234
235    #[test]
236    fn test_std_dev_f64_simd() {
237        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
238        let std_dev = std_dev_f64_simd(&data).unwrap();
239        // 标准差 = sqrt(4.571428...) = 2.138...
240        assert!((std_dev - 2.1380899352993947).abs() < 1e-10);
241    }
242
243    #[test]
244    fn test_min_f64_simd() {
245        let data = vec![5.0, 2.0, 8.0, 1.0, 9.0, 3.0];
246        let min = min_f64_simd(&data).unwrap();
247        assert_eq!(min, 1.0);
248    }
249
250    #[test]
251    fn test_min_f64_simd_empty() {
252        let data: Vec<f64> = vec![];
253        assert!(min_f64_simd(&data).is_none());
254    }
255
256    #[test]
257    fn test_max_f64_simd() {
258        let data = vec![5.0, 2.0, 8.0, 1.0, 9.0, 3.0];
259        let max = max_f64_simd(&data).unwrap();
260        assert_eq!(max, 9.0);
261    }
262
263    #[test]
264    fn test_max_f64_simd_empty() {
265        let data: Vec<f64> = vec![];
266        assert!(max_f64_simd(&data).is_none());
267    }
268
269    #[test]
270    fn test_simd_consistency() {
271        // 测试优化实现与标准实现的一致性
272        let data: Vec<f64> = (1..=1000).map(|x| x as f64 * 0.1).collect();
273
274        let simd_sum = sum_f64_simd(&data);
275        let std_sum: f64 = data.iter().sum();
276        assert!((simd_sum - std_sum).abs() < 1e-8);
277
278        let simd_mean = mean_f64_simd(&data).unwrap();
279        let std_mean = std_sum / data.len() as f64;
280        assert!((simd_mean - std_mean).abs() < 1e-8);
281    }
282
283    #[test]
284    fn test_unaligned_data() {
285        // 测试非 4 的倍数的数据
286        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
287        let sum = sum_f64_simd(&data);
288        assert_eq!(sum, 28.0);
289
290        let mean = mean_f64_simd(&data).unwrap();
291        assert_eq!(mean, 4.0);
292    }
293}