Skip to main content

tensorlogic_adapters/
symbol_table.rs

1//! Symbol table for managing domains, predicates, and variables.
2
3use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use tensorlogic_ir::TLExpr;
7
8use crate::error::AdapterError;
9use crate::{DomainInfo, PredicateInfo};
10
11/// Symbol table containing all domain, predicate, and variable information
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct SymbolTable {
14    pub domains: IndexMap<String, DomainInfo>,
15    pub predicates: IndexMap<String, PredicateInfo>,
16    pub variables: IndexMap<String, String>,
17}
18
19impl SymbolTable {
20    pub fn new() -> Self {
21        SymbolTable {
22            domains: IndexMap::new(),
23            predicates: IndexMap::new(),
24            variables: IndexMap::new(),
25        }
26    }
27
28    pub fn add_domain(&mut self, domain: DomainInfo) -> Result<()> {
29        self.domains.insert(domain.name.clone(), domain);
30        Ok(())
31    }
32
33    pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<()> {
34        for domain in &predicate.arg_domains {
35            if !self.domains.contains_key(domain) {
36                bail!(
37                    "Domain '{}' referenced by predicate '{}' does not exist",
38                    domain,
39                    predicate.name
40                );
41            }
42        }
43        self.predicates.insert(predicate.name.clone(), predicate);
44        Ok(())
45    }
46
47    pub fn bind_variable(
48        &mut self,
49        var: impl Into<String>,
50        domain: impl Into<String>,
51    ) -> Result<()> {
52        let var = var.into();
53        let domain = domain.into();
54
55        if !self.domains.contains_key(&domain) {
56            return Err(AdapterError::DomainNotFound(domain).into());
57        }
58
59        self.variables.insert(var, domain);
60        Ok(())
61    }
62
63    pub fn get_domain(&self, name: &str) -> Option<&DomainInfo> {
64        self.domains.get(name)
65    }
66
67    pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
68        self.predicates.get(name)
69    }
70
71    pub fn get_variable_domain(&self, var: &str) -> Option<&str> {
72        self.variables.get(var).map(|s| s.as_str())
73    }
74
75    pub fn infer_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
76        self.collect_domains_from_expr(expr)?;
77        self.collect_predicates_from_expr(expr)?;
78        Ok(())
79    }
80
81    fn collect_domains_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
82        match expr {
83            TLExpr::Exists { domain, var, body } | TLExpr::ForAll { domain, var, body } => {
84                if !self.domains.contains_key(domain) {
85                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
86                }
87                self.bind_variable(var, domain)?;
88                self.collect_domains_from_expr(body)?;
89            }
90            TLExpr::And(l, r)
91            | TLExpr::Or(l, r)
92            | TLExpr::Imply(l, r)
93            | TLExpr::Add(l, r)
94            | TLExpr::Sub(l, r)
95            | TLExpr::Mul(l, r)
96            | TLExpr::Div(l, r)
97            | TLExpr::Pow(l, r)
98            | TLExpr::Mod(l, r)
99            | TLExpr::Min(l, r)
100            | TLExpr::Max(l, r)
101            | TLExpr::Eq(l, r)
102            | TLExpr::Lt(l, r)
103            | TLExpr::Gt(l, r)
104            | TLExpr::Lte(l, r)
105            | TLExpr::Gte(l, r) => {
106                self.collect_domains_from_expr(l)?;
107                self.collect_domains_from_expr(r)?;
108            }
109            TLExpr::Not(e)
110            | TLExpr::Score(e)
111            | TLExpr::Abs(e)
112            | TLExpr::Floor(e)
113            | TLExpr::Ceil(e)
114            | TLExpr::Round(e)
115            | TLExpr::Sqrt(e)
116            | TLExpr::Exp(e)
117            | TLExpr::Log(e)
118            | TLExpr::Sin(e)
119            | TLExpr::Cos(e)
120            | TLExpr::Tan(e) => {
121                self.collect_domains_from_expr(e)?;
122            }
123            TLExpr::IfThenElse {
124                condition,
125                then_branch,
126                else_branch,
127            } => {
128                self.collect_domains_from_expr(condition)?;
129                self.collect_domains_from_expr(then_branch)?;
130                self.collect_domains_from_expr(else_branch)?;
131            }
132            TLExpr::Aggregate {
133                domain, var, body, ..
134            } => {
135                if !self.domains.contains_key(domain) {
136                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
137                }
138                self.bind_variable(var, domain)?;
139                self.collect_domains_from_expr(body)?;
140            }
141            TLExpr::Let {
142                var: _,
143                value,
144                body,
145            } => {
146                self.collect_domains_from_expr(value)?;
147                // The variable binding is temporary, so we don't add it to the symbol table
148                self.collect_domains_from_expr(body)?;
149            }
150            // Modal/temporal logic operators (future enhancement)
151            TLExpr::Box(inner)
152            | TLExpr::Diamond(inner)
153            | TLExpr::Next(inner)
154            | TLExpr::Eventually(inner)
155            | TLExpr::Always(inner) => {
156                self.collect_domains_from_expr(inner)?;
157            }
158            TLExpr::Until { before, after }
159            | TLExpr::Release {
160                released: before,
161                releaser: after,
162            }
163            | TLExpr::WeakUntil { before, after }
164            | TLExpr::StrongRelease {
165                released: before,
166                releaser: after,
167            } => {
168                self.collect_domains_from_expr(before)?;
169                self.collect_domains_from_expr(after)?;
170            }
171            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
172                self.collect_domains_from_expr(left)?;
173                self.collect_domains_from_expr(right)?;
174            }
175            TLExpr::FuzzyNot { expr, .. } => {
176                self.collect_domains_from_expr(expr)?;
177            }
178            TLExpr::FuzzyImplication {
179                premise,
180                conclusion,
181                ..
182            } => {
183                self.collect_domains_from_expr(premise)?;
184                self.collect_domains_from_expr(conclusion)?;
185            }
186            TLExpr::SoftExists {
187                domain, var, body, ..
188            }
189            | TLExpr::SoftForAll {
190                domain, var, body, ..
191            } => {
192                if !self.domains.contains_key(domain) {
193                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
194                }
195                self.bind_variable(var, domain)?;
196                self.collect_domains_from_expr(body)?;
197            }
198            TLExpr::WeightedRule { rule, .. } => {
199                self.collect_domains_from_expr(rule)?;
200            }
201            TLExpr::ProbabilisticChoice { alternatives } => {
202                for (_prob, expr) in alternatives {
203                    self.collect_domains_from_expr(expr)?;
204                }
205            }
206            // Counting quantifiers
207            TLExpr::CountingExists {
208                domain, var, body, ..
209            }
210            | TLExpr::CountingForAll {
211                domain, var, body, ..
212            }
213            | TLExpr::ExactCount {
214                domain, var, body, ..
215            }
216            | TLExpr::Majority { domain, var, body } => {
217                if !self.domains.contains_key(domain) {
218                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
219                }
220                self.bind_variable(var, domain)?;
221                self.collect_domains_from_expr(body)?;
222            }
223            TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
224            // All other expression types (enhancements) - don't introduce new domains
225            _ => {
226                // For now, skip domain collection for unimplemented expression types
227                // This allows the code to compile while features are being implemented
228            }
229        }
230        Ok(())
231    }
232
233    fn collect_predicates_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
234        match expr {
235            TLExpr::Pred { name, args } if !self.predicates.contains_key(name) => {
236                let arg_domains: Vec<String> = args.iter().map(|_| "Unknown".to_string()).collect();
237                self.predicates
238                    .insert(name.clone(), PredicateInfo::new(name.clone(), arg_domains));
239            }
240            TLExpr::And(l, r)
241            | TLExpr::Or(l, r)
242            | TLExpr::Imply(l, r)
243            | TLExpr::Add(l, r)
244            | TLExpr::Sub(l, r)
245            | TLExpr::Mul(l, r)
246            | TLExpr::Div(l, r)
247            | TLExpr::Pow(l, r)
248            | TLExpr::Mod(l, r)
249            | TLExpr::Min(l, r)
250            | TLExpr::Max(l, r)
251            | TLExpr::Eq(l, r)
252            | TLExpr::Lt(l, r)
253            | TLExpr::Gt(l, r)
254            | TLExpr::Lte(l, r)
255            | TLExpr::Gte(l, r) => {
256                self.collect_predicates_from_expr(l)?;
257                self.collect_predicates_from_expr(r)?;
258            }
259            TLExpr::Not(e)
260            | TLExpr::Score(e)
261            | TLExpr::Abs(e)
262            | TLExpr::Floor(e)
263            | TLExpr::Ceil(e)
264            | TLExpr::Round(e)
265            | TLExpr::Sqrt(e)
266            | TLExpr::Exp(e)
267            | TLExpr::Log(e)
268            | TLExpr::Sin(e)
269            | TLExpr::Cos(e)
270            | TLExpr::Tan(e) => {
271                self.collect_predicates_from_expr(e)?;
272            }
273            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
274                self.collect_predicates_from_expr(body)?;
275            }
276            TLExpr::IfThenElse {
277                condition,
278                then_branch,
279                else_branch,
280            } => {
281                self.collect_predicates_from_expr(condition)?;
282                self.collect_predicates_from_expr(then_branch)?;
283                self.collect_predicates_from_expr(else_branch)?;
284            }
285            TLExpr::Aggregate { body, .. } => {
286                self.collect_predicates_from_expr(body)?;
287            }
288            TLExpr::Let { value, body, .. } => {
289                self.collect_predicates_from_expr(value)?;
290                self.collect_predicates_from_expr(body)?;
291            }
292            // Modal/temporal logic operators (future enhancement)
293            TLExpr::Box(inner)
294            | TLExpr::Diamond(inner)
295            | TLExpr::Next(inner)
296            | TLExpr::Eventually(inner)
297            | TLExpr::Always(inner) => {
298                self.collect_predicates_from_expr(inner)?;
299            }
300            TLExpr::Until { before, after }
301            | TLExpr::Release {
302                released: before,
303                releaser: after,
304            }
305            | TLExpr::WeakUntil { before, after }
306            | TLExpr::StrongRelease {
307                released: before,
308                releaser: after,
309            } => {
310                self.collect_predicates_from_expr(before)?;
311                self.collect_predicates_from_expr(after)?;
312            }
313            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
314                self.collect_predicates_from_expr(left)?;
315                self.collect_predicates_from_expr(right)?;
316            }
317            TLExpr::FuzzyNot { expr, .. } => {
318                self.collect_predicates_from_expr(expr)?;
319            }
320            TLExpr::FuzzyImplication {
321                premise,
322                conclusion,
323                ..
324            } => {
325                self.collect_predicates_from_expr(premise)?;
326                self.collect_predicates_from_expr(conclusion)?;
327            }
328            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
329                self.collect_predicates_from_expr(body)?;
330            }
331            TLExpr::WeightedRule { rule, .. } => {
332                self.collect_predicates_from_expr(rule)?;
333            }
334            TLExpr::ProbabilisticChoice { alternatives } => {
335                for (_prob, expr) in alternatives {
336                    self.collect_predicates_from_expr(expr)?;
337                }
338            }
339            // Counting quantifiers
340            TLExpr::CountingExists { body, .. }
341            | TLExpr::CountingForAll { body, .. }
342            | TLExpr::ExactCount { body, .. }
343            | TLExpr::Majority { body, .. } => {
344                self.collect_predicates_from_expr(body)?;
345            }
346            TLExpr::Constant(_) => {}
347            // All other expression types (enhancements) - don't introduce predicates
348            _ => {
349                // For now, skip predicate collection for unimplemented expression types
350                // This allows the code to compile while features are being implemented
351            }
352        }
353        Ok(())
354    }
355
356    pub fn to_json(&self) -> Result<String> {
357        Ok(serde_json::to_string_pretty(self)?)
358    }
359
360    pub fn from_json(json: &str) -> Result<Self> {
361        Ok(serde_json::from_str(json)?)
362    }
363
364    pub fn to_yaml(&self) -> Result<String> {
365        Ok(serde_yaml::to_string(self)?)
366    }
367
368    pub fn from_yaml(yaml: &str) -> Result<Self> {
369        Ok(serde_yaml::from_str(yaml)?)
370    }
371}
372
373impl Default for SymbolTable {
374    fn default() -> Self {
375        Self::new()
376    }
377}