1use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::{Duration, Instant};
9
10use super::prefetch::{AccessPattern, AccessPatternTracker, PrefetchConfig, PrefetchStats};
11
12const MAX_EXPLORATION_STRATEGIES: usize = 5;
14
15const STRATEGY_TEST_DURATION: Duration = Duration::from_secs(60);
17
18const LEARNING_RATE: f64 = 0.1;
20#[allow(dead_code)]
21const DISCOUNT_FACTOR: f64 = 0.9;
22const EXPLORATION_RATE_INITIAL: f64 = 0.3;
23const EXPLORATION_RATE_DECAY: f64 = 0.995;
24
25const MATRIX_TRAVERSAL_ROW_MAJOR: &str = "MATRIX_TRAVERSAL_ROW_MAJOR";
27const MATRIX_TRAVERSAL_COL_MAJOR: &str = "MATRIX_TRAVERSAL_COL_MAJOR";
28const ZIGZAG_SCAN: &str = "ZIGZAG_SCAN";
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum PrefetchStrategy {
33 Sequential(usize),
35
36 Strided { stride: usize, count: usize },
38
39 Pattern { windowsize: usize, lookahead: usize },
41
42 Hybrid { sequential: usize, pattern: usize },
44
45 Conservative,
47
48 Aggressive,
50
51 None,
53}
54
55impl Default for PrefetchStrategy {
56 fn default() -> Self {
57 PrefetchStrategy::Sequential(2)
58 }
59}
60
61#[derive(Debug, Clone)]
63struct StrategyPerformance {
64 strategy: PrefetchStrategy,
66
67 usage_count: usize,
69
70 hit_rate: f64,
72
73 avg_latency_ns: f64,
75
76 last_used: Instant,
78
79 q_value: f64,
81}
82
83#[derive(Debug)]
85pub struct AdaptivePatternTracker {
86 config: PrefetchConfig,
88
89 history: VecDeque<(usize, Instant, Duration)>, current_pattern: AccessPattern,
94
95 stride: Option<usize>,
97
98 strategy_performance: HashMap<PrefetchStrategy, StrategyPerformance>,
100
101 current_strategy: PrefetchStrategy,
103
104 next_strategy_change: Instant,
106
107 exploring: bool,
109
110 exploration_rate: f64,
112
113 dimensions: Option<Vec<usize>>,
115
116 dimensional_patterns: HashMap<String, Vec<usize>>,
118
119 exploration_step: usize,
121}
122
123impl AdaptivePatternTracker {
124 pub fn new(config: PrefetchConfig) -> Self {
126 let mut strategies = HashMap::new();
127
128 let history_size = config.history_size;
130
131 for strategy in [
133 PrefetchStrategy::Sequential(2),
134 PrefetchStrategy::Sequential(5),
135 PrefetchStrategy::Strided {
136 stride: 10,
137 count: 3,
138 },
139 PrefetchStrategy::Conservative,
140 PrefetchStrategy::Aggressive,
141 PrefetchStrategy::None,
142 ] {
143 strategies.insert(
144 strategy,
145 StrategyPerformance {
146 strategy,
147 usage_count: 0,
148 hit_rate: 0.0,
149 avg_latency_ns: 0.0,
150 last_used: Instant::now(),
151 q_value: 0.0,
152 },
153 );
154 }
155
156 Self {
157 config,
158 history: VecDeque::with_capacity(history_size),
159 current_pattern: AccessPattern::Random,
160 stride: None,
161 strategy_performance: strategies,
162 current_strategy: PrefetchStrategy::default(),
163 next_strategy_change: Instant::now() + STRATEGY_TEST_DURATION,
164 exploring: true,
165 exploration_rate: EXPLORATION_RATE_INITIAL,
166 dimensions: None,
167 dimensional_patterns: HashMap::new(),
168 exploration_step: 0,
169 }
170 }
171
172 pub fn set_dimensions(&mut self, dimensions: Vec<usize>) {
174 self.dimensions = Some(dimensions);
175 }
176
177 pub fn ns(&mut self, stats: PrefetchStats, avg_latencyns: f64) {
179 if let Some(perf) = self.strategy_performance.get_mut(&self.current_strategy) {
180 perf.usage_count += 1;
182 perf.hit_rate = stats.hit_rate;
183 perf.avg_latency_ns = avg_latencyns;
184 perf.last_used = Instant::now();
185
186 let hit_rate_reward = stats.hit_rate;
188 let latency_factor = if perf.avg_latency_ns > 0.0 {
189 1.0 / (1.0 + perf.avg_latency_ns / 1_000_000.0) } else {
191 0.0
192 };
193
194 let reward = hit_rate_reward * 0.7 + latency_factor * 0.3;
195
196 perf.q_value = (1.0 - LEARNING_RATE) * perf.q_value + LEARNING_RATE * reward;
198 }
199
200 if Instant::now() >= self.next_strategy_change {
202 self.select_next_strategy();
203 }
204 }
205
206 fn select_next_strategy(&mut self) {
208 self.exploration_step += 1;
210
211 self.exploration_rate *= EXPLORATION_RATE_DECAY;
213
214 if self.exploring
216 || (self.exploration_step % 100) < (self.exploration_rate * 100.0) as usize
217 {
218 let available_strategies: Vec<PrefetchStrategy> =
220 self.strategy_performance.keys().copied().collect();
221
222 let candidates: Vec<PrefetchStrategy> = available_strategies
224 .into_iter()
225 .filter(|&s| s != self.current_strategy)
226 .collect();
227
228 if !candidates.is_empty() {
229 let idx = self.exploration_step % candidates.len();
230 self.current_strategy = candidates[idx];
231 }
232
233 let total_usage: usize = self
235 .strategy_performance
236 .values()
237 .map(|p| p.usage_count)
238 .sum();
239
240 if total_usage >= MAX_EXPLORATION_STRATEGIES * 2 {
241 self.exploring = false;
242 }
243 } else {
244 let best_strategy = self
246 .strategy_performance
247 .values()
248 .max_by(|a, b| a.q_value.partial_cmp(&b.q_value).expect("Operation failed"))
249 .map(|p| p.strategy)
250 .unwrap_or_default();
251
252 self.current_strategy = best_strategy;
253 }
254
255 self.next_strategy_change = Instant::now() + STRATEGY_TEST_DURATION;
257
258 self.update_strategy_from_pattern();
260 }
261
262 fn update_strategy_from_pattern(&mut self) {
264 match self.current_pattern {
265 AccessPattern::Sequential => {
266 match self.current_strategy {
269 PrefetchStrategy::Sequential(_) => {
270 }
272 _ => {
273 let seq_strategy = PrefetchStrategy::Sequential(self.config.prefetch_count);
275
276 if let Some(seq_perf) = self.strategy_performance.get(&seq_strategy) {
277 let current_q = self
278 .strategy_performance
279 .get(&self.current_strategy)
280 .map(|p| p.q_value)
281 .unwrap_or(0.0);
282
283 if seq_perf.q_value > current_q * 1.2 {
284 self.current_strategy = seq_strategy;
286 }
287 } else {
288 self.strategy_performance.insert(
290 seq_strategy,
291 StrategyPerformance {
292 strategy: seq_strategy,
293 usage_count: 0,
294 hit_rate: 0.0,
295 avg_latency_ns: 0.0,
296 last_used: Instant::now(),
297 q_value: 0.2, },
299 );
300
301 if (self.exploration_step % 100) < 50 {
303 self.current_strategy = seq_strategy;
304 }
305 }
306 }
307 }
308 }
309 AccessPattern::Strided(stride) => {
310 let strided_strategy = PrefetchStrategy::Strided {
313 stride,
314 count: self.config.prefetch_count,
315 };
316
317 self.strategy_performance
319 .entry(strided_strategy)
320 .or_insert_with(|| {
321 StrategyPerformance {
322 strategy: strided_strategy,
323 usage_count: 0,
324 hit_rate: 0.0,
325 avg_latency_ns: 0.0,
326 last_used: Instant::now(),
327 q_value: 0.2, }
329 });
330
331 match self.current_strategy {
333 PrefetchStrategy::Strided {
334 stride: current_stride,
335 ..
336 } => {
337 if current_stride != stride && (self.exploration_step % 100) < 70 {
339 self.current_strategy = strided_strategy;
340 }
341 }
342 _ => {
343 let current_q = self
345 .strategy_performance
346 .get(&self.current_strategy)
347 .map(|p| p.q_value)
348 .unwrap_or(0.0);
349
350 if let Some(strided_perf) = self.strategy_performance.get(&strided_strategy)
351 {
352 if strided_perf.q_value > current_q * 1.1
353 || (self.exploration_step % 100) < 30
354 {
355 self.current_strategy = strided_strategy;
356 }
357 } else {
358 if (self.exploration_step % 100) < 40 {
360 self.current_strategy = strided_strategy;
361 }
362 }
363 }
364 }
365 }
366 AccessPattern::Custom => {
367 if let Some(dims) = self.dimensions.clone() {
369 let detected_patterns = self.detect_dimensional_patterns(&dims);
371
372 for pattern_name in detected_patterns {
373 if pattern_name == MATRIX_TRAVERSAL_ROW_MAJOR {
375 let strategy = PrefetchStrategy::Hybrid {
376 sequential: dims[1], pattern: 2,
378 };
379
380 self.strategy_performance
382 .entry(strategy)
383 .or_insert_with(|| {
384 StrategyPerformance {
385 strategy,
386 usage_count: 0,
387 hit_rate: 0.0,
388 avg_latency_ns: 0.0,
389 last_used: Instant::now(),
390 q_value: 0.3, }
392 });
393
394 if (self.exploration_step % 100) < 60 {
396 self.current_strategy = strategy;
397 }
398 } else if pattern_name == MATRIX_TRAVERSAL_COL_MAJOR {
399 let strategy = PrefetchStrategy::Strided {
400 stride: dims[0], count: 3,
402 };
403
404 self.strategy_performance
406 .entry(strategy)
407 .or_insert_with(|| StrategyPerformance {
408 strategy,
409 usage_count: 0,
410 hit_rate: 0.0,
411 avg_latency_ns: 0.0,
412 last_used: Instant::now(),
413 q_value: 0.3,
414 });
415
416 if (self.exploration_step % 100) < 60 {
418 self.current_strategy = strategy;
419 }
420 }
421 }
422 } else {
423 let strategy = PrefetchStrategy::Pattern {
425 windowsize: self.config.min_pattern_length,
426 lookahead: self.config.prefetch_count,
427 };
428
429 self.strategy_performance
431 .entry(strategy)
432 .or_insert_with(|| StrategyPerformance {
433 strategy,
434 usage_count: 0,
435 hit_rate: 0.0,
436 avg_latency_ns: 0.0,
437 last_used: Instant::now(),
438 q_value: 0.2,
439 });
440
441 if (self.exploration_step % 100) < 40 {
443 self.current_strategy = strategy;
444 }
445 }
446 }
447 AccessPattern::Random => {
448 let conservative_q = self
450 .strategy_performance
451 .get(&PrefetchStrategy::Conservative)
452 .map(|p| p.q_value)
453 .unwrap_or(0.1);
454
455 let aggressive_q = self
456 .strategy_performance
457 .get(&PrefetchStrategy::Aggressive)
458 .map(|p| p.q_value)
459 .unwrap_or(0.1);
460
461 if conservative_q > aggressive_q * 1.2 {
462 self.current_strategy = PrefetchStrategy::Conservative;
463 } else if aggressive_q > conservative_q * 1.2 {
464 self.current_strategy = PrefetchStrategy::Aggressive;
465 } else {
466 self.current_strategy = if (self.exploration_step % 100) < 50 {
468 PrefetchStrategy::Conservative
469 } else {
470 PrefetchStrategy::Aggressive
471 };
472 }
473 }
474 }
475 }
476
477 fn detect_dimensional_patterns(&mut self, dimensions: &[usize]) -> Vec<String> {
479 if dimensions.len() < 2 || self.history.len() < 10 {
480 return Vec::new();
481 }
482
483 let mut detected_patterns = Vec::new();
484
485 let flat_indices: Vec<usize> = self.history.iter().map(|(idx__, _, _)| *idx__).collect();
487
488 let mut row_major_matches = 0;
490 for i in 1..flat_indices.len() {
491 if flat_indices[i] == flat_indices[i.saturating_sub(1)] + 1 {
492 row_major_matches += 1;
493 }
494 }
495
496 let mut col_major_matches = 0;
498 let col_stride = dimensions[0]; for i in 1..flat_indices.len() {
500 if flat_indices[i] == flat_indices[i.saturating_sub(1)] + col_stride {
501 col_major_matches += 1;
502 }
503 }
504
505 let total_pairs = flat_indices.len() - 1;
507 let row_major_pct = row_major_matches as f64 / total_pairs as f64;
508 let col_major_pct = col_major_matches as f64 / total_pairs as f64;
509
510 if row_major_pct > 0.6 {
512 detected_patterns.push(MATRIX_TRAVERSAL_ROW_MAJOR.to_string());
513 }
514
515 if col_major_pct > 0.6 {
516 detected_patterns.push(MATRIX_TRAVERSAL_COL_MAJOR.to_string());
517 }
518
519 if self.detect_zigzag_pattern(&flat_indices, dimensions) {
521 detected_patterns.push(ZIGZAG_SCAN.to_string());
522 }
523
524 for pattern in &detected_patterns {
526 self.dimensional_patterns
527 .entry(pattern.clone())
528 .or_default()
529 .push(flat_indices.len());
530 }
531
532 detected_patterns
533 }
534
535 fn detect_zigzag_pattern(&self, indices: &[usize], dimensions: &[usize]) -> bool {
537 if indices.len() < 10 || dimensions.len() < 2 {
538 return false;
539 }
540
541 let row_size = dimensions[1];
542
543 let mut direction_changes = 0;
545 let mut current_direction = if indices.len() >= 2 {
546 if indices[1] > indices[0] {
547 1
548 } else {
549 -1
550 }
551 } else {
552 return false;
553 };
554
555 for _i in 1..indices.len() - 1 {
556 if (indices[_i] % row_size == 0) || (indices[_i] % row_size == row_size - 1) {
558 let next_direction = if indices[_i + 1] > indices[_i] { 1 } else { -1 };
559
560 if next_direction != current_direction {
561 direction_changes += 1;
562 current_direction = next_direction;
563 }
564 }
565 }
566
567 let expected_changes = indices.len() / row_size;
569 direction_changes >= expected_changes / 2
570 }
571
572 fn detect_pattern(&mut self) {
574 if self.history.len() < self.config.min_pattern_length {
575 self.current_pattern = AccessPattern::Random;
577 return;
578 }
579
580 let indices: Vec<usize> = self.history.iter().map(|(idx__, _, _)| *idx__).collect();
582
583 let mut is_sequential = true;
585 for i in 1..indices.len() {
586 if indices[i] != indices[i.saturating_sub(1)] + 1 {
587 is_sequential = false;
588 break;
589 }
590 }
591
592 if is_sequential {
593 self.current_pattern = AccessPattern::Sequential;
594 self.update_strategy_from_pattern();
595 return;
596 }
597
598 if indices.len() >= 3 {
600 let mut possible_strides = Vec::new();
601
602 for windowsize in 2..=std::cmp::min(indices.len() / 2, 10) {
604 let mut stride_counts = HashMap::new();
605
606 for i in windowsize..indices.len() {
607 let stride = match indices[i].checked_sub(indices[i - windowsize]) {
608 Some(s) => s / windowsize,
609 None => continue,
610 };
611
612 *stride_counts.entry(stride).or_insert(0) += 1;
613 }
614
615 if let Some((stride, count)) =
617 stride_counts.into_iter().max_by_key(|(_, count)| *count)
618 {
619 let threshold = (indices.len() - windowsize) / 2;
621 if count >= threshold {
622 possible_strides.push((stride, count, windowsize));
623 }
624 }
625 }
626
627 if let Some((stride__, _, _)) = possible_strides
629 .into_iter()
630 .max_by_key(|(_, count_, _)| *count_)
631 {
632 if stride__ > 0 {
633 self.current_pattern = AccessPattern::Strided(stride__);
634 self.stride = Some(stride__);
635 self.update_strategy_from_pattern();
636 return;
637 }
638 }
639 }
640
641 if let Some(dims) = self.dimensions.clone() {
643 if !self.detect_dimensional_patterns(&dims).is_empty() {
644 self.current_pattern = AccessPattern::Custom;
645 self.update_strategy_from_pattern();
646 return;
647 }
648 }
649
650 self.current_pattern = AccessPattern::Random;
652
653 self.update_strategy_from_pattern();
655 }
656
657 pub fn get_blocks_to_prefetch(&self, count: usize) -> Vec<usize> {
659 if self.history.is_empty() {
660 return Vec::new();
661 }
662
663 let latest = self.history.back().expect("Operation failed").0;
664
665 match self.current_strategy {
666 PrefetchStrategy::Sequential(n) => {
667 let prefetch_count = std::cmp::min(n, count);
669 (1..=prefetch_count).map(|i| latest + i).collect()
670 }
671 PrefetchStrategy::Strided { stride, count: n } => {
672 let prefetch_count = std::cmp::min(n, count);
674 (1..=prefetch_count).map(|i| latest + stride * i).collect()
675 }
676 PrefetchStrategy::Pattern {
677 windowsize: _,
678 lookahead,
679 } => {
680 self.predict_from_pattern(latest, std::cmp::min(lookahead, count))
682 }
683 PrefetchStrategy::Hybrid {
684 sequential,
685 pattern,
686 } => {
687 let mut blocks = Vec::new();
689
690 for i in 1..=sequential {
692 blocks.push(latest + i);
693 }
694
695 blocks.extend(self.predict_from_pattern(
697 latest,
698 std::cmp::min(pattern, count.saturating_sub(sequential)),
699 ));
700
701 blocks
703 .into_iter()
704 .collect::<HashSet<_>>()
705 .into_iter()
706 .collect()
707 }
708 PrefetchStrategy::Conservative => {
709 vec![latest + 1]
711 }
712 PrefetchStrategy::Aggressive => {
713 let mut blocks = Vec::with_capacity(count);
715
716 for i in 1..=count / 2 {
718 blocks.push(latest + i);
719 }
720
721 if let Some(stride) = self.stride {
723 blocks.push(latest + stride);
724 if stride > 1 && blocks.len() < count {
725 blocks.push(latest + stride * 2);
726 }
727 }
728
729 let remaining = count.saturating_sub(blocks.len());
731 if remaining > 0 {
732 blocks.extend(self.predict_from_pattern(latest, remaining));
733 }
734
735 blocks
737 .into_iter()
738 .collect::<HashSet<_>>()
739 .into_iter()
740 .collect()
741 }
742 PrefetchStrategy::None => {
743 Vec::new()
745 }
746 }
747 }
748
749 fn predict_from_pattern(&self, latest: usize, count: usize) -> Vec<usize> {
751 let history_window = std::cmp::min(8, self.history.len());
753 let mut pattern = Vec::with_capacity(history_window);
754
755 for i in 0..history_window {
756 if let Some((block_idx, _, _)) = self.history.get(self.history.len() - 1 - i) {
757 pattern.push(*block_idx);
758 }
759 }
760
761 if pattern.is_empty() {
762 return vec![latest + 1]; }
764
765 let mut predictions = Vec::new();
767 let mut occurrences = Vec::new();
768
769 for i in 0..self.history.len().saturating_sub(pattern.len()) {
770 let mut matches = true;
771 for (j, &pattern_idx) in pattern.iter().enumerate() {
772 if let Some((block_idx, _, _)) = self.history.get(i + j) {
773 if *block_idx != pattern_idx {
774 matches = false;
775 break;
776 }
777 } else {
778 matches = false;
779 break;
780 }
781 }
782
783 if matches {
784 occurrences.push(i);
785 }
786 }
787
788 for &occurrence_idx in &occurrences {
790 if occurrence_idx + pattern.len() < self.history.len() {
791 if let Some((next_block_idx, _, _)) =
792 self.history.get(occurrence_idx + pattern.len())
793 {
794 predictions.push(*next_block_idx);
795 }
796 }
797 }
798
799 if predictions.is_empty() && pattern.len() >= 2 {
801 if let Some(stride) = pattern[0].checked_sub(pattern[1]) {
802 predictions.push(latest + stride);
803 }
804 }
805
806 predictions
808 .into_iter()
809 .collect::<HashSet<_>>()
810 .into_iter()
811 .take(count)
812 .collect()
813 }
814}
815
816impl AccessPatternTracker for AdaptivePatternTracker {
817 fn record_access(&mut self, blockidx: usize) {
818 let now = Instant::now();
820 let access_time = if let Some((_, last_time_, _)) = self.history.back() {
821 now.duration_since(*last_time_)
822 } else {
823 Duration::from_nanos(0)
824 };
825
826 self.history.push_back((blockidx, now, access_time));
828
829 if self.history.len() > self.config.history_size {
830 self.history.pop_front();
831 }
832
833 if self.history.len() >= self.config.min_pattern_length {
835 self.detect_pattern();
836 }
837 }
838
839 fn predict_next_blocks(&self, count: usize) -> Vec<usize> {
840 self.get_blocks_to_prefetch(count)
841 }
842
843 fn current_pattern(&self) -> AccessPattern {
844 self.current_pattern
845 }
846
847 fn clear_history(&mut self) {
848 self.history.clear();
849 self.current_pattern = AccessPattern::Random;
850 self.stride = None;
851 }
852}
853
854pub struct PatternTrackerFactory;
856
857impl PatternTrackerFactory {
858 pub fn create_tracker(
860 tracker_type: &str,
861 config: PrefetchConfig,
862 ) -> Box<dyn AccessPatternTracker + Send + Sync> {
863 match tracker_type {
864 "adaptive" => Box::new(AdaptivePatternTracker::new(config)),
865 _ => Box::new(super::prefetch::BlockAccessTracker::new(config)),
866 }
867 }
868}
869
870#[derive(Debug, Clone)]
872pub struct AdaptivePrefetchConfig {
873 pub base: PrefetchConfig,
875
876 pub use_adaptive_tracker: bool,
878
879 pub enable_learning: bool,
881
882 pub dimensions: Option<Vec<usize>>,
884
885 pub learningrate: f64,
887
888 pub evaluation_interval: Duration,
890}
891
892impl Default for AdaptivePrefetchConfig {
893 fn default() -> Self {
894 Self {
895 base: PrefetchConfig::default(),
896 use_adaptive_tracker: true,
897 enable_learning: true,
898 dimensions: None,
899 learningrate: LEARNING_RATE,
900 evaluation_interval: STRATEGY_TEST_DURATION,
901 }
902 }
903}
904
905#[derive(Debug, Clone)]
907pub struct AdaptivePrefetchConfigBuilder {
908 config: AdaptivePrefetchConfig,
909}
910
911impl AdaptivePrefetchConfigBuilder {
912 pub fn new() -> Self {
914 Self {
915 config: AdaptivePrefetchConfig::default(),
916 }
917 }
918
919 pub const fn enabled(mut self, enabled: bool) -> Self {
921 self.config.base.enabled = enabled;
922 self
923 }
924
925 pub const fn prefetch_count(mut self, count: usize) -> Self {
927 self.config.base.prefetch_count = count;
928 self
929 }
930
931 pub const fn history_size(mut self, size: usize) -> Self {
933 self.config.base.history_size = size;
934 self
935 }
936
937 pub const fn min_pattern_length(mut self, length: usize) -> Self {
939 self.config.base.min_pattern_length = length;
940 self
941 }
942
943 pub const fn prefetch(mut self, asyncprefetch: bool) -> Self {
945 self.config.base.async_prefetch = asyncprefetch;
946 self
947 }
948
949 pub const fn prefetch_timeout(mut self, timeout: Duration) -> Self {
951 self.config.base.prefetch_timeout = timeout;
952 self
953 }
954
955 pub const fn adaptive(mut self, useadaptive: bool) -> Self {
957 self.config.use_adaptive_tracker = useadaptive;
958 self
959 }
960
961 pub const fn enable_learning(mut self, enable: bool) -> Self {
963 self.config.enable_learning = enable;
964 self
965 }
966
967 pub fn dimensions(mut self, dimensions: Vec<usize>) -> Self {
969 self.config.dimensions = Some(dimensions);
970 self
971 }
972
973 pub const fn learningrate(mut self, rate: f64) -> Self {
975 self.config.learningrate = rate;
976 self
977 }
978
979 pub const fn evaluation_interval(mut self, interval: Duration) -> Self {
981 self.config.evaluation_interval = interval;
982 self
983 }
984
985 pub fn build(self) -> AdaptivePrefetchConfig {
987 self.config
988 }
989}
990
991impl Default for AdaptivePrefetchConfigBuilder {
992 fn default() -> Self {
993 Self::new()
994 }
995}
996
997#[cfg(test)]
998mod tests {
999 use super::*;
1000
1001 #[test]
1002 fn test_adaptive_pattern_detection_sequential() {
1003 let config = PrefetchConfig {
1004 min_pattern_length: 4,
1005 ..Default::default()
1006 };
1007
1008 let mut tracker = AdaptivePatternTracker::new(config);
1009
1010 for i in 0..10 {
1012 tracker.record_access(i);
1013 }
1014
1015 assert_eq!(tracker.current_pattern(), AccessPattern::Sequential);
1017
1018 let predictions = tracker.predict_next_blocks(3);
1020 assert!(!predictions.is_empty());
1021
1022 assert!(predictions.contains(&10));
1024 }
1025
1026 #[test]
1027 fn test_adaptive_pattern_detection_strided() {
1028 let config = PrefetchConfig {
1029 min_pattern_length: 4,
1030 ..Default::default()
1031 };
1032
1033 let mut tracker = AdaptivePatternTracker::new(config);
1034
1035 for i in (0..30).step_by(3) {
1037 tracker.record_access(i);
1038 }
1039
1040 assert_eq!(tracker.current_pattern(), AccessPattern::Strided(3));
1042
1043 let predictions = tracker.predict_next_blocks(3);
1045 assert!(!predictions.is_empty());
1046
1047 assert!(predictions.contains(&30));
1049 }
1050
1051 #[test]
1052 fn test_adaptive_strategy_selection() {
1053 let config = PrefetchConfig {
1054 min_pattern_length: 4,
1055 ..Default::default()
1056 };
1057
1058 let mut tracker = AdaptivePatternTracker::new(config);
1059
1060 for i in 0..5 {
1062 tracker.record_access(0);
1063 }
1064
1065 for i in (10..30).step_by(5) {
1066 tracker.record_access(0);
1067 }
1068
1069 let stats = PrefetchStats {
1071 prefetch_count: 10,
1072 prefetch_hits: 8,
1073 prefetch_misses: 2,
1074 hit_rate: 0.8,
1075 };
1076
1077 let strategy = tracker.current_strategy;
1082 assert!(matches!(
1083 strategy,
1084 PrefetchStrategy::Sequential(_)
1085 | PrefetchStrategy::Strided { .. }
1086 | PrefetchStrategy::Conservative
1087 | PrefetchStrategy::Aggressive
1088 ));
1089
1090 let predictions = tracker.predict_next_blocks(3);
1092 assert!(!predictions.is_empty());
1093 }
1094
1095 #[test]
1096 fn test_dimensional_pattern_detection() {
1097 let config = PrefetchConfig {
1098 min_pattern_length: 4,
1099 history_size: 50,
1100 ..Default::default()
1101 };
1102
1103 let mut tracker = AdaptivePatternTracker::new(config);
1104
1105 tracker.set_dimensions(vec![5, 5]);
1107
1108 for i in 0..5 {
1110 for j in 0..5 {
1111 tracker.record_access(i * 5 + j);
1112 }
1113 }
1114
1115 let dimensions = vec![5, 5];
1117 let patterns = tracker.detect_dimensional_patterns(&dimensions);
1118 assert!(!patterns.is_empty());
1119 assert!(patterns.contains(&MATRIX_TRAVERSAL_ROW_MAJOR.to_string()));
1120
1121 tracker.clear_history();
1123
1124 for j in 0..5 {
1126 for i in 0..5 {
1127 tracker.record_access(i * 5 + j);
1128 }
1129 }
1130
1131 let patterns = tracker.detect_dimensional_patterns(&dimensions);
1133 assert!(!patterns.is_empty());
1134 assert!(patterns.contains(&MATRIX_TRAVERSAL_COL_MAJOR.to_string()));
1135 }
1136
1137 #[test]
1138 fn test_zigzag_pattern_detection() {
1139 let config = PrefetchConfig {
1140 min_pattern_length: 4,
1141 history_size: 50,
1142 ..Default::default()
1143 };
1144
1145 let mut tracker = AdaptivePatternTracker::new(config);
1146
1147 tracker.set_dimensions(vec![5, 5]);
1149
1150 for j in 0..5 {
1153 tracker.record_access(j);
1154 }
1155 for j in (0..5).rev() {
1157 tracker.record_access(5 + j);
1158 }
1159 for j in 0..5 {
1161 tracker.record_access(10 + j);
1162 }
1163 for j in (0..5).rev() {
1165 tracker.record_access(15 + j);
1166 }
1167
1168 let indices: Vec<usize> = tracker.history.iter().map(|(idx, _, _)| *idx).collect();
1170
1171 let dimensions = vec![5, 5];
1173 assert!(tracker.detect_zigzag_pattern(&indices, &dimensions));
1174 }
1175}