Skip to main content

tensorlogic_ir/graph/
constant_folding.rs

1//! Constant propagation and folding optimizations.
2//!
3//! This module identifies constant tensors and folds constant computations
4//! at compile time to reduce runtime overhead.
5
6use std::collections::{HashMap, HashSet};
7
8use serde::{Deserialize, Serialize};
9
10use crate::{EinsumGraph, EinsumNode, IrError, OpType};
11
12/// Information about a constant tensor.
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14pub struct ConstantInfo {
15    /// Tensor index
16    pub tensor_idx: usize,
17    /// Whether the value is known at compile time
18    pub is_compile_time_constant: bool,
19    /// Whether the value is an identity (e.g., 1 for multiplication)
20    pub is_identity: bool,
21    /// Whether the value is zero/absorbing
22    pub is_zero: bool,
23}
24
25/// Result of constant propagation analysis.
26#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
27pub struct ConstantPropagationResult {
28    /// Set of constant tensor indices
29    pub constant_tensors: HashSet<usize>,
30    /// Detailed information about each constant
31    pub constant_info: HashMap<usize, ConstantInfo>,
32    /// Number of operations that can be constant-folded
33    pub foldable_operations: usize,
34    /// Estimated speedup from constant folding
35    pub estimated_speedup: f64,
36}
37
38impl ConstantPropagationResult {
39    /// Create a result with no constants.
40    pub fn none() -> Self {
41        Self {
42            constant_tensors: HashSet::new(),
43            constant_info: HashMap::new(),
44            foldable_operations: 0,
45            estimated_speedup: 1.0,
46        }
47    }
48
49    /// Check if a tensor is constant.
50    pub fn is_constant(&self, tensor_idx: usize) -> bool {
51        self.constant_tensors.contains(&tensor_idx)
52    }
53
54    /// Get detailed information about a constant tensor.
55    pub fn get_info(&self, tensor_idx: usize) -> Option<&ConstantInfo> {
56        self.constant_info.get(&tensor_idx)
57    }
58}
59
60/// Statistics from constant folding transformation.
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct FoldingStats {
63    /// Number of operations folded
64    pub operations_folded: usize,
65    /// Number of operations simplified (e.g., x * 1 → x)
66    pub operations_simplified: usize,
67    /// Number of operations eliminated (e.g., x * 0 → 0)
68    pub operations_eliminated: usize,
69    /// Estimated speedup
70    pub estimated_speedup: f64,
71}
72
73impl FoldingStats {
74    /// Create statistics with no transformations.
75    pub fn none() -> Self {
76        Self {
77            operations_folded: 0,
78            operations_simplified: 0,
79            operations_eliminated: 0,
80            estimated_speedup: 1.0,
81        }
82    }
83
84    /// Total number of transformations.
85    pub fn total_transformations(&self) -> usize {
86        self.operations_folded + self.operations_simplified + self.operations_eliminated
87    }
88}
89
90/// Analyze constant propagation opportunities in a graph.
91pub fn analyze_constants(graph: &EinsumGraph) -> Result<ConstantPropagationResult, IrError> {
92    let mut result = ConstantPropagationResult::none();
93
94    // Start with input tensors marked as potentially constant
95    let _constant_candidates: HashSet<usize> = graph.inputs.iter().copied().collect();
96
97    // Identify tensors that are compile-time constants based on metadata
98    for (tensor_idx, metadata) in &graph.tensor_metadata {
99        if is_compile_time_constant(metadata) {
100            result.constant_tensors.insert(*tensor_idx);
101            result.constant_info.insert(
102                *tensor_idx,
103                ConstantInfo {
104                    tensor_idx: *tensor_idx,
105                    is_compile_time_constant: true,
106                    is_identity: is_identity_value(metadata),
107                    is_zero: is_zero_value(metadata),
108                },
109            );
110        }
111    }
112
113    // Propagate constant information through the graph
114    let mut changed = true;
115    while changed {
116        changed = false;
117
118        for node in graph.nodes.iter() {
119            // Check if all inputs are constants
120            let all_inputs_constant = node
121                .inputs
122                .iter()
123                .all(|&idx| result.constant_tensors.contains(&idx));
124
125            if all_inputs_constant && !node.inputs.is_empty() {
126                // This operation can be constant-folded
127                for &output_idx in &node.outputs {
128                    if !result.constant_tensors.contains(&output_idx) {
129                        result.constant_tensors.insert(output_idx);
130                        result.constant_info.insert(
131                            output_idx,
132                            ConstantInfo {
133                                tensor_idx: output_idx,
134                                is_compile_time_constant: true,
135                                is_identity: false,
136                                is_zero: false,
137                            },
138                        );
139                        result.foldable_operations += 1;
140                        changed = true;
141                    }
142                }
143            }
144        }
145    }
146
147    // Estimate speedup
148    if result.foldable_operations > 0 {
149        let total_ops = graph.nodes.len();
150        let folding_ratio = result.foldable_operations as f64 / total_ops.max(1) as f64;
151        result.estimated_speedup = 1.0 + folding_ratio * 0.3; // Conservative estimate
152    }
153
154    Ok(result)
155}
156
157/// Apply constant folding transformations to a graph.
158pub fn apply_constant_folding(
159    graph: &mut EinsumGraph,
160    constants: &ConstantPropagationResult,
161) -> Result<FoldingStats, IrError> {
162    let mut stats = FoldingStats::none();
163    let mut replacements: HashMap<usize, usize> = HashMap::new();
164
165    // First pass: identify algebraic simplifications
166    for node in graph.nodes.iter() {
167        if let Some(simplified_output) = try_simplify_operation(node, constants) {
168            // Record the simplification
169            if !node.outputs.is_empty() {
170                replacements.insert(node.outputs[0], simplified_output);
171                stats.operations_simplified += 1;
172            }
173        } else if try_eliminate_operation(node, constants) {
174            stats.operations_eliminated += 1;
175        } else if constants.is_constant(node.outputs.first().copied().unwrap_or(usize::MAX)) {
176            stats.operations_folded += 1;
177        }
178    }
179
180    // Second pass: apply replacements
181    for node in &mut graph.nodes {
182        for input_idx in &mut node.inputs {
183            if let Some(&replacement) = replacements.get(input_idx) {
184                *input_idx = replacement;
185            }
186        }
187    }
188
189    // Update outputs
190    for output_idx in &mut graph.outputs {
191        if let Some(&replacement) = replacements.get(output_idx) {
192            *output_idx = replacement;
193        }
194    }
195
196    // Estimate speedup
197    if stats.total_transformations() > 0 {
198        let total_ops = graph.nodes.len().max(1);
199        let optimization_ratio = stats.total_transformations() as f64 / total_ops as f64;
200        stats.estimated_speedup = 1.0 + optimization_ratio * 0.4;
201    }
202
203    Ok(stats)
204}
205
206/// Perform aggressive constant folding with all available optimizations.
207pub fn fold_constants_aggressive(graph: &mut EinsumGraph) -> Result<FoldingStats, IrError> {
208    let mut total_stats = FoldingStats::none();
209
210    // Multiple passes for maximum effect
211    for _ in 0..3 {
212        let constants = analyze_constants(graph)?;
213        let stats = apply_constant_folding(graph, &constants)?;
214
215        total_stats.operations_folded += stats.operations_folded;
216        total_stats.operations_simplified += stats.operations_simplified;
217        total_stats.operations_eliminated += stats.operations_eliminated;
218
219        // Stop if no changes
220        if stats.total_transformations() == 0 {
221            break;
222        }
223    }
224
225    // Update final speedup estimate
226    if total_stats.total_transformations() > 0 {
227        let total_ops = graph.nodes.len().max(1);
228        let optimization_ratio = total_stats.total_transformations() as f64 / total_ops as f64;
229        total_stats.estimated_speedup = 1.0 + optimization_ratio * 0.5;
230    }
231
232    Ok(total_stats)
233}
234
235/// Identify constant subgraphs that can be pre-computed.
236pub fn identify_constant_subgraphs(graph: &EinsumGraph) -> Result<Vec<Vec<usize>>, IrError> {
237    let constants = analyze_constants(graph)?;
238    let mut subgraphs = Vec::new();
239    let mut visited = HashSet::new();
240
241    for (node_idx, node) in graph.nodes.iter().enumerate() {
242        if visited.contains(&node_idx) {
243            continue;
244        }
245
246        // Check if all inputs are constants
247        let all_constant = node.inputs.iter().all(|&idx| constants.is_constant(idx));
248
249        if all_constant && !node.inputs.is_empty() {
250            // Find connected constant subgraph
251            let mut subgraph = vec![node_idx];
252            visited.insert(node_idx);
253
254            // Expand to include dependent constant operations
255            let mut changed = true;
256            while changed {
257                changed = false;
258                for (idx, n) in graph.nodes.iter().enumerate() {
259                    if visited.contains(&idx) {
260                        continue;
261                    }
262
263                    let depends_on_subgraph = n.inputs.iter().any(|&input_idx| {
264                        graph.nodes.iter().enumerate().any(|(sub_idx, sub_node)| {
265                            subgraph.contains(&sub_idx) && sub_node.outputs.contains(&input_idx)
266                        })
267                    });
268
269                    if depends_on_subgraph {
270                        subgraph.push(idx);
271                        visited.insert(idx);
272                        changed = true;
273                    }
274                }
275            }
276
277            if !subgraph.is_empty() {
278                subgraphs.push(subgraph);
279            }
280        }
281    }
282
283    Ok(subgraphs)
284}
285
286// Helper functions
287
288fn is_compile_time_constant(metadata: &crate::Metadata) -> bool {
289    metadata
290        .get_attribute("constant")
291        .map(|v| v == "true")
292        .unwrap_or(false)
293}
294
295fn is_identity_value(metadata: &crate::Metadata) -> bool {
296    metadata
297        .get_attribute("identity")
298        .map(|v| v == "true")
299        .unwrap_or(false)
300}
301
302fn is_zero_value(metadata: &crate::Metadata) -> bool {
303    metadata
304        .get_attribute("zero")
305        .map(|v| v == "true")
306        .unwrap_or(false)
307}
308
309fn try_simplify_operation(
310    node: &EinsumNode,
311    constants: &ConstantPropagationResult,
312) -> Option<usize> {
313    if let OpType::ElemBinary { op } = &node.op {
314        if node.inputs.len() == 2 {
315            let left = node.inputs[0];
316            let right = node.inputs[1];
317
318            // Simplify x + 0 → x or 0 + x → x
319            if op == "add" {
320                if constants.get_info(right).is_some_and(|info| info.is_zero) {
321                    return Some(left);
322                }
323                if constants.get_info(left).is_some_and(|info| info.is_zero) {
324                    return Some(right);
325                }
326            }
327
328            // Simplify x * 1 → x or 1 * x → x
329            if op == "mul" {
330                if constants
331                    .get_info(right)
332                    .is_some_and(|info| info.is_identity)
333                {
334                    return Some(left);
335                }
336                if constants
337                    .get_info(left)
338                    .is_some_and(|info| info.is_identity)
339                {
340                    return Some(right);
341                }
342            }
343        }
344    }
345
346    None
347}
348
349fn try_eliminate_operation(node: &EinsumNode, constants: &ConstantPropagationResult) -> bool {
350    if let OpType::ElemBinary { op } = &node.op {
351        if node.inputs.len() == 2 {
352            let left = node.inputs[0];
353            let right = node.inputs[1];
354
355            // Eliminate x * 0 or 0 * x (result is always 0)
356            if op == "mul" {
357                return constants.get_info(left).is_some_and(|info| info.is_zero)
358                    || constants.get_info(right).is_some_and(|info| info.is_zero);
359            }
360        }
361    }
362
363    false
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::Metadata;
370
371    fn create_constant_metadata() -> Metadata {
372        Metadata::new().with_attribute("constant", "true")
373    }
374
375    fn create_zero_metadata() -> Metadata {
376        Metadata::new()
377            .with_attribute("constant", "true")
378            .with_attribute("zero", "true")
379    }
380
381    fn create_identity_metadata() -> Metadata {
382        Metadata::new()
383            .with_attribute("constant", "true")
384            .with_attribute("identity", "true")
385    }
386
387    #[test]
388    fn test_constant_info() {
389        let info = ConstantInfo {
390            tensor_idx: 0,
391            is_compile_time_constant: true,
392            is_identity: false,
393            is_zero: false,
394        };
395
396        assert_eq!(info.tensor_idx, 0);
397        assert!(info.is_compile_time_constant);
398        assert!(!info.is_identity);
399        assert!(!info.is_zero);
400    }
401
402    #[test]
403    fn test_constant_propagation_result_none() {
404        let result = ConstantPropagationResult::none();
405        assert!(result.constant_tensors.is_empty());
406        assert!(result.constant_info.is_empty());
407        assert_eq!(result.foldable_operations, 0);
408        assert_eq!(result.estimated_speedup, 1.0);
409    }
410
411    #[test]
412    fn test_folding_stats_none() {
413        let stats = FoldingStats::none();
414        assert_eq!(stats.operations_folded, 0);
415        assert_eq!(stats.operations_simplified, 0);
416        assert_eq!(stats.operations_eliminated, 0);
417        assert_eq!(stats.total_transformations(), 0);
418    }
419
420    #[test]
421    fn test_analyze_constants_empty_graph() {
422        let graph = EinsumGraph::new();
423        let result = analyze_constants(&graph).unwrap();
424        assert!(result.constant_tensors.is_empty());
425    }
426
427    #[test]
428    fn test_analyze_constants_with_metadata() {
429        let mut graph = EinsumGraph::new();
430        let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
431        let b = graph.add_tensor("B");
432
433        graph
434            .add_node(EinsumNode::elem_unary("relu", a, b))
435            .unwrap();
436
437        let result = analyze_constants(&graph).unwrap();
438        assert!(result.is_constant(a));
439        assert!(result.is_constant(b)); // Propagated from a
440        assert_eq!(result.foldable_operations, 1);
441    }
442
443    #[test]
444    fn test_simplify_add_zero() {
445        let mut graph = EinsumGraph::new();
446        let x = graph.add_tensor("x");
447        let zero = graph.add_tensor_with_metadata("zero", create_zero_metadata());
448        let result = graph.add_tensor("result");
449
450        let node = EinsumNode::elem_binary("add", x, zero, result);
451
452        let mut const_result = ConstantPropagationResult::none();
453        const_result.constant_tensors.insert(zero);
454        const_result.constant_info.insert(
455            zero,
456            ConstantInfo {
457                tensor_idx: zero,
458                is_compile_time_constant: true,
459                is_identity: false,
460                is_zero: true,
461            },
462        );
463
464        let simplified = try_simplify_operation(&node, &const_result);
465        assert_eq!(simplified, Some(x));
466    }
467
468    #[test]
469    fn test_simplify_mul_one() {
470        let mut graph = EinsumGraph::new();
471        let x = graph.add_tensor("x");
472        let one = graph.add_tensor_with_metadata("one", create_identity_metadata());
473        let result = graph.add_tensor("result");
474
475        let node = EinsumNode::elem_binary("mul", x, one, result);
476
477        let mut const_result = ConstantPropagationResult::none();
478        const_result.constant_tensors.insert(one);
479        const_result.constant_info.insert(
480            one,
481            ConstantInfo {
482                tensor_idx: one,
483                is_compile_time_constant: true,
484                is_identity: true,
485                is_zero: false,
486            },
487        );
488
489        let simplified = try_simplify_operation(&node, &const_result);
490        assert_eq!(simplified, Some(x));
491    }
492
493    #[test]
494    fn test_eliminate_mul_zero() {
495        let mut graph = EinsumGraph::new();
496        let x = graph.add_tensor("x");
497        let zero = graph.add_tensor_with_metadata("zero", create_zero_metadata());
498        let result = graph.add_tensor("result");
499
500        let node = EinsumNode::elem_binary("mul", x, zero, result);
501
502        let mut const_result = ConstantPropagationResult::none();
503        const_result.constant_tensors.insert(zero);
504        const_result.constant_info.insert(
505            zero,
506            ConstantInfo {
507                tensor_idx: zero,
508                is_compile_time_constant: true,
509                is_identity: false,
510                is_zero: true,
511            },
512        );
513
514        let should_eliminate = try_eliminate_operation(&node, &const_result);
515        assert!(should_eliminate);
516    }
517
518    #[test]
519    fn test_apply_constant_folding() {
520        let mut graph = EinsumGraph::new();
521        let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
522        let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
523        let c = graph.add_tensor("C");
524
525        graph
526            .add_node(EinsumNode::elem_binary("add", a, b, c))
527            .unwrap();
528
529        let constants = analyze_constants(&graph).unwrap();
530        let stats = apply_constant_folding(&mut graph, &constants).unwrap();
531
532        assert!(stats.operations_folded > 0 || stats.total_transformations() > 0);
533    }
534
535    #[test]
536    fn test_fold_constants_aggressive() {
537        let mut graph = EinsumGraph::new();
538        let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
539        let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
540        let c = graph.add_tensor("C");
541        let d = graph.add_tensor("D");
542
543        graph
544            .add_node(EinsumNode::elem_binary("add", a, b, c))
545            .unwrap();
546        graph
547            .add_node(EinsumNode::elem_unary("relu", c, d))
548            .unwrap();
549
550        let stats = fold_constants_aggressive(&mut graph).unwrap();
551        assert!(stats.operations_folded >= 1);
552    }
553
554    #[test]
555    fn test_identify_constant_subgraphs() {
556        let mut graph = EinsumGraph::new();
557        let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
558        let b = graph.add_tensor_with_metadata("B", create_constant_metadata());
559        let c = graph.add_tensor("C");
560
561        graph
562            .add_node(EinsumNode::elem_binary("add", a, b, c))
563            .unwrap();
564
565        let subgraphs = identify_constant_subgraphs(&graph).unwrap();
566        assert!(!subgraphs.is_empty());
567    }
568
569    #[test]
570    fn test_is_constant_metadata_helpers() {
571        let const_metadata = create_constant_metadata();
572        assert!(is_compile_time_constant(&const_metadata));
573
574        let zero_metadata = create_zero_metadata();
575        assert!(is_compile_time_constant(&zero_metadata));
576        assert!(is_zero_value(&zero_metadata));
577
578        let identity_metadata = create_identity_metadata();
579        assert!(is_compile_time_constant(&identity_metadata));
580        assert!(is_identity_value(&identity_metadata));
581    }
582
583    #[test]
584    fn test_constant_propagation_through_chain() {
585        let mut graph = EinsumGraph::new();
586        let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
587        let b = graph.add_tensor("B");
588        let c = graph.add_tensor("C");
589        let d = graph.add_tensor("D");
590
591        graph
592            .add_node(EinsumNode::elem_unary("relu", a, b))
593            .unwrap();
594        graph
595            .add_node(EinsumNode::elem_unary("relu", b, c))
596            .unwrap();
597        graph
598            .add_node(EinsumNode::elem_unary("relu", c, d))
599            .unwrap();
600
601        let result = analyze_constants(&graph).unwrap();
602
603        assert!(result.is_constant(a));
604        assert!(result.is_constant(b));
605        assert!(result.is_constant(c));
606        assert!(result.is_constant(d));
607        assert_eq!(result.foldable_operations, 3);
608    }
609
610    #[test]
611    fn test_mixed_constant_and_variable_graph() {
612        let mut graph = EinsumGraph::new();
613        let const_a = graph.add_tensor_with_metadata("const_A", create_constant_metadata());
614        let var_x = graph.add_tensor("var_X");
615        let result = graph.add_tensor("result");
616
617        graph
618            .add_node(EinsumNode::elem_binary("add", const_a, var_x, result))
619            .unwrap();
620
621        let analysis = analyze_constants(&graph).unwrap();
622
623        assert!(analysis.is_constant(const_a));
624        assert!(!analysis.is_constant(var_x));
625        assert!(!analysis.is_constant(result)); // Result is not constant (depends on variable)
626    }
627
628    #[test]
629    fn test_folding_stats_total_transformations() {
630        let stats = FoldingStats {
631            operations_folded: 2,
632            operations_simplified: 3,
633            operations_eliminated: 1,
634            estimated_speedup: 1.5,
635        };
636
637        assert_eq!(stats.total_transformations(), 6);
638    }
639
640    #[test]
641    fn test_speedup_estimation() {
642        let mut graph = EinsumGraph::new();
643        let a = graph.add_tensor_with_metadata("A", create_constant_metadata());
644        let b = graph.add_tensor("B");
645
646        graph
647            .add_node(EinsumNode::elem_unary("relu", a, b))
648            .unwrap();
649
650        let result = analyze_constants(&graph).unwrap();
651        assert!(result.estimated_speedup > 1.0);
652    }
653}