Skip to main content

tensorlogic_compiler/passes/
dataflow.rs

1//! Dataflow analysis for logical expressions and einsum graphs.
2//!
3//! This module provides dataflow analysis passes that track how data flows
4//! through expressions and computation graphs. These analyses enable powerful
5//! optimizations and help identify opportunities for parallelization.
6//!
7//! # Overview
8//!
9//! Dataflow analysis is a fundamental compiler technique that tracks:
10//! - **Reaching definitions**: Which variable assignments reach each point
11//! - **Live variables**: Which variables are used after each point
12//! - **Available expressions**: Which expressions have been computed
13//! - **Use-def chains**: Relationships between variable uses and definitions
14//!
15//! # Applications
16//!
17//! - Dead code elimination
18//! - Common subexpression elimination
19//! - Register allocation
20//! - Constant propagation
21//! - Loop optimization
22//!
23//! # Examples
24//!
25//! ```rust
26//! use tensorlogic_compiler::passes::analyze_dataflow;
27//! use tensorlogic_ir::{TLExpr, Term};
28//!
29//! let expr = TLExpr::and(
30//!     TLExpr::pred("P", vec![Term::var("x")]),
31//!     TLExpr::pred("Q", vec![Term::var("x")]),
32//! );
33//!
34//! let analysis = analyze_dataflow(&expr);
35//! println!("Live variables: {:?}", analysis.live_variables);
36//! ```
37
38use std::collections::{HashMap, HashSet};
39use tensorlogic_ir::{EinsumGraph, TLExpr, Term};
40
41/// Result of dataflow analysis on an expression.
42#[derive(Debug, Clone)]
43pub struct DataflowAnalysis {
44    /// Variables that are live (may be used later) at each expression
45    pub live_variables: HashMap<String, HashSet<String>>,
46    /// Reaching definitions for each variable
47    pub reaching_defs: HashMap<String, HashSet<String>>,
48    /// Available expressions at each program point
49    pub available_exprs: HashSet<String>,
50    /// Use-def chains mapping uses to their definitions
51    pub use_def_chains: HashMap<String, Vec<String>>,
52    /// Def-use chains mapping definitions to their uses
53    pub def_use_chains: HashMap<String, Vec<String>>,
54}
55
56impl DataflowAnalysis {
57    /// Create a new empty dataflow analysis.
58    pub fn new() -> Self {
59        Self {
60            live_variables: HashMap::new(),
61            reaching_defs: HashMap::new(),
62            available_exprs: HashSet::new(),
63            use_def_chains: HashMap::new(),
64            def_use_chains: HashMap::new(),
65        }
66    }
67
68    /// Check if a variable is live at a given point.
69    pub fn is_live(&self, expr_id: &str, var: &str) -> bool {
70        self.live_variables
71            .get(expr_id)
72            .map(|vars| vars.contains(var))
73            .unwrap_or(false)
74    }
75
76    /// Get all live variables at a given point.
77    pub fn get_live_vars(&self, expr_id: &str) -> HashSet<String> {
78        self.live_variables
79            .get(expr_id)
80            .cloned()
81            .unwrap_or_default()
82    }
83
84    /// Get reaching definitions for a variable.
85    pub fn get_reaching_defs(&self, var: &str) -> HashSet<String> {
86        self.reaching_defs.get(var).cloned().unwrap_or_default()
87    }
88
89    /// Check if an expression is available at a point.
90    pub fn is_available(&self, expr: &str) -> bool {
91        self.available_exprs.contains(expr)
92    }
93
94    /// Get use-def chain for a variable use.
95    pub fn get_use_def_chain(&self, var: &str) -> Vec<String> {
96        self.use_def_chains.get(var).cloned().unwrap_or_default()
97    }
98
99    /// Get def-use chain for a variable definition.
100    pub fn get_def_use_chain(&self, var: &str) -> Vec<String> {
101        self.def_use_chains.get(var).cloned().unwrap_or_default()
102    }
103}
104
105impl Default for DataflowAnalysis {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111/// Configuration for dataflow analysis.
112#[derive(Debug, Clone)]
113pub struct DataflowConfig {
114    /// Compute live variable analysis
115    pub compute_live_vars: bool,
116    /// Compute reaching definitions
117    pub compute_reaching_defs: bool,
118    /// Compute available expressions
119    pub compute_available_exprs: bool,
120    /// Compute use-def chains
121    pub compute_use_def_chains: bool,
122}
123
124impl Default for DataflowConfig {
125    fn default() -> Self {
126        Self {
127            compute_live_vars: true,
128            compute_reaching_defs: true,
129            compute_available_exprs: true,
130            compute_use_def_chains: true,
131        }
132    }
133}
134
135/// Perform dataflow analysis on a logical expression.
136pub fn analyze_dataflow(expr: &TLExpr) -> DataflowAnalysis {
137    analyze_dataflow_with_config(expr, &DataflowConfig::default())
138}
139
140/// Perform dataflow analysis with custom configuration.
141pub fn analyze_dataflow_with_config(expr: &TLExpr, config: &DataflowConfig) -> DataflowAnalysis {
142    let mut analysis = DataflowAnalysis::new();
143
144    if config.compute_live_vars {
145        compute_live_variables(expr, &mut analysis);
146    }
147
148    if config.compute_reaching_defs {
149        compute_reaching_definitions(expr, &mut analysis);
150    }
151
152    if config.compute_available_exprs {
153        compute_available_expressions(expr, &mut analysis);
154    }
155
156    if config.compute_use_def_chains {
157        compute_use_def_chains(expr, &mut analysis);
158    }
159
160    analysis
161}
162
163/// Compute live variable analysis.
164///
165/// A variable is live at a point if it may be used after that point.
166fn compute_live_variables(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
167    let expr_id = format!("{:?}", expr as *const _);
168    let mut live = HashSet::new();
169
170    // Collect variables used in this expression
171    match expr {
172        TLExpr::Pred { args, .. } => {
173            for arg in args {
174                if let Term::Var(v) = arg {
175                    live.insert(v.clone());
176                }
177            }
178        }
179        TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
180            // Union of live variables from both branches
181            compute_live_variables(lhs, analysis);
182            compute_live_variables(rhs, analysis);
183
184            let lhs_live = analysis.get_live_vars(&format!("{:?}", lhs.as_ref() as *const _));
185            let rhs_live = analysis.get_live_vars(&format!("{:?}", rhs.as_ref() as *const _));
186            live.extend(lhs_live);
187            live.extend(rhs_live);
188        }
189        TLExpr::Not(inner) => {
190            compute_live_variables(inner, analysis);
191            let inner_live = analysis.get_live_vars(&format!("{:?}", inner.as_ref() as *const _));
192            live.extend(inner_live);
193        }
194        TLExpr::Exists { var, body, .. } | TLExpr::ForAll { var, body, .. } => {
195            compute_live_variables(body, analysis);
196            let mut body_live = analysis.get_live_vars(&format!("{:?}", body.as_ref() as *const _));
197
198            // Remove the bound variable
199            body_live.remove(var);
200            live.extend(body_live);
201        }
202        TLExpr::Let { var, value, body } => {
203            compute_live_variables(value, analysis);
204            compute_live_variables(body, analysis);
205
206            let mut body_live = analysis.get_live_vars(&format!("{:?}", body.as_ref() as *const _));
207            let value_live = analysis.get_live_vars(&format!("{:?}", value.as_ref() as *const _));
208
209            // Variable is defined here, remove from live set
210            body_live.remove(var);
211            live.extend(body_live);
212            live.extend(value_live);
213        }
214        _ => {
215            // For other expressions, just collect free variables
216            live.extend(expr.free_vars());
217        }
218    }
219
220    analysis.live_variables.insert(expr_id, live);
221}
222
223/// Compute reaching definitions analysis.
224///
225/// A definition reaches a point if it may be the most recent assignment
226/// to a variable at that point.
227fn compute_reaching_definitions(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
228    match expr {
229        TLExpr::Let { var, value, body } => {
230            // This is a definition of 'var'
231            let def_id = format!("let_{}", var);
232            analysis
233                .reaching_defs
234                .entry(var.clone())
235                .or_default()
236                .insert(def_id);
237
238            compute_reaching_definitions(value, analysis);
239            compute_reaching_definitions(body, analysis);
240        }
241        TLExpr::Exists { var, body, .. } | TLExpr::ForAll { var, body, .. } => {
242            // Quantifier introduces a new scope for var
243            let def_id = format!("quant_{}", var);
244            analysis
245                .reaching_defs
246                .entry(var.clone())
247                .or_default()
248                .insert(def_id);
249
250            compute_reaching_definitions(body, analysis);
251        }
252        TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
253            compute_reaching_definitions(lhs, analysis);
254            compute_reaching_definitions(rhs, analysis);
255        }
256        TLExpr::Not(inner) => {
257            compute_reaching_definitions(inner, analysis);
258        }
259        _ => {
260            // Leaf expressions don't introduce definitions
261        }
262    }
263}
264
265/// Compute available expressions analysis.
266///
267/// An expression is available at a point if it has been computed and
268/// not invalidated since.
269fn compute_available_expressions(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
270    let expr_str = format!("{:?}", expr);
271
272    match expr {
273        TLExpr::Pred { .. } | TLExpr::Constant(_) => {
274            // Atomic expressions are always available
275            analysis.available_exprs.insert(expr_str);
276        }
277        TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
278            compute_available_expressions(lhs, analysis);
279            compute_available_expressions(rhs, analysis);
280
281            // This expression is available if both operands are
282            analysis.available_exprs.insert(expr_str);
283        }
284        TLExpr::Not(inner) => {
285            compute_available_expressions(inner, analysis);
286            analysis.available_exprs.insert(expr_str);
287        }
288        TLExpr::Let { value, body, .. } => {
289            compute_available_expressions(value, analysis);
290            compute_available_expressions(body, analysis);
291        }
292        _ => {
293            // Other expressions may be available
294            analysis.available_exprs.insert(expr_str);
295        }
296    }
297}
298
299/// Compute use-def chains.
300///
301/// Maps each variable use to the definitions that may reach it.
302fn compute_use_def_chains(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
303    // First compute reaching definitions
304    compute_reaching_definitions(expr, analysis);
305
306    // Then build use-def chains by linking uses to their reaching defs
307    collect_uses(expr, analysis);
308}
309
310/// Collect variable uses and link them to definitions.
311fn collect_uses(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
312    match expr {
313        TLExpr::Pred { args, .. } => {
314            for arg in args {
315                if let Term::Var(v) = arg {
316                    // Link this use to its reaching definitions
317                    let defs = analysis.get_reaching_defs(v);
318                    analysis
319                        .use_def_chains
320                        .entry(v.clone())
321                        .or_default()
322                        .extend(defs.iter().cloned());
323
324                    // Also update def-use chains
325                    for def in defs {
326                        analysis
327                            .def_use_chains
328                            .entry(def)
329                            .or_default()
330                            .push(v.clone());
331                    }
332                }
333            }
334        }
335        TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
336            collect_uses(lhs, analysis);
337            collect_uses(rhs, analysis);
338        }
339        TLExpr::Not(inner) => {
340            collect_uses(inner, analysis);
341        }
342        TLExpr::Let { value, body, .. } => {
343            collect_uses(value, analysis);
344            collect_uses(body, analysis);
345        }
346        TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
347            collect_uses(body, analysis);
348        }
349        _ => {}
350    }
351}
352
353/// Dataflow analysis for einsum graphs.
354#[derive(Debug, Clone)]
355pub struct GraphDataflow {
356    /// Live tensors at each node
357    pub live_tensors: HashMap<usize, HashSet<usize>>,
358    /// Tensor dependencies
359    pub dependencies: HashMap<usize, HashSet<usize>>,
360    /// Reverse dependencies (uses)
361    pub uses: HashMap<usize, HashSet<usize>>,
362}
363
364impl GraphDataflow {
365    /// Create a new graph dataflow analysis.
366    pub fn new() -> Self {
367        Self {
368            live_tensors: HashMap::new(),
369            dependencies: HashMap::new(),
370            uses: HashMap::new(),
371        }
372    }
373
374    /// Check if a tensor is live at a node.
375    pub fn is_tensor_live(&self, node: usize, tensor: usize) -> bool {
376        self.live_tensors
377            .get(&node)
378            .map(|tensors| tensors.contains(&tensor))
379            .unwrap_or(false)
380    }
381
382    /// Get dependencies of a tensor.
383    pub fn get_dependencies(&self, tensor: usize) -> HashSet<usize> {
384        self.dependencies.get(&tensor).cloned().unwrap_or_default()
385    }
386
387    /// Get uses of a tensor.
388    pub fn get_uses(&self, tensor: usize) -> HashSet<usize> {
389        self.uses.get(&tensor).cloned().unwrap_or_default()
390    }
391}
392
393impl Default for GraphDataflow {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399/// Analyze dataflow in an einsum graph.
400pub fn analyze_graph_dataflow(graph: &EinsumGraph) -> GraphDataflow {
401    let mut analysis = GraphDataflow::new();
402
403    // Compute dependencies
404    for (node_idx, node) in graph.nodes.iter().enumerate() {
405        for &output in &node.outputs {
406            let mut deps = HashSet::new();
407            deps.extend(&node.inputs);
408
409            analysis.dependencies.insert(output, deps);
410
411            // Update reverse dependencies (uses)
412            for &input in &node.inputs {
413                analysis.uses.entry(input).or_default().insert(node_idx);
414            }
415        }
416    }
417
418    // Compute live tensors (backward analysis)
419    let mut live: HashSet<usize> = HashSet::new();
420    live.extend(&graph.outputs);
421
422    for (node_idx, node) in graph.nodes.iter().enumerate().rev() {
423        // Tensors are live if they're used by later nodes or are outputs
424        let node_live: HashSet<usize> = node
425            .outputs
426            .iter()
427            .filter(|&&t| live.contains(&t) || graph.outputs.contains(&t))
428            .copied()
429            .collect();
430
431        if !node_live.is_empty() {
432            // Add inputs to live set
433            live.extend(&node.inputs);
434        }
435
436        analysis.live_tensors.insert(node_idx, node_live);
437    }
438
439    analysis
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_live_variables_simple() {
448        let expr = TLExpr::pred("P", vec![Term::var("x")]);
449        let analysis = analyze_dataflow(&expr);
450
451        // x should be live in the predicate
452        assert!(!analysis.live_variables.is_empty());
453    }
454
455    #[test]
456    fn test_live_variables_and() {
457        let expr = TLExpr::and(
458            TLExpr::pred("P", vec![Term::var("x")]),
459            TLExpr::pred("Q", vec![Term::var("y")]),
460        );
461
462        let analysis = analyze_dataflow(&expr);
463
464        // Both x and y should be live
465        assert!(!analysis.live_variables.is_empty());
466    }
467
468    #[test]
469    fn test_reaching_definitions_let() {
470        let expr = TLExpr::Let {
471            var: "x".to_string(),
472            value: Box::new(TLExpr::Constant(1.0)),
473            body: Box::new(TLExpr::pred("P", vec![Term::var("x")])),
474        };
475
476        let analysis = analyze_dataflow(&expr);
477
478        // Should have a reaching definition for x
479        assert!(analysis.reaching_defs.contains_key("x"));
480    }
481
482    #[test]
483    fn test_quantifier_binding() {
484        let expr = TLExpr::exists("x", "Domain", TLExpr::pred("P", vec![Term::var("x")]));
485
486        let analysis = analyze_dataflow(&expr);
487
488        // x is bound by exists, so it shouldn't be in the live set of the outer expression
489        let expr_id = format!("{:?}", &expr as *const _);
490        let live = analysis.get_live_vars(&expr_id);
491        assert!(!live.contains("x"));
492    }
493
494    #[test]
495    fn test_available_expressions() {
496        let expr = TLExpr::and(
497            TLExpr::pred("P", vec![Term::var("x")]),
498            TLExpr::pred("Q", vec![Term::var("x")]),
499        );
500
501        let analysis = analyze_dataflow(&expr);
502
503        // Should have available expressions
504        assert!(!analysis.available_exprs.is_empty());
505    }
506
507    #[test]
508    fn test_graph_dataflow() {
509        let mut graph = EinsumGraph::new();
510        let t0 = graph.add_tensor("t0");
511        let t1 = graph.add_tensor("t1");
512
513        let node = graph
514            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
515            .unwrap();
516        graph.add_output(t1).unwrap();
517
518        let analysis = analyze_graph_dataflow(&graph);
519
520        // t0 should be a dependency of t1
521        let deps = analysis.get_dependencies(t1);
522        assert!(deps.contains(&t0));
523
524        // t1 should be live
525        assert!(analysis.is_tensor_live(node, t1));
526    }
527
528    #[test]
529    fn test_dataflow_config() {
530        let config = DataflowConfig {
531            compute_live_vars: true,
532            compute_reaching_defs: false,
533            compute_available_exprs: false,
534            compute_use_def_chains: false,
535        };
536
537        let expr = TLExpr::pred("P", vec![Term::var("x")]);
538        let analysis = analyze_dataflow_with_config(&expr, &config);
539
540        // Only live variables should be computed
541        assert!(!analysis.live_variables.is_empty());
542    }
543
544    #[test]
545    fn test_use_def_chains() {
546        let expr = TLExpr::Let {
547            var: "x".to_string(),
548            value: Box::new(TLExpr::Constant(1.0)),
549            body: Box::new(TLExpr::pred("P", vec![Term::var("x")])),
550        };
551
552        let analysis = analyze_dataflow(&expr);
553
554        // Should have use-def chains for x
555        assert!(!analysis.use_def_chains.is_empty() || !analysis.def_use_chains.is_empty());
556    }
557
558    #[test]
559    fn test_graph_dependencies() {
560        let mut graph = EinsumGraph::new();
561        let t0 = graph.add_tensor("t0");
562        let t1 = graph.add_tensor("t1");
563        let t2 = graph.add_tensor("t2");
564
565        graph
566            .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
567            .unwrap();
568        graph
569            .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
570            .unwrap();
571
572        let analysis = analyze_graph_dataflow(&graph);
573
574        // t2 depends on t1, t1 depends on t0
575        assert!(analysis.get_dependencies(t1).contains(&t0));
576        assert!(analysis.get_dependencies(t2).contains(&t1));
577    }
578
579    #[test]
580    fn test_dataflow_analysis_default() {
581        let analysis = DataflowAnalysis::new();
582        assert!(analysis.live_variables.is_empty());
583        assert!(analysis.reaching_defs.is_empty());
584        assert!(analysis.available_exprs.is_empty());
585    }
586}