scirs2_core/memory_efficient/
adaptive_prefetch.rs

1//! Adaptive prefetching strategy that dynamically learns from access patterns.
2//!
3//! This module provides an enhanced prefetching system that uses machine learning
4//! techniques to dynamically adjust its prefetching strategy based on observed
5//! access patterns and performance metrics.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::time::{Duration, Instant};
9
10use super::prefetch::{AccessPattern, AccessPatternTracker, PrefetchConfig, PrefetchStats};
11
12/// Maximum number of strategies to try during exploration phase
13const MAX_EXPLORATION_STRATEGIES: usize = 5;
14
15/// Duration for testing each strategy during exploration
16const STRATEGY_TEST_DURATION: Duration = Duration::from_secs(60);
17
18/// Reinforcement learning parameters
19const LEARNING_RATE: f64 = 0.1;
20#[allow(dead_code)]
21const DISCOUNT_FACTOR: f64 = 0.9;
22const EXPLORATION_RATE_INITIAL: f64 = 0.3;
23const EXPLORATION_RATE_DECAY: f64 = 0.995;
24
25/// Matrix traversal pattern constants
26const MATRIX_TRAVERSAL_ROW_MAJOR: &str = "MATRIX_TRAVERSAL_ROW_MAJOR";
27const MATRIX_TRAVERSAL_COL_MAJOR: &str = "MATRIX_TRAVERSAL_COL_MAJOR";
28const ZIGZAG_SCAN: &str = "ZIGZAG_SCAN";
29
30/// Types of prefetching strategies that can be dynamically selected.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum PrefetchStrategy {
33    /// Prefetch next N consecutive blocks
34    Sequential(usize),
35
36    /// Prefetch blocks with a fixed stride
37    Strided { stride: usize, count: usize },
38
39    /// Prefetch blocks based on a custom pattern
40    Pattern { windowsize: usize, lookahead: usize },
41
42    /// Hybrid approach combining sequential and pattern-based
43    Hybrid { sequential: usize, pattern: usize },
44
45    /// Conservative prefetching (minimal prefetching)
46    Conservative,
47
48    /// Aggressive prefetching (prefetch many blocks)
49    Aggressive,
50
51    /// No prefetching (baseline for comparisons)
52    None,
53}
54
55impl Default for PrefetchStrategy {
56    fn default() -> Self {
57        PrefetchStrategy::Sequential(2)
58    }
59}
60
61/// Performance metrics for a particular prefetching strategy.
62#[derive(Debug, Clone)]
63struct StrategyPerformance {
64    /// The strategy being evaluated
65    strategy: PrefetchStrategy,
66
67    /// Number of times this strategy has been used
68    usage_count: usize,
69
70    /// Cache hit rate when using this strategy
71    hit_rate: f64,
72
73    /// Average latency for block access with this strategy
74    avg_latency_ns: f64,
75
76    /// Time when this strategy was last used
77    last_used: Instant,
78
79    /// Q-value for reinforcement learning
80    q_value: f64,
81}
82
83/// Advanced access pattern detector with dynamic learning.
84#[derive(Debug)]
85pub struct AdaptivePatternTracker {
86    /// Base configuration
87    config: PrefetchConfig,
88
89    /// History of accessed blocks
90    history: VecDeque<(usize, Instant, Duration)>, // (block_idx, timestamp, access_time)
91
92    /// Current detected pattern
93    current_pattern: AccessPattern,
94
95    /// For strided patterns, the stride value
96    stride: Option<usize>,
97
98    /// Performance of different strategies
99    strategy_performance: HashMap<PrefetchStrategy, StrategyPerformance>,
100
101    /// Current active strategy
102    current_strategy: PrefetchStrategy,
103
104    /// Time when we should try another strategy
105    next_strategy_change: Instant,
106
107    /// Whether we're in exploration or exploitation phase
108    exploring: bool,
109
110    /// Current exploration rate (epsilon) for epsilon-greedy strategy
111    exploration_rate: f64,
112
113    /// Matrix dimension, if known
114    dimensions: Option<Vec<usize>>,
115
116    /// Patterns with dimension-aware context
117    dimensional_patterns: HashMap<String, Vec<usize>>,
118
119    /// Step counter for deterministic exploration
120    exploration_step: usize,
121}
122
123impl AdaptivePatternTracker {
124    /// Create a new adaptive pattern tracker.
125    pub fn new(config: PrefetchConfig) -> Self {
126        let mut strategies = HashMap::new();
127
128        // Store history_size before moving config
129        let history_size = config.history_size;
130
131        // Initialize default strategies with neutral Q-values
132        for strategy in [
133            PrefetchStrategy::Sequential(2),
134            PrefetchStrategy::Sequential(5),
135            PrefetchStrategy::Strided {
136                stride: 10,
137                count: 3,
138            },
139            PrefetchStrategy::Conservative,
140            PrefetchStrategy::Aggressive,
141            PrefetchStrategy::None,
142        ] {
143            strategies.insert(
144                strategy,
145                StrategyPerformance {
146                    strategy,
147                    usage_count: 0,
148                    hit_rate: 0.0,
149                    avg_latency_ns: 0.0,
150                    last_used: Instant::now(),
151                    q_value: 0.0,
152                },
153            );
154        }
155
156        Self {
157            config,
158            history: VecDeque::with_capacity(history_size),
159            current_pattern: AccessPattern::Random,
160            stride: None,
161            strategy_performance: strategies,
162            current_strategy: PrefetchStrategy::default(),
163            next_strategy_change: Instant::now() + STRATEGY_TEST_DURATION,
164            exploring: true,
165            exploration_rate: EXPLORATION_RATE_INITIAL,
166            dimensions: None,
167            dimensional_patterns: HashMap::new(),
168            exploration_step: 0,
169        }
170    }
171
172    /// Set the array dimensions for better pattern detection.
173    pub fn set_dimensions(&mut self, dimensions: Vec<usize>) {
174        self.dimensions = Some(dimensions);
175    }
176
177    /// Update the performance metrics for the current strategy.
178    pub fn ns(&mut self, stats: PrefetchStats, avg_latencyns: f64) {
179        if let Some(perf) = self.strategy_performance.get_mut(&self.current_strategy) {
180            // Update the performance metrics
181            perf.usage_count += 1;
182            perf.hit_rate = stats.hit_rate;
183            perf.avg_latency_ns = avg_latencyns;
184            perf.last_used = Instant::now();
185
186            // Calculate reward (higher hit rate and lower latency are better)
187            let hit_rate_reward = stats.hit_rate;
188            let latency_factor = if perf.avg_latency_ns > 0.0 {
189                1.0 / (1.0 + perf.avg_latency_ns / 1_000_000.0) // Convert to ms and normalize
190            } else {
191                0.0
192            };
193
194            let reward = hit_rate_reward * 0.7 + latency_factor * 0.3;
195
196            // Update Q-value with simple Q-learning
197            perf.q_value = (1.0 - LEARNING_RATE) * perf.q_value + LEARNING_RATE * reward;
198        }
199
200        // Check if it's time to select a new strategy
201        if Instant::now() >= self.next_strategy_change {
202            self.select_next_strategy();
203        }
204    }
205
206    /// Select the next strategy to use.
207    fn select_next_strategy(&mut self) {
208        // Increment exploration step
209        self.exploration_step += 1;
210
211        // Decay exploration rate
212        self.exploration_rate *= EXPLORATION_RATE_DECAY;
213
214        // Decide whether to explore or exploit
215        if self.exploring
216            || (self.exploration_step % 100) < (self.exploration_rate * 100.0) as usize
217        {
218            // Exploration phase: try different strategies
219            let available_strategies: Vec<PrefetchStrategy> =
220                self.strategy_performance.keys().copied().collect();
221
222            // Select a random strategy, but avoid the current one
223            let candidates: Vec<PrefetchStrategy> = available_strategies
224                .into_iter()
225                .filter(|&s| s != self.current_strategy)
226                .collect();
227
228            if !candidates.is_empty() {
229                let idx = self.exploration_step % candidates.len();
230                self.current_strategy = candidates[idx];
231            }
232
233            // Check if we should move to exploitation phase
234            let total_usage: usize = self
235                .strategy_performance
236                .values()
237                .map(|p| p.usage_count)
238                .sum();
239
240            if total_usage >= MAX_EXPLORATION_STRATEGIES * 2 {
241                self.exploring = false;
242            }
243        } else {
244            // Exploitation phase: choose the strategy with the highest Q-value
245            let best_strategy = self
246                .strategy_performance
247                .values()
248                .max_by(|a, b| a.q_value.partial_cmp(&b.q_value).expect("Operation failed"))
249                .map(|p| p.strategy)
250                .unwrap_or_default();
251
252            self.current_strategy = best_strategy;
253        }
254
255        // Set the next time to change strategies
256        self.next_strategy_change = Instant::now() + STRATEGY_TEST_DURATION;
257
258        // Update the strategy if it's based on detected pattern
259        self.update_strategy_from_pattern();
260    }
261
262    /// Update strategy based on the current detected pattern.
263    fn update_strategy_from_pattern(&mut self) {
264        match self.current_pattern {
265            AccessPattern::Sequential => {
266                // If pattern is sequential but we're not using a sequential strategy,
267                // consider switching to a sequential strategy
268                match self.current_strategy {
269                    PrefetchStrategy::Sequential(_) => {
270                        // Already using sequential strategy, nothing to do
271                    }
272                    _ => {
273                        // Consider switching to sequential, but respect the Q-values
274                        let seq_strategy = PrefetchStrategy::Sequential(self.config.prefetch_count);
275
276                        if let Some(seq_perf) = self.strategy_performance.get(&seq_strategy) {
277                            let current_q = self
278                                .strategy_performance
279                                .get(&self.current_strategy)
280                                .map(|p| p.q_value)
281                                .unwrap_or(0.0);
282
283                            if seq_perf.q_value > current_q * 1.2 {
284                                // Sequential is significantly better, switch to it
285                                self.current_strategy = seq_strategy;
286                            }
287                        } else {
288                            // We don't have data on sequential yet, add it and possibly switch
289                            self.strategy_performance.insert(
290                                seq_strategy,
291                                StrategyPerformance {
292                                    strategy: seq_strategy,
293                                    usage_count: 0,
294                                    hit_rate: 0.0,
295                                    avg_latency_ns: 0.0,
296                                    last_used: Instant::now(),
297                                    q_value: 0.2, // Slight bias towards sequential when detected
298                                },
299                            );
300
301                            // Occasionally switch to it for exploration
302                            if (self.exploration_step % 100) < 50 {
303                                self.current_strategy = seq_strategy;
304                            }
305                        }
306                    }
307                }
308            }
309            AccessPattern::Strided(stride) => {
310                // If pattern is strided but we're not using a strided strategy,
311                // consider switching to a strided strategy
312                let strided_strategy = PrefetchStrategy::Strided {
313                    stride,
314                    count: self.config.prefetch_count,
315                };
316
317                // Add or update this strategy in our performance map
318                self.strategy_performance
319                    .entry(strided_strategy)
320                    .or_insert_with(|| {
321                        StrategyPerformance {
322                            strategy: strided_strategy,
323                            usage_count: 0,
324                            hit_rate: 0.0,
325                            avg_latency_ns: 0.0,
326                            last_used: Instant::now(),
327                            q_value: 0.2, // Slight bias when detected
328                        }
329                    });
330
331                // Consider switching to this strided strategy
332                match self.current_strategy {
333                    PrefetchStrategy::Strided {
334                        stride: current_stride,
335                        ..
336                    } => {
337                        // Already using strided strategy, maybe update the stride
338                        if current_stride != stride && (self.exploration_step % 100) < 70 {
339                            self.current_strategy = strided_strategy;
340                        }
341                    }
342                    _ => {
343                        // Not using strided strategy, consider switching
344                        let current_q = self
345                            .strategy_performance
346                            .get(&self.current_strategy)
347                            .map(|p| p.q_value)
348                            .unwrap_or(0.0);
349
350                        if let Some(strided_perf) = self.strategy_performance.get(&strided_strategy)
351                        {
352                            if strided_perf.q_value > current_q * 1.1
353                                || (self.exploration_step % 100) < 30
354                            {
355                                self.current_strategy = strided_strategy;
356                            }
357                        } else {
358                            // Occasionally switch to it for exploration
359                            if (self.exploration_step % 100) < 40 {
360                                self.current_strategy = strided_strategy;
361                            }
362                        }
363                    }
364                }
365            }
366            AccessPattern::Custom => {
367                // If we have dimensional information, try to detect specific patterns
368                if let Some(dims) = self.dimensions.clone() {
369                    // Create pattern-specific strategies
370                    let detected_patterns = self.detect_dimensional_patterns(&dims);
371
372                    for pattern_name in detected_patterns {
373                        // For matrix traversal, use hybrid strategy
374                        if pattern_name == MATRIX_TRAVERSAL_ROW_MAJOR {
375                            let strategy = PrefetchStrategy::Hybrid {
376                                sequential: dims[1], // Row length
377                                pattern: 2,
378                            };
379
380                            // Add this strategy if it doesn't exist
381                            self.strategy_performance
382                                .entry(strategy)
383                                .or_insert_with(|| {
384                                    StrategyPerformance {
385                                        strategy,
386                                        usage_count: 0,
387                                        hit_rate: 0.0,
388                                        avg_latency_ns: 0.0,
389                                        last_used: Instant::now(),
390                                        q_value: 0.3, // Higher bias for dimensional patterns
391                                    }
392                                });
393
394                            // Consider switching to this strategy
395                            if (self.exploration_step % 100) < 60 {
396                                self.current_strategy = strategy;
397                            }
398                        } else if pattern_name == MATRIX_TRAVERSAL_COL_MAJOR {
399                            let strategy = PrefetchStrategy::Strided {
400                                stride: dims[0], // Column stride
401                                count: 3,
402                            };
403
404                            // Add this strategy if it doesn't exist
405                            self.strategy_performance
406                                .entry(strategy)
407                                .or_insert_with(|| StrategyPerformance {
408                                    strategy,
409                                    usage_count: 0,
410                                    hit_rate: 0.0,
411                                    avg_latency_ns: 0.0,
412                                    last_used: Instant::now(),
413                                    q_value: 0.3,
414                                });
415
416                            // Consider switching to this strategy
417                            if (self.exploration_step % 100) < 60 {
418                                self.current_strategy = strategy;
419                            }
420                        }
421                    }
422                } else {
423                    // Without dimensional information, use pattern-based strategy
424                    let strategy = PrefetchStrategy::Pattern {
425                        windowsize: self.config.min_pattern_length,
426                        lookahead: self.config.prefetch_count,
427                    };
428
429                    // Add this strategy if it doesn't exist
430                    self.strategy_performance
431                        .entry(strategy)
432                        .or_insert_with(|| StrategyPerformance {
433                            strategy,
434                            usage_count: 0,
435                            hit_rate: 0.0,
436                            avg_latency_ns: 0.0,
437                            last_used: Instant::now(),
438                            q_value: 0.2,
439                        });
440
441                    // Occasionally switch to pattern-based strategy
442                    if (self.exploration_step % 100) < 40 {
443                        self.current_strategy = strategy;
444                    }
445                }
446            }
447            AccessPattern::Random => {
448                // For random access, favor conservative or aggressive based on past performance
449                let conservative_q = self
450                    .strategy_performance
451                    .get(&PrefetchStrategy::Conservative)
452                    .map(|p| p.q_value)
453                    .unwrap_or(0.1);
454
455                let aggressive_q = self
456                    .strategy_performance
457                    .get(&PrefetchStrategy::Aggressive)
458                    .map(|p| p.q_value)
459                    .unwrap_or(0.1);
460
461                if conservative_q > aggressive_q * 1.2 {
462                    self.current_strategy = PrefetchStrategy::Conservative;
463                } else if aggressive_q > conservative_q * 1.2 {
464                    self.current_strategy = PrefetchStrategy::Aggressive;
465                } else {
466                    // They're similar, choose randomly
467                    self.current_strategy = if (self.exploration_step % 100) < 50 {
468                        PrefetchStrategy::Conservative
469                    } else {
470                        PrefetchStrategy::Aggressive
471                    };
472                }
473            }
474        }
475    }
476
477    /// Detect dimensional patterns in the access history.
478    fn detect_dimensional_patterns(&mut self, dimensions: &[usize]) -> Vec<String> {
479        if dimensions.len() < 2 || self.history.len() < 10 {
480            return Vec::new();
481        }
482
483        let mut detected_patterns = Vec::new();
484
485        // Get the flat indices from history
486        let flat_indices: Vec<usize> = self.history.iter().map(|(idx__, _, _)| *idx__).collect();
487
488        // Check for row-major traversal (adjacent elements in a row)
489        let mut row_major_matches = 0;
490        for i in 1..flat_indices.len() {
491            if flat_indices[i] == flat_indices[i.saturating_sub(1)] + 1 {
492                row_major_matches += 1;
493            }
494        }
495
496        // Check for column-major traversal (adjacent elements in a column)
497        let mut col_major_matches = 0;
498        let col_stride = dimensions[0]; // For 2D array, stride between columns
499        for i in 1..flat_indices.len() {
500            if flat_indices[i] == flat_indices[i.saturating_sub(1)] + col_stride {
501                col_major_matches += 1;
502            }
503        }
504
505        // Calculate match percentages
506        let total_pairs = flat_indices.len() - 1;
507        let row_major_pct = row_major_matches as f64 / total_pairs as f64;
508        let col_major_pct = col_major_matches as f64 / total_pairs as f64;
509
510        // Detect patterns if they match a significant portion of the history
511        if row_major_pct > 0.6 {
512            detected_patterns.push(MATRIX_TRAVERSAL_ROW_MAJOR.to_string());
513        }
514
515        if col_major_pct > 0.6 {
516            detected_patterns.push(MATRIX_TRAVERSAL_COL_MAJOR.to_string());
517        }
518
519        // Try to detect zigzag pattern (alternating row directions)
520        if self.detect_zigzag_pattern(&flat_indices, dimensions) {
521            detected_patterns.push(ZIGZAG_SCAN.to_string());
522        }
523
524        // Keep track of dimensional patterns
525        for pattern in &detected_patterns {
526            self.dimensional_patterns
527                .entry(pattern.clone())
528                .or_default()
529                .push(flat_indices.len());
530        }
531
532        detected_patterns
533    }
534
535    /// Detect zigzag pattern (alternating row directions).
536    fn detect_zigzag_pattern(&self, indices: &[usize], dimensions: &[usize]) -> bool {
537        if indices.len() < 10 || dimensions.len() < 2 {
538            return false;
539        }
540
541        let row_size = dimensions[1];
542
543        // Try to detect changes in direction at row boundaries
544        let mut direction_changes = 0;
545        let mut current_direction = if indices.len() >= 2 {
546            if indices[1] > indices[0] {
547                1
548            } else {
549                -1
550            }
551        } else {
552            return false;
553        };
554
555        for _i in 1..indices.len() - 1 {
556            // Check if we're at a potential row boundary
557            if (indices[_i] % row_size == 0) || (indices[_i] % row_size == row_size - 1) {
558                let next_direction = if indices[_i + 1] > indices[_i] { 1 } else { -1 };
559
560                if next_direction != current_direction {
561                    direction_changes += 1;
562                    current_direction = next_direction;
563                }
564            }
565        }
566
567        // Check if there are enough direction changes to indicate a zigzag pattern
568        let expected_changes = indices.len() / row_size;
569        direction_changes >= expected_changes / 2
570    }
571
572    /// Detect the access pattern based on the history.
573    fn detect_pattern(&mut self) {
574        if self.history.len() < self.config.min_pattern_length {
575            // Not enough history to detect a pattern
576            self.current_pattern = AccessPattern::Random;
577            return;
578        }
579
580        // Extract just the block indices from history
581        let indices: Vec<usize> = self.history.iter().map(|(idx__, _, _)| *idx__).collect();
582
583        // Check for sequential access
584        let mut is_sequential = true;
585        for i in 1..indices.len() {
586            if indices[i] != indices[i.saturating_sub(1)] + 1 {
587                is_sequential = false;
588                break;
589            }
590        }
591
592        if is_sequential {
593            self.current_pattern = AccessPattern::Sequential;
594            self.update_strategy_from_pattern();
595            return;
596        }
597
598        // Check for strided access
599        if indices.len() >= 3 {
600            let mut possible_strides = Vec::new();
601
602            // Calculate potential strides
603            for windowsize in 2..=std::cmp::min(indices.len() / 2, 10) {
604                let mut stride_counts = HashMap::new();
605
606                for i in windowsize..indices.len() {
607                    let stride = match indices[i].checked_sub(indices[i - windowsize]) {
608                        Some(s) => s / windowsize,
609                        None => continue,
610                    };
611
612                    *stride_counts.entry(stride).or_insert(0) += 1;
613                }
614
615                // Find the most common stride
616                if let Some((stride, count)) =
617                    stride_counts.into_iter().max_by_key(|(_, count)| *count)
618                {
619                    // Check if this stride appears enough times to be significant
620                    let threshold = (indices.len() - windowsize) / 2;
621                    if count >= threshold {
622                        possible_strides.push((stride, count, windowsize));
623                    }
624                }
625            }
626
627            // Choose the stride with the highest count
628            if let Some((stride__, _, _)) = possible_strides
629                .into_iter()
630                .max_by_key(|(_, count_, _)| *count_)
631            {
632                if stride__ > 0 {
633                    self.current_pattern = AccessPattern::Strided(stride__);
634                    self.stride = Some(stride__);
635                    self.update_strategy_from_pattern();
636                    return;
637                }
638            }
639        }
640
641        // Check for custom dimensional patterns
642        if let Some(dims) = self.dimensions.clone() {
643            if !self.detect_dimensional_patterns(&dims).is_empty() {
644                self.current_pattern = AccessPattern::Custom;
645                self.update_strategy_from_pattern();
646                return;
647            }
648        }
649
650        // No regular pattern detected
651        self.current_pattern = AccessPattern::Random;
652
653        // Update strategy based on detected pattern
654        self.update_strategy_from_pattern();
655    }
656
657    /// Get the blocks to prefetch based on the current strategy.
658    pub fn get_blocks_to_prefetch(&self, count: usize) -> Vec<usize> {
659        if self.history.is_empty() {
660            return Vec::new();
661        }
662
663        let latest = self.history.back().expect("Operation failed").0;
664
665        match self.current_strategy {
666            PrefetchStrategy::Sequential(n) => {
667                // Prefetch the next n blocks sequentially
668                let prefetch_count = std::cmp::min(n, count);
669                (1..=prefetch_count).map(|i| latest + i).collect()
670            }
671            PrefetchStrategy::Strided { stride, count: n } => {
672                // Prefetch n blocks with the given stride
673                let prefetch_count = std::cmp::min(n, count);
674                (1..=prefetch_count).map(|i| latest + stride * i).collect()
675            }
676            PrefetchStrategy::Pattern {
677                windowsize: _,
678                lookahead,
679            } => {
680                // Use pattern matching to predict future blocks
681                self.predict_from_pattern(latest, std::cmp::min(lookahead, count))
682            }
683            PrefetchStrategy::Hybrid {
684                sequential,
685                pattern,
686            } => {
687                // Combine sequential and pattern-based prefetching
688                let mut blocks = Vec::new();
689
690                // First add sequential blocks
691                for i in 1..=sequential {
692                    blocks.push(latest + i);
693                }
694
695                // Then add pattern-based predictions
696                blocks.extend(self.predict_from_pattern(
697                    latest,
698                    std::cmp::min(pattern, count.saturating_sub(sequential)),
699                ));
700
701                // Return unique blocks
702                blocks
703                    .into_iter()
704                    .collect::<HashSet<_>>()
705                    .into_iter()
706                    .collect()
707            }
708            PrefetchStrategy::Conservative => {
709                // Prefetch conservatively (just 1-2 blocks)
710                vec![latest + 1]
711            }
712            PrefetchStrategy::Aggressive => {
713                // Prefetch aggressively
714                let mut blocks = Vec::with_capacity(count);
715
716                // First try sequential blocks
717                for i in 1..=count / 2 {
718                    blocks.push(latest + i);
719                }
720
721                // Then add some nearby blocks
722                if let Some(stride) = self.stride {
723                    blocks.push(latest + stride);
724                    if stride > 1 && blocks.len() < count {
725                        blocks.push(latest + stride * 2);
726                    }
727                }
728
729                // For the remaining slots, add some pattern-based predictions
730                let remaining = count.saturating_sub(blocks.len());
731                if remaining > 0 {
732                    blocks.extend(self.predict_from_pattern(latest, remaining));
733                }
734
735                // Return unique blocks
736                blocks
737                    .into_iter()
738                    .collect::<HashSet<_>>()
739                    .into_iter()
740                    .collect()
741            }
742            PrefetchStrategy::None => {
743                // Don't prefetch anything
744                Vec::new()
745            }
746        }
747    }
748
749    /// Predict blocks based on pattern matching in history.
750    fn predict_from_pattern(&self, latest: usize, count: usize) -> Vec<usize> {
751        // Get the last few block indices from history
752        let history_window = std::cmp::min(8, self.history.len());
753        let mut pattern = Vec::with_capacity(history_window);
754
755        for i in 0..history_window {
756            if let Some((block_idx, _, _)) = self.history.get(self.history.len() - 1 - i) {
757                pattern.push(*block_idx);
758            }
759        }
760
761        if pattern.is_empty() {
762            return vec![latest + 1]; // Default to next block if no pattern
763        }
764
765        // Look for this pattern elsewhere in history
766        let mut predictions = Vec::new();
767        let mut occurrences = Vec::new();
768
769        for i in 0..self.history.len().saturating_sub(pattern.len()) {
770            let mut matches = true;
771            for (j, &pattern_idx) in pattern.iter().enumerate() {
772                if let Some((block_idx, _, _)) = self.history.get(i + j) {
773                    if *block_idx != pattern_idx {
774                        matches = false;
775                        break;
776                    }
777                } else {
778                    matches = false;
779                    break;
780                }
781            }
782
783            if matches {
784                occurrences.push(i);
785            }
786        }
787
788        // For each occurrence, check what comes next
789        for &occurrence_idx in &occurrences {
790            if occurrence_idx + pattern.len() < self.history.len() {
791                if let Some((next_block_idx, _, _)) =
792                    self.history.get(occurrence_idx + pattern.len())
793                {
794                    predictions.push(*next_block_idx);
795                }
796            }
797        }
798
799        // If no predictions from pattern matching, fall back to recent strides
800        if predictions.is_empty() && pattern.len() >= 2 {
801            if let Some(stride) = pattern[0].checked_sub(pattern[1]) {
802                predictions.push(latest + stride);
803            }
804        }
805
806        // Return unique predictions, limited to count
807        predictions
808            .into_iter()
809            .collect::<HashSet<_>>()
810            .into_iter()
811            .take(count)
812            .collect()
813    }
814}
815
816impl AccessPatternTracker for AdaptivePatternTracker {
817    fn record_access(&mut self, blockidx: usize) {
818        // Record the time since the last access (latency)
819        let now = Instant::now();
820        let access_time = if let Some((_, last_time_, _)) = self.history.back() {
821            now.duration_since(*last_time_)
822        } else {
823            Duration::from_nanos(0)
824        };
825
826        // Add to history and remove oldest if needed
827        self.history.push_back((blockidx, now, access_time));
828
829        if self.history.len() > self.config.history_size {
830            self.history.pop_front();
831        }
832
833        // Update pattern if we have enough history
834        if self.history.len() >= self.config.min_pattern_length {
835            self.detect_pattern();
836        }
837    }
838
839    fn predict_next_blocks(&self, count: usize) -> Vec<usize> {
840        self.get_blocks_to_prefetch(count)
841    }
842
843    fn current_pattern(&self) -> AccessPattern {
844        self.current_pattern
845    }
846
847    fn clear_history(&mut self) {
848        self.history.clear();
849        self.current_pattern = AccessPattern::Random;
850        self.stride = None;
851    }
852}
853
854/// Factory for creating different types of access pattern trackers.
855pub struct PatternTrackerFactory;
856
857impl PatternTrackerFactory {
858    /// Create a new access pattern tracker of the specified type.
859    pub fn create_tracker(
860        tracker_type: &str,
861        config: PrefetchConfig,
862    ) -> Box<dyn AccessPatternTracker + Send + Sync> {
863        match tracker_type {
864            "adaptive" => Box::new(AdaptivePatternTracker::new(config)),
865            _ => Box::new(super::prefetch::BlockAccessTracker::new(config)),
866        }
867    }
868}
869
870/// Extended prefetching configuration with adaptive learning options.
871#[derive(Debug, Clone)]
872pub struct AdaptivePrefetchConfig {
873    /// Base prefetching configuration
874    pub base: PrefetchConfig,
875
876    /// Whether to use the adaptive tracker
877    pub use_adaptive_tracker: bool,
878
879    /// Whether to enable reinforcement learning
880    pub enable_learning: bool,
881
882    /// Dimensions of the array (if known)
883    pub dimensions: Option<Vec<usize>>,
884
885    /// Learning rate for Q-value updates
886    pub learningrate: f64,
887
888    /// How often to evaluate strategies (in seconds)
889    pub evaluation_interval: Duration,
890}
891
892impl Default for AdaptivePrefetchConfig {
893    fn default() -> Self {
894        Self {
895            base: PrefetchConfig::default(),
896            use_adaptive_tracker: true,
897            enable_learning: true,
898            dimensions: None,
899            learningrate: LEARNING_RATE,
900            evaluation_interval: STRATEGY_TEST_DURATION,
901        }
902    }
903}
904
905/// Builder for adaptive prefetch configuration.
906#[derive(Debug, Clone)]
907pub struct AdaptivePrefetchConfigBuilder {
908    config: AdaptivePrefetchConfig,
909}
910
911impl AdaptivePrefetchConfigBuilder {
912    /// Create a new builder with default settings.
913    pub fn new() -> Self {
914        Self {
915            config: AdaptivePrefetchConfig::default(),
916        }
917    }
918
919    /// Enable or disable prefetching.
920    pub const fn enabled(mut self, enabled: bool) -> Self {
921        self.config.base.enabled = enabled;
922        self
923    }
924
925    /// Set the number of blocks to prefetch ahead of the current access.
926    pub const fn prefetch_count(mut self, count: usize) -> Self {
927        self.config.base.prefetch_count = count;
928        self
929    }
930
931    /// Set the maximum number of blocks to keep in the prefetch history.
932    pub const fn history_size(mut self, size: usize) -> Self {
933        self.config.base.history_size = size;
934        self
935    }
936
937    /// Set the minimum number of accesses needed to detect a pattern.
938    pub const fn min_pattern_length(mut self, length: usize) -> Self {
939        self.config.base.min_pattern_length = length;
940        self
941    }
942
943    /// Enable or disable asynchronous prefetching.
944    pub const fn prefetch(mut self, asyncprefetch: bool) -> Self {
945        self.config.base.async_prefetch = asyncprefetch;
946        self
947    }
948
949    /// Set the timeout for prefetch operations.
950    pub const fn prefetch_timeout(mut self, timeout: Duration) -> Self {
951        self.config.base.prefetch_timeout = timeout;
952        self
953    }
954
955    /// Set whether to use the adaptive tracker.
956    pub const fn adaptive(mut self, useadaptive: bool) -> Self {
957        self.config.use_adaptive_tracker = useadaptive;
958        self
959    }
960
961    /// Enable or disable reinforcement learning.
962    pub const fn enable_learning(mut self, enable: bool) -> Self {
963        self.config.enable_learning = enable;
964        self
965    }
966
967    /// Set the dimensions of the array.
968    pub fn dimensions(mut self, dimensions: Vec<usize>) -> Self {
969        self.config.dimensions = Some(dimensions);
970        self
971    }
972
973    /// Set the learning rate for Q-value updates.
974    pub const fn learningrate(mut self, rate: f64) -> Self {
975        self.config.learningrate = rate;
976        self
977    }
978
979    /// Set how often to evaluate strategies.
980    pub const fn evaluation_interval(mut self, interval: Duration) -> Self {
981        self.config.evaluation_interval = interval;
982        self
983    }
984
985    /// Build the configuration.
986    pub fn build(self) -> AdaptivePrefetchConfig {
987        self.config
988    }
989}
990
991impl Default for AdaptivePrefetchConfigBuilder {
992    fn default() -> Self {
993        Self::new()
994    }
995}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000
1001    #[test]
1002    fn test_adaptive_pattern_detection_sequential() {
1003        let config = PrefetchConfig {
1004            min_pattern_length: 4,
1005            ..Default::default()
1006        };
1007
1008        let mut tracker = AdaptivePatternTracker::new(config);
1009
1010        // Record sequential access
1011        for i in 0..10 {
1012            tracker.record_access(i);
1013        }
1014
1015        // Check that the pattern was detected correctly
1016        assert_eq!(tracker.current_pattern(), AccessPattern::Sequential);
1017
1018        // Check predictions
1019        let predictions = tracker.predict_next_blocks(3);
1020        assert!(!predictions.is_empty());
1021
1022        // Should include at least the next sequential block
1023        assert!(predictions.contains(&10));
1024    }
1025
1026    #[test]
1027    fn test_adaptive_pattern_detection_strided() {
1028        let config = PrefetchConfig {
1029            min_pattern_length: 4,
1030            ..Default::default()
1031        };
1032
1033        let mut tracker = AdaptivePatternTracker::new(config);
1034
1035        // Record strided access with stride 3
1036        for i in (0..30).step_by(3) {
1037            tracker.record_access(i);
1038        }
1039
1040        // Check that the pattern was detected correctly
1041        assert_eq!(tracker.current_pattern(), AccessPattern::Strided(3));
1042
1043        // Check predictions
1044        let predictions = tracker.predict_next_blocks(3);
1045        assert!(!predictions.is_empty());
1046
1047        // Should include at least the next strided block
1048        assert!(predictions.contains(&30));
1049    }
1050
1051    #[test]
1052    fn test_adaptive_strategy_selection() {
1053        let config = PrefetchConfig {
1054            min_pattern_length: 4,
1055            ..Default::default()
1056        };
1057
1058        let mut tracker = AdaptivePatternTracker::new(config);
1059
1060        // Record a mix of access patterns
1061        for i in 0..5 {
1062            tracker.record_access(0);
1063        }
1064
1065        for i in (10..30).step_by(5) {
1066            tracker.record_access(0);
1067        }
1068
1069        // Update performance metrics
1070        let stats = PrefetchStats {
1071            prefetch_count: 10,
1072            prefetch_hits: 8,
1073            prefetch_misses: 2,
1074            hit_rate: 0.8,
1075        };
1076
1077        // Update performance is not needed for this test
1078        // The tracker adjusts strategy based on access patterns recorded
1079
1080        // Check that strategy selection works
1081        let strategy = tracker.current_strategy;
1082        assert!(matches!(
1083            strategy,
1084            PrefetchStrategy::Sequential(_)
1085                | PrefetchStrategy::Strided { .. }
1086                | PrefetchStrategy::Conservative
1087                | PrefetchStrategy::Aggressive
1088        ));
1089
1090        // Check predictions
1091        let predictions = tracker.predict_next_blocks(3);
1092        assert!(!predictions.is_empty());
1093    }
1094
1095    #[test]
1096    fn test_dimensional_pattern_detection() {
1097        let config = PrefetchConfig {
1098            min_pattern_length: 4,
1099            history_size: 50,
1100            ..Default::default()
1101        };
1102
1103        let mut tracker = AdaptivePatternTracker::new(config);
1104
1105        // Set dimensions to a 5x5 matrix
1106        tracker.set_dimensions(vec![5, 5]);
1107
1108        // Record row-major traversal
1109        for i in 0..5 {
1110            for j in 0..5 {
1111                tracker.record_access(i * 5 + j);
1112            }
1113        }
1114
1115        // Check pattern detection
1116        let dimensions = vec![5, 5];
1117        let patterns = tracker.detect_dimensional_patterns(&dimensions);
1118        assert!(!patterns.is_empty());
1119        assert!(patterns.contains(&MATRIX_TRAVERSAL_ROW_MAJOR.to_string()));
1120
1121        // Clear history
1122        tracker.clear_history();
1123
1124        // Record column-major traversal
1125        for j in 0..5 {
1126            for i in 0..5 {
1127                tracker.record_access(i * 5 + j);
1128            }
1129        }
1130
1131        // Check pattern detection
1132        let patterns = tracker.detect_dimensional_patterns(&dimensions);
1133        assert!(!patterns.is_empty());
1134        assert!(patterns.contains(&MATRIX_TRAVERSAL_COL_MAJOR.to_string()));
1135    }
1136
1137    #[test]
1138    fn test_zigzag_pattern_detection() {
1139        let config = PrefetchConfig {
1140            min_pattern_length: 4,
1141            history_size: 50,
1142            ..Default::default()
1143        };
1144
1145        let mut tracker = AdaptivePatternTracker::new(config);
1146
1147        // Set dimensions to a 5x5 matrix
1148        tracker.set_dimensions(vec![5, 5]);
1149
1150        // Record zigzag traversal
1151        // Row 0: left to right
1152        for j in 0..5 {
1153            tracker.record_access(j);
1154        }
1155        // Row 1: right to left
1156        for j in (0..5).rev() {
1157            tracker.record_access(5 + j);
1158        }
1159        // Row 2: left to right
1160        for j in 0..5 {
1161            tracker.record_access(10 + j);
1162        }
1163        // Row 3: right to left
1164        for j in (0..5).rev() {
1165            tracker.record_access(15 + j);
1166        }
1167
1168        // Get flat indices from history
1169        let indices: Vec<usize> = tracker.history.iter().map(|(idx, _, _)| *idx).collect();
1170
1171        // Check zigzag detection
1172        let dimensions = vec![5, 5];
1173        assert!(tracker.detect_zigzag_pattern(&indices, &dimensions));
1174    }
1175}