1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::{EinsumGraph, EinsumNode, IrError, OpType};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub struct TileConfig {
15 pub axis: usize,
17 pub tile_size: usize,
19 pub unroll: bool,
21}
22
23impl TileConfig {
24 pub fn new(axis: usize, tile_size: usize) -> Self {
26 Self {
27 axis,
28 tile_size,
29 unroll: false,
30 }
31 }
32
33 pub fn with_unroll(axis: usize, tile_size: usize) -> Self {
35 Self {
36 axis,
37 tile_size,
38 unroll: true,
39 }
40 }
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct TilingStrategy {
46 pub tiles: Vec<TileConfig>,
48 pub register_tiling: bool,
50 pub cache_line_size: usize,
52}
53
54impl Default for TilingStrategy {
55 fn default() -> Self {
56 Self {
57 tiles: Vec::new(),
58 register_tiling: false,
59 cache_line_size: 64, }
61 }
62}
63
64impl TilingStrategy {
65 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn add_tile(&mut self, config: TileConfig) -> &mut Self {
72 self.tiles.push(config);
73 self
74 }
75
76 pub fn with_register_tiling(mut self) -> Self {
78 self.register_tiling = true;
79 self
80 }
81
82 pub fn with_cache_line_size(mut self, size: usize) -> Self {
84 self.cache_line_size = size;
85 self
86 }
87
88 pub fn for_matmul(m: usize, k: usize, n: usize) -> Self {
90 let tile_m = m.clamp(8, 64);
92 let tile_k = k.clamp(8, 64);
93 let tile_n = n.clamp(8, 64);
94
95 let mut strategy = Self::new();
96 strategy.add_tile(TileConfig::new(0, tile_m)); strategy.add_tile(TileConfig::new(1, tile_k)); strategy.add_tile(TileConfig::new(2, tile_n)); strategy
100 }
101
102 pub fn for_conv(
104 batch: usize,
105 out_channels: usize,
106 out_height: usize,
107 out_width: usize,
108 ) -> Self {
109 let tile_b = batch.clamp(1, 16);
110 let tile_c = out_channels.clamp(1, 16);
111 let tile_h = out_height.clamp(1, 8);
112 let tile_w = out_width.clamp(1, 8);
113
114 let mut strategy = Self::new();
115 strategy.add_tile(TileConfig::new(0, tile_b));
116 strategy.add_tile(TileConfig::new(1, tile_c));
117 strategy.add_tile(TileConfig::new(2, tile_h));
118 strategy.add_tile(TileConfig::new(3, tile_w));
119 strategy
120 }
121}
122
123#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
125pub struct TilingResult {
126 pub nodes_tiled: usize,
128 pub loops_unrolled: usize,
130 pub estimated_cache_improvement: f64,
132 pub estimated_speedup: f64,
134}
135
136impl TilingResult {
137 pub fn none() -> Self {
139 Self {
140 nodes_tiled: 0,
141 loops_unrolled: 0,
142 estimated_cache_improvement: 0.0,
143 estimated_speedup: 1.0,
144 }
145 }
146}
147
148pub fn apply_tiling(
150 graph: &mut EinsumGraph,
151 strategy: &TilingStrategy,
152) -> Result<TilingResult, IrError> {
153 let mut result = TilingResult::none();
154
155 for node in &mut graph.nodes {
156 if let OpType::Einsum { spec } = &node.op {
157 if should_tile_einsum(spec) {
158 tile_einsum_node(node, strategy)?;
160 result.nodes_tiled += 1;
161
162 for tile in &strategy.tiles {
164 if tile.unroll {
165 result.loops_unrolled += 1;
166 }
167 }
168 }
169 }
170 }
171
172 if result.nodes_tiled > 0 {
174 result.estimated_cache_improvement = estimate_cache_improvement(strategy);
175 result.estimated_speedup = 1.0 + result.estimated_cache_improvement * 0.5;
176 }
177
178 Ok(result)
179}
180
181pub fn apply_register_tiling(graph: &mut EinsumGraph) -> Result<TilingResult, IrError> {
183 let mut strategy = TilingStrategy::new().with_register_tiling();
184
185 strategy.add_tile(TileConfig::with_unroll(0, 4));
187 strategy.add_tile(TileConfig::with_unroll(1, 4));
188
189 apply_tiling(graph, &strategy)
190}
191
192pub fn apply_multilevel_tiling(
194 graph: &mut EinsumGraph,
195 l1_tiles: &[usize],
196 l2_tiles: &[usize],
197 l3_tiles: &[usize],
198) -> Result<TilingResult, IrError> {
199 let mut total_result = TilingResult::none();
200
201 if !l3_tiles.is_empty() {
203 let mut strategy = TilingStrategy::new();
204 for (i, &tile_size) in l3_tiles.iter().enumerate() {
205 strategy.add_tile(TileConfig::new(i, tile_size));
206 }
207 let result = apply_tiling(graph, &strategy)?;
208 total_result.nodes_tiled += result.nodes_tiled;
209 }
210
211 if !l2_tiles.is_empty() {
213 let mut strategy = TilingStrategy::new();
214 for (i, &tile_size) in l2_tiles.iter().enumerate() {
215 strategy.add_tile(TileConfig::new(i, tile_size));
216 }
217 let result = apply_tiling(graph, &strategy)?;
218 total_result.nodes_tiled += result.nodes_tiled;
219 }
220
221 if !l1_tiles.is_empty() {
223 let mut strategy = TilingStrategy::new();
224 for (i, &tile_size) in l1_tiles.iter().enumerate() {
225 strategy.add_tile(TileConfig::with_unroll(i, tile_size));
226 }
227 let result = apply_tiling(graph, &strategy)?;
228 total_result.nodes_tiled += result.nodes_tiled;
229 total_result.loops_unrolled += result.loops_unrolled;
230 }
231
232 total_result.estimated_cache_improvement = 0.3; total_result.estimated_speedup = 1.5; Ok(total_result)
237}
238
239pub fn recommend_tiling_strategy(graph: &EinsumGraph) -> HashMap<usize, TilingStrategy> {
241 let mut recommendations = HashMap::new();
242
243 for (node_idx, node) in graph.nodes.iter().enumerate() {
244 if let OpType::Einsum { spec } = &node.op {
245 if let Some(strategy) = analyze_einsum_for_tiling(spec) {
246 recommendations.insert(node_idx, strategy);
247 }
248 }
249 }
250
251 recommendations
252}
253
254fn should_tile_einsum(spec: &str) -> bool {
257 spec.contains("->") && (spec.contains(',') || spec.len() > 6)
259}
260
261fn tile_einsum_node(node: &mut EinsumNode, strategy: &TilingStrategy) -> Result<(), IrError> {
262 if node.metadata.is_none() {
267 node.metadata = Some(crate::Metadata::new());
268 }
269
270 if let Some(metadata) = &mut node.metadata {
271 metadata.attributes.push((
272 "tiling_strategy".to_string(),
273 format!("{} tiles", strategy.tiles.len()),
274 ));
275 metadata.attributes.push((
276 "register_tiling".to_string(),
277 strategy.register_tiling.to_string(),
278 ));
279 }
280
281 Ok(())
282}
283
284fn estimate_cache_improvement(strategy: &TilingStrategy) -> f64 {
285 let base_improvement = 0.2; let per_tile_improvement = 0.1; let register_bonus = if strategy.register_tiling { 0.15 } else { 0.0 };
289
290 let total =
291 base_improvement + (strategy.tiles.len() as f64 * per_tile_improvement) + register_bonus;
292
293 total.min(0.8) }
295
296fn analyze_einsum_for_tiling(spec: &str) -> Option<TilingStrategy> {
297 if let Some(arrow_pos) = spec.find("->") {
299 let inputs = &spec[..arrow_pos];
300 let output = &spec[arrow_pos + 2..];
301
302 if inputs.contains(',') {
304 let parts: Vec<&str> = inputs.split(',').collect();
305 if parts.len() == 2 {
306 let a_axes = parts[0].trim();
307 let b_axes = parts[1].trim();
308
309 if a_axes.len() == 2 && b_axes.len() == 2 && output.len() == 2 {
311 let mut strategy = TilingStrategy::new();
312 strategy.add_tile(TileConfig::new(0, 32)); strategy.add_tile(TileConfig::new(1, 32)); strategy.add_tile(TileConfig::new(2, 32)); return Some(strategy);
316 }
317 }
318 }
319
320 if output.len() < inputs.replace(',', "").len() {
322 let mut strategy = TilingStrategy::new();
323 strategy.add_tile(TileConfig::new(0, 16));
324 return Some(strategy);
325 }
326 }
327
328 None
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_tile_config_creation() {
337 let config = TileConfig::new(0, 32);
338 assert_eq!(config.axis, 0);
339 assert_eq!(config.tile_size, 32);
340 assert!(!config.unroll);
341
342 let config_unroll = TileConfig::with_unroll(1, 16);
343 assert_eq!(config_unroll.axis, 1);
344 assert_eq!(config_unroll.tile_size, 16);
345 assert!(config_unroll.unroll);
346 }
347
348 #[test]
349 fn test_tiling_strategy_builder() {
350 let mut strategy = TilingStrategy::new();
351 strategy.add_tile(TileConfig::new(0, 32));
352 strategy.add_tile(TileConfig::new(1, 32));
353
354 assert_eq!(strategy.tiles.len(), 2);
355 assert!(!strategy.register_tiling);
356 }
357
358 #[test]
359 fn test_matmul_tiling_strategy() {
360 let strategy = TilingStrategy::for_matmul(128, 128, 128);
361 assert_eq!(strategy.tiles.len(), 3);
362 assert!(strategy.tiles[0].tile_size <= 64);
363 }
364
365 #[test]
366 fn test_conv_tiling_strategy() {
367 let strategy = TilingStrategy::for_conv(32, 64, 56, 56);
368 assert_eq!(strategy.tiles.len(), 4);
369 }
370
371 #[test]
372 fn test_should_tile_einsum() {
373 assert!(should_tile_einsum("ik,kj->ij"));
374 assert!(should_tile_einsum("ijk->ij"));
375 assert!(!should_tile_einsum("i->i"));
376 }
377
378 #[test]
379 fn test_analyze_einsum_for_tiling() {
380 let strategy = analyze_einsum_for_tiling("ik,kj->ij");
381 assert!(strategy.is_some());
382 let s = strategy.unwrap();
383 assert_eq!(s.tiles.len(), 3);
384
385 let strategy_reduction = analyze_einsum_for_tiling("ijk->ij");
386 assert!(strategy_reduction.is_some());
387 }
388
389 #[test]
390 fn test_apply_tiling_to_graph() {
391 let mut graph = EinsumGraph::new();
392 let a = graph.add_tensor("A");
393 let b = graph.add_tensor("B");
394 let c = graph.add_tensor("C");
395
396 graph
397 .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
398 .unwrap();
399
400 let strategy = TilingStrategy::for_matmul(64, 64, 64);
401 let result = apply_tiling(&mut graph, &strategy).unwrap();
402
403 assert_eq!(result.nodes_tiled, 1);
404 assert!(result.estimated_speedup >= 1.0);
405 }
406
407 #[test]
408 fn test_register_tiling() {
409 let mut graph = EinsumGraph::new();
410 let a = graph.add_tensor("A");
411 let b = graph.add_tensor("B");
412 let c = graph.add_tensor("C");
413
414 graph
415 .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
416 .unwrap();
417
418 let result = apply_register_tiling(&mut graph).unwrap();
419 assert_eq!(result.nodes_tiled, 1);
420 assert!(result.loops_unrolled > 0);
421 }
422
423 #[test]
424 fn test_multilevel_tiling() {
425 let mut graph = EinsumGraph::new();
426 let a = graph.add_tensor("A");
427 let b = graph.add_tensor("B");
428 let c = graph.add_tensor("C");
429
430 graph
431 .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
432 .unwrap();
433
434 let l1_tiles = vec![8, 8, 8];
435 let l2_tiles = vec![32, 32, 32];
436 let l3_tiles = vec![128, 128, 128];
437
438 let result = apply_multilevel_tiling(&mut graph, &l1_tiles, &l2_tiles, &l3_tiles).unwrap();
439 assert!(result.nodes_tiled > 0);
440 assert!(result.estimated_speedup > 1.0);
441 }
442
443 #[test]
444 fn test_recommend_tiling_strategy() {
445 let mut graph = EinsumGraph::new();
446 let a = graph.add_tensor("A");
447 let b = graph.add_tensor("B");
448 let c = graph.add_tensor("C");
449 let d = graph.add_tensor("D");
450
451 graph
453 .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
454 .unwrap();
455
456 graph
458 .add_node(EinsumNode::elem_unary("relu", c, d))
459 .unwrap();
460
461 let recommendations = recommend_tiling_strategy(&graph);
462 assert_eq!(recommendations.len(), 1); assert!(recommendations.contains_key(&0));
464 }
465
466 #[test]
467 fn test_estimate_cache_improvement() {
468 let mut strategy = TilingStrategy::new();
469 strategy.add_tile(TileConfig::new(0, 32));
470 strategy.add_tile(TileConfig::new(1, 32));
471
472 let improvement = estimate_cache_improvement(&strategy);
473 assert!(improvement > 0.0 && improvement <= 0.8);
474
475 let strategy_with_register = strategy.with_register_tiling();
476 let improvement_with_register = estimate_cache_improvement(&strategy_with_register);
477 assert!(improvement_with_register > improvement);
478 }
479
480 #[test]
481 fn test_tiling_result_none() {
482 let result = TilingResult::none();
483 assert_eq!(result.nodes_tiled, 0);
484 assert_eq!(result.loops_unrolled, 0);
485 assert_eq!(result.estimated_cache_improvement, 0.0);
486 assert_eq!(result.estimated_speedup, 1.0);
487 }
488
489 #[test]
490 fn test_tiling_with_metadata() {
491 let mut graph = EinsumGraph::new();
492 let a = graph.add_tensor("A");
493 let b = graph.add_tensor("B");
494 let c = graph.add_tensor("C");
495
496 graph
497 .add_node(EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
498 .unwrap();
499
500 let strategy = TilingStrategy::for_matmul(64, 64, 64);
501 apply_tiling(&mut graph, &strategy).unwrap();
502
503 let node = &graph.nodes[0];
505 assert!(node.metadata.is_some());
506 if let Some(metadata) = &node.metadata {
507 assert!(metadata.get_attribute("tiling_strategy").is_some());
508 }
509 }
510
511 #[test]
512 fn test_cache_line_size_configuration() {
513 let strategy = TilingStrategy::new().with_cache_line_size(128);
514 assert_eq!(strategy.cache_line_size, 128);
515 }
516
517 #[test]
518 fn test_small_matrix_tiling() {
519 let strategy = TilingStrategy::for_matmul(4, 4, 4);
521 assert_eq!(strategy.tiles.len(), 3);
522 for tile in &strategy.tiles {
524 assert!(tile.tile_size >= 8);
525 }
526 }
527
528 #[test]
529 fn test_large_matrix_tiling() {
530 let strategy = TilingStrategy::for_matmul(1024, 1024, 1024);
532 assert_eq!(strategy.tiles.len(), 3);
533 for tile in &strategy.tiles {
535 assert!(tile.tile_size <= 64);
536 }
537 }
538}