1use std::collections::{HashMap, HashSet};
39use tensorlogic_ir::{EinsumGraph, TLExpr, Term};
40
41#[derive(Debug, Clone)]
43pub struct DataflowAnalysis {
44 pub live_variables: HashMap<String, HashSet<String>>,
46 pub reaching_defs: HashMap<String, HashSet<String>>,
48 pub available_exprs: HashSet<String>,
50 pub use_def_chains: HashMap<String, Vec<String>>,
52 pub def_use_chains: HashMap<String, Vec<String>>,
54}
55
56impl DataflowAnalysis {
57 pub fn new() -> Self {
59 Self {
60 live_variables: HashMap::new(),
61 reaching_defs: HashMap::new(),
62 available_exprs: HashSet::new(),
63 use_def_chains: HashMap::new(),
64 def_use_chains: HashMap::new(),
65 }
66 }
67
68 pub fn is_live(&self, expr_id: &str, var: &str) -> bool {
70 self.live_variables
71 .get(expr_id)
72 .map(|vars| vars.contains(var))
73 .unwrap_or(false)
74 }
75
76 pub fn get_live_vars(&self, expr_id: &str) -> HashSet<String> {
78 self.live_variables
79 .get(expr_id)
80 .cloned()
81 .unwrap_or_default()
82 }
83
84 pub fn get_reaching_defs(&self, var: &str) -> HashSet<String> {
86 self.reaching_defs.get(var).cloned().unwrap_or_default()
87 }
88
89 pub fn is_available(&self, expr: &str) -> bool {
91 self.available_exprs.contains(expr)
92 }
93
94 pub fn get_use_def_chain(&self, var: &str) -> Vec<String> {
96 self.use_def_chains.get(var).cloned().unwrap_or_default()
97 }
98
99 pub fn get_def_use_chain(&self, var: &str) -> Vec<String> {
101 self.def_use_chains.get(var).cloned().unwrap_or_default()
102 }
103}
104
105impl Default for DataflowAnalysis {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct DataflowConfig {
114 pub compute_live_vars: bool,
116 pub compute_reaching_defs: bool,
118 pub compute_available_exprs: bool,
120 pub compute_use_def_chains: bool,
122}
123
124impl Default for DataflowConfig {
125 fn default() -> Self {
126 Self {
127 compute_live_vars: true,
128 compute_reaching_defs: true,
129 compute_available_exprs: true,
130 compute_use_def_chains: true,
131 }
132 }
133}
134
135pub fn analyze_dataflow(expr: &TLExpr) -> DataflowAnalysis {
137 analyze_dataflow_with_config(expr, &DataflowConfig::default())
138}
139
140pub fn analyze_dataflow_with_config(expr: &TLExpr, config: &DataflowConfig) -> DataflowAnalysis {
142 let mut analysis = DataflowAnalysis::new();
143
144 if config.compute_live_vars {
145 compute_live_variables(expr, &mut analysis);
146 }
147
148 if config.compute_reaching_defs {
149 compute_reaching_definitions(expr, &mut analysis);
150 }
151
152 if config.compute_available_exprs {
153 compute_available_expressions(expr, &mut analysis);
154 }
155
156 if config.compute_use_def_chains {
157 compute_use_def_chains(expr, &mut analysis);
158 }
159
160 analysis
161}
162
163fn compute_live_variables(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
167 let expr_id = format!("{:?}", expr as *const _);
168 let mut live = HashSet::new();
169
170 match expr {
172 TLExpr::Pred { args, .. } => {
173 for arg in args {
174 if let Term::Var(v) = arg {
175 live.insert(v.clone());
176 }
177 }
178 }
179 TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
180 compute_live_variables(lhs, analysis);
182 compute_live_variables(rhs, analysis);
183
184 let lhs_live = analysis.get_live_vars(&format!("{:?}", lhs.as_ref() as *const _));
185 let rhs_live = analysis.get_live_vars(&format!("{:?}", rhs.as_ref() as *const _));
186 live.extend(lhs_live);
187 live.extend(rhs_live);
188 }
189 TLExpr::Not(inner) => {
190 compute_live_variables(inner, analysis);
191 let inner_live = analysis.get_live_vars(&format!("{:?}", inner.as_ref() as *const _));
192 live.extend(inner_live);
193 }
194 TLExpr::Exists { var, body, .. } | TLExpr::ForAll { var, body, .. } => {
195 compute_live_variables(body, analysis);
196 let mut body_live = analysis.get_live_vars(&format!("{:?}", body.as_ref() as *const _));
197
198 body_live.remove(var);
200 live.extend(body_live);
201 }
202 TLExpr::Let { var, value, body } => {
203 compute_live_variables(value, analysis);
204 compute_live_variables(body, analysis);
205
206 let mut body_live = analysis.get_live_vars(&format!("{:?}", body.as_ref() as *const _));
207 let value_live = analysis.get_live_vars(&format!("{:?}", value.as_ref() as *const _));
208
209 body_live.remove(var);
211 live.extend(body_live);
212 live.extend(value_live);
213 }
214 _ => {
215 live.extend(expr.free_vars());
217 }
218 }
219
220 analysis.live_variables.insert(expr_id, live);
221}
222
223fn compute_reaching_definitions(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
228 match expr {
229 TLExpr::Let { var, value, body } => {
230 let def_id = format!("let_{}", var);
232 analysis
233 .reaching_defs
234 .entry(var.clone())
235 .or_default()
236 .insert(def_id);
237
238 compute_reaching_definitions(value, analysis);
239 compute_reaching_definitions(body, analysis);
240 }
241 TLExpr::Exists { var, body, .. } | TLExpr::ForAll { var, body, .. } => {
242 let def_id = format!("quant_{}", var);
244 analysis
245 .reaching_defs
246 .entry(var.clone())
247 .or_default()
248 .insert(def_id);
249
250 compute_reaching_definitions(body, analysis);
251 }
252 TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
253 compute_reaching_definitions(lhs, analysis);
254 compute_reaching_definitions(rhs, analysis);
255 }
256 TLExpr::Not(inner) => {
257 compute_reaching_definitions(inner, analysis);
258 }
259 _ => {
260 }
262 }
263}
264
265fn compute_available_expressions(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
270 let expr_str = format!("{:?}", expr);
271
272 match expr {
273 TLExpr::Pred { .. } | TLExpr::Constant(_) => {
274 analysis.available_exprs.insert(expr_str);
276 }
277 TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
278 compute_available_expressions(lhs, analysis);
279 compute_available_expressions(rhs, analysis);
280
281 analysis.available_exprs.insert(expr_str);
283 }
284 TLExpr::Not(inner) => {
285 compute_available_expressions(inner, analysis);
286 analysis.available_exprs.insert(expr_str);
287 }
288 TLExpr::Let { value, body, .. } => {
289 compute_available_expressions(value, analysis);
290 compute_available_expressions(body, analysis);
291 }
292 _ => {
293 analysis.available_exprs.insert(expr_str);
295 }
296 }
297}
298
299fn compute_use_def_chains(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
303 compute_reaching_definitions(expr, analysis);
305
306 collect_uses(expr, analysis);
308}
309
310fn collect_uses(expr: &TLExpr, analysis: &mut DataflowAnalysis) {
312 match expr {
313 TLExpr::Pred { args, .. } => {
314 for arg in args {
315 if let Term::Var(v) = arg {
316 let defs = analysis.get_reaching_defs(v);
318 analysis
319 .use_def_chains
320 .entry(v.clone())
321 .or_default()
322 .extend(defs.iter().cloned());
323
324 for def in defs {
326 analysis
327 .def_use_chains
328 .entry(def)
329 .or_default()
330 .push(v.clone());
331 }
332 }
333 }
334 }
335 TLExpr::And(lhs, rhs) | TLExpr::Or(lhs, rhs) | TLExpr::Imply(lhs, rhs) => {
336 collect_uses(lhs, analysis);
337 collect_uses(rhs, analysis);
338 }
339 TLExpr::Not(inner) => {
340 collect_uses(inner, analysis);
341 }
342 TLExpr::Let { value, body, .. } => {
343 collect_uses(value, analysis);
344 collect_uses(body, analysis);
345 }
346 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
347 collect_uses(body, analysis);
348 }
349 _ => {}
350 }
351}
352
353#[derive(Debug, Clone)]
355pub struct GraphDataflow {
356 pub live_tensors: HashMap<usize, HashSet<usize>>,
358 pub dependencies: HashMap<usize, HashSet<usize>>,
360 pub uses: HashMap<usize, HashSet<usize>>,
362}
363
364impl GraphDataflow {
365 pub fn new() -> Self {
367 Self {
368 live_tensors: HashMap::new(),
369 dependencies: HashMap::new(),
370 uses: HashMap::new(),
371 }
372 }
373
374 pub fn is_tensor_live(&self, node: usize, tensor: usize) -> bool {
376 self.live_tensors
377 .get(&node)
378 .map(|tensors| tensors.contains(&tensor))
379 .unwrap_or(false)
380 }
381
382 pub fn get_dependencies(&self, tensor: usize) -> HashSet<usize> {
384 self.dependencies.get(&tensor).cloned().unwrap_or_default()
385 }
386
387 pub fn get_uses(&self, tensor: usize) -> HashSet<usize> {
389 self.uses.get(&tensor).cloned().unwrap_or_default()
390 }
391}
392
393impl Default for GraphDataflow {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399pub fn analyze_graph_dataflow(graph: &EinsumGraph) -> GraphDataflow {
401 let mut analysis = GraphDataflow::new();
402
403 for (node_idx, node) in graph.nodes.iter().enumerate() {
405 for &output in &node.outputs {
406 let mut deps = HashSet::new();
407 deps.extend(&node.inputs);
408
409 analysis.dependencies.insert(output, deps);
410
411 for &input in &node.inputs {
413 analysis.uses.entry(input).or_default().insert(node_idx);
414 }
415 }
416 }
417
418 let mut live: HashSet<usize> = HashSet::new();
420 live.extend(&graph.outputs);
421
422 for (node_idx, node) in graph.nodes.iter().enumerate().rev() {
423 let node_live: HashSet<usize> = node
425 .outputs
426 .iter()
427 .filter(|&&t| live.contains(&t) || graph.outputs.contains(&t))
428 .copied()
429 .collect();
430
431 if !node_live.is_empty() {
432 live.extend(&node.inputs);
434 }
435
436 analysis.live_tensors.insert(node_idx, node_live);
437 }
438
439 analysis
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_live_variables_simple() {
448 let expr = TLExpr::pred("P", vec![Term::var("x")]);
449 let analysis = analyze_dataflow(&expr);
450
451 assert!(!analysis.live_variables.is_empty());
453 }
454
455 #[test]
456 fn test_live_variables_and() {
457 let expr = TLExpr::and(
458 TLExpr::pred("P", vec![Term::var("x")]),
459 TLExpr::pred("Q", vec![Term::var("y")]),
460 );
461
462 let analysis = analyze_dataflow(&expr);
463
464 assert!(!analysis.live_variables.is_empty());
466 }
467
468 #[test]
469 fn test_reaching_definitions_let() {
470 let expr = TLExpr::Let {
471 var: "x".to_string(),
472 value: Box::new(TLExpr::Constant(1.0)),
473 body: Box::new(TLExpr::pred("P", vec![Term::var("x")])),
474 };
475
476 let analysis = analyze_dataflow(&expr);
477
478 assert!(analysis.reaching_defs.contains_key("x"));
480 }
481
482 #[test]
483 fn test_quantifier_binding() {
484 let expr = TLExpr::exists("x", "Domain", TLExpr::pred("P", vec![Term::var("x")]));
485
486 let analysis = analyze_dataflow(&expr);
487
488 let expr_id = format!("{:?}", &expr as *const _);
490 let live = analysis.get_live_vars(&expr_id);
491 assert!(!live.contains("x"));
492 }
493
494 #[test]
495 fn test_available_expressions() {
496 let expr = TLExpr::and(
497 TLExpr::pred("P", vec![Term::var("x")]),
498 TLExpr::pred("Q", vec![Term::var("x")]),
499 );
500
501 let analysis = analyze_dataflow(&expr);
502
503 assert!(!analysis.available_exprs.is_empty());
505 }
506
507 #[test]
508 fn test_graph_dataflow() {
509 let mut graph = EinsumGraph::new();
510 let t0 = graph.add_tensor("t0");
511 let t1 = graph.add_tensor("t1");
512
513 let node = graph
514 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
515 .unwrap();
516 graph.add_output(t1).unwrap();
517
518 let analysis = analyze_graph_dataflow(&graph);
519
520 let deps = analysis.get_dependencies(t1);
522 assert!(deps.contains(&t0));
523
524 assert!(analysis.is_tensor_live(node, t1));
526 }
527
528 #[test]
529 fn test_dataflow_config() {
530 let config = DataflowConfig {
531 compute_live_vars: true,
532 compute_reaching_defs: false,
533 compute_available_exprs: false,
534 compute_use_def_chains: false,
535 };
536
537 let expr = TLExpr::pred("P", vec![Term::var("x")]);
538 let analysis = analyze_dataflow_with_config(&expr, &config);
539
540 assert!(!analysis.live_variables.is_empty());
542 }
543
544 #[test]
545 fn test_use_def_chains() {
546 let expr = TLExpr::Let {
547 var: "x".to_string(),
548 value: Box::new(TLExpr::Constant(1.0)),
549 body: Box::new(TLExpr::pred("P", vec![Term::var("x")])),
550 };
551
552 let analysis = analyze_dataflow(&expr);
553
554 assert!(!analysis.use_def_chains.is_empty() || !analysis.def_use_chains.is_empty());
556 }
557
558 #[test]
559 fn test_graph_dependencies() {
560 let mut graph = EinsumGraph::new();
561 let t0 = graph.add_tensor("t0");
562 let t1 = graph.add_tensor("t1");
563 let t2 = graph.add_tensor("t2");
564
565 graph
566 .add_node(tensorlogic_ir::EinsumNode::elem_unary("exp", t0, t1))
567 .unwrap();
568 graph
569 .add_node(tensorlogic_ir::EinsumNode::elem_unary("log", t1, t2))
570 .unwrap();
571
572 let analysis = analyze_graph_dataflow(&graph);
573
574 assert!(analysis.get_dependencies(t1).contains(&t0));
576 assert!(analysis.get_dependencies(t2).contains(&t1));
577 }
578
579 #[test]
580 fn test_dataflow_analysis_default() {
581 let analysis = DataflowAnalysis::new();
582 assert!(analysis.live_variables.is_empty());
583 assert!(analysis.reaching_defs.is_empty());
584 assert!(analysis.available_exprs.is_empty());
585 }
586}