sublinear_solver/matrix/
optimized.rs

1//! Optimized sparse matrix implementations with SIMD acceleration and buffer pooling.
2//!
3//! This module provides high-performance matrix storage formats optimized for
4//! sublinear-time algorithms with focus on minimizing memory allocation overhead
5//! and maximizing cache efficiency.
6
7use crate::types::{Precision, DimensionType, IndexType};
8use crate::error::{SolverError, Result};
9use crate::matrix::sparse::{CSRStorage, CSCStorage, COOStorage};
10use alloc::{vec::Vec, collections::VecDeque, boxed::Box};
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::sync::Mutex;
15
16#[cfg(feature = "simd")]
17use wide::f64x4;
18
19/// High-performance buffer pool for reducing allocation overhead.
20///
21/// This pool maintains pre-allocated buffers of various sizes to minimize
22/// runtime allocations during matrix operations.
23pub struct BufferPool {
24    /// Small buffers (< 1KB)
25    small_buffers: VecDeque<Vec<Precision>>,
26    /// Medium buffers (1KB - 64KB)
27    medium_buffers: VecDeque<Vec<Precision>>,
28    /// Large buffers (> 64KB)
29    large_buffers: VecDeque<Vec<Precision>>,
30    /// Statistics
31    allocations: AtomicUsize,
32    deallocations: AtomicUsize,
33    cache_hits: AtomicUsize,
34    cache_misses: AtomicUsize,
35}
36
37/// Buffer size categories
38const SMALL_BUFFER_THRESHOLD: usize = 128;  // 1KB for f64
39const MEDIUM_BUFFER_THRESHOLD: usize = 8192; // 64KB for f64
40
41impl BufferPool {
42    /// Create a new buffer pool with initial capacity.
43    pub fn new() -> Self {
44        Self {
45            small_buffers: VecDeque::with_capacity(16),
46            medium_buffers: VecDeque::with_capacity(8),
47            large_buffers: VecDeque::with_capacity(4),
48            allocations: AtomicUsize::new(0),
49            deallocations: AtomicUsize::new(0),
50            cache_hits: AtomicUsize::new(0),
51            cache_misses: AtomicUsize::new(0),
52        }
53    }
54
55    /// Get a buffer of at least the requested size.
56    pub fn get_buffer(&mut self, min_size: usize) -> Vec<Precision> {
57        self.allocations.fetch_add(1, Ordering::Relaxed);
58
59        let buffer_queue = if min_size <= SMALL_BUFFER_THRESHOLD {
60            &mut self.small_buffers
61        } else if min_size <= MEDIUM_BUFFER_THRESHOLD {
62            &mut self.medium_buffers
63        } else {
64            &mut self.large_buffers
65        };
66
67        // Try to find a suitable buffer
68        for _ in 0..buffer_queue.len() {
69            if let Some(mut buffer) = buffer_queue.pop_front() {
70                if buffer.capacity() >= min_size {
71                    buffer.clear();
72                    buffer.resize(min_size, 0.0);
73                    self.cache_hits.fetch_add(1, Ordering::Relaxed);
74                    return buffer;
75                } else {
76                    buffer_queue.push_back(buffer);
77                }
78            }
79        }
80
81        // No suitable buffer found, allocate new
82        self.cache_misses.fetch_add(1, Ordering::Relaxed);
83        vec![0.0; min_size]
84    }
85
86    /// Return a buffer to the pool.
87    pub fn return_buffer(&mut self, buffer: Vec<Precision>) {
88        self.deallocations.fetch_add(1, Ordering::Relaxed);
89
90        let capacity = buffer.capacity();
91        let buffer_queue = if capacity <= SMALL_BUFFER_THRESHOLD {
92            &mut self.small_buffers
93        } else if capacity <= MEDIUM_BUFFER_THRESHOLD {
94            &mut self.medium_buffers
95        } else {
96            &mut self.large_buffers
97        };
98
99        // Only store if we have room and the buffer is reasonable size
100        if buffer_queue.len() < 32 && capacity < 1_000_000 {
101            buffer_queue.push_back(buffer);
102        }
103        // Otherwise let it drop
104    }
105
106    /// Get buffer pool statistics.
107    pub fn stats(&self) -> BufferPoolStats {
108        BufferPoolStats {
109            allocations: self.allocations.load(Ordering::Relaxed),
110            deallocations: self.deallocations.load(Ordering::Relaxed),
111            cache_hits: self.cache_hits.load(Ordering::Relaxed),
112            cache_misses: self.cache_misses.load(Ordering::Relaxed),
113            small_buffers_pooled: self.small_buffers.len(),
114            medium_buffers_pooled: self.medium_buffers.len(),
115            large_buffers_pooled: self.large_buffers.len(),
116        }
117    }
118
119    /// Clear all pooled buffers to free memory.
120    pub fn clear(&mut self) {
121        self.small_buffers.clear();
122        self.medium_buffers.clear();
123        self.large_buffers.clear();
124    }
125}
126
127impl Default for BufferPool {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133/// Buffer pool statistics.
134#[derive(Debug, Clone, Copy)]
135pub struct BufferPoolStats {
136    pub allocations: usize,
137    pub deallocations: usize,
138    pub cache_hits: usize,
139    pub cache_misses: usize,
140    pub small_buffers_pooled: usize,
141    pub medium_buffers_pooled: usize,
142    pub large_buffers_pooled: usize,
143}
144
145impl BufferPoolStats {
146    /// Calculate cache hit rate as a percentage.
147    pub fn hit_rate(&self) -> f64 {
148        if self.allocations == 0 {
149            0.0
150        } else {
151            (self.cache_hits as f64 / self.allocations as f64) * 100.0
152        }
153    }
154}
155
156/// Thread-safe global buffer pool.
157#[cfg(all(feature = "std", feature = "lazy_static"))]
158lazy_static::lazy_static! {
159    static ref GLOBAL_BUFFER_POOL: Mutex<BufferPool> = Mutex::new(BufferPool::new());
160}
161
162/// Get a buffer from the global pool.
163#[cfg(all(feature = "std", feature = "lazy_static"))]
164pub fn get_global_buffer(min_size: usize) -> Vec<Precision> {
165    GLOBAL_BUFFER_POOL.lock().unwrap().get_buffer(min_size)
166}
167
168/// Return a buffer to the global pool.
169#[cfg(all(feature = "std", feature = "lazy_static"))]
170pub fn return_global_buffer(buffer: Vec<Precision>) {
171    GLOBAL_BUFFER_POOL.lock().unwrap().return_buffer(buffer);
172}
173
174/// Optimized CSR storage with SIMD acceleration and buffer pooling.
175pub struct OptimizedCSRStorage {
176    /// Base CSR storage
177    storage: CSRStorage,
178    /// Buffer pool for temporary vectors
179    buffer_pool: BufferPool,
180    /// Pre-allocated workspace
181    workspace: Vec<Precision>,
182    /// Performance counters
183    matvec_count: AtomicUsize,
184    bytes_processed: AtomicUsize,
185}
186
187impl OptimizedCSRStorage {
188    /// Create optimized CSR storage from COO format.
189    pub fn from_coo(coo: &COOStorage, rows: DimensionType, cols: DimensionType) -> Result<Self> {
190        let storage = CSRStorage::from_coo(coo, rows, cols)?;
191        let workspace_size = rows.max(cols);
192
193        Ok(Self {
194            storage,
195            buffer_pool: BufferPool::new(),
196            workspace: vec![0.0; workspace_size],
197            matvec_count: AtomicUsize::new(0),
198            bytes_processed: AtomicUsize::new(0),
199        })
200    }
201
202    /// SIMD-accelerated matrix-vector multiplication.
203    #[cfg(feature = "simd")]
204    pub fn multiply_vector_simd(&self, x: &[Precision], result: &mut [Precision]) {
205        result.fill(0.0);
206        self.matvec_count.fetch_add(1, Ordering::Relaxed);
207
208        let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (result.len() * 8);
209        self.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
210
211        for (row, row_result) in result.iter_mut().enumerate() {
212            let start = self.storage.row_ptr[row] as usize;
213            let end = self.storage.row_ptr[row + 1] as usize;
214
215            if end <= start {
216                continue;
217            }
218
219            let row_values = &self.storage.values[start..end];
220            let row_indices = &self.storage.col_indices[start..end];
221
222            // Process in chunks of 4 for SIMD
223            let simd_chunks = row_values.len() / 4;
224            let mut sum = f64x4::splat(0.0);
225
226            for chunk in 0..simd_chunks {
227                let val_idx = chunk * 4;
228                let values = f64x4::new([
229                    row_values[val_idx],
230                    row_values[val_idx + 1],
231                    row_values[val_idx + 2],
232                    row_values[val_idx + 3],
233                ]);
234
235                let x_vals = f64x4::new([
236                    x[row_indices[val_idx] as usize],
237                    x[row_indices[val_idx + 1] as usize],
238                    x[row_indices[val_idx + 2] as usize],
239                    x[row_indices[val_idx + 3] as usize],
240                ]);
241
242                sum = sum + (values * x_vals);
243            }
244
245            // Sum the SIMD register
246            let sum_array = sum.to_array();
247            *row_result = sum_array[0] + sum_array[1] + sum_array[2] + sum_array[3];
248
249            // Handle remaining elements
250            for i in (simd_chunks * 4)..row_values.len() {
251                let col = row_indices[i] as usize;
252                *row_result += row_values[i] * x[col];
253            }
254        }
255    }
256
257    /// Fallback non-SIMD matrix-vector multiplication.
258    #[cfg(not(feature = "simd"))]
259    pub fn multiply_vector_simd(&self, x: &[Precision], result: &mut [Precision]) {
260        self.multiply_vector_optimized(x, result);
261    }
262
263    /// Cache-optimized matrix-vector multiplication.
264    pub fn multiply_vector_optimized(&self, x: &[Precision], result: &mut [Precision]) {
265        result.fill(0.0);
266        self.matvec_count.fetch_add(1, Ordering::Relaxed);
267
268        // Use blocked computation for better cache behavior
269        const BLOCK_SIZE: usize = 64; // Chosen for L1 cache efficiency
270
271        for row_block in (0..result.len()).step_by(BLOCK_SIZE) {
272            let row_end = (row_block + BLOCK_SIZE).min(result.len());
273
274            for row in row_block..row_end {
275                let start = self.storage.row_ptr[row] as usize;
276                let end = self.storage.row_ptr[row + 1] as usize;
277
278                let mut sum = 0.0;
279                for i in start..end {
280                    let col = self.storage.col_indices[i] as usize;
281                    sum += self.storage.values[i] * x[col];
282                }
283                result[row] = sum;
284            }
285        }
286    }
287
288    /// Streaming matrix-vector multiplication for large matrices.
289    pub fn multiply_vector_streaming<F>(
290        &self,
291        x: &[Precision],
292        mut callback: F,
293        chunk_size: usize
294    ) -> Result<()>
295    where
296        F: FnMut(usize, &[Precision]),
297    {
298        let mut result_chunk = vec![0.0; chunk_size];
299
300        for chunk_start in (0..self.storage.row_ptr.len() - 1).step_by(chunk_size) {
301            let chunk_end = (chunk_start + chunk_size).min(self.storage.row_ptr.len() - 1);
302            let actual_chunk_size = chunk_end - chunk_start;
303
304            result_chunk.resize(actual_chunk_size, 0.0);
305            result_chunk.fill(0.0);
306
307            // Compute this chunk
308            for (local_row, global_row) in (chunk_start..chunk_end).enumerate() {
309                let start = self.storage.row_ptr[global_row] as usize;
310                let end = self.storage.row_ptr[global_row + 1] as usize;
311
312                let mut sum = 0.0;
313                for i in start..end {
314                    let col = self.storage.col_indices[i] as usize;
315                    sum += self.storage.values[i] * x[col];
316                }
317                result_chunk[local_row] = sum;
318            }
319
320            callback(chunk_start, &result_chunk[..actual_chunk_size]);
321        }
322
323        Ok(())
324    }
325
326    /// Get performance statistics.
327    pub fn performance_stats(&self) -> OptimizedMatrixStats {
328        OptimizedMatrixStats {
329            matvec_count: self.matvec_count.load(Ordering::Relaxed),
330            bytes_processed: self.bytes_processed.load(Ordering::Relaxed),
331            buffer_pool_stats: self.buffer_pool.stats(),
332            matrix_nnz: self.storage.nnz(),
333            matrix_rows: self.storage.row_ptr.len() - 1,
334            workspace_size: self.workspace.len(),
335        }
336    }
337
338    /// Reset performance counters.
339    pub fn reset_stats(&self) {
340        self.matvec_count.store(0, Ordering::Relaxed);
341        self.bytes_processed.store(0, Ordering::Relaxed);
342    }
343
344    /// Get a temporary buffer from the pool.
345    pub fn get_temp_buffer(&mut self, size: usize) -> Vec<Precision> {
346        self.buffer_pool.get_buffer(size)
347    }
348
349    /// Return a temporary buffer to the pool.
350    pub fn return_temp_buffer(&mut self, buffer: Vec<Precision>) {
351        self.buffer_pool.return_buffer(buffer);
352    }
353
354    /// Access the underlying CSR storage.
355    pub fn storage(&self) -> &CSRStorage {
356        &self.storage
357    }
358}
359
360/// Performance statistics for optimized matrix operations.
361#[derive(Debug, Clone)]
362pub struct OptimizedMatrixStats {
363    pub matvec_count: usize,
364    pub bytes_processed: usize,
365    pub buffer_pool_stats: BufferPoolStats,
366    pub matrix_nnz: usize,
367    pub matrix_rows: usize,
368    pub workspace_size: usize,
369}
370
371impl OptimizedMatrixStats {
372    /// Calculate effective bandwidth in GB/s.
373    pub fn bandwidth_gbs(&self, total_time_ms: f64) -> f64 {
374        if total_time_ms <= 0.0 {
375            0.0
376        } else {
377            let total_gb = self.bytes_processed as f64 / 1_073_741_824.0; // Convert to GB
378            let total_seconds = total_time_ms / 1000.0;
379            total_gb / total_seconds
380        }
381    }
382
383    /// Calculate operations per second.
384    pub fn ops_per_second(&self, total_time_ms: f64) -> f64 {
385        if total_time_ms <= 0.0 {
386            0.0
387        } else {
388            let total_ops = self.matvec_count as f64;
389            let total_seconds = total_time_ms / 1000.0;
390            total_ops / total_seconds
391        }
392    }
393}
394
395/// Parallel CSR storage for multi-threaded operations.
396#[cfg(feature = "std")]
397pub struct ParallelCSRStorage {
398    storage: OptimizedCSRStorage,
399    num_threads: usize,
400}
401
402#[cfg(feature = "std")]
403impl ParallelCSRStorage {
404    /// Create parallel CSR storage.
405    pub fn new(storage: OptimizedCSRStorage, num_threads: Option<usize>) -> Self {
406        let num_threads = num_threads.unwrap_or_else(|| {
407            std::thread::available_parallelism()
408                .map(|p| p.get())
409                .unwrap_or(1)
410        });
411
412        Self {
413            storage,
414            num_threads,
415        }
416    }
417
418    /// Parallel matrix-vector multiplication using Rayon.
419    #[cfg(feature = "rayon")]
420    pub fn multiply_vector_parallel(&self, x: &[Precision], result: &mut [Precision]) {
421        use rayon::prelude::*;
422
423        result.fill(0.0);
424
425        // Determine chunk size for good load balancing
426        let rows = result.len();
427        let chunk_size = (rows + self.num_threads - 1) / self.num_threads;
428
429        result.par_chunks_mut(chunk_size)
430            .enumerate()
431            .for_each(|(chunk_idx, result_chunk)| {
432                let start_row = chunk_idx * chunk_size;
433                let end_row = (start_row + result_chunk.len()).min(rows);
434
435                for (local_idx, global_row) in (start_row..end_row).enumerate() {
436                    let start = self.storage.storage.row_ptr[global_row] as usize;
437                    let end = self.storage.storage.row_ptr[global_row + 1] as usize;
438
439                    let mut sum = 0.0;
440                    for i in start..end {
441                        let col = self.storage.storage.col_indices[i] as usize;
442                        sum += self.storage.storage.values[i] * x[col];
443                    }
444                    result_chunk[local_idx] = sum;
445                }
446            });
447    }
448}
449
450/// Memory-efficient matrix representation for extremely large problems.
451pub struct StreamingMatrix {
452    /// Matrix stored in chunks
453    chunks: Vec<OptimizedCSRStorage>,
454    /// Chunk size (number of rows per chunk)
455    chunk_size: usize,
456    /// Total dimensions
457    total_rows: usize,
458    total_cols: usize,
459    /// Memory limit in bytes
460    memory_limit: usize,
461}
462
463impl StreamingMatrix {
464    /// Create a streaming matrix from triplets with memory constraints.
465    pub fn from_triplets(
466        triplets: Vec<(usize, usize, Precision)>,
467        rows: usize,
468        cols: usize,
469        memory_limit_mb: usize,
470    ) -> Result<Self> {
471        let memory_limit = memory_limit_mb * 1_048_576; // Convert to bytes
472
473        // Estimate memory per row
474        let nnz = triplets.len();
475        let avg_nnz_per_row = if rows > 0 { nnz / rows } else { 0 };
476        let bytes_per_row = avg_nnz_per_row * (8 + 4) + 4; // value + col_index + row_ptr
477
478        // Calculate chunk size to stay within memory limit
479        let target_chunk_size = if bytes_per_row > 0 {
480            (memory_limit / (bytes_per_row * 2)).max(1) // Factor of 2 for safety
481        } else {
482            1000
483        };
484
485        let chunk_size = target_chunk_size.min(rows);
486
487        // Sort triplets by row
488        let mut sorted_triplets = triplets;
489        sorted_triplets.sort_by_key(|(row, _, _)| *row);
490
491        // Split into chunks
492        let mut chunks = Vec::new();
493        let num_chunks = (rows + chunk_size - 1) / chunk_size;
494
495        for chunk_idx in 0..num_chunks {
496            let chunk_start_row = chunk_idx * chunk_size;
497            let chunk_end_row = ((chunk_idx + 1) * chunk_size).min(rows);
498            let chunk_rows = chunk_end_row - chunk_start_row;
499
500            // Extract triplets for this chunk
501            let chunk_triplets: Vec<(usize, usize, Precision)> = sorted_triplets
502                .iter()
503                .filter(|(row, _, _)| *row >= chunk_start_row && *row < chunk_end_row)
504                .map(|(row, col, val)| (row - chunk_start_row, *col, *val))
505                .collect();
506
507            // Create chunk storage
508            if !chunk_triplets.is_empty() {
509                let coo = COOStorage::from_triplets(chunk_triplets)?;
510                let chunk_storage = OptimizedCSRStorage::from_coo(&coo, chunk_rows, cols)?;
511                chunks.push(chunk_storage);
512            } else {
513                // Empty chunk
514                let empty_coo = COOStorage::from_triplets(vec![])?;
515                let empty_storage = OptimizedCSRStorage::from_coo(&empty_coo, chunk_rows, cols)?;
516                chunks.push(empty_storage);
517            }
518        }
519
520        Ok(Self {
521            chunks,
522            chunk_size,
523            total_rows: rows,
524            total_cols: cols,
525            memory_limit,
526        })
527    }
528
529    /// Streaming matrix-vector multiplication.
530    pub fn multiply_vector_streaming<F>(
531        &self,
532        x: &[Precision],
533        mut callback: F,
534    ) -> Result<()>
535    where
536        F: FnMut(usize, &[Precision]),
537    {
538        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
539            let start_row = chunk_idx * self.chunk_size;
540            let end_row = (start_row + self.chunk_size).min(self.total_rows);
541            let chunk_rows = end_row - start_row;
542
543            let mut result = vec![0.0; chunk_rows];
544            chunk.multiply_vector_optimized(x, &mut result);
545
546            callback(start_row, &result);
547        }
548
549        Ok(())
550    }
551
552    /// Get memory usage statistics.
553    pub fn memory_usage(&self) -> usize {
554        self.chunks.iter()
555            .map(|chunk| {
556                let stats = chunk.performance_stats();
557                stats.matrix_nnz * 12 + stats.matrix_rows * 4 // Rough estimate
558            })
559            .sum()
560    }
561}
562
563#[cfg(all(test, feature = "std"))]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_buffer_pool() {
569        let mut pool = BufferPool::new();
570
571        // Test buffer allocation and return
572        let buffer1 = pool.get_buffer(100);
573        assert_eq!(buffer1.len(), 100);
574
575        pool.return_buffer(buffer1);
576
577        let buffer2 = pool.get_buffer(50);
578        assert_eq!(buffer2.len(), 50);
579
580        let stats = pool.stats();
581        assert_eq!(stats.allocations, 2);
582        assert_eq!(stats.deallocations, 1);
583    }
584
585    #[test]
586    fn test_optimized_csr_performance() {
587        // Create a simple test matrix
588        let triplets = vec![
589            (0, 0, 2.0), (0, 1, 1.0),
590            (1, 0, 1.0), (1, 1, 3.0),
591        ];
592        let coo = COOStorage::from_triplets(triplets).unwrap();
593        let optimized = OptimizedCSRStorage::from_coo(&coo, 2, 2).unwrap();
594
595        let x = vec![1.0, 2.0];
596        let mut result = vec![0.0; 2];
597
598        optimized.multiply_vector_optimized(&x, &mut result);
599        assert_eq!(result, vec![4.0, 7.0]);
600
601        let stats = optimized.performance_stats();
602        assert_eq!(stats.matvec_count, 1);
603    }
604
605    #[test]
606    fn test_streaming_matrix() {
607        let triplets = vec![
608            (0, 0, 1.0), (0, 1, 2.0),
609            (1, 0, 3.0), (1, 1, 4.0),
610            (2, 0, 5.0), (2, 1, 6.0),
611        ];
612
613        let streaming = StreamingMatrix::from_triplets(triplets, 3, 2, 1).unwrap();
614        let x = vec![1.0, 1.0];
615
616        let mut results = Vec::new();
617        streaming.multiply_vector_streaming(&x, |start_row, chunk_result| {
618            results.extend_from_slice(chunk_result);
619        }).unwrap();
620
621        // Each chunk should produce correct results
622        assert!(results.len() >= 3);
623    }
624}