Skip to main content

sublinear_solver/
simd_ops.rs

1//! SIMD-accelerated linear algebra operations for high-performance computing.
2//!
3//! This module provides vectorized implementations of core matrix and vector
4//! operations using SIMD intrinsics for maximum performance on modern CPUs.
5
6use crate::types::Precision;
7
8#[cfg(feature = "simd")]
9use wide::f64x4;
10
11#[cfg(all(feature = "std", feature = "rayon"))]
12use rayon::prelude::*;
13
14/// SIMD-accelerated matrix-vector multiplication: y = A * x
15///
16/// This function uses SIMD intrinsics to perform matrix-vector multiplication
17/// with optimal memory access patterns and vectorization.
18#[cfg(feature = "simd")]
19pub fn matrix_vector_multiply_simd(
20    values: &[Precision],
21    col_indices: &[u32],
22    row_ptr: &[u32],
23    x: &[Precision],
24    y: &mut [Precision],
25) {
26    y.fill(0.0);
27
28    for row in 0..y.len() {
29        let start = row_ptr[row] as usize;
30        let end = row_ptr[row + 1] as usize;
31
32        if end <= start {
33            continue;
34        }
35
36        let row_values = &values[start..end];
37        let row_indices = &col_indices[start..end];
38        let nnz = row_values.len();
39
40        if nnz >= 8 {
41            // Process in chunks of 4 for AVX2/SIMD128
42            let simd_chunks = nnz / 4;
43            let mut sum = f64x4::splat(0.0);
44
45            for chunk in 0..simd_chunks {
46                let idx = chunk * 4;
47
48                // Load 4 matrix values
49                let vals = f64x4::new([
50                    row_values[idx],
51                    row_values[idx + 1],
52                    row_values[idx + 2],
53                    row_values[idx + 3],
54                ]);
55
56                // Load corresponding x values (gather operation)
57                let x_vals = f64x4::new([
58                    x[row_indices[idx] as usize],
59                    x[row_indices[idx + 1] as usize],
60                    x[row_indices[idx + 2] as usize],
61                    x[row_indices[idx + 3] as usize],
62                ]);
63
64                // Multiply and accumulate
65                sum = sum + (vals * x_vals);
66            }
67
68            // Sum the SIMD register horizontally
69            let sum_array = sum.to_array();
70            y[row] = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
71
72            // Handle remaining elements
73            for i in (simd_chunks * 4)..nnz {
74                let col = row_indices[i] as usize;
75                y[row] += row_values[i] * x[col];
76            }
77        } else {
78            // For small rows, use scalar code (avoid SIMD overhead)
79            let mut sum = 0.0;
80            for i in 0..nnz {
81                let col = row_indices[i] as usize;
82                sum += row_values[i] * x[col];
83            }
84            y[row] = sum;
85        }
86    }
87}
88
89/// Fallback implementation for when SIMD is not available
90#[cfg(not(feature = "simd"))]
91pub fn matrix_vector_multiply_simd(
92    values: &[Precision],
93    col_indices: &[u32],
94    row_ptr: &[u32],
95    x: &[Precision],
96    y: &mut [Precision],
97) {
98    y.fill(0.0);
99
100    for row in 0..y.len() {
101        let start = row_ptr[row] as usize;
102        let end = row_ptr[row + 1] as usize;
103
104        let mut sum = 0.0;
105        for i in start..end {
106            let col = col_indices[i] as usize;
107            sum += values[i] * x[col];
108        }
109        y[row] = sum;
110    }
111}
112
113/// SIMD-accelerated dot product: result = x^T * y
114#[cfg(feature = "simd")]
115pub fn dot_product_simd(x: &[Precision], y: &[Precision]) -> Precision {
116    assert_eq!(x.len(), y.len());
117
118    let n = x.len();
119    let simd_chunks = n / 4;
120    let mut sum = f64x4::splat(0.0);
121
122    // Process in chunks of 4
123    for chunk in 0..simd_chunks {
124        let idx = chunk * 4;
125
126        let x_vals = f64x4::new([x[idx], x[idx + 1], x[idx + 2], x[idx + 3]]);
127        let y_vals = f64x4::new([y[idx], y[idx + 1], y[idx + 2], y[idx + 3]]);
128
129        sum = sum + (x_vals * y_vals);
130    }
131
132    // Sum the SIMD register
133    let sum_array = sum.to_array();
134    let mut result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
135
136    // Handle remaining elements
137    for i in (simd_chunks * 4)..n {
138        result += x[i] * y[i];
139    }
140
141    result
142}
143
144/// Fallback dot product implementation
145#[cfg(not(feature = "simd"))]
146pub fn dot_product_simd(x: &[Precision], y: &[Precision]) -> Precision {
147    assert_eq!(x.len(), y.len());
148    x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
149}
150
151/// SIMD-accelerated AXPY operation: y = alpha * x + y
152#[cfg(feature = "simd")]
153pub fn axpy_simd(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
154    assert_eq!(x.len(), y.len());
155
156    let n = x.len();
157    let simd_chunks = n / 4;
158    let alpha_vec = f64x4::splat(alpha);
159
160    // Process in chunks of 4
161    for chunk in 0..simd_chunks {
162        let idx = chunk * 4;
163
164        let x_vals = f64x4::new([x[idx], x[idx + 1], x[idx + 2], x[idx + 3]]);
165        let y_vals = f64x4::new([y[idx], y[idx + 1], y[idx + 2], y[idx + 3]]);
166
167        let result = (alpha_vec * x_vals) + y_vals;
168        let result_array = result.to_array();
169
170        y[idx] = result_array[0];
171        y[idx + 1] = result_array[1];
172        y[idx + 2] = result_array[2];
173        y[idx + 3] = result_array[3];
174    }
175
176    // Handle remaining elements
177    for i in (simd_chunks * 4)..n {
178        y[i] += alpha * x[i];
179    }
180}
181
182/// Fallback AXPY implementation
183#[cfg(not(feature = "simd"))]
184pub fn axpy_simd(alpha: Precision, x: &[Precision], y: &mut [Precision]) {
185    assert_eq!(x.len(), y.len());
186    for (y_val, &x_val) in y.iter_mut().zip(x.iter()) {
187        *y_val += alpha * x_val;
188    }
189}
190
191/// Parallel matrix-vector multiplication using Rayon for very large matrices
192#[cfg(all(feature = "std", feature = "rayon"))]
193pub fn parallel_matrix_vector_multiply(
194    values: &[Precision],
195    col_indices: &[u32],
196    row_ptr: &[u32],
197    x: &[Precision],
198    y: &mut [Precision],
199    num_threads: Option<usize>,
200) {
201    y.fill(0.0);
202
203    let num_threads = num_threads.unwrap_or_else(|| {
204        std::thread::available_parallelism()
205            .map(|p| p.get())
206            .unwrap_or(1)
207    });
208
209    let rows = y.len();
210    let chunk_size = (rows + num_threads - 1) / num_threads;
211
212    y.par_chunks_mut(chunk_size)
213        .enumerate()
214        .for_each(|(chunk_idx, y_chunk)| {
215            let start_row = chunk_idx * chunk_size;
216            let end_row = (start_row + y_chunk.len()).min(rows);
217
218            for (local_idx, global_row) in (start_row..end_row).enumerate() {
219                let start = row_ptr[global_row] as usize;
220                let end = row_ptr[global_row + 1] as usize;
221
222                let mut sum = 0.0;
223                for i in start..end {
224                    let col = col_indices[i] as usize;
225                    sum += values[i] * x[col];
226                }
227                y_chunk[local_idx] = sum;
228            }
229        });
230}
231
232/// Fallback parallel implementation
233#[cfg(not(all(feature = "std", feature = "rayon")))]
234pub fn parallel_matrix_vector_multiply(
235    values: &[Precision],
236    col_indices: &[u32],
237    row_ptr: &[u32],
238    x: &[Precision],
239    y: &mut [Precision],
240    _num_threads: Option<usize>,
241) {
242    matrix_vector_multiply_simd(values, col_indices, row_ptr, x, y);
243}
244
245#[cfg(all(test, feature = "std"))]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_simd_matrix_vector_multiply() {
251        let values = vec![2.0, 1.0, 1.0, 3.0];
252        let col_indices = vec![0, 1, 0, 1];
253        let row_ptr = vec![0, 2, 4];
254        let x = vec![1.0, 2.0];
255        let mut y = vec![0.0; 2];
256
257        matrix_vector_multiply_simd(&values, &col_indices, &row_ptr, &x, &mut y);
258        assert_eq!(y, vec![4.0, 7.0]);
259    }
260
261    #[test]
262    fn test_simd_dot_product() {
263        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
264        let y = vec![2.0, 3.0, 4.0, 5.0, 6.0];
265        let result = dot_product_simd(&x, &y);
266        assert_eq!(result, 70.0); // 1*2 + 2*3 + 3*4 + 4*5 + 5*6
267    }
268
269    #[test]
270    fn test_simd_axpy() {
271        let alpha = 2.0;
272        let x = vec![1.0, 2.0, 3.0, 4.0];
273        let mut y = vec![1.0, 1.0, 1.0, 1.0];
274
275        axpy_simd(alpha, &x, &mut y);
276        assert_eq!(y, vec![3.0, 5.0, 7.0, 9.0]);
277    }
278}