tensorlogic_adapters/
symbol_table.rs

1//! Symbol table for managing domains, predicates, and variables.
2
3use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use tensorlogic_ir::TLExpr;
7
8use crate::error::AdapterError;
9use crate::{DomainInfo, PredicateInfo};
10
11/// Symbol table containing all domain, predicate, and variable information
12#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct SymbolTable {
14    pub domains: IndexMap<String, DomainInfo>,
15    pub predicates: IndexMap<String, PredicateInfo>,
16    pub variables: IndexMap<String, String>,
17}
18
19impl SymbolTable {
20    pub fn new() -> Self {
21        SymbolTable {
22            domains: IndexMap::new(),
23            predicates: IndexMap::new(),
24            variables: IndexMap::new(),
25        }
26    }
27
28    pub fn add_domain(&mut self, domain: DomainInfo) -> Result<()> {
29        self.domains.insert(domain.name.clone(), domain);
30        Ok(())
31    }
32
33    pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<()> {
34        for domain in &predicate.arg_domains {
35            if !self.domains.contains_key(domain) {
36                bail!(
37                    "Domain '{}' referenced by predicate '{}' does not exist",
38                    domain,
39                    predicate.name
40                );
41            }
42        }
43        self.predicates.insert(predicate.name.clone(), predicate);
44        Ok(())
45    }
46
47    pub fn bind_variable(
48        &mut self,
49        var: impl Into<String>,
50        domain: impl Into<String>,
51    ) -> Result<()> {
52        let var = var.into();
53        let domain = domain.into();
54
55        if !self.domains.contains_key(&domain) {
56            return Err(AdapterError::DomainNotFound(domain).into());
57        }
58
59        self.variables.insert(var, domain);
60        Ok(())
61    }
62
63    pub fn get_domain(&self, name: &str) -> Option<&DomainInfo> {
64        self.domains.get(name)
65    }
66
67    pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
68        self.predicates.get(name)
69    }
70
71    pub fn get_variable_domain(&self, var: &str) -> Option<&str> {
72        self.variables.get(var).map(|s| s.as_str())
73    }
74
75    pub fn infer_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
76        self.collect_domains_from_expr(expr)?;
77        self.collect_predicates_from_expr(expr)?;
78        Ok(())
79    }
80
81    fn collect_domains_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
82        match expr {
83            TLExpr::Exists { domain, var, body } | TLExpr::ForAll { domain, var, body } => {
84                if !self.domains.contains_key(domain) {
85                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
86                }
87                self.bind_variable(var, domain)?;
88                self.collect_domains_from_expr(body)?;
89            }
90            TLExpr::And(l, r)
91            | TLExpr::Or(l, r)
92            | TLExpr::Imply(l, r)
93            | TLExpr::Add(l, r)
94            | TLExpr::Sub(l, r)
95            | TLExpr::Mul(l, r)
96            | TLExpr::Div(l, r)
97            | TLExpr::Pow(l, r)
98            | TLExpr::Mod(l, r)
99            | TLExpr::Min(l, r)
100            | TLExpr::Max(l, r)
101            | TLExpr::Eq(l, r)
102            | TLExpr::Lt(l, r)
103            | TLExpr::Gt(l, r)
104            | TLExpr::Lte(l, r)
105            | TLExpr::Gte(l, r) => {
106                self.collect_domains_from_expr(l)?;
107                self.collect_domains_from_expr(r)?;
108            }
109            TLExpr::Not(e)
110            | TLExpr::Score(e)
111            | TLExpr::Abs(e)
112            | TLExpr::Floor(e)
113            | TLExpr::Ceil(e)
114            | TLExpr::Round(e)
115            | TLExpr::Sqrt(e)
116            | TLExpr::Exp(e)
117            | TLExpr::Log(e)
118            | TLExpr::Sin(e)
119            | TLExpr::Cos(e)
120            | TLExpr::Tan(e) => {
121                self.collect_domains_from_expr(e)?;
122            }
123            TLExpr::IfThenElse {
124                condition,
125                then_branch,
126                else_branch,
127            } => {
128                self.collect_domains_from_expr(condition)?;
129                self.collect_domains_from_expr(then_branch)?;
130                self.collect_domains_from_expr(else_branch)?;
131            }
132            TLExpr::Aggregate {
133                domain, var, body, ..
134            } => {
135                if !self.domains.contains_key(domain) {
136                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
137                }
138                self.bind_variable(var, domain)?;
139                self.collect_domains_from_expr(body)?;
140            }
141            TLExpr::Let {
142                var: _,
143                value,
144                body,
145            } => {
146                self.collect_domains_from_expr(value)?;
147                // The variable binding is temporary, so we don't add it to the symbol table
148                self.collect_domains_from_expr(body)?;
149            }
150            // Modal/temporal logic operators (future enhancement)
151            TLExpr::Box(inner)
152            | TLExpr::Diamond(inner)
153            | TLExpr::Next(inner)
154            | TLExpr::Eventually(inner)
155            | TLExpr::Always(inner) => {
156                self.collect_domains_from_expr(inner)?;
157            }
158            TLExpr::Until { before, after }
159            | TLExpr::Release {
160                released: before,
161                releaser: after,
162            }
163            | TLExpr::WeakUntil { before, after }
164            | TLExpr::StrongRelease {
165                released: before,
166                releaser: after,
167            } => {
168                self.collect_domains_from_expr(before)?;
169                self.collect_domains_from_expr(after)?;
170            }
171            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
172                self.collect_domains_from_expr(left)?;
173                self.collect_domains_from_expr(right)?;
174            }
175            TLExpr::FuzzyNot { expr, .. } => {
176                self.collect_domains_from_expr(expr)?;
177            }
178            TLExpr::FuzzyImplication {
179                premise,
180                conclusion,
181                ..
182            } => {
183                self.collect_domains_from_expr(premise)?;
184                self.collect_domains_from_expr(conclusion)?;
185            }
186            TLExpr::SoftExists {
187                domain, var, body, ..
188            }
189            | TLExpr::SoftForAll {
190                domain, var, body, ..
191            } => {
192                if !self.domains.contains_key(domain) {
193                    self.add_domain(DomainInfo::new(domain.clone(), 0))?;
194                }
195                self.bind_variable(var, domain)?;
196                self.collect_domains_from_expr(body)?;
197            }
198            TLExpr::WeightedRule { rule, .. } => {
199                self.collect_domains_from_expr(rule)?;
200            }
201            TLExpr::ProbabilisticChoice { alternatives } => {
202                for (_prob, expr) in alternatives {
203                    self.collect_domains_from_expr(expr)?;
204                }
205            }
206            TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
207        }
208        Ok(())
209    }
210
211    fn collect_predicates_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
212        match expr {
213            TLExpr::Pred { name, args } => {
214                if !self.predicates.contains_key(name) {
215                    let arg_domains: Vec<String> =
216                        args.iter().map(|_| "Unknown".to_string()).collect();
217                    self.predicates
218                        .insert(name.clone(), PredicateInfo::new(name.clone(), arg_domains));
219                }
220            }
221            TLExpr::And(l, r)
222            | TLExpr::Or(l, r)
223            | TLExpr::Imply(l, r)
224            | TLExpr::Add(l, r)
225            | TLExpr::Sub(l, r)
226            | TLExpr::Mul(l, r)
227            | TLExpr::Div(l, r)
228            | TLExpr::Pow(l, r)
229            | TLExpr::Mod(l, r)
230            | TLExpr::Min(l, r)
231            | TLExpr::Max(l, r)
232            | TLExpr::Eq(l, r)
233            | TLExpr::Lt(l, r)
234            | TLExpr::Gt(l, r)
235            | TLExpr::Lte(l, r)
236            | TLExpr::Gte(l, r) => {
237                self.collect_predicates_from_expr(l)?;
238                self.collect_predicates_from_expr(r)?;
239            }
240            TLExpr::Not(e)
241            | TLExpr::Score(e)
242            | TLExpr::Abs(e)
243            | TLExpr::Floor(e)
244            | TLExpr::Ceil(e)
245            | TLExpr::Round(e)
246            | TLExpr::Sqrt(e)
247            | TLExpr::Exp(e)
248            | TLExpr::Log(e)
249            | TLExpr::Sin(e)
250            | TLExpr::Cos(e)
251            | TLExpr::Tan(e) => {
252                self.collect_predicates_from_expr(e)?;
253            }
254            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
255                self.collect_predicates_from_expr(body)?;
256            }
257            TLExpr::IfThenElse {
258                condition,
259                then_branch,
260                else_branch,
261            } => {
262                self.collect_predicates_from_expr(condition)?;
263                self.collect_predicates_from_expr(then_branch)?;
264                self.collect_predicates_from_expr(else_branch)?;
265            }
266            TLExpr::Aggregate { body, .. } => {
267                self.collect_predicates_from_expr(body)?;
268            }
269            TLExpr::Let { value, body, .. } => {
270                self.collect_predicates_from_expr(value)?;
271                self.collect_predicates_from_expr(body)?;
272            }
273            // Modal/temporal logic operators (future enhancement)
274            TLExpr::Box(inner)
275            | TLExpr::Diamond(inner)
276            | TLExpr::Next(inner)
277            | TLExpr::Eventually(inner)
278            | TLExpr::Always(inner) => {
279                self.collect_predicates_from_expr(inner)?;
280            }
281            TLExpr::Until { before, after }
282            | TLExpr::Release {
283                released: before,
284                releaser: after,
285            }
286            | TLExpr::WeakUntil { before, after }
287            | TLExpr::StrongRelease {
288                released: before,
289                releaser: after,
290            } => {
291                self.collect_predicates_from_expr(before)?;
292                self.collect_predicates_from_expr(after)?;
293            }
294            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
295                self.collect_predicates_from_expr(left)?;
296                self.collect_predicates_from_expr(right)?;
297            }
298            TLExpr::FuzzyNot { expr, .. } => {
299                self.collect_predicates_from_expr(expr)?;
300            }
301            TLExpr::FuzzyImplication {
302                premise,
303                conclusion,
304                ..
305            } => {
306                self.collect_predicates_from_expr(premise)?;
307                self.collect_predicates_from_expr(conclusion)?;
308            }
309            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
310                self.collect_predicates_from_expr(body)?;
311            }
312            TLExpr::WeightedRule { rule, .. } => {
313                self.collect_predicates_from_expr(rule)?;
314            }
315            TLExpr::ProbabilisticChoice { alternatives } => {
316                for (_prob, expr) in alternatives {
317                    self.collect_predicates_from_expr(expr)?;
318                }
319            }
320            TLExpr::Constant(_) => {}
321        }
322        Ok(())
323    }
324
325    pub fn to_json(&self) -> Result<String> {
326        Ok(serde_json::to_string_pretty(self)?)
327    }
328
329    pub fn from_json(json: &str) -> Result<Self> {
330        Ok(serde_json::from_str(json)?)
331    }
332
333    pub fn to_yaml(&self) -> Result<String> {
334        Ok(serde_yaml::to_string(self)?)
335    }
336
337    pub fn from_yaml(yaml: &str) -> Result<Self> {
338        Ok(serde_yaml::from_str(yaml)?)
339    }
340}
341
342impl Default for SymbolTable {
343    fn default() -> Self {
344        Self::new()
345    }
346}