1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10use crate::{EinsumGraph, IrError};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
14pub enum LayoutStrategy {
15 #[default]
17 RowMajor,
18 ColumnMajor,
20 Blocked { block_size: usize },
22 Tiled {
24 tile_height: usize,
25 tile_width: usize,
26 },
27 ZOrder,
29 Hilbert,
31}
32
33impl LayoutStrategy {
34 pub fn for_operation(op: &str) -> Self {
36 match op {
37 "matmul" | "einsum" => Self::Blocked { block_size: 32 },
38 "transpose" => Self::ColumnMajor,
39 "conv2d" => Self::Tiled {
40 tile_height: 8,
41 tile_width: 8,
42 },
43 "scan" | "reduce" => Self::RowMajor,
44 _ => Self::default(),
45 }
46 }
47
48 pub fn supports_vectorization(&self) -> bool {
50 matches!(
51 self,
52 Self::RowMajor | Self::Blocked { .. } | Self::Tiled { .. }
53 )
54 }
55
56 pub fn preserves_locality(&self) -> bool {
58 matches!(
59 self,
60 Self::Blocked { .. } | Self::Tiled { .. } | Self::ZOrder | Self::Hilbert
61 )
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67pub struct StridePattern {
68 pub strides: Vec<usize>,
70 pub is_contiguous: bool,
72 pub alignment: usize,
74}
75
76impl StridePattern {
77 pub fn row_major(dims: &[usize]) -> Self {
79 let mut strides = vec![1];
80 for i in (0..dims.len() - 1).rev() {
81 strides.insert(0, strides[0] * dims[i + 1]);
82 }
83
84 Self {
85 strides,
86 is_contiguous: true,
87 alignment: 0,
88 }
89 }
90
91 pub fn column_major(dims: &[usize]) -> Self {
93 let mut strides = vec![1];
94 for i in 0..dims.len() - 1 {
95 strides.push(strides[i] * dims[i]);
96 }
97
98 Self {
99 strides,
100 is_contiguous: true,
101 alignment: 0,
102 }
103 }
104
105 pub fn custom(strides: Vec<usize>) -> Self {
107 let is_contiguous = is_contiguous_strides(&strides);
108 Self {
109 strides,
110 is_contiguous,
111 alignment: 0,
112 }
113 }
114
115 pub fn with_alignment(mut self, alignment: usize) -> Self {
117 self.alignment = alignment;
118 self
119 }
120
121 pub fn is_vectorizable(&self) -> bool {
123 self.is_contiguous && self.strides.last().copied().unwrap_or(0) == 1
124 }
125
126 pub fn access_cost(&self) -> f64 {
128 if self.is_contiguous {
129 1.0
130 } else {
131 1.5 + (self.strides.len() as f64 * 0.1)
133 }
134 }
135}
136
137#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
139pub struct TensorLayout {
140 pub tensor_idx: usize,
142 pub strategy: LayoutStrategy,
144 pub strides: StridePattern,
146 pub is_mutable: bool,
148}
149
150impl TensorLayout {
151 pub fn new(tensor_idx: usize, strategy: LayoutStrategy, dims: &[usize]) -> Self {
153 let strides = match strategy {
154 LayoutStrategy::RowMajor => StridePattern::row_major(dims),
155 LayoutStrategy::ColumnMajor => StridePattern::column_major(dims),
156 _ => StridePattern::row_major(dims), };
158
159 Self {
160 tensor_idx,
161 strategy,
162 strides,
163 is_mutable: true,
164 }
165 }
166
167 pub fn access_efficiency(&self) -> f64 {
169 let base_efficiency = if self.strides.is_contiguous { 0.9 } else { 0.5 };
170
171 let locality_bonus: f64 = if self.strategy.preserves_locality() {
172 0.1
173 } else {
174 0.0
175 };
176
177 (base_efficiency + locality_bonus).min(1.0f64)
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct LayoutOptimizationResult {
184 pub layouts: HashMap<usize, TensorLayout>,
186 pub transformations_needed: usize,
188 pub estimated_improvement: f64,
190 pub estimated_speedup: f64,
192}
193
194impl LayoutOptimizationResult {
195 pub fn none() -> Self {
197 Self {
198 layouts: HashMap::new(),
199 transformations_needed: 0,
200 estimated_improvement: 0.0,
201 estimated_speedup: 1.0,
202 }
203 }
204
205 pub fn get_layout(&self, tensor_idx: usize) -> Option<&TensorLayout> {
207 self.layouts.get(&tensor_idx)
208 }
209}
210
211pub fn optimize_layouts(graph: &EinsumGraph) -> Result<LayoutOptimizationResult, IrError> {
213 let mut result = LayoutOptimizationResult::none();
214
215 for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
217 let dims = infer_dimensions(tensor_name, graph, tensor_idx);
219
220 let strategy = analyze_usage_pattern(graph, tensor_idx);
222
223 let layout = TensorLayout::new(tensor_idx, strategy, &dims);
224 result.layouts.insert(tensor_idx, layout);
225 }
226
227 result.transformations_needed = count_layout_conversions(&result.layouts);
229
230 let avg_efficiency: f64 = result
232 .layouts
233 .values()
234 .map(|l| l.access_efficiency())
235 .sum::<f64>()
236 / result.layouts.len().max(1) as f64;
237
238 result.estimated_improvement = (avg_efficiency - 0.7).max(0.0);
239 result.estimated_speedup = 1.0 + result.estimated_improvement * 0.3;
240
241 Ok(result)
242}
243
244pub fn apply_layouts(
246 graph: &mut EinsumGraph,
247 layouts: &HashMap<usize, TensorLayout>,
248) -> Result<(), IrError> {
249 for (tensor_idx, layout) in layouts {
251 if *tensor_idx < graph.tensors.len() {
252 let mut metadata = graph
253 .get_tensor_metadata(*tensor_idx)
254 .cloned()
255 .unwrap_or_else(crate::Metadata::new);
256
257 metadata
258 .attributes
259 .push(("layout".to_string(), format!("{:?}", layout.strategy)));
260 metadata.attributes.push((
261 "is_contiguous".to_string(),
262 layout.strides.is_contiguous.to_string(),
263 ));
264
265 graph.add_tensor_metadata(*tensor_idx, metadata);
266 }
267 }
268
269 Ok(())
270}
271
272pub fn find_layout_fusion_opportunities(
274 layouts: &HashMap<usize, TensorLayout>,
275) -> Vec<(usize, usize)> {
276 let mut opportunities = Vec::new();
277
278 let tensor_indices: Vec<_> = layouts.keys().copied().collect();
280
281 for i in 0..tensor_indices.len() {
282 for j in (i + 1)..tensor_indices.len() {
283 let idx1 = tensor_indices[i];
284 let idx2 = tensor_indices[j];
285
286 if let (Some(layout1), Some(layout2)) = (layouts.get(&idx1), layouts.get(&idx2)) {
287 if layout1.strategy != layout2.strategy && layout1.is_mutable && layout2.is_mutable
288 {
289 opportunities.push((idx1, idx2));
290 }
291 }
292 }
293 }
294
295 opportunities
296}
297
298fn infer_dimensions(_tensor_name: &str, _graph: &EinsumGraph, _tensor_idx: usize) -> Vec<usize> {
301 vec![64, 64]
305}
306
307fn analyze_usage_pattern(graph: &EinsumGraph, tensor_idx: usize) -> LayoutStrategy {
308 let mut read_patterns = Vec::new();
310
311 for node in &graph.nodes {
312 if node.inputs.contains(&tensor_idx) {
313 let pattern = match &node.op {
315 crate::OpType::Einsum { spec } => analyze_einsum_pattern(spec),
316 crate::OpType::Reduce { .. } => "reduce",
317 crate::OpType::ElemUnary { .. } => "scan",
318 crate::OpType::ElemBinary { .. } => "scan",
319 };
320 read_patterns.push(pattern);
321 }
322 }
323
324 if read_patterns.contains(&"matmul") {
326 LayoutStrategy::Blocked { block_size: 32 }
327 } else if read_patterns.contains(&"transpose") {
328 LayoutStrategy::ColumnMajor
329 } else if read_patterns.contains(&"conv") {
330 LayoutStrategy::Tiled {
331 tile_height: 8,
332 tile_width: 8,
333 }
334 } else {
335 LayoutStrategy::RowMajor
336 }
337}
338
339fn analyze_einsum_pattern(spec: &str) -> &'static str {
340 if spec.contains(',') {
341 "matmul"
342 } else if spec.contains("->") {
343 let parts: Vec<&str> = spec.split("->").collect();
344 if parts.len() == 2 && parts[0].len() > parts[1].len() {
345 "reduce"
346 } else {
347 "scan"
348 }
349 } else {
350 "scan"
351 }
352}
353
354fn count_layout_conversions(layouts: &HashMap<usize, TensorLayout>) -> usize {
355 layouts
357 .values()
358 .filter(|l| l.strategy != LayoutStrategy::RowMajor)
359 .count()
360}
361
362fn is_contiguous_strides(strides: &[usize]) -> bool {
363 if strides.is_empty() {
364 return true;
365 }
366
367 let mut prev = strides[strides.len() - 1];
369 if prev != 1 {
370 return false;
371 }
372
373 for &stride in strides.iter().rev().skip(1) {
374 if stride <= prev {
375 return false;
376 }
377 let ratio = stride / prev;
380 if ratio == 0 || ratio > 10000 {
381 return false;
382 }
383 if stride % prev != 0 {
385 return false;
386 }
387 prev = stride;
388 }
389
390 true
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_layout_strategy_default() {
399 assert_eq!(LayoutStrategy::default(), LayoutStrategy::RowMajor);
400 }
401
402 #[test]
403 fn test_layout_strategy_for_operation() {
404 let matmul_layout = LayoutStrategy::for_operation("matmul");
405 assert!(matches!(matmul_layout, LayoutStrategy::Blocked { .. }));
406
407 let transpose_layout = LayoutStrategy::for_operation("transpose");
408 assert_eq!(transpose_layout, LayoutStrategy::ColumnMajor);
409
410 let conv_layout = LayoutStrategy::for_operation("conv2d");
411 assert!(matches!(conv_layout, LayoutStrategy::Tiled { .. }));
412 }
413
414 #[test]
415 fn test_layout_strategy_vectorization() {
416 assert!(LayoutStrategy::RowMajor.supports_vectorization());
417 assert!(LayoutStrategy::Blocked { block_size: 32 }.supports_vectorization());
418 assert!(!LayoutStrategy::ZOrder.supports_vectorization());
419 }
420
421 #[test]
422 fn test_layout_strategy_locality() {
423 assert!(LayoutStrategy::Blocked { block_size: 32 }.preserves_locality());
424 assert!(LayoutStrategy::ZOrder.preserves_locality());
425 assert!(LayoutStrategy::Hilbert.preserves_locality());
426 assert!(!LayoutStrategy::RowMajor.preserves_locality());
427 }
428
429 #[test]
430 fn test_stride_pattern_row_major() {
431 let dims = vec![4, 8, 16];
432 let pattern = StridePattern::row_major(&dims);
433
434 assert_eq!(pattern.strides, vec![128, 16, 1]);
435 assert!(pattern.is_contiguous);
436 assert!(pattern.is_vectorizable());
437 }
438
439 #[test]
440 fn test_stride_pattern_column_major() {
441 let dims = vec![4, 8, 16];
442 let pattern = StridePattern::column_major(&dims);
443
444 assert_eq!(pattern.strides, vec![1, 4, 32]);
445 assert!(pattern.is_contiguous);
446 }
447
448 #[test]
449 fn test_stride_pattern_custom() {
450 let strides = vec![64, 8, 1];
451 let pattern = StridePattern::custom(strides.clone());
452
453 assert_eq!(pattern.strides, strides);
454 assert!(pattern.is_contiguous);
455 }
456
457 #[test]
458 fn test_stride_pattern_non_contiguous() {
459 let strides = vec![100, 10, 2]; let pattern = StridePattern::custom(strides);
461
462 assert!(!pattern.is_contiguous);
463 assert!(!pattern.is_vectorizable());
464 }
465
466 #[test]
467 fn test_stride_pattern_with_alignment() {
468 let pattern = StridePattern::row_major(&[4, 8]).with_alignment(64);
469 assert_eq!(pattern.alignment, 64);
470 }
471
472 #[test]
473 fn test_stride_pattern_access_cost() {
474 let contiguous = StridePattern::row_major(&[4, 8]);
475 let non_contiguous = StridePattern::custom(vec![100, 10, 2]);
476
477 assert!(contiguous.access_cost() < non_contiguous.access_cost());
478 }
479
480 #[test]
481 fn test_tensor_layout_creation() {
482 let layout = TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]);
483
484 assert_eq!(layout.tensor_idx, 0);
485 assert_eq!(layout.strategy, LayoutStrategy::RowMajor);
486 assert!(layout.is_mutable);
487 assert!(layout.strides.is_contiguous);
488 }
489
490 #[test]
491 fn test_tensor_layout_access_efficiency() {
492 let row_major = TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]);
493 let blocked = TensorLayout::new(0, LayoutStrategy::Blocked { block_size: 32 }, &[4, 8]);
494
495 let row_efficiency = row_major.access_efficiency();
496 let blocked_efficiency = blocked.access_efficiency();
497
498 assert!(row_efficiency > 0.0 && row_efficiency <= 1.0);
499 assert!(blocked_efficiency > row_efficiency); }
501
502 #[test]
503 fn test_layout_optimization_result_none() {
504 let result = LayoutOptimizationResult::none();
505 assert!(result.layouts.is_empty());
506 assert_eq!(result.transformations_needed, 0);
507 assert_eq!(result.estimated_improvement, 0.0);
508 assert_eq!(result.estimated_speedup, 1.0);
509 }
510
511 #[test]
512 fn test_optimize_layouts_empty_graph() {
513 let graph = EinsumGraph::new();
514 let result = optimize_layouts(&graph).unwrap();
515 assert!(result.layouts.is_empty());
516 }
517
518 #[test]
519 fn test_optimize_layouts_simple_graph() {
520 let mut graph = EinsumGraph::new();
521 let a = graph.add_tensor("A");
522 let b = graph.add_tensor("B");
523 let c = graph.add_tensor("C");
524
525 graph
526 .add_node(crate::EinsumNode::einsum("ik,kj->ij", vec![a, b], vec![c]))
527 .unwrap();
528
529 let result = optimize_layouts(&graph).unwrap();
530 assert_eq!(result.layouts.len(), 3);
531 assert!(result.estimated_speedup >= 1.0);
532 }
533
534 #[test]
535 fn test_apply_layouts() {
536 let mut graph = EinsumGraph::new();
537 let a = graph.add_tensor("A");
538
539 let mut layouts = HashMap::new();
540 layouts.insert(
541 a,
542 TensorLayout::new(a, LayoutStrategy::Blocked { block_size: 32 }, &[64, 64]),
543 );
544
545 apply_layouts(&mut graph, &layouts).unwrap();
546
547 let metadata = graph.get_tensor_metadata(a);
549 assert!(metadata.is_some());
550 }
551
552 #[test]
553 fn test_find_layout_fusion_opportunities() {
554 let mut layouts = HashMap::new();
555
556 layouts.insert(0, TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]));
557 layouts.insert(
558 1,
559 TensorLayout::new(1, LayoutStrategy::ColumnMajor, &[4, 8]),
560 );
561 layouts.insert(2, TensorLayout::new(2, LayoutStrategy::RowMajor, &[4, 8]));
562
563 let opportunities = find_layout_fusion_opportunities(&layouts);
564 assert!(!opportunities.is_empty());
565 }
566
567 #[test]
568 fn test_analyze_einsum_pattern() {
569 assert_eq!(analyze_einsum_pattern("ik,kj->ij"), "matmul");
570 assert_eq!(analyze_einsum_pattern("ijk->ij"), "reduce");
571 assert_eq!(analyze_einsum_pattern("ij->ij"), "scan");
572 }
573
574 #[test]
575 fn test_is_contiguous_strides() {
576 assert!(is_contiguous_strides(&[8, 4, 1]));
577 assert!(is_contiguous_strides(&[1]));
578 assert!(is_contiguous_strides(&[]));
579 assert!(is_contiguous_strides(&[8, 2, 1])); assert!(!is_contiguous_strides(&[8, 4, 2])); assert!(!is_contiguous_strides(&[9, 2, 1])); }
583
584 #[test]
585 fn test_count_layout_conversions() {
586 let mut layouts = HashMap::new();
587
588 layouts.insert(0, TensorLayout::new(0, LayoutStrategy::RowMajor, &[4, 8]));
589 layouts.insert(
590 1,
591 TensorLayout::new(1, LayoutStrategy::ColumnMajor, &[4, 8]),
592 );
593 layouts.insert(
594 2,
595 TensorLayout::new(2, LayoutStrategy::Blocked { block_size: 32 }, &[4, 8]),
596 );
597
598 let conversions = count_layout_conversions(&layouts);
599 assert_eq!(conversions, 2); }
601
602 #[test]
603 fn test_layout_optimization_with_metadata() {
604 let mut graph = EinsumGraph::new();
605 let a = graph.add_tensor("A");
606 let b = graph.add_tensor("B");
607
608 let metadata = crate::Metadata::new().with_attribute("preferred_layout", "blocked");
610 graph.add_tensor_metadata(a, metadata);
611
612 let result = optimize_layouts(&graph).unwrap();
613 assert!(result.get_layout(a).is_some());
614 assert!(result.get_layout(b).is_some());
615 }
616}