1use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use tensorlogic_ir::TLExpr;
7
8use crate::error::AdapterError;
9use crate::{DomainInfo, PredicateInfo};
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct SymbolTable {
14 pub domains: IndexMap<String, DomainInfo>,
15 pub predicates: IndexMap<String, PredicateInfo>,
16 pub variables: IndexMap<String, String>,
17}
18
19impl SymbolTable {
20 pub fn new() -> Self {
21 SymbolTable {
22 domains: IndexMap::new(),
23 predicates: IndexMap::new(),
24 variables: IndexMap::new(),
25 }
26 }
27
28 pub fn add_domain(&mut self, domain: DomainInfo) -> Result<()> {
29 self.domains.insert(domain.name.clone(), domain);
30 Ok(())
31 }
32
33 pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<()> {
34 for domain in &predicate.arg_domains {
35 if !self.domains.contains_key(domain) {
36 bail!(
37 "Domain '{}' referenced by predicate '{}' does not exist",
38 domain,
39 predicate.name
40 );
41 }
42 }
43 self.predicates.insert(predicate.name.clone(), predicate);
44 Ok(())
45 }
46
47 pub fn bind_variable(
48 &mut self,
49 var: impl Into<String>,
50 domain: impl Into<String>,
51 ) -> Result<()> {
52 let var = var.into();
53 let domain = domain.into();
54
55 if !self.domains.contains_key(&domain) {
56 return Err(AdapterError::DomainNotFound(domain).into());
57 }
58
59 self.variables.insert(var, domain);
60 Ok(())
61 }
62
63 pub fn get_domain(&self, name: &str) -> Option<&DomainInfo> {
64 self.domains.get(name)
65 }
66
67 pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
68 self.predicates.get(name)
69 }
70
71 pub fn get_variable_domain(&self, var: &str) -> Option<&str> {
72 self.variables.get(var).map(|s| s.as_str())
73 }
74
75 pub fn infer_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
76 self.collect_domains_from_expr(expr)?;
77 self.collect_predicates_from_expr(expr)?;
78 Ok(())
79 }
80
81 fn collect_domains_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
82 match expr {
83 TLExpr::Exists { domain, var, body } | TLExpr::ForAll { domain, var, body } => {
84 if !self.domains.contains_key(domain) {
85 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
86 }
87 self.bind_variable(var, domain)?;
88 self.collect_domains_from_expr(body)?;
89 }
90 TLExpr::And(l, r)
91 | TLExpr::Or(l, r)
92 | TLExpr::Imply(l, r)
93 | TLExpr::Add(l, r)
94 | TLExpr::Sub(l, r)
95 | TLExpr::Mul(l, r)
96 | TLExpr::Div(l, r)
97 | TLExpr::Pow(l, r)
98 | TLExpr::Mod(l, r)
99 | TLExpr::Min(l, r)
100 | TLExpr::Max(l, r)
101 | TLExpr::Eq(l, r)
102 | TLExpr::Lt(l, r)
103 | TLExpr::Gt(l, r)
104 | TLExpr::Lte(l, r)
105 | TLExpr::Gte(l, r) => {
106 self.collect_domains_from_expr(l)?;
107 self.collect_domains_from_expr(r)?;
108 }
109 TLExpr::Not(e)
110 | TLExpr::Score(e)
111 | TLExpr::Abs(e)
112 | TLExpr::Floor(e)
113 | TLExpr::Ceil(e)
114 | TLExpr::Round(e)
115 | TLExpr::Sqrt(e)
116 | TLExpr::Exp(e)
117 | TLExpr::Log(e)
118 | TLExpr::Sin(e)
119 | TLExpr::Cos(e)
120 | TLExpr::Tan(e) => {
121 self.collect_domains_from_expr(e)?;
122 }
123 TLExpr::IfThenElse {
124 condition,
125 then_branch,
126 else_branch,
127 } => {
128 self.collect_domains_from_expr(condition)?;
129 self.collect_domains_from_expr(then_branch)?;
130 self.collect_domains_from_expr(else_branch)?;
131 }
132 TLExpr::Aggregate {
133 domain, var, body, ..
134 } => {
135 if !self.domains.contains_key(domain) {
136 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
137 }
138 self.bind_variable(var, domain)?;
139 self.collect_domains_from_expr(body)?;
140 }
141 TLExpr::Let {
142 var: _,
143 value,
144 body,
145 } => {
146 self.collect_domains_from_expr(value)?;
147 self.collect_domains_from_expr(body)?;
149 }
150 TLExpr::Box(inner)
152 | TLExpr::Diamond(inner)
153 | TLExpr::Next(inner)
154 | TLExpr::Eventually(inner)
155 | TLExpr::Always(inner) => {
156 self.collect_domains_from_expr(inner)?;
157 }
158 TLExpr::Until { before, after }
159 | TLExpr::Release {
160 released: before,
161 releaser: after,
162 }
163 | TLExpr::WeakUntil { before, after }
164 | TLExpr::StrongRelease {
165 released: before,
166 releaser: after,
167 } => {
168 self.collect_domains_from_expr(before)?;
169 self.collect_domains_from_expr(after)?;
170 }
171 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
172 self.collect_domains_from_expr(left)?;
173 self.collect_domains_from_expr(right)?;
174 }
175 TLExpr::FuzzyNot { expr, .. } => {
176 self.collect_domains_from_expr(expr)?;
177 }
178 TLExpr::FuzzyImplication {
179 premise,
180 conclusion,
181 ..
182 } => {
183 self.collect_domains_from_expr(premise)?;
184 self.collect_domains_from_expr(conclusion)?;
185 }
186 TLExpr::SoftExists {
187 domain, var, body, ..
188 }
189 | TLExpr::SoftForAll {
190 domain, var, body, ..
191 } => {
192 if !self.domains.contains_key(domain) {
193 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
194 }
195 self.bind_variable(var, domain)?;
196 self.collect_domains_from_expr(body)?;
197 }
198 TLExpr::WeightedRule { rule, .. } => {
199 self.collect_domains_from_expr(rule)?;
200 }
201 TLExpr::ProbabilisticChoice { alternatives } => {
202 for (_prob, expr) in alternatives {
203 self.collect_domains_from_expr(expr)?;
204 }
205 }
206 TLExpr::CountingExists {
208 domain, var, body, ..
209 }
210 | TLExpr::CountingForAll {
211 domain, var, body, ..
212 }
213 | TLExpr::ExactCount {
214 domain, var, body, ..
215 }
216 | TLExpr::Majority { domain, var, body } => {
217 if !self.domains.contains_key(domain) {
218 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
219 }
220 self.bind_variable(var, domain)?;
221 self.collect_domains_from_expr(body)?;
222 }
223 TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
224 _ => {
226 }
229 }
230 Ok(())
231 }
232
233 fn collect_predicates_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
234 match expr {
235 TLExpr::Pred { name, args } if !self.predicates.contains_key(name) => {
236 let arg_domains: Vec<String> = args.iter().map(|_| "Unknown".to_string()).collect();
237 self.predicates
238 .insert(name.clone(), PredicateInfo::new(name.clone(), arg_domains));
239 }
240 TLExpr::And(l, r)
241 | TLExpr::Or(l, r)
242 | TLExpr::Imply(l, r)
243 | TLExpr::Add(l, r)
244 | TLExpr::Sub(l, r)
245 | TLExpr::Mul(l, r)
246 | TLExpr::Div(l, r)
247 | TLExpr::Pow(l, r)
248 | TLExpr::Mod(l, r)
249 | TLExpr::Min(l, r)
250 | TLExpr::Max(l, r)
251 | TLExpr::Eq(l, r)
252 | TLExpr::Lt(l, r)
253 | TLExpr::Gt(l, r)
254 | TLExpr::Lte(l, r)
255 | TLExpr::Gte(l, r) => {
256 self.collect_predicates_from_expr(l)?;
257 self.collect_predicates_from_expr(r)?;
258 }
259 TLExpr::Not(e)
260 | TLExpr::Score(e)
261 | TLExpr::Abs(e)
262 | TLExpr::Floor(e)
263 | TLExpr::Ceil(e)
264 | TLExpr::Round(e)
265 | TLExpr::Sqrt(e)
266 | TLExpr::Exp(e)
267 | TLExpr::Log(e)
268 | TLExpr::Sin(e)
269 | TLExpr::Cos(e)
270 | TLExpr::Tan(e) => {
271 self.collect_predicates_from_expr(e)?;
272 }
273 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
274 self.collect_predicates_from_expr(body)?;
275 }
276 TLExpr::IfThenElse {
277 condition,
278 then_branch,
279 else_branch,
280 } => {
281 self.collect_predicates_from_expr(condition)?;
282 self.collect_predicates_from_expr(then_branch)?;
283 self.collect_predicates_from_expr(else_branch)?;
284 }
285 TLExpr::Aggregate { body, .. } => {
286 self.collect_predicates_from_expr(body)?;
287 }
288 TLExpr::Let { value, body, .. } => {
289 self.collect_predicates_from_expr(value)?;
290 self.collect_predicates_from_expr(body)?;
291 }
292 TLExpr::Box(inner)
294 | TLExpr::Diamond(inner)
295 | TLExpr::Next(inner)
296 | TLExpr::Eventually(inner)
297 | TLExpr::Always(inner) => {
298 self.collect_predicates_from_expr(inner)?;
299 }
300 TLExpr::Until { before, after }
301 | TLExpr::Release {
302 released: before,
303 releaser: after,
304 }
305 | TLExpr::WeakUntil { before, after }
306 | TLExpr::StrongRelease {
307 released: before,
308 releaser: after,
309 } => {
310 self.collect_predicates_from_expr(before)?;
311 self.collect_predicates_from_expr(after)?;
312 }
313 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
314 self.collect_predicates_from_expr(left)?;
315 self.collect_predicates_from_expr(right)?;
316 }
317 TLExpr::FuzzyNot { expr, .. } => {
318 self.collect_predicates_from_expr(expr)?;
319 }
320 TLExpr::FuzzyImplication {
321 premise,
322 conclusion,
323 ..
324 } => {
325 self.collect_predicates_from_expr(premise)?;
326 self.collect_predicates_from_expr(conclusion)?;
327 }
328 TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
329 self.collect_predicates_from_expr(body)?;
330 }
331 TLExpr::WeightedRule { rule, .. } => {
332 self.collect_predicates_from_expr(rule)?;
333 }
334 TLExpr::ProbabilisticChoice { alternatives } => {
335 for (_prob, expr) in alternatives {
336 self.collect_predicates_from_expr(expr)?;
337 }
338 }
339 TLExpr::CountingExists { body, .. }
341 | TLExpr::CountingForAll { body, .. }
342 | TLExpr::ExactCount { body, .. }
343 | TLExpr::Majority { body, .. } => {
344 self.collect_predicates_from_expr(body)?;
345 }
346 TLExpr::Constant(_) => {}
347 _ => {
349 }
352 }
353 Ok(())
354 }
355
356 pub fn to_json(&self) -> Result<String> {
357 Ok(serde_json::to_string_pretty(self)?)
358 }
359
360 pub fn from_json(json: &str) -> Result<Self> {
361 Ok(serde_json::from_str(json)?)
362 }
363
364 pub fn to_yaml(&self) -> Result<String> {
365 Ok(serde_yaml::to_string(self)?)
366 }
367
368 pub fn from_yaml(yaml: &str) -> Result<Self> {
369 Ok(serde_yaml::from_str(yaml)?)
370 }
371}
372
373impl Default for SymbolTable {
374 fn default() -> Self {
375 Self::new()
376 }
377}