tensorlogic_compiler/passes/
metadata_propagation.rs1use std::collections::HashMap;
8use tensorlogic_ir::{EinsumGraph, EinsumNode, Metadata, TLExpr};
9
10use crate::CompilerContext;
11
12pub struct MetadataBuilder {
14 source_file: Option<String>,
16 rule_id: Option<String>,
18 rule_counter: usize,
20}
21
22impl MetadataBuilder {
23 pub fn new() -> Self {
25 Self {
26 source_file: None,
27 rule_id: None,
28 rule_counter: 0,
29 }
30 }
31
32 pub fn with_source_file(mut self, file: impl Into<String>) -> Self {
34 self.source_file = Some(file.into());
35 self
36 }
37
38 pub fn with_rule_id(mut self, rule_id: impl Into<String>) -> Self {
40 self.rule_id = Some(rule_id.into());
41 self
42 }
43
44 pub fn fresh_rule_id(&mut self) -> String {
46 let id = format!("rule_{}", self.rule_counter);
47 self.rule_counter += 1;
48 id
49 }
50
51 pub fn predicate_metadata(&mut self, name: &str, args: &[String]) -> Metadata {
53 let mut meta = Metadata::new().with_name(format!("predicate:{}", name));
54
55 if let Some(ref file) = self.source_file {
56 meta = meta.with_attribute("source_file", file.clone());
57 }
58
59 if let Some(ref rule) = self.rule_id {
60 meta = meta.with_attribute("rule_id", rule.clone());
61 }
62
63 meta = meta.with_attribute("predicate_name", name.to_string());
64 meta = meta.with_attribute("arity", args.len().to_string());
65
66 for (i, arg) in args.iter().enumerate() {
67 meta = meta.with_attribute(format!("arg_{}", i), arg.clone());
68 }
69
70 meta
71 }
72
73 pub fn logic_op_metadata(&mut self, op_type: &str, operand_count: usize) -> Metadata {
75 let mut meta = Metadata::new().with_name(format!("logic_op:{}", op_type));
76
77 if let Some(ref file) = self.source_file {
78 meta = meta.with_attribute("source_file", file.clone());
79 }
80
81 if let Some(ref rule) = self.rule_id {
82 meta = meta.with_attribute("rule_id", rule.clone());
83 }
84
85 meta = meta.with_attribute("operation", op_type.to_string());
86 meta = meta.with_attribute("operand_count", operand_count.to_string());
87
88 meta
89 }
90
91 pub fn quantifier_metadata(
93 &mut self,
94 quantifier_type: &str,
95 var: &str,
96 domain: &str,
97 ) -> Metadata {
98 let mut meta = Metadata::new().with_name(format!("quantifier:{}", quantifier_type));
99
100 if let Some(ref file) = self.source_file {
101 meta = meta.with_attribute("source_file", file.clone());
102 }
103
104 if let Some(ref rule) = self.rule_id {
105 meta = meta.with_attribute("rule_id", rule.clone());
106 }
107
108 meta = meta.with_attribute("quantifier", quantifier_type.to_string());
109 meta = meta.with_attribute("variable", var.to_string());
110 meta = meta.with_attribute("domain", domain.to_string());
111
112 meta
113 }
114
115 pub fn from_expr(&mut self, expr: &TLExpr) -> Metadata {
117 match expr {
118 TLExpr::Pred { name, args } => {
119 let arg_names: Vec<String> = args.iter().map(|t| format!("{:?}", t)).collect();
120 self.predicate_metadata(name, &arg_names)
121 }
122 TLExpr::And(_, _) => self.logic_op_metadata("AND", 2),
123 TLExpr::Or(_, _) => self.logic_op_metadata("OR", 2),
124 TLExpr::Not(_) => self.logic_op_metadata("NOT", 1),
125 TLExpr::Imply(_, _) => self.logic_op_metadata("IMPLY", 2),
126 TLExpr::Exists { var, domain, .. } => self.quantifier_metadata("EXISTS", var, domain),
127 TLExpr::ForAll { var, domain, .. } => self.quantifier_metadata("FORALL", var, domain),
128 TLExpr::Constant(_) => Metadata::new()
129 .with_name("constant")
130 .with_attribute("type", "constant"),
131 _ => Metadata::new().with_name("expression"),
132 }
133 }
134}
135
136impl Default for MetadataBuilder {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142pub fn propagate_metadata(
144 graph: &mut EinsumGraph,
145 ctx: &CompilerContext,
146 _builder: &mut MetadataBuilder,
147) {
148 let mut metadata_to_add: Vec<(usize, Metadata)> = Vec::new();
150
151 for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
153 if graph.inputs.contains(&tensor_idx) {
154 if let Some(domain_name) = ctx.var_to_domain.values().find(|d| {
156 tensor_name.starts_with(&format!("{}_", d))
157 || tensor_name.contains(&format!("_{}_", d))
158 }) {
159 let mut meta = Metadata::new()
160 .with_name(format!("input_tensor:{}", tensor_name))
161 .with_attribute("domain", domain_name.clone())
162 .with_attribute("tensor_type", "input");
163
164 if let Some(domain_info) = ctx.domains.get(domain_name) {
165 meta = meta.with_attribute("cardinality", domain_info.cardinality.to_string());
166 }
167
168 metadata_to_add.push((tensor_idx, meta));
169 }
170 }
171 }
172
173 for (domain_name, domain_info) in &ctx.domains {
175 for &output_idx in &graph.outputs {
178 if let Some(tensor_name) = graph.tensors.get(output_idx) {
179 if tensor_name.contains(domain_name) {
180 let meta = Metadata::new()
181 .with_name(format!("output_tensor:{}", tensor_name))
182 .with_attribute("domain", domain_name.clone())
183 .with_attribute("cardinality", domain_info.cardinality.to_string())
184 .with_attribute("tensor_type", "output");
185
186 metadata_to_add.push((output_idx, meta));
187 }
188 }
189 }
190 }
191
192 for (idx, meta) in metadata_to_add {
194 graph.add_tensor_metadata(idx, meta);
195 }
196}
197
198pub struct MetadataCompilationResult {
200 pub graph: EinsumGraph,
202 pub builder: MetadataBuilder,
204 pub expr_to_nodes: HashMap<String, Vec<usize>>,
206}
207
208impl MetadataCompilationResult {
209 pub fn new(graph: EinsumGraph, builder: MetadataBuilder) -> Self {
211 Self {
212 graph,
213 builder,
214 expr_to_nodes: HashMap::new(),
215 }
216 }
217
218 pub fn record_expression(&mut self, expr_id: impl Into<String>, node_indices: Vec<usize>) {
220 self.expr_to_nodes.insert(expr_id.into(), node_indices);
221 }
222
223 pub fn get_nodes_for_expr(&self, expr_id: &str) -> Option<&[usize]> {
225 self.expr_to_nodes.get(expr_id).map(|v| v.as_slice())
226 }
227}
228
229pub fn attach_expr_metadata(node: &mut EinsumNode, expr: &TLExpr, builder: &mut MetadataBuilder) {
231 let metadata = builder.from_expr(expr);
232 node.set_metadata(metadata);
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use tensorlogic_ir::Term;
239
240 #[test]
241 fn test_metadata_builder_new() {
242 let builder = MetadataBuilder::new();
243 assert!(builder.source_file.is_none());
244 assert!(builder.rule_id.is_none());
245 assert_eq!(builder.rule_counter, 0);
246 }
247
248 #[test]
249 fn test_metadata_builder_with_source_file() {
250 let builder = MetadataBuilder::new().with_source_file("test.tl");
251 assert_eq!(builder.source_file, Some("test.tl".to_string()));
252 }
253
254 #[test]
255 fn test_metadata_builder_fresh_rule_id() {
256 let mut builder = MetadataBuilder::new();
257 let id1 = builder.fresh_rule_id();
258 let id2 = builder.fresh_rule_id();
259 assert_eq!(id1, "rule_0");
260 assert_eq!(id2, "rule_1");
261 }
262
263 #[test]
264 fn test_predicate_metadata() {
265 let mut builder = MetadataBuilder::new()
266 .with_source_file("test.tl")
267 .with_rule_id("rule_1");
268
269 let meta = builder.predicate_metadata("knows", &["x".to_string(), "y".to_string()]);
270
271 assert_eq!(meta.name, Some("predicate:knows".to_string()));
272 assert_eq!(meta.get_attribute("predicate_name"), Some("knows"));
273 assert_eq!(meta.get_attribute("arity"), Some("2"));
274 assert_eq!(meta.get_attribute("source_file"), Some("test.tl"));
275 assert_eq!(meta.get_attribute("rule_id"), Some("rule_1"));
276 }
277
278 #[test]
279 fn test_logic_op_metadata() {
280 let mut builder = MetadataBuilder::new();
281 let meta = builder.logic_op_metadata("AND", 2);
282
283 assert_eq!(meta.name, Some("logic_op:AND".to_string()));
284 assert_eq!(meta.get_attribute("operation"), Some("AND"));
285 assert_eq!(meta.get_attribute("operand_count"), Some("2"));
286 }
287
288 #[test]
289 fn test_quantifier_metadata() {
290 let mut builder = MetadataBuilder::new();
291 let meta = builder.quantifier_metadata("EXISTS", "x", "Person");
292
293 assert_eq!(meta.name, Some("quantifier:EXISTS".to_string()));
294 assert_eq!(meta.get_attribute("quantifier"), Some("EXISTS"));
295 assert_eq!(meta.get_attribute("variable"), Some("x"));
296 assert_eq!(meta.get_attribute("domain"), Some("Person"));
297 }
298
299 #[test]
300 fn test_from_expr_predicate() {
301 let mut builder = MetadataBuilder::new();
302 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
303 let meta = builder.from_expr(&expr);
304
305 assert_eq!(meta.name, Some("predicate:knows".to_string()));
306 assert_eq!(meta.get_attribute("predicate_name"), Some("knows"));
307 }
308
309 #[test]
310 fn test_from_expr_and() {
311 let mut builder = MetadataBuilder::new();
312 let expr = TLExpr::And(
313 Box::new(TLExpr::pred("p", vec![Term::var("x")])),
314 Box::new(TLExpr::pred("q", vec![Term::var("y")])),
315 );
316 let meta = builder.from_expr(&expr);
317
318 assert_eq!(meta.name, Some("logic_op:AND".to_string()));
319 assert_eq!(meta.get_attribute("operation"), Some("AND"));
320 }
321
322 #[test]
323 fn test_from_expr_exists() {
324 let mut builder = MetadataBuilder::new();
325 let expr = TLExpr::exists(
326 "x",
327 "Person",
328 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
329 );
330 let meta = builder.from_expr(&expr);
331
332 assert_eq!(meta.name, Some("quantifier:EXISTS".to_string()));
333 assert_eq!(meta.get_attribute("quantifier"), Some("EXISTS"));
334 assert_eq!(meta.get_attribute("variable"), Some("x"));
335 assert_eq!(meta.get_attribute("domain"), Some("Person"));
336 }
337
338 #[test]
339 fn test_propagate_metadata_with_domains() {
340 let mut ctx = CompilerContext::new();
341 ctx.add_domain("Person", 100);
342 ctx.bind_var("x", "Person").unwrap();
343
344 let mut graph = EinsumGraph::new();
345 let tensor_idx = graph.add_tensor("Person_x");
346 graph.inputs.push(tensor_idx);
347
348 let mut builder = MetadataBuilder::new();
349 propagate_metadata(&mut graph, &ctx, &mut builder);
350
351 let meta = graph.get_tensor_metadata(tensor_idx);
353 assert!(meta.is_some());
354 }
355
356 #[test]
357 fn test_metadata_compilation_result() {
358 let graph = EinsumGraph::new();
359 let builder = MetadataBuilder::new();
360 let mut result = MetadataCompilationResult::new(graph, builder);
361
362 result.record_expression("expr_1", vec![0, 1, 2]);
363 assert_eq!(result.get_nodes_for_expr("expr_1"), Some(&[0, 1, 2][..]));
364 assert_eq!(result.get_nodes_for_expr("expr_2"), None);
365 }
366
367 #[test]
368 fn test_attach_expr_metadata() {
369 let mut builder = MetadataBuilder::new();
370 let mut node = EinsumNode::new("ab->a", vec![0], vec![1]);
371 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
372
373 attach_expr_metadata(&mut node, &expr, &mut builder);
374
375 let meta = node.get_metadata();
376 assert!(meta.is_some());
377 assert_eq!(meta.unwrap().get_attribute("predicate_name"), Some("knows"));
378 }
379}