1#[cfg(not(feature = "std"))]
16use alloc::{string::String, vec, vec::Vec};
17#[cfg(feature = "std")]
18use std::{collections::HashMap, sync::Arc};
19
20#[cfg(not(feature = "std"))]
21extern crate alloc;
22#[cfg(not(feature = "std"))]
23use alloc::collections::BTreeMap as HashMap;
24#[cfg(not(feature = "std"))]
25use alloc::sync::Arc;
26
27use crate::{
28 error::{Result, TorshError},
29 shape::Shape,
30 MemoryFormat,
31};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum AccessPattern {
36 Sequential,
38 Strided { stride: usize },
40 Random,
42 RowMajor,
44 ColumnMajor,
46 BlockWise { block_size: usize },
48 Diagonal,
50 Broadcast,
52}
53
54#[derive(Debug, Clone)]
56pub struct AccessStatistics {
57 pub total_accesses: u64,
59 pub cache_hits: u64,
61 pub cache_misses: u64,
63 pub average_stride: f64,
65 pub stride_variance: f64,
67 pub dominant_pattern: AccessPattern,
69 pub pattern_distribution: HashMap<AccessPattern, u64>,
71}
72
73#[derive(Debug, Clone)]
75pub struct AccessTracker {
76 shape: Shape,
78 memory_format: MemoryFormat,
80 recent_accesses: Vec<usize>,
82 max_history: usize,
84 stats: AccessStatistics,
86 cache_line_size: usize,
88}
89
90impl AccessTracker {
91 pub fn new(shape: Shape, memory_format: MemoryFormat) -> Self {
93 Self {
94 shape,
95 memory_format,
96 recent_accesses: Vec::with_capacity(1000),
97 max_history: 1000,
98 stats: AccessStatistics {
99 total_accesses: 0,
100 cache_hits: 0,
101 cache_misses: 0,
102 average_stride: 0.0,
103 stride_variance: 0.0,
104 dominant_pattern: AccessPattern::Random,
105 pattern_distribution: HashMap::new(),
106 },
107 cache_line_size: 64, }
109 }
110
111 pub fn with_cache_line_size(mut self, cache_line_size: usize) -> Self {
113 self.cache_line_size = cache_line_size;
114 self
115 }
116
117 pub fn record_access(&mut self, linear_index: usize) {
119 if self.recent_accesses.len() >= self.max_history {
121 self.recent_accesses.remove(0);
122 }
123 self.recent_accesses.push(linear_index);
124
125 self.stats.total_accesses += 1;
127
128 if self.recent_accesses.len() >= 2 {
130 let prev_index = self.recent_accesses[self.recent_accesses.len() - 2];
131 let stride = if linear_index > prev_index {
132 linear_index - prev_index
133 } else {
134 prev_index - linear_index
135 };
136
137 if stride * core::mem::size_of::<f32>() <= self.cache_line_size {
139 self.stats.cache_hits += 1;
140 } else {
141 self.stats.cache_misses += 1;
142 }
143 }
144
145 if self.stats.total_accesses % 100 == 0 {
147 self.analyze_pattern();
148 }
149 }
150
151 fn analyze_pattern(&mut self) {
153 if self.recent_accesses.len() < 10 {
154 return;
155 }
156
157 let mut strides = Vec::new();
159 for i in 1..self.recent_accesses.len() {
160 let stride = if self.recent_accesses[i] > self.recent_accesses[i - 1] {
161 self.recent_accesses[i] - self.recent_accesses[i - 1]
162 } else {
163 self.recent_accesses[i - 1] - self.recent_accesses[i]
164 };
165 strides.push(stride as f64);
166 }
167
168 let sum: f64 = strides.iter().sum();
170 let avg = sum / strides.len() as f64;
171 self.stats.average_stride = avg;
172
173 let variance_sum: f64 = strides.iter().map(|&s| (s - avg).powi(2)).sum();
174 self.stats.stride_variance = variance_sum / strides.len() as f64;
175
176 let pattern = self.detect_pattern(&strides);
178 *self.stats.pattern_distribution.entry(pattern).or_insert(0) += 1;
179
180 if let Some((&dominant, _)) = self
182 .stats
183 .pattern_distribution
184 .iter()
185 .max_by_key(|(_, &count)| count)
186 {
187 self.stats.dominant_pattern = dominant;
188 }
189 }
190
191 fn detect_pattern(&self, strides: &[f64]) -> AccessPattern {
193 if strides.is_empty() {
194 return AccessPattern::Random;
195 }
196
197 let avg = self.stats.average_stride;
198 let variance = self.stats.stride_variance;
199
200 if (avg - 1.0).abs() < 0.1 && variance < 0.5 {
202 return AccessPattern::Sequential;
203 }
204
205 if variance < avg * 0.2 && avg > 1.5 {
207 return AccessPattern::Strided {
208 stride: avg.round() as usize,
209 };
210 }
211
212 if let Some(row_len) = self.shape.dims().last() {
214 if (avg - *row_len as f64).abs() < 0.5 {
215 return AccessPattern::RowMajor;
216 }
217 }
218
219 if let Some(&first_dim) = self.shape.dims().first() {
221 if (avg - first_dim as f64).abs() < 0.5 {
222 return AccessPattern::ColumnMajor;
223 }
224 }
225
226 if variance < 1.0 && avg < 2.0 {
228 return AccessPattern::Broadcast;
229 }
230
231 AccessPattern::Random
233 }
234
235 pub fn statistics(&self) -> &AccessStatistics {
237 &self.stats
238 }
239
240 pub fn cache_hit_rate(&self) -> f64 {
242 if self.stats.total_accesses == 0 {
243 return 0.0;
244 }
245 self.stats.cache_hits as f64 / self.stats.total_accesses as f64
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct LayoutRecommendation {
252 pub current_format: MemoryFormat,
254 pub recommended_format: MemoryFormat,
256 pub expected_improvement: f64,
258 pub reason: String,
260 pub transformation_cost: TransformationCost,
262}
263
264#[derive(Debug, Clone)]
266pub struct TransformationCost {
267 pub memory_copies: usize,
269 pub estimated_time_us: f64,
271 pub memory_overhead_bytes: usize,
273}
274
275#[derive(Debug)]
277pub struct LayoutOptimizer {
278 trackers: HashMap<usize, Arc<AccessTracker>>,
280 optimization_threshold: f64,
282 aggressive: bool,
284}
285
286impl Default for LayoutOptimizer {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292impl LayoutOptimizer {
293 pub fn new() -> Self {
295 Self {
296 trackers: HashMap::new(),
297 optimization_threshold: 0.1, aggressive: false,
299 }
300 }
301
302 pub fn with_threshold(mut self, threshold: f64) -> Self {
304 self.optimization_threshold = threshold;
305 self
306 }
307
308 pub fn aggressive(mut self, enabled: bool) -> Self {
310 self.aggressive = enabled;
311 self
312 }
313
314 pub fn register_tensor(&mut self, tensor_id: usize, shape: Shape, format: MemoryFormat) {
316 let tracker = AccessTracker::new(shape, format);
317 self.trackers.insert(tensor_id, Arc::new(tracker));
318 }
319
320 pub fn record_access(&mut self, tensor_id: usize, linear_index: usize) -> Result<()> {
322 if let Some(tracker) = self.trackers.get_mut(&tensor_id) {
323 let mut tracker_mut = (**tracker).clone();
325 tracker_mut.record_access(linear_index);
326 *tracker = Arc::new(tracker_mut);
327 Ok(())
328 } else {
329 Err(TorshError::InvalidArgument(format!(
330 "Tensor {} not registered for tracking",
331 tensor_id
332 )))
333 }
334 }
335
336 pub fn recommend_layout(&self, tensor_id: usize) -> Result<Option<LayoutRecommendation>> {
338 let tracker = self.trackers.get(&tensor_id).ok_or_else(|| {
339 TorshError::InvalidArgument(format!("Tensor {} not registered", tensor_id))
340 })?;
341
342 let stats = tracker.statistics();
343
344 if stats.total_accesses < 100 {
346 return Ok(None);
347 }
348
349 let recommendation = self.analyze_and_recommend(tracker)?;
351
352 if recommendation.expected_improvement >= self.optimization_threshold {
354 Ok(Some(recommendation))
355 } else {
356 Ok(None)
357 }
358 }
359
360 fn analyze_and_recommend(&self, tracker: &AccessTracker) -> Result<LayoutRecommendation> {
362 let stats = tracker.statistics();
363 let current_format = tracker.memory_format;
364 let cache_hit_rate = tracker.cache_hit_rate();
365
366 match stats.dominant_pattern {
367 AccessPattern::Sequential | AccessPattern::RowMajor => {
368 if current_format != MemoryFormat::Contiguous {
370 Ok(LayoutRecommendation {
371 current_format,
372 recommended_format: MemoryFormat::Contiguous,
373 expected_improvement: 0.3, reason: "Sequential/row-major access pattern detected. Contiguous layout will improve cache locality.".to_string(),
375 transformation_cost: self.estimate_cost(&tracker.shape),
376 })
377 } else {
378 Ok(LayoutRecommendation {
379 current_format,
380 recommended_format: current_format,
381 expected_improvement: 0.0,
382 reason: "Already using optimal layout".to_string(),
383 transformation_cost: TransformationCost {
384 memory_copies: 0,
385 estimated_time_us: 0.0,
386 memory_overhead_bytes: 0,
387 },
388 })
389 }
390 }
391 AccessPattern::ColumnMajor => {
392 if current_format != MemoryFormat::ChannelsLast {
394 Ok(LayoutRecommendation {
395 current_format,
396 recommended_format: MemoryFormat::ChannelsLast,
397 expected_improvement: 0.25,
398 reason: "Column-major access detected. ChannelsLast layout will improve stride patterns.".to_string(),
399 transformation_cost: self.estimate_cost(&tracker.shape),
400 })
401 } else {
402 Ok(LayoutRecommendation {
403 current_format,
404 recommended_format: current_format,
405 expected_improvement: 0.0,
406 reason: "Already using optimal layout".to_string(),
407 transformation_cost: TransformationCost {
408 memory_copies: 0,
409 estimated_time_us: 0.0,
410 memory_overhead_bytes: 0,
411 },
412 })
413 }
414 }
415 AccessPattern::Strided { stride } => {
416 let improvement = if cache_hit_rate < 0.5 { 0.4 } else { 0.15 };
418 Ok(LayoutRecommendation {
419 current_format,
420 recommended_format: MemoryFormat::Contiguous,
421 expected_improvement: improvement,
422 reason: format!(
423 "Strided access (stride={}) with low cache hit rate ({}%). Contiguous layout recommended.",
424 stride,
425 (cache_hit_rate * 100.0) as u32
426 ),
427 transformation_cost: self.estimate_cost(&tracker.shape),
428 })
429 }
430 AccessPattern::BlockWise { block_size } => {
431 if self.aggressive {
432 Ok(LayoutRecommendation {
433 current_format,
434 recommended_format: MemoryFormat::Contiguous,
435 expected_improvement: 0.2,
436 reason: format!(
437 "Block-wise access (block_size={}) detected. Consider cache-friendly blocking.",
438 block_size
439 ),
440 transformation_cost: self.estimate_cost(&tracker.shape),
441 })
442 } else {
443 Ok(LayoutRecommendation {
444 current_format,
445 recommended_format: current_format,
446 expected_improvement: 0.0,
447 reason: "Block-wise access requires specialized optimization".to_string(),
448 transformation_cost: TransformationCost {
449 memory_copies: 0,
450 estimated_time_us: 0.0,
451 memory_overhead_bytes: 0,
452 },
453 })
454 }
455 }
456 AccessPattern::Random => {
457 Ok(LayoutRecommendation {
459 current_format,
460 recommended_format: current_format,
461 expected_improvement: 0.0,
462 reason: "Random access pattern - layout optimization unlikely to help"
463 .to_string(),
464 transformation_cost: TransformationCost {
465 memory_copies: 0,
466 estimated_time_us: 0.0,
467 memory_overhead_bytes: 0,
468 },
469 })
470 }
471 AccessPattern::Broadcast => Ok(LayoutRecommendation {
472 current_format,
473 recommended_format: current_format,
474 expected_improvement: 0.0,
475 reason: "Broadcast-like access - current layout is fine".to_string(),
476 transformation_cost: TransformationCost {
477 memory_copies: 0,
478 estimated_time_us: 0.0,
479 memory_overhead_bytes: 0,
480 },
481 }),
482 AccessPattern::Diagonal => Ok(LayoutRecommendation {
483 current_format,
484 recommended_format: current_format,
485 expected_improvement: 0.0,
486 reason: "Diagonal access - specialized algorithm recommended".to_string(),
487 transformation_cost: TransformationCost {
488 memory_copies: 0,
489 estimated_time_us: 0.0,
490 memory_overhead_bytes: 0,
491 },
492 }),
493 }
494 }
495
496 fn estimate_cost(&self, shape: &Shape) -> TransformationCost {
498 let numel = shape.numel();
499 let element_size = 4; let total_bytes = numel * element_size;
501
502 let copy_time_us = (total_bytes as f64 / 10_000.0) * 1_000_000.0;
504
505 TransformationCost {
506 memory_copies: 1,
507 estimated_time_us: copy_time_us,
508 memory_overhead_bytes: total_bytes,
509 }
510 }
511
512 pub fn tracked_tensors(&self) -> Vec<usize> {
514 self.trackers.keys().copied().collect()
515 }
516
517 pub fn get_statistics(&self, tensor_id: usize) -> Result<AccessStatistics> {
519 let tracker = self.trackers.get(&tensor_id).ok_or_else(|| {
520 TorshError::InvalidArgument(format!("Tensor {} not registered", tensor_id))
521 })?;
522 Ok(tracker.statistics().clone())
523 }
524
525 pub fn clear_tensor(&mut self, tensor_id: usize) {
527 self.trackers.remove(&tensor_id);
528 }
529
530 pub fn clear_all(&mut self) {
532 self.trackers.clear();
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn test_access_tracker_creation() {
542 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
543 let tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
544 assert_eq!(tracker.statistics().total_accesses, 0);
545 }
546
547 #[test]
548 fn test_sequential_access_detection() {
549 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
550 let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
551
552 for i in 0..1000 {
554 tracker.record_access(i);
555 }
556
557 let stats = tracker.statistics();
558 assert!(stats.total_accesses == 1000);
559 assert!(stats.cache_hits > stats.cache_misses); }
561
562 #[test]
563 fn test_strided_access_detection() {
564 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
565 let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
566
567 for i in 0..100 {
569 tracker.record_access(i * 10);
570 }
571
572 let stats = tracker.statistics();
573 assert!(stats.total_accesses == 100);
574 assert!(stats.average_stride > 8.0);
576 }
577
578 #[test]
579 fn test_random_access_detection() {
580 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
581 let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
582
583 let indices = [42, 1000, 5, 9999, 50, 7500, 200];
585 for &idx in &indices {
586 tracker.record_access(idx);
587 }
588
589 let stats = tracker.statistics();
590 assert!(stats.total_accesses == indices.len() as u64);
591 }
592
593 #[test]
594 fn test_cache_hit_rate() {
595 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
596 let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
597
598 for i in 0..100 {
600 tracker.record_access(i);
601 }
602
603 let hit_rate = tracker.cache_hit_rate();
604 assert!(hit_rate > 0.5); }
606
607 #[test]
608 fn test_layout_optimizer_creation() {
609 let optimizer = LayoutOptimizer::new();
610 assert!(optimizer.tracked_tensors().is_empty());
611 }
612
613 #[test]
614 fn test_register_and_track_tensor() {
615 let mut optimizer = LayoutOptimizer::new();
616 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
617
618 optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
619 assert_eq!(optimizer.tracked_tensors().len(), 1);
620 assert!(optimizer.tracked_tensors().contains(&1));
621 }
622
623 #[test]
624 fn test_record_access() {
625 let mut optimizer = LayoutOptimizer::new();
626 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
627
628 optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
629
630 for i in 0..50 {
631 optimizer
632 .record_access(1, i)
633 .expect("record_access should succeed");
634 }
635
636 let stats = optimizer
637 .get_statistics(1)
638 .expect("get_statistics should succeed");
639 assert_eq!(stats.total_accesses, 50);
640 }
641
642 #[test]
643 fn test_optimization_recommendation() {
644 let mut optimizer = LayoutOptimizer::new().with_threshold(0.05);
645 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
646
647 optimizer.register_tensor(1, shape, MemoryFormat::Strided);
648
649 for i in 0..200 {
651 optimizer
652 .record_access(1, i)
653 .expect("record_access should succeed");
654 }
655
656 let recommendation = optimizer
657 .recommend_layout(1)
658 .expect("recommend_layout should succeed");
659 assert!(recommendation.is_some());
660
661 if let Some(rec) = recommendation {
662 assert_eq!(rec.recommended_format, MemoryFormat::Contiguous);
664 assert!(rec.expected_improvement > 0.0);
665 }
666 }
667
668 #[test]
669 fn test_insufficient_data_no_recommendation() {
670 let mut optimizer = LayoutOptimizer::new();
671 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
672
673 optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
674
675 for i in 0..10 {
677 optimizer
678 .record_access(1, i)
679 .expect("record_access should succeed");
680 }
681
682 let recommendation = optimizer
683 .recommend_layout(1)
684 .expect("recommend_layout should succeed");
685 assert!(recommendation.is_none()); }
687
688 #[test]
689 fn test_clear_tensor() {
690 let mut optimizer = LayoutOptimizer::new();
691 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
692
693 optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
694 assert_eq!(optimizer.tracked_tensors().len(), 1);
695
696 optimizer.clear_tensor(1);
697 assert!(optimizer.tracked_tensors().is_empty());
698 }
699
700 #[test]
701 fn test_aggressive_optimization() {
702 let optimizer = LayoutOptimizer::new().aggressive(true);
703 assert!(optimizer.aggressive);
704 }
705
706 #[test]
707 fn test_transformation_cost_estimation() {
708 let optimizer = LayoutOptimizer::new();
709 let shape = Shape::from_array([1000, 1000]).expect("shape creation should succeed");
710
711 let cost = optimizer.estimate_cost(&shape);
712 assert!(cost.memory_copies > 0);
713 assert!(cost.estimated_time_us > 0.0);
714 assert!(cost.memory_overhead_bytes > 0);
715 }
716
717 #[test]
718 fn test_custom_cache_line_size() {
719 let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
720 let tracker = AccessTracker::new(shape, MemoryFormat::Contiguous).with_cache_line_size(128);
721
722 assert_eq!(tracker.cache_line_size, 128);
723 }
724}