Skip to main content

torsh_jit/
polyhedral_optimization.rs

1// Copyright (c) 2025 ToRSh Contributors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! # Polyhedral Optimization
5//!
6//! This module implements advanced loop transformations using the polyhedral model,
7//! enabling aggressive optimization of nested loop structures common in deep learning.
8//!
9//! ## Key Concepts
10//!
11//! - **Polyhedral Model**: Represents loop iterations as integer points in polyhedra
12//! - **Affine Scheduling**: Compute optimal execution order using affine transformations
13//! - **Dependence Analysis**: Precise analysis of data dependencies in loop nests
14//! - **Loop Transformations**: Tiling, fusion, interchange, skewing, distribution
15//! - **Locality Optimization**: Maximize cache reuse through careful scheduling
16//!
17//! ## Transformations
18//!
19//! ```text
20//! Original Loop:                  After Tiling:
21//! for i in 0..N                  for ii in 0..N step T
22//!   for j in 0..M                  for jj in 0..M step T
23//!     A[i,j] = ...                   for i in ii..min(ii+T,N)
24//!                                      for j in jj..min(jj+T,M)
25//!                                        A[i,j] = ...
26//! ```
27//!
28//! ## Example
29//!
30//! ```rust,ignore
31//! use torsh_jit::polyhedral_optimization::{PolyhedralOptimizer, LoopNest};
32//!
33//! let optimizer = PolyhedralOptimizer::new();
34//!
35//! // Analyze loop nest
36//! let nest = LoopNest::from_graph(&graph)?;
37//!
38//! // Compute optimal schedule
39//! let schedule = optimizer.compute_schedule(&nest)?;
40//!
41//! // Apply transformations
42//! let optimized = optimizer.apply_schedule(&nest, &schedule)?;
43//! ```
44
45use crate::graph::{ComputationGraph, NodeId};
46use crate::JitResult;
47use std::collections::HashMap;
48
49// ============================================================================
50// Polyhedral Representation
51// ============================================================================
52
53/// A loop nest represented in polyhedral form
54#[derive(Debug, Clone)]
55pub struct LoopNest {
56    /// Loops in the nest
57    pub loops: Vec<Loop>,
58
59    /// Statements in the loop body
60    pub statements: Vec<Statement>,
61
62    /// Data dependencies
63    pub dependencies: Vec<Dependence>,
64
65    /// Iteration domain
66    pub domain: IterationDomain,
67}
68
69/// A single loop in the nest
70#[derive(Debug, Clone)]
71pub struct Loop {
72    /// Loop variable name
73    pub variable: String,
74
75    /// Lower bound (affine expression)
76    pub lower_bound: AffineExpr,
77
78    /// Upper bound (affine expression)
79    pub upper_bound: AffineExpr,
80
81    /// Step size
82    pub step: i64,
83
84    /// Nesting depth
85    pub depth: usize,
86}
87
88/// A statement within the loop
89#[derive(Debug, Clone)]
90pub struct Statement {
91    /// Statement ID
92    pub id: usize,
93
94    /// Associated graph node
95    pub node_id: NodeId,
96
97    /// Iteration domain (which loop iterations execute this)
98    pub domain: Polyhedron,
99
100    /// Schedule (when this executes)
101    pub schedule: AffineSchedule,
102
103    /// Memory accesses
104    pub accesses: Vec<MemoryAccess>,
105}
106
107/// Memory access pattern
108#[derive(Debug, Clone)]
109pub struct MemoryAccess {
110    /// Array being accessed
111    pub array_name: String,
112
113    /// Access function (affine)
114    pub access_fn: Vec<AffineExpr>,
115
116    /// Access type
117    pub access_type: AccessType,
118}
119
120/// Type of memory access
121#[derive(Debug, Clone, PartialEq)]
122pub enum AccessType {
123    Read,
124    Write,
125    ReadWrite,
126}
127
128/// Data dependence between statements
129#[derive(Debug, Clone)]
130pub struct Dependence {
131    /// Source statement
132    pub source: usize,
133
134    /// Target statement
135    pub target: usize,
136
137    /// Dependence type
138    pub dep_type: DependenceType,
139
140    /// Dependence polyhedron (iterations involved)
141    pub polyhedron: Polyhedron,
142
143    /// Dependence distance vector
144    pub distance: Vec<i64>,
145}
146
147/// Types of dependencies
148#[derive(Debug, Clone, PartialEq)]
149pub enum DependenceType {
150    /// Read-after-Write (true dependence)
151    Flow,
152
153    /// Write-after-Read (anti dependence)
154    Anti,
155
156    /// Write-after-Write (output dependence)
157    Output,
158
159    /// Read-after-Read (no real dependence)
160    Input,
161}
162
163// ============================================================================
164// Affine Expressions
165// ============================================================================
166
167/// Affine expression: a₀ + a₁*x₁ + a₂*x₂ + ... + aₙ*xₙ
168#[derive(Debug, Clone, PartialEq)]
169pub struct AffineExpr {
170    /// Constant term
171    pub constant: i64,
172
173    /// Coefficients for each variable
174    pub coefficients: HashMap<String, i64>,
175}
176
177impl AffineExpr {
178    /// Create constant expression
179    pub fn constant(value: i64) -> Self {
180        Self {
181            constant: value,
182            coefficients: HashMap::new(),
183        }
184    }
185
186    /// Create variable expression
187    pub fn variable(name: String) -> Self {
188        let mut coefficients = HashMap::new();
189        coefficients.insert(name, 1);
190        Self {
191            constant: 0,
192            coefficients,
193        }
194    }
195
196    /// Add two expressions
197    pub fn add(&self, other: &AffineExpr) -> AffineExpr {
198        let mut coefficients = self.coefficients.clone();
199        for (var, &coeff) in &other.coefficients {
200            *coefficients.entry(var.clone()).or_insert(0) += coeff;
201        }
202        AffineExpr {
203            constant: self.constant + other.constant,
204            coefficients,
205        }
206    }
207
208    /// Multiply by constant
209    pub fn mul(&self, scalar: i64) -> AffineExpr {
210        let coefficients = self
211            .coefficients
212            .iter()
213            .map(|(k, &v)| (k.clone(), v * scalar))
214            .collect();
215        AffineExpr {
216            constant: self.constant * scalar,
217            coefficients,
218        }
219    }
220
221    /// Evaluate with given variable values
222    pub fn evaluate(&self, vars: &HashMap<String, i64>) -> i64 {
223        let mut result = self.constant;
224        for (var, &coeff) in &self.coefficients {
225            if let Some(&val) = vars.get(var) {
226                result += coeff * val;
227            }
228        }
229        result
230    }
231
232    /// Check if expression is constant
233    pub fn is_constant(&self) -> bool {
234        self.coefficients.is_empty()
235    }
236}
237
238/// Affine schedule: maps iterations to execution time
239#[derive(Debug, Clone)]
240pub struct AffineSchedule {
241    /// Schedule dimensions (one per level)
242    pub dimensions: Vec<AffineExpr>,
243}
244
245impl AffineSchedule {
246    /// Create identity schedule (original order)
247    pub fn identity(num_dims: usize) -> Self {
248        let dimensions = (0..num_dims)
249            .map(|i| AffineExpr::variable(format!("i{}", i)))
250            .collect();
251        Self { dimensions }
252    }
253
254    /// Apply transformation matrix
255    pub fn transform(&self, matrix: &TransformationMatrix) -> AffineSchedule {
256        matrix.apply_schedule(self)
257    }
258}
259
260// ============================================================================
261// Polyhedra
262// ============================================================================
263
264/// A polyhedron defined by affine inequalities: Ax + b ≥ 0
265#[derive(Debug, Clone)]
266pub struct Polyhedron {
267    /// Affine constraints
268    pub constraints: Vec<AffineConstraint>,
269
270    /// Dimension (number of variables)
271    pub dimension: usize,
272}
273
274/// Single affine constraint: expr ≥ 0 or expr = 0
275#[derive(Debug, Clone)]
276pub struct AffineConstraint {
277    /// Affine expression
278    pub expression: AffineExpr,
279
280    /// Constraint type
281    pub constraint_type: ConstraintType,
282}
283
284#[derive(Debug, Clone, PartialEq)]
285pub enum ConstraintType {
286    /// expr ≥ 0
287    Inequality,
288
289    /// expr = 0
290    Equality,
291}
292
293impl Polyhedron {
294    /// Create empty polyhedron
295    pub fn empty(dimension: usize) -> Self {
296        Self {
297            constraints: Vec::new(),
298            dimension,
299        }
300    }
301
302    /// Add constraint
303    pub fn add_constraint(&mut self, constraint: AffineConstraint) {
304        self.constraints.push(constraint);
305    }
306
307    /// Check if polyhedron is empty
308    pub fn is_empty(&self) -> bool {
309        // Simplified: check for obvious contradictions
310        for c in &self.constraints {
311            if c.constraint_type == ConstraintType::Equality {
312                if c.expression.is_constant() && c.expression.constant != 0 {
313                    return true; // 0 = constant (non-zero) is contradiction
314                }
315            }
316        }
317        false
318    }
319
320    /// Compute intersection with another polyhedron
321    pub fn intersect(&self, other: &Polyhedron) -> Polyhedron {
322        let mut result = self.clone();
323        for constraint in &other.constraints {
324            result.add_constraint(constraint.clone());
325        }
326        result
327    }
328
329    /// Project out a dimension
330    pub fn project_out(&self, _dimension: usize) -> Polyhedron {
331        // Simplified: Fourier-Motzkin elimination would be used here
332        self.clone()
333    }
334}
335
336/// Iteration domain for a loop nest
337#[derive(Debug, Clone)]
338pub struct IterationDomain {
339    /// Polyhedron representing valid iterations
340    pub polyhedron: Polyhedron,
341
342    /// Loop variables
343    pub variables: Vec<String>,
344}
345
346impl IterationDomain {
347    /// Create domain for simple rectangular iteration space
348    pub fn rectangular(bounds: Vec<(String, i64, i64)>) -> Self {
349        let dimension = bounds.len();
350        let mut polyhedron = Polyhedron::empty(dimension);
351        let variables: Vec<String> = bounds.iter().map(|(v, _, _)| v.clone()).collect();
352
353        for (var, lower, upper) in bounds {
354            // var - lower ≥ 0
355            let mut lower_expr = AffineExpr::variable(var.clone());
356            lower_expr.constant = -lower;
357            polyhedron.add_constraint(AffineConstraint {
358                expression: lower_expr,
359                constraint_type: ConstraintType::Inequality,
360            });
361
362            // upper - var ≥ 0
363            let mut upper_expr = AffineExpr::constant(upper);
364            *upper_expr.coefficients.entry(var).or_insert(0) -= 1;
365            polyhedron.add_constraint(AffineConstraint {
366                expression: upper_expr,
367                constraint_type: ConstraintType::Inequality,
368            });
369        }
370
371        Self {
372            polyhedron,
373            variables,
374        }
375    }
376}
377
378// ============================================================================
379// Transformations
380// ============================================================================
381
382/// Transformation matrix for affine scheduling
383#[derive(Debug, Clone)]
384pub struct TransformationMatrix {
385    /// Matrix coefficients (row-major)
386    pub matrix: Vec<Vec<i64>>,
387
388    /// Constant vector
389    pub offset: Vec<i64>,
390}
391
392impl TransformationMatrix {
393    /// Create identity transformation
394    pub fn identity(size: usize) -> Self {
395        let mut matrix = vec![vec![0; size]; size];
396        for i in 0..size {
397            matrix[i][i] = 1;
398        }
399        Self {
400            matrix,
401            offset: vec![0; size],
402        }
403    }
404
405    /// Create loop interchange (swap dimensions i and j)
406    pub fn interchange(size: usize, i: usize, j: usize) -> Self {
407        let mut matrix = Self::identity(size);
408        matrix.matrix.swap(i, j);
409        matrix
410    }
411
412    /// Create loop reversal (reverse dimension i)
413    pub fn reversal(size: usize, i: usize) -> Self {
414        let mut matrix = Self::identity(size);
415        matrix.matrix[i][i] = -1;
416        matrix
417    }
418
419    /// Create skewing transformation
420    pub fn skew(size: usize, i: usize, j: usize, factor: i64) -> Self {
421        let mut matrix = Self::identity(size);
422        matrix.matrix[i][j] = factor;
423        matrix
424    }
425
426    /// Apply to affine schedule
427    pub fn apply_schedule(&self, schedule: &AffineSchedule) -> AffineSchedule {
428        let mut new_dims = Vec::new();
429
430        for (row_idx, row) in self.matrix.iter().enumerate() {
431            let mut new_expr = AffineExpr::constant(self.offset[row_idx]);
432
433            for (col_idx, &coeff) in row.iter().enumerate() {
434                if coeff != 0 && col_idx < schedule.dimensions.len() {
435                    let scaled = schedule.dimensions[col_idx].mul(coeff);
436                    new_expr = new_expr.add(&scaled);
437                }
438            }
439
440            new_dims.push(new_expr);
441        }
442
443        AffineSchedule {
444            dimensions: new_dims,
445        }
446    }
447}
448
449// ============================================================================
450// Polyhedral Optimizer
451// ============================================================================
452
453/// Main polyhedral optimization engine
454pub struct PolyhedralOptimizer {
455    /// Configuration
456    config: PolyhedralConfig,
457
458    /// Optimization statistics
459    stats: OptimizationStats,
460}
461
462/// Configuration for polyhedral optimization
463#[derive(Debug, Clone)]
464pub struct PolyhedralConfig {
465    /// Enable loop tiling
466    pub enable_tiling: bool,
467
468    /// Tile size for cache blocking
469    pub tile_size: usize,
470
471    /// Enable loop fusion
472    pub enable_fusion: bool,
473
474    /// Enable loop interchange
475    pub enable_interchange: bool,
476
477    /// Enable loop skewing
478    pub enable_skewing: bool,
479
480    /// Maximize parallelism
481    pub maximize_parallelism: bool,
482
483    /// Optimize for cache locality
484    pub optimize_locality: bool,
485}
486
487impl Default for PolyhedralConfig {
488    fn default() -> Self {
489        Self {
490            enable_tiling: true,
491            tile_size: 32,
492            enable_fusion: true,
493            enable_interchange: true,
494            enable_skewing: true,
495            maximize_parallelism: true,
496            optimize_locality: true,
497        }
498    }
499}
500
501/// Optimization statistics
502#[derive(Debug, Clone, Default)]
503pub struct OptimizationStats {
504    /// Number of loops transformed
505    pub loops_transformed: usize,
506
507    /// Number of statements fused
508    pub statements_fused: usize,
509
510    /// Estimated speedup
511    pub estimated_speedup: f32,
512
513    /// Parallelism exposed
514    pub parallelism_degree: usize,
515}
516
517impl PolyhedralOptimizer {
518    /// Create new polyhedral optimizer
519    pub fn new() -> Self {
520        Self::with_config(PolyhedralConfig::default())
521    }
522
523    /// Create with custom configuration
524    pub fn with_config(config: PolyhedralConfig) -> Self {
525        Self {
526            config,
527            stats: OptimizationStats::default(),
528        }
529    }
530
531    /// Extract loop nest from computation graph
532    pub fn extract_loop_nest(&self, _graph: &ComputationGraph) -> JitResult<LoopNest> {
533        // Simplified: Create a sample loop nest
534        // In production, would analyze graph structure
535
536        let loops = vec![
537            Loop {
538                variable: "i".to_string(),
539                lower_bound: AffineExpr::constant(0),
540                upper_bound: AffineExpr::constant(100),
541                step: 1,
542                depth: 0,
543            },
544            Loop {
545                variable: "j".to_string(),
546                lower_bound: AffineExpr::constant(0),
547                upper_bound: AffineExpr::constant(100),
548                step: 1,
549                depth: 1,
550            },
551        ];
552
553        let domain = IterationDomain::rectangular(vec![
554            ("i".to_string(), 0, 100),
555            ("j".to_string(), 0, 100),
556        ]);
557
558        Ok(LoopNest {
559            loops,
560            statements: Vec::new(),
561            dependencies: Vec::new(),
562            domain,
563        })
564    }
565
566    /// Compute optimal affine schedule
567    pub fn compute_schedule(&mut self, nest: &LoopNest) -> JitResult<Vec<AffineSchedule>> {
568        let mut schedules = Vec::new();
569
570        // Start with identity schedule
571        for stmt in &nest.statements {
572            let schedule = AffineSchedule::identity(nest.loops.len());
573            schedules.push(schedule);
574        }
575
576        // If no statements, create one for each loop
577        if schedules.is_empty() {
578            schedules.push(AffineSchedule::identity(nest.loops.len()));
579        }
580
581        // Apply transformations based on config
582        if self.config.enable_interchange {
583            schedules = self.apply_interchange(nest, schedules)?;
584        }
585
586        if self.config.enable_skewing {
587            schedules = self.apply_skewing(nest, schedules)?;
588        }
589
590        if self.config.enable_tiling {
591            schedules = self.apply_tiling(nest, schedules)?;
592        }
593
594        Ok(schedules)
595    }
596
597    /// Apply loop interchange
598    fn apply_interchange(
599        &mut self,
600        nest: &LoopNest,
601        schedules: Vec<AffineSchedule>,
602    ) -> JitResult<Vec<AffineSchedule>> {
603        let num_loops = nest.loops.len();
604
605        if num_loops < 2 {
606            return Ok(schedules);
607        }
608
609        // Simple heuristic: interchange if inner loop has better locality
610        let transform = TransformationMatrix::interchange(num_loops, 0, 1);
611
612        let new_schedules = schedules
613            .iter()
614            .map(|sched| transform.apply_schedule(sched))
615            .collect();
616
617        self.stats.loops_transformed += num_loops;
618
619        Ok(new_schedules)
620    }
621
622    /// Apply loop skewing
623    fn apply_skewing(
624        &mut self,
625        nest: &LoopNest,
626        schedules: Vec<AffineSchedule>,
627    ) -> JitResult<Vec<AffineSchedule>> {
628        let num_loops = nest.loops.len();
629
630        if num_loops < 2 {
631            return Ok(schedules);
632        }
633
634        // Check if skewing would help with dependencies
635        let has_diagonal_deps = nest
636            .dependencies
637            .iter()
638            .any(|dep| dep.distance.len() >= 2 && dep.distance[0] == dep.distance[1]);
639
640        if has_diagonal_deps {
641            let transform = TransformationMatrix::skew(num_loops, 0, 1, 1);
642            let new_schedules = schedules
643                .iter()
644                .map(|sched| transform.apply_schedule(sched))
645                .collect();
646            return Ok(new_schedules);
647        }
648
649        Ok(schedules)
650    }
651
652    /// Apply loop tiling
653    fn apply_tiling(
654        &mut self,
655        nest: &LoopNest,
656        schedules: Vec<AffineSchedule>,
657    ) -> JitResult<Vec<AffineSchedule>> {
658        // Tiling creates new loop dimensions
659        // Original: i, j  → Tiled: ii, jj, i, j
660        // where ii, jj are outer tile loops
661
662        let tile_size = self.config.tile_size as i64;
663        let num_loops = nest.loops.len();
664
665        // Create tiled schedule (simplified)
666        let mut new_schedules = Vec::new();
667
668        for schedule in schedules {
669            let mut tiled_dims = Vec::new();
670
671            // Outer tile loops (ii, jj, ...)
672            for dim in &schedule.dimensions {
673                // ii = floor(i / tile_size)
674                let tiled = dim.mul(1); // Simplified
675                tiled_dims.push(tiled);
676            }
677
678            // Inner tile loops (original dimensions)
679            tiled_dims.extend(schedule.dimensions.clone());
680
681            new_schedules.push(AffineSchedule {
682                dimensions: tiled_dims,
683            });
684        }
685
686        self.stats.loops_transformed += num_loops;
687
688        Ok(new_schedules)
689    }
690
691    /// Analyze dependencies
692    pub fn analyze_dependencies(&self, nest: &LoopNest) -> Vec<Dependence> {
693        let mut dependencies = Vec::new();
694
695        // Simplified: check all statement pairs
696        for (i, stmt1) in nest.statements.iter().enumerate() {
697            for (j, stmt2) in nest.statements.iter().enumerate().skip(i) {
698                if let Some(dep) = self.check_dependence(stmt1, stmt2) {
699                    dependencies.push(dep);
700                }
701            }
702        }
703
704        dependencies
705    }
706
707    /// Check for dependence between two statements
708    fn check_dependence(&self, stmt1: &Statement, stmt2: &Statement) -> Option<Dependence> {
709        // Check if there's a memory access conflict
710        for access1 in &stmt1.accesses {
711            for access2 in &stmt2.accesses {
712                if access1.array_name == access2.array_name {
713                    let dep_type = self.classify_dependence(access1, access2);
714
715                    if dep_type != DependenceType::Input {
716                        // Compute dependence polyhedron (simplified)
717                        let polyhedron = Polyhedron::empty(2);
718
719                        return Some(Dependence {
720                            source: stmt1.id,
721                            target: stmt2.id,
722                            dep_type,
723                            polyhedron,
724                            distance: vec![1, 0], // Simplified
725                        });
726                    }
727                }
728            }
729        }
730
731        None
732    }
733
734    /// Classify dependence type
735    fn classify_dependence(
736        &self,
737        access1: &MemoryAccess,
738        access2: &MemoryAccess,
739    ) -> DependenceType {
740        match (&access1.access_type, &access2.access_type) {
741            (AccessType::Write, AccessType::Read) => DependenceType::Flow,
742            (AccessType::Read, AccessType::Write) => DependenceType::Anti,
743            (AccessType::Write, AccessType::Write) => DependenceType::Output,
744            (AccessType::Read, AccessType::Read) => DependenceType::Input,
745            _ => DependenceType::Flow,
746        }
747    }
748
749    /// Check if loop fusion is legal
750    pub fn is_fusion_legal(&self, nest1: &LoopNest, nest2: &LoopNest) -> bool {
751        // Check if loops have compatible bounds and no conflicting dependencies
752        if nest1.loops.len() != nest2.loops.len() {
753            return false;
754        }
755
756        // Check bounds compatibility
757        for (loop1, loop2) in nest1.loops.iter().zip(nest2.loops.iter()) {
758            if loop1.lower_bound != loop2.lower_bound || loop1.upper_bound != loop2.upper_bound {
759                return false;
760            }
761        }
762
763        true
764    }
765
766    /// Get optimization statistics
767    pub fn statistics(&self) -> &OptimizationStats {
768        &self.stats
769    }
770
771    /// Reset statistics
772    pub fn reset_stats(&mut self) {
773        self.stats = OptimizationStats::default();
774    }
775}
776
777impl Default for PolyhedralOptimizer {
778    fn default() -> Self {
779        Self::new()
780    }
781}
782
783// ============================================================================
784// Optimization Strategies
785// ============================================================================
786
787/// Strategy for selecting polyhedral transformations
788#[derive(Debug, Clone)]
789pub enum OptimizationStrategy {
790    /// Maximize parallelism
791    MaxParallelism,
792
793    /// Maximize cache locality
794    MaxLocality,
795
796    /// Balanced approach
797    Balanced,
798
799    /// Custom transformation sequence
800    Custom(Vec<TransformationType>),
801}
802
803/// Types of polyhedral transformations
804#[derive(Debug, Clone, PartialEq)]
805pub enum TransformationType {
806    Interchange(usize, usize),
807    Skewing(usize, usize, i64),
808    Tiling(Vec<usize>),
809    Fusion(Vec<usize>),
810    Distribution(Vec<usize>),
811    Reversal(usize),
812}
813
814// ============================================================================
815// Tests
816// ============================================================================
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821
822    #[test]
823    fn test_affine_expr() {
824        let expr1 = AffineExpr::constant(5);
825        let expr2 = AffineExpr::variable("x".to_string());
826
827        let sum = expr1.add(&expr2);
828        assert_eq!(sum.constant, 5);
829        assert_eq!(sum.coefficients.get("x"), Some(&1));
830
831        let scaled = expr2.mul(3);
832        assert_eq!(scaled.coefficients.get("x"), Some(&3));
833    }
834
835    #[test]
836    fn test_affine_evaluation() {
837        let mut expr = AffineExpr::constant(10);
838        expr.coefficients.insert("x".to_string(), 2);
839        expr.coefficients.insert("y".to_string(), 3);
840
841        let mut vars = HashMap::new();
842        vars.insert("x".to_string(), 4);
843        vars.insert("y".to_string(), 5);
844
845        let result = expr.evaluate(&vars);
846        assert_eq!(result, 10 + 2 * 4 + 3 * 5); // 10 + 8 + 15 = 33
847    }
848
849    #[test]
850    fn test_polyhedron() {
851        let mut poly = Polyhedron::empty(2);
852        assert_eq!(poly.dimension, 2);
853
854        poly.add_constraint(AffineConstraint {
855            expression: AffineExpr::variable("x".to_string()),
856            constraint_type: ConstraintType::Inequality,
857        });
858
859        assert_eq!(poly.constraints.len(), 1);
860        assert!(!poly.is_empty());
861    }
862
863    #[test]
864    fn test_iteration_domain() {
865        let domain =
866            IterationDomain::rectangular(vec![("i".to_string(), 0, 10), ("j".to_string(), 0, 20)]);
867
868        assert_eq!(domain.variables.len(), 2);
869        assert_eq!(domain.polyhedron.constraints.len(), 4); // 2 bounds × 2 variables
870    }
871
872    #[test]
873    fn test_transformation_matrix() {
874        let identity = TransformationMatrix::identity(3);
875        assert_eq!(identity.matrix[0][0], 1);
876        assert_eq!(identity.matrix[0][1], 0);
877
878        let interchange = TransformationMatrix::interchange(3, 0, 1);
879        assert_eq!(interchange.matrix[0][0], 0);
880        assert_eq!(interchange.matrix[0][1], 1);
881    }
882
883    #[test]
884    fn test_polyhedral_optimizer() {
885        let optimizer = PolyhedralOptimizer::new();
886        assert!(optimizer.config.enable_tiling);
887        assert!(optimizer.config.enable_fusion);
888    }
889
890    #[test]
891    fn test_schedule_computation() {
892        use crate::graph::GraphBuilder;
893        use torsh_core::{DType, Shape};
894
895        let mut optimizer = PolyhedralOptimizer::new();
896
897        let mut builder = GraphBuilder::new();
898        let x = builder.add_input("x".to_string(), Shape::new(vec![10, 10]), DType::F32);
899        builder.mark_output(x).unwrap();
900
901        let graph = builder.build().unwrap();
902        let nest = optimizer.extract_loop_nest(&graph).unwrap();
903        let schedules = optimizer.compute_schedule(&nest).unwrap();
904
905        assert!(!schedules.is_empty());
906    }
907
908    #[test]
909    fn test_dependence_analysis() {
910        let optimizer = PolyhedralOptimizer::new();
911
912        let stmt1 = Statement {
913            id: 0,
914            node_id: 0.into(),
915            domain: Polyhedron::empty(2),
916            schedule: AffineSchedule::identity(2),
917            accesses: vec![MemoryAccess {
918                array_name: "A".to_string(),
919                access_fn: vec![AffineExpr::variable("i".to_string())],
920                access_type: AccessType::Write,
921            }],
922        };
923
924        let stmt2 = Statement {
925            id: 1,
926            node_id: 1.into(),
927            domain: Polyhedron::empty(2),
928            schedule: AffineSchedule::identity(2),
929            accesses: vec![MemoryAccess {
930                array_name: "A".to_string(),
931                access_fn: vec![AffineExpr::variable("i".to_string())],
932                access_type: AccessType::Read,
933            }],
934        };
935
936        let dep = optimizer.check_dependence(&stmt1, &stmt2);
937        assert!(dep.is_some());
938        assert_eq!(dep.unwrap().dep_type, DependenceType::Flow);
939    }
940
941    #[test]
942    fn test_fusion_legality() {
943        let optimizer = PolyhedralOptimizer::new();
944
945        let nest1 = LoopNest {
946            loops: vec![Loop {
947                variable: "i".to_string(),
948                lower_bound: AffineExpr::constant(0),
949                upper_bound: AffineExpr::constant(10),
950                step: 1,
951                depth: 0,
952            }],
953            statements: Vec::new(),
954            dependencies: Vec::new(),
955            domain: IterationDomain::rectangular(vec![("i".to_string(), 0, 10)]),
956        };
957
958        let nest2 = nest1.clone();
959
960        assert!(optimizer.is_fusion_legal(&nest1, &nest2));
961    }
962}