1use crate::graph::{ComputationGraph, NodeId};
46use crate::JitResult;
47use std::collections::HashMap;
48
49#[derive(Debug, Clone)]
55pub struct LoopNest {
56 pub loops: Vec<Loop>,
58
59 pub statements: Vec<Statement>,
61
62 pub dependencies: Vec<Dependence>,
64
65 pub domain: IterationDomain,
67}
68
69#[derive(Debug, Clone)]
71pub struct Loop {
72 pub variable: String,
74
75 pub lower_bound: AffineExpr,
77
78 pub upper_bound: AffineExpr,
80
81 pub step: i64,
83
84 pub depth: usize,
86}
87
88#[derive(Debug, Clone)]
90pub struct Statement {
91 pub id: usize,
93
94 pub node_id: NodeId,
96
97 pub domain: Polyhedron,
99
100 pub schedule: AffineSchedule,
102
103 pub accesses: Vec<MemoryAccess>,
105}
106
107#[derive(Debug, Clone)]
109pub struct MemoryAccess {
110 pub array_name: String,
112
113 pub access_fn: Vec<AffineExpr>,
115
116 pub access_type: AccessType,
118}
119
120#[derive(Debug, Clone, PartialEq)]
122pub enum AccessType {
123 Read,
124 Write,
125 ReadWrite,
126}
127
128#[derive(Debug, Clone)]
130pub struct Dependence {
131 pub source: usize,
133
134 pub target: usize,
136
137 pub dep_type: DependenceType,
139
140 pub polyhedron: Polyhedron,
142
143 pub distance: Vec<i64>,
145}
146
147#[derive(Debug, Clone, PartialEq)]
149pub enum DependenceType {
150 Flow,
152
153 Anti,
155
156 Output,
158
159 Input,
161}
162
163#[derive(Debug, Clone, PartialEq)]
169pub struct AffineExpr {
170 pub constant: i64,
172
173 pub coefficients: HashMap<String, i64>,
175}
176
177impl AffineExpr {
178 pub fn constant(value: i64) -> Self {
180 Self {
181 constant: value,
182 coefficients: HashMap::new(),
183 }
184 }
185
186 pub fn variable(name: String) -> Self {
188 let mut coefficients = HashMap::new();
189 coefficients.insert(name, 1);
190 Self {
191 constant: 0,
192 coefficients,
193 }
194 }
195
196 pub fn add(&self, other: &AffineExpr) -> AffineExpr {
198 let mut coefficients = self.coefficients.clone();
199 for (var, &coeff) in &other.coefficients {
200 *coefficients.entry(var.clone()).or_insert(0) += coeff;
201 }
202 AffineExpr {
203 constant: self.constant + other.constant,
204 coefficients,
205 }
206 }
207
208 pub fn mul(&self, scalar: i64) -> AffineExpr {
210 let coefficients = self
211 .coefficients
212 .iter()
213 .map(|(k, &v)| (k.clone(), v * scalar))
214 .collect();
215 AffineExpr {
216 constant: self.constant * scalar,
217 coefficients,
218 }
219 }
220
221 pub fn evaluate(&self, vars: &HashMap<String, i64>) -> i64 {
223 let mut result = self.constant;
224 for (var, &coeff) in &self.coefficients {
225 if let Some(&val) = vars.get(var) {
226 result += coeff * val;
227 }
228 }
229 result
230 }
231
232 pub fn is_constant(&self) -> bool {
234 self.coefficients.is_empty()
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct AffineSchedule {
241 pub dimensions: Vec<AffineExpr>,
243}
244
245impl AffineSchedule {
246 pub fn identity(num_dims: usize) -> Self {
248 let dimensions = (0..num_dims)
249 .map(|i| AffineExpr::variable(format!("i{}", i)))
250 .collect();
251 Self { dimensions }
252 }
253
254 pub fn transform(&self, matrix: &TransformationMatrix) -> AffineSchedule {
256 matrix.apply_schedule(self)
257 }
258}
259
260#[derive(Debug, Clone)]
266pub struct Polyhedron {
267 pub constraints: Vec<AffineConstraint>,
269
270 pub dimension: usize,
272}
273
274#[derive(Debug, Clone)]
276pub struct AffineConstraint {
277 pub expression: AffineExpr,
279
280 pub constraint_type: ConstraintType,
282}
283
284#[derive(Debug, Clone, PartialEq)]
285pub enum ConstraintType {
286 Inequality,
288
289 Equality,
291}
292
293impl Polyhedron {
294 pub fn empty(dimension: usize) -> Self {
296 Self {
297 constraints: Vec::new(),
298 dimension,
299 }
300 }
301
302 pub fn add_constraint(&mut self, constraint: AffineConstraint) {
304 self.constraints.push(constraint);
305 }
306
307 pub fn is_empty(&self) -> bool {
309 for c in &self.constraints {
311 if c.constraint_type == ConstraintType::Equality {
312 if c.expression.is_constant() && c.expression.constant != 0 {
313 return true; }
315 }
316 }
317 false
318 }
319
320 pub fn intersect(&self, other: &Polyhedron) -> Polyhedron {
322 let mut result = self.clone();
323 for constraint in &other.constraints {
324 result.add_constraint(constraint.clone());
325 }
326 result
327 }
328
329 pub fn project_out(&self, _dimension: usize) -> Polyhedron {
331 self.clone()
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct IterationDomain {
339 pub polyhedron: Polyhedron,
341
342 pub variables: Vec<String>,
344}
345
346impl IterationDomain {
347 pub fn rectangular(bounds: Vec<(String, i64, i64)>) -> Self {
349 let dimension = bounds.len();
350 let mut polyhedron = Polyhedron::empty(dimension);
351 let variables: Vec<String> = bounds.iter().map(|(v, _, _)| v.clone()).collect();
352
353 for (var, lower, upper) in bounds {
354 let mut lower_expr = AffineExpr::variable(var.clone());
356 lower_expr.constant = -lower;
357 polyhedron.add_constraint(AffineConstraint {
358 expression: lower_expr,
359 constraint_type: ConstraintType::Inequality,
360 });
361
362 let mut upper_expr = AffineExpr::constant(upper);
364 *upper_expr.coefficients.entry(var).or_insert(0) -= 1;
365 polyhedron.add_constraint(AffineConstraint {
366 expression: upper_expr,
367 constraint_type: ConstraintType::Inequality,
368 });
369 }
370
371 Self {
372 polyhedron,
373 variables,
374 }
375 }
376}
377
378#[derive(Debug, Clone)]
384pub struct TransformationMatrix {
385 pub matrix: Vec<Vec<i64>>,
387
388 pub offset: Vec<i64>,
390}
391
392impl TransformationMatrix {
393 pub fn identity(size: usize) -> Self {
395 let mut matrix = vec![vec![0; size]; size];
396 for i in 0..size {
397 matrix[i][i] = 1;
398 }
399 Self {
400 matrix,
401 offset: vec![0; size],
402 }
403 }
404
405 pub fn interchange(size: usize, i: usize, j: usize) -> Self {
407 let mut matrix = Self::identity(size);
408 matrix.matrix.swap(i, j);
409 matrix
410 }
411
412 pub fn reversal(size: usize, i: usize) -> Self {
414 let mut matrix = Self::identity(size);
415 matrix.matrix[i][i] = -1;
416 matrix
417 }
418
419 pub fn skew(size: usize, i: usize, j: usize, factor: i64) -> Self {
421 let mut matrix = Self::identity(size);
422 matrix.matrix[i][j] = factor;
423 matrix
424 }
425
426 pub fn apply_schedule(&self, schedule: &AffineSchedule) -> AffineSchedule {
428 let mut new_dims = Vec::new();
429
430 for (row_idx, row) in self.matrix.iter().enumerate() {
431 let mut new_expr = AffineExpr::constant(self.offset[row_idx]);
432
433 for (col_idx, &coeff) in row.iter().enumerate() {
434 if coeff != 0 && col_idx < schedule.dimensions.len() {
435 let scaled = schedule.dimensions[col_idx].mul(coeff);
436 new_expr = new_expr.add(&scaled);
437 }
438 }
439
440 new_dims.push(new_expr);
441 }
442
443 AffineSchedule {
444 dimensions: new_dims,
445 }
446 }
447}
448
449pub struct PolyhedralOptimizer {
455 config: PolyhedralConfig,
457
458 stats: OptimizationStats,
460}
461
462#[derive(Debug, Clone)]
464pub struct PolyhedralConfig {
465 pub enable_tiling: bool,
467
468 pub tile_size: usize,
470
471 pub enable_fusion: bool,
473
474 pub enable_interchange: bool,
476
477 pub enable_skewing: bool,
479
480 pub maximize_parallelism: bool,
482
483 pub optimize_locality: bool,
485}
486
487impl Default for PolyhedralConfig {
488 fn default() -> Self {
489 Self {
490 enable_tiling: true,
491 tile_size: 32,
492 enable_fusion: true,
493 enable_interchange: true,
494 enable_skewing: true,
495 maximize_parallelism: true,
496 optimize_locality: true,
497 }
498 }
499}
500
501#[derive(Debug, Clone, Default)]
503pub struct OptimizationStats {
504 pub loops_transformed: usize,
506
507 pub statements_fused: usize,
509
510 pub estimated_speedup: f32,
512
513 pub parallelism_degree: usize,
515}
516
517impl PolyhedralOptimizer {
518 pub fn new() -> Self {
520 Self::with_config(PolyhedralConfig::default())
521 }
522
523 pub fn with_config(config: PolyhedralConfig) -> Self {
525 Self {
526 config,
527 stats: OptimizationStats::default(),
528 }
529 }
530
531 pub fn extract_loop_nest(&self, _graph: &ComputationGraph) -> JitResult<LoopNest> {
533 let loops = vec![
537 Loop {
538 variable: "i".to_string(),
539 lower_bound: AffineExpr::constant(0),
540 upper_bound: AffineExpr::constant(100),
541 step: 1,
542 depth: 0,
543 },
544 Loop {
545 variable: "j".to_string(),
546 lower_bound: AffineExpr::constant(0),
547 upper_bound: AffineExpr::constant(100),
548 step: 1,
549 depth: 1,
550 },
551 ];
552
553 let domain = IterationDomain::rectangular(vec![
554 ("i".to_string(), 0, 100),
555 ("j".to_string(), 0, 100),
556 ]);
557
558 Ok(LoopNest {
559 loops,
560 statements: Vec::new(),
561 dependencies: Vec::new(),
562 domain,
563 })
564 }
565
566 pub fn compute_schedule(&mut self, nest: &LoopNest) -> JitResult<Vec<AffineSchedule>> {
568 let mut schedules = Vec::new();
569
570 for stmt in &nest.statements {
572 let schedule = AffineSchedule::identity(nest.loops.len());
573 schedules.push(schedule);
574 }
575
576 if schedules.is_empty() {
578 schedules.push(AffineSchedule::identity(nest.loops.len()));
579 }
580
581 if self.config.enable_interchange {
583 schedules = self.apply_interchange(nest, schedules)?;
584 }
585
586 if self.config.enable_skewing {
587 schedules = self.apply_skewing(nest, schedules)?;
588 }
589
590 if self.config.enable_tiling {
591 schedules = self.apply_tiling(nest, schedules)?;
592 }
593
594 Ok(schedules)
595 }
596
597 fn apply_interchange(
599 &mut self,
600 nest: &LoopNest,
601 schedules: Vec<AffineSchedule>,
602 ) -> JitResult<Vec<AffineSchedule>> {
603 let num_loops = nest.loops.len();
604
605 if num_loops < 2 {
606 return Ok(schedules);
607 }
608
609 let transform = TransformationMatrix::interchange(num_loops, 0, 1);
611
612 let new_schedules = schedules
613 .iter()
614 .map(|sched| transform.apply_schedule(sched))
615 .collect();
616
617 self.stats.loops_transformed += num_loops;
618
619 Ok(new_schedules)
620 }
621
622 fn apply_skewing(
624 &mut self,
625 nest: &LoopNest,
626 schedules: Vec<AffineSchedule>,
627 ) -> JitResult<Vec<AffineSchedule>> {
628 let num_loops = nest.loops.len();
629
630 if num_loops < 2 {
631 return Ok(schedules);
632 }
633
634 let has_diagonal_deps = nest
636 .dependencies
637 .iter()
638 .any(|dep| dep.distance.len() >= 2 && dep.distance[0] == dep.distance[1]);
639
640 if has_diagonal_deps {
641 let transform = TransformationMatrix::skew(num_loops, 0, 1, 1);
642 let new_schedules = schedules
643 .iter()
644 .map(|sched| transform.apply_schedule(sched))
645 .collect();
646 return Ok(new_schedules);
647 }
648
649 Ok(schedules)
650 }
651
652 fn apply_tiling(
654 &mut self,
655 nest: &LoopNest,
656 schedules: Vec<AffineSchedule>,
657 ) -> JitResult<Vec<AffineSchedule>> {
658 let tile_size = self.config.tile_size as i64;
663 let num_loops = nest.loops.len();
664
665 let mut new_schedules = Vec::new();
667
668 for schedule in schedules {
669 let mut tiled_dims = Vec::new();
670
671 for dim in &schedule.dimensions {
673 let tiled = dim.mul(1); tiled_dims.push(tiled);
676 }
677
678 tiled_dims.extend(schedule.dimensions.clone());
680
681 new_schedules.push(AffineSchedule {
682 dimensions: tiled_dims,
683 });
684 }
685
686 self.stats.loops_transformed += num_loops;
687
688 Ok(new_schedules)
689 }
690
691 pub fn analyze_dependencies(&self, nest: &LoopNest) -> Vec<Dependence> {
693 let mut dependencies = Vec::new();
694
695 for (i, stmt1) in nest.statements.iter().enumerate() {
697 for (j, stmt2) in nest.statements.iter().enumerate().skip(i) {
698 if let Some(dep) = self.check_dependence(stmt1, stmt2) {
699 dependencies.push(dep);
700 }
701 }
702 }
703
704 dependencies
705 }
706
707 fn check_dependence(&self, stmt1: &Statement, stmt2: &Statement) -> Option<Dependence> {
709 for access1 in &stmt1.accesses {
711 for access2 in &stmt2.accesses {
712 if access1.array_name == access2.array_name {
713 let dep_type = self.classify_dependence(access1, access2);
714
715 if dep_type != DependenceType::Input {
716 let polyhedron = Polyhedron::empty(2);
718
719 return Some(Dependence {
720 source: stmt1.id,
721 target: stmt2.id,
722 dep_type,
723 polyhedron,
724 distance: vec![1, 0], });
726 }
727 }
728 }
729 }
730
731 None
732 }
733
734 fn classify_dependence(
736 &self,
737 access1: &MemoryAccess,
738 access2: &MemoryAccess,
739 ) -> DependenceType {
740 match (&access1.access_type, &access2.access_type) {
741 (AccessType::Write, AccessType::Read) => DependenceType::Flow,
742 (AccessType::Read, AccessType::Write) => DependenceType::Anti,
743 (AccessType::Write, AccessType::Write) => DependenceType::Output,
744 (AccessType::Read, AccessType::Read) => DependenceType::Input,
745 _ => DependenceType::Flow,
746 }
747 }
748
749 pub fn is_fusion_legal(&self, nest1: &LoopNest, nest2: &LoopNest) -> bool {
751 if nest1.loops.len() != nest2.loops.len() {
753 return false;
754 }
755
756 for (loop1, loop2) in nest1.loops.iter().zip(nest2.loops.iter()) {
758 if loop1.lower_bound != loop2.lower_bound || loop1.upper_bound != loop2.upper_bound {
759 return false;
760 }
761 }
762
763 true
764 }
765
766 pub fn statistics(&self) -> &OptimizationStats {
768 &self.stats
769 }
770
771 pub fn reset_stats(&mut self) {
773 self.stats = OptimizationStats::default();
774 }
775}
776
777impl Default for PolyhedralOptimizer {
778 fn default() -> Self {
779 Self::new()
780 }
781}
782
783#[derive(Debug, Clone)]
789pub enum OptimizationStrategy {
790 MaxParallelism,
792
793 MaxLocality,
795
796 Balanced,
798
799 Custom(Vec<TransformationType>),
801}
802
803#[derive(Debug, Clone, PartialEq)]
805pub enum TransformationType {
806 Interchange(usize, usize),
807 Skewing(usize, usize, i64),
808 Tiling(Vec<usize>),
809 Fusion(Vec<usize>),
810 Distribution(Vec<usize>),
811 Reversal(usize),
812}
813
814#[cfg(test)]
819mod tests {
820 use super::*;
821
822 #[test]
823 fn test_affine_expr() {
824 let expr1 = AffineExpr::constant(5);
825 let expr2 = AffineExpr::variable("x".to_string());
826
827 let sum = expr1.add(&expr2);
828 assert_eq!(sum.constant, 5);
829 assert_eq!(sum.coefficients.get("x"), Some(&1));
830
831 let scaled = expr2.mul(3);
832 assert_eq!(scaled.coefficients.get("x"), Some(&3));
833 }
834
835 #[test]
836 fn test_affine_evaluation() {
837 let mut expr = AffineExpr::constant(10);
838 expr.coefficients.insert("x".to_string(), 2);
839 expr.coefficients.insert("y".to_string(), 3);
840
841 let mut vars = HashMap::new();
842 vars.insert("x".to_string(), 4);
843 vars.insert("y".to_string(), 5);
844
845 let result = expr.evaluate(&vars);
846 assert_eq!(result, 10 + 2 * 4 + 3 * 5); }
848
849 #[test]
850 fn test_polyhedron() {
851 let mut poly = Polyhedron::empty(2);
852 assert_eq!(poly.dimension, 2);
853
854 poly.add_constraint(AffineConstraint {
855 expression: AffineExpr::variable("x".to_string()),
856 constraint_type: ConstraintType::Inequality,
857 });
858
859 assert_eq!(poly.constraints.len(), 1);
860 assert!(!poly.is_empty());
861 }
862
863 #[test]
864 fn test_iteration_domain() {
865 let domain =
866 IterationDomain::rectangular(vec![("i".to_string(), 0, 10), ("j".to_string(), 0, 20)]);
867
868 assert_eq!(domain.variables.len(), 2);
869 assert_eq!(domain.polyhedron.constraints.len(), 4); }
871
872 #[test]
873 fn test_transformation_matrix() {
874 let identity = TransformationMatrix::identity(3);
875 assert_eq!(identity.matrix[0][0], 1);
876 assert_eq!(identity.matrix[0][1], 0);
877
878 let interchange = TransformationMatrix::interchange(3, 0, 1);
879 assert_eq!(interchange.matrix[0][0], 0);
880 assert_eq!(interchange.matrix[0][1], 1);
881 }
882
883 #[test]
884 fn test_polyhedral_optimizer() {
885 let optimizer = PolyhedralOptimizer::new();
886 assert!(optimizer.config.enable_tiling);
887 assert!(optimizer.config.enable_fusion);
888 }
889
890 #[test]
891 fn test_schedule_computation() {
892 use crate::graph::GraphBuilder;
893 use torsh_core::{DType, Shape};
894
895 let mut optimizer = PolyhedralOptimizer::new();
896
897 let mut builder = GraphBuilder::new();
898 let x = builder.add_input("x".to_string(), Shape::new(vec![10, 10]), DType::F32);
899 builder.mark_output(x).unwrap();
900
901 let graph = builder.build().unwrap();
902 let nest = optimizer.extract_loop_nest(&graph).unwrap();
903 let schedules = optimizer.compute_schedule(&nest).unwrap();
904
905 assert!(!schedules.is_empty());
906 }
907
908 #[test]
909 fn test_dependence_analysis() {
910 let optimizer = PolyhedralOptimizer::new();
911
912 let stmt1 = Statement {
913 id: 0,
914 node_id: 0.into(),
915 domain: Polyhedron::empty(2),
916 schedule: AffineSchedule::identity(2),
917 accesses: vec![MemoryAccess {
918 array_name: "A".to_string(),
919 access_fn: vec![AffineExpr::variable("i".to_string())],
920 access_type: AccessType::Write,
921 }],
922 };
923
924 let stmt2 = Statement {
925 id: 1,
926 node_id: 1.into(),
927 domain: Polyhedron::empty(2),
928 schedule: AffineSchedule::identity(2),
929 accesses: vec![MemoryAccess {
930 array_name: "A".to_string(),
931 access_fn: vec![AffineExpr::variable("i".to_string())],
932 access_type: AccessType::Read,
933 }],
934 };
935
936 let dep = optimizer.check_dependence(&stmt1, &stmt2);
937 assert!(dep.is_some());
938 assert_eq!(dep.unwrap().dep_type, DependenceType::Flow);
939 }
940
941 #[test]
942 fn test_fusion_legality() {
943 let optimizer = PolyhedralOptimizer::new();
944
945 let nest1 = LoopNest {
946 loops: vec![Loop {
947 variable: "i".to_string(),
948 lower_bound: AffineExpr::constant(0),
949 upper_bound: AffineExpr::constant(10),
950 step: 1,
951 depth: 0,
952 }],
953 statements: Vec::new(),
954 dependencies: Vec::new(),
955 domain: IterationDomain::rectangular(vec![("i".to_string(), 0, 10)]),
956 };
957
958 let nest2 = nest1.clone();
959
960 assert!(optimizer.is_fusion_legal(&nest1, &nest2));
961 }
962}