Skip to main content

tensorlogic_infer/
compilation.rs

1//! Graph compilation and caching infrastructure.
2//!
3//! This module provides ahead-of-time graph optimization and compilation capabilities:
4//! - `CompiledGraph`: Optimized, executable representation of computation graphs
5//! - `GraphCompiler`: Applies optimization passes and produces compiled graphs
6//! - `CompilationCache`: Caches compiled graphs to avoid recompilation
7//! - `TlCompilableExecutor`: Trait for executors that support graph compilation
8//!
9//! # Example
10//!
11//! ```
12//! use tensorlogic_infer::compilation::{GraphCompiler, CompilationConfig, OptimizationLevel};
13//! use tensorlogic_infer::DummyExecutor;
14//! use tensorlogic_ir::EinsumGraph;
15//!
16//! let mut compiler = GraphCompiler::new(CompilationConfig {
17//!     optimization_level: OptimizationLevel::Aggressive,
18//!     ..Default::default()
19//! });
20//!
21//! let graph = EinsumGraph::new();
22//! let compiled = compiler.compile(&graph).expect("unwrap");
23//! ```
24
25use crate::error::ExecutorError;
26use crate::memory::MemoryEstimator;
27use crate::optimization::{GraphOptimizer, OptimizationResult};
28use crate::scheduling::{ExecutionSchedule, Scheduler, SchedulingStrategy};
29use crate::shape::ShapeInferenceContext;
30use crate::validation::GraphValidator;
31use std::collections::HashMap;
32use std::sync::{Arc, RwLock};
33use std::time::{Duration, SystemTime};
34use tensorlogic_ir::EinsumGraph;
35
36/// Optimization level for graph compilation.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38pub enum OptimizationLevel {
39    /// No optimization - compile as-is
40    None,
41    /// Basic optimizations (dead code elimination, constant folding)
42    Basic,
43    /// Moderate optimizations (fusion, CSE, basic scheduling)
44    #[default]
45    Moderate,
46    /// Aggressive optimizations (all passes, advanced scheduling)
47    Aggressive,
48}
49
50/// Configuration for graph compilation.
51#[derive(Debug, Clone)]
52pub struct CompilationConfig {
53    /// Optimization level to apply
54    pub optimization_level: OptimizationLevel,
55    /// Whether to enable shape inference
56    pub enable_shape_inference: bool,
57    /// Whether to enable memory estimation
58    pub enable_memory_estimation: bool,
59    /// Target device for compilation (e.g., "cpu", "cuda:0")
60    pub target_device: Option<String>,
61    /// Maximum memory budget in bytes (None = unlimited)
62    pub memory_budget: Option<usize>,
63    /// Enable caching of intermediate results
64    pub enable_caching: bool,
65    /// Enable parallel execution planning
66    pub enable_parallelism: bool,
67}
68
69impl Default for CompilationConfig {
70    fn default() -> Self {
71        CompilationConfig {
72            optimization_level: OptimizationLevel::default(),
73            enable_shape_inference: true,
74            enable_memory_estimation: true,
75            target_device: None,
76            memory_budget: None,
77            enable_caching: true,
78            enable_parallelism: true,
79        }
80    }
81}
82
83/// Statistics about the compilation process.
84#[derive(Debug, Clone)]
85pub struct CompilationStats {
86    /// Time taken for compilation
87    pub compilation_time: Duration,
88    /// Number of nodes in original graph
89    pub original_nodes: usize,
90    /// Number of nodes after optimization
91    pub optimized_nodes: usize,
92    /// Number of fusion opportunities applied
93    pub fusions_applied: usize,
94    /// Number of dead nodes eliminated
95    pub dead_nodes_eliminated: usize,
96    /// Estimated memory usage in bytes
97    pub estimated_memory_bytes: usize,
98    /// Scheduled execution steps
99    pub execution_steps: usize,
100}
101
102impl Default for CompilationStats {
103    fn default() -> Self {
104        CompilationStats {
105            compilation_time: Duration::from_secs(0),
106            original_nodes: 0,
107            optimized_nodes: 0,
108            fusions_applied: 0,
109            dead_nodes_eliminated: 0,
110            estimated_memory_bytes: 0,
111            execution_steps: 0,
112        }
113    }
114}
115
116/// Compiled representation of a computation graph.
117///
118/// Contains the optimized graph, execution schedule, and metadata
119/// necessary for efficient execution.
120#[derive(Debug, Clone)]
121pub struct CompiledGraph {
122    /// The optimized graph
123    pub graph: EinsumGraph,
124    /// Execution schedule for the graph
125    pub schedule: ExecutionSchedule,
126    /// Shape information (if available)
127    pub shapes: HashMap<usize, Vec<usize>>,
128    /// Estimated memory usage per node
129    pub memory_usage: HashMap<usize, usize>,
130    /// Configuration used for compilation
131    pub config: CompilationConfig,
132    /// Compilation statistics
133    pub stats: CompilationStats,
134    /// Timestamp when compiled
135    pub compiled_at: SystemTime,
136}
137
138impl CompiledGraph {
139    /// Get the number of nodes in the compiled graph
140    pub fn node_count(&self) -> usize {
141        self.graph.nodes.len()
142    }
143
144    /// Get the total estimated memory usage
145    pub fn total_memory(&self) -> usize {
146        self.memory_usage.values().sum()
147    }
148
149    /// Check if this compiled graph is still valid
150    pub fn is_valid(&self) -> bool {
151        // Check if graph structure is valid
152        if self.graph.nodes.is_empty() {
153            return false;
154        }
155
156        // Check if schedule matches graph
157        if self.schedule.execution_order.len() != self.graph.nodes.len() {
158            return false;
159        }
160
161        true
162    }
163
164    /// Get a summary of the compiled graph
165    pub fn summary(&self) -> String {
166        format!(
167            "CompiledGraph: {} nodes, {} steps, {:.2}MB memory, compiled in {:.2}ms",
168            self.node_count(),
169            self.stats.execution_steps,
170            self.total_memory() as f64 / 1_000_000.0,
171            self.stats.compilation_time.as_secs_f64() * 1000.0
172        )
173    }
174}
175
176/// Graph compiler that applies optimization passes.
177pub struct GraphCompiler {
178    config: CompilationConfig,
179    optimizer: GraphOptimizer,
180    validator: GraphValidator,
181    scheduler: Scheduler,
182}
183
184impl GraphCompiler {
185    /// Create a new graph compiler with the given configuration.
186    pub fn new(config: CompilationConfig) -> Self {
187        GraphCompiler {
188            config,
189            optimizer: GraphOptimizer::new(),
190            validator: GraphValidator::new(),
191            scheduler: Scheduler::new(SchedulingStrategy::Balanced),
192        }
193    }
194
195    /// Create a compiler with default configuration.
196    pub fn with_default_config() -> Self {
197        Self::new(CompilationConfig::default())
198    }
199
200    /// Compile a graph with the configured optimization passes.
201    pub fn compile(&mut self, graph: &EinsumGraph) -> Result<CompiledGraph, ExecutorError> {
202        let start_time = SystemTime::now();
203        let original_nodes = graph.nodes.len();
204
205        // Validate the graph
206        let validation_result = self.validator.validate(graph);
207        if !validation_result.is_valid {
208            return Err(ExecutorError::GraphValidationError(format!(
209                "Graph validation failed: {}",
210                validation_result
211                    .errors
212                    .first()
213                    .map(|e| e.as_str())
214                    .unwrap_or("unknown error")
215            )));
216        }
217
218        // Clone the graph for optimization
219        let optimized_graph = graph.clone();
220
221        // Apply optimizations based on level
222        let opt_result = match self.config.optimization_level {
223            OptimizationLevel::None => OptimizationResult {
224                fusion_opportunities: vec![],
225                dead_nodes: vec![],
226                redundant_computations: vec![],
227                estimated_improvement: 0.0,
228            },
229            OptimizationLevel::Basic
230            | OptimizationLevel::Moderate
231            | OptimizationLevel::Aggressive => {
232                // Analyze the graph to find optimization opportunities
233                self.optimizer.analyze(&optimized_graph)
234            }
235        };
236
237        // Generate execution schedule
238        let schedule = self.scheduler.schedule(&optimized_graph);
239
240        // Shape inference (if enabled)
241        let shapes = if self.config.enable_shape_inference {
242            let _shape_ctx = ShapeInferenceContext::new();
243            // Infer shapes for all nodes
244            // Note: This is a simplified version - real implementation would need input shapes
245            HashMap::new()
246        } else {
247            HashMap::new()
248        };
249
250        // Memory estimation (if enabled)
251        let memory_usage = if self.config.enable_memory_estimation {
252            use crate::capabilities::DType;
253            let estimator = MemoryEstimator::new(DType::F32);
254            let estimate = estimator.estimate(&optimized_graph);
255            // Build per-node memory map from estimate
256            let mut per_node: HashMap<usize, usize> = HashMap::new();
257            for (idx, mem) in estimate.intermediate_memory.iter().enumerate() {
258                per_node.insert(idx, mem.bytes);
259            }
260            per_node
261        } else {
262            HashMap::new()
263        };
264
265        let compilation_time = start_time.elapsed().unwrap_or(Duration::from_secs(0));
266
267        let stats = CompilationStats {
268            compilation_time,
269            original_nodes,
270            optimized_nodes: optimized_graph.nodes.len(),
271            fusions_applied: opt_result.fusion_opportunities.len(),
272            dead_nodes_eliminated: opt_result.dead_nodes.len(),
273            estimated_memory_bytes: memory_usage.values().sum(),
274            execution_steps: schedule.execution_order.len(),
275        };
276
277        Ok(CompiledGraph {
278            graph: optimized_graph,
279            schedule,
280            shapes,
281            memory_usage,
282            config: self.config.clone(),
283            stats,
284            compiled_at: SystemTime::now(),
285        })
286    }
287
288    /// Update the compilation configuration.
289    pub fn set_config(&mut self, config: CompilationConfig) {
290        self.config = config;
291    }
292
293    /// Get the current configuration.
294    pub fn config(&self) -> &CompilationConfig {
295        &self.config
296    }
297}
298
299/// Cache key for compiled graphs.
300#[derive(Debug, Clone, PartialEq, Eq, Hash)]
301pub struct CompilationKey {
302    /// Hash of the graph structure
303    pub graph_hash: u64,
304    /// Optimization level used
305    pub optimization_level: OptimizationLevel,
306    /// Target device (if specified)
307    pub target_device: Option<String>,
308}
309
310impl CompilationKey {
311    /// Create a key from a graph and config.
312    pub fn new(graph: &EinsumGraph, config: &CompilationConfig) -> Self {
313        CompilationKey {
314            graph_hash: Self::hash_graph(graph),
315            optimization_level: config.optimization_level,
316            target_device: config.target_device.clone(),
317        }
318    }
319
320    /// Compute a hash of the graph structure.
321    fn hash_graph(graph: &EinsumGraph) -> u64 {
322        use std::collections::hash_map::DefaultHasher;
323        use std::hash::{Hash, Hasher};
324
325        let mut hasher = DefaultHasher::new();
326
327        // Hash node count
328        graph.nodes.len().hash(&mut hasher);
329
330        // Hash each node's operation type and connections
331        for node in &graph.nodes {
332            // Hash operation type
333            match &node.op {
334                tensorlogic_ir::OpType::Einsum { spec } => {
335                    "einsum".hash(&mut hasher);
336                    spec.hash(&mut hasher);
337                }
338                tensorlogic_ir::OpType::Reduce { op, axes } => {
339                    "reduce".hash(&mut hasher);
340                    op.hash(&mut hasher);
341                    axes.hash(&mut hasher);
342                }
343                tensorlogic_ir::OpType::ElemUnary { op } => {
344                    "elemunary".hash(&mut hasher);
345                    op.hash(&mut hasher);
346                }
347                tensorlogic_ir::OpType::ElemBinary { op } => {
348                    "elembinary".hash(&mut hasher);
349                    op.hash(&mut hasher);
350                }
351            }
352
353            // Hash inputs and outputs
354            node.inputs.hash(&mut hasher);
355            node.outputs.hash(&mut hasher);
356        }
357
358        hasher.finish()
359    }
360}
361
362/// Statistics for the compilation cache.
363#[derive(Debug, Clone, Default)]
364pub struct CacheStats {
365    /// Number of cache hits
366    pub hits: usize,
367    /// Number of cache misses
368    pub misses: usize,
369    /// Number of entries in cache
370    pub size: usize,
371    /// Total compilation time saved (approximate)
372    pub time_saved: Duration,
373}
374
375impl CacheStats {
376    /// Calculate hit rate (0.0 to 1.0)
377    pub fn hit_rate(&self) -> f64 {
378        let total = self.hits + self.misses;
379        if total == 0 {
380            0.0
381        } else {
382            self.hits as f64 / total as f64
383        }
384    }
385}
386
387/// Cache for compiled graphs.
388///
389/// Stores compiled graphs by their cache key to avoid recompilation
390/// of the same graph with the same configuration.
391pub struct CompilationCache {
392    cache: Arc<RwLock<HashMap<CompilationKey, Arc<CompiledGraph>>>>,
393    stats: Arc<RwLock<CacheStats>>,
394    max_size: usize,
395}
396
397impl CompilationCache {
398    /// Create a new compilation cache with the given maximum size.
399    pub fn new(max_size: usize) -> Self {
400        CompilationCache {
401            cache: Arc::new(RwLock::new(HashMap::new())),
402            stats: Arc::new(RwLock::new(CacheStats::default())),
403            max_size,
404        }
405    }
406
407    /// Create a cache with default size (100 entries).
408    pub fn with_default_size() -> Self {
409        Self::new(100)
410    }
411
412    /// Get a compiled graph from the cache.
413    pub fn get(&self, key: &CompilationKey) -> Option<Arc<CompiledGraph>> {
414        let cache = self.cache.read().expect("lock should not be poisoned");
415        let result = cache.get(key).cloned();
416
417        // Update stats
418        let mut stats = self.stats.write().expect("lock should not be poisoned");
419        if let Some(ref compiled) = result {
420            stats.hits += 1;
421            stats.time_saved += compiled.stats.compilation_time;
422        } else {
423            stats.misses += 1;
424        }
425
426        result
427    }
428
429    /// Insert a compiled graph into the cache.
430    pub fn insert(&self, key: CompilationKey, compiled: CompiledGraph) {
431        let mut cache = self.cache.write().expect("lock should not be poisoned");
432
433        // Evict oldest entry if at capacity
434        if cache.len() >= self.max_size && !cache.contains_key(&key) {
435            if let Some(oldest_key) = cache.keys().next().cloned() {
436                cache.remove(&oldest_key);
437            }
438        }
439
440        cache.insert(key, Arc::new(compiled));
441
442        // Update size stat
443        let mut stats = self.stats.write().expect("lock should not be poisoned");
444        stats.size = cache.len();
445    }
446
447    /// Clear the cache.
448    pub fn clear(&self) {
449        let mut cache = self.cache.write().expect("lock should not be poisoned");
450        cache.clear();
451
452        let mut stats = self.stats.write().expect("lock should not be poisoned");
453        stats.size = 0;
454    }
455
456    /// Get cache statistics.
457    pub fn stats(&self) -> CacheStats {
458        self.stats
459            .read()
460            .expect("lock should not be poisoned")
461            .clone()
462    }
463
464    /// Get the number of entries in the cache.
465    pub fn len(&self) -> usize {
466        self.cache
467            .read()
468            .expect("lock should not be poisoned")
469            .len()
470    }
471
472    /// Check if the cache is empty.
473    pub fn is_empty(&self) -> bool {
474        self.len() == 0
475    }
476}
477
478/// Trait for executors that support graph compilation.
479///
480/// Executors implementing this trait can execute pre-compiled graphs
481/// more efficiently than executing the original graph.
482pub trait TlCompilableExecutor {
483    /// Compile a graph for efficient execution.
484    ///
485    /// Returns a compiled graph that can be executed multiple times
486    /// with different inputs without recompiling.
487    fn compile_graph(
488        &mut self,
489        graph: &EinsumGraph,
490        config: &CompilationConfig,
491    ) -> Result<CompiledGraph, ExecutorError>;
492
493    /// Execute a compiled graph.
494    ///
495    /// This should be more efficient than executing the original graph
496    /// since optimization passes have already been applied.
497    fn execute_compiled(
498        &mut self,
499        compiled: &CompiledGraph,
500        inputs: &HashMap<usize, Box<dyn std::any::Any>>,
501    ) -> Result<HashMap<usize, Box<dyn std::any::Any>>, ExecutorError>;
502
503    /// Check if compilation is supported for this executor.
504    fn supports_compilation(&self) -> bool {
505        true
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use tensorlogic_ir::EinsumNode;
513
514    fn create_test_graph() -> EinsumGraph {
515        let mut graph = EinsumGraph::new();
516
517        // Add input tensors
518        graph.tensors.push("input".to_string());
519        graph.inputs.push(0);
520
521        // Add nodes that process the input
522        graph
523            .nodes
524            .push(EinsumNode::new("ij->ij", vec![0], vec![1]));
525        graph
526            .nodes
527            .push(EinsumNode::new("ij,jk->ik", vec![1], vec![2]));
528        graph
529            .nodes
530            .push(EinsumNode::new("ik->ik", vec![2], vec![3]));
531
532        // Mark final output
533        graph.outputs.push(3);
534
535        graph
536    }
537
538    #[test]
539    fn test_compilation_key_equality() {
540        let graph1 = create_test_graph();
541        let graph2 = create_test_graph();
542
543        let config = CompilationConfig::default();
544
545        let key1 = CompilationKey::new(&graph1, &config);
546        let key2 = CompilationKey::new(&graph2, &config);
547
548        assert_eq!(key1, key2);
549    }
550
551    #[test]
552    fn test_compilation_key_different_graphs() {
553        let graph1 = create_test_graph();
554        let mut graph2 = create_test_graph();
555        graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
556
557        let config = CompilationConfig::default();
558
559        let key1 = CompilationKey::new(&graph1, &config);
560        let key2 = CompilationKey::new(&graph2, &config);
561
562        assert_ne!(key1, key2);
563    }
564
565    #[test]
566    fn test_compilation_key_different_config() {
567        let graph = create_test_graph();
568
569        let config1 = CompilationConfig {
570            optimization_level: OptimizationLevel::Basic,
571            ..Default::default()
572        };
573
574        let config2 = CompilationConfig {
575            optimization_level: OptimizationLevel::Aggressive,
576            ..Default::default()
577        };
578
579        let key1 = CompilationKey::new(&graph, &config1);
580        let key2 = CompilationKey::new(&graph, &config2);
581
582        assert_ne!(key1, key2);
583    }
584
585    #[test]
586    fn test_graph_compiler_basic() {
587        let graph = create_test_graph();
588        let mut compiler = GraphCompiler::new(CompilationConfig {
589            optimization_level: OptimizationLevel::Basic,
590            ..Default::default()
591        });
592
593        let result = compiler.compile(&graph);
594        assert!(result.is_ok());
595
596        let compiled = result.expect("unwrap");
597        assert!(compiled.is_valid());
598        assert_eq!(compiled.stats.original_nodes, 3);
599    }
600
601    #[test]
602    fn test_graph_compiler_moderate() {
603        let graph = create_test_graph();
604        let mut compiler = GraphCompiler::new(CompilationConfig {
605            optimization_level: OptimizationLevel::Moderate,
606            ..Default::default()
607        });
608
609        let result = compiler.compile(&graph);
610        assert!(result.is_ok());
611
612        let compiled = result.expect("unwrap");
613        assert!(compiled.is_valid());
614        assert!(compiled.stats.compilation_time > Duration::from_secs(0));
615    }
616
617    #[test]
618    fn test_graph_compiler_aggressive() {
619        let graph = create_test_graph();
620        let mut compiler = GraphCompiler::new(CompilationConfig {
621            optimization_level: OptimizationLevel::Aggressive,
622            ..Default::default()
623        });
624
625        let result = compiler.compile(&graph);
626        assert!(result.is_ok());
627
628        let compiled = result.expect("unwrap");
629        assert!(compiled.is_valid());
630        assert_eq!(compiled.node_count(), compiled.stats.optimized_nodes);
631    }
632
633    #[test]
634    fn test_compiled_graph_summary() {
635        let graph = create_test_graph();
636        let mut compiler = GraphCompiler::with_default_config();
637        let compiled = compiler.compile(&graph).expect("unwrap");
638
639        let summary = compiled.summary();
640        assert!(summary.contains("CompiledGraph"));
641        assert!(summary.contains("nodes"));
642        assert!(summary.contains("MB"));
643    }
644
645    #[test]
646    fn test_compilation_cache_basic() {
647        let cache = CompilationCache::new(10);
648        assert_eq!(cache.len(), 0);
649        assert!(cache.is_empty());
650
651        let graph = create_test_graph();
652        let config = CompilationConfig::default();
653        let key = CompilationKey::new(&graph, &config);
654
655        // Cache miss
656        assert!(cache.get(&key).is_none());
657
658        // Insert and retrieve
659        let mut compiler = GraphCompiler::with_default_config();
660        let compiled = compiler.compile(&graph).expect("unwrap");
661        cache.insert(key.clone(), compiled);
662
663        assert_eq!(cache.len(), 1);
664        assert!(!cache.is_empty());
665
666        // Cache hit
667        let cached = cache.get(&key);
668        assert!(cached.is_some());
669    }
670
671    #[test]
672    fn test_compilation_cache_eviction() {
673        let cache = CompilationCache::new(2);
674
675        let graph1 = create_test_graph();
676        let mut graph2 = create_test_graph();
677        graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
678        let mut graph3 = create_test_graph();
679        graph3
680            .nodes
681            .push(EinsumNode::new("ij->ji", vec![3], vec![5]));
682
683        let config = CompilationConfig::default();
684        let mut compiler = GraphCompiler::with_default_config();
685
686        let key1 = CompilationKey::new(&graph1, &config);
687        let key2 = CompilationKey::new(&graph2, &config);
688        let key3 = CompilationKey::new(&graph3, &config);
689
690        // Fill cache
691        cache.insert(key1.clone(), compiler.compile(&graph1).expect("unwrap"));
692        cache.insert(key2.clone(), compiler.compile(&graph2).expect("unwrap"));
693        assert_eq!(cache.len(), 2);
694
695        // Add third entry - should evict first
696        cache.insert(key3.clone(), compiler.compile(&graph3).expect("unwrap"));
697        assert_eq!(cache.len(), 2);
698    }
699
700    #[test]
701    fn test_compilation_cache_stats() {
702        let cache = CompilationCache::new(10);
703
704        let graph = create_test_graph();
705        let config = CompilationConfig::default();
706        let key = CompilationKey::new(&graph, &config);
707
708        // Initial stats
709        let stats = cache.stats();
710        assert_eq!(stats.hits, 0);
711        assert_eq!(stats.misses, 0);
712        assert_eq!(stats.hit_rate(), 0.0);
713
714        // Cache miss
715        cache.get(&key);
716        let stats = cache.stats();
717        assert_eq!(stats.misses, 1);
718
719        // Insert and hit
720        let mut compiler = GraphCompiler::with_default_config();
721        let compiled = compiler.compile(&graph).expect("unwrap");
722        cache.insert(key.clone(), compiled);
723        cache.get(&key);
724
725        let stats = cache.stats();
726        assert_eq!(stats.hits, 1);
727        assert_eq!(stats.misses, 1);
728        assert_eq!(stats.hit_rate(), 0.5);
729    }
730
731    #[test]
732    fn test_compilation_cache_clear() {
733        let cache = CompilationCache::new(10);
734        let graph = create_test_graph();
735        let config = CompilationConfig::default();
736        let key = CompilationKey::new(&graph, &config);
737
738        let mut compiler = GraphCompiler::with_default_config();
739        let compiled = compiler.compile(&graph).expect("unwrap");
740        cache.insert(key.clone(), compiled);
741
742        assert_eq!(cache.len(), 1);
743
744        cache.clear();
745        assert_eq!(cache.len(), 0);
746        assert!(cache.is_empty());
747    }
748
749    #[test]
750    fn test_optimization_levels() {
751        let graph = create_test_graph();
752
753        let levels = vec![
754            OptimizationLevel::None,
755            OptimizationLevel::Basic,
756            OptimizationLevel::Moderate,
757            OptimizationLevel::Aggressive,
758        ];
759
760        for level in levels {
761            let mut compiler = GraphCompiler::new(CompilationConfig {
762                optimization_level: level,
763                ..Default::default()
764            });
765
766            let result = compiler.compile(&graph);
767            assert!(result.is_ok(), "Compilation failed for level {:?}", level);
768
769            let compiled = result.expect("unwrap");
770            assert!(compiled.is_valid());
771        }
772    }
773
774    #[test]
775    fn test_compiled_graph_memory_estimation() {
776        let graph = create_test_graph();
777        let mut compiler = GraphCompiler::new(CompilationConfig {
778            enable_memory_estimation: true,
779            ..Default::default()
780        });
781
782        let compiled = compiler.compile(&graph).expect("unwrap");
783        // Memory estimation should return a value (usize is always non-negative)
784        let _memory = compiled.total_memory();
785    }
786
787    #[test]
788    fn test_config_update() {
789        let mut compiler = GraphCompiler::with_default_config();
790
791        let new_config = CompilationConfig {
792            optimization_level: OptimizationLevel::Aggressive,
793            enable_parallelism: false,
794            ..Default::default()
795        };
796
797        compiler.set_config(new_config.clone());
798
799        let config = compiler.config();
800        assert_eq!(config.optimization_level, OptimizationLevel::Aggressive);
801        assert!(!config.enable_parallelism);
802    }
803}