Skip to main content

tensorlogic_ir/graph/
tiling.rs

1//! Loop tiling and unrolling optimizations for cache locality.
2//!
3//! This module provides advanced loop transformation techniques to improve
4//! memory access patterns and cache utilization in tensor computations.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::{EinsumGraph, EinsumNode, IrError, OpType};
11
12/// Tiling configuration for a specific axis.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub struct TileConfig {
15    /// The axis/dimension to tile
16    pub axis: usize,
17    /// Tile size (number of elements per tile)
18    pub tile_size: usize,
19    /// Whether to unroll the inner loop
20    pub unroll: bool,
21}
22
23impl TileConfig {
24    /// Create a new tiling configuration.
25    pub fn new(axis: usize, tile_size: usize) -> Self {
26        Self {
27            axis,
28            tile_size,
29            unroll: false,
30        }
31    }
32
33    /// Create a tiling configuration with unrolling enabled.
34    pub fn with_unroll(axis: usize, tile_size: usize) -> Self {
35        Self {
36            axis,
37            tile_size,
38            unroll: true,
39        }
40    }
41}
42
43/// Multi-dimensional tiling strategy.
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct TilingStrategy {
46    /// Tiling configurations for each axis
47    pub tiles: Vec<TileConfig>,
48    /// Whether to apply register tiling (very small tiles for register reuse)
49    pub register_tiling: bool,
50    /// Cache line size in bytes (for alignment)
51    pub cache_line_size: usize,
52}
53
54impl Default for TilingStrategy {
55    fn default() -> Self {
56        Self {
57            tiles: Vec::new(),
58            register_tiling: false,
59            cache_line_size: 64, // Common cache line size
60        }
61    }
62}
63
64impl TilingStrategy {
65    /// Create a new tiling strategy.
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Add a tile configuration for a specific axis.
71    pub fn add_tile(&mut self, config: TileConfig) -> &mut Self {
72        self.tiles.push(config);
73        self
74    }
75
76    /// Enable register tiling for maximum register reuse.
77    pub fn with_register_tiling(mut self) -> Self {
78        self.register_tiling = true;
79        self
80    }
81
82    /// Set the cache line size for alignment optimization.
83    pub fn with_cache_line_size(mut self, size: usize) -> Self {
84        self.cache_line_size = size;
85        self
86    }
87
88    /// Get recommended tile sizes for matrix multiplication (M×K @ K×N).
89    pub fn for_matmul(m: usize, k: usize, n: usize) -> Self {
90        // Recommended tile sizes based on typical L1/L2 cache sizes
91        let tile_m = m.clamp(8, 64);
92        let tile_k = k.clamp(8, 64);
93        let tile_n = n.clamp(8, 64);
94
95        let mut strategy = Self::new();
96        strategy.add_tile(TileConfig::new(0, tile_m)); // M dimension
97        strategy.add_tile(TileConfig::new(1, tile_k)); // K dimension
98        strategy.add_tile(TileConfig::new(2, tile_n)); // N dimension
99        strategy
100    }
101
102    /// Get recommended tile sizes for convolution operations.
103    pub fn for_conv(
104        batch: usize,
105        out_channels: usize,
106        out_height: usize,
107        out_width: usize,
108    ) -> Self {
109        let tile_b = batch.clamp(1, 16);
110        let tile_c = out_channels.clamp(1, 16);
111        let tile_h = out_height.clamp(1, 8);
112        let tile_w = out_width.clamp(1, 8);
113
114        let mut strategy = Self::new();
115        strategy.add_tile(TileConfig::new(0, tile_b));
116        strategy.add_tile(TileConfig::new(1, tile_c));
117        strategy.add_tile(TileConfig::new(2, tile_h));
118        strategy.add_tile(TileConfig::new(3, tile_w));
119        strategy
120    }
121}
122
123/// Result of applying tiling transformations.
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
125pub struct TilingResult {
126    /// Number of nodes that were tiled
127    pub nodes_tiled: usize,
128    /// Number of loops unrolled
129    pub loops_unrolled: usize,
130    /// Estimated cache hit rate improvement (0.0 to 1.0)
131    pub estimated_cache_improvement: f64,
132    /// Estimated speedup factor
133    pub estimated_speedup: f64,
134}
135
136impl TilingResult {
137    /// Create a new tiling result with no transformations.
138    pub fn none() -> Self {
139        Self {
140            nodes_tiled: 0,
141            loops_unrolled: 0,
142            estimated_cache_improvement: 0.0,
143            estimated_speedup: 1.0,
144        }
145    }
146}
147
148/// Apply loop tiling to einsum operations in the graph.
149pub fn apply_tiling(
150    graph: &mut EinsumGraph,
151    strategy: &TilingStrategy,
152) -> Result<TilingResult, IrError> {
153    let mut result = TilingResult::none();
154
155    for node in &mut graph.nodes {
156        if let OpType::Einsum { spec } = &node.op {
157            if should_tile_einsum(spec) {
158                // Apply tiling transformation
159                tile_einsum_node(node, strategy)?;
160                result.nodes_tiled += 1;
161
162                // Count unrolled loops
163                for tile in &strategy.tiles {
164                    if tile.unroll {
165                        result.loops_unrolled += 1;
166                    }
167                }
168            }
169        }
170    }
171
172    // Estimate performance improvements
173    if result.nodes_tiled > 0 {
174        result.estimated_cache_improvement = estimate_cache_improvement(strategy);
175        result.estimated_speedup = 1.0 + result.estimated_cache_improvement * 0.5;
176    }
177
178    Ok(result)
179}
180
181/// Apply register-level tiling for maximum register reuse.
182pub fn apply_register_tiling(graph: &mut EinsumGraph) -> Result<TilingResult, IrError> {
183    let mut strategy = TilingStrategy::new().with_register_tiling();
184
185    // Use small tile sizes (4-8 elements) for register tiling
186    strategy.add_tile(TileConfig::with_unroll(0, 4));
187    strategy.add_tile(TileConfig::with_unroll(1, 4));
188
189    apply_tiling(graph, &strategy)
190}
191
192/// Apply multi-level tiling (L1, L2, L3 cache hierarchy).
193pub fn apply_multilevel_tiling(
194    graph: &mut EinsumGraph,
195    l1_tiles: &[usize],
196    l2_tiles: &[usize],
197    l3_tiles: &[usize],
198) -> Result<TilingResult, IrError> {
199    let mut total_result = TilingResult::none();
200
201    // Apply L3 tiles first (outermost)
202    if !l3_tiles.is_empty() {
203        let mut strategy = TilingStrategy::new();
204        for (i, &tile_size) in l3_tiles.iter().enumerate() {
205            strategy.add_tile(TileConfig::new(i, tile_size));
206        }
207        let result = apply_tiling(graph, &strategy)?;
208        total_result.nodes_tiled += result.nodes_tiled;
209    }
210
211    // Apply L2 tiles
212    if !l2_tiles.is_empty() {
213        let mut strategy = TilingStrategy::new();
214        for (i, &tile_size) in l2_tiles.iter().enumerate() {
215            strategy.add_tile(TileConfig::new(i, tile_size));
216        }
217        let result = apply_tiling(graph, &strategy)?;
218        total_result.nodes_tiled += result.nodes_tiled;
219    }
220
221    // Apply L1 tiles with unrolling (innermost)
222    if !l1_tiles.is_empty() {
223        let mut strategy = TilingStrategy::new();
224        for (i, &tile_size) in l1_tiles.iter().enumerate() {
225            strategy.add_tile(TileConfig::with_unroll(i, tile_size));
226        }
227        let result = apply_tiling(graph, &strategy)?;
228        total_result.nodes_tiled += result.nodes_tiled;
229        total_result.loops_unrolled += result.loops_unrolled;
230    }
231
232    // Estimate combined improvements
233    total_result.estimated_cache_improvement = 0.3; // Conservative estimate
234    total_result.estimated_speedup = 1.5; // Typical speedup for multi-level tiling
235
236    Ok(total_result)
237}
238
239/// Analyze a graph and recommend optimal tiling strategies.
240pub fn recommend_tiling_strategy(graph: &EinsumGraph) -> HashMap<usize, TilingStrategy> {
241    let mut recommendations = HashMap::new();
242
243    for (node_idx, node) in graph.nodes.iter().enumerate() {
244        if let OpType::Einsum { spec } = &node.op {
245            if let Some(strategy) = analyze_einsum_for_tiling(spec) {
246                recommendations.insert(node_idx, strategy);
247            }
248        }
249    }
250
251    recommendations
252}
253
254// Helper functions
255
256fn should_tile_einsum(spec: &str) -> bool {
257    // Tile if the einsum involves reduction or matrix-like operations
258    spec.contains("->") && (spec.contains(',') || spec.len() > 6)
259}
260
261fn tile_einsum_node(node: &mut EinsumNode, strategy: &TilingStrategy) -> Result<(), IrError> {
262    // In a real implementation, this would transform the einsum specification
263    // to include tiling metadata. For now, we just annotate it.
264
265    // Add tiling metadata to the node
266    if node.metadata.is_none() {
267        node.metadata = Some(crate::Metadata::new());
268    }
269
270    if let Some(metadata) = &mut node.metadata {
271        metadata.attributes.push((
272            "tiling_strategy".to_string(),
273            format!("{} tiles", strategy.tiles.len()),
274        ));
275        metadata.attributes.push((
276            "register_tiling".to_string(),
277            strategy.register_tiling.to_string(),
278        ));
279    }
280
281    Ok(())
282}
283
284fn estimate_cache_improvement(strategy: &TilingStrategy) -> f64 {
285    // Estimate based on number of tiling levels and tile sizes
286    let base_improvement = 0.2; // 20% baseline
287    let per_tile_improvement = 0.1; // 10% per tiling dimension
288    let register_bonus = if strategy.register_tiling { 0.15 } else { 0.0 };
289
290    let total =
291        base_improvement + (strategy.tiles.len() as f64 * per_tile_improvement) + register_bonus;
292
293    total.min(0.8) // Cap at 80% improvement
294}
295
296fn analyze_einsum_for_tiling(spec: &str) -> Option<TilingStrategy> {
297    // Parse einsum spec to determine appropriate tiling
298    if let Some(arrow_pos) = spec.find("->") {
299        let inputs = &spec[..arrow_pos];
300        let output = &spec[arrow_pos + 2..];
301
302        // Detect matrix multiplication pattern (e.g., "ik,kj->ij")
303        if inputs.contains(',') {
304            let parts: Vec<&str> = inputs.split(',').collect();
305            if parts.len() == 2 {
306                let a_axes = parts[0].trim();
307                let b_axes = parts[1].trim();
308
309                // Check for matmul-like pattern
310                if a_axes.len() == 2 && b_axes.len() == 2 && output.len() == 2 {
311                    let mut strategy = TilingStrategy::new();
312                    strategy.add_tile(TileConfig::new(0, 32)); // M dimension
313                    strategy.add_tile(TileConfig::new(1, 32)); // K dimension
314                    strategy.add_tile(TileConfig::new(2, 32)); // N dimension
315                    return Some(strategy);
316                }
317            }
318        }
319
320        // For reductions, use smaller tiles
321        if output.len() < inputs.replace(',', "").len() {
322            let mut strategy = TilingStrategy::new();
323            strategy.add_tile(TileConfig::new(0, 16));
324            return Some(strategy);
325        }
326    }
327
328    None
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn test_tile_config_creation() {
337        let config = TileConfig::new(0, 32);
338        assert_eq!(config.axis, 0);
339        assert_eq!(config.tile_size, 32);
340        assert!(!config.unroll);
341
342        let config_unroll = TileConfig::with_unroll(1, 16);
343        assert_eq!(config_unroll.axis, 1);
344        assert_eq!(config_unroll.tile_size, 16);
345        assert!(config_unroll.unroll);
346    }
347
348    #[test]
349    fn test_tiling_strategy_builder() {
350        let mut strategy = TilingStrategy::new();
351        strategy.add_tile(TileConfig::new(0, 32));
352        strategy.add_tile(TileConfig::new(1, 32));
353
354        assert_eq!(strategy.tiles.len(), 2);
355        assert!(!strategy.register_tiling);
356    }
357
358    #[test]
359    fn test_matmul_tiling_strategy() {
360        let strategy = TilingStrategy::for_matmul(128, 128, 128);
361        assert_eq!(strategy.tiles.len(), 3);
362        assert!(strategy.tiles[0].tile_size <= 64);
363    }
364
365    #[test]
366    fn test_conv_tiling_strategy() {
367        let strategy = TilingStrategy::for_conv(32, 64, 56, 56);
368        assert_eq!(strategy.tiles.len(), 4);
369    }
370
371    #[test]
372    fn test_should_tile_einsum() {
373        assert!(should_tile_einsum("ik,kj->ij"));
374        assert!(should_tile_einsum("ijk->ij"));
375        assert!(!should_tile_einsum("i->i"));
376    }
377
378    #[test]
379    fn test_analyze_einsum_for_tiling() {
380        let strategy = analyze_einsum_for_tiling("ik,kj->ij");
381        assert!(strategy.is_some());
382        let s = strategy.unwrap();
383        assert_eq!(s.tiles.len(), 3);
384
385        let strategy_reduction = analyze_einsum_for_tiling("ijk->ij");
386        assert!(strategy_reduction.is_some());
387    }
388
389    #[test]
390    fn test_apply_tiling_to_graph() {
391        let mut graph = EinsumGraph::new();
392        let a = graph.add_tensor("A");
393        let b = graph.add_tensor("B");
394        let c = graph.add_tensor("C");
395
396        graph
397            .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
398            .unwrap();
399
400        let strategy = TilingStrategy::for_matmul(64, 64, 64);
401        let result = apply_tiling(&mut graph, &strategy).unwrap();
402
403        assert_eq!(result.nodes_tiled, 1);
404        assert!(result.estimated_speedup >= 1.0);
405    }
406
407    #[test]
408    fn test_register_tiling() {
409        let mut graph = EinsumGraph::new();
410        let a = graph.add_tensor("A");
411        let b = graph.add_tensor("B");
412        let c = graph.add_tensor("C");
413
414        graph
415            .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
416            .unwrap();
417
418        let result = apply_register_tiling(&mut graph).unwrap();
419        assert_eq!(result.nodes_tiled, 1);
420        assert!(result.loops_unrolled > 0);
421    }
422
423    #[test]
424    fn test_multilevel_tiling() {
425        let mut graph = EinsumGraph::new();
426        let a = graph.add_tensor("A");
427        let b = graph.add_tensor("B");
428        let c = graph.add_tensor("C");
429
430        graph
431            .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
432            .unwrap();
433
434        let l1_tiles = vec![8, 8, 8];
435        let l2_tiles = vec![32, 32, 32];
436        let l3_tiles = vec![128, 128, 128];
437
438        let result = apply_multilevel_tiling(&mut graph, &l1_tiles, &l2_tiles, &l3_tiles).unwrap();
439        assert!(result.nodes_tiled > 0);
440        assert!(result.estimated_speedup > 1.0);
441    }
442
443    #[test]
444    fn test_recommend_tiling_strategy() {
445        let mut graph = EinsumGraph::new();
446        let a = graph.add_tensor("A");
447        let b = graph.add_tensor("B");
448        let c = graph.add_tensor("C");
449        let d = graph.add_tensor("D");
450
451        // Matrix multiplication
452        graph
453            .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
454            .unwrap();
455
456        // Element-wise operation (should not be tiled)
457        graph
458            .add_node(EinsumNode::elem_unary("relu", c, d))
459            .unwrap();
460
461        let recommendations = recommend_tiling_strategy(&graph);
462        assert_eq!(recommendations.len(), 1); // Only matmul should have recommendation
463        assert!(recommendations.contains_key(&0));
464    }
465
466    #[test]
467    fn test_estimate_cache_improvement() {
468        let mut strategy = TilingStrategy::new();
469        strategy.add_tile(TileConfig::new(0, 32));
470        strategy.add_tile(TileConfig::new(1, 32));
471
472        let improvement = estimate_cache_improvement(&strategy);
473        assert!(improvement > 0.0 && improvement <= 0.8);
474
475        let strategy_with_register = strategy.with_register_tiling();
476        let improvement_with_register = estimate_cache_improvement(&strategy_with_register);
477        assert!(improvement_with_register > improvement);
478    }
479
480    #[test]
481    fn test_tiling_result_none() {
482        let result = TilingResult::none();
483        assert_eq!(result.nodes_tiled, 0);
484        assert_eq!(result.loops_unrolled, 0);
485        assert_eq!(result.estimated_cache_improvement, 0.0);
486        assert_eq!(result.estimated_speedup, 1.0);
487    }
488
489    #[test]
490    fn test_tiling_with_metadata() {
491        let mut graph = EinsumGraph::new();
492        let a = graph.add_tensor("A");
493        let b = graph.add_tensor("B");
494        let c = graph.add_tensor("C");
495
496        graph
497            .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
498            .unwrap();
499
500        let strategy = TilingStrategy::for_matmul(64, 64, 64);
501        apply_tiling(&mut graph, &strategy).unwrap();
502
503        // Check that metadata was added
504        let node = &graph.nodes[0];
505        assert!(node.metadata.is_some());
506        if let Some(metadata) = &node.metadata {
507            assert!(metadata.get_attribute("tiling_strategy").is_some());
508        }
509    }
510
511    #[test]
512    fn test_cache_line_size_configuration() {
513        let strategy = TilingStrategy::new().with_cache_line_size(128);
514        assert_eq!(strategy.cache_line_size, 128);
515    }
516
517    #[test]
518    fn test_small_matrix_tiling() {
519        // Test tiling for small matrices (should use minimum tile sizes)
520        let strategy = TilingStrategy::for_matmul(4, 4, 4);
521        assert_eq!(strategy.tiles.len(), 3);
522        // All tiles should be at least 8 (the minimum)
523        for tile in &strategy.tiles {
524            assert!(tile.tile_size >= 8);
525        }
526    }
527
528    #[test]
529    fn test_large_matrix_tiling() {
530        // Test tiling for large matrices (should cap at 64)
531        let strategy = TilingStrategy::for_matmul(1024, 1024, 1024);
532        assert_eq!(strategy.tiles.len(), 3);
533        // All tiles should be capped at 64
534        for tile in &strategy.tiles {
535            assert!(tile.tile_size <= 64);
536        }
537    }
538}