Skip to main content

trustformers_core/compiler/
mod.rs

1/*!
2# Compiler Optimization Module
3
4This module provides comprehensive compiler optimizations for TrustformeRS including:
5
6- **JIT Compilation**: Just-in-time compilation for dynamic graphs
7- **Kernel Fusion**: Automatic fusion of compatible operations
8- **Graph Optimization**: Multi-pass optimization of computation graphs
9- **MLIR Integration**: Integration with MLIR for advanced compiler optimizations
10
11## Features
12
13- Automatic operation fusion for reduced memory bandwidth
14- Graph-level optimizations including constant folding and dead code elimination
15- JIT compilation for dynamic computation graphs
16- MLIR-based optimizations for maximum performance
17- Adaptive optimization strategies based on hardware characteristics
18- Comprehensive performance analysis and optimization recommendations
19
20## Usage
21
22```rust,no_run
23use trustformers_core::compiler::{CompilerOptimizer, OptimizationLevel, ComputationGraph};
24
25# fn main() -> Result<(), Box<dyn std::error::Error>> {
26let computation_graph = ComputationGraph::new();
27let mut optimizer = CompilerOptimizer::with_optimization_level(OptimizationLevel::Aggressive)?;
28let optimization_result = optimizer.optimize_graph(computation_graph)?;
29let compiled_model = optimizer.compile_graph(optimization_result.optimized_graph)?;
30# Ok(())
31# }
32```
33*/
34
35pub mod analysis;
36pub mod graph_optimizer;
37pub mod jit_compiler;
38pub mod kernel_fusion;
39pub mod mlir_backend;
40pub mod passes;
41
42// Re-export key types for convenience
43pub use analysis::{
44    BottleneckInfo, DependencyAnalysis, GraphAnalyzer, HardwareUtilization, MemoryAnalysis,
45    PerformanceAnalysis,
46};
47pub use jit_compiler::{
48    IRInstruction, IROpcode, IntermediateRepresentation, JitBackend, JitCompiler,
49};
50pub use kernel_fusion::{FusionGroup, FusionPattern, FusionResult, FusionType, KernelFusion};
51pub use mlir_backend::{DialectSupport, MlirBackend};
52pub use passes::{
53    CommonSubexpressionEliminationPass, ConstantFoldingPass, DeadCodeEliminationPass,
54    MemoryLayoutOptimizationPass, OperationFusionPass, PassManager,
55};
56
57use crate::errors::invalid_input;
58use crate::errors::TrustformersError;
59use serde::{Deserialize, Serialize};
60use std::collections::HashMap;
61
62/// Optimization level for compiler optimizations
63#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
64pub enum OptimizationLevel {
65    /// No optimizations (debug mode)
66    None,
67    /// Basic optimizations with minimal compilation time
68    Basic,
69    /// Standard optimizations with balanced compilation time/performance
70    #[default]
71    Standard,
72    /// Aggressive optimizations with longer compilation time
73    Aggressive,
74    /// Maximum optimizations (may significantly increase compilation time)
75    Maximum,
76}
77
78/// Configuration for compiler optimizations
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CompilerConfig {
81    /// Optimization level to apply
82    pub optimization_level: OptimizationLevel,
83    /// Enable JIT compilation
84    pub enable_jit: bool,
85    /// Enable kernel fusion
86    pub enable_fusion: bool,
87    /// Enable graph optimizations
88    pub enable_graph_opts: bool,
89    /// Enable MLIR backend
90    pub enable_mlir: bool,
91    /// Target hardware characteristics
92    pub target_hardware: HardwareTarget,
93    /// Maximum compilation time in seconds (0 = no limit)
94    pub max_compile_time: u64,
95    /// Cache compiled kernels
96    pub enable_cache: bool,
97    /// Additional compiler flags
98    pub compiler_flags: Vec<String>,
99}
100
101impl Default for CompilerConfig {
102    fn default() -> Self {
103        Self {
104            optimization_level: OptimizationLevel::Standard,
105            enable_jit: true,
106            enable_fusion: true,
107            enable_graph_opts: true,
108            enable_mlir: false, // Experimental
109            target_hardware: HardwareTarget::default(),
110            max_compile_time: 300, // 5 minutes
111            enable_cache: true,
112            compiler_flags: Vec::new(),
113        }
114    }
115}
116
117/// Target hardware characteristics for optimization
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct HardwareTarget {
120    /// Target device type
121    pub device_type: DeviceType,
122    /// Number of compute units (cores, SMs, etc.)
123    pub compute_units: u32,
124    /// Memory bandwidth in GB/s
125    pub memory_bandwidth: f64,
126    /// Cache sizes in bytes
127    pub cache_sizes: Vec<u64>,
128    /// Supports specific instruction sets
129    pub instruction_sets: Vec<String>,
130}
131
132impl Default for HardwareTarget {
133    fn default() -> Self {
134        Self {
135            device_type: DeviceType::CPU,
136            compute_units: 8,
137            memory_bandwidth: 100.0,
138            cache_sizes: vec![32768, 262144, 8388608], // L1, L2, L3
139            instruction_sets: vec!["AVX2".to_string(), "FMA".to_string()],
140        }
141    }
142}
143
144/// Device type for optimization targeting
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
146pub enum DeviceType {
147    CPU,
148    GPU,
149    TPU,
150    DSP,
151    FPGA,
152    Custom(u32),
153}
154
155/// Compilation statistics and metrics
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct CompilationStats {
158    /// Total compilation time in milliseconds
159    pub compilation_time_ms: u64,
160    /// Number of operations in original graph
161    pub original_ops: usize,
162    /// Number of operations after optimization
163    pub optimized_ops: usize,
164    /// Number of kernels fused
165    pub fused_kernels: usize,
166    /// Estimated performance improvement
167    pub performance_gain: f64,
168    /// Memory usage reduction
169    pub memory_reduction: f64,
170    /// Applied optimization passes
171    pub applied_passes: Vec<String>,
172}
173
174/// Main compiler optimizer interface
175pub struct CompilerOptimizer {
176    config: CompilerConfig,
177    graph_optimizer: graph_optimizer::GraphOptimizer,
178    jit_compiler: jit_compiler::JitCompiler,
179    kernel_fusion: kernel_fusion::KernelFusion,
180    mlir_backend: Option<mlir_backend::MlirBackend>,
181    graph_analyzer: analysis::GraphAnalyzer,
182    pass_manager: passes::PassManager,
183    compilation_cache: HashMap<String, Vec<u8>>,
184}
185
186impl CompilerOptimizer {
187    /// Create a new compiler optimizer with the given configuration
188    pub fn new(config: CompilerConfig) -> Result<Self, TrustformersError> {
189        let graph_optimizer = graph_optimizer::GraphOptimizer::new(&config)?;
190        let jit_compiler = jit_compiler::JitCompiler::new(&config)?;
191        let kernel_fusion = kernel_fusion::KernelFusion::new(&config)?;
192        let mlir_backend = if config.enable_mlir {
193            Some(mlir_backend::MlirBackend::new(&config)?)
194        } else {
195            None
196        };
197
198        let graph_analyzer = analysis::GraphAnalyzer::new(config.target_hardware.clone());
199        let pass_manager = match config.optimization_level {
200            OptimizationLevel::None => passes::PassManager::new(),
201            OptimizationLevel::Basic | OptimizationLevel::Standard => {
202                passes::PassManager::default_pipeline()
203            },
204            OptimizationLevel::Aggressive | OptimizationLevel::Maximum => {
205                passes::PassManager::aggressive_pipeline()
206            },
207        };
208
209        Ok(Self {
210            config,
211            graph_optimizer,
212            jit_compiler,
213            kernel_fusion,
214            mlir_backend,
215            graph_analyzer,
216            pass_manager,
217            compilation_cache: HashMap::new(),
218        })
219    }
220
221    /// Create a new compiler optimizer with a specific optimization level
222    pub fn with_optimization_level(level: OptimizationLevel) -> Result<Self, TrustformersError> {
223        let config = CompilerConfig {
224            optimization_level: level,
225            ..Default::default()
226        };
227        Self::new(config)
228    }
229
230    /// Get the current configuration
231    pub fn config(&self) -> &CompilerConfig {
232        &self.config
233    }
234
235    /// Update the configuration
236    pub fn set_config(&mut self, config: CompilerConfig) -> Result<(), TrustformersError> {
237        self.config = config;
238        self.graph_optimizer.update_config(&self.config)?;
239        self.jit_compiler.update_config(&self.config)?;
240        self.kernel_fusion.update_config(&self.config)?;
241        if let Some(ref mut mlir) = self.mlir_backend {
242            mlir.update_config(&self.config)?;
243        }
244        Ok(())
245    }
246
247    /// Clear the compilation cache
248    pub fn clear_cache(&mut self) {
249        self.compilation_cache.clear();
250        self.jit_compiler.clear_cache();
251        if let Some(ref mut mlir) = self.mlir_backend {
252            mlir.clear_cache();
253        }
254    }
255
256    /// Get compilation cache statistics
257    pub fn cache_stats(&self) -> HashMap<String, usize> {
258        let mut stats = HashMap::new();
259        stats.insert("cache_entries".to_string(), self.compilation_cache.len());
260        stats.insert(
261            "jit_cache_entries".to_string(),
262            self.jit_compiler.cache_size(),
263        );
264        if let Some(ref mlir) = self.mlir_backend {
265            stats.insert("mlir_cache_entries".to_string(), mlir.cache_size());
266        }
267        stats
268    }
269
270    /// Optimize a computation graph using all enabled optimizations
271    pub fn optimize_graph(
272        &mut self,
273        mut graph: ComputationGraph,
274    ) -> Result<OptimizationResult, TrustformersError> {
275        let start_time = std::time::Instant::now();
276        let original_ops = graph.nodes.len();
277        let original_compute_cost = graph.total_compute_cost();
278        let original_memory_cost = graph.total_memory_cost();
279
280        // Apply optimization passes
281        let pass_results = if self.config.enable_graph_opts {
282            self.pass_manager.run(&mut graph)?
283        } else {
284            Vec::new()
285        };
286
287        // Apply kernel fusion
288        let fusion_result = if self.config.enable_fusion {
289            self.kernel_fusion.apply_fusion(&mut graph)?
290        } else {
291            kernel_fusion::FusionResult {
292                fused_operations: 0,
293                estimated_speedup: 1.0,
294                fusion_time_ms: 0,
295                applied_patterns: Vec::new(),
296            }
297        };
298
299        let optimized_ops = graph.nodes.len();
300        let optimized_compute_cost = graph.total_compute_cost();
301        let optimized_memory_cost = graph.total_memory_cost();
302
303        let optimization_time = start_time.elapsed();
304
305        // Calculate improvements
306        let compute_improvement = if original_compute_cost > 0.0 {
307            (original_compute_cost - optimized_compute_cost) / original_compute_cost
308        } else {
309            0.0
310        };
311
312        let memory_improvement = if original_memory_cost > 0.0 {
313            (original_memory_cost - optimized_memory_cost) / original_memory_cost
314        } else {
315            0.0
316        };
317
318        let applied_passes: Vec<String> = pass_results
319            .iter()
320            .enumerate()
321            .filter(|(_, result)| result.changed)
322            .map(|(i, _)| format!("pass_{}", i))
323            .collect();
324
325        Ok(OptimizationResult {
326            optimized_graph: graph,
327            original_operations: original_ops,
328            optimized_operations: optimized_ops,
329            fused_operations: fusion_result.fused_operations,
330            compute_improvement,
331            memory_improvement,
332            estimated_speedup: fusion_result.estimated_speedup,
333            optimization_time_ms: optimization_time.as_millis() as u64,
334            applied_passes,
335            fusion_patterns: fusion_result.applied_patterns,
336        })
337    }
338
339    /// Compile an optimized graph to executable code
340    pub fn compile_graph(
341        &mut self,
342        graph: ComputationGraph,
343    ) -> Result<CompilationResult, TrustformersError> {
344        if self.config.enable_jit {
345            let result = self.jit_compiler.compile(graph)?;
346            Ok(result)
347        } else {
348            // Fallback to basic compilation
349            let stats = CompilationStats {
350                compilation_time_ms: 0,
351                original_ops: graph.nodes.len(),
352                optimized_ops: graph.nodes.len(),
353                fused_kernels: 0,
354                performance_gain: 1.0,
355                memory_reduction: 0.0,
356                applied_passes: vec!["basic".to_string()],
357            };
358
359            Ok(CompilationResult {
360                compiled_code: vec![0u8; 64], // Placeholder
361                stats,
362                metadata: HashMap::new(),
363            })
364        }
365    }
366
367    /// Perform comprehensive performance analysis
368    pub fn analyze_performance(
369        &mut self,
370        graph: &ComputationGraph,
371    ) -> Result<analysis::PerformanceAnalysis, TrustformersError> {
372        self.graph_analyzer.analyze_performance(graph)
373    }
374
375    /// Perform memory usage analysis
376    pub fn analyze_memory(
377        &mut self,
378        graph: &ComputationGraph,
379    ) -> Result<analysis::MemoryAnalysis, TrustformersError> {
380        self.graph_analyzer.analyze_memory(graph)
381    }
382
383    /// Perform dependency analysis
384    pub fn analyze_dependencies(
385        &mut self,
386        graph: &ComputationGraph,
387    ) -> Result<analysis::DependencyAnalysis, TrustformersError> {
388        self.graph_analyzer.analyze_dependencies(graph)
389    }
390
391    /// Generate optimization recommendations for a graph
392    pub fn recommend_optimizations(
393        &mut self,
394        graph: &ComputationGraph,
395    ) -> Result<OptimizationRecommendations, TrustformersError> {
396        let perf_analysis = self.analyze_performance(graph)?;
397        let memory_analysis = self.analyze_memory(graph)?;
398
399        let mut recommendations = Vec::new();
400
401        // Performance-based recommendations
402        for bottleneck in &perf_analysis.bottlenecks {
403            if bottleneck.criticality_score > 50.0 {
404                recommendations.push(OptimizationRecommendation {
405                    category: RecommendationCategory::Performance,
406                    priority: RecommendationPriority::High,
407                    description: format!(
408                        "Optimize {} operation (node {}) - {}% of total time",
409                        bottleneck.operation_type, bottleneck.node_id, bottleneck.criticality_score
410                    ),
411                    suggested_actions: bottleneck.optimization_suggestions.clone(),
412                    estimated_benefit: bottleneck.criticality_score / 100.0,
413                });
414            }
415        }
416
417        // Memory-based recommendations
418        if memory_analysis.peak_memory_usage > 8 * 1024 * 1024 * 1024 {
419            // > 8GB
420            recommendations.push(OptimizationRecommendation {
421                category: RecommendationCategory::Memory,
422                priority: RecommendationPriority::Medium,
423                description: "High memory usage detected - consider memory optimization"
424                    .to_string(),
425                suggested_actions: vec![
426                    "Enable gradient checkpointing".to_string(),
427                    "Use mixed precision training".to_string(),
428                    "Consider model parallelism".to_string(),
429                ],
430                estimated_benefit: 0.3,
431            });
432        }
433
434        // Parallelization recommendations
435        if perf_analysis.parallelizable_operations.len() > 5 {
436            recommendations.push(OptimizationRecommendation {
437                category: RecommendationCategory::Parallelization,
438                priority: RecommendationPriority::Medium,
439                description: format!(
440                    "Found {} parallelizable operation groups",
441                    perf_analysis.parallelizable_operations.len()
442                ),
443                suggested_actions: vec![
444                    "Enable multi-threading".to_string(),
445                    "Consider GPU acceleration".to_string(),
446                    "Use parallel execution backends".to_string(),
447                ],
448                estimated_benefit: 0.4,
449            });
450        }
451
452        // Hardware utilization recommendations
453        if perf_analysis.hardware_utilization.compute_utilization < 0.5 {
454            recommendations.push(OptimizationRecommendation {
455                category: RecommendationCategory::Hardware,
456                priority: RecommendationPriority::Low,
457                description: "Low compute utilization detected".to_string(),
458                suggested_actions: vec![
459                    "Increase batch size".to_string(),
460                    "Enable operation fusion".to_string(),
461                    "Consider different hardware targets".to_string(),
462                ],
463                estimated_benefit: 0.2,
464            });
465        }
466
467        // Sort by priority and benefit
468        recommendations.sort_by(|a, b| match (a.priority.clone(), b.priority.clone()) {
469            (RecommendationPriority::High, RecommendationPriority::High) => b
470                .estimated_benefit
471                .partial_cmp(&a.estimated_benefit)
472                .unwrap_or(std::cmp::Ordering::Equal),
473            (RecommendationPriority::High, _) => std::cmp::Ordering::Less,
474            (_, RecommendationPriority::High) => std::cmp::Ordering::Greater,
475            _ => b
476                .estimated_benefit
477                .partial_cmp(&a.estimated_benefit)
478                .unwrap_or(std::cmp::Ordering::Equal),
479        });
480
481        Ok(OptimizationRecommendations {
482            recommendations,
483            overall_score: self.calculate_optimization_score(graph)?,
484            target_hardware: self.config.target_hardware.clone(),
485        })
486    }
487
488    /// Calculate an overall optimization score for the graph
489    fn calculate_optimization_score(
490        &mut self,
491        graph: &ComputationGraph,
492    ) -> Result<f64, TrustformersError> {
493        let perf_analysis = self.analyze_performance(graph)?;
494
495        // Combine various metrics into a single score (0-100)
496        let utilization_score = perf_analysis.hardware_utilization.compute_utilization * 25.0;
497        let balance_score = perf_analysis.load_balance_score * 25.0;
498        let parallel_score = perf_analysis.hardware_utilization.parallel_efficiency * 25.0;
499        let memory_score =
500            (1.0 - perf_analysis.hardware_utilization.memory_utilization.min(1.0)) * 25.0;
501
502        Ok(utilization_score + balance_score + parallel_score + memory_score)
503    }
504
505    /// Get comprehensive compiler statistics
506    pub fn get_comprehensive_stats(&self) -> CompilerStatistics {
507        CompilerStatistics {
508            jit_stats: self.jit_compiler.get_stats().clone(),
509            fusion_stats: self.kernel_fusion.get_stats().clone(),
510            cache_stats: self.cache_stats(),
511            config: self.config.clone(),
512        }
513    }
514}
515
516/// Result of compilation process
517#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct CompilationResult {
519    /// Compiled bytecode or machine code
520    pub compiled_code: Vec<u8>,
521    /// Compilation statistics
522    pub stats: CompilationStats,
523    /// Optimization metadata
524    pub metadata: HashMap<String, String>,
525}
526
527/// Optimization pass result
528#[derive(Debug)]
529pub struct PassResult {
530    /// Whether the pass made changes
531    pub changed: bool,
532    /// Statistics about the pass
533    pub stats: HashMap<String, f64>,
534    /// Pass-specific metadata
535    pub metadata: HashMap<String, String>,
536}
537
538/// Simplified computation graph representation for optimization
539#[derive(Debug, Clone, Serialize, Deserialize)]
540pub struct ComputationGraph {
541    /// Graph nodes (operations)
542    pub nodes: Vec<GraphNode>,
543    /// Graph edges (data dependencies)
544    pub edges: Vec<GraphEdge>,
545    /// Graph metadata
546    pub metadata: HashMap<String, String>,
547}
548
549/// Graph node representing an operation
550#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct GraphNode {
552    /// Unique node ID
553    pub id: usize,
554    /// Operation type
555    pub op_type: String,
556    /// Node attributes
557    pub attributes: HashMap<String, String>,
558    /// Input tensor shapes
559    pub input_shapes: Vec<Vec<usize>>,
560    /// Output tensor shapes
561    pub output_shapes: Vec<Vec<usize>>,
562    /// Estimated computation cost
563    pub compute_cost: f64,
564    /// Estimated memory cost
565    pub memory_cost: f64,
566}
567
568/// Graph edge representing data flow
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct GraphEdge {
571    /// Source node ID
572    pub from: usize,
573    /// Destination node ID
574    pub to: usize,
575    /// Output index from source
576    pub output_idx: usize,
577    /// Input index to destination
578    pub input_idx: usize,
579    /// Tensor shape
580    pub shape: Vec<usize>,
581    /// Data type
582    pub dtype: String,
583}
584
585impl ComputationGraph {
586    /// Create a new empty computation graph
587    pub fn new() -> Self {
588        Self {
589            nodes: Vec::new(),
590            edges: Vec::new(),
591            metadata: HashMap::new(),
592        }
593    }
594
595    /// Add a node to the graph
596    pub fn add_node(&mut self, node: GraphNode) -> usize {
597        let id = self.nodes.len();
598        self.nodes.push(node);
599        id
600    }
601
602    /// Add an edge to the graph
603    pub fn add_edge(&mut self, edge: GraphEdge) {
604        self.edges.push(edge);
605    }
606
607    /// Get node by ID
608    pub fn get_node(&self, id: usize) -> Option<&GraphNode> {
609        self.nodes.get(id)
610    }
611
612    /// Get mutable node by ID
613    pub fn get_node_mut(&mut self, id: usize) -> Option<&mut GraphNode> {
614        self.nodes.get_mut(id)
615    }
616
617    /// Get all edges connected to a node
618    pub fn get_node_edges(&self, node_id: usize) -> Vec<&GraphEdge> {
619        self.edges
620            .iter()
621            .filter(|edge| edge.from == node_id || edge.to == node_id)
622            .collect()
623    }
624
625    /// Validate the graph structure
626    pub fn validate(&self) -> Result<(), TrustformersError> {
627        // Check for invalid node references in edges
628        for edge in &self.edges {
629            if edge.from >= self.nodes.len() || edge.to >= self.nodes.len() {
630                return Err(invalid_input("Edge references non-existent node"));
631            }
632        }
633
634        // Check for cycles (simplified)
635        if self.has_cycles() {
636            return Err(invalid_input("Graph contains cycles"));
637        }
638
639        Ok(())
640    }
641
642    /// Check if the graph has cycles (simplified DFS)
643    fn has_cycles(&self) -> bool {
644        let mut visited = vec![false; self.nodes.len()];
645        let mut rec_stack = vec![false; self.nodes.len()];
646
647        for i in 0..self.nodes.len() {
648            if !visited[i] && self.dfs_has_cycle(i, &mut visited, &mut rec_stack) {
649                return true;
650            }
651        }
652        false
653    }
654
655    fn dfs_has_cycle(&self, node: usize, visited: &mut [bool], rec_stack: &mut [bool]) -> bool {
656        visited[node] = true;
657        rec_stack[node] = true;
658
659        for edge in &self.edges {
660            if edge.from == node {
661                let next = edge.to;
662                if !visited[next] && self.dfs_has_cycle(next, visited, rec_stack) {
663                    return true;
664                }
665                if rec_stack[next] {
666                    return true;
667                }
668            }
669        }
670
671        rec_stack[node] = false;
672        false
673    }
674
675    /// Calculate total estimated compute cost
676    pub fn total_compute_cost(&self) -> f64 {
677        self.nodes.iter().map(|node| node.compute_cost).sum()
678    }
679
680    /// Calculate total estimated memory cost
681    pub fn total_memory_cost(&self) -> f64 {
682        self.nodes.iter().map(|node| node.memory_cost).sum()
683    }
684}
685
686impl Default for ComputationGraph {
687    fn default() -> Self {
688        Self::new()
689    }
690}
691
692/// Result of graph optimization process
693#[derive(Debug)]
694pub struct OptimizationResult {
695    /// The optimized computation graph
696    pub optimized_graph: ComputationGraph,
697    /// Number of operations in original graph
698    pub original_operations: usize,
699    /// Number of operations after optimization
700    pub optimized_operations: usize,
701    /// Number of operations that were fused
702    pub fused_operations: usize,
703    /// Compute cost improvement (0.0 to 1.0)
704    pub compute_improvement: f64,
705    /// Memory cost improvement (0.0 to 1.0)
706    pub memory_improvement: f64,
707    /// Estimated overall speedup
708    pub estimated_speedup: f64,
709    /// Time spent on optimization in milliseconds
710    pub optimization_time_ms: u64,
711    /// List of applied optimization passes
712    pub applied_passes: Vec<String>,
713    /// List of applied fusion patterns
714    pub fusion_patterns: Vec<String>,
715}
716
717/// Comprehensive compiler statistics
718#[derive(Debug)]
719pub struct CompilerStatistics {
720    /// JIT compiler statistics
721    pub jit_stats: jit_compiler::CompilationStatistics,
722    /// Kernel fusion statistics
723    pub fusion_stats: kernel_fusion::FusionStatistics,
724    /// Cache usage statistics
725    pub cache_stats: HashMap<String, usize>,
726    /// Current configuration
727    pub config: CompilerConfig,
728}
729
730/// Optimization recommendations for a computation graph
731#[derive(Debug)]
732pub struct OptimizationRecommendations {
733    /// List of specific recommendations
734    pub recommendations: Vec<OptimizationRecommendation>,
735    /// Overall optimization score (0-100)
736    pub overall_score: f64,
737    /// Target hardware configuration
738    pub target_hardware: HardwareTarget,
739}
740
741/// Individual optimization recommendation
742#[derive(Debug)]
743pub struct OptimizationRecommendation {
744    /// Category of the recommendation
745    pub category: RecommendationCategory,
746    /// Priority level
747    pub priority: RecommendationPriority,
748    /// Human-readable description
749    pub description: String,
750    /// List of suggested actions
751    pub suggested_actions: Vec<String>,
752    /// Estimated benefit (0.0 to 1.0)
753    pub estimated_benefit: f64,
754}
755
756/// Categories of optimization recommendations
757#[derive(Debug, Clone, PartialEq)]
758pub enum RecommendationCategory {
759    Performance,
760    Memory,
761    Parallelization,
762    Hardware,
763    Compilation,
764}
765
766/// Priority levels for recommendations
767#[derive(Debug, Clone, PartialEq)]
768pub enum RecommendationPriority {
769    High,
770    Medium,
771    Low,
772}
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777
778    #[test]
779    fn test_compiler_config_default() {
780        let config = CompilerConfig::default();
781        assert_eq!(config.optimization_level, OptimizationLevel::Standard);
782        assert!(config.enable_jit);
783        assert!(config.enable_fusion);
784        assert!(config.enable_graph_opts);
785    }
786
787    #[test]
788    fn test_optimization_levels() {
789        assert_ne!(OptimizationLevel::None, OptimizationLevel::Maximum);
790        assert_eq!(OptimizationLevel::default(), OptimizationLevel::Standard);
791    }
792
793    #[test]
794    fn test_computation_graph_basic() {
795        let mut graph = ComputationGraph::new();
796
797        let node1 = GraphNode {
798            id: 0,
799            op_type: "MatMul".to_string(),
800            attributes: HashMap::new(),
801            input_shapes: vec![vec![128, 256], vec![256, 512]],
802            output_shapes: vec![vec![128, 512]],
803            compute_cost: 100.0,
804            memory_cost: 50.0,
805        };
806
807        let node2 = GraphNode {
808            id: 1,
809            op_type: "ReLU".to_string(),
810            attributes: HashMap::new(),
811            input_shapes: vec![vec![128, 512]],
812            output_shapes: vec![vec![128, 512]],
813            compute_cost: 10.0,
814            memory_cost: 5.0,
815        };
816
817        graph.add_node(node1);
818        graph.add_node(node2);
819
820        let edge = GraphEdge {
821            from: 0,
822            to: 1,
823            output_idx: 0,
824            input_idx: 0,
825            shape: vec![128, 512],
826            dtype: "f32".to_string(),
827        };
828
829        graph.add_edge(edge);
830
831        assert_eq!(graph.nodes.len(), 2);
832        assert_eq!(graph.edges.len(), 1);
833        assert_eq!(graph.total_compute_cost(), 110.0);
834        assert_eq!(graph.total_memory_cost(), 55.0);
835
836        assert!(graph.validate().is_ok());
837    }
838
839    #[test]
840    fn test_compiler_optimizer_creation() {
841        let config = CompilerConfig::default();
842        let result = CompilerOptimizer::new(config);
843        assert!(result.is_ok());
844    }
845
846    #[test]
847    fn test_hardware_target_default() {
848        let target = HardwareTarget::default();
849        assert_eq!(target.device_type, DeviceType::CPU);
850        assert_eq!(target.compute_units, 8);
851        assert!(target.memory_bandwidth > 0.0);
852    }
853}