1use crate::Dataset;
7use std::collections::{HashMap, VecDeque};
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10use std::thread;
11use std::time::{Duration, Instant};
12use tenflowers_core::{Result, Tensor, TensorError};
13
14#[cfg(feature = "serialize")]
15use serde::{Deserialize, Serialize};
16
17pub struct StreamPrefetchOptimizer<T>
19where
20 T: Clone,
21{
22 config: PrefetchOptimizerConfig,
24 pattern_analyzer: Arc<Mutex<AccessPatternAnalyzer>>,
26 prefetch_buffer: Arc<RwLock<PrefetchBuffer<T>>>,
28 metrics: Arc<Mutex<PrefetchMetrics>>,
30 worker_handles: Vec<thread::JoinHandle<()>>,
32 shutdown: Arc<AtomicBool>,
34}
35
36#[derive(Debug, Clone)]
38#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
39pub struct PrefetchOptimizerConfig {
40 pub max_buffer_size: usize,
42 pub worker_count: usize,
44 pub prediction_confidence_threshold: f64,
46 pub learning_rate: f64,
48 pub max_lookahead_distance: usize,
50 pub adaptive_buffer_resizing: bool,
52 pub buffer_resize_factor: f64,
54 pub min_buffer_utilization: f64,
56 pub pattern_window_size: usize,
58 pub cross_epoch_learning: bool,
60}
61
62impl Default for PrefetchOptimizerConfig {
63 fn default() -> Self {
64 Self {
65 max_buffer_size: 1000,
66 worker_count: 2,
67 prediction_confidence_threshold: 0.7,
68 learning_rate: 0.1,
69 max_lookahead_distance: 100,
70 adaptive_buffer_resizing: true,
71 buffer_resize_factor: 1.5,
72 min_buffer_utilization: 0.3,
73 pattern_window_size: 500,
74 cross_epoch_learning: true,
75 }
76 }
77}
78
79#[derive(Debug)]
81pub struct AccessPatternAnalyzer {
82 access_history: VecDeque<AccessEvent>,
84 patterns: HashMap<PatternSignature, PatternPrediction>,
86 detection_state: PatternDetectionState,
88 config: PrefetchOptimizerConfig,
90}
91
92#[derive(Debug, Clone)]
94pub struct AccessEvent {
95 pub index: usize,
96 pub timestamp: Instant,
97 pub access_type: AccessType,
98 pub context: AccessContext,
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub enum AccessType {
104 Sequential,
105 Random,
106 Strided { stride: usize },
107 Repetitive { cycle_length: usize },
108}
109
110#[derive(Debug, Clone)]
112pub struct AccessContext {
113 pub epoch: Option<usize>,
114 pub batch_index: Option<usize>,
115 pub worker_id: Option<usize>,
116}
117
118#[derive(Debug, Clone, Hash, PartialEq, Eq)]
120pub struct PatternSignature {
121 pub pattern_type: PatternType,
122 pub window_hash: u64,
123 pub context_hash: u64,
124}
125
126#[derive(Debug, Clone, Hash, PartialEq, Eq)]
128pub enum PatternType {
129 Sequential,
130 Strided,
131 Cyclic,
132 RandomWalk,
133 HotSpot,
134}
135
136#[derive(Debug, Clone)]
138pub struct PatternPrediction {
139 pub next_indices: Vec<usize>,
140 pub confidence: f64,
141 pub last_updated: Instant,
142 pub usage_count: usize,
143 pub accuracy_history: VecDeque<bool>,
144}
145
146#[derive(Debug)]
148pub struct PatternDetectionState {
149 pub current_sequence: VecDeque<usize>,
150 pub stride_detector: StrideDetector,
151 pub cycle_detector: CycleDetector,
152 pub hotspot_detector: HotspotDetector,
153}
154
155#[derive(Debug)]
157pub struct StrideDetector {
158 pub candidate_strides: HashMap<usize, usize>, pub min_sequence_length: usize,
160}
161
162#[derive(Debug)]
164pub struct CycleDetector {
165 pub candidate_cycles: HashMap<Vec<usize>, usize>, pub max_cycle_length: usize,
167}
168
169#[derive(Debug)]
171pub struct HotspotDetector {
172 pub access_counts: HashMap<usize, usize>, pub temporal_windows: VecDeque<HashMap<usize, usize>>,
174 pub window_size: Duration,
175}
176
177#[derive(Debug)]
179pub struct PrefetchBuffer<T>
180where
181 T: Clone,
182{
183 buffer: HashMap<usize, BufferedSample<T>>,
185 access_order: VecDeque<usize>,
187 current_size: AtomicUsize,
189 max_size: usize,
191 utilization_stats: UtilizationStats,
193}
194
195#[derive(Debug)]
197pub struct BufferedSample<T>
198where
199 T: Clone,
200{
201 pub data: (Tensor<T>, Tensor<T>),
202 pub load_time: Instant,
203 pub access_count: usize,
204 pub prediction_confidence: f64,
205}
206
207#[derive(Debug, Default)]
209pub struct UtilizationStats {
210 pub hit_count: AtomicUsize,
211 pub miss_count: AtomicUsize,
212 pub eviction_count: AtomicUsize,
213 pub total_requests: AtomicUsize,
214}
215
216#[derive(Debug, Default)]
218#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
219pub struct PrefetchMetrics {
220 pub hit_rate: f64,
222 pub prediction_accuracy: f64,
224 pub buffer_utilization: f64,
226 pub average_latency_us: f64,
228 pub patterns_learned: usize,
230 pub prefetch_efficiency: f64,
232 pub bandwidth_utilization: f64,
234 pub memory_overhead: f64,
236}
237
238impl<T> StreamPrefetchOptimizer<T>
239where
240 T: Clone + Default + Send + Sync + 'static,
241{
242 pub fn new(config: PrefetchOptimizerConfig) -> Self {
244 let pattern_analyzer = Arc::new(Mutex::new(AccessPatternAnalyzer::new(config.clone())));
245 let prefetch_buffer = Arc::new(RwLock::new(PrefetchBuffer::new(config.max_buffer_size)));
246 let metrics = Arc::new(Mutex::new(PrefetchMetrics::default()));
247 let shutdown = Arc::new(AtomicBool::new(false));
248
249 Self {
250 config,
251 pattern_analyzer,
252 prefetch_buffer,
253 metrics,
254 worker_handles: Vec::new(),
255 shutdown,
256 }
257 }
258
259 pub fn start<D>(&mut self, dataset: Arc<D>) -> Result<()>
261 where
262 D: Dataset<T> + Send + Sync + 'static,
263 {
264 for worker_id in 0..self.config.worker_count {
266 let dataset_clone = Arc::clone(&dataset);
267 let pattern_analyzer = Arc::clone(&self.pattern_analyzer);
268 let prefetch_buffer = Arc::clone(&self.prefetch_buffer);
269 let metrics = Arc::clone(&self.metrics);
270 let shutdown = Arc::clone(&self.shutdown);
271 let config = self.config.clone();
272
273 let handle = thread::spawn(move || {
274 Self::prefetch_worker(
275 worker_id,
276 dataset_clone,
277 pattern_analyzer,
278 prefetch_buffer,
279 metrics,
280 shutdown,
281 config,
282 );
283 });
284
285 self.worker_handles.push(handle);
286 }
287
288 Ok(())
289 }
290
291 pub fn get(&self, index: usize, context: AccessContext) -> Result<(Tensor<T>, Tensor<T>)> {
293 let start_time = Instant::now();
294
295 self.record_access(index, context.clone());
297
298 if let Some(sample) = self.get_from_buffer(index) {
300 self.update_hit_metrics(start_time);
301 return Ok(sample.data);
302 }
303
304 self.update_miss_metrics(start_time);
306
307 Err(TensorError::invalid_argument(format!(
310 "Data not available in prefetch buffer for index {index}"
311 )))
312 }
313
314 fn record_access(&self, index: usize, context: AccessContext) {
316 let event = AccessEvent {
317 index,
318 timestamp: Instant::now(),
319 access_type: AccessType::Sequential, context,
321 };
322
323 if let Ok(mut analyzer) = self.pattern_analyzer.lock() {
324 analyzer.record_access(event);
325 }
326 }
327
328 fn get_from_buffer(&self, index: usize) -> Option<BufferedSample<T>> {
330 if let Ok(mut buffer) = self.prefetch_buffer.write() {
331 buffer.get_sample(index)
332 } else {
333 None
334 }
335 }
336
337 fn update_hit_metrics(&self, start_time: Instant) {
339 let latency = start_time.elapsed().as_micros() as f64;
340
341 if let Ok(mut metrics) = self.metrics.lock() {
342 let total_requests = metrics.hit_rate + metrics.prediction_accuracy + 1.0;
343 metrics.hit_rate = (metrics.hit_rate * (total_requests - 1.0) + 1.0) / total_requests;
344 metrics.average_latency_us =
345 (metrics.average_latency_us * (total_requests - 1.0) + latency) / total_requests;
346 }
347 }
348
349 fn update_miss_metrics(&self, start_time: Instant) {
351 let latency = start_time.elapsed().as_micros() as f64;
352
353 if let Ok(mut metrics) = self.metrics.lock() {
354 let total_requests = metrics.hit_rate + metrics.prediction_accuracy + 1.0;
355 metrics.hit_rate = (metrics.hit_rate * (total_requests - 1.0)) / total_requests;
356 metrics.average_latency_us =
357 (metrics.average_latency_us * (total_requests - 1.0) + latency) / total_requests;
358 }
359 }
360
361 fn prefetch_worker<D>(
363 worker_id: usize,
364 dataset: Arc<D>,
365 pattern_analyzer: Arc<Mutex<AccessPatternAnalyzer>>,
366 prefetch_buffer: Arc<RwLock<PrefetchBuffer<T>>>,
367 _metrics: Arc<Mutex<PrefetchMetrics>>,
368 shutdown: Arc<AtomicBool>,
369 _config: PrefetchOptimizerConfig,
370 ) where
371 D: Dataset<T> + Send + Sync + 'static,
372 {
373 while !shutdown.load(Ordering::Relaxed) {
374 let predictions = if let Ok(analyzer) = pattern_analyzer.lock() {
376 analyzer.get_predictions()
377 } else {
378 Vec::new()
379 };
380
381 for prediction in predictions {
383 for &index in &prediction.next_indices {
384 if index < dataset.len() {
385 if let Ok(sample) = dataset.get(index) {
386 let buffered_sample = BufferedSample {
387 data: sample,
388 load_time: Instant::now(),
389 access_count: 0,
390 prediction_confidence: prediction.confidence,
391 };
392
393 if let Ok(mut buffer) = prefetch_buffer.write() {
394 buffer.add_sample(index, buffered_sample);
395 }
396 }
397 }
398 }
399 }
400
401 thread::sleep(Duration::from_millis(10));
403 }
404
405 println!("Prefetch worker {worker_id} shutting down");
406 }
407
408 pub fn get_metrics(&self) -> PrefetchMetrics {
410 if let Ok(metrics) = self.metrics.lock() {
411 PrefetchMetrics {
413 hit_rate: metrics.hit_rate,
414 prediction_accuracy: metrics.prediction_accuracy,
415 buffer_utilization: metrics.buffer_utilization,
416 average_latency_us: metrics.average_latency_us,
417 patterns_learned: metrics.patterns_learned,
418 prefetch_efficiency: metrics.prefetch_efficiency,
419 bandwidth_utilization: metrics.bandwidth_utilization,
420 memory_overhead: metrics.memory_overhead,
421 }
422 } else {
423 PrefetchMetrics::default()
424 }
425 }
426
427 pub fn stop(&mut self) {
429 self.shutdown.store(true, Ordering::Relaxed);
430
431 while let Some(handle) = self.worker_handles.pop() {
433 let _ = handle.join();
434 }
435 }
436}
437
438impl AccessPatternAnalyzer {
439 fn new(config: PrefetchOptimizerConfig) -> Self {
441 Self {
442 access_history: VecDeque::with_capacity(config.pattern_window_size),
443 patterns: HashMap::new(),
444 detection_state: PatternDetectionState::new(),
445 config,
446 }
447 }
448
449 fn record_access(&mut self, event: AccessEvent) {
451 self.access_history.push_back(event.clone());
453
454 if self.access_history.len() > self.config.pattern_window_size {
456 self.access_history.pop_front();
457 }
458
459 self.detection_state.current_sequence.push_back(event.index);
461 if self.detection_state.current_sequence.len() > 100 {
462 self.detection_state.current_sequence.pop_front();
463 }
464
465 self.analyze_patterns();
467 }
468
469 fn analyze_patterns(&mut self) {
471 self.detect_sequential_patterns();
473
474 self.detect_strided_patterns();
476
477 self.detect_cyclic_patterns();
479
480 self.detect_hotspot_patterns();
482 }
483
484 fn detect_sequential_patterns(&mut self) {
486 if self.access_history.len() < 3 {
487 return;
488 }
489
490 let recent_accesses: Vec<usize> = self
491 .access_history
492 .iter()
493 .rev()
494 .take(10)
495 .map(|event| event.index)
496 .collect();
497
498 let mut sequential_count = 0;
499 for window in recent_accesses.windows(2) {
500 if window[1] == window[0] + 1 {
501 sequential_count += 1;
502 }
503 }
504
505 if sequential_count >= 5 {
506 let signature = PatternSignature {
507 pattern_type: PatternType::Sequential,
508 window_hash: self.hash_sequence(&recent_accesses),
509 context_hash: 0, };
511
512 let next_index = recent_accesses[0] + 1;
513 let prediction = PatternPrediction {
514 next_indices: vec![next_index, next_index + 1, next_index + 2],
515 confidence: 0.9,
516 last_updated: Instant::now(),
517 usage_count: 1,
518 accuracy_history: VecDeque::new(),
519 };
520
521 self.patterns.insert(signature, prediction);
522 }
523 }
524
525 fn detect_strided_patterns(&mut self) {
527 self.detection_state
528 .stride_detector
529 .analyze(&self.access_history);
530 }
531
532 fn detect_cyclic_patterns(&mut self) {
534 self.detection_state
535 .cycle_detector
536 .analyze(&self.access_history);
537 }
538
539 fn detect_hotspot_patterns(&mut self) {
541 self.detection_state
542 .hotspot_detector
543 .analyze(&self.access_history);
544 }
545
546 fn get_predictions(&self) -> Vec<PatternPrediction> {
548 self.patterns
549 .values()
550 .filter(|p| p.confidence >= self.config.prediction_confidence_threshold)
551 .cloned()
552 .collect()
553 }
554
555 fn hash_sequence(&self, sequence: &[usize]) -> u64 {
557 use std::collections::hash_map::DefaultHasher;
558 use std::hash::{Hash, Hasher};
559
560 let mut hasher = DefaultHasher::new();
561 sequence.hash(&mut hasher);
562 hasher.finish()
563 }
564}
565
566impl<T> PrefetchBuffer<T>
567where
568 T: Clone,
569{
570 fn new(max_size: usize) -> Self {
572 Self {
573 buffer: HashMap::new(),
574 access_order: VecDeque::new(),
575 current_size: AtomicUsize::new(0),
576 max_size,
577 utilization_stats: UtilizationStats::default(),
578 }
579 }
580
581 fn add_sample(&mut self, index: usize, sample: BufferedSample<T>) {
583 if self.current_size.load(Ordering::Relaxed) >= self.max_size {
585 self.evict_lru();
586 }
587
588 self.buffer.insert(index, sample);
590 self.access_order.push_back(index);
591 self.current_size.fetch_add(1, Ordering::Relaxed);
592 }
593
594 fn get_sample(&mut self, index: usize) -> Option<BufferedSample<T>> {
596 if let Some(mut sample) = self.buffer.remove(&index) {
597 sample.access_count += 1;
598
599 if let Some(pos) = self.access_order.iter().position(|&x| x == index) {
601 self.access_order.remove(pos);
602 self.access_order.push_back(index);
603 }
604
605 let updated_sample = BufferedSample {
607 data: sample.data.clone(),
608 load_time: sample.load_time,
609 access_count: sample.access_count,
610 prediction_confidence: sample.prediction_confidence,
611 };
612 self.buffer.insert(index, updated_sample);
613
614 self.utilization_stats
615 .hit_count
616 .fetch_add(1, Ordering::Relaxed);
617 self.utilization_stats
618 .total_requests
619 .fetch_add(1, Ordering::Relaxed);
620
621 Some(sample)
622 } else {
623 self.utilization_stats
624 .miss_count
625 .fetch_add(1, Ordering::Relaxed);
626 self.utilization_stats
627 .total_requests
628 .fetch_add(1, Ordering::Relaxed);
629 None
630 }
631 }
632
633 fn evict_lru(&mut self) {
635 if let Some(lru_index) = self.access_order.pop_front() {
636 self.buffer.remove(&lru_index);
637 self.current_size.fetch_sub(1, Ordering::Relaxed);
638 self.utilization_stats
639 .eviction_count
640 .fetch_add(1, Ordering::Relaxed);
641 }
642 }
643}
644
645impl PatternDetectionState {
646 fn new() -> Self {
647 Self {
648 current_sequence: VecDeque::new(),
649 stride_detector: StrideDetector::new(),
650 cycle_detector: CycleDetector::new(),
651 hotspot_detector: HotspotDetector::new(),
652 }
653 }
654}
655
656impl StrideDetector {
657 fn new() -> Self {
658 Self {
659 candidate_strides: HashMap::new(),
660 min_sequence_length: 5,
661 }
662 }
663
664 fn analyze(&mut self, access_history: &VecDeque<AccessEvent>) {
665 if access_history.len() < self.min_sequence_length {
666 return;
667 }
668
669 let indices: Vec<usize> = access_history.iter().map(|e| e.index).collect();
670
671 for window_size in 3..=self.min_sequence_length {
673 if indices.len() >= window_size {
674 let window = &indices[indices.len() - window_size..];
675
676 if let Some(stride) = self.detect_stride(window) {
677 *self.candidate_strides.entry(stride).or_insert(0) += 1;
678 }
679 }
680 }
681 }
682
683 fn detect_stride(&self, window: &[usize]) -> Option<usize> {
684 if window.len() < 3 {
685 return None;
686 }
687
688 let first_diff = window[1] as i64 - window[0] as i64;
689
690 for i in 2..window.len() {
691 let diff = window[i] as i64 - window[i - 1] as i64;
692 if diff != first_diff {
693 return None;
694 }
695 }
696
697 if first_diff > 0 {
698 Some(first_diff as usize)
699 } else {
700 None
701 }
702 }
703}
704
705impl CycleDetector {
706 fn new() -> Self {
707 Self {
708 candidate_cycles: HashMap::new(),
709 max_cycle_length: 20,
710 }
711 }
712
713 fn analyze(&mut self, access_history: &VecDeque<AccessEvent>) {
714 let indices: Vec<usize> = access_history.iter().map(|e| e.index).collect();
715
716 for cycle_len in 2..=self.max_cycle_length.min(indices.len() / 2) {
718 if indices.len() >= cycle_len * 2 {
719 let potential_cycle = &indices[indices.len() - cycle_len..];
720 let prev_cycle = &indices[indices.len() - cycle_len * 2..indices.len() - cycle_len];
721
722 if potential_cycle == prev_cycle {
723 *self
724 .candidate_cycles
725 .entry(potential_cycle.to_vec())
726 .or_insert(0) += 1;
727 }
728 }
729 }
730 }
731}
732
733impl HotspotDetector {
734 fn new() -> Self {
735 Self {
736 access_counts: HashMap::new(),
737 temporal_windows: VecDeque::new(),
738 window_size: Duration::from_secs(60),
739 }
740 }
741
742 fn analyze(&mut self, access_history: &VecDeque<AccessEvent>) {
743 for event in access_history {
745 *self.access_counts.entry(event.index).or_insert(0) += 1;
746 }
747
748 if let Some(latest_event) = access_history.back() {
750 let cutoff_time = latest_event.timestamp - self.window_size;
751
752 while let Some(front_window) = self.temporal_windows.front() {
754 if front_window.is_empty() {
755 self.temporal_windows.pop_front();
756 } else {
757 break;
758 }
759 }
760
761 let mut recent_window = HashMap::new();
763 for event in access_history {
764 if event.timestamp >= cutoff_time {
765 *recent_window.entry(event.index).or_insert(0) += 1;
766 }
767 }
768
769 if !recent_window.is_empty() {
770 self.temporal_windows.push_back(recent_window);
771 }
772 }
773 }
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779 use tenflowers_core::Tensor;
780
781 #[test]
782 fn test_optimizer_creation() {
783 let config = PrefetchOptimizerConfig::default();
784 let optimizer: StreamPrefetchOptimizer<f32> = StreamPrefetchOptimizer::new(config);
785
786 assert_eq!(optimizer.config.max_buffer_size, 1000);
787 assert_eq!(optimizer.config.worker_count, 2);
788 }
789
790 #[test]
791 fn test_access_pattern_analyzer() {
792 let config = PrefetchOptimizerConfig {
793 prediction_confidence_threshold: 0.5, ..Default::default()
795 };
796 let mut analyzer = AccessPatternAnalyzer::new(config);
797
798 for i in 0..15 {
800 let event = AccessEvent {
801 index: i,
802 timestamp: Instant::now(),
803 access_type: AccessType::Sequential,
804 context: AccessContext {
805 epoch: Some(0),
806 batch_index: Some(i / 4),
807 worker_id: Some(0),
808 },
809 };
810 analyzer.record_access(event);
811 }
812
813 let _predictions = analyzer.get_predictions();
814 assert!(analyzer.access_history.len() == 15);
817 }
818
819 #[test]
820 fn test_prefetch_buffer() {
821 let mut buffer: PrefetchBuffer<f32> = PrefetchBuffer::new(5);
822
823 let sample_data = (
824 Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("test: tensor creation should succeed"),
825 Tensor::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed"),
826 );
827
828 let buffered_sample = BufferedSample {
829 data: sample_data,
830 load_time: Instant::now(),
831 access_count: 0,
832 prediction_confidence: 0.8,
833 };
834
835 buffer.add_sample(0, buffered_sample);
836 assert_eq!(buffer.current_size.load(Ordering::Relaxed), 1);
837
838 let retrieved = buffer.get_sample(0);
839 assert!(retrieved.is_some());
840 assert_eq!(
841 retrieved
842 .expect("test: operation should succeed")
843 .access_count,
844 1
845 );
846 }
847
848 #[test]
849 fn test_stride_detector() {
850 let mut detector = StrideDetector::new();
851
852 let events: Vec<AccessEvent> = (0..10)
854 .map(|i| AccessEvent {
855 index: i * 3, timestamp: Instant::now(),
857 access_type: AccessType::Sequential,
858 context: AccessContext {
859 epoch: Some(0),
860 batch_index: None,
861 worker_id: None,
862 },
863 })
864 .collect();
865
866 let access_history: VecDeque<AccessEvent> = events.into();
867 detector.analyze(&access_history);
868
869 assert!(detector.candidate_strides.contains_key(&3));
870 }
871
872 #[test]
873 fn test_metrics_tracking() {
874 let config = PrefetchOptimizerConfig::default();
875 let optimizer: StreamPrefetchOptimizer<f32> = StreamPrefetchOptimizer::new(config);
876
877 let metrics = optimizer.get_metrics();
878 assert_eq!(metrics.hit_rate, 0.0);
879 assert_eq!(metrics.patterns_learned, 0);
880 }
881}