1use crate::types::Precision;
7use alloc::vec::Vec;
8
9#[cfg(feature = "simd")]
10use wide::f64x4;
11
12#[cfg(all(feature = "std", feature = "rayon"))]
13use rayon::prelude::*;
14
15#[cfg(feature = "simd")]
20pub fn matrix_vector_multiply_simd(
21 values: &[Precision],
22 col_indices: &[u32],
23 row_ptr: &[u32],
24 x: &[Precision],
25 y: &mut [Precision],
26) {
27 y.fill(0.0);
28
29 for row in 0..y.len() {
30 let start = row_ptr[row] as usize;
31 let end = row_ptr[row + 1] as usize;
32
33 if end <= start {
34 continue;
35 }
36
37 let row_values = &values[start..end];
38 let row_indices = &col_indices[start..end];
39 let nnz = row_values.len();
40
41 if nnz >= 8 {
42 let simd_chunks = nnz / 4;
44 let mut sum = f64x4::splat(0.0);
45
46 for chunk in 0..simd_chunks {
47 let idx = chunk * 4;
48
49 let vals = f64x4::new([
51 row_values[idx],
52 row_values[idx + 1],
53 row_values[idx + 2],
54 row_values[idx + 3],
55 ]);
56
57 let x_vals = f64x4::new([
59 x[row_indices[idx] as usize],
60 x[row_indices[idx + 1] as usize],
61 x[row_indices[idx + 2] as usize],
62 x[row_indices[idx + 3] as usize],
63 ]);
64
65 sum = sum + (vals * x_vals);
67 }
68
69 let sum_array = sum.to_array();
71 y[row] = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
72
73 for i in (simd_chunks * 4)..nnz {
75 let col = row_indices[i] as usize;
76 y[row] += row_values[i] * x[col];
77 }
78 } else {
79 let mut sum = 0.0;
81 for i in 0..nnz {
82 let col = row_indices[i] as usize;
83 sum += row_values[i] * x[col];
84 }
85 y[row] = sum;
86 }
87 }
88}
89
90#[cfg(not(feature = "simd"))]
92pub fn matrix_vector_multiply_simd(
93 values: &[Precision],
94 col_indices: &[u32],
95 row_ptr: &[u32],
96 x: &[Precision],
97 y: &mut [Precision],
98) {
99 y.fill(0.0);
100
101 for row in 0..y.len() {
102 let start = row_ptr[row] as usize;
103 let end = row_ptr[row + 1] as usize;
104
105 let mut sum = 0.0;
106 for i in start..end {
107 let col = col_indices[i] as usize;
108 sum += values[i] * x[col];
109 }
110 y[row] = sum;
111 }
112}
113
114#[cfg(feature = "simd")]
116pub fn dot_product_simd(x: &[Precision], y: &[Precision]) -> Precision {
117 assert_eq!(x.len(), y.len());
118
119 let n = x.len();
120 let simd_chunks = n / 4;
121 let mut sum = f64x4::splat(0.0);
122
123 for chunk in 0..simd_chunks {
125 let idx = chunk * 4;
126
127 let x_vals = f64x4::new([
128 x[idx], x[idx + 1], x[idx + 2], x[idx + 3]
129 ]);
130 let y_vals = f64x4::new([
131 y[idx], y[idx + 1], y[idx + 2], y[idx + 3]
132 ]);
133
134 sum = sum + (x_vals * y_vals);
135 }
136
137 let sum_array = sum.to_array();
139 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
140
141 for i in (simd_chunks * 4)..n {
143 result += x[i] * y[i];
144 }
145
146 result
147}
148
149#[cfg(not(feature = "simd"))]
151pub fn dot_product_simd(x: &[Precision], y: &[Precision]) -> Precision {
152 assert_eq!(x.len(), y.len());
153 x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
154}
155
156#[cfg(feature = "simd")]
158pub fn axpy_simd(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
159 assert_eq!(x.len(), y.len());
160
161 let n = x.len();
162 let simd_chunks = n / 4;
163 let alpha_vec = f64x4::splat(alpha);
164
165 for chunk in 0..simd_chunks {
167 let idx = chunk * 4;
168
169 let x_vals = f64x4::new([
170 x[idx], x[idx + 1], x[idx + 2], x[idx + 3]
171 ]);
172 let y_vals = f64x4::new([
173 y[idx], y[idx + 1], y[idx + 2], y[idx + 3]
174 ]);
175
176 let result = (alpha_vec * x_vals) + y_vals;
177 let result_array = result.to_array();
178
179 y[idx] = result_array[0];
180 y[idx + 1] = result_array[1];
181 y[idx + 2] = result_array[2];
182 y[idx + 3] = result_array[3];
183 }
184
185 for i in (simd_chunks * 4)..n {
187 y[i] += alpha * x[i];
188 }
189}
190
191#[cfg(not(feature = "simd"))]
193pub fn axpy_simd(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
194 assert_eq!(x.len(), y.len());
195 for (y_val, &x_val) in y.iter_mut().zip(x.iter()) {
196 *y_val += alpha * x_val;
197 }
198}
199
200#[cfg(all(feature = "std", feature = "rayon"))]
202pub fn parallel_matrix_vector_multiply(
203 values: &[Precision],
204 col_indices: &[u32],
205 row_ptr: &[u32],
206 x: &[Precision],
207 y: &mut [Precision],
208 num_threads: Option<usize>,
209) {
210 y.fill(0.0);
211
212 let num_threads = num_threads.unwrap_or_else(|| {
213 std::thread::available_parallelism()
214 .map(|p| p.get())
215 .unwrap_or(1)
216 });
217
218 let rows = y.len();
219 let chunk_size = (rows + num_threads - 1) / num_threads;
220
221 y.par_chunks_mut(chunk_size)
222 .enumerate()
223 .for_each(|(chunk_idx, y_chunk)| {
224 let start_row = chunk_idx * chunk_size;
225 let end_row = (start_row + y_chunk.len()).min(rows);
226
227 for (local_idx, global_row) in (start_row..end_row).enumerate() {
228 let start = row_ptr[global_row] as usize;
229 let end = row_ptr[global_row + 1] as usize;
230
231 let mut sum = 0.0;
232 for i in start..end {
233 let col = col_indices[i] as usize;
234 sum += values[i] * x[col];
235 }
236 y_chunk[local_idx] = sum;
237 }
238 });
239}
240
241#[cfg(not(all(feature = "std", feature = "rayon")))]
243pub fn parallel_matrix_vector_multiply(
244 values: &[Precision],
245 col_indices: &[u32],
246 row_ptr: &[u32],
247 x: &[Precision],
248 y: &mut [Precision],
249 _num_threads: Option<usize>,
250) {
251 matrix_vector_multiply_simd(values, col_indices, row_ptr, x, y);
252}
253
254#[cfg(all(test, feature = "std"))]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_simd_matrix_vector_multiply() {
260 let values = vec![2.0, 1.0, 1.0, 3.0];
261 let col_indices = vec![0, 1, 0, 1];
262 let row_ptr = vec![0, 2, 4];
263 let x = vec![1.0, 2.0];
264 let mut y = vec![0.0; 2];
265
266 matrix_vector_multiply_simd(&values, &col_indices, &row_ptr, &x, &mut y);
267 assert_eq!(y, vec![4.0, 7.0]);
268 }
269
270 #[test]
271 fn test_simd_dot_product() {
272 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
273 let y = vec![2.0, 3.0, 4.0, 5.0, 6.0];
274 let result = dot_product_simd(&x, &y);
275 assert_eq!(result, 70.0); }
277
278 #[test]
279 fn test_simd_axpy() {
280 let alpha = 2.0;
281 let x = vec![1.0, 2.0, 3.0, 4.0];
282 let mut y = vec![1.0, 1.0, 1.0, 1.0];
283
284 axpy_simd(alpha, &x, &mut y);
285 assert_eq!(y, vec![3.0, 5.0, 7.0, 9.0]);
286 }
287}