1use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
8use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
9
10#[derive(Debug, Clone)]
12pub struct SimdConfig {
13 pub avx2_available: bool,
15 pub sse41_available: bool,
17 pub fma_available: bool,
19 pub vector_width: usize,
21}
22
23impl SimdConfig {
24 pub fn detect() -> Self {
26 let caps = PlatformCapabilities::detect();
27 Self {
28 avx2_available: caps.avx2_available,
29 sse41_available: caps.simd_available, fma_available: caps.simd_available, vector_width: if caps.avx2_available {
32 4 } else if caps.simd_available {
34 2 } else {
36 1 },
38 }
39 }
40
41 pub fn has_simd(&self) -> bool {
43 self.avx2_available || self.sse41_available
44 }
45}
46
47pub struct SimdVectorOps {
49 optimizer: AutoOptimizer,
50}
51
52impl SimdVectorOps {
53 pub fn new() -> Self {
55 Self {
56 optimizer: AutoOptimizer::new(),
57 }
58 }
59
60 pub fn with_config(config: SimdConfig) -> Self {
62 Self::new()
64 }
65
66 pub fn config(&self) -> SimdConfig {
68 SimdConfig::detect()
69 }
70
71 pub fn platform_capabilities(&self) -> PlatformCapabilities {
73 PlatformCapabilities::detect()
74 }
75
76 pub fn dot_product(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
78 assert_eq!(a.len(), b.len());
79
80 if self.optimizer.should_use_simd(a.len()) {
81 f64::simd_dot(a, b)
82 } else {
83 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
85 }
86 }
87
88 pub fn add(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
90 assert_eq!(a.len(), b.len());
91
92 if self.optimizer.should_use_simd(a.len()) {
93 f64::simd_add(a, b)
94 } else {
95 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai + bi).collect()
96 }
97 }
98
99 pub fn sub(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
101 assert_eq!(a.len(), b.len());
102
103 if self.optimizer.should_use_simd(a.len()) {
104 f64::simd_sub(a, b)
105 } else {
106 a.iter().zip(b.iter()).map(|(&ai, &bi)| ai - bi).collect()
107 }
108 }
109
110 pub fn scale(&self, alpha: f64, a: &ArrayView1<f64>) -> Array1<f64> {
112 if self.optimizer.should_use_simd(a.len()) {
113 f64::simd_scalar_mul(a, alpha)
114 } else {
115 a.iter().map(|&ai| alpha * ai).collect()
116 }
117 }
118
119 pub fn axpy(&self, alpha: f64, x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> Array1<f64> {
121 assert_eq!(x.len(), y.len());
122
123 if self.optimizer.should_use_simd(x.len()) {
124 let alpha_x = f64::simd_scalar_mul(x, alpha);
126 f64::simd_add(&alpha_x.view(), y)
127 } else {
128 x.iter()
129 .zip(y.iter())
130 .map(|(&xi, &yi)| alpha * xi + yi)
131 .collect()
132 }
133 }
134
135 pub fn norm(&self, a: &ArrayView1<f64>) -> f64 {
137 if self.optimizer.should_use_simd(a.len()) {
138 f64::simd_norm(a)
139 } else {
140 a.iter().map(|&ai| ai * ai).sum::<f64>().sqrt()
141 }
142 }
143
144 pub fn matvec(&self, matrix: &ArrayView2<f64>, vector: &ArrayView1<f64>) -> Array1<f64> {
146 assert_eq!(matrix.ncols(), vector.len());
147
148 let mut result = Array1::zeros(matrix.nrows());
149 for (i, row) in matrix.outer_iter().enumerate() {
150 result[i] = self.dot_product(&row, vector);
151 }
152 result
153 }
154}
155
156impl Default for SimdVectorOps {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use approx::assert_abs_diff_eq;
166 use scirs2_core::ndarray::array;
167
168 #[test]
169 fn test_platform_capabilities() {
170 let ops = SimdVectorOps::new();
171 let caps = ops.platform_capabilities();
172
173 println!(
175 "Platform capabilities - SIMD: {}, GPU: {}, AVX2: {}",
176 caps.simd_available, caps.gpu_available, caps.avx2_available
177 );
178 }
179
180 #[test]
181 fn test_dot_product() {
182 let ops = SimdVectorOps::new();
183 let a = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
184 let b = array![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
185
186 let result = ops.dot_product(&a.view(), &b.view());
187 let expected = 240.0; assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
190 }
191
192 #[test]
193 fn test_vector_operations() {
194 let ops = SimdVectorOps::new();
195 let a = array![1.0, 2.0, 3.0, 4.0];
196 let b = array![5.0, 6.0, 7.0, 8.0];
197
198 let sum = ops.add(&a.view(), &b.view());
200 assert_abs_diff_eq!(sum[0], 6.0, epsilon = 1e-10);
201 assert_abs_diff_eq!(sum[1], 8.0, epsilon = 1e-10);
202 assert_abs_diff_eq!(sum[2], 10.0, epsilon = 1e-10);
203 assert_abs_diff_eq!(sum[3], 12.0, epsilon = 1e-10);
204
205 let diff = ops.sub(&b.view(), &a.view());
207 assert_abs_diff_eq!(diff[0], 4.0, epsilon = 1e-10);
208 assert_abs_diff_eq!(diff[1], 4.0, epsilon = 1e-10);
209 assert_abs_diff_eq!(diff[2], 4.0, epsilon = 1e-10);
210 assert_abs_diff_eq!(diff[3], 4.0, epsilon = 1e-10);
211
212 let scaled = ops.scale(2.0, &a.view());
214 assert_abs_diff_eq!(scaled[0], 2.0, epsilon = 1e-10);
215 assert_abs_diff_eq!(scaled[1], 4.0, epsilon = 1e-10);
216 assert_abs_diff_eq!(scaled[2], 6.0, epsilon = 1e-10);
217 assert_abs_diff_eq!(scaled[3], 8.0, epsilon = 1e-10);
218 }
219
220 #[test]
221 fn test_axpy() {
222 let ops = SimdVectorOps::new();
223 let x = array![1.0, 2.0, 3.0, 4.0];
224 let y = array![5.0, 6.0, 7.0, 8.0];
225 let alpha = 2.0;
226
227 let result = ops.axpy(alpha, &x.view(), &y.view());
228
229 assert_abs_diff_eq!(result[0], 7.0, epsilon = 1e-10);
231 assert_abs_diff_eq!(result[1], 10.0, epsilon = 1e-10);
232 assert_abs_diff_eq!(result[2], 13.0, epsilon = 1e-10);
233 assert_abs_diff_eq!(result[3], 16.0, epsilon = 1e-10);
234 }
235
236 #[test]
237 fn test_norm() {
238 let ops = SimdVectorOps::new();
239 let a = array![3.0, 4.0]; let norm = ops.norm(&a.view());
242 assert_abs_diff_eq!(norm, 5.0, epsilon = 1e-10);
243 }
244
245 #[test]
246 fn test_matvec() {
247 let ops = SimdVectorOps::new();
248 let matrix = array![[1.0, 2.0], [3.0, 4.0]];
249 let vector = array![1.0, 2.0];
250
251 let result = ops.matvec(&matrix.view(), &vector.view());
252
253 assert_abs_diff_eq!(result[0], 5.0, epsilon = 1e-10);
255 assert_abs_diff_eq!(result[1], 11.0, epsilon = 1e-10);
256 }
257
258 #[test]
259 #[ignore = "timeout"]
260 fn test_large_vectors() {
261 let ops = SimdVectorOps::new();
262 let n = 100;
264 let a: Array1<f64> = Array1::from_shape_fn(n, |i| i as f64);
265 let b: Array1<f64> = Array1::from_shape_fn(n, |i| (i + 1) as f64);
266
267 let dot_result = ops.dot_product(&a.view(), &b.view());
269 let norm_result = ops.norm(&a.view());
270 let add_result = ops.add(&a.view(), &b.view());
271
272 assert!(dot_result > 0.0);
273 assert!(norm_result > 0.0);
274 assert_eq!(add_result.len(), n);
275 }
276}