1use crate::types::Precision;
7
8#[cfg(feature = "simd")]
9use wide::f64x4;
10
11#[cfg(all(feature = "std", feature = "rayon"))]
12use rayon::prelude::*;
13
14#[cfg(feature = "simd")]
19pub fn matrix_vector_multiply_simd(
20 values: &[Precision],
21 col_indices: &[u32],
22 row_ptr: &[u32],
23 x: &[Precision],
24 y: &mut [Precision],
25) {
26 y.fill(0.0);
27
28 for row in 0..y.len() {
29 let start = row_ptr[row] as usize;
30 let end = row_ptr[row + 1] as usize;
31
32 if end <= start {
33 continue;
34 }
35
36 let row_values = &values[start..end];
37 let row_indices = &col_indices[start..end];
38 let nnz = row_values.len();
39
40 if nnz >= 8 {
41 let simd_chunks = nnz / 4;
43 let mut sum = f64x4::splat(0.0);
44
45 for chunk in 0..simd_chunks {
46 let idx = chunk * 4;
47
48 let vals = f64x4::new([
50 row_values[idx],
51 row_values[idx + 1],
52 row_values[idx + 2],
53 row_values[idx + 3],
54 ]);
55
56 let x_vals = f64x4::new([
58 x[row_indices[idx] as usize],
59 x[row_indices[idx + 1] as usize],
60 x[row_indices[idx + 2] as usize],
61 x[row_indices[idx + 3] as usize],
62 ]);
63
64 sum = sum + (vals * x_vals);
66 }
67
68 let sum_array = sum.to_array();
70 y[row] = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
71
72 for i in (simd_chunks * 4)..nnz {
74 let col = row_indices[i] as usize;
75 y[row] += row_values[i] * x[col];
76 }
77 } else {
78 let mut sum = 0.0;
80 for i in 0..nnz {
81 let col = row_indices[i] as usize;
82 sum += row_values[i] * x[col];
83 }
84 y[row] = sum;
85 }
86 }
87}
88
89#[cfg(not(feature = "simd"))]
91pub fn matrix_vector_multiply_simd(
92 values: &[Precision],
93 col_indices: &[u32],
94 row_ptr: &[u32],
95 x: &[Precision],
96 y: &mut [Precision],
97) {
98 y.fill(0.0);
99
100 for row in 0..y.len() {
101 let start = row_ptr[row] as usize;
102 let end = row_ptr[row + 1] as usize;
103
104 let mut sum = 0.0;
105 for i in start..end {
106 let col = col_indices[i] as usize;
107 sum += values[i] * x[col];
108 }
109 y[row] = sum;
110 }
111}
112
113#[cfg(feature = "simd")]
115pub fn dot_product_simd(x: &[Precision], y: &[Precision]) -> Precision {
116 assert_eq!(x.len(), y.len());
117
118 let n = x.len();
119 let simd_chunks = n / 4;
120 let mut sum = f64x4::splat(0.0);
121
122 for chunk in 0..simd_chunks {
124 let idx = chunk * 4;
125
126 let x_vals = f64x4::new([x[idx], x[idx + 1], x[idx + 2], x[idx + 3]]);
127 let y_vals = f64x4::new([y[idx], y[idx + 1], y[idx + 2], y[idx + 3]]);
128
129 sum = sum + (x_vals * y_vals);
130 }
131
132 let sum_array = sum.to_array();
134 let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
135
136 for i in (simd_chunks * 4)..n {
138 result += x[i] * y[i];
139 }
140
141 result
142}
143
144#[cfg(not(feature = "simd"))]
146pub fn dot_product_simd(x: &[Precision], y: &[Precision]) -> Precision {
147 assert_eq!(x.len(), y.len());
148 x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
149}
150
151#[cfg(feature = "simd")]
153pub fn axpy_simd(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
154 assert_eq!(x.len(), y.len());
155
156 let n = x.len();
157 let simd_chunks = n / 4;
158 let alpha_vec = f64x4::splat(alpha);
159
160 for chunk in 0..simd_chunks {
162 let idx = chunk * 4;
163
164 let x_vals = f64x4::new([x[idx], x[idx + 1], x[idx + 2], x[idx + 3]]);
165 let y_vals = f64x4::new([y[idx], y[idx + 1], y[idx + 2], y[idx + 3]]);
166
167 let result = (alpha_vec * x_vals) + y_vals;
168 let result_array = result.to_array();
169
170 y[idx] = result_array[0];
171 y[idx + 1] = result_array[1];
172 y[idx + 2] = result_array[2];
173 y[idx + 3] = result_array[3];
174 }
175
176 for i in (simd_chunks * 4)..n {
178 y[i] += alpha * x[i];
179 }
180}
181
182#[cfg(not(feature = "simd"))]
184pub fn axpy_simd(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
185 assert_eq!(x.len(), y.len());
186 for (y_val, &x_val) in y.iter_mut().zip(x.iter()) {
187 *y_val += alpha * x_val;
188 }
189}
190
191#[cfg(all(feature = "std", feature = "rayon"))]
193pub fn parallel_matrix_vector_multiply(
194 values: &[Precision],
195 col_indices: &[u32],
196 row_ptr: &[u32],
197 x: &[Precision],
198 y: &mut [Precision],
199 num_threads: Option<usize>,
200) {
201 y.fill(0.0);
202
203 let num_threads = num_threads.unwrap_or_else(|| {
204 std::thread::available_parallelism()
205 .map(|p| p.get())
206 .unwrap_or(1)
207 });
208
209 let rows = y.len();
210 let chunk_size = (rows + num_threads - 1) / num_threads;
211
212 y.par_chunks_mut(chunk_size)
213 .enumerate()
214 .for_each(|(chunk_idx, y_chunk)| {
215 let start_row = chunk_idx * chunk_size;
216 let end_row = (start_row + y_chunk.len()).min(rows);
217
218 for (local_idx, global_row) in (start_row..end_row).enumerate() {
219 let start = row_ptr[global_row] as usize;
220 let end = row_ptr[global_row + 1] as usize;
221
222 let mut sum = 0.0;
223 for i in start..end {
224 let col = col_indices[i] as usize;
225 sum += values[i] * x[col];
226 }
227 y_chunk[local_idx] = sum;
228 }
229 });
230}
231
232#[cfg(not(all(feature = "std", feature = "rayon")))]
234pub fn parallel_matrix_vector_multiply(
235 values: &[Precision],
236 col_indices: &[u32],
237 row_ptr: &[u32],
238 x: &[Precision],
239 y: &mut [Precision],
240 _num_threads: Option<usize>,
241) {
242 matrix_vector_multiply_simd(values, col_indices, row_ptr, x, y);
243}
244
245#[cfg(all(test, feature = "std"))]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_simd_matrix_vector_multiply() {
251 let values = vec![2.0, 1.0, 1.0, 3.0];
252 let col_indices = vec![0, 1, 0, 1];
253 let row_ptr = vec![0, 2, 4];
254 let x = vec![1.0, 2.0];
255 let mut y = vec![0.0; 2];
256
257 matrix_vector_multiply_simd(&values, &col_indices, &row_ptr, &x, &mut y);
258 assert_eq!(y, vec![4.0, 7.0]);
259 }
260
261 #[test]
262 fn test_simd_dot_product() {
263 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
264 let y = vec![2.0, 3.0, 4.0, 5.0, 6.0];
265 let result = dot_product_simd(&x, &y);
266 assert_eq!(result, 70.0); }
268
269 #[test]
270 fn test_simd_axpy() {
271 let alpha = 2.0;
272 let x = vec![1.0, 2.0, 3.0, 4.0];
273 let mut y = vec![1.0, 1.0, 1.0, 1.0];
274
275 axpy_simd(alpha, &x, &mut y);
276 assert_eq!(y, vec![3.0, 5.0, 7.0, 9.0]);
277 }
278}