1use std::collections::{HashMap, HashSet, VecDeque};
12use std::time::Instant;
13
14use super::prefetch::AccessPattern;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum ComplexPattern {
19 RowMajor,
21
22 ColumnMajor,
24
25 Zigzag,
27
28 DiagonalMajor,
30
31 DiagonalMinor,
33
34 Block {
36 block_height: usize,
37 block_width: usize,
38 },
39
40 BlockStrided { block_size: usize, stride: usize },
42
43 Stencil { dimensions: usize, radius: usize },
45
46 RotatingBlock { block_size: usize },
48
49 Sparse { density: f64 },
51
52 Hierarchical { levels: usize },
54
55 Custom(String),
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
61pub enum Confidence {
62 High,
64
65 Medium,
67
68 Low,
70
71 Tentative,
73}
74
75#[derive(Debug, Clone)]
77pub struct RecognizedPattern {
78 pub pattern_type: ComplexPattern,
80
81 pub confidence: Confidence,
83
84 pub metadata: HashMap<String, String>,
86
87 pub first_detected: Instant,
89
90 pub last_confirmed: Instant,
92
93 pub confirmation_count: usize,
95}
96
97impl RecognizedPattern {
98 pub fn new(patterntype: ComplexPattern, confidence: Confidence) -> Self {
100 let now = Instant::now();
101 Self {
102 pattern_type: patterntype,
103 confidence,
104 metadata: HashMap::new(),
105 first_detected: now,
106 last_confirmed: now,
107 confirmation_count: 1,
108 }
109 }
110
111 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
113 self.metadata.insert(key.to_string(), value.to_string());
114 self
115 }
116
117 pub fn confirm(&mut self) {
119 self.confirmation_count += 1;
120 self.last_confirmed = Instant::now();
121
122 if self.confirmation_count >= 10 {
124 self.confidence = Confidence::High;
125 } else if self.confirmation_count >= 5 {
126 self.confidence = Confidence::Medium;
127 } else if self.confirmation_count >= 2 {
128 self.confidence = Confidence::Low;
129 }
130 }
131
132 pub fn is_valid(&self, maxage: std::time::Duration) -> bool {
134 self.last_confirmed.elapsed() <= maxage
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct PatternRecognitionConfig {
141 pub min_history_size: usize,
143
144 pub pattern_expiry: std::time::Duration,
146
147 pub detect_diagonal: bool,
149
150 pub detect_block: bool,
152
153 pub detect_stencil: bool,
155
156 pub detect_sparse: bool,
158
159 pub use_machine_learning: bool,
161}
162
163impl Default for PatternRecognitionConfig {
164 fn default() -> Self {
165 Self {
166 min_history_size: 20,
167 pattern_expiry: std::time::Duration::from_secs(60),
168 detect_diagonal: true,
169 detect_block: true,
170 detect_stencil: true,
171 detect_sparse: true,
172 use_machine_learning: false, }
174 }
175}
176
177#[derive(Debug)]
179pub struct PatternRecognizer {
180 config: PatternRecognitionConfig,
182
183 dimensions: Option<Vec<usize>>,
185
186 history: VecDeque<usize>,
188
189 patterns: Vec<RecognizedPattern>,
191
192 basic_pattern: AccessPattern,
194}
195
196impl PatternRecognizer {
197 pub fn new(config: PatternRecognitionConfig) -> Self {
199 Self {
200 config,
201 dimensions: None,
202 history: VecDeque::with_capacity(100),
203 patterns: Vec::new(),
204 basic_pattern: AccessPattern::Random,
205 }
206 }
207
208 pub fn set_dimensions(&mut self, dimensions: Vec<usize>) {
210 self.dimensions = Some(dimensions);
211 }
212
213 pub fn record_access(&mut self, index: usize) {
215 self.history.push_back(index);
216
217 while self.history.len() > 100 {
219 self.history.pop_front();
220 }
221
222 if self.history.len() >= self.config.min_history_size {
224 self.detect_patterns();
225 }
226 }
227
228 fn detect_patterns(&mut self) {
230 self.patterns
232 .retain(|pattern| pattern.is_valid(self.config.pattern_expiry));
233
234 self.detect_basic_patterns();
236
237 if let Some(dims) = self.dimensions.clone() {
239 if dims.len() >= 2 {
241 self.detectmatrix_patterns(&dims);
242 }
243
244 if self.config.detect_block && dims.len() >= 2 {
246 self.detect_block_patterns(&dims);
247 }
248
249 if self.config.detect_diagonal && dims.len() == 2 {
251 self.detect_diagonal_patterns(&dims);
252 }
253
254 if self.config.detect_stencil && dims.len() >= 2 {
256 self.detect_stencil_patterns(&dims);
257 }
258 }
259
260 if self.config.detect_sparse {
262 self.detect_sparse_pattern();
263 }
264 }
265
266 fn detect_basic_patterns(&mut self) {
268 let indices: Vec<_> = self.history.iter().cloned().collect();
269
270 let mut sequential_count = 0;
272 for i in 1..indices.len() {
273 if indices[i] == indices[i.saturating_sub(1)] + 1 {
274 sequential_count += 1;
275 }
276 }
277
278 if sequential_count >= indices.len() * 3 / 4 {
279 self.basic_pattern = AccessPattern::Sequential;
280
281 if let Some(ref dims) = self.dimensions {
283 if dims.len() >= 2 {
284 let row_size = dims[1];
285 let pattern = ComplexPattern::RowMajor;
286
287 if let Some(existing) = self.find_pattern(&pattern) {
289 existing.confirm();
291 } else {
292 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
294 .with_metadata("row_size", &row_size.to_string());
295 self.patterns.push(pattern);
296 }
297 }
298 }
299
300 return;
301 }
302
303 let mut best_stride = 0;
305 let mut best_stride_count = 0;
306
307 for stride in 2..=20 {
308 let mut stride_count = 0;
309 for i in 1..indices.len() {
310 if indices[i].saturating_sub(indices[i.saturating_sub(1)]) == stride {
311 stride_count += 1;
312 }
313 }
314
315 if stride_count > best_stride_count {
316 best_stride_count = stride_count;
317 best_stride = stride;
318 }
319 }
320
321 if best_stride_count >= indices.len() * 2 / 3 {
322 self.basic_pattern = AccessPattern::Strided(best_stride);
323
324 if let Some(ref dims) = self.dimensions {
326 if dims.len() >= 2 {
327 let num_rows = dims[0];
328
329 if best_stride == num_rows {
330 let pattern = ComplexPattern::ColumnMajor;
331
332 if let Some(existing) = self.find_pattern(&pattern) {
334 existing.confirm();
336 } else {
337 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
339 .with_metadata("num_rows", &num_rows.to_string());
340 self.patterns.push(pattern);
341 }
342 }
343 }
344 }
345
346 return;
347 }
348
349 self.basic_pattern = AccessPattern::Random;
351 }
352
353 fn detectmatrix_patterns(&mut self, dimensions: &[usize]) {
355 if dimensions.len() < 2 {
356 return;
357 }
358
359 let _rows = dimensions[0];
360 let cols = dimensions[1];
361 let indices: Vec<_> = self.history.iter().cloned().collect();
362
363 let mut zigzag_evidence = 0;
365 let mut last_row_direction = None;
366
367 let mut rows: HashMap<usize, Vec<(usize, usize)>> = HashMap::new(); for (access_order, &idx) in indices.iter().enumerate() {
370 let row = idx / cols;
371 let col = idx % cols;
372 rows.entry(row).or_default().push((col, access_order));
373 }
374
375 let sorted_rows: Vec<_> = {
377 let mut sorted = rows.keys().cloned().collect::<Vec<_>>();
378 sorted.sort();
379 sorted
380 };
381
382 for row_num in &sorted_rows {
383 let mut cols_in_row = rows[row_num].clone();
384 if cols_in_row.len() >= 2 {
385 cols_in_row.sort_by_key(|(_, access_order)| *access_order);
387
388 let mut increasing = 0;
391 let mut decreasing = 0;
392 for i in 1..cols_in_row.len() {
393 match cols_in_row[i].0.cmp(&cols_in_row[i.saturating_sub(1)].0) {
394 std::cmp::Ordering::Greater => increasing += 1,
395 std::cmp::Ordering::Less => decreasing += 1,
396 std::cmp::Ordering::Equal => {}
397 }
398 }
399
400 let current_direction = match increasing.cmp(&decreasing) {
401 std::cmp::Ordering::Greater => 1, std::cmp::Ordering::Less => -1, std::cmp::Ordering::Equal => 0, };
405
406 if current_direction != 0 {
408 if let Some(prev_direction) = last_row_direction {
409 if current_direction != prev_direction && prev_direction != 0 {
410 zigzag_evidence += 1;
411 }
412 }
413 last_row_direction = Some(current_direction);
414 }
415 }
416 }
417
418 if zigzag_evidence >= 2 && sorted_rows.len() >= 3 {
421 let pattern = ComplexPattern::Zigzag;
422
423 if let Some(existing) = self.find_pattern(&pattern) {
425 existing.confirm();
427 } else {
428 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
430 .with_metadata("zigzag_evidence", &zigzag_evidence.to_string());
431 self.patterns.push(pattern);
432 }
433 }
434 }
435
436 fn detect_diagonal_patterns(&mut self, dimensions: &[usize]) {
438 if dimensions.len() != 2 {
439 return;
440 }
441
442 let _rows = dimensions[0];
443 let cols = dimensions[1];
444 let indices: Vec<_> = self.history.iter().cloned().collect();
445
446 let mut diagonal_matches = 0;
448 for i in 1..indices.len() {
449 let prev_idx = indices[i.saturating_sub(1)];
450 let curr_idx = indices[i];
451
452 let prev_row = prev_idx / cols;
453 let prev_col = prev_idx % cols;
454
455 let curr_row = curr_idx / cols;
456 let curr_col = curr_idx % cols;
457
458 if curr_row == prev_row + 1 && curr_col == prev_col + 1 {
460 diagonal_matches += 1;
461 }
462 }
463
464 let expected_transitions = indices.len().saturating_sub(1);
467 if (diagonal_matches >= expected_transitions / 3 || diagonal_matches >= 3)
469 && diagonal_matches > 0
470 {
471 let pattern = ComplexPattern::DiagonalMajor;
472
473 if let Some(existing) = self.find_pattern(&pattern) {
475 existing.confirm();
477 } else {
478 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
480 .with_metadata("diagonal_matches", &diagonal_matches.to_string());
481 self.patterns.push(pattern);
482 }
483
484 return;
485 }
486
487 let mut anti_diagonal_matches = 0;
489 for i in 1..indices.len() {
490 let prev_idx = indices[i.saturating_sub(1)];
491 let curr_idx = indices[i];
492
493 let prev_row = prev_idx / cols;
494 let prev_col = prev_idx % cols;
495
496 let curr_row = curr_idx / cols;
497 let curr_col = curr_idx % cols;
498
499 if curr_row == prev_row + 1 && curr_col + 1 == prev_col {
501 anti_diagonal_matches += 1;
502 }
503 }
504
505 let expected_transitions = indices.len().saturating_sub(1);
507 if (anti_diagonal_matches >= expected_transitions / 3 || anti_diagonal_matches >= 3)
509 && anti_diagonal_matches > 0
510 {
511 let pattern = ComplexPattern::DiagonalMinor;
512
513 if let Some(existing) = self.find_pattern(&pattern) {
515 existing.confirm();
517 } else {
518 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
520 .with_metadata("anti_diagonal_matches", &anti_diagonal_matches.to_string());
521 self.patterns.push(pattern);
522 }
523 }
524 }
525
526 fn detect_block_patterns(&mut self, dimensions: &[usize]) {
528 if dimensions.len() < 2 {
529 return;
530 }
531
532 let rows = dimensions[0];
533 let cols = dimensions[1];
534 let indices: Vec<_> = self.history.iter().cloned().collect();
535
536 let block_sizes_to_try = [
538 (2, 2),
539 (4, 4),
540 (8, 8),
541 (16, 16),
542 (32, 32),
543 (64, 64),
544 (rows, 4),
545 (4, cols),
546 ];
547
548 for &(block_height, block_width) in &block_sizes_to_try {
549 if block_height > rows || block_width > cols {
551 continue;
552 }
553
554 let mut block_accesses = HashMap::new();
555
556 for &idx in &indices {
558 let row = idx / cols;
559 let col = idx % cols;
560
561 let block_row = row / block_height;
562 let block_col = col / block_width;
563
564 let block_id = (block_row, block_col);
565 let entry: &mut Vec<usize> = block_accesses.entry(block_id).or_default();
566 entry.push(idx);
567 }
568
569 let mut complete_blocks = 0;
571 for accesses in block_accesses.values() {
572 if accesses.len() == block_height * block_width {
573 complete_blocks += 1;
574 }
575 }
576
577 if complete_blocks >= 2 && block_accesses.len() <= 10 {
579 let pattern = ComplexPattern::Block {
580 block_height,
581 block_width,
582 };
583
584 if let Some(existing) = self.find_pattern(&pattern) {
586 existing.confirm();
588 } else {
589 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
591 .with_metadata("complete_blocks", &complete_blocks.to_string())
592 .with_metadata("total_blocks", &block_accesses.len().to_string());
593 self.patterns.push(pattern);
594 }
595 }
596 }
597
598 let mut block_strides = HashMap::new();
600
601 for i in 1..indices.len() {
603 let stride = indices[i].saturating_sub(indices[i.saturating_sub(1)]);
604 *block_strides.entry(stride).or_insert(0) += 1;
605 }
606
607 if let Some((&stride, &count)) = block_strides.iter().max_by_key(|(_, &count)| count) {
609 if count >= indices.len() / 3 && stride > 1 {
610 let possible_block_sizes = [8, 16, 32, 64, 128];
612
613 for &block_size in &possible_block_sizes {
614 if stride < block_size && block_size % stride == 0 {
615 let pattern = ComplexPattern::BlockStrided { block_size, stride };
616
617 if let Some(existing) = self.find_pattern(&pattern) {
619 existing.confirm();
621 } else {
622 let pattern = RecognizedPattern::new(pattern, Confidence::Low)
624 .with_metadata("stride_count", &count.to_string())
625 .with_metadata(
626 "total_transitions",
627 &(indices.len() - 1).to_string(),
628 );
629 self.patterns.push(pattern);
630 }
631
632 break; }
634 }
635 }
636 }
637 }
638
639 fn detect_stencil_patterns(&mut self, dimensions: &[usize]) {
641 if dimensions.len() < 2 {
642 return;
643 }
644
645 let _rows = dimensions[0];
646 let cols = dimensions[1];
647 let indices: Vec<_> = self.history.iter().cloned().collect();
648
649 let mut stencil_groups = 0;
652
653 for window_start in 0..indices.len().saturating_sub(4) {
655 if window_start + 4 >= indices.len() {
656 break;
657 }
658
659 let center_idx = indices[window_start];
660 let center_row = center_idx / cols;
661 let center_col = center_idx % cols;
662
663 let mut neighbors_found = 0;
665 let expected_neighbors = [
666 center_idx.saturating_sub(cols), center_idx + 1, center_idx + cols, center_idx.saturating_sub(1), ];
671
672 for offset in 1..=4 {
673 if window_start + offset < indices.len() {
674 let neighbor_idx = indices[window_start + offset];
675 if expected_neighbors.contains(&neighbor_idx) {
676 neighbors_found += 1;
677 }
678 }
679 }
680
681 if neighbors_found >= 3 {
683 stencil_groups += 1;
684 }
685 }
686
687 if stencil_groups >= 3 {
689 let pattern = ComplexPattern::Stencil {
690 dimensions: 2,
691 radius: 1,
692 };
693
694 if let Some(existing) = self.find_pattern(&pattern) {
696 existing.confirm();
698 } else {
699 let pattern = RecognizedPattern::new(pattern, Confidence::Medium)
701 .with_metadata("stencil_groups", &stencil_groups.to_string());
702 self.patterns.push(pattern);
703 }
704 }
705 }
706
707 fn detect_sparse_pattern(&mut self) {
709 let indices: Vec<_> = self.history.iter().cloned().collect();
710
711 if indices.len() < 20 {
713 return;
714 }
715
716 if let Some(&max_idx) = indices.iter().max() {
718 let unique_indices = indices.iter().collect::<HashSet<_>>().len();
719
720 let density = unique_indices as f64 / (max_idx + 1) as f64;
722
723 if density < 0.1 {
725 let pattern = ComplexPattern::Sparse { density };
726
727 if let Some(existing) = self.find_pattern(&pattern) {
729 existing.confirm();
731 } else {
732 let confidence = if density < 0.01 {
734 Confidence::High
735 } else if density < 0.05 {
736 Confidence::Medium
737 } else {
738 Confidence::Low
739 };
740
741 let pattern = RecognizedPattern::new(pattern, confidence)
742 .with_metadata("unique_indices", &unique_indices.to_string())
743 .with_metadata("max_index", &max_idx.to_string())
744 .with_metadata("density", &format!("{density:.6}"));
745 self.patterns.push(pattern);
746 }
747 }
748 }
749 }
750
751 fn find_pattern(&mut self, patterntype: &ComplexPattern) -> Option<&mut RecognizedPattern> {
753 self.patterns
754 .iter_mut()
755 .find(|p| &p.pattern_type == patterntype)
756 }
757
758 pub fn get_patterns(&self) -> Vec<&RecognizedPattern> {
760 let mut patterns: Vec<_> = self.patterns.iter().collect();
761 patterns.sort_by(|a, b| b.confidence.cmp(&a.confidence));
762 patterns
763 }
764
765 pub fn get_best_pattern(&self) -> Option<&RecognizedPattern> {
767 self.patterns
768 .iter()
769 .filter(|p| p.confidence >= Confidence::Medium)
770 .max_by_key(|p| p.confidence)
771 }
772
773 pub fn get_basic_pattern(&self) -> AccessPattern {
775 self.basic_pattern
776 }
777
778 pub fn clear(&mut self) {
780 self.history.clear();
781 self.patterns.clear();
782 self.basic_pattern = AccessPattern::Random;
783 }
784}
785
786#[allow(dead_code)]
788pub struct PatternRecognizerFactory;
789
790#[allow(dead_code)]
791impl PatternRecognizerFactory {
792 pub fn create() -> PatternRecognizer {
794 PatternRecognizer::new(PatternRecognitionConfig::default())
795 }
796
797 pub fn create_with_config(config: PatternRecognitionConfig) -> PatternRecognizer {
799 PatternRecognizer::new(config)
800 }
801}
802
803pub mod pattern_utils {
805 use super::*;
806 use crate::memory_efficient::prefetch::AccessPattern;
807
808 #[allow(dead_code)]
810 pub fn to_basic_pattern(pattern: &ComplexPattern) -> AccessPattern {
811 match pattern {
812 ComplexPattern::RowMajor => AccessPattern::Sequential,
813 ComplexPattern::ColumnMajor => AccessPattern::Strided(0), ComplexPattern::Zigzag => AccessPattern::Custom,
815 ComplexPattern::DiagonalMajor => AccessPattern::Custom,
816 ComplexPattern::DiagonalMinor => AccessPattern::Custom,
817 ComplexPattern::Block { .. } => AccessPattern::Custom,
818 ComplexPattern::BlockStrided { stride, .. } => AccessPattern::Strided(*stride),
819 ComplexPattern::Stencil { .. } => AccessPattern::Custom,
820 ComplexPattern::RotatingBlock { .. } => AccessPattern::Custom,
821 ComplexPattern::Sparse { .. } => AccessPattern::Random,
822 ComplexPattern::Hierarchical { .. } => AccessPattern::Custom,
823 ComplexPattern::Custom(_) => AccessPattern::Custom,
824 }
825 }
826
827 #[allow(dead_code)]
829 pub fn get_prefetch_pattern(
830 pattern: &ComplexPattern,
831 dimensions: &[usize],
832 current_idx: usize,
833 prefetch_count: usize,
834 ) -> Vec<usize> {
835 match pattern {
836 ComplexPattern::RowMajor => {
837 (1..=prefetch_count).map(|i| current_idx + i).collect()
839 }
840 ComplexPattern::ColumnMajor => {
841 if dimensions.len() >= 2 {
842 let stride = dimensions[0];
843 (1..=prefetch_count)
845 .map(|i| current_idx + stride * i)
846 .collect()
847 } else {
848 (1..=prefetch_count).map(|i| current_idx + i).collect()
850 }
851 }
852 ComplexPattern::Zigzag => {
853 if dimensions.len() >= 2 {
854 let cols = dimensions[1];
855 let row = current_idx / cols;
856 let col = current_idx % cols;
857
858 let mut result = Vec::with_capacity(prefetch_count);
860
861 if row % 2 == 0 {
862 for i in 1..=prefetch_count {
864 if col + i < cols {
865 result.push(current_idx + i);
867 } else {
868 let overflow = (col + i) - cols;
870 result.push(current_idx + (cols - col) + (cols - 1) - overflow);
871 }
872 }
873 } else {
874 for i in 1..=prefetch_count {
876 if col >= i {
877 result.push(current_idx - i);
879 } else {
880 let overflow = i - col;
882 result.push(current_idx + (col + 1) + overflow);
883 }
884 }
885 }
886
887 result
888 } else {
889 (1..=prefetch_count).map(|i| current_idx + i).collect()
891 }
892 }
893 ComplexPattern::DiagonalMajor => {
894 if dimensions.len() >= 2 {
895 let cols = dimensions[1];
896 (1..=prefetch_count)
898 .map(|i| current_idx + cols * i + i)
899 .collect()
900 } else {
901 (1..=prefetch_count).map(|i| current_idx + i).collect()
903 }
904 }
905 ComplexPattern::DiagonalMinor => {
906 if dimensions.len() >= 2 {
907 let cols = dimensions[1];
908 (1..=prefetch_count)
910 .map(|i| current_idx + cols * i - i)
911 .collect()
912 } else {
913 (1..=prefetch_count).map(|i| current_idx + i).collect()
915 }
916 }
917 ComplexPattern::Block {
918 block_height,
919 block_width,
920 } => {
921 if dimensions.len() >= 2 {
922 let cols = dimensions[1];
923 let row = current_idx / cols;
924 let col = current_idx % cols;
925
926 let block_row = row / *block_height;
928 let block_col = col / *block_width;
929
930 let block_row_offset = row % *block_height;
932 let block_col_offset = col % *block_width;
933
934 let mut result = Vec::with_capacity(prefetch_count);
936 let mut remaining = prefetch_count;
937
938 for i in 1..=std::cmp::min(*block_width - block_col_offset, remaining) {
940 result.push(current_idx + i);
941 remaining -= 1;
942 }
943
944 let mut next_row = block_row_offset + 1;
946 while remaining > 0 && next_row < *block_height {
947 for col_offset in 0..std::cmp::min(*block_width, remaining) {
948 let idx = (block_row * *block_height + next_row) * cols
949 + block_col * *block_width
950 + col_offset;
951 result.push(idx);
952 remaining -= 1;
953 }
954 next_row += 1;
955 }
956
957 if remaining > 0 {
959 let next_block_row = if block_col + 1 < cols / *block_width {
960 block_row } else {
962 block_row + 1 };
964
965 let next_block_col = if block_col + 1 < cols / *block_width {
966 block_col + 1 } else {
968 0 };
970
971 for i in 0..remaining {
973 let row_offset = i / *block_width;
974 let col_offset = i % *block_width;
975 let idx = (next_block_row * *block_height + row_offset) * cols
976 + next_block_col * *block_width
977 + col_offset;
978 result.push(idx);
979 }
980 }
981
982 result
983 } else {
984 (1..=prefetch_count).map(|i| current_idx + i).collect()
986 }
987 }
988 ComplexPattern::BlockStrided { block_size, stride } => {
989 (1..=prefetch_count)
991 .map(|i| {
992 let offset = i * stride;
993 let block_offset = offset % block_size;
994 let blocks_advanced = offset / block_size;
995
996 if blocks_advanced == 0 {
997 current_idx + offset
999 } else {
1000 current_idx + block_size * blocks_advanced + block_offset
1002 }
1003 })
1004 .collect()
1005 }
1006 ComplexPattern::Stencil {
1007 dimensions: dim_count,
1008 radius,
1009 } => {
1010 if dimensions.len() >= *dim_count {
1011 let cols = dimensions[1];
1012 let row = current_idx / cols;
1013 let col = current_idx % cols;
1014
1015 let mut result = Vec::new();
1017
1018 for r in -(*radius as isize)..=(*radius as isize) {
1020 for c in -(*radius as isize)..=(*radius as isize) {
1021 if r == 0 && c == 0 {
1023 continue;
1024 }
1025
1026 let new_row = row as isize + r;
1027 let new_col = col as isize + c;
1028
1029 if new_row >= 0
1031 && new_row < dimensions[0] as isize
1032 && new_col >= 0
1033 && new_col < cols as isize
1034 {
1035 let idx = (new_row as usize) * cols + (new_col as usize);
1036 result.push(idx);
1037 }
1038 }
1039 }
1040
1041 result.into_iter().take(prefetch_count).collect()
1043 } else {
1044 (1..=prefetch_count).map(|i| current_idx + i).collect()
1046 }
1047 }
1048 _ => {
1050 let mut result = Vec::with_capacity(prefetch_count);
1051
1052 for i in 1..=prefetch_count / 2 {
1054 result.push(current_idx + i);
1055 }
1056
1057 if dimensions.len() >= 2 {
1058 let cols = dimensions[1];
1059 result.push(current_idx.saturating_sub(cols));
1061 result.push(current_idx + cols);
1062 }
1063
1064 while result.len() < prefetch_count {
1066 result.push(current_idx + result.len() + 1);
1067 }
1068
1069 result
1071 .into_iter()
1072 .collect::<HashSet<_>>()
1073 .into_iter()
1074 .collect()
1075 }
1076 }
1077 }
1078}
1079
1080#[cfg(test)]
1081mod tests {
1082 use super::*;
1083
1084 #[test]
1085 fn test_row_major_detection() {
1086 let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1087 recognizer.set_dimensions(vec![8, 8]);
1088
1089 for i in 0..64 {
1091 recognizer.record_access(i);
1092 }
1093
1094 let patterns = recognizer.get_patterns();
1096
1097 assert!(patterns
1099 .iter()
1100 .any(|p| matches!(p.pattern_type, ComplexPattern::RowMajor)));
1101
1102 assert_eq!(recognizer.get_basic_pattern(), AccessPattern::Sequential);
1104 }
1105
1106 #[test]
1107 fn test_column_major_detection() {
1108 let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1109 recognizer.set_dimensions(vec![8, 8]);
1110
1111 for j in 0..8 {
1113 for i in 0..8 {
1114 recognizer.record_access(i * 8 + j);
1115 }
1116 }
1117
1118 let patterns = recognizer.get_patterns();
1120
1121 assert!(patterns
1123 .iter()
1124 .any(|p| matches!(p.pattern_type, ComplexPattern::ColumnMajor)));
1125
1126 assert!(matches!(
1128 recognizer.get_basic_pattern(),
1129 AccessPattern::Strided(_)
1130 ));
1131 }
1132
1133 #[test]
1134 fn test_zigzag_detection() {
1135 let config = PatternRecognitionConfig {
1136 min_history_size: 10, ..Default::default()
1138 };
1139 let mut recognizer = PatternRecognizer::new(config);
1140 recognizer.set_dimensions(vec![8, 8]);
1141
1142 for row in 0..8 {
1144 if row % 2 == 0 {
1145 for j in 0..8 {
1147 recognizer.record_access(row * 8 + j);
1148 }
1149 } else {
1150 for j in (0..8).rev() {
1152 recognizer.record_access(row * 8 + j);
1153 }
1154 }
1155 }
1156
1157 let patterns = recognizer.get_patterns();
1159
1160 assert!(patterns
1162 .iter()
1163 .any(|p| matches!(p.pattern_type, ComplexPattern::Zigzag)));
1164 }
1165
1166 #[test]
1167 fn test_diagonal_detection() {
1168 let config = PatternRecognitionConfig {
1169 min_history_size: 10, ..Default::default()
1171 };
1172 let mut recognizer = PatternRecognizer::new(config);
1173 recognizer.set_dimensions(vec![16, 16]);
1174
1175 for i in 0..16 {
1177 recognizer.record_access(i * 16 + i);
1178 }
1179
1180 for i in 0..8 {
1182 recognizer.record_access(i * 16 + i);
1183 }
1184
1185 let patterns = recognizer.get_patterns();
1187
1188 assert!(patterns
1190 .iter()
1191 .any(|p| matches!(p.pattern_type, ComplexPattern::DiagonalMajor)));
1192 }
1193
1194 #[test]
1195 fn test_block_detection() {
1196 let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1197 recognizer.set_dimensions(vec![8, 8]);
1198
1199 for i in 0..4 {
1202 for j in 0..4 {
1203 recognizer.record_access(i * 8 + j);
1204 }
1205 }
1206 for i in 0..4 {
1208 for j in 4..8 {
1209 recognizer.record_access(i * 8 + j);
1210 }
1211 }
1212
1213 let patterns = recognizer.get_patterns();
1215
1216 assert!(patterns.iter().any(|p| {
1218 if let ComplexPattern::Block {
1219 block_height,
1220 block_width,
1221 } = p.pattern_type
1222 {
1223 block_height == 4 && block_width == 4
1224 } else {
1225 false
1226 }
1227 }));
1228 }
1229
1230 #[test]
1231 fn test_stencil_detection() {
1232 let mut recognizer = PatternRecognizer::new(PatternRecognitionConfig::default());
1233 recognizer.set_dimensions(vec![10, 10]);
1234
1235 for i in 1..9 {
1237 for j in 1..9 {
1238 let center = i * 10 + j;
1240 recognizer.record_access(center);
1241
1242 recognizer.record_access(center - 10); recognizer.record_access(center + 1); recognizer.record_access(center + 10); recognizer.record_access(center - 1); }
1248 }
1249
1250 let patterns = recognizer.get_patterns();
1252
1253 assert!(patterns.iter().any(|p| {
1255 if let ComplexPattern::Stencil { dimensions, radius } = p.pattern_type {
1256 dimensions == 2 && radius == 1
1257 } else {
1258 false
1259 }
1260 }));
1261 }
1262
1263 #[test]
1264 fn test_pattern_utils() {
1265 let pattern = ComplexPattern::RowMajor;
1267 let dimensions = vec![8, 8];
1268 let current_idx = 10;
1269 let prefetch_count = 3;
1270
1271 let prefetches =
1272 pattern_utils::get_prefetch_pattern(&pattern, &dimensions, current_idx, prefetch_count);
1273
1274 assert_eq!(prefetches, vec![11, 12, 13]);
1276
1277 let pattern = ComplexPattern::DiagonalMajor;
1279
1280 let prefetches =
1281 pattern_utils::get_prefetch_pattern(&pattern, &dimensions, current_idx, prefetch_count);
1282
1283 assert_eq!(prefetches, vec![19, 28, 37]);
1286 }
1287}