Skip to main content

tensorlogic_infer/
compilation.rs

1//! Graph compilation and caching infrastructure.
2//!
3//! This module provides ahead-of-time graph optimization and compilation capabilities:
4//! - `CompiledGraph`: Optimized, executable representation of computation graphs
5//! - `GraphCompiler`: Applies optimization passes and produces compiled graphs
6//! - `CompilationCache`: Caches compiled graphs to avoid recompilation
7//! - `TlCompilableExecutor`: Trait for executors that support graph compilation
8//!
9//! # Example
10//!
11//! ```
12//! use tensorlogic_infer::compilation::{GraphCompiler, CompilationConfig, OptimizationLevel};
13//! use tensorlogic_infer::DummyExecutor;
14//! use tensorlogic_ir::EinsumGraph;
15//!
16//! let mut compiler = GraphCompiler::new(CompilationConfig {
17//!     optimization_level: OptimizationLevel::Aggressive,
18//!     ..Default::default()
19//! });
20//!
21//! let graph = EinsumGraph::new();
22//! let compiled = compiler.compile(&graph).unwrap();
23//! ```
24
25use crate::error::ExecutorError;
26use crate::memory::MemoryEstimator;
27use crate::optimization::{GraphOptimizer, OptimizationResult};
28use crate::scheduling::{ExecutionSchedule, Scheduler, SchedulingStrategy};
29use crate::shape::ShapeInferenceContext;
30use crate::validation::GraphValidator;
31use std::collections::HashMap;
32use std::sync::{Arc, RwLock};
33use std::time::{Duration, SystemTime};
34use tensorlogic_ir::EinsumGraph;
35
36/// Optimization level for graph compilation.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
38pub enum OptimizationLevel {
39    /// No optimization - compile as-is
40    None,
41    /// Basic optimizations (dead code elimination, constant folding)
42    Basic,
43    /// Moderate optimizations (fusion, CSE, basic scheduling)
44    #[default]
45    Moderate,
46    /// Aggressive optimizations (all passes, advanced scheduling)
47    Aggressive,
48}
49
50/// Configuration for graph compilation.
51#[derive(Debug, Clone)]
52pub struct CompilationConfig {
53    /// Optimization level to apply
54    pub optimization_level: OptimizationLevel,
55    /// Whether to enable shape inference
56    pub enable_shape_inference: bool,
57    /// Whether to enable memory estimation
58    pub enable_memory_estimation: bool,
59    /// Target device for compilation (e.g., "cpu", "cuda:0")
60    pub target_device: Option<String>,
61    /// Maximum memory budget in bytes (None = unlimited)
62    pub memory_budget: Option<usize>,
63    /// Enable caching of intermediate results
64    pub enable_caching: bool,
65    /// Enable parallel execution planning
66    pub enable_parallelism: bool,
67}
68
69impl Default for CompilationConfig {
70    fn default() -> Self {
71        CompilationConfig {
72            optimization_level: OptimizationLevel::default(),
73            enable_shape_inference: true,
74            enable_memory_estimation: true,
75            target_device: None,
76            memory_budget: None,
77            enable_caching: true,
78            enable_parallelism: true,
79        }
80    }
81}
82
83/// Statistics about the compilation process.
84#[derive(Debug, Clone)]
85pub struct CompilationStats {
86    /// Time taken for compilation
87    pub compilation_time: Duration,
88    /// Number of nodes in original graph
89    pub original_nodes: usize,
90    /// Number of nodes after optimization
91    pub optimized_nodes: usize,
92    /// Number of fusion opportunities applied
93    pub fusions_applied: usize,
94    /// Number of dead nodes eliminated
95    pub dead_nodes_eliminated: usize,
96    /// Estimated memory usage in bytes
97    pub estimated_memory_bytes: usize,
98    /// Scheduled execution steps
99    pub execution_steps: usize,
100}
101
102impl Default for CompilationStats {
103    fn default() -> Self {
104        CompilationStats {
105            compilation_time: Duration::from_secs(0),
106            original_nodes: 0,
107            optimized_nodes: 0,
108            fusions_applied: 0,
109            dead_nodes_eliminated: 0,
110            estimated_memory_bytes: 0,
111            execution_steps: 0,
112        }
113    }
114}
115
116/// Compiled representation of a computation graph.
117///
118/// Contains the optimized graph, execution schedule, and metadata
119/// necessary for efficient execution.
120#[derive(Debug, Clone)]
121pub struct CompiledGraph {
122    /// The optimized graph
123    pub graph: EinsumGraph,
124    /// Execution schedule for the graph
125    pub schedule: ExecutionSchedule,
126    /// Shape information (if available)
127    pub shapes: HashMap<usize, Vec<usize>>,
128    /// Estimated memory usage per node
129    pub memory_usage: HashMap<usize, usize>,
130    /// Configuration used for compilation
131    pub config: CompilationConfig,
132    /// Compilation statistics
133    pub stats: CompilationStats,
134    /// Timestamp when compiled
135    pub compiled_at: SystemTime,
136}
137
138impl CompiledGraph {
139    /// Get the number of nodes in the compiled graph
140    pub fn node_count(&self) -> usize {
141        self.graph.nodes.len()
142    }
143
144    /// Get the total estimated memory usage
145    pub fn total_memory(&self) -> usize {
146        self.memory_usage.values().sum()
147    }
148
149    /// Check if this compiled graph is still valid
150    pub fn is_valid(&self) -> bool {
151        // Check if graph structure is valid
152        if self.graph.nodes.is_empty() {
153            return false;
154        }
155
156        // Check if schedule matches graph
157        if self.schedule.execution_order.len() != self.graph.nodes.len() {
158            return false;
159        }
160
161        true
162    }
163
164    /// Get a summary of the compiled graph
165    pub fn summary(&self) -> String {
166        format!(
167            "CompiledGraph: {} nodes, {} steps, {:.2}MB memory, compiled in {:.2}ms",
168            self.node_count(),
169            self.stats.execution_steps,
170            self.total_memory() as f64 / 1_000_000.0,
171            self.stats.compilation_time.as_secs_f64() * 1000.0
172        )
173    }
174}
175
176/// Graph compiler that applies optimization passes.
177pub struct GraphCompiler {
178    config: CompilationConfig,
179    optimizer: GraphOptimizer,
180    validator: GraphValidator,
181    scheduler: Scheduler,
182}
183
184impl GraphCompiler {
185    /// Create a new graph compiler with the given configuration.
186    pub fn new(config: CompilationConfig) -> Self {
187        GraphCompiler {
188            config,
189            optimizer: GraphOptimizer::new(),
190            validator: GraphValidator::new(),
191            scheduler: Scheduler::new(SchedulingStrategy::Balanced),
192        }
193    }
194
195    /// Create a compiler with default configuration.
196    pub fn with_default_config() -> Self {
197        Self::new(CompilationConfig::default())
198    }
199
200    /// Compile a graph with the configured optimization passes.
201    pub fn compile(&mut self, graph: &EinsumGraph) -> Result<CompiledGraph, ExecutorError> {
202        let start_time = SystemTime::now();
203        let original_nodes = graph.nodes.len();
204
205        // Validate the graph
206        let validation_result = self.validator.validate(graph);
207        if !validation_result.is_valid {
208            return Err(ExecutorError::GraphValidationError(format!(
209                "Graph validation failed: {}",
210                validation_result
211                    .errors
212                    .first()
213                    .map(|e| e.as_str())
214                    .unwrap_or("unknown error")
215            )));
216        }
217
218        // Clone the graph for optimization
219        let optimized_graph = graph.clone();
220
221        // Apply optimizations based on level
222        let opt_result = match self.config.optimization_level {
223            OptimizationLevel::None => OptimizationResult {
224                fusion_opportunities: vec![],
225                dead_nodes: vec![],
226                redundant_computations: vec![],
227                estimated_improvement: 0.0,
228            },
229            OptimizationLevel::Basic
230            | OptimizationLevel::Moderate
231            | OptimizationLevel::Aggressive => {
232                // Analyze the graph to find optimization opportunities
233                self.optimizer.analyze(&optimized_graph)
234            }
235        };
236
237        // Generate execution schedule
238        let schedule = self.scheduler.schedule(&optimized_graph);
239
240        // Shape inference (if enabled)
241        let shapes = if self.config.enable_shape_inference {
242            let _shape_ctx = ShapeInferenceContext::new();
243            // Infer shapes for all nodes
244            // Note: This is a simplified version - real implementation would need input shapes
245            HashMap::new()
246        } else {
247            HashMap::new()
248        };
249
250        // Memory estimation (if enabled)
251        let memory_usage = if self.config.enable_memory_estimation {
252            use crate::capabilities::DType;
253            let estimator = MemoryEstimator::new(DType::F32);
254            let estimate = estimator.estimate(&optimized_graph);
255            // Build per-node memory map from estimate
256            let mut per_node: HashMap<usize, usize> = HashMap::new();
257            for (idx, mem) in estimate.intermediate_memory.iter().enumerate() {
258                per_node.insert(idx, mem.bytes);
259            }
260            per_node
261        } else {
262            HashMap::new()
263        };
264
265        let compilation_time = start_time.elapsed().unwrap_or(Duration::from_secs(0));
266
267        let stats = CompilationStats {
268            compilation_time,
269            original_nodes,
270            optimized_nodes: optimized_graph.nodes.len(),
271            fusions_applied: opt_result.fusion_opportunities.len(),
272            dead_nodes_eliminated: opt_result.dead_nodes.len(),
273            estimated_memory_bytes: memory_usage.values().sum(),
274            execution_steps: schedule.execution_order.len(),
275        };
276
277        Ok(CompiledGraph {
278            graph: optimized_graph,
279            schedule,
280            shapes,
281            memory_usage,
282            config: self.config.clone(),
283            stats,
284            compiled_at: SystemTime::now(),
285        })
286    }
287
288    /// Update the compilation configuration.
289    pub fn set_config(&mut self, config: CompilationConfig) {
290        self.config = config;
291    }
292
293    /// Get the current configuration.
294    pub fn config(&self) -> &CompilationConfig {
295        &self.config
296    }
297}
298
299/// Cache key for compiled graphs.
300#[derive(Debug, Clone, PartialEq, Eq, Hash)]
301pub struct CompilationKey {
302    /// Hash of the graph structure
303    pub graph_hash: u64,
304    /// Optimization level used
305    pub optimization_level: OptimizationLevel,
306    /// Target device (if specified)
307    pub target_device: Option<String>,
308}
309
310impl CompilationKey {
311    /// Create a key from a graph and config.
312    pub fn new(graph: &EinsumGraph, config: &CompilationConfig) -> Self {
313        CompilationKey {
314            graph_hash: Self::hash_graph(graph),
315            optimization_level: config.optimization_level,
316            target_device: config.target_device.clone(),
317        }
318    }
319
320    /// Compute a hash of the graph structure.
321    fn hash_graph(graph: &EinsumGraph) -> u64 {
322        use std::collections::hash_map::DefaultHasher;
323        use std::hash::{Hash, Hasher};
324
325        let mut hasher = DefaultHasher::new();
326
327        // Hash node count
328        graph.nodes.len().hash(&mut hasher);
329
330        // Hash each node's operation type and connections
331        for node in &graph.nodes {
332            // Hash operation type
333            match &node.op {
334                tensorlogic_ir::OpType::Einsum { spec } => {
335                    "einsum".hash(&mut hasher);
336                    spec.hash(&mut hasher);
337                }
338                tensorlogic_ir::OpType::Reduce { op, axes } => {
339                    "reduce".hash(&mut hasher);
340                    op.hash(&mut hasher);
341                    axes.hash(&mut hasher);
342                }
343                tensorlogic_ir::OpType::ElemUnary { op } => {
344                    "elemunary".hash(&mut hasher);
345                    op.hash(&mut hasher);
346                }
347                tensorlogic_ir::OpType::ElemBinary { op } => {
348                    "elembinary".hash(&mut hasher);
349                    op.hash(&mut hasher);
350                }
351            }
352
353            // Hash inputs and outputs
354            node.inputs.hash(&mut hasher);
355            node.outputs.hash(&mut hasher);
356        }
357
358        hasher.finish()
359    }
360}
361
362/// Statistics for the compilation cache.
363#[derive(Debug, Clone, Default)]
364pub struct CacheStats {
365    /// Number of cache hits
366    pub hits: usize,
367    /// Number of cache misses
368    pub misses: usize,
369    /// Number of entries in cache
370    pub size: usize,
371    /// Total compilation time saved (approximate)
372    pub time_saved: Duration,
373}
374
375impl CacheStats {
376    /// Calculate hit rate (0.0 to 1.0)
377    pub fn hit_rate(&self) -> f64 {
378        let total = self.hits + self.misses;
379        if total == 0 {
380            0.0
381        } else {
382            self.hits as f64 / total as f64
383        }
384    }
385}
386
387/// Cache for compiled graphs.
388///
389/// Stores compiled graphs by their cache key to avoid recompilation
390/// of the same graph with the same configuration.
391pub struct CompilationCache {
392    cache: Arc<RwLock<HashMap<CompilationKey, Arc<CompiledGraph>>>>,
393    stats: Arc<RwLock<CacheStats>>,
394    max_size: usize,
395}
396
397impl CompilationCache {
398    /// Create a new compilation cache with the given maximum size.
399    pub fn new(max_size: usize) -> Self {
400        CompilationCache {
401            cache: Arc::new(RwLock::new(HashMap::new())),
402            stats: Arc::new(RwLock::new(CacheStats::default())),
403            max_size,
404        }
405    }
406
407    /// Create a cache with default size (100 entries).
408    pub fn with_default_size() -> Self {
409        Self::new(100)
410    }
411
412    /// Get a compiled graph from the cache.
413    pub fn get(&self, key: &CompilationKey) -> Option<Arc<CompiledGraph>> {
414        let cache = self.cache.read().unwrap();
415        let result = cache.get(key).cloned();
416
417        // Update stats
418        let mut stats = self.stats.write().unwrap();
419        if let Some(ref compiled) = result {
420            stats.hits += 1;
421            stats.time_saved += compiled.stats.compilation_time;
422        } else {
423            stats.misses += 1;
424        }
425
426        result
427    }
428
429    /// Insert a compiled graph into the cache.
430    pub fn insert(&self, key: CompilationKey, compiled: CompiledGraph) {
431        let mut cache = self.cache.write().unwrap();
432
433        // Evict oldest entry if at capacity
434        if cache.len() >= self.max_size && !cache.contains_key(&key) {
435            if let Some(oldest_key) = cache.keys().next().cloned() {
436                cache.remove(&oldest_key);
437            }
438        }
439
440        cache.insert(key, Arc::new(compiled));
441
442        // Update size stat
443        let mut stats = self.stats.write().unwrap();
444        stats.size = cache.len();
445    }
446
447    /// Clear the cache.
448    pub fn clear(&self) {
449        let mut cache = self.cache.write().unwrap();
450        cache.clear();
451
452        let mut stats = self.stats.write().unwrap();
453        stats.size = 0;
454    }
455
456    /// Get cache statistics.
457    pub fn stats(&self) -> CacheStats {
458        self.stats.read().unwrap().clone()
459    }
460
461    /// Get the number of entries in the cache.
462    pub fn len(&self) -> usize {
463        self.cache.read().unwrap().len()
464    }
465
466    /// Check if the cache is empty.
467    pub fn is_empty(&self) -> bool {
468        self.len() == 0
469    }
470}
471
472/// Trait for executors that support graph compilation.
473///
474/// Executors implementing this trait can execute pre-compiled graphs
475/// more efficiently than executing the original graph.
476pub trait TlCompilableExecutor {
477    /// Compile a graph for efficient execution.
478    ///
479    /// Returns a compiled graph that can be executed multiple times
480    /// with different inputs without recompiling.
481    fn compile_graph(
482        &mut self,
483        graph: &EinsumGraph,
484        config: &CompilationConfig,
485    ) -> Result<CompiledGraph, ExecutorError>;
486
487    /// Execute a compiled graph.
488    ///
489    /// This should be more efficient than executing the original graph
490    /// since optimization passes have already been applied.
491    fn execute_compiled(
492        &mut self,
493        compiled: &CompiledGraph,
494        inputs: &HashMap<usize, Box<dyn std::any::Any>>,
495    ) -> Result<HashMap<usize, Box<dyn std::any::Any>>, ExecutorError>;
496
497    /// Check if compilation is supported for this executor.
498    fn supports_compilation(&self) -> bool {
499        true
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use tensorlogic_ir::EinsumNode;
507
508    fn create_test_graph() -> EinsumGraph {
509        let mut graph = EinsumGraph::new();
510
511        // Add input tensors
512        graph.tensors.push("input".to_string());
513        graph.inputs.push(0);
514
515        // Add nodes that process the input
516        graph
517            .nodes
518            .push(EinsumNode::new("ij->ij", vec![0], vec![1]));
519        graph
520            .nodes
521            .push(EinsumNode::new("ij,jk->ik", vec![1], vec![2]));
522        graph
523            .nodes
524            .push(EinsumNode::new("ik->ik", vec![2], vec![3]));
525
526        // Mark final output
527        graph.outputs.push(3);
528
529        graph
530    }
531
532    #[test]
533    fn test_compilation_key_equality() {
534        let graph1 = create_test_graph();
535        let graph2 = create_test_graph();
536
537        let config = CompilationConfig::default();
538
539        let key1 = CompilationKey::new(&graph1, &config);
540        let key2 = CompilationKey::new(&graph2, &config);
541
542        assert_eq!(key1, key2);
543    }
544
545    #[test]
546    fn test_compilation_key_different_graphs() {
547        let graph1 = create_test_graph();
548        let mut graph2 = create_test_graph();
549        graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
550
551        let config = CompilationConfig::default();
552
553        let key1 = CompilationKey::new(&graph1, &config);
554        let key2 = CompilationKey::new(&graph2, &config);
555
556        assert_ne!(key1, key2);
557    }
558
559    #[test]
560    fn test_compilation_key_different_config() {
561        let graph = create_test_graph();
562
563        let config1 = CompilationConfig {
564            optimization_level: OptimizationLevel::Basic,
565            ..Default::default()
566        };
567
568        let config2 = CompilationConfig {
569            optimization_level: OptimizationLevel::Aggressive,
570            ..Default::default()
571        };
572
573        let key1 = CompilationKey::new(&graph, &config1);
574        let key2 = CompilationKey::new(&graph, &config2);
575
576        assert_ne!(key1, key2);
577    }
578
579    #[test]
580    fn test_graph_compiler_basic() {
581        let graph = create_test_graph();
582        let mut compiler = GraphCompiler::new(CompilationConfig {
583            optimization_level: OptimizationLevel::Basic,
584            ..Default::default()
585        });
586
587        let result = compiler.compile(&graph);
588        assert!(result.is_ok());
589
590        let compiled = result.unwrap();
591        assert!(compiled.is_valid());
592        assert_eq!(compiled.stats.original_nodes, 3);
593    }
594
595    #[test]
596    fn test_graph_compiler_moderate() {
597        let graph = create_test_graph();
598        let mut compiler = GraphCompiler::new(CompilationConfig {
599            optimization_level: OptimizationLevel::Moderate,
600            ..Default::default()
601        });
602
603        let result = compiler.compile(&graph);
604        assert!(result.is_ok());
605
606        let compiled = result.unwrap();
607        assert!(compiled.is_valid());
608        assert!(compiled.stats.compilation_time > Duration::from_secs(0));
609    }
610
611    #[test]
612    fn test_graph_compiler_aggressive() {
613        let graph = create_test_graph();
614        let mut compiler = GraphCompiler::new(CompilationConfig {
615            optimization_level: OptimizationLevel::Aggressive,
616            ..Default::default()
617        });
618
619        let result = compiler.compile(&graph);
620        assert!(result.is_ok());
621
622        let compiled = result.unwrap();
623        assert!(compiled.is_valid());
624        assert_eq!(compiled.node_count(), compiled.stats.optimized_nodes);
625    }
626
627    #[test]
628    fn test_compiled_graph_summary() {
629        let graph = create_test_graph();
630        let mut compiler = GraphCompiler::with_default_config();
631        let compiled = compiler.compile(&graph).unwrap();
632
633        let summary = compiled.summary();
634        assert!(summary.contains("CompiledGraph"));
635        assert!(summary.contains("nodes"));
636        assert!(summary.contains("MB"));
637    }
638
639    #[test]
640    fn test_compilation_cache_basic() {
641        let cache = CompilationCache::new(10);
642        assert_eq!(cache.len(), 0);
643        assert!(cache.is_empty());
644
645        let graph = create_test_graph();
646        let config = CompilationConfig::default();
647        let key = CompilationKey::new(&graph, &config);
648
649        // Cache miss
650        assert!(cache.get(&key).is_none());
651
652        // Insert and retrieve
653        let mut compiler = GraphCompiler::with_default_config();
654        let compiled = compiler.compile(&graph).unwrap();
655        cache.insert(key.clone(), compiled);
656
657        assert_eq!(cache.len(), 1);
658        assert!(!cache.is_empty());
659
660        // Cache hit
661        let cached = cache.get(&key);
662        assert!(cached.is_some());
663    }
664
665    #[test]
666    fn test_compilation_cache_eviction() {
667        let cache = CompilationCache::new(2);
668
669        let graph1 = create_test_graph();
670        let mut graph2 = create_test_graph();
671        graph2.nodes.push(EinsumNode::new("i->i", vec![3], vec![4]));
672        let mut graph3 = create_test_graph();
673        graph3
674            .nodes
675            .push(EinsumNode::new("ij->ji", vec![3], vec![5]));
676
677        let config = CompilationConfig::default();
678        let mut compiler = GraphCompiler::with_default_config();
679
680        let key1 = CompilationKey::new(&graph1, &config);
681        let key2 = CompilationKey::new(&graph2, &config);
682        let key3 = CompilationKey::new(&graph3, &config);
683
684        // Fill cache
685        cache.insert(key1.clone(), compiler.compile(&graph1).unwrap());
686        cache.insert(key2.clone(), compiler.compile(&graph2).unwrap());
687        assert_eq!(cache.len(), 2);
688
689        // Add third entry - should evict first
690        cache.insert(key3.clone(), compiler.compile(&graph3).unwrap());
691        assert_eq!(cache.len(), 2);
692    }
693
694    #[test]
695    fn test_compilation_cache_stats() {
696        let cache = CompilationCache::new(10);
697
698        let graph = create_test_graph();
699        let config = CompilationConfig::default();
700        let key = CompilationKey::new(&graph, &config);
701
702        // Initial stats
703        let stats = cache.stats();
704        assert_eq!(stats.hits, 0);
705        assert_eq!(stats.misses, 0);
706        assert_eq!(stats.hit_rate(), 0.0);
707
708        // Cache miss
709        cache.get(&key);
710        let stats = cache.stats();
711        assert_eq!(stats.misses, 1);
712
713        // Insert and hit
714        let mut compiler = GraphCompiler::with_default_config();
715        let compiled = compiler.compile(&graph).unwrap();
716        cache.insert(key.clone(), compiled);
717        cache.get(&key);
718
719        let stats = cache.stats();
720        assert_eq!(stats.hits, 1);
721        assert_eq!(stats.misses, 1);
722        assert_eq!(stats.hit_rate(), 0.5);
723    }
724
725    #[test]
726    fn test_compilation_cache_clear() {
727        let cache = CompilationCache::new(10);
728        let graph = create_test_graph();
729        let config = CompilationConfig::default();
730        let key = CompilationKey::new(&graph, &config);
731
732        let mut compiler = GraphCompiler::with_default_config();
733        let compiled = compiler.compile(&graph).unwrap();
734        cache.insert(key.clone(), compiled);
735
736        assert_eq!(cache.len(), 1);
737
738        cache.clear();
739        assert_eq!(cache.len(), 0);
740        assert!(cache.is_empty());
741    }
742
743    #[test]
744    fn test_optimization_levels() {
745        let graph = create_test_graph();
746
747        let levels = vec![
748            OptimizationLevel::None,
749            OptimizationLevel::Basic,
750            OptimizationLevel::Moderate,
751            OptimizationLevel::Aggressive,
752        ];
753
754        for level in levels {
755            let mut compiler = GraphCompiler::new(CompilationConfig {
756                optimization_level: level,
757                ..Default::default()
758            });
759
760            let result = compiler.compile(&graph);
761            assert!(result.is_ok(), "Compilation failed for level {:?}", level);
762
763            let compiled = result.unwrap();
764            assert!(compiled.is_valid());
765        }
766    }
767
768    #[test]
769    fn test_compiled_graph_memory_estimation() {
770        let graph = create_test_graph();
771        let mut compiler = GraphCompiler::new(CompilationConfig {
772            enable_memory_estimation: true,
773            ..Default::default()
774        });
775
776        let compiled = compiler.compile(&graph).unwrap();
777        // Memory estimation should return a value (usize is always non-negative)
778        let _memory = compiled.total_memory();
779    }
780
781    #[test]
782    fn test_config_update() {
783        let mut compiler = GraphCompiler::with_default_config();
784
785        let new_config = CompilationConfig {
786            optimization_level: OptimizationLevel::Aggressive,
787            enable_parallelism: false,
788            ..Default::default()
789        };
790
791        compiler.set_config(new_config.clone());
792
793        let config = compiler.config();
794        assert_eq!(config.optimization_level, OptimizationLevel::Aggressive);
795        assert!(!config.enable_parallelism);
796    }
797}