Skip to main content

tensorlogic_adapters/
symbol_table.rs

1//! Symbol table for managing domains, predicates, and variables.
2
3use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use tensorlogic_ir::TLExpr;
7
8use crate::error::AdapterError;
9use crate::{DomainInfo, PredicateInfo};
10
11/// Symbol table containing all domain, predicate, and variable information
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct SymbolTable {
14    pub domains: IndexMap<String, DomainInfo>,
15    pub predicates: IndexMap<String, PredicateInfo>,
16    pub variables: IndexMap<String, String>,
17}
18
19impl SymbolTable {
20    pub fn new() -> Self {
21        SymbolTable {
22            domains: IndexMap::new(),
23            predicates: IndexMap::new(),
24            variables: IndexMap::new(),
25        }
26    }
27
28    pub fn add_domain(&mut self, domain: DomainInfo) -> Result<()> {
29        self.domains.insert(domain.name.clone(), domain);
30        Ok(())
31    }
32
33    pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<()> {
34        for domain in &predicate.arg_domains {
35            if !self.domains.contains_key(domain) {
36                bail!(
37                    "Domain '{}' referenced by predicate '{}' does not exist",
38                    domain,
39                    predicate.name
40                );
41            }
42        }
43        self.predicates.insert(predicate.name.clone(), predicate);
44        Ok(())
45    }
46
47    pub fn bind_variable(
48        &mut self,
49        var: impl Into<String>,
50        domain: impl Into<String>,
51    ) -> Result<()> {
52        let var = var.into();
53        let domain = domain.into();
54
55        if !self.domains.contains_key(&domain) {
56            return Err(AdapterError::DomainNotFound(domain).into());
57        }
58
59        self.variables.insert(var, domain);
60        Ok(())
61    }
62
63    pub fn get_domain(&self, name: &str) -> Option<&DomainInfo> {
64        self.domains.get(name)
65    }
66
67    pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
68        self.predicates.get(name)
69    }
70
71    pub fn get_variable_domain(&self, var: &str) -> Option<&str> {
72        self.variables.get(var).map(|s| s.as_str())
73    }
74
75    pub fn infer_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
76        self.collect_domains_from_expr(expr)?;
77        self.collect_predicates_from_expr(expr)?;
78        Ok(())
79    }
80
81    fn collect_domains_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
82        match expr {
83            TLExpr::Exists { domain, var, body } | TLExpr::ForAll { domain, var, body } => {
84                if !self.domains.contains_key(domain) {
85                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
86                }
87                self.bind_variable(var, domain)?;
88                self.collect_domains_from_expr(body)?;
89            }
90            TLExpr::And(l, r)
91            | TLExpr::Or(l, r)
92            | TLExpr::Imply(l, r)
93            | TLExpr::Add(l, r)
94            | TLExpr::Sub(l, r)
95            | TLExpr::Mul(l, r)
96            | TLExpr::Div(l, r)
97            | TLExpr::Pow(l, r)
98            | TLExpr::Mod(l, r)
99            | TLExpr::Min(l, r)
100            | TLExpr::Max(l, r)
101            | TLExpr::Eq(l, r)
102            | TLExpr::Lt(l, r)
103            | TLExpr::Gt(l, r)
104            | TLExpr::Lte(l, r)
105            | TLExpr::Gte(l, r) => {
106                self.collect_domains_from_expr(l)?;
107                self.collect_domains_from_expr(r)?;
108            }
109            TLExpr::Not(e)
110            | TLExpr::Score(e)
111            | TLExpr::Abs(e)
112            | TLExpr::Floor(e)
113            | TLExpr::Ceil(e)
114            | TLExpr::Round(e)
115            | TLExpr::Sqrt(e)
116            | TLExpr::Exp(e)
117            | TLExpr::Log(e)
118            | TLExpr::Sin(e)
119            | TLExpr::Cos(e)
120            | TLExpr::Tan(e) => {
121                self.collect_domains_from_expr(e)?;
122            }
123            TLExpr::IfThenElse {
124                condition,
125                then_branch,
126                else_branch,
127            } => {
128                self.collect_domains_from_expr(condition)?;
129                self.collect_domains_from_expr(then_branch)?;
130                self.collect_domains_from_expr(else_branch)?;
131            }
132            TLExpr::Aggregate {
133                domain, var, body, ..
134            } => {
135                if !self.domains.contains_key(domain) {
136                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
137                }
138                self.bind_variable(var, domain)?;
139                self.collect_domains_from_expr(body)?;
140            }
141            TLExpr::Let {
142                var: _,
143                value,
144                body,
145            } => {
146                self.collect_domains_from_expr(value)?;
147                // The variable binding is temporary, so we don't add it to the symbol table
148                self.collect_domains_from_expr(body)?;
149            }
150            // Modal/temporal logic operators (future enhancement)
151            TLExpr::Box(inner)
152            | TLExpr::Diamond(inner)
153            | TLExpr::Next(inner)
154            | TLExpr::Eventually(inner)
155            | TLExpr::Always(inner) => {
156                self.collect_domains_from_expr(inner)?;
157            }
158            TLExpr::Until { before, after }
159            | TLExpr::Release {
160                released: before,
161                releaser: after,
162            }
163            | TLExpr::WeakUntil { before, after }
164            | TLExpr::StrongRelease {
165                released: before,
166                releaser: after,
167            } => {
168                self.collect_domains_from_expr(before)?;
169                self.collect_domains_from_expr(after)?;
170            }
171            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
172                self.collect_domains_from_expr(left)?;
173                self.collect_domains_from_expr(right)?;
174            }
175            TLExpr::FuzzyNot { expr, .. } => {
176                self.collect_domains_from_expr(expr)?;
177            }
178            TLExpr::FuzzyImplication {
179                premise,
180                conclusion,
181                ..
182            } => {
183                self.collect_domains_from_expr(premise)?;
184                self.collect_domains_from_expr(conclusion)?;
185            }
186            TLExpr::SoftExists {
187                domain, var, body, ..
188            }
189            | TLExpr::SoftForAll {
190                domain, var, body, ..
191            } => {
192                if !self.domains.contains_key(domain) {
193                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
194                }
195                self.bind_variable(var, domain)?;
196                self.collect_domains_from_expr(body)?;
197            }
198            TLExpr::WeightedRule { rule, .. } => {
199                self.collect_domains_from_expr(rule)?;
200            }
201            TLExpr::ProbabilisticChoice { alternatives } => {
202                for (_prob, expr) in alternatives {
203                    self.collect_domains_from_expr(expr)?;
204                }
205            }
206            // Counting quantifiers
207            TLExpr::CountingExists {
208                domain, var, body, ..
209            }
210            | TLExpr::CountingForAll {
211                domain, var, body, ..
212            }
213            | TLExpr::ExactCount {
214                domain, var, body, ..
215            }
216            | TLExpr::Majority { domain, var, body } => {
217                if !self.domains.contains_key(domain) {
218                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
219                }
220                self.bind_variable(var, domain)?;
221                self.collect_domains_from_expr(body)?;
222            }
223            TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
224            // All other expression types (enhancements) - don't introduce new domains
225            _ => {
226                // For now, skip domain collection for unimplemented expression types
227                // This allows the code to compile while features are being implemented
228            }
229        }
230        Ok(())
231    }
232
233    fn collect_predicates_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
234        match expr {
235            TLExpr::Pred { name, args } => {
236                if !self.predicates.contains_key(name) {
237                    let arg_domains: Vec<String> =
238                        args.iter().map(|_| "Unknown".to_string()).collect();
239                    self.predicates
240                        .insert(name.clone(), PredicateInfo::new(name.clone(), arg_domains));
241                }
242            }
243            TLExpr::And(l, r)
244            | TLExpr::Or(l, r)
245            | TLExpr::Imply(l, r)
246            | TLExpr::Add(l, r)
247            | TLExpr::Sub(l, r)
248            | TLExpr::Mul(l, r)
249            | TLExpr::Div(l, r)
250            | TLExpr::Pow(l, r)
251            | TLExpr::Mod(l, r)
252            | TLExpr::Min(l, r)
253            | TLExpr::Max(l, r)
254            | TLExpr::Eq(l, r)
255            | TLExpr::Lt(l, r)
256            | TLExpr::Gt(l, r)
257            | TLExpr::Lte(l, r)
258            | TLExpr::Gte(l, r) => {
259                self.collect_predicates_from_expr(l)?;
260                self.collect_predicates_from_expr(r)?;
261            }
262            TLExpr::Not(e)
263            | TLExpr::Score(e)
264            | TLExpr::Abs(e)
265            | TLExpr::Floor(e)
266            | TLExpr::Ceil(e)
267            | TLExpr::Round(e)
268            | TLExpr::Sqrt(e)
269            | TLExpr::Exp(e)
270            | TLExpr::Log(e)
271            | TLExpr::Sin(e)
272            | TLExpr::Cos(e)
273            | TLExpr::Tan(e) => {
274                self.collect_predicates_from_expr(e)?;
275            }
276            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
277                self.collect_predicates_from_expr(body)?;
278            }
279            TLExpr::IfThenElse {
280                condition,
281                then_branch,
282                else_branch,
283            } => {
284                self.collect_predicates_from_expr(condition)?;
285                self.collect_predicates_from_expr(then_branch)?;
286                self.collect_predicates_from_expr(else_branch)?;
287            }
288            TLExpr::Aggregate { body, .. } => {
289                self.collect_predicates_from_expr(body)?;
290            }
291            TLExpr::Let { value, body, .. } => {
292                self.collect_predicates_from_expr(value)?;
293                self.collect_predicates_from_expr(body)?;
294            }
295            // Modal/temporal logic operators (future enhancement)
296            TLExpr::Box(inner)
297            | TLExpr::Diamond(inner)
298            | TLExpr::Next(inner)
299            | TLExpr::Eventually(inner)
300            | TLExpr::Always(inner) => {
301                self.collect_predicates_from_expr(inner)?;
302            }
303            TLExpr::Until { before, after }
304            | TLExpr::Release {
305                released: before,
306                releaser: after,
307            }
308            | TLExpr::WeakUntil { before, after }
309            | TLExpr::StrongRelease {
310                released: before,
311                releaser: after,
312            } => {
313                self.collect_predicates_from_expr(before)?;
314                self.collect_predicates_from_expr(after)?;
315            }
316            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
317                self.collect_predicates_from_expr(left)?;
318                self.collect_predicates_from_expr(right)?;
319            }
320            TLExpr::FuzzyNot { expr, .. } => {
321                self.collect_predicates_from_expr(expr)?;
322            }
323            TLExpr::FuzzyImplication {
324                premise,
325                conclusion,
326                ..
327            } => {
328                self.collect_predicates_from_expr(premise)?;
329                self.collect_predicates_from_expr(conclusion)?;
330            }
331            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
332                self.collect_predicates_from_expr(body)?;
333            }
334            TLExpr::WeightedRule { rule, .. } => {
335                self.collect_predicates_from_expr(rule)?;
336            }
337            TLExpr::ProbabilisticChoice { alternatives } => {
338                for (_prob, expr) in alternatives {
339                    self.collect_predicates_from_expr(expr)?;
340                }
341            }
342            // Counting quantifiers
343            TLExpr::CountingExists { body, .. }
344            | TLExpr::CountingForAll { body, .. }
345            | TLExpr::ExactCount { body, .. }
346            | TLExpr::Majority { body, .. } => {
347                self.collect_predicates_from_expr(body)?;
348            }
349            TLExpr::Constant(_) => {}
350            // All other expression types (enhancements) - don't introduce predicates
351            _ => {
352                // For now, skip predicate collection for unimplemented expression types
353                // This allows the code to compile while features are being implemented
354            }
355        }
356        Ok(())
357    }
358
359    pub fn to_json(&self) -> Result<String> {
360        Ok(serde_json::to_string_pretty(self)?)
361    }
362
363    pub fn from_json(json: &str) -> Result<Self> {
364        Ok(serde_json::from_str(json)?)
365    }
366
367    pub fn to_yaml(&self) -> Result<String> {
368        Ok(serde_yaml::to_string(self)?)
369    }
370
371    pub fn from_yaml(yaml: &str) -> Result<Self> {
372        Ok(serde_yaml::from_str(yaml)?)
373    }
374}
375
376impl Default for SymbolTable {
377    fn default() -> Self {
378        Self::new()
379    }
380}