Skip to main content

torsh_core/
layout_optimizer.rs

1//! Automatic Memory Layout Optimization based on Access Patterns
2//!
3//! This module provides intelligent memory layout optimization by:
4//! - Tracking tensor access patterns at runtime
5//! - Analyzing access patterns to determine optimal memory layouts
6//! - Recommending layout transformations for performance improvement
7//! - Providing cache-aware optimization strategies
8//!
9//! # SciRS2 POLICY COMPLIANCE
10//! This module uses scirs2_core abstractions exclusively:
11//! - ✅ Uses torsh_core::numeric for numerical traits
12//! - ✅ Uses torsh_core::parallel for parallel operations (when enabled)
13//! - ❌ NO direct external dependencies
14
15#[cfg(not(feature = "std"))]
16use alloc::{string::String, vec, vec::Vec};
17#[cfg(feature = "std")]
18use std::{collections::HashMap, sync::Arc};
19
20#[cfg(not(feature = "std"))]
21extern crate alloc;
22#[cfg(not(feature = "std"))]
23use alloc::collections::BTreeMap as HashMap;
24#[cfg(not(feature = "std"))]
25use alloc::sync::Arc;
26
27use crate::{
28    error::{Result, TorshError},
29    shape::Shape,
30    MemoryFormat,
31};
32
33/// Access pattern types that can be detected
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum AccessPattern {
36    /// Sequential access (consecutive elements)
37    Sequential,
38    /// Strided access (regular stride pattern)
39    Strided { stride: usize },
40    /// Random access (no clear pattern)
41    Random,
42    /// Row-major access (scanning rows)
43    RowMajor,
44    /// Column-major access (scanning columns)
45    ColumnMajor,
46    /// Block-wise access (accessing blocks of data)
47    BlockWise { block_size: usize },
48    /// Diagonal access (accessing diagonal elements)
49    Diagonal,
50    /// Broadcast-like access (repeated access to same elements)
51    Broadcast,
52}
53
54/// Statistics about memory access patterns
55#[derive(Debug, Clone)]
56pub struct AccessStatistics {
57    /// Total number of accesses recorded
58    pub total_accesses: u64,
59    /// Number of cache hits (estimated)
60    pub cache_hits: u64,
61    /// Number of cache misses (estimated)
62    pub cache_misses: u64,
63    /// Average stride between consecutive accesses
64    pub average_stride: f64,
65    /// Standard deviation of stride
66    pub stride_variance: f64,
67    /// Dominant access pattern
68    pub dominant_pattern: AccessPattern,
69    /// Pattern frequency distribution
70    pub pattern_distribution: HashMap<AccessPattern, u64>,
71}
72
73/// Access pattern tracker for a tensor
74#[derive(Debug, Clone)]
75pub struct AccessTracker {
76    /// Tensor shape being tracked
77    shape: Shape,
78    /// Current memory format
79    memory_format: MemoryFormat,
80    /// Recent access indices (circular buffer)
81    recent_accesses: Vec<usize>,
82    /// Maximum size of access history
83    max_history: usize,
84    /// Statistics accumulator
85    stats: AccessStatistics,
86    /// Cache line size (in bytes)
87    cache_line_size: usize,
88}
89
90impl AccessTracker {
91    /// Create a new access tracker
92    pub fn new(shape: Shape, memory_format: MemoryFormat) -> Self {
93        Self {
94            shape,
95            memory_format,
96            recent_accesses: Vec::with_capacity(1000),
97            max_history: 1000,
98            stats: AccessStatistics {
99                total_accesses: 0,
100                cache_hits: 0,
101                cache_misses: 0,
102                average_stride: 0.0,
103                stride_variance: 0.0,
104                dominant_pattern: AccessPattern::Random,
105                pattern_distribution: HashMap::new(),
106            },
107            cache_line_size: 64, // Common cache line size
108        }
109    }
110
111    /// Create with custom cache line size
112    pub fn with_cache_line_size(mut self, cache_line_size: usize) -> Self {
113        self.cache_line_size = cache_line_size;
114        self
115    }
116
117    /// Record a memory access
118    pub fn record_access(&mut self, linear_index: usize) {
119        // Add to recent accesses
120        if self.recent_accesses.len() >= self.max_history {
121            self.recent_accesses.remove(0);
122        }
123        self.recent_accesses.push(linear_index);
124
125        // Update statistics
126        self.stats.total_accesses += 1;
127
128        // Estimate cache hit/miss based on access pattern
129        if self.recent_accesses.len() >= 2 {
130            let prev_index = self.recent_accesses[self.recent_accesses.len() - 2];
131            let stride = if linear_index > prev_index {
132                linear_index - prev_index
133            } else {
134                prev_index - linear_index
135            };
136
137            // If stride is within cache line, likely a cache hit
138            if stride * core::mem::size_of::<f32>() <= self.cache_line_size {
139                self.stats.cache_hits += 1;
140            } else {
141                self.stats.cache_misses += 1;
142            }
143        }
144
145        // Analyze pattern periodically
146        if self.stats.total_accesses % 100 == 0 {
147            self.analyze_pattern();
148        }
149    }
150
151    /// Analyze the access pattern
152    fn analyze_pattern(&mut self) {
153        if self.recent_accesses.len() < 10 {
154            return;
155        }
156
157        // Calculate stride statistics
158        let mut strides = Vec::new();
159        for i in 1..self.recent_accesses.len() {
160            let stride = if self.recent_accesses[i] > self.recent_accesses[i - 1] {
161                self.recent_accesses[i] - self.recent_accesses[i - 1]
162            } else {
163                self.recent_accesses[i - 1] - self.recent_accesses[i]
164            };
165            strides.push(stride as f64);
166        }
167
168        // Calculate average and variance
169        let sum: f64 = strides.iter().sum();
170        let avg = sum / strides.len() as f64;
171        self.stats.average_stride = avg;
172
173        let variance_sum: f64 = strides.iter().map(|&s| (s - avg).powi(2)).sum();
174        self.stats.stride_variance = variance_sum / strides.len() as f64;
175
176        // Detect pattern based on stride statistics
177        let pattern = self.detect_pattern(&strides);
178        *self.stats.pattern_distribution.entry(pattern).or_insert(0) += 1;
179
180        // Update dominant pattern
181        if let Some((&dominant, _)) = self
182            .stats
183            .pattern_distribution
184            .iter()
185            .max_by_key(|(_, &count)| count)
186        {
187            self.stats.dominant_pattern = dominant;
188        }
189    }
190
191    /// Detect specific access pattern from stride data
192    fn detect_pattern(&self, strides: &[f64]) -> AccessPattern {
193        if strides.is_empty() {
194            return AccessPattern::Random;
195        }
196
197        let avg = self.stats.average_stride;
198        let variance = self.stats.stride_variance;
199
200        // Sequential: average stride ~1, low variance
201        if (avg - 1.0).abs() < 0.1 && variance < 0.5 {
202            return AccessPattern::Sequential;
203        }
204
205        // Strided: consistent stride, low variance
206        if variance < avg * 0.2 && avg > 1.5 {
207            return AccessPattern::Strided {
208                stride: avg.round() as usize,
209            };
210        }
211
212        // Row-major: stride equals row length
213        if let Some(row_len) = self.shape.dims().last() {
214            if (avg - *row_len as f64).abs() < 0.5 {
215                return AccessPattern::RowMajor;
216            }
217        }
218
219        // Column-major: stride equals column height
220        if let Some(&first_dim) = self.shape.dims().first() {
221            if (avg - first_dim as f64).abs() < 0.5 {
222                return AccessPattern::ColumnMajor;
223            }
224        }
225
226        // Broadcast: very low variance, repeated accesses
227        if variance < 1.0 && avg < 2.0 {
228            return AccessPattern::Broadcast;
229        }
230
231        // Default to random
232        AccessPattern::Random
233    }
234
235    /// Get current statistics
236    pub fn statistics(&self) -> &AccessStatistics {
237        &self.stats
238    }
239
240    /// Get cache hit rate
241    pub fn cache_hit_rate(&self) -> f64 {
242        if self.stats.total_accesses == 0 {
243            return 0.0;
244        }
245        self.stats.cache_hits as f64 / self.stats.total_accesses as f64
246    }
247}
248
249/// Layout optimization recommendation
250#[derive(Debug, Clone)]
251pub struct LayoutRecommendation {
252    /// Current memory format
253    pub current_format: MemoryFormat,
254    /// Recommended memory format
255    pub recommended_format: MemoryFormat,
256    /// Expected performance improvement (0.0 to 1.0)
257    pub expected_improvement: f64,
258    /// Reason for recommendation
259    pub reason: String,
260    /// Estimated transformation cost
261    pub transformation_cost: TransformationCost,
262}
263
264/// Cost of transforming memory layout
265#[derive(Debug, Clone)]
266pub struct TransformationCost {
267    /// Number of memory copies required
268    pub memory_copies: usize,
269    /// Estimated time in microseconds
270    pub estimated_time_us: f64,
271    /// Memory overhead during transformation
272    pub memory_overhead_bytes: usize,
273}
274
275/// Layout optimizer that analyzes access patterns and recommends layouts
276#[derive(Debug)]
277pub struct LayoutOptimizer {
278    /// Cache of access trackers per tensor
279    trackers: HashMap<usize, Arc<AccessTracker>>,
280    /// Optimization threshold (minimum improvement to recommend)
281    optimization_threshold: f64,
282    /// Enable aggressive optimizations
283    aggressive: bool,
284}
285
286impl Default for LayoutOptimizer {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292impl LayoutOptimizer {
293    /// Create a new layout optimizer
294    pub fn new() -> Self {
295        Self {
296            trackers: HashMap::new(),
297            optimization_threshold: 0.1, // 10% improvement threshold
298            aggressive: false,
299        }
300    }
301
302    /// Create with custom optimization threshold
303    pub fn with_threshold(mut self, threshold: f64) -> Self {
304        self.optimization_threshold = threshold;
305        self
306    }
307
308    /// Enable aggressive optimizations (may use more memory)
309    pub fn aggressive(mut self, enabled: bool) -> Self {
310        self.aggressive = enabled;
311        self
312    }
313
314    /// Register a tensor for tracking
315    pub fn register_tensor(&mut self, tensor_id: usize, shape: Shape, format: MemoryFormat) {
316        let tracker = AccessTracker::new(shape, format);
317        self.trackers.insert(tensor_id, Arc::new(tracker));
318    }
319
320    /// Record an access for a tensor
321    pub fn record_access(&mut self, tensor_id: usize, linear_index: usize) -> Result<()> {
322        if let Some(tracker) = self.trackers.get_mut(&tensor_id) {
323            // Make mutable copy for modification
324            let mut tracker_mut = (**tracker).clone();
325            tracker_mut.record_access(linear_index);
326            *tracker = Arc::new(tracker_mut);
327            Ok(())
328        } else {
329            Err(TorshError::InvalidArgument(format!(
330                "Tensor {} not registered for tracking",
331                tensor_id
332            )))
333        }
334    }
335
336    /// Get optimization recommendation for a tensor
337    pub fn recommend_layout(&self, tensor_id: usize) -> Result<Option<LayoutRecommendation>> {
338        let tracker = self.trackers.get(&tensor_id).ok_or_else(|| {
339            TorshError::InvalidArgument(format!("Tensor {} not registered", tensor_id))
340        })?;
341
342        let stats = tracker.statistics();
343
344        // Need sufficient data for recommendation
345        if stats.total_accesses < 100 {
346            return Ok(None);
347        }
348
349        // Analyze dominant pattern and recommend layout
350        let recommendation = self.analyze_and_recommend(tracker)?;
351
352        // Only recommend if improvement exceeds threshold
353        if recommendation.expected_improvement >= self.optimization_threshold {
354            Ok(Some(recommendation))
355        } else {
356            Ok(None)
357        }
358    }
359
360    /// Analyze pattern and generate recommendation
361    fn analyze_and_recommend(&self, tracker: &AccessTracker) -> Result<LayoutRecommendation> {
362        let stats = tracker.statistics();
363        let current_format = tracker.memory_format;
364        let cache_hit_rate = tracker.cache_hit_rate();
365
366        match stats.dominant_pattern {
367            AccessPattern::Sequential | AccessPattern::RowMajor => {
368                // Row-major access benefits from contiguous layout
369                if current_format != MemoryFormat::Contiguous {
370                    Ok(LayoutRecommendation {
371                        current_format,
372                        recommended_format: MemoryFormat::Contiguous,
373                        expected_improvement: 0.3, // 30% improvement
374                        reason: "Sequential/row-major access pattern detected. Contiguous layout will improve cache locality.".to_string(),
375                        transformation_cost: self.estimate_cost(&tracker.shape),
376                    })
377                } else {
378                    Ok(LayoutRecommendation {
379                        current_format,
380                        recommended_format: current_format,
381                        expected_improvement: 0.0,
382                        reason: "Already using optimal layout".to_string(),
383                        transformation_cost: TransformationCost {
384                            memory_copies: 0,
385                            estimated_time_us: 0.0,
386                            memory_overhead_bytes: 0,
387                        },
388                    })
389                }
390            }
391            AccessPattern::ColumnMajor => {
392                // Column-major access benefits from channels-last layout
393                if current_format != MemoryFormat::ChannelsLast {
394                    Ok(LayoutRecommendation {
395                        current_format,
396                        recommended_format: MemoryFormat::ChannelsLast,
397                        expected_improvement: 0.25,
398                        reason: "Column-major access detected. ChannelsLast layout will improve stride patterns.".to_string(),
399                        transformation_cost: self.estimate_cost(&tracker.shape),
400                    })
401                } else {
402                    Ok(LayoutRecommendation {
403                        current_format,
404                        recommended_format: current_format,
405                        expected_improvement: 0.0,
406                        reason: "Already using optimal layout".to_string(),
407                        transformation_cost: TransformationCost {
408                            memory_copies: 0,
409                            estimated_time_us: 0.0,
410                            memory_overhead_bytes: 0,
411                        },
412                    })
413                }
414            }
415            AccessPattern::Strided { stride } => {
416                // Large strides indicate poor cache locality
417                let improvement = if cache_hit_rate < 0.5 { 0.4 } else { 0.15 };
418                Ok(LayoutRecommendation {
419                    current_format,
420                    recommended_format: MemoryFormat::Contiguous,
421                    expected_improvement: improvement,
422                    reason: format!(
423                        "Strided access (stride={}) with low cache hit rate ({}%). Contiguous layout recommended.",
424                        stride,
425                        (cache_hit_rate * 100.0) as u32
426                    ),
427                    transformation_cost: self.estimate_cost(&tracker.shape),
428                })
429            }
430            AccessPattern::BlockWise { block_size } => {
431                if self.aggressive {
432                    Ok(LayoutRecommendation {
433                        current_format,
434                        recommended_format: MemoryFormat::Contiguous,
435                        expected_improvement: 0.2,
436                        reason: format!(
437                            "Block-wise access (block_size={}) detected. Consider cache-friendly blocking.",
438                            block_size
439                        ),
440                        transformation_cost: self.estimate_cost(&tracker.shape),
441                    })
442                } else {
443                    Ok(LayoutRecommendation {
444                        current_format,
445                        recommended_format: current_format,
446                        expected_improvement: 0.0,
447                        reason: "Block-wise access requires specialized optimization".to_string(),
448                        transformation_cost: TransformationCost {
449                            memory_copies: 0,
450                            estimated_time_us: 0.0,
451                            memory_overhead_bytes: 0,
452                        },
453                    })
454                }
455            }
456            AccessPattern::Random => {
457                // Random access doesn't benefit much from layout changes
458                Ok(LayoutRecommendation {
459                    current_format,
460                    recommended_format: current_format,
461                    expected_improvement: 0.0,
462                    reason: "Random access pattern - layout optimization unlikely to help"
463                        .to_string(),
464                    transformation_cost: TransformationCost {
465                        memory_copies: 0,
466                        estimated_time_us: 0.0,
467                        memory_overhead_bytes: 0,
468                    },
469                })
470            }
471            AccessPattern::Broadcast => Ok(LayoutRecommendation {
472                current_format,
473                recommended_format: current_format,
474                expected_improvement: 0.0,
475                reason: "Broadcast-like access - current layout is fine".to_string(),
476                transformation_cost: TransformationCost {
477                    memory_copies: 0,
478                    estimated_time_us: 0.0,
479                    memory_overhead_bytes: 0,
480                },
481            }),
482            AccessPattern::Diagonal => Ok(LayoutRecommendation {
483                current_format,
484                recommended_format: current_format,
485                expected_improvement: 0.0,
486                reason: "Diagonal access - specialized algorithm recommended".to_string(),
487                transformation_cost: TransformationCost {
488                    memory_copies: 0,
489                    estimated_time_us: 0.0,
490                    memory_overhead_bytes: 0,
491                },
492            }),
493        }
494    }
495
496    /// Estimate transformation cost
497    fn estimate_cost(&self, shape: &Shape) -> TransformationCost {
498        let numel = shape.numel();
499        let element_size = 4; // Assume f32 for estimation
500        let total_bytes = numel * element_size;
501
502        // Memory copy cost: ~10 GB/s throughput
503        let copy_time_us = (total_bytes as f64 / 10_000.0) * 1_000_000.0;
504
505        TransformationCost {
506            memory_copies: 1,
507            estimated_time_us: copy_time_us,
508            memory_overhead_bytes: total_bytes,
509        }
510    }
511
512    /// Get all tracked tensor IDs
513    pub fn tracked_tensors(&self) -> Vec<usize> {
514        self.trackers.keys().copied().collect()
515    }
516
517    /// Get statistics for a tensor
518    pub fn get_statistics(&self, tensor_id: usize) -> Result<AccessStatistics> {
519        let tracker = self.trackers.get(&tensor_id).ok_or_else(|| {
520            TorshError::InvalidArgument(format!("Tensor {} not registered", tensor_id))
521        })?;
522        Ok(tracker.statistics().clone())
523    }
524
525    /// Clear tracking data for a tensor
526    pub fn clear_tensor(&mut self, tensor_id: usize) {
527        self.trackers.remove(&tensor_id);
528    }
529
530    /// Clear all tracking data
531    pub fn clear_all(&mut self) {
532        self.trackers.clear();
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn test_access_tracker_creation() {
542        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
543        let tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
544        assert_eq!(tracker.statistics().total_accesses, 0);
545    }
546
547    #[test]
548    fn test_sequential_access_detection() {
549        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
550        let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
551
552        // Simulate sequential access
553        for i in 0..1000 {
554            tracker.record_access(i);
555        }
556
557        let stats = tracker.statistics();
558        assert!(stats.total_accesses == 1000);
559        assert!(stats.cache_hits > stats.cache_misses); // Sequential should have good cache hits
560    }
561
562    #[test]
563    fn test_strided_access_detection() {
564        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
565        let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
566
567        // Simulate strided access (every 10th element)
568        for i in 0..100 {
569            tracker.record_access(i * 10);
570        }
571
572        let stats = tracker.statistics();
573        assert!(stats.total_accesses == 100);
574        // Strided access should show in average_stride
575        assert!(stats.average_stride > 8.0);
576    }
577
578    #[test]
579    fn test_random_access_detection() {
580        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
581        let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
582
583        // Simulate random access
584        let indices = [42, 1000, 5, 9999, 50, 7500, 200];
585        for &idx in &indices {
586            tracker.record_access(idx);
587        }
588
589        let stats = tracker.statistics();
590        assert!(stats.total_accesses == indices.len() as u64);
591    }
592
593    #[test]
594    fn test_cache_hit_rate() {
595        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
596        let mut tracker = AccessTracker::new(shape, MemoryFormat::Contiguous);
597
598        // Sequential access should have high cache hit rate
599        for i in 0..100 {
600            tracker.record_access(i);
601        }
602
603        let hit_rate = tracker.cache_hit_rate();
604        assert!(hit_rate > 0.5); // Should have >50% hit rate
605    }
606
607    #[test]
608    fn test_layout_optimizer_creation() {
609        let optimizer = LayoutOptimizer::new();
610        assert!(optimizer.tracked_tensors().is_empty());
611    }
612
613    #[test]
614    fn test_register_and_track_tensor() {
615        let mut optimizer = LayoutOptimizer::new();
616        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
617
618        optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
619        assert_eq!(optimizer.tracked_tensors().len(), 1);
620        assert!(optimizer.tracked_tensors().contains(&1));
621    }
622
623    #[test]
624    fn test_record_access() {
625        let mut optimizer = LayoutOptimizer::new();
626        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
627
628        optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
629
630        for i in 0..50 {
631            optimizer
632                .record_access(1, i)
633                .expect("record_access should succeed");
634        }
635
636        let stats = optimizer
637            .get_statistics(1)
638            .expect("get_statistics should succeed");
639        assert_eq!(stats.total_accesses, 50);
640    }
641
642    #[test]
643    fn test_optimization_recommendation() {
644        let mut optimizer = LayoutOptimizer::new().with_threshold(0.05);
645        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
646
647        optimizer.register_tensor(1, shape, MemoryFormat::Strided);
648
649        // Simulate sequential access pattern
650        for i in 0..200 {
651            optimizer
652                .record_access(1, i)
653                .expect("record_access should succeed");
654        }
655
656        let recommendation = optimizer
657            .recommend_layout(1)
658            .expect("recommend_layout should succeed");
659        assert!(recommendation.is_some());
660
661        if let Some(rec) = recommendation {
662            // Should recommend Contiguous for sequential access
663            assert_eq!(rec.recommended_format, MemoryFormat::Contiguous);
664            assert!(rec.expected_improvement > 0.0);
665        }
666    }
667
668    #[test]
669    fn test_insufficient_data_no_recommendation() {
670        let mut optimizer = LayoutOptimizer::new();
671        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
672
673        optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
674
675        // Only a few accesses
676        for i in 0..10 {
677            optimizer
678                .record_access(1, i)
679                .expect("record_access should succeed");
680        }
681
682        let recommendation = optimizer
683            .recommend_layout(1)
684            .expect("recommend_layout should succeed");
685        assert!(recommendation.is_none()); // Not enough data
686    }
687
688    #[test]
689    fn test_clear_tensor() {
690        let mut optimizer = LayoutOptimizer::new();
691        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
692
693        optimizer.register_tensor(1, shape, MemoryFormat::Contiguous);
694        assert_eq!(optimizer.tracked_tensors().len(), 1);
695
696        optimizer.clear_tensor(1);
697        assert!(optimizer.tracked_tensors().is_empty());
698    }
699
700    #[test]
701    fn test_aggressive_optimization() {
702        let optimizer = LayoutOptimizer::new().aggressive(true);
703        assert!(optimizer.aggressive);
704    }
705
706    #[test]
707    fn test_transformation_cost_estimation() {
708        let optimizer = LayoutOptimizer::new();
709        let shape = Shape::from_array([1000, 1000]).expect("shape creation should succeed");
710
711        let cost = optimizer.estimate_cost(&shape);
712        assert!(cost.memory_copies > 0);
713        assert!(cost.estimated_time_us > 0.0);
714        assert!(cost.memory_overhead_bytes > 0);
715    }
716
717    #[test]
718    fn test_custom_cache_line_size() {
719        let shape = Shape::from_array([100, 100]).expect("shape creation should succeed");
720        let tracker = AccessTracker::new(shape, MemoryFormat::Contiguous).with_cache_line_size(128);
721
722        assert_eq!(tracker.cache_line_size, 128);
723    }
724}