Skip to main content

tensorlogic_ir/graph/
layout.rs

1//! Tensor layout and stride optimization.
2//!
3//! This module provides optimizations for tensor memory layouts to improve
4//! cache utilization and memory access patterns.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::{EinsumGraph, IrError};
11
12/// Memory layout strategy for tensors.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
14pub enum LayoutStrategy {
15    /// Row-major order (C-style, default for most systems)
16    #[default]
17    RowMajor,
18    /// Column-major order (Fortran-style, good for column operations)
19    ColumnMajor,
20    /// Blocked layout for cache-friendly access
21    Blocked { block_size: usize },
22    /// Tiled layout with specific tile dimensions
23    Tiled {
24        tile_height: usize,
25        tile_width: usize,
26    },
27    /// Z-order (Morton) curve for locality preservation
28    ZOrder,
29    /// Hilbert curve for even better locality
30    Hilbert,
31}
32
33impl LayoutStrategy {
34    /// Get the recommended strategy for a given operation pattern.
35    pub fn for_operation(op: &str) -> Self {
36        match op {
37            "matmul" | "einsum" => Self::Blocked { block_size: 32 },
38            "transpose" => Self::ColumnMajor,
39            "conv2d" => Self::Tiled {
40                tile_height: 8,
41                tile_width: 8,
42            },
43            "scan" | "reduce" => Self::RowMajor,
44            _ => Self::default(),
45        }
46    }
47
48    /// Check if this layout benefits from vectorization.
49    pub fn supports_vectorization(&self) -> bool {
50        matches!(
51            self,
52            Self::RowMajor | Self::Blocked { .. } | Self::Tiled { .. }
53        )
54    }
55
56    /// Check if this layout preserves spatial locality.
57    pub fn preserves_locality(&self) -> bool {
58        matches!(
59            self,
60            Self::Blocked { .. } | Self::Tiled { .. } | Self::ZOrder | Self::Hilbert
61        )
62    }
63}
64
65/// Stride pattern for a tensor.
66#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67pub struct StridePattern {
68    /// Strides for each dimension (in elements)
69    pub strides: Vec<usize>,
70    /// Whether strides are contiguous
71    pub is_contiguous: bool,
72    /// Alignment in bytes (0 means no specific alignment)
73    pub alignment: usize,
74}
75
76impl StridePattern {
77    /// Create a row-major stride pattern for given dimensions.
78    pub fn row_major(dims: &[usize]) -> Self {
79        let mut strides = vec![1];
80        for i in (0..dims.len() - 1).rev() {
81            strides.insert(0, strides[0] * dims[i + 1]);
82        }
83
84        Self {
85            strides,
86            is_contiguous: true,
87            alignment: 0,
88        }
89    }
90
91    /// Create a column-major stride pattern for given dimensions.
92    pub fn column_major(dims: &[usize]) -> Self {
93        let mut strides = vec![1];
94        for i in 0..dims.len() - 1 {
95            strides.push(strides[i] * dims[i]);
96        }
97
98        Self {
99            strides,
100            is_contiguous: true,
101            alignment: 0,
102        }
103    }
104
105    /// Create a custom stride pattern.
106    pub fn custom(strides: Vec<usize>) -> Self {
107        let is_contiguous = is_contiguous_strides(&strides);
108        Self {
109            strides,
110            is_contiguous,
111            alignment: 0,
112        }
113    }
114
115    /// Set the alignment requirement.
116    pub fn with_alignment(mut self, alignment: usize) -> Self {
117        self.alignment = alignment;
118        self
119    }
120
121    /// Check if the stride pattern allows efficient vectorization.
122    pub fn is_vectorizable(&self) -> bool {
123        self.is_contiguous && self.strides.last().copied().unwrap_or(0) == 1
124    }
125
126    /// Estimate memory access cost (lower is better).
127    pub fn access_cost(&self) -> f64 {
128        if self.is_contiguous {
129            1.0
130        } else {
131            // Non-contiguous access is more expensive
132            1.5 + (self.strides.len() as f64 * 0.1)
133        }
134    }
135}
136
137/// Layout configuration for a tensor.
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
139pub struct TensorLayout {
140    /// Tensor index
141    pub tensor_idx: usize,
142    /// Layout strategy
143    pub strategy: LayoutStrategy,
144    /// Stride pattern
145    pub strides: StridePattern,
146    /// Whether this layout can be transformed
147    pub is_mutable: bool,
148}
149
150impl TensorLayout {
151    /// Create a new tensor layout.
152    pub fn new(tensor_idx: usize, strategy: LayoutStrategy, dims: &[usize]) -> Self {
153        let strides = match strategy {
154            LayoutStrategy::RowMajor => StridePattern::row_major(dims),
155            LayoutStrategy::ColumnMajor => StridePattern::column_major(dims),
156            _ => StridePattern::row_major(dims), // Default to row-major
157        };
158
159        Self {
160            tensor_idx,
161            strategy,
162            strides,
163            is_mutable: true,
164        }
165    }
166
167    /// Estimate the memory access efficiency (0.0 to 1.0, higher is better).
168    pub fn access_efficiency(&self) -> f64 {
169        let base_efficiency = if self.strides.is_contiguous { 0.9 } else { 0.5 };
170
171        let locality_bonus: f64 = if self.strategy.preserves_locality() {
172            0.1
173        } else {
174            0.0
175        };
176
177        (base_efficiency + locality_bonus).min(1.0f64)
178    }
179}
180
181/// Result of layout optimization.
182#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct LayoutOptimizationResult {
184    /// Optimized layouts for each tensor
185    pub layouts: HashMap<usize, TensorLayout>,
186    /// Number of layout transformations required
187    pub transformations_needed: usize,
188    /// Estimated memory access improvement (0.0 to 1.0)
189    pub estimated_improvement: f64,
190    /// Estimated speedup from better layouts
191    pub estimated_speedup: f64,
192}
193
194impl LayoutOptimizationResult {
195    /// Create a result with no optimizations.
196    pub fn none() -> Self {
197        Self {
198            layouts: HashMap::new(),
199            transformations_needed: 0,
200            estimated_improvement: 0.0,
201            estimated_speedup: 1.0,
202        }
203    }
204
205    /// Get the layout for a tensor.
206    pub fn get_layout(&self, tensor_idx: usize) -> Option<&TensorLayout> {
207        self.layouts.get(&tensor_idx)
208    }
209}
210
211/// Optimize tensor layouts for a graph.
212pub fn optimize_layouts(graph: &EinsumGraph) -> Result<LayoutOptimizationResult, IrError> {
213    let mut result = LayoutOptimizationResult::none();
214
215    // Analyze each tensor and choose optimal layout
216    for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
217        // Infer dimensions from tensor name or metadata
218        let dims = infer_dimensions(tensor_name, graph, tensor_idx);
219
220        // Analyze usage pattern to determine best layout
221        let strategy = analyze_usage_pattern(graph, tensor_idx);
222
223        let layout = TensorLayout::new(tensor_idx, strategy, &dims);
224        result.layouts.insert(tensor_idx, layout);
225    }
226
227    // Count needed transformations
228    result.transformations_needed = count_layout_conversions(&result.layouts);
229
230    // Estimate improvements
231    let avg_efficiency: f64 = result
232        .layouts
233        .values()
234        .map(|l| l.access_efficiency())
235        .sum::<f64>()
236        / result.layouts.len().max(1) as f64;
237
238    result.estimated_improvement = (avg_efficiency - 0.7).max(0.0);
239    result.estimated_speedup = 1.0 + result.estimated_improvement * 0.3;
240
241    Ok(result)
242}
243
244/// Apply the recommended layouts to a graph.
245pub fn apply_layouts(
246    graph: &mut EinsumGraph,
247    layouts: &HashMap<usize, TensorLayout>,
248) -> Result<(), IrError> {
249    // Add layout metadata to tensors
250    for (tensor_idx, layout) in layouts {
251        if *tensor_idx < graph.tensors.len() {
252            let mut metadata = graph
253                .get_tensor_metadata(*tensor_idx)
254                .cloned()
255                .unwrap_or_else(crate::Metadata::new);
256
257            metadata
258                .attributes
259                .push(("layout".to_string(), format!("{:?}", layout.strategy)));
260            metadata.attributes.push((
261                "is_contiguous".to_string(),
262                layout.strides.is_contiguous.to_string(),
263            ));
264
265            graph.add_tensor_metadata(*tensor_idx, metadata);
266        }
267    }
268
269    Ok(())
270}
271
272/// Find opportunities for layout fusion (avoiding layout conversions).
273pub fn find_layout_fusion_opportunities(
274    layouts: &HashMap<usize, TensorLayout>,
275) -> Vec<(usize, usize)> {
276    let mut opportunities = Vec::new();
277
278    // Find pairs of tensors that would benefit from the same layout
279    let tensor_indices: Vec<_> = layouts.keys().copied().collect();
280
281    for i in 0..tensor_indices.len() {
282        for j in (i + 1)..tensor_indices.len() {
283            let idx1 = tensor_indices[i];
284            let idx2 = tensor_indices[j];
285
286            if let (Some(layout1), Some(layout2)) = (layouts.get(&idx1), layouts.get(&idx2)) {
287                if layout1.strategy != layout2.strategy && layout1.is_mutable && layout2.is_mutable
288                {
289                    opportunities.push((idx1, idx2));
290                }
291            }
292        }
293    }
294
295    opportunities
296}
297
298// Helper functions
299
300fn infer_dimensions(_tensor_name: &str, _graph: &EinsumGraph, _tensor_idx: usize) -> Vec<usize> {
301    // Try to infer dimensions from tensor name
302    // For now, return a default 2D shape
303    // In a real implementation, this would use shape inference
304    vec![64, 64]
305}
306
307fn analyze_usage_pattern(graph: &EinsumGraph, tensor_idx: usize) -> LayoutStrategy {
308    // Count how tensor is used
309    let mut read_patterns = Vec::new();
310
311    for node in &graph.nodes {
312        if node.inputs.contains(&tensor_idx) {
313            // Analyze how it's accessed
314            let pattern = match &node.op {
315                crate::OpType::Einsum { spec } => analyze_einsum_pattern(spec),
316                crate::OpType::Reduce { .. } => "reduce",
317                crate::OpType::ElemUnary { .. } => "scan",
318                crate::OpType::ElemBinary { .. } => "scan",
319            };
320            read_patterns.push(pattern);
321        }
322    }
323
324    // Choose best layout based on dominant pattern
325    if read_patterns.contains(&"matmul") {
326        LayoutStrategy::Blocked { block_size: 32 }
327    } else if read_patterns.contains(&"transpose") {
328        LayoutStrategy::ColumnMajor
329    } else if read_patterns.contains(&"conv") {
330        LayoutStrategy::Tiled {
331            tile_height: 8,
332            tile_width: 8,
333        }
334    } else {
335        LayoutStrategy::RowMajor
336    }
337}
338
339fn analyze_einsum_pattern(spec: &str) -> &'static str {
340    if spec.contains(',') {
341        "matmul"
342    } else if spec.contains("->") {
343        let parts: Vec<&str> = spec.split("->").collect();
344        if parts.len() == 2 && parts[0].len() > parts[1].len() {
345            "reduce"
346        } else {
347            "scan"
348        }
349    } else {
350        "scan"
351    }
352}
353
354fn count_layout_conversions(layouts: &HashMap<usize, TensorLayout>) -> usize {
355    // Count tensors that need non-default layout
356    layouts
357        .values()
358        .filter(|l| l.strategy != LayoutStrategy::RowMajor)
359        .count()
360}
361
362fn is_contiguous_strides(strides: &[usize]) -> bool {
363    if strides.is_empty() {
364        return true;
365    }
366
367    // Check if strides form a contiguous pattern
368    let mut prev = strides[strides.len() - 1];
369    if prev != 1 {
370        return false;
371    }
372
373    for &stride in strides.iter().rev().skip(1) {
374        if stride <= prev {
375            return false;
376        }
377        // Check if the ratio is reasonable (not a huge gap)
378        // For contiguous arrays, dimension sizes are typically < 10000
379        let ratio = stride / prev;
380        if ratio == 0 || ratio > 10000 {
381            return false;
382        }
383        // Also check that stride is exactly divisible by prev
384        if stride % prev != 0 {
385            return false;
386        }
387        prev = stride;
388    }
389
390    true
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_layout_strategy_default() {
399        assert_eq!(LayoutStrategy::default(), LayoutStrategy::RowMajor);
400    }
401
402    #[test]
403    fn test_layout_strategy_for_operation() {
404        let matmul_layout = LayoutStrategy::for_operation("matmul");
405        assert!(matches!(matmul_layout, LayoutStrategy::Blocked { .. }));
406
407        let transpose_layout = LayoutStrategy::for_operation("transpose");
408        assert_eq!(transpose_layout, LayoutStrategy::ColumnMajor);
409
410        let conv_layout = LayoutStrategy::for_operation("conv2d");
411        assert!(matches!(conv_layout, LayoutStrategy::Tiled { .. }));
412    }
413
414    #[test]
415    fn test_layout_strategy_vectorization() {
416        assert!(LayoutStrategy::RowMajor.supports_vectorization());
417        assert!(LayoutStrategy::Blocked { block_size: 32 }.supports_vectorization());
418        assert!(!LayoutStrategy::ZOrder.supports_vectorization());
419    }
420
421    #[test]
422    fn test_layout_strategy_locality() {
423        assert!(LayoutStrategy::Blocked { block_size: 32 }.preserves_locality());
424        assert!(LayoutStrategy::ZOrder.preserves_locality());
425        assert!(LayoutStrategy::Hilbert.preserves_locality());
426        assert!(!LayoutStrategy::RowMajor.preserves_locality());
427    }
428
429    #[test]
430    fn test_stride_pattern_row_major() {
431        let dims = vec![4, 8, 16];
432        let pattern = StridePattern::row_major(&dims);
433
434        assert_eq!(pattern.strides, vec![128, 16, 1]);
435        assert!(pattern.is_contiguous);
436        assert!(pattern.is_vectorizable());
437    }
438
439    #[test]
440    fn test_stride_pattern_column_major() {
441        let dims = vec![4, 8, 16];
442        let pattern = StridePattern::column_major(&dims);
443
444        assert_eq!(pattern.strides, vec![1, 4, 32]);
445        assert!(pattern.is_contiguous);
446    }
447
448    #[test]
449    fn test_stride_pattern_custom() {
450        let strides = vec![64, 8, 1];
451        let pattern = StridePattern::custom(strides.clone());
452
453        assert_eq!(pattern.strides, strides);
454        assert!(pattern.is_contiguous);
455    }
456
457    #[test]
458    fn test_stride_pattern_non_contiguous() {
459        let strides = vec![100, 10, 2]; // Non-contiguous
460        let pattern = StridePattern::custom(strides);
461
462        assert!(!pattern.is_contiguous);
463        assert!(!pattern.is_vectorizable());
464    }
465
466    #[test]
467    fn test_stride_pattern_with_alignment() {
468        let pattern = StridePattern::row_major(&[4, 8]).with_alignment(64);
469        assert_eq!(pattern.alignment, 64);
470    }
471
472    #[test]
473    fn test_stride_pattern_access_cost() {
474        let contiguous = StridePattern::row_major(&[4, 8]);
475        let non_contiguous = StridePattern::custom(vec![100, 10, 2]);
476
477        assert!(contiguous.access_cost() < non_contiguous.access_cost());
478    }
479
480    #[test]
481    fn test_tensor_layout_creation() {
482        let layout = TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]);
483
484        assert_eq!(layout.tensor_idx, 0);
485        assert_eq!(layout.strategy, LayoutStrategy::RowMajor);
486        assert!(layout.is_mutable);
487        assert!(layout.strides.is_contiguous);
488    }
489
490    #[test]
491    fn test_tensor_layout_access_efficiency() {
492        let row_major = TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]);
493        let blocked = TensorLayout::new(0, LayoutStrategy::Blocked { block_size: 32 }, &[4, 8]);
494
495        let row_efficiency = row_major.access_efficiency();
496        let blocked_efficiency = blocked.access_efficiency();
497
498        assert!(row_efficiency > 0.0 && row_efficiency <= 1.0);
499        assert!(blocked_efficiency > row_efficiency); // Blocked should be more efficient
500    }
501
502    #[test]
503    fn test_layout_optimization_result_none() {
504        let result = LayoutOptimizationResult::none();
505        assert!(result.layouts.is_empty());
506        assert_eq!(result.transformations_needed, 0);
507        assert_eq!(result.estimated_improvement, 0.0);
508        assert_eq!(result.estimated_speedup, 1.0);
509    }
510
511    #[test]
512    fn test_optimize_layouts_empty_graph() {
513        let graph = EinsumGraph::new();
514        let result = optimize_layouts(&graph).unwrap();
515        assert!(result.layouts.is_empty());
516    }
517
518    #[test]
519    fn test_optimize_layouts_simple_graph() {
520        let mut graph = EinsumGraph::new();
521        let a = graph.add_tensor("A");
522        let b = graph.add_tensor("B");
523        let c = graph.add_tensor("C");
524
525        graph
526            .add_node(crate::EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
527            .unwrap();
528
529        let result = optimize_layouts(&graph).unwrap();
530        assert_eq!(result.layouts.len(), 3);
531        assert!(result.estimated_speedup >= 1.0);
532    }
533
534    #[test]
535    fn test_apply_layouts() {
536        let mut graph = EinsumGraph::new();
537        let a = graph.add_tensor("A");
538
539        let mut layouts = HashMap::new();
540        layouts.insert(
541            a,
542            TensorLayout::new(a, LayoutStrategy::Blocked { block_size: 32 }, &[64, 64]),
543        );
544
545        apply_layouts(&mut graph, &layouts).unwrap();
546
547        // Check that metadata was added
548        let metadata = graph.get_tensor_metadata(a);
549        assert!(metadata.is_some());
550    }
551
552    #[test]
553    fn test_find_layout_fusion_opportunities() {
554        let mut layouts = HashMap::new();
555
556        layouts.insert(0, TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]));
557        layouts.insert(
558            1,
559            TensorLayout::new(1, LayoutStrategy::ColumnMajor, &[4, 8]),
560        );
561        layouts.insert(2, TensorLayout::new(2, LayoutStrategy::RowMajor, &[4, 8]));
562
563        let opportunities = find_layout_fusion_opportunities(&layouts);
564        assert!(!opportunities.is_empty());
565    }
566
567    #[test]
568    fn test_analyze_einsum_pattern() {
569        assert_eq!(analyze_einsum_pattern("ik,kj->ij"), "matmul");
570        assert_eq!(analyze_einsum_pattern("ijk->ij"), "reduce");
571        assert_eq!(analyze_einsum_pattern("ij->ij"), "scan");
572    }
573
574    #[test]
575    fn test_is_contiguous_strides() {
576        assert!(is_contiguous_strides(&[8, 4, 1]));
577        assert!(is_contiguous_strides(&[1]));
578        assert!(is_contiguous_strides(&[]));
579        assert!(is_contiguous_strides(&[8, 2, 1])); // Valid: dims [?, 4, 2]
580        assert!(!is_contiguous_strides(&[8, 4, 2])); // Doesn't end with 1
581        assert!(!is_contiguous_strides(&[9, 2, 1])); // Not divisible: 9 % 2 != 0
582    }
583
584    #[test]
585    fn test_count_layout_conversions() {
586        let mut layouts = HashMap::new();
587
588        layouts.insert(0, TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]));
589        layouts.insert(
590            1,
591            TensorLayout::new(1, LayoutStrategy::ColumnMajor, &[4, 8]),
592        );
593        layouts.insert(
594            2,
595            TensorLayout::new(2, LayoutStrategy::Blocked { block_size: 32 }, &[4, 8]),
596        );
597
598        let conversions = count_layout_conversions(&layouts);
599        assert_eq!(conversions, 2); // column-major and blocked need conversion
600    }
601
602    #[test]
603    fn test_layout_optimization_with_metadata() {
604        let mut graph = EinsumGraph::new();
605        let a = graph.add_tensor("A");
606        let b = graph.add_tensor("B");
607
608        // Add metadata to suggest layout
609        let metadata = crate::Metadata::new().with_attribute("preferred_layout", "blocked");
610        graph.add_tensor_metadata(a, metadata);
611
612        let result = optimize_layouts(&graph).unwrap();
613        assert!(result.get_layout(a).is_some());
614        assert!(result.get_layout(b).is_some());
615    }
616}