scirs2_core/memory_efficient/
prefetch.rs1use std::collections::{HashSet, VecDeque};
15#[cfg(feature = "memory_compression")]
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant};
18
19#[cfg(feature = "memory_compression")]
20use super::compressed_memmap::CompressedMemMappedArray;
21use crate::error::CoreResult;
22#[cfg(feature = "memory_compression")]
23use crate::error::{CoreError, ErrorContext};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum AccessPattern {
28 Sequential,
30
31 Strided(usize),
33
34 Random,
36
37 Custom,
39}
40
41#[derive(Debug, Clone)]
43pub struct PrefetchConfig {
44 pub enabled: bool,
46
47 pub prefetch_count: usize,
49
50 pub history_size: usize,
52
53 pub min_pattern_length: usize,
55
56 pub async_prefetch: bool,
58
59 pub prefetch_timeout: Duration,
61}
62
63impl Default for PrefetchConfig {
64 fn default() -> Self {
65 Self {
66 enabled: true,
67 prefetch_count: 2,
68 history_size: 32,
69 min_pattern_length: 4,
70 async_prefetch: true,
71 prefetch_timeout: Duration::from_millis(100),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Default)]
78pub struct PrefetchConfigBuilder {
79 config: PrefetchConfig,
80}
81
82impl PrefetchConfigBuilder {
83 pub fn new() -> Self {
85 Self::default()
86 }
87
88 pub const fn enabled(mut self, enabled: bool) -> Self {
90 self.config.enabled = enabled;
91 self
92 }
93
94 pub const fn prefetch_count(mut self, count: usize) -> Self {
96 self.config.prefetch_count = count;
97 self
98 }
99
100 pub const fn history_size(mut self, size: usize) -> Self {
102 self.config.history_size = size;
103 self
104 }
105
106 pub const fn min_pattern_length(mut self, length: usize) -> Self {
108 self.config.min_pattern_length = length;
109 self
110 }
111
112 pub const fn async_prefetch(mut self, asyncprefetch: bool) -> Self {
114 self.config.async_prefetch = asyncprefetch;
115 self
116 }
117
118 pub const fn prefetch_timeout(mut self, timeout: Duration) -> Self {
120 self.config.prefetch_timeout = timeout;
121 self
122 }
123
124 pub fn build(self) -> PrefetchConfig {
126 self.config
127 }
128}
129
130pub trait AccessPatternTracker: std::fmt::Debug {
132 fn record_access(&mut self, blockidx: usize);
134
135 fn predict_next_blocks(&self, count: usize) -> Vec<usize>;
137
138 fn current_pattern(&self) -> AccessPattern;
140
141 fn clear_history(&mut self);
143}
144
145#[derive(Debug)]
147pub struct BlockAccessTracker {
148 config: PrefetchConfig,
150
151 history: VecDeque<usize>,
153
154 current_pattern: AccessPattern,
156
157 stride: Option<usize>,
159
160 last_update: Instant,
162}
163
164impl BlockAccessTracker {
165 pub fn new(config: PrefetchConfig) -> Self {
167 let history_size = config.history_size;
168 Self {
169 config,
170 history: VecDeque::with_capacity(history_size),
171 current_pattern: AccessPattern::Random,
172 stride: None,
173 last_update: Instant::now(),
174 }
175 }
176
177 fn detect_pattern(&mut self) {
179 if self.history.len() < self.config.min_pattern_length {
180 self.current_pattern = AccessPattern::Random;
182 return;
183 }
184
185 let mut is_sequential = true;
187 let front = match self.history.front() {
188 Some(v) => *v,
189 None => {
190 self.current_pattern = AccessPattern::Random;
191 return;
192 }
193 };
194 let mut prev = front;
195
196 for &block_idx in self.history.iter().skip(1) {
197 if block_idx != prev + 1 {
198 is_sequential = false;
199 break;
200 }
201 prev = block_idx;
202 }
203
204 if is_sequential {
205 self.current_pattern = AccessPattern::Sequential;
206 return;
207 }
208
209 let second = match self.history.get(1) {
211 Some(v) => *v,
212 None => {
213 self.current_pattern = AccessPattern::Random;
214 return;
215 }
216 };
217 let front2 = match self.history.front() {
218 Some(v) => *v,
219 None => {
220 self.current_pattern = AccessPattern::Random;
221 return;
222 }
223 };
224 let mut is_strided = true;
225 let stride = if second >= front2 {
227 second - front2
228 } else {
229 self.current_pattern = AccessPattern::Random;
230 return;
231 };
232 prev = front2;
233
234 for &block_idx in self.history.iter().skip(1) {
235 if block_idx != prev + stride {
236 is_strided = false;
237 break;
238 }
239 prev = block_idx;
240 }
241
242 if is_strided {
243 self.current_pattern = AccessPattern::Strided(stride);
244 self.stride = Some(stride);
245 return;
246 }
247
248 self.current_pattern = AccessPattern::Random;
250 }
251}
252
253impl AccessPatternTracker for BlockAccessTracker {
254 fn record_access(&mut self, blockidx: usize) {
255 self.history.push_back(blockidx);
257
258 if self.history.len() > self.config.history_size {
259 self.history.pop_front();
260 }
261
262 if self.history.len() >= self.config.min_pattern_length {
264 self.detect_pattern();
265 }
266
267 self.last_update = Instant::now();
269 }
270
271 fn predict_next_blocks(&self, count: usize) -> Vec<usize> {
272 if self.history.is_empty() {
273 return Vec::new();
274 }
275
276 let mut predictions = Vec::with_capacity(count);
277 let latest = match self.history.back() {
278 Some(v) => *v,
279 None => return Vec::new(),
280 };
281
282 match self.current_pattern {
283 AccessPattern::Sequential => {
284 for i in 1..=count {
286 predictions.push(latest + i);
287 }
288 }
289 AccessPattern::Strided(stride) => {
290 for i in 1..=count {
292 predictions.push(latest + stride * i);
293 }
294 }
295 _ => {
296 if latest > 0 {
299 predictions.push(latest - 1);
300 }
301 predictions.push(latest + 1);
302
303 let mut offset = 2;
305 while predictions.len() < count {
306 if latest >= offset {
307 predictions.push(latest - offset);
308 }
309 predictions.push(latest + offset);
310 offset += 1;
311 }
312
313 predictions.truncate(count);
315 }
316 }
317
318 predictions
319 }
320
321 fn current_pattern(&self) -> AccessPattern {
322 self.current_pattern
323 }
324
325 fn clear_history(&mut self) {
326 self.history.clear();
327 self.current_pattern = AccessPattern::Random;
328 self.stride = None;
329 }
330}
331
332#[derive(Debug)]
334#[allow(dead_code)]
335pub struct PrefetchingState {
336 config: PrefetchConfig,
338
339 tracker: Box<dyn AccessPatternTracker + Send + Sync>,
341
342 prefetching: HashSet<usize>,
344
345 prefetched: HashSet<usize>,
347
348 #[allow(dead_code)]
350 stats: PrefetchStats,
351}
352
353#[derive(Debug, Default, Clone)]
355pub struct PrefetchStats {
356 pub prefetch_count: usize,
358
359 pub prefetch_hits: usize,
361
362 pub prefetch_misses: usize,
364
365 pub hit_rate: f64,
367}
368
369impl PrefetchingState {
370 #[allow(dead_code)]
372 pub fn new(config: PrefetchConfig) -> Self {
373 Self {
374 tracker: Box::new(BlockAccessTracker::new(config.clone())),
375 config,
376 prefetching: HashSet::new(),
377 prefetched: HashSet::new(),
378 stats: PrefetchStats::default(),
379 }
380 }
381
382 #[allow(dead_code)]
384 pub fn idx(&mut self, blockidx: usize) {
385 self.tracker.record_access(blockidx);
386
387 if self.prefetched.contains(&blockidx) {
389 self.stats.prefetch_hits += 1;
390 self.prefetched.remove(&blockidx);
391 } else {
392 self.stats.prefetch_misses += 1;
393 }
394
395 let total = self.stats.prefetch_hits + self.stats.prefetch_misses;
397 if total > 0 {
398 self.stats.hit_rate = self.stats.prefetch_hits as f64 / total as f64;
399 }
400 }
401
402 #[allow(dead_code)]
404 pub fn get_blocks_to_prefetch(&self) -> Vec<usize> {
405 if !self.config.enabled {
406 return Vec::new();
407 }
408
409 let predicted = self.tracker.predict_next_blocks(self.config.prefetch_count);
411
412 predicted
414 .into_iter()
415 .filter(|&block_idx| {
416 !self.prefetched.contains(&block_idx) && !self.prefetching.contains(&block_idx)
417 })
418 .collect()
419 }
420
421 #[allow(dead_code)]
423 pub fn idx_2(&mut self, blockidx: usize) {
424 self.prefetching.insert(blockidx);
425 }
426
427 #[allow(dead_code)]
429 pub fn idx_3(&mut self, blockidx: usize) {
430 self.prefetching.remove(&blockidx);
431 self.prefetched.insert(blockidx);
432 self.stats.prefetch_count += 1;
433 }
434
435 #[allow(dead_code)]
437 pub fn stats(&self) -> PrefetchStats {
438 self.stats.clone()
439 }
440}
441
442pub trait Prefetching {
444 fn enable_prefetching(&mut self, config: PrefetchConfig) -> CoreResult<()>;
446
447 fn disable_prefetching(&mut self) -> CoreResult<()>;
449
450 fn prefetch_stats(&self) -> CoreResult<PrefetchStats>;
452
453 fn prefetch_block_by_idx_by_idx(&mut self, idx: usize) -> CoreResult<()>;
455
456 fn prefetch_indices(&mut self, indices: &[usize]) -> CoreResult<()>;
458
459 fn clear_prefetch_state(&mut self) -> CoreResult<()>;
461}
462
463#[cfg(feature = "memory_compression")]
465#[derive(Debug)]
466pub struct PrefetchingCompressedArray<A: Clone + Copy + 'static + Send + Sync> {
467 array: CompressedMemMappedArray<A>,
469
470 prefetch_state: Arc<Mutex<PrefetchingState>>,
472
473 prefetching_enabled: bool,
475
476 #[allow(dead_code)] prefetch_thread: Option<std::thread::JoinHandle<()>>,
479
480 #[allow(dead_code)] prefetch_sender: Option<std::sync::mpsc::Sender<PrefetchCommand>>,
483}
484
485#[cfg(feature = "memory_compression")]
487enum PrefetchCommand {
488 Prefetch(usize),
490
491 Stop,
493}
494
495#[cfg(feature = "memory_compression")]
496impl<A: Clone + Copy + 'static + Send + Sync> PrefetchingCompressedArray<A> {
497 pub fn new(array: CompressedMemMappedArray<A>) -> Self {
499 let prefetch_state = Arc::new(Mutex::new(PrefetchingState::new(PrefetchConfig::default())));
501
502 Self {
503 array,
504 prefetch_state,
505 prefetching_enabled: false,
506 prefetch_thread: None,
507 prefetch_sender: None,
508 }
509 }
510
511 pub fn new_with_config(
513 array: CompressedMemMappedArray<A>,
514 config: PrefetchConfig,
515 ) -> CoreResult<Self> {
516 let mut prefetching_array = Self::new(array);
517 prefetching_array.enable_prefetching(config)?;
518 Ok(prefetching_array)
519 }
520
521 fn start_background_prefetching(
523 &mut self,
524 state: Arc<Mutex<PrefetchingState>>,
525 ) -> CoreResult<()> {
526 let (sender, receiver) = std::sync::mpsc::channel();
528 self.prefetch_sender = Some(sender);
529
530 let array = self.array.clone();
532 let prefetch_state = state.clone();
533
534 let timeout = {
536 let guard = self.prefetch_state.lock().map_err(|_| {
537 CoreError::MutexError(ErrorContext::new(
538 "Failed to lock prefetch _state".to_string(),
539 ))
540 })?;
541 guard.config.prefetch_timeout
542 };
543
544 let thread = std::thread::spawn(move || {
546 loop {
548 match receiver.recv_timeout(timeout) {
550 Ok(PrefetchCommand::Prefetch(block_idx)) => {
551 {
553 if let Ok(mut guard) = prefetch_state.lock() {
554 guard.idx_2(block_idx);
555 }
556 }
557
558 if array.preload_block(block_idx).is_ok() {
560 if let Ok(mut guard) = prefetch_state.lock() {
562 guard.idx_3(block_idx);
563 }
564 }
565 }
566 Ok(PrefetchCommand::Stop) => {
567 break;
569 }
570 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
571 if let Ok(guard) = prefetch_state.lock() {
573 let blocks = guard.get_blocks_to_prefetch();
574
575 if !blocks.is_empty() {
578 drop(guard);
579
580 for &block_idx in &blocks {
581 if let Ok(mut guard) = prefetch_state.lock() {
583 guard.idx_2(block_idx);
584 }
585
586 if array.preload_block(block_idx).is_ok() {
588 if let Ok(mut guard) = prefetch_state.lock() {
590 guard.idx_3(block_idx);
591 }
592 }
593 }
594 }
595 }
596 }
597 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
598 break;
600 }
601 }
602 }
603 });
604
605 self.prefetch_thread = Some(thread);
606 Ok(())
607 }
608
609 fn stop_prefetch_thread(&mut self) -> CoreResult<()> {
611 if let Some(sender) = self.prefetch_sender.take() {
612 sender.send(PrefetchCommand::Stop).map_err(|_| {
614 CoreError::ThreadError(ErrorContext::new("Failed to send stop command".to_string()))
615 })?;
616
617 if let Some(thread) = self.prefetch_thread.take() {
619 thread.join().map_err(|_| {
620 CoreError::ThreadError(ErrorContext::new(
621 "Failed to join prefetch thread".to_string(),
622 ))
623 })?;
624 }
625 }
626
627 Ok(())
628 }
629
630 pub const fn inner(&self) -> &CompressedMemMappedArray<A> {
632 &self.array
633 }
634
635 pub fn inner_mut(&mut self) -> &mut CompressedMemMappedArray<A> {
637 &mut self.array
638 }
639
640 fn request_prefetch(&self, blockidx: usize) -> CoreResult<()> {
642 if let Some(sender) = &self.prefetch_sender {
643 sender
644 .send(PrefetchCommand::Prefetch(blockidx))
645 .map_err(|_| {
646 CoreError::ThreadError(ErrorContext::new(
647 "Failed to send prefetch command".to_string(),
648 ))
649 })?;
650 }
651
652 Ok(())
653 }
654}
655
656#[cfg(feature = "memory_compression")]
657impl<A: Clone + Copy + 'static + Send + Sync> Prefetching for PrefetchingCompressedArray<A> {
658 fn enable_prefetching(&mut self, config: PrefetchConfig) -> CoreResult<()> {
659 if self.prefetching_enabled {
661 let current_config = {
663 let guard = self.prefetch_state.lock().map_err(|_| {
664 CoreError::MutexError(ErrorContext::new(
665 "Failed to lock prefetch state".to_string(),
666 ))
667 })?;
668 guard.config.clone()
669 };
670
671 if current_config.async_prefetch == config.async_prefetch
672 && current_config.prefetch_count == config.prefetch_count
673 && current_config.history_size == config.history_size
674 {
675 let mut guard = self.prefetch_state.lock().map_err(|_| {
677 CoreError::MutexError(ErrorContext::new(
678 "Failed to lock prefetch state".to_string(),
679 ))
680 })?;
681 guard.config = config;
682 return Ok(());
683 }
684
685 self.disable_prefetching()?;
687 }
688
689 let prefetch_state = Arc::new(Mutex::new(PrefetchingState::new(config.clone())));
691 self.prefetch_state = prefetch_state.clone();
692
693 if config.async_prefetch {
695 self.start_background_prefetching(prefetch_state)?;
696 }
697
698 self.prefetching_enabled = true;
699 Ok(())
700 }
701
702 fn disable_prefetching(&mut self) -> CoreResult<()> {
703 if self.prefetching_enabled {
704 self.stop_prefetch_thread()?;
706
707 let mut guard = self.prefetch_state.lock().map_err(|_| {
709 CoreError::MutexError(ErrorContext::new(
710 "Failed to lock prefetch state".to_string(),
711 ))
712 })?;
713
714 guard.config.enabled = false;
716
717 self.prefetching_enabled = false;
718 }
719
720 Ok(())
721 }
722
723 fn prefetch_stats(&self) -> CoreResult<PrefetchStats> {
724 let guard = self.prefetch_state.lock().map_err(|_| {
725 CoreError::MutexError(ErrorContext::new(
726 "Failed to lock prefetch state".to_string(),
727 ))
728 })?;
729
730 Ok(guard.stats())
731 }
732
733 fn prefetch_block_by_idx_by_idx(&mut self, blockidx: usize) -> CoreResult<()> {
734 if !self.prefetching_enabled {
735 return Ok(());
736 }
737
738 let should_prefetch = {
740 let guard = self.prefetch_state.lock().map_err(|_| {
741 CoreError::MutexError(ErrorContext::new(
742 "Failed to lock prefetch state".to_string(),
743 ))
744 })?;
745
746 !guard.prefetched.contains(&blockidx) && !guard.prefetching.contains(&blockidx)
747 };
748
749 if should_prefetch {
750 let is_async = {
752 let guard = self.prefetch_state.lock().map_err(|_| {
753 CoreError::MutexError(ErrorContext::new(
754 "Failed to lock prefetch state".to_string(),
755 ))
756 })?;
757
758 guard.config.async_prefetch
759 };
760
761 if is_async {
762 self.request_prefetch(blockidx)?;
764 } else {
765 {
767 let mut guard = self.prefetch_state.lock().map_err(|_| {
768 CoreError::MutexError(ErrorContext::new(
769 "Failed to lock prefetch state".to_string(),
770 ))
771 })?;
772
773 guard.idx_2(blockidx);
774 }
775
776 self.array.preload_block(blockidx)?;
778
779 let mut guard = self.prefetch_state.lock().map_err(|_| {
781 CoreError::MutexError(ErrorContext::new(
782 "Failed to lock prefetch state".to_string(),
783 ))
784 })?;
785
786 guard.idx_3(blockidx);
787 }
788 }
789
790 Ok(())
791 }
792
793 fn prefetch_indices(&mut self, indices: &[usize]) -> CoreResult<()> {
794 if !self.prefetching_enabled {
795 return Ok(());
796 }
797
798 for &block_idx in indices {
799 self.prefetch_block_by_idx_by_idx(block_idx)?;
800 }
801
802 Ok(())
803 }
804
805 fn clear_prefetch_state(&mut self) -> CoreResult<()> {
806 let mut guard = self.prefetch_state.lock().map_err(|_| {
807 CoreError::MutexError(ErrorContext::new(
808 "Failed to lock prefetch state".to_string(),
809 ))
810 })?;
811
812 guard.prefetched.clear();
813 guard.prefetching.clear();
814 guard.tracker.clear_history();
815
816 Ok(())
817 }
818}
819
820#[cfg(feature = "memory_compression")]
822impl<A: Clone + Copy + 'static + Send + Sync> CompressedMemMappedArray<A> {
823 pub fn with_prefetching(self) -> PrefetchingCompressedArray<A> {
825 PrefetchingCompressedArray::new(self)
826 }
827
828 pub fn with_prefetching_config(
830 self,
831 config: PrefetchConfig,
832 ) -> CoreResult<PrefetchingCompressedArray<A>> {
833 PrefetchingCompressedArray::new_with_config(self, config)
834 }
835}
836
837#[cfg(feature = "memory_compression")]
839impl<A> std::ops::Deref for PrefetchingCompressedArray<A>
840where
841 A: Clone + Copy + 'static + Send + Sync,
842{
843 type Target = CompressedMemMappedArray<A>;
844
845 fn deref(&self) -> &Self::Target {
846 &self.array
847 }
848}
849
850#[cfg(feature = "memory_compression")]
852impl<A: Clone + Copy + 'static + Send + Sync> PrefetchingCompressedArray<A> {
853 pub fn get(&self, indices: &[usize]) -> CoreResult<A> {
855 let flat_index = self.calculate_flat_index(indices)?;
857 let block_idx = flat_index / self.metadata().block_size;
858
859 if self.prefetching_enabled {
861 let mut guard = self.prefetch_state.lock().map_err(|_| {
862 CoreError::MutexError(ErrorContext::new(
863 "Failed to lock prefetch state".to_string(),
864 ))
865 })?;
866
867 guard.idx(block_idx);
868
869 let to_prefetch = guard.get_blocks_to_prefetch();
871
872 drop(guard);
874
875 for &idx in &to_prefetch {
879 if let Err(_e) = self.request_prefetch(idx) {
880 }
882 }
883 }
884
885 self.array.get(indices)
887 }
888
889 fn calculate_flat_index(&self, indices: &[usize]) -> CoreResult<usize> {
891 if indices.len() != self.metadata().shape.len() {
893 return Err(CoreError::DimensionError(ErrorContext::new(format!(
894 "Expected {} indices, got {}",
895 self.metadata().shape.len(),
896 indices.len()
897 ))));
898 }
899
900 for (_, &idx) in indices.iter().enumerate() {
901 if idx >= self.metadata().shape[0] {
902 return Err(CoreError::IndexError(ErrorContext::new(format!(
903 "Index {} out of bounds for dimension {} (max {})",
904 idx,
905 0,
906 self.metadata().shape[0] - 1
907 ))));
908 }
909 }
910
911 let mut flat_index = 0;
913 let mut stride = 1;
914 for i in (0..indices.len()).rev() {
915 flat_index += indices[i] * stride;
916 if i > 0 {
917 stride *= self.metadata().shape[i];
918 }
919 }
920
921 Ok(flat_index)
922 }
923
924 pub fn slice(
926 &self,
927 ranges: &[(usize, usize)],
928 ) -> CoreResult<crate::ndarray::Array<A, crate::ndarray::IxDyn>> {
929 if self.prefetching_enabled {
931 let blocks = self.calculate_blocks_for_slice(ranges)?;
933
934 let mut guard = self.prefetch_state.lock().map_err(|_| {
936 CoreError::MutexError(ErrorContext::new(
937 "Failed to lock prefetch state".to_string(),
938 ))
939 })?;
940
941 for &block_idx in &blocks {
943 guard.idx(block_idx);
944 }
945
946 let to_prefetch = guard.get_blocks_to_prefetch();
948
949 drop(guard);
951
952 for &idx in &to_prefetch {
956 if let Err(_e) = self.request_prefetch(idx) {
957 }
959 }
960 }
961
962 self.array.slice(ranges)
964 }
965
966 fn calculate_blocks_for_slice(&self, ranges: &[(usize, usize)]) -> CoreResult<HashSet<usize>> {
968 if ranges.len() != self.metadata().shape.len() {
970 return Err(CoreError::DimensionError(ErrorContext::new(format!(
971 "Expected {} ranges, got {}",
972 self.metadata().shape.len(),
973 ranges.len()
974 ))));
975 }
976
977 let mut resultshape = Vec::with_capacity(ranges.len());
979 for (_, &(start, end)) in ranges.iter().enumerate() {
980 if start >= end {
981 return Err(CoreError::ValueError(ErrorContext::new(format!(
982 "Invalid range for dimension {}: {}..{}",
983 0, start, end
984 ))));
985 }
986 if end > self.metadata().shape[0] {
987 return Err(CoreError::IndexError(ErrorContext::new(format!(
988 "Range {}..{} out of bounds for dimension {} (max {})",
989 start,
990 end,
991 0,
992 self.metadata().shape[0]
993 ))));
994 }
995 resultshape.push(end - start);
996 }
997
998 let mut strides = Vec::with_capacity(self.metadata().shape.len());
1000 let mut stride = 1;
1001 for i in (0..self.metadata().shape.len()).rev() {
1002 strides.push(stride);
1003 if i > 0 {
1004 stride *= self.metadata().shape[i];
1005 }
1006 }
1007 strides.reverse();
1008
1009 let mut blocks = HashSet::new();
1011 let block_size = self.metadata().block_size;
1012
1013 let mut corners = Vec::with_capacity(1 << ranges.len());
1015 corners.push(vec![0; ranges.len()]);
1016
1017 for dim in 0..ranges.len() {
1018 let mut new_corners = Vec::new();
1019 for corner in &corners {
1020 let mut corner1 = corner.clone();
1021 let mut corner2 = corner.clone();
1022 corner1[dim] = 0;
1023 corner2[dim] = resultshape[dim] - 1;
1024 new_corners.push(corner1);
1025 new_corners.push(corner2);
1026 }
1027 corners = new_corners;
1028 }
1029
1030 for corner in corners {
1032 let mut flat_index = 0;
1033 for (dim, &offset) in corner.iter().enumerate() {
1034 flat_index += (ranges[dim].0 + offset) * strides[dim];
1035 }
1036
1037 let block_idx = flat_index / block_size;
1038 blocks.insert(block_idx);
1039 }
1040
1041 if blocks.len() > 1 {
1044 let min_block = match blocks.iter().min() {
1046 Some(v) => *v,
1047 None => return Ok(blocks),
1048 };
1049 let max_block = match blocks.iter().max() {
1050 Some(v) => *v,
1051 None => return Ok(blocks),
1052 };
1053
1054 for block_idx in min_block..=max_block {
1056 blocks.insert(block_idx);
1057 }
1058 }
1059
1060 Ok(blocks)
1061 }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067
1068 #[test]
1069 fn test_access_pattern_detection_sequential() {
1070 let config = PrefetchConfig {
1071 min_pattern_length: 4,
1072 ..Default::default()
1073 };
1074
1075 let mut tracker = BlockAccessTracker::new(config);
1076
1077 for i in 0..10 {
1079 tracker.record_access(i);
1080 }
1081
1082 assert_eq!(tracker.current_pattern(), AccessPattern::Sequential);
1084
1085 let predictions = tracker.predict_next_blocks(3);
1087 assert_eq!(predictions, vec![10, 11, 12]);
1088 }
1089
1090 #[test]
1091 fn test_access_pattern_detection_strided() {
1092 let config = PrefetchConfig {
1093 min_pattern_length: 4,
1094 ..Default::default()
1095 };
1096
1097 let mut tracker = BlockAccessTracker::new(config);
1098
1099 for i in (0..30).step_by(3) {
1101 tracker.record_access(i);
1102 }
1103
1104 assert_eq!(tracker.current_pattern(), AccessPattern::Strided(3));
1106
1107 let predictions = tracker.predict_next_blocks(3);
1109 assert_eq!(predictions, vec![30, 33, 36]);
1110 }
1111
1112 #[test]
1113 fn test_prefetching_state() {
1114 let config = PrefetchConfig {
1115 prefetch_count: 3,
1116 ..Default::default()
1117 };
1118
1119 let mut state = PrefetchingState::new(config);
1120
1121 for i in 0..5 {
1123 state.idx(i);
1124 }
1125
1126 let to_prefetch = state.get_blocks_to_prefetch();
1128 assert_eq!(to_prefetch, vec![5, 6, 7]);
1129
1130 for &block in &to_prefetch {
1132 state.prefetching.insert(block);
1134 }
1135
1136 state.prefetched.insert(5);
1138 state.prefetching.remove(&5);
1139
1140 state.idx(5);
1142
1143 let stats = state.stats();
1145 assert_eq!(stats.prefetch_hits, 1);
1146 assert_eq!(stats.prefetch_misses, 5); assert!(stats.hit_rate > 0.0);
1148 }
1149}