scirs2_core/memory_efficient/
pattern_recognition.rs

1//! Advanced pattern recognition for memory access patterns.
2//!
3//! This module provides specialized algorithms for detecting complex access patterns
4//! that are common in scientific computing workloads, such as:
5//! - Diagonal traversals
6//! - Block-based accesses
7//! - Stencil operations
8//! - Strided matrix operations
9//! - Custom patterns defined by mathematical functions
10
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::time::Instant;
13
14use super::prefetch::AccessPattern;
15
16/// The different types of complex patterns that can be recognized.
17#[derive(Debug, Clone, PartialEq)]
18pub enum ComplexPattern {
19    /// Standard row-major traversal
20    RowMajor,
21
22    /// Standard column-major traversal
23    ColumnMajor,
24
25    /// Zigzag (alternating directions per row)
26    Zigzag,
27
28    /// Diagonal traversal (main diagonal)
29    DiagonalMajor,
30
31    /// Anti-diagonal traversal (other diagonal)
32    DiagonalMinor,
33
34    /// Block-based traversal (common in tiled algorithms)
35    Block {
36        block_height: usize,
37        block_width: usize,
38    },
39
40    /// Strided access within blocks
41    BlockStrided { block_size: usize, stride: usize },
42
43    /// Stencil operation (center point with neighbors)
44    Stencil { dimensions: usize, radius: usize },
45
46    /// Rotating blocks (e.g., for matrix transposition)
47    RotatingBlock { block_size: usize },
48
49    /// Sparse access (e.g., for sparse matrices)
50    Sparse { density: f64 },
51
52    /// Hierarchical traversal (e.g., Z-order curve)
53    Hierarchical { levels: usize },
54
55    /// Custom pattern with a name
56    Custom(String),
57}
58
59/// Confidence level for pattern detection.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
61pub enum Confidence {
62    /// Pattern is definitely detected
63    High,
64
65    /// Pattern is probably detected
66    Medium,
67
68    /// Pattern might be detected
69    Low,
70
71    /// Pattern is only tentatively detected
72    Tentative,
73}
74
75/// A recognized pattern with metadata.
76#[derive(Debug, Clone)]
77pub struct RecognizedPattern {
78    /// The type of pattern
79    pub pattern_type: ComplexPattern,
80
81    /// Confidence in the pattern detection
82    pub confidence: Confidence,
83
84    /// Additional metadata about the pattern
85    pub metadata: HashMap<String, String>,
86
87    /// When the pattern was first detected
88    pub first_detected: Instant,
89
90    /// When the pattern was last confirmed
91    pub last_confirmed: Instant,
92
93    /// Number of times the pattern has been confirmed
94    pub confirmation_count: usize,
95}
96
97impl RecognizedPattern {
98    /// Create a new recognized pattern.
99    pub fn new(patterntype: ComplexPattern, confidence: Confidence) -> Self {
100        let now = Instant::now();
101        Self {
102            pattern_type: patterntype,
103            confidence,
104            metadata: HashMap::new(),
105            first_detected: now,
106            last_confirmed: now,
107            confirmation_count: 1,
108        }
109    }
110
111    /// Add metadata to the pattern.
112    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
113        self.metadata.insert(key.to_string(), value.to_string());
114        self
115    }
116
117    /// Confirm the pattern, increasing confidence.
118    pub fn confirm(&mut self) {
119        self.confirmation_count += 1;
120        self.last_confirmed = Instant::now();
121
122        // Increase confidence based on confirmation count
123        if self.confirmation_count >= 10 {
124            self.confidence = Confidence::High;
125        } else if self.confirmation_count >= 5 {
126            self.confidence = Confidence::Medium;
127        } else if self.confirmation_count >= 2 {
128            self.confidence = Confidence::Low;
129        }
130    }
131
132    /// Check if the pattern is still valid.
133    pub fn is_valid(&self, maxage: std::time::Duration) -> bool {
134        self.last_confirmed.elapsed() <= maxage
135    }
136}
137
138/// Configuration for pattern recognition.
139#[derive(Debug, Clone)]
140pub struct PatternRecognitionConfig {
141    /// Minimum history size needed for pattern detection
142    pub min_history_size: usize,
143
144    /// Maximum time to consider a pattern valid without confirmation
145    pub pattern_expiry: std::time::Duration,
146
147    /// Whether to detect diagonal patterns
148    pub detect_diagonal: bool,
149
150    /// Whether to detect block patterns
151    pub detect_block: bool,
152
153    /// Whether to detect stencil patterns
154    pub detect_stencil: bool,
155
156    /// Whether to detect sparse patterns
157    pub detect_sparse: bool,
158
159    /// Whether to use machine learning for pattern detection
160    pub use_machine_learning: bool,
161}
162
163impl Default for PatternRecognitionConfig {
164    fn default() -> Self {
165        Self {
166            min_history_size: 20,
167            pattern_expiry: std::time::Duration::from_secs(60),
168            detect_diagonal: true,
169            detect_block: true,
170            detect_stencil: true,
171            detect_sparse: true,
172            use_machine_learning: false, // Disabled by default as it requires more dependencies
173        }
174    }
175}
176
177/// Pattern recognition engine for complex access patterns.
178#[derive(Debug)]
179pub struct PatternRecognizer {
180    /// Configuration for the recognizer
181    config: PatternRecognitionConfig,
182
183    /// Dimensions of the array
184    dimensions: Option<Vec<usize>>,
185
186    /// History of accessed indices
187    history: VecDeque<usize>,
188
189    /// Recognized patterns
190    patterns: Vec<RecognizedPattern>,
191
192    /// Most recently detected basic pattern
193    basic_pattern: AccessPattern,
194}
195
196impl PatternRecognizer {
197    /// Create a new pattern recognizer.
198    pub fn new(config: PatternRecognitionConfig) -> Self {
199        Self {
200            config,
201            dimensions: None,
202            history: VecDeque::with_capacity(100),
203            patterns: Vec::new(),
204            basic_pattern: AccessPattern::Random,
205        }
206    }
207
208    /// Set the dimensions of the array.
209    pub fn set_dimensions(&mut self, dimensions: Vec<usize>) {
210        self.dimensions = Some(dimensions);
211    }
212
213    /// Add a new access to the history.
214    pub fn record_access(&mut self, index: usize) {
215        self.history.push_back(index);
216
217        // Limit history size
218        while self.history.len() > 100 {
219            self.history.pop_front();
220        }
221
222        // Only try to detect patterns if we have enough history
223        if self.history.len() >= self.config.min_history_size {
224            self.detect_patterns();
225        }
226    }
227
228    /// Detect patterns in the current history.
229    fn detect_patterns(&mut self) {
230        // Remove expired patterns
231        self.patterns
232            .retain(|pattern| pattern.is_valid(self.config.pattern_expiry));
233
234        // Basic patterns
235        self.detect_basic_patterns();
236
237        // Complex patterns based on dimensions
238        if let Some(dims) = self.dimensions.clone() {
239            // Detect matrix traversal patterns
240            if dims.len() >= 2 {
241                self.detectmatrix_patterns(&dims);
242            }
243
244            // Detect block patterns
245            if self.config.detect_block && dims.len() >= 2 {
246                self.detect_block_patterns(&dims);
247            }
248
249            // Detect diagonal patterns
250            if self.config.detect_diagonal && dims.len() == 2 {
251                self.detect_diagonal_patterns(&dims);
252            }
253
254            // Detect stencil patterns
255            if self.config.detect_stencil && dims.len() >= 2 {
256                self.detect_stencil_patterns(&dims);
257            }
258        }
259
260        // Detect sparse patterns
261        if self.config.detect_sparse {
262            self.detect_sparse_pattern();
263        }
264    }
265
266    /// Detect basic sequential and strided patterns.
267    fn detect_basic_patterns(&mut self) {
268        let indices: Vec<_> = self.history.iter().cloned().collect();
269
270        // Check for sequential access
271        let mut sequential_count = 0;
272        for i in 1..indices.len() {
273            if indices[i] == indices[i.saturating_sub(1)] + 1 {
274                sequential_count += 1;
275            }
276        }
277
278        if sequential_count >= indices.len() * 3 / 4 {
279            self.basic_pattern = AccessPattern::Sequential;
280
281            // Check for row-major pattern if dimensions are known
282            if let Some(ref dims) = self.dimensions {
283                if dims.len() >= 2 {
284                    let row_size = dims[1];
285                    let pattern = ComplexPattern::RowMajor;
286
287                    // Check if we already have this pattern
288                    if let Some(existing) = self.find_pattern(&pattern) {
289                        // Confirm existing pattern
290                        existing.confirm();
291                    } else {
292                        // Add new pattern
293                        let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
294                            .with_metadata("row_size", &row_size.to_string());
295                        self.patterns.push(pattern);
296                    }
297                }
298            }
299
300            return;
301        }
302
303        // Check for strided access - try different strides
304        let mut best_stride = 0;
305        let mut best_stride_count = 0;
306
307        for stride in 2..=20 {
308            let mut stride_count = 0;
309            for i in 1..indices.len() {
310                if indices[i].saturating_sub(indices[i.saturating_sub(1)]) == stride {
311                    stride_count += 1;
312                }
313            }
314
315            if stride_count > best_stride_count {
316                best_stride_count = stride_count;
317                best_stride = stride;
318            }
319        }
320
321        if best_stride_count >= indices.len() * 2 / 3 {
322            self.basic_pattern = AccessPattern::Strided(best_stride);
323
324            // Check for column-major pattern if dimensions are known
325            if let Some(ref dims) = self.dimensions {
326                if dims.len() >= 2 {
327                    let num_rows = dims[0];
328
329                    if best_stride == num_rows {
330                        let pattern = ComplexPattern::ColumnMajor;
331
332                        // Check if we already have this pattern
333                        if let Some(existing) = self.find_pattern(&pattern) {
334                            // Confirm existing pattern
335                            existing.confirm();
336                        } else {
337                            // Add new pattern
338                            let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
339                                .with_metadata("num_rows", &num_rows.to_string());
340                            self.patterns.push(pattern);
341                        }
342                    }
343                }
344            }
345
346            return;
347        }
348
349        // No simple pattern detected
350        self.basic_pattern = AccessPattern::Random;
351    }
352
353    /// Detect matrix traversal patterns.
354    fn detectmatrix_patterns(&mut self, dimensions: &[usize]) {
355        if dimensions.len() < 2 {
356            return;
357        }
358
359        let _rows = dimensions[0];
360        let cols = dimensions[1];
361        let indices: Vec<_> = self.history.iter().cloned().collect();
362
363        // Check for zigzag pattern - alternating left-right traversal within rows
364        let mut zigzag_evidence = 0;
365        let mut last_row_direction = None;
366
367        // Group indices by rows and preserve access order
368        let mut rows: HashMap<usize, Vec<(usize, usize)>> = HashMap::new(); // (col_index, access_order)
369        for (access_order, &idx) in indices.iter().enumerate() {
370            let row = idx / cols;
371            let col = idx % cols;
372            rows.entry(row).or_default().push((col, access_order));
373        }
374
375        // Check if consecutive rows alternate direction based on access order
376        let sorted_rows: Vec<_> = {
377            let mut sorted = rows.keys().cloned().collect::<Vec<_>>();
378            sorted.sort();
379            sorted
380        };
381
382        for row_num in &sorted_rows {
383            let mut cols_in_row = rows[row_num].clone();
384            if cols_in_row.len() >= 2 {
385                // Sort by access order to see the actual traversal pattern
386                cols_in_row.sort_by_key(|(_, access_order)| *access_order);
387
388                // Determine direction within this row based on column progression
389                // Check if columns are accessed in increasing or decreasing order
390                let mut increasing = 0;
391                let mut decreasing = 0;
392                for i in 1..cols_in_row.len() {
393                    match cols_in_row[i].0.cmp(&cols_in_row[i.saturating_sub(1)].0) {
394                        std::cmp::Ordering::Greater => increasing += 1,
395                        std::cmp::Ordering::Less => decreasing += 1,
396                        std::cmp::Ordering::Equal => {}
397                    }
398                }
399
400                let current_direction = match increasing.cmp(&decreasing) {
401                    std::cmp::Ordering::Greater => 1, // Left to right
402                    std::cmp::Ordering::Less => -1,   // Right to left
403                    std::cmp::Ordering::Equal => 0,   // No clear direction
404                };
405
406                // Check if direction alternates from previous row
407                if current_direction != 0 {
408                    if let Some(prev_direction) = last_row_direction {
409                        if current_direction != prev_direction && prev_direction != 0 {
410                            zigzag_evidence += 1;
411                        }
412                    }
413                    last_row_direction = Some(current_direction);
414                }
415            }
416        }
417
418        // To confirm zigzag, we need at least 2 direction changes (3 rows minimum)
419        // Also ensure we have seen enough rows to make this determination
420        if zigzag_evidence >= 2 && sorted_rows.len() >= 3 {
421            let pattern = ComplexPattern::Zigzag;
422
423            // Check if we already have this pattern
424            if let Some(existing) = self.find_pattern(&pattern) {
425                // Confirm existing pattern
426                existing.confirm();
427            } else {
428                // Add new pattern
429                let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
430                    .with_metadata("zigzag_evidence", &zigzag_evidence.to_string());
431                self.patterns.push(pattern);
432            }
433        }
434    }
435
436    /// Detect diagonal traversal patterns.
437    fn detect_diagonal_patterns(&mut self, dimensions: &[usize]) {
438        if dimensions.len() != 2 {
439            return;
440        }
441
442        let _rows = dimensions[0];
443        let cols = dimensions[1];
444        let indices: Vec<_> = self.history.iter().cloned().collect();
445
446        // Check for main diagonal traversal
447        let mut diagonal_matches = 0;
448        for i in 1..indices.len() {
449            let prev_idx = indices[i.saturating_sub(1)];
450            let curr_idx = indices[i];
451
452            let prev_row = prev_idx / cols;
453            let prev_col = prev_idx % cols;
454
455            let curr_row = curr_idx / cols;
456            let curr_col = curr_idx % cols;
457
458            // Check if moving along the main diagonal (row+1, col+1)
459            if curr_row == prev_row + 1 && curr_col == prev_col + 1 {
460                diagonal_matches += 1;
461            }
462        }
463
464        // Need a significant portion of transitions to be diagonal
465        // For consecutive diagonal accesses, we expect (n-1) diagonal transitions
466        let expected_transitions = indices.len().saturating_sub(1);
467        // Lower threshold: at least 1/3 of transitions or at least 3 diagonal matches
468        if (diagonal_matches >= expected_transitions / 3 || diagonal_matches >= 3)
469            && diagonal_matches > 0
470        {
471            let pattern = ComplexPattern::DiagonalMajor;
472
473            // Check if we already have this pattern
474            if let Some(existing) = self.find_pattern(&pattern) {
475                // Confirm existing pattern
476                existing.confirm();
477            } else {
478                // Add new pattern
479                let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
480                    .with_metadata("diagonal_matches", &diagonal_matches.to_string());
481                self.patterns.push(pattern);
482            }
483
484            return;
485        }
486
487        // Check for anti-diagonal traversal
488        let mut anti_diagonal_matches = 0;
489        for i in 1..indices.len() {
490            let prev_idx = indices[i.saturating_sub(1)];
491            let curr_idx = indices[i];
492
493            let prev_row = prev_idx / cols;
494            let prev_col = prev_idx % cols;
495
496            let curr_row = curr_idx / cols;
497            let curr_col = curr_idx % cols;
498
499            // Check if moving along the anti-diagonal (row+1, col-1)
500            if curr_row == prev_row + 1 && curr_col + 1 == prev_col {
501                anti_diagonal_matches += 1;
502            }
503        }
504
505        // Need a significant portion of transitions to be anti-diagonal
506        let expected_transitions = indices.len().saturating_sub(1);
507        // Lower threshold: at least 1/3 of transitions or at least 3 anti-diagonal matches
508        if (anti_diagonal_matches >= expected_transitions / 3 || anti_diagonal_matches >= 3)
509            && anti_diagonal_matches > 0
510        {
511            let pattern = ComplexPattern::DiagonalMinor;
512
513            // Check if we already have this pattern
514            if let Some(existing) = self.find_pattern(&pattern) {
515                // Confirm existing pattern
516                existing.confirm();
517            } else {
518                // Add new pattern
519                let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
520                    .with_metadata("anti_diagonal_matches", &anti_diagonal_matches.to_string());
521                self.patterns.push(pattern);
522            }
523        }
524    }
525
526    /// Detect block-based access patterns.
527    fn detect_block_patterns(&mut self, dimensions: &[usize]) {
528        if dimensions.len() < 2 {
529            return;
530        }
531
532        let rows = dimensions[0];
533        let cols = dimensions[1];
534        let indices: Vec<_> = self.history.iter().cloned().collect();
535
536        // Try different block sizes
537        let block_sizes_to_try = [
538            (2, 2),
539            (4, 4),
540            (8, 8),
541            (16, 16),
542            (32, 32),
543            (64, 64),
544            (rows, 4),
545            (4, cols),
546        ];
547
548        for &(block_height, block_width) in &block_sizes_to_try {
549            // Skip invalid block sizes
550            if block_height > rows || block_width > cols {
551                continue;
552            }
553
554            let mut block_accesses = HashMap::new();
555
556            // Group accesses by block
557            for &idx in &indices {
558                let row = idx / cols;
559                let col = idx % cols;
560
561                let block_row = row / block_height;
562                let block_col = col / block_width;
563
564                let block_id = (block_row, block_col);
565                let entry: &mut Vec<usize> = block_accesses.entry(block_id).or_default();
566                entry.push(idx);
567            }
568
569            // Check for complete blocks (where all elements in the block are accessed)
570            let mut complete_blocks = 0;
571            for accesses in block_accesses.values() {
572                if accesses.len() == block_height * block_width {
573                    complete_blocks += 1;
574                }
575            }
576
577            // Check if we have evidence of block-based access
578            if complete_blocks >= 2 && block_accesses.len() <= 10 {
579                let pattern = ComplexPattern::Block {
580                    block_height,
581                    block_width,
582                };
583
584                // Check if we already have this pattern
585                if let Some(existing) = self.find_pattern(&pattern) {
586                    // Confirm existing pattern
587                    existing.confirm();
588                } else {
589                    // Add new pattern
590                    let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
591                        .with_metadata("complete_blocks", &complete_blocks.to_string())
592                        .with_metadata("total_blocks", &block_accesses.len().to_string());
593                    self.patterns.push(pattern);
594                }
595            }
596        }
597
598        // Check for strided access within blocks
599        let mut block_strides = HashMap::new();
600
601        // Group consecutive accesses by stride
602        for i in 1..indices.len() {
603            let stride = indices[i].saturating_sub(indices[i.saturating_sub(1)]);
604            *block_strides.entry(stride).or_insert(0) += 1;
605        }
606
607        // Find the most common stride
608        if let Some((&stride, &count)) = block_strides.iter().max_by_key(|(_, &count)| count) {
609            if count >= indices.len() / 3 && stride > 1 {
610                // Try to determine if this is within-block striding
611                let possible_block_sizes = [8, 16, 32, 64, 128];
612
613                for &block_size in &possible_block_sizes {
614                    if stride < block_size && block_size % stride == 0 {
615                        let pattern = ComplexPattern::BlockStrided { block_size, stride };
616
617                        // Check if we already have this pattern
618                        if let Some(existing) = self.find_pattern(&pattern) {
619                            // Confirm existing pattern
620                            existing.confirm();
621                        } else {
622                            // Add new pattern
623                            let pattern = RecognizedPattern::new(pattern, Confidence::Low)
624                                .with_metadata("stride_count", &count.to_string())
625                                .with_metadata(
626                                    "total_transitions",
627                                    &(indices.len() - 1).to_string(),
628                                );
629                            self.patterns.push(pattern);
630                        }
631
632                        break; // Only add one block stride pattern
633                    }
634                }
635            }
636        }
637    }
638
639    /// Detect stencil operation patterns.
640    fn detect_stencil_patterns(&mut self, dimensions: &[usize]) {
641        if dimensions.len() < 2 {
642            return;
643        }
644
645        let _rows = dimensions[0];
646        let cols = dimensions[1];
647        let indices: Vec<_> = self.history.iter().cloned().collect();
648
649        // Look for classic stencil patterns (5-point stencil)
650        // Pattern: center, then 4 neighbors (N, E, S, W)
651        let mut stencil_groups = 0;
652
653        // Look for groups of 5 consecutive accesses that form a stencil
654        for window_start in 0..indices.len().saturating_sub(4) {
655            if window_start + 4 >= indices.len() {
656                break;
657            }
658
659            let center_idx = indices[window_start];
660            let center_row = center_idx / cols;
661            let center_col = center_idx % cols;
662
663            // Check if the next 4 accesses are neighbors of the center
664            let mut neighbors_found = 0;
665            let expected_neighbors = [
666                center_idx.saturating_sub(cols), // North
667                center_idx + 1,                  // East
668                center_idx + cols,               // South
669                center_idx.saturating_sub(1),    // West
670            ];
671
672            for offset in 1..=4 {
673                if window_start + offset < indices.len() {
674                    let neighbor_idx = indices[window_start + offset];
675                    if expected_neighbors.contains(&neighbor_idx) {
676                        neighbors_found += 1;
677                    }
678                }
679            }
680
681            // If we found at least 3 of the 4 expected neighbors, count as stencil
682            if neighbors_found >= 3 {
683                stencil_groups += 1;
684            }
685        }
686
687        // If we found enough stencil patterns, recognize it
688        if stencil_groups >= 3 {
689            let pattern = ComplexPattern::Stencil {
690                dimensions: 2,
691                radius: 1,
692            };
693
694            // Check if we already have this pattern
695            if let Some(existing) = self.find_pattern(&pattern) {
696                // Confirm existing pattern
697                existing.confirm();
698            } else {
699                // Add new pattern
700                let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
701                    .with_metadata("stencil_groups", &stencil_groups.to_string());
702                self.patterns.push(pattern);
703            }
704        }
705    }
706
707    /// Detect sparse access patterns.
708    fn detect_sparse_pattern(&mut self) {
709        let indices: Vec<_> = self.history.iter().cloned().collect();
710
711        // Skip if history is too small
712        if indices.len() < 20 {
713            return;
714        }
715
716        // Estimate the total space from the max index
717        if let Some(&max_idx) = indices.iter().max() {
718            let unique_indices = indices.iter().collect::<HashSet<_>>().len();
719
720            // Calculate density: unique indices accessed / total space
721            let density = unique_indices as f64 / (max_idx + 1) as f64;
722
723            // If density is low, consider it sparse
724            if density < 0.1 {
725                let pattern = ComplexPattern::Sparse { density };
726
727                // Check if we already have this pattern
728                if let Some(existing) = self.find_pattern(&pattern) {
729                    // Confirm existing pattern
730                    existing.confirm();
731                } else {
732                    // Add new pattern
733                    let confidence = if density < 0.01 {
734                        Confidence::High
735                    } else if density < 0.05 {
736                        Confidence::Medium
737                    } else {
738                        Confidence::Low
739                    };
740
741                    let pattern = RecognizedPattern::new(pattern, confidence)
742                        .with_metadata("unique_indices", &unique_indices.to_string())
743                        .with_metadata("max_index", &max_idx.to_string())
744                        .with_metadata("density", &format!("{density:.6}"));
745                    self.patterns.push(pattern);
746                }
747            }
748        }
749    }
750
751    /// Find an existing pattern by type.
752    fn find_pattern(&mut self, patterntype: &ComplexPattern) -> Option<&mut RecognizedPattern> {
753        self.patterns
754            .iter_mut()
755            .find(|p| &p.pattern_type == patterntype)
756    }
757
758    /// Get all recognized patterns ordered by confidence.
759    pub fn get_patterns(&self) -> Vec<&RecognizedPattern> {
760        let mut patterns: Vec<_> = self.patterns.iter().collect();
761        patterns.sort_by(|a, b| b.confidence.cmp(&a.confidence));
762        patterns
763    }
764
765    /// Get the best pattern for prefetching.
766    pub fn get_best_pattern(&self) -> Option<&RecognizedPattern> {
767        self.patterns
768            .iter()
769            .filter(|p| p.confidence >= Confidence::Medium)
770            .max_by_key(|p| p.confidence)
771    }
772
773    /// Get the current basic access pattern.
774    pub fn get_basic_pattern(&self) -> AccessPattern {
775        self.basic_pattern
776    }
777
778    /// Clear all detected patterns and history.
779    pub fn clear(&mut self) {
780        self.history.clear();
781        self.patterns.clear();
782        self.basic_pattern = AccessPattern::Random;
783    }
784}
785
786/// Factory for creating pattern recognizers.
787#[allow(dead_code)]
788pub struct PatternRecognizerFactory;
789
790#[allow(dead_code)]
791impl PatternRecognizerFactory {
792    /// Create a new pattern recognizer with default configuration.
793    pub fn create() -> PatternRecognizer {
794        PatternRecognizer::new(PatternRecognitionConfig::default())
795    }
796
797    /// Create a new pattern recognizer with the specified configuration.
798    pub fn create_with_config(config: PatternRecognitionConfig) -> PatternRecognizer {
799        PatternRecognizer::new(config)
800    }
801}
802
803/// Helper functions for converting between pattern types.
804pub mod pattern_utils {
805    use super::*;
806    use crate::memory_efficient::prefetch::AccessPattern;
807
808    /// Convert from complex pattern to basic pattern.
809    #[allow(dead_code)]
810    pub fn to_basic_pattern(pattern: &ComplexPattern) -> AccessPattern {
811        match pattern {
812            ComplexPattern::RowMajor => AccessPattern::Sequential,
813            ComplexPattern::ColumnMajor => AccessPattern::Strided(0), // Stride depends on dimensions
814            ComplexPattern::Zigzag => AccessPattern::Custom,
815            ComplexPattern::DiagonalMajor => AccessPattern::Custom,
816            ComplexPattern::DiagonalMinor => AccessPattern::Custom,
817            ComplexPattern::Block { .. } => AccessPattern::Custom,
818            ComplexPattern::BlockStrided { stride, .. } => AccessPattern::Strided(*stride),
819            ComplexPattern::Stencil { .. } => AccessPattern::Custom,
820            ComplexPattern::RotatingBlock { .. } => AccessPattern::Custom,
821            ComplexPattern::Sparse { .. } => AccessPattern::Random,
822            ComplexPattern::Hierarchical { .. } => AccessPattern::Custom,
823            ComplexPattern::Custom(_) => AccessPattern::Custom,
824        }
825    }
826
827    /// Get the prefetch pattern for a complex pattern.
828    #[allow(dead_code)]
829    pub fn get_prefetch_pattern(
830        pattern: &ComplexPattern,
831        dimensions: &[usize],
832        current_idx: usize,
833        prefetch_count: usize,
834    ) -> Vec<usize> {
835        match pattern {
836            ComplexPattern::RowMajor => {
837                // For row-major, prefetch the next sequential indices
838                (1..=prefetch_count).map(|i| current_idx + i).collect()
839            }
840            ComplexPattern::ColumnMajor => {
841                if dimensions.len() >= 2 {
842                    let stride = dimensions[0];
843                    // For column-major, prefetch with stride equal to number of rows
844                    (1..=prefetch_count)
845                        .map(|i| current_idx + stride * i)
846                        .collect()
847                } else {
848                    // Default to sequential if dimensions unknown
849                    (1..=prefetch_count).map(|i| current_idx + i).collect()
850                }
851            }
852            ComplexPattern::Zigzag => {
853                if dimensions.len() >= 2 {
854                    let cols = dimensions[1];
855                    let row = current_idx / cols;
856                    let col = current_idx % cols;
857
858                    // In zigzag, alternating rows go in opposite directions
859                    let mut result = Vec::with_capacity(prefetch_count);
860
861                    if row % 2 == 0 {
862                        // Even rows go left to right
863                        for i in 1..=prefetch_count {
864                            if col + i < cols {
865                                // Continue in this row
866                                result.push(current_idx + i);
867                            } else {
868                                // Next row, right to left
869                                let overflow = (col + i) - cols;
870                                result.push(current_idx + (cols - col) + (cols - 1) - overflow);
871                            }
872                        }
873                    } else {
874                        // Odd rows go right to left
875                        for i in 1..=prefetch_count {
876                            if col >= i {
877                                // Continue in this row
878                                result.push(current_idx - i);
879                            } else {
880                                // Next row, left to right
881                                let overflow = i - col;
882                                result.push(current_idx + (col + 1) + overflow);
883                            }
884                        }
885                    }
886
887                    result
888                } else {
889                    // Default to sequential if dimensions unknown
890                    (1..=prefetch_count).map(|i| current_idx + i).collect()
891                }
892            }
893            ComplexPattern::DiagonalMajor => {
894                if dimensions.len() >= 2 {
895                    let cols = dimensions[1];
896                    // For diagonal, move down and right
897                    (1..=prefetch_count)
898                        .map(|i| current_idx + cols * i + i)
899                        .collect()
900                } else {
901                    // Default to sequential if dimensions unknown
902                    (1..=prefetch_count).map(|i| current_idx + i).collect()
903                }
904            }
905            ComplexPattern::DiagonalMinor => {
906                if dimensions.len() >= 2 {
907                    let cols = dimensions[1];
908                    // For anti-diagonal, move down and left
909                    (1..=prefetch_count)
910                        .map(|i| current_idx + cols * i - i)
911                        .collect()
912                } else {
913                    // Default to sequential if dimensions unknown
914                    (1..=prefetch_count).map(|i| current_idx + i).collect()
915                }
916            }
917            ComplexPattern::Block {
918                block_height,
919                block_width,
920            } => {
921                if dimensions.len() >= 2 {
922                    let cols = dimensions[1];
923                    let row = current_idx / cols;
924                    let col = current_idx % cols;
925
926                    // Calculate block coordinates
927                    let block_row = row / *block_height;
928                    let block_col = col / *block_width;
929
930                    // Calculate position within block
931                    let block_row_offset = row % *block_height;
932                    let block_col_offset = col % *block_width;
933
934                    // Predict next positions within the block (row-major within block)
935                    let mut result = Vec::with_capacity(prefetch_count);
936                    let mut remaining = prefetch_count;
937
938                    // First, complete the current row in the block
939                    for i in 1..=std::cmp::min(*block_width - block_col_offset, remaining) {
940                        result.push(current_idx + i);
941                        remaining -= 1;
942                    }
943
944                    // Then, continue with subsequent rows in the block
945                    let mut next_row = block_row_offset + 1;
946                    while remaining > 0 && next_row < *block_height {
947                        for col_offset in 0..std::cmp::min(*block_width, remaining) {
948                            let idx = (block_row * *block_height + next_row) * cols
949                                + block_col * *block_width
950                                + col_offset;
951                            result.push(idx);
952                            remaining -= 1;
953                        }
954                        next_row += 1;
955                    }
956
957                    // If still remaining, move to next block
958                    if remaining > 0 {
959                        let next_block_row = if block_col + 1 < cols / *block_width {
960                            block_row // Same row, next column
961                        } else {
962                            block_row + 1 // Next row, first column
963                        };
964
965                        let next_block_col = if block_col + 1 < cols / *block_width {
966                            block_col + 1 // Next column
967                        } else {
968                            0 // First column
969                        };
970
971                        // Add first few elements of next block
972                        for i in 0..remaining {
973                            let row_offset = i / *block_width;
974                            let col_offset = i % *block_width;
975                            let idx = (next_block_row * *block_height + row_offset) * cols
976                                + next_block_col * *block_width
977                                + col_offset;
978                            result.push(idx);
979                        }
980                    }
981
982                    result
983                } else {
984                    // Default to sequential if dimensions unknown
985                    (1..=prefetch_count).map(|i| current_idx + i).collect()
986                }
987            }
988            ComplexPattern::BlockStrided { block_size, stride } => {
989                // Prefetch with the specified stride within the block
990                (1..=prefetch_count)
991                    .map(|i| {
992                        let offset = i * stride;
993                        let block_offset = offset % block_size;
994                        let blocks_advanced = offset / block_size;
995
996                        if blocks_advanced == 0 {
997                            // Still in same block
998                            current_idx + offset
999                        } else {
1000                            // Advanced to next block(s)
1001                            current_idx + block_size * blocks_advanced + block_offset
1002                        }
1003                    })
1004                    .collect()
1005            }
1006            ComplexPattern::Stencil {
1007                dimensions: dim_count,
1008                radius,
1009            } => {
1010                if dimensions.len() >= *dim_count {
1011                    let cols = dimensions[1];
1012                    let row = current_idx / cols;
1013                    let col = current_idx % cols;
1014
1015                    // In a stencil operation, predict accesses to neighboring cells
1016                    let mut result = Vec::new();
1017
1018                    // Add cells in a radius around the current position
1019                    for r in -(*radius as isize)..=(*radius as isize) {
1020                        for c in -(*radius as isize)..=(*radius as isize) {
1021                            // Skip the center (current) position
1022                            if r == 0 && c == 0 {
1023                                continue;
1024                            }
1025
1026                            let new_row = row as isize + r;
1027                            let new_col = col as isize + c;
1028
1029                            // Check bounds
1030                            if new_row >= 0
1031                                && new_row < dimensions[0] as isize
1032                                && new_col >= 0
1033                                && new_col < cols as isize
1034                            {
1035                                let idx = (new_row as usize) * cols + (new_col as usize);
1036                                result.push(idx);
1037                            }
1038                        }
1039                    }
1040
1041                    // Take only the requested number of predictions
1042                    result.into_iter().take(prefetch_count).collect()
1043                } else {
1044                    // Default to sequential if dimensions unknown
1045                    (1..=prefetch_count).map(|i| current_idx + i).collect()
1046                }
1047            }
1048            // For other patterns, default to nearby elements
1049            _ => {
1050                let mut result = Vec::with_capacity(prefetch_count);
1051
1052                // Try to prefetch a mix of sequential and nearby indices
1053                for i in 1..=prefetch_count / 2 {
1054                    result.push(current_idx + i);
1055                }
1056
1057                if dimensions.len() >= 2 {
1058                    let cols = dimensions[1];
1059                    // Add row above and below
1060                    result.push(current_idx.saturating_sub(cols));
1061                    result.push(current_idx + cols);
1062                }
1063
1064                // Fill remaining slots with sequential prefetches
1065                while result.len() < prefetch_count {
1066                    result.push(current_idx + result.len() + 1);
1067                }
1068
1069                // Deduplicate
1070                result
1071                    .into_iter()
1072                    .collect::<HashSet<_>>()
1073                    .into_iter()
1074                    .collect()
1075            }
1076        }
1077    }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082    use super::*;
1083
1084    #[test]
1085    fn test_row_major_detection() {
1086        let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1087        recognizer.set_dimensions(vec![8, 8]);
1088
1089        // Record row-major traversal (sequential)
1090        for i in 0..64 {
1091            recognizer.record_access(i);
1092        }
1093
1094        // Get detected patterns
1095        let patterns = recognizer.get_patterns();
1096
1097        // Should detect row-major pattern
1098        assert!(patterns
1099            .iter()
1100            .any(|p| matches!(p.pattern_type, ComplexPattern::RowMajor)));
1101
1102        // Check basic pattern
1103        assert_eq!(recognizer.get_basic_pattern(), AccessPattern::Sequential);
1104    }
1105
1106    #[test]
1107    fn test_column_major_detection() {
1108        let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1109        recognizer.set_dimensions(vec![8, 8]);
1110
1111        // Record column-major traversal
1112        for j in 0..8 {
1113            for i in 0..8 {
1114                recognizer.record_access(i * 8 + j);
1115            }
1116        }
1117
1118        // Get detected patterns
1119        let patterns = recognizer.get_patterns();
1120
1121        // Should detect column-major pattern
1122        assert!(patterns
1123            .iter()
1124            .any(|p| matches!(p.pattern_type, ComplexPattern::ColumnMajor)));
1125
1126        // Check basic pattern - should be strided
1127        assert!(matches!(
1128            recognizer.get_basic_pattern(),
1129            AccessPattern::Strided(_)
1130        ));
1131    }
1132
1133    #[test]
1134    fn test_zigzag_detection() {
1135        let config = PatternRecognitionConfig {
1136            min_history_size: 10, // Lower threshold for test
1137            ..Default::default()
1138        };
1139        let mut recognizer = PatternRecognizer::new(config);
1140        recognizer.set_dimensions(vec![8, 8]);
1141
1142        // Record zigzag traversal - multiple complete rows to ensure enough data
1143        for row in 0..8 {
1144            if row % 2 == 0 {
1145                // Even rows: left to right
1146                for j in 0..8 {
1147                    recognizer.record_access(row * 8 + j);
1148                }
1149            } else {
1150                // Odd rows: right to left
1151                for j in (0..8).rev() {
1152                    recognizer.record_access(row * 8 + j);
1153                }
1154            }
1155        }
1156
1157        // Get detected patterns
1158        let patterns = recognizer.get_patterns();
1159
1160        // Should detect zigzag pattern
1161        assert!(patterns
1162            .iter()
1163            .any(|p| matches!(p.pattern_type, ComplexPattern::Zigzag)));
1164    }
1165
1166    #[test]
1167    fn test_diagonal_detection() {
1168        let config = PatternRecognitionConfig {
1169            min_history_size: 10, // Lower threshold for test
1170            ..Default::default()
1171        };
1172        let mut recognizer = PatternRecognizer::new(config);
1173        recognizer.set_dimensions(vec![16, 16]);
1174
1175        // Record diagonal traversal - longer diagonal to ensure enough data
1176        for i in 0..16 {
1177            recognizer.record_access(i * 16 + i);
1178        }
1179
1180        // Add a few more diagonal elements to strengthen the pattern
1181        for i in 0..8 {
1182            recognizer.record_access(i * 16 + i);
1183        }
1184
1185        // Get detected patterns
1186        let patterns = recognizer.get_patterns();
1187
1188        // Should detect diagonal pattern
1189        assert!(patterns
1190            .iter()
1191            .any(|p| matches!(p.pattern_type, ComplexPattern::DiagonalMajor)));
1192    }
1193
1194    #[test]
1195    fn test_block_detection() {
1196        let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1197        recognizer.set_dimensions(vec![8, 8]);
1198
1199        // Record block traversal (4x4 blocks)
1200        // First block (top-left)
1201        for i in 0..4 {
1202            for j in 0..4 {
1203                recognizer.record_access(i * 8 + j);
1204            }
1205        }
1206        // Second block (top-right)
1207        for i in 0..4 {
1208            for j in 4..8 {
1209                recognizer.record_access(i * 8 + j);
1210            }
1211        }
1212
1213        // Get detected patterns
1214        let patterns = recognizer.get_patterns();
1215
1216        // Should detect block pattern
1217        assert!(patterns.iter().any(|p| {
1218            if let ComplexPattern::Block {
1219                block_height,
1220                block_width,
1221            } = p.pattern_type
1222            {
1223                block_height == 4 && block_width == 4
1224            } else {
1225                false
1226            }
1227        }));
1228    }
1229
1230    #[test]
1231    fn test_stencil_detection() {
1232        let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1233        recognizer.set_dimensions(vec![10, 10]);
1234
1235        // Record stencil operations (5-point stencil)
1236        for i in 1..9 {
1237            for j in 1..9 {
1238                // Center point
1239                let center = i * 10 + j;
1240                recognizer.record_access(center);
1241
1242                // 4 neighbors (north, east, south, west)
1243                recognizer.record_access(center - 10); // North
1244                recognizer.record_access(center + 1); // East
1245                recognizer.record_access(center + 10); // South
1246                recognizer.record_access(center - 1); // West
1247            }
1248        }
1249
1250        // Get detected patterns
1251        let patterns = recognizer.get_patterns();
1252
1253        // Should detect stencil pattern
1254        assert!(patterns.iter().any(|p| {
1255            if let ComplexPattern::Stencil { dimensions, radius } = p.pattern_type {
1256                dimensions == 2 && radius == 1
1257            } else {
1258                false
1259            }
1260        }));
1261    }
1262
1263    #[test]
1264    fn test_pattern_utils() {
1265        // Test row-major prefetching
1266        let pattern = ComplexPattern::RowMajor;
1267        let dimensions = vec![8, 8];
1268        let current_idx = 10;
1269        let prefetch_count = 3;
1270
1271        let prefetches =
1272            pattern_utils::get_prefetch_pattern(&pattern, &dimensions, current_idx, prefetch_count);
1273
1274        // Should prefetch the next 3 indices
1275        assert_eq!(prefetches, vec![11, 12, 13]);
1276
1277        // Test diagonal prefetching
1278        let pattern = ComplexPattern::DiagonalMajor;
1279
1280        let prefetches =
1281            pattern_utils::get_prefetch_pattern(&pattern, &dimensions, current_idx, prefetch_count);
1282
1283        // Should prefetch along the diagonal (down-right)
1284        // Each step adds row_stride + 1
1285        assert_eq!(prefetches, vec![19, 28, 37]);
1286    }
1287}