Skip to main content

tensorlogic_compiler/passes/
metadata_propagation.rs

1//! Metadata propagation for provenance tracking and debugging.
2//!
3//! This module provides utilities for attaching metadata to compiled tensor
4//! graphs, enabling better debugging, provenance tracking, and understanding
5//! of the compilation process.
6
7use std::collections::HashMap;
8use tensorlogic_ir::{EinsumGraph, EinsumNode, Metadata, TLExpr};
9
10use crate::CompilerContext;
11
12/// Metadata builder for tracking compilation provenance
13pub struct MetadataBuilder {
14    /// Current source file being compiled
15    source_file: Option<String>,
16    /// Current rule ID being compiled
17    rule_id: Option<String>,
18    /// Counter for generating unique rule IDs
19    rule_counter: usize,
20}
21
22impl MetadataBuilder {
23    /// Create a new metadata builder
24    pub fn new() -> Self {
25        Self {
26            source_file: None,
27            rule_id: None,
28            rule_counter: 0,
29        }
30    }
31
32    /// Set the current source file
33    pub fn with_source_file(mut self, file: impl Into<String>) -> Self {
34        self.source_file = Some(file.into());
35        self
36    }
37
38    /// Set the current rule ID
39    pub fn with_rule_id(mut self, rule_id: impl Into<String>) -> Self {
40        self.rule_id = Some(rule_id.into());
41        self
42    }
43
44    /// Generate a fresh rule ID
45    pub fn fresh_rule_id(&mut self) -> String {
46        let id = format!("rule_{}", self.rule_counter);
47        self.rule_counter += 1;
48        id
49    }
50
51    /// Create metadata for a predicate
52    pub fn predicate_metadata(&mut self, name: &str, args: &[String]) -> Metadata {
53        let mut meta = Metadata::new().with_name(format!("predicate:{}", name));
54
55        if let Some(ref file) = self.source_file {
56            meta = meta.with_attribute("source_file", file.clone());
57        }
58
59        if let Some(ref rule) = self.rule_id {
60            meta = meta.with_attribute("rule_id", rule.clone());
61        }
62
63        meta = meta.with_attribute("predicate_name", name.to_string());
64        meta = meta.with_attribute("arity", args.len().to_string());
65
66        for (i, arg) in args.iter().enumerate() {
67            meta = meta.with_attribute(format!("arg_{}", i), arg.clone());
68        }
69
70        meta
71    }
72
73    /// Create metadata for a logical operation
74    pub fn logic_op_metadata(&mut self, op_type: &str, operand_count: usize) -> Metadata {
75        let mut meta = Metadata::new().with_name(format!("logic_op:{}", op_type));
76
77        if let Some(ref file) = self.source_file {
78            meta = meta.with_attribute("source_file", file.clone());
79        }
80
81        if let Some(ref rule) = self.rule_id {
82            meta = meta.with_attribute("rule_id", rule.clone());
83        }
84
85        meta = meta.with_attribute("operation", op_type.to_string());
86        meta = meta.with_attribute("operand_count", operand_count.to_string());
87
88        meta
89    }
90
91    /// Create metadata for a quantifier
92    pub fn quantifier_metadata(
93        &mut self,
94        quantifier_type: &str,
95        var: &str,
96        domain: &str,
97    ) -> Metadata {
98        let mut meta = Metadata::new().with_name(format!("quantifier:{}", quantifier_type));
99
100        if let Some(ref file) = self.source_file {
101            meta = meta.with_attribute("source_file", file.clone());
102        }
103
104        if let Some(ref rule) = self.rule_id {
105            meta = meta.with_attribute("rule_id", rule.clone());
106        }
107
108        meta = meta.with_attribute("quantifier", quantifier_type.to_string());
109        meta = meta.with_attribute("variable", var.to_string());
110        meta = meta.with_attribute("domain", domain.to_string());
111
112        meta
113    }
114
115    /// Create metadata from TLExpr
116    pub fn from_expr(&mut self, expr: &TLExpr) -> Metadata {
117        match expr {
118            TLExpr::Pred { name, args } => {
119                let arg_names: Vec<String> = args.iter().map(|t| format!("{:?}", t)).collect();
120                self.predicate_metadata(name, &arg_names)
121            }
122            TLExpr::And(_, _) => self.logic_op_metadata("AND", 2),
123            TLExpr::Or(_, _) => self.logic_op_metadata("OR", 2),
124            TLExpr::Not(_) => self.logic_op_metadata("NOT", 1),
125            TLExpr::Imply(_, _) => self.logic_op_metadata("IMPLY", 2),
126            TLExpr::Exists { var, domain, .. } => self.quantifier_metadata("EXISTS", var, domain),
127            TLExpr::ForAll { var, domain, .. } => self.quantifier_metadata("FORALL", var, domain),
128            TLExpr::Constant(_) => Metadata::new()
129                .with_name("constant")
130                .with_attribute("type", "constant"),
131            _ => Metadata::new().with_name("expression"),
132        }
133    }
134}
135
136impl Default for MetadataBuilder {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Propagate metadata through a compiled graph
143pub fn propagate_metadata(
144    graph: &mut EinsumGraph,
145    ctx: &CompilerContext,
146    _builder: &mut MetadataBuilder,
147) {
148    // Collect metadata to add (to avoid borrowing issues)
149    let mut metadata_to_add: Vec<(usize, Metadata)> = Vec::new();
150
151    // Add domain metadata to input tensors
152    for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
153        if graph.inputs.contains(&tensor_idx) {
154            // Check if this tensor corresponds to a predicate
155            if let Some(domain_name) = ctx.var_to_domain.values().find(|d| {
156                tensor_name.starts_with(&format!("{}_", d))
157                    || tensor_name.contains(&format!("_{}_", d))
158            }) {
159                let mut meta = Metadata::new()
160                    .with_name(format!("input_tensor:{}", tensor_name))
161                    .with_attribute("domain", domain_name.clone())
162                    .with_attribute("tensor_type", "input");
163
164                if let Some(domain_info) = ctx.domains.get(domain_name) {
165                    meta = meta.with_attribute("cardinality", domain_info.cardinality.to_string());
166                }
167
168                metadata_to_add.push((tensor_idx, meta));
169            }
170        }
171    }
172
173    // Add domain information as graph-level metadata
174    for (domain_name, domain_info) in &ctx.domains {
175        // This could be stored as a special metadata attribute on the graph
176        // For now, we'll add it as metadata on output tensors if they relate to this domain
177        for &output_idx in &graph.outputs {
178            if let Some(tensor_name) = graph.tensors.get(output_idx) {
179                if tensor_name.contains(domain_name) {
180                    let meta = Metadata::new()
181                        .with_name(format!("output_tensor:{}", tensor_name))
182                        .with_attribute("domain", domain_name.clone())
183                        .with_attribute("cardinality", domain_info.cardinality.to_string())
184                        .with_attribute("tensor_type", "output");
185
186                    metadata_to_add.push((output_idx, meta));
187                }
188            }
189        }
190    }
191
192    // Add all collected metadata
193    for (idx, meta) in metadata_to_add {
194        graph.add_tensor_metadata(idx, meta);
195    }
196}
197
198/// Enhanced compilation result with metadata
199pub struct MetadataCompilationResult {
200    /// The compiled graph
201    pub graph: EinsumGraph,
202    /// Metadata builder used during compilation
203    pub builder: MetadataBuilder,
204    /// Mapping from expression to node indices
205    pub expr_to_nodes: HashMap<String, Vec<usize>>,
206}
207
208impl MetadataCompilationResult {
209    /// Create a new result
210    pub fn new(graph: EinsumGraph, builder: MetadataBuilder) -> Self {
211        Self {
212            graph,
213            builder,
214            expr_to_nodes: HashMap::new(),
215        }
216    }
217
218    /// Record that an expression was compiled to specific nodes
219    pub fn record_expression(&mut self, expr_id: impl Into<String>, node_indices: Vec<usize>) {
220        self.expr_to_nodes.insert(expr_id.into(), node_indices);
221    }
222
223    /// Get nodes for an expression
224    pub fn get_nodes_for_expr(&self, expr_id: &str) -> Option<&[usize]> {
225        self.expr_to_nodes.get(expr_id).map(|v| v.as_slice())
226    }
227}
228
229/// Helper to attach metadata to nodes based on expression type
230pub fn attach_expr_metadata(node: &mut EinsumNode, expr: &TLExpr, builder: &mut MetadataBuilder) {
231    let metadata = builder.from_expr(expr);
232    node.set_metadata(metadata);
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use tensorlogic_ir::Term;
239
240    #[test]
241    fn test_metadata_builder_new() {
242        let builder = MetadataBuilder::new();
243        assert!(builder.source_file.is_none());
244        assert!(builder.rule_id.is_none());
245        assert_eq!(builder.rule_counter, 0);
246    }
247
248    #[test]
249    fn test_metadata_builder_with_source_file() {
250        let builder = MetadataBuilder::new().with_source_file("test.tl");
251        assert_eq!(builder.source_file, Some("test.tl".to_string()));
252    }
253
254    #[test]
255    fn test_metadata_builder_fresh_rule_id() {
256        let mut builder = MetadataBuilder::new();
257        let id1 = builder.fresh_rule_id();
258        let id2 = builder.fresh_rule_id();
259        assert_eq!(id1, "rule_0");
260        assert_eq!(id2, "rule_1");
261    }
262
263    #[test]
264    fn test_predicate_metadata() {
265        let mut builder = MetadataBuilder::new()
266            .with_source_file("test.tl")
267            .with_rule_id("rule_1");
268
269        let meta = builder.predicate_metadata("knows", &["x".to_string(), "y".to_string()]);
270
271        assert_eq!(meta.name, Some("predicate:knows".to_string()));
272        assert_eq!(meta.get_attribute("predicate_name"), Some("knows"));
273        assert_eq!(meta.get_attribute("arity"), Some("2"));
274        assert_eq!(meta.get_attribute("source_file"), Some("test.tl"));
275        assert_eq!(meta.get_attribute("rule_id"), Some("rule_1"));
276    }
277
278    #[test]
279    fn test_logic_op_metadata() {
280        let mut builder = MetadataBuilder::new();
281        let meta = builder.logic_op_metadata("AND", 2);
282
283        assert_eq!(meta.name, Some("logic_op:AND".to_string()));
284        assert_eq!(meta.get_attribute("operation"), Some("AND"));
285        assert_eq!(meta.get_attribute("operand_count"), Some("2"));
286    }
287
288    #[test]
289    fn test_quantifier_metadata() {
290        let mut builder = MetadataBuilder::new();
291        let meta = builder.quantifier_metadata("EXISTS", "x", "Person");
292
293        assert_eq!(meta.name, Some("quantifier:EXISTS".to_string()));
294        assert_eq!(meta.get_attribute("quantifier"), Some("EXISTS"));
295        assert_eq!(meta.get_attribute("variable"), Some("x"));
296        assert_eq!(meta.get_attribute("domain"), Some("Person"));
297    }
298
299    #[test]
300    fn test_from_expr_predicate() {
301        let mut builder = MetadataBuilder::new();
302        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
303        let meta = builder.from_expr(&expr);
304
305        assert_eq!(meta.name, Some("predicate:knows".to_string()));
306        assert_eq!(meta.get_attribute("predicate_name"), Some("knows"));
307    }
308
309    #[test]
310    fn test_from_expr_and() {
311        let mut builder = MetadataBuilder::new();
312        let expr = TLExpr::And(
313            Box::new(TLExpr::pred("p", vec![Term::var("x")])),
314            Box::new(TLExpr::pred("q", vec![Term::var("y")])),
315        );
316        let meta = builder.from_expr(&expr);
317
318        assert_eq!(meta.name, Some("logic_op:AND".to_string()));
319        assert_eq!(meta.get_attribute("operation"), Some("AND"));
320    }
321
322    #[test]
323    fn test_from_expr_exists() {
324        let mut builder = MetadataBuilder::new();
325        let expr = TLExpr::exists(
326            "x",
327            "Person",
328            TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
329        );
330        let meta = builder.from_expr(&expr);
331
332        assert_eq!(meta.name, Some("quantifier:EXISTS".to_string()));
333        assert_eq!(meta.get_attribute("quantifier"), Some("EXISTS"));
334        assert_eq!(meta.get_attribute("variable"), Some("x"));
335        assert_eq!(meta.get_attribute("domain"), Some("Person"));
336    }
337
338    #[test]
339    fn test_propagate_metadata_with_domains() {
340        let mut ctx = CompilerContext::new();
341        ctx.add_domain("Person", 100);
342        ctx.bind_var("x", "Person").unwrap();
343
344        let mut graph = EinsumGraph::new();
345        let tensor_idx = graph.add_tensor("Person_x");
346        graph.inputs.push(tensor_idx);
347
348        let mut builder = MetadataBuilder::new();
349        propagate_metadata(&mut graph, &ctx, &mut builder);
350
351        // Check that metadata was added
352        let meta = graph.get_tensor_metadata(tensor_idx);
353        assert!(meta.is_some());
354    }
355
356    #[test]
357    fn test_metadata_compilation_result() {
358        let graph = EinsumGraph::new();
359        let builder = MetadataBuilder::new();
360        let mut result = MetadataCompilationResult::new(graph, builder);
361
362        result.record_expression("expr_1", vec![0, 1, 2]);
363        assert_eq!(result.get_nodes_for_expr("expr_1"), Some(&[0, 1, 2][..]));
364        assert_eq!(result.get_nodes_for_expr("expr_2"), None);
365    }
366
367    #[test]
368    fn test_attach_expr_metadata() {
369        let mut builder = MetadataBuilder::new();
370        let mut node = EinsumNode::new("ab->a", vec![0], vec![1]);
371        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
372
373        attach_expr_metadata(&mut node, &expr, &mut builder);
374
375        let meta = node.get_metadata();
376        assert!(meta.is_some());
377        assert_eq!(meta.unwrap().get_attribute("predicate_name"), Some("knows"));
378    }
379}