scirs2_sparse/neural_adaptive_sparse/
pattern_memory.rs

1//! Pattern memory and fingerprinting for sparse matrix optimization
2//!
3//! This module implements memory systems that learn and store optimal strategies
4//! for different types of sparse matrix patterns and access behaviors.
5
6use std::collections::{HashMap, VecDeque};
7use std::hash::{Hash, Hasher};
8
9/// Optimization strategies learned by the neural network
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum OptimizationStrategy {
12    /// Row-wise processing with cache optimization
13    RowWiseCache,
14    /// Column-wise processing for memory locality
15    ColumnWiseLocality,
16    /// Block-based processing for structured matrices
17    BlockStructured,
18    /// Diagonal-optimized processing
19    DiagonalOptimized,
20    /// Hierarchical decomposition
21    Hierarchical,
22    /// Streaming computation for large matrices
23    StreamingCompute,
24    /// SIMD-vectorized computation
25    SIMDVectorized,
26    /// Parallel work-stealing
27    ParallelWorkStealing,
28    /// Adaptive hybrid approach
29    AdaptiveHybrid,
30}
31
32/// Pattern memory for learning matrix characteristics
33#[derive(Debug)]
34pub(crate) struct PatternMemory {
35    pub matrix_patterns: HashMap<MatrixFingerprint, OptimizationStrategy>,
36    #[allow(dead_code)]
37    pub access_patterns: VecDeque<AccessPattern>,
38    #[allow(dead_code)]
39    pub performance_cache: HashMap<String, f64>,
40}
41
42/// Matrix fingerprint for pattern recognition
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub(crate) struct MatrixFingerprint {
45    pub rows: usize,
46    pub cols: usize,
47    pub nnz: usize,
48    pub sparsity_pattern_hash: u64,
49    pub row_distribution_type: DistributionType,
50    pub column_distribution_type: DistributionType,
51}
52
53/// Distribution types for sparsity patterns
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub(crate) enum DistributionType {
56    Uniform,
57    Clustered,
58    BandDiagonal,
59    #[allow(dead_code)]
60    BlockStructured,
61    Random,
62    PowerLaw,
63}
64
65/// Access pattern for memory optimization
66#[derive(Debug, Clone)]
67pub(crate) struct AccessPattern {
68    #[allow(dead_code)]
69    pub timestamp: u64,
70    #[allow(dead_code)]
71    pub row_sequence: Vec<usize>,
72    #[allow(dead_code)]
73    pub column_sequence: Vec<usize>,
74    #[allow(dead_code)]
75    pub cache_hits: usize,
76    #[allow(dead_code)]
77    pub cache_misses: usize,
78}
79
80impl PatternMemory {
81    /// Create a new pattern memory system
82    pub fn new(capacity: usize) -> Self {
83        Self {
84            matrix_patterns: HashMap::new(),
85            access_patterns: VecDeque::new(),
86            performance_cache: HashMap::new(),
87        }
88    }
89
90    /// Store a learned pattern-strategy mapping
91    pub fn store_pattern(
92        &mut self,
93        fingerprint: MatrixFingerprint,
94        strategy: OptimizationStrategy,
95    ) {
96        self.matrix_patterns.insert(fingerprint, strategy);
97    }
98
99    /// Retrieve optimal strategy for a matrix pattern
100    pub fn get_strategy(&self, fingerprint: &MatrixFingerprint) -> Option<OptimizationStrategy> {
101        self.matrix_patterns.get(fingerprint).copied()
102    }
103
104    /// Find similar patterns using fingerprint matching
105    pub fn find_similar_patterns(
106        &self,
107        fingerprint: &MatrixFingerprint,
108        similarity_threshold: f64,
109    ) -> Vec<(MatrixFingerprint, OptimizationStrategy)> {
110        let mut similar_patterns = Vec::new();
111
112        for (stored_fingerprint, strategy) in &self.matrix_patterns {
113            let similarity = self.compute_similarity(fingerprint, stored_fingerprint);
114            if similarity >= similarity_threshold {
115                similar_patterns.push((stored_fingerprint.clone(), *strategy));
116            }
117        }
118
119        // Sort by similarity (descending)
120        similar_patterns.sort_by(|a, b| {
121            let sim_a = self.compute_similarity(fingerprint, &a.0);
122            let sim_b = self.compute_similarity(fingerprint, &b.0);
123            sim_b.partial_cmp(&sim_a).unwrap()
124        });
125
126        similar_patterns
127    }
128
129    /// Compute similarity between two matrix fingerprints
130    fn compute_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
131        let size_similarity = self.size_similarity(fp1, fp2);
132        let sparsity_similarity = self.sparsity_similarity(fp1, fp2);
133        let pattern_similarity = self.pattern_similarity(fp1, fp2);
134        let distribution_similarity = self.distribution_similarity(fp1, fp2);
135
136        // Weighted combination of different similarity measures
137        0.3 * size_similarity
138            + 0.3 * sparsity_similarity
139            + 0.2 * pattern_similarity
140            + 0.2 * distribution_similarity
141    }
142
143    /// Compute size similarity between matrices
144    fn size_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
145        let row_ratio = (fp1.rows.min(fp2.rows) as f64) / (fp1.rows.max(fp2.rows) as f64);
146        let col_ratio = (fp1.cols.min(fp2.cols) as f64) / (fp1.cols.max(fp2.cols) as f64);
147        (row_ratio + col_ratio) / 2.0
148    }
149
150    /// Compute sparsity similarity
151    fn sparsity_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
152        let sparsity1 = fp1.nnz as f64 / (fp1.rows * fp1.cols) as f64;
153        let sparsity2 = fp2.nnz as f64 / (fp2.rows * fp2.cols) as f64;
154        1.0 - (sparsity1 - sparsity2).abs()
155    }
156
157    /// Compute pattern similarity using hash comparison
158    fn pattern_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
159        // Simple hash-based similarity (in practice, you might use more sophisticated methods)
160        let hash_diff = (fp1.sparsity_pattern_hash ^ fp2.sparsity_pattern_hash).count_ones() as f64;
161        1.0 - (hash_diff / 64.0) // Assuming 64-bit hash
162    }
163
164    /// Compute distribution similarity
165    fn distribution_similarity(&self, fp1: &MatrixFingerprint, fp2: &MatrixFingerprint) -> f64 {
166        let row_match = if fp1.row_distribution_type == fp2.row_distribution_type {
167            1.0
168        } else {
169            0.0
170        };
171        let col_match = if fp1.column_distribution_type == fp2.column_distribution_type {
172            1.0
173        } else {
174            0.0
175        };
176        (row_match + col_match) / 2.0
177    }
178
179    /// Record access pattern for learning
180    pub fn record_access_pattern(&mut self, pattern: AccessPattern) {
181        self.access_patterns.push_back(pattern);
182
183        // Keep only recent patterns (sliding window)
184        if self.access_patterns.len() > 1000 {
185            self.access_patterns.pop_front();
186        }
187    }
188
189    /// Cache performance result
190    pub fn cache_performance(&mut self, key: String, performance: f64) {
191        self.performance_cache.insert(key, performance);
192
193        // Limit cache size
194        if self.performance_cache.len() > 10000 {
195            // Remove oldest entries (simplified approach)
196            let keys_to_remove: Vec<String> =
197                self.performance_cache.keys().take(1000).cloned().collect();
198            for key in keys_to_remove {
199                self.performance_cache.remove(&key);
200            }
201        }
202    }
203
204    /// Get cached performance
205    pub fn get_cached_performance(&self, key: &str) -> Option<f64> {
206        self.performance_cache.get(key).copied()
207    }
208
209    /// Analyze access patterns to detect trends
210    pub fn analyze_access_patterns(&self) -> AccessPatternAnalysis {
211        if self.access_patterns.is_empty() {
212            return AccessPatternAnalysis::default();
213        }
214
215        let mut sequential_count = 0;
216        let mut random_count = 0;
217        let mut block_count = 0;
218
219        for pattern in &self.access_patterns {
220            let access_type = self.classify_access_pattern(pattern);
221            match access_type {
222                AccessType::Sequential => sequential_count += 1,
223                AccessType::Random => random_count += 1,
224                AccessType::Block => block_count += 1,
225            }
226        }
227
228        let total = self.access_patterns.len();
229        AccessPatternAnalysis {
230            sequential_ratio: sequential_count as f64 / total as f64,
231            random_ratio: random_count as f64 / total as f64,
232            block_ratio: block_count as f64 / total as f64,
233            cache_hit_rate: self.compute_average_cache_hit_rate(),
234        }
235    }
236
237    /// Classify access pattern type
238    fn classify_access_pattern(&self, pattern: &AccessPattern) -> AccessType {
239        // Simplified classification logic
240        if pattern.row_sequence.is_empty() {
241            return AccessType::Random;
242        }
243
244        // Check for sequential access
245        let mut sequential = true;
246        for i in 1..pattern.row_sequence.len() {
247            if pattern.row_sequence[i] != pattern.row_sequence[i - 1] + 1
248                && pattern.row_sequence[i] != pattern.row_sequence[i - 1]
249            {
250                sequential = false;
251                break;
252            }
253        }
254
255        if sequential {
256            return AccessType::Sequential;
257        }
258
259        // Check for block access
260        let unique_rows: std::collections::HashSet<_> = pattern.row_sequence.iter().collect();
261        if unique_rows.len() < pattern.row_sequence.len() / 2 {
262            return AccessType::Block;
263        }
264
265        AccessType::Random
266    }
267
268    /// Compute average cache hit rate
269    fn compute_average_cache_hit_rate(&self) -> f64 {
270        if self.access_patterns.is_empty() {
271            return 0.0;
272        }
273
274        let total_hits: usize = self.access_patterns.iter().map(|p| p.cache_hits).sum();
275        let total_accesses: usize = self
276            .access_patterns
277            .iter()
278            .map(|p| p.cache_hits + p.cache_misses)
279            .sum();
280
281        if total_accesses == 0 {
282            0.0
283        } else {
284            total_hits as f64 / total_accesses as f64
285        }
286    }
287
288    /// Suggest optimization strategy based on patterns
289    pub fn suggest_strategy(&self, fingerprint: &MatrixFingerprint) -> OptimizationStrategy {
290        // First, try exact match
291        if let Some(strategy) = self.get_strategy(fingerprint) {
292            return strategy;
293        }
294
295        // Find similar patterns
296        let similar = self.find_similar_patterns(fingerprint, 0.7);
297        if !similar.is_empty() {
298            return similar[0].1; // Return strategy from most similar pattern
299        }
300
301        // Fallback to heuristic-based suggestion
302        self.heuristic_strategy_suggestion(fingerprint)
303    }
304
305    /// Heuristic-based strategy suggestion
306    fn heuristic_strategy_suggestion(
307        &self,
308        fingerprint: &MatrixFingerprint,
309    ) -> OptimizationStrategy {
310        let sparsity = fingerprint.nnz as f64 / (fingerprint.rows * fingerprint.cols) as f64;
311        let size = fingerprint.rows * fingerprint.cols;
312
313        match (
314            fingerprint.row_distribution_type,
315            fingerprint.column_distribution_type,
316        ) {
317            (DistributionType::BandDiagonal, _) | (_, DistributionType::BandDiagonal) => {
318                OptimizationStrategy::DiagonalOptimized
319            }
320            (DistributionType::Clustered, DistributionType::Clustered) => {
321                OptimizationStrategy::BlockStructured
322            }
323            _ => {
324                if sparsity < 0.01 && size > 10000 {
325                    OptimizationStrategy::StreamingCompute
326                } else if size > 100000 {
327                    OptimizationStrategy::ParallelWorkStealing
328                } else if sparsity > 0.1 {
329                    OptimizationStrategy::SIMDVectorized
330                } else {
331                    OptimizationStrategy::AdaptiveHybrid
332                }
333            }
334        }
335    }
336
337    /// Get memory statistics
338    pub fn get_statistics(&self) -> PatternMemoryStats {
339        PatternMemoryStats {
340            stored_patterns: self.matrix_patterns.len(),
341            access_patterns_recorded: self.access_patterns.len(),
342            cached_performances: self.performance_cache.len(),
343            most_common_strategy: self.get_most_common_strategy(),
344        }
345    }
346
347    /// Get most commonly used optimization strategy
348    fn get_most_common_strategy(&self) -> Option<OptimizationStrategy> {
349        let mut strategy_counts = HashMap::new();
350
351        for strategy in self.matrix_patterns.values() {
352            *strategy_counts.entry(*strategy).or_insert(0) += 1;
353        }
354
355        strategy_counts
356            .into_iter()
357            .max_by(|(_, count1), (_, count2)| count1.cmp(count2))
358            .map(|(strategy, _)| strategy)
359    }
360}
361
362impl MatrixFingerprint {
363    /// Create a new matrix fingerprint
364    pub fn new<T>(rows: Vec<usize>, cols: Vec<usize>, data: &[T], shape: (usize, usize)) -> Self
365    where
366        T: std::fmt::Debug + Copy + PartialEq,
367    {
368        let nnz = data.len();
369        let sparsity_pattern_hash = Self::compute_pattern_hash(&rows, &cols);
370        let row_distribution_type = Self::analyze_distribution(&rows, shape.0);
371        let column_distribution_type = Self::analyze_distribution(&cols, shape.1);
372
373        Self {
374            rows: shape.0,
375            cols: shape.1,
376            nnz,
377            sparsity_pattern_hash,
378            row_distribution_type,
379            column_distribution_type,
380        }
381    }
382
383    /// Compute hash of sparsity pattern
384    fn compute_pattern_hash(rows: &[usize], cols: &[usize]) -> u64 {
385        use std::collections::hash_map::DefaultHasher;
386
387        let mut hasher = DefaultHasher::new();
388
389        // Sample pattern for hash computation (to avoid excessive computation)
390        let step = (rows.len() / 100).max(1);
391        for i in (0..rows.len()).step_by(step) {
392            rows[i].hash(&mut hasher);
393            if i < cols.len() {
394                cols[i].hash(&mut hasher);
395            }
396        }
397
398        hasher.finish()
399    }
400
401    /// Analyze distribution type of indices
402    fn analyze_distribution(indices: &[usize], max_value: usize) -> DistributionType {
403        if indices.is_empty() {
404            return DistributionType::Uniform;
405        }
406
407        // Check for band diagonal pattern
408        if Self::is_band_diagonal(indices) {
409            return DistributionType::BandDiagonal;
410        }
411
412        // Check for clustering
413        if Self::is_clustered(indices, max_value) {
414            return DistributionType::Clustered;
415        }
416
417        // Check for uniform distribution
418        if Self::is_uniform(indices, max_value) {
419            return DistributionType::Uniform;
420        }
421
422        // Check for power law distribution
423        if Self::is_power_law(indices) {
424            return DistributionType::PowerLaw;
425        }
426
427        DistributionType::Random
428    }
429
430    /// Check if indices follow a band diagonal pattern
431    fn is_band_diagonal(indices: &[usize]) -> bool {
432        if indices.len() < 2 {
433            return false;
434        }
435
436        let mut sorted_indices = indices.to_vec();
437        sorted_indices.sort_unstable();
438
439        // Check if indices are within a small range (band)
440        let range = sorted_indices[sorted_indices.len() - 1] - sorted_indices[0];
441        let density = indices.len() as f64 / (range + 1) as f64;
442
443        density > 0.5 && range < indices.len() * 3
444    }
445
446    /// Check if indices are clustered
447    fn is_clustered(indices: &[usize], max_value: usize) -> bool {
448        if indices.is_empty() {
449            return false;
450        }
451
452        let mut histogram = vec![0; (max_value / 10).max(10)];
453        for &idx in indices {
454            let bucket = (idx * histogram.len()) / (max_value + 1);
455            if bucket < histogram.len() {
456                histogram[bucket] += 1;
457            }
458        }
459
460        // Check if most values are in few buckets
461        histogram.sort_unstable();
462        let top_buckets = histogram.len() / 3;
463        let top_count: usize = histogram.iter().rev().take(top_buckets).sum();
464        let total_count: usize = histogram.iter().sum();
465
466        top_count as f64 / total_count as f64 > 0.7
467    }
468
469    /// Check if indices are uniformly distributed
470    fn is_uniform(indices: &[usize], max_value: usize) -> bool {
471        if indices.is_empty() {
472            return false;
473        }
474
475        let bucket_count = (max_value / 10).max(10);
476        let mut histogram = vec![0; bucket_count];
477
478        for &idx in indices {
479            let bucket = (idx * bucket_count) / (max_value + 1);
480            if bucket < histogram.len() {
481                histogram[bucket] += 1;
482            }
483        }
484
485        // Check variance of histogram
486        let mean = indices.len() as f64 / bucket_count as f64;
487        let variance: f64 = histogram
488            .iter()
489            .map(|&count| (count as f64 - mean).powi(2))
490            .sum::<f64>()
491            / bucket_count as f64;
492        let std_dev = variance.sqrt();
493
494        std_dev / mean < 0.5 // Low relative variance indicates uniformity
495    }
496
497    /// Check if indices follow a power law distribution
498    fn is_power_law(indices: &[usize]) -> bool {
499        if indices.is_empty() {
500            return false;
501        }
502
503        let mut sorted_indices = indices.to_vec();
504        sorted_indices.sort_unstable();
505        sorted_indices.dedup();
506
507        // Simple power law check: few values appear very frequently
508        let mut frequency_map = HashMap::new();
509        for &idx in indices {
510            *frequency_map.entry(idx).or_insert(0) += 1;
511        }
512
513        let mut frequencies: Vec<usize> = frequency_map.values().copied().collect();
514        frequencies.sort_unstable();
515        frequencies.reverse();
516
517        if frequencies.len() < 3 {
518            return false;
519        }
520
521        // Check if top few frequencies dominate
522        let top_10_percent = (frequencies.len() / 10).max(1);
523        let top_sum: usize = frequencies.iter().take(top_10_percent).sum();
524        let total_sum: usize = frequencies.iter().sum();
525
526        top_sum as f64 / total_sum as f64 > 0.8
527    }
528}
529
530/// Access pattern classification
531#[derive(Debug, Clone, Copy)]
532enum AccessType {
533    Sequential,
534    Random,
535    Block,
536}
537
538/// Analysis results for access patterns
539#[derive(Debug, Clone)]
540pub struct AccessPatternAnalysis {
541    pub sequential_ratio: f64,
542    pub random_ratio: f64,
543    pub block_ratio: f64,
544    pub cache_hit_rate: f64,
545}
546
547impl Default for AccessPatternAnalysis {
548    fn default() -> Self {
549        Self {
550            sequential_ratio: 0.0,
551            random_ratio: 0.0,
552            block_ratio: 0.0,
553            cache_hit_rate: 0.0,
554        }
555    }
556}
557
558/// Statistics for pattern memory
559#[derive(Debug, Clone)]
560pub struct PatternMemoryStats {
561    pub stored_patterns: usize,
562    pub access_patterns_recorded: usize,
563    pub cached_performances: usize,
564    pub most_common_strategy: Option<OptimizationStrategy>,
565}