1use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use tensorlogic_ir::TLExpr;
7
8use crate::error::AdapterError;
9use crate::{DomainInfo, PredicateInfo};
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct SymbolTable {
14 pub domains: IndexMap<String, DomainInfo>,
15 pub predicates: IndexMap<String, PredicateInfo>,
16 pub variables: IndexMap<String, String>,
17}
18
19impl SymbolTable {
20 pub fn new() -> Self {
21 SymbolTable {
22 domains: IndexMap::new(),
23 predicates: IndexMap::new(),
24 variables: IndexMap::new(),
25 }
26 }
27
28 pub fn add_domain(&mut self, domain: DomainInfo) -> Result<()> {
29 self.domains.insert(domain.name.clone(), domain);
30 Ok(())
31 }
32
33 pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<()> {
34 for domain in &predicate.arg_domains {
35 if !self.domains.contains_key(domain) {
36 bail!(
37 "Domain '{}' referenced by predicate '{}' does not exist",
38 domain,
39 predicate.name
40 );
41 }
42 }
43 self.predicates.insert(predicate.name.clone(), predicate);
44 Ok(())
45 }
46
47 pub fn bind_variable(
48 &mut self,
49 var: impl Into<String>,
50 domain: impl Into<String>,
51 ) -> Result<()> {
52 let var = var.into();
53 let domain = domain.into();
54
55 if !self.domains.contains_key(&domain) {
56 return Err(AdapterError::DomainNotFound(domain).into());
57 }
58
59 self.variables.insert(var, domain);
60 Ok(())
61 }
62
63 pub fn get_domain(&self, name: &str) -> Option<&DomainInfo> {
64 self.domains.get(name)
65 }
66
67 pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
68 self.predicates.get(name)
69 }
70
71 pub fn get_variable_domain(&self, var: &str) -> Option<&str> {
72 self.variables.get(var).map(|s| s.as_str())
73 }
74
75 pub fn infer_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
76 self.collect_domains_from_expr(expr)?;
77 self.collect_predicates_from_expr(expr)?;
78 Ok(())
79 }
80
81 fn collect_domains_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
82 match expr {
83 TLExpr::Exists { domain, var, body } | TLExpr::ForAll { domain, var, body } => {
84 if !self.domains.contains_key(domain) {
85 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
86 }
87 self.bind_variable(var, domain)?;
88 self.collect_domains_from_expr(body)?;
89 }
90 TLExpr::And(l, r)
91 | TLExpr::Or(l, r)
92 | TLExpr::Imply(l, r)
93 | TLExpr::Add(l, r)
94 | TLExpr::Sub(l, r)
95 | TLExpr::Mul(l, r)
96 | TLExpr::Div(l, r)
97 | TLExpr::Pow(l, r)
98 | TLExpr::Mod(l, r)
99 | TLExpr::Min(l, r)
100 | TLExpr::Max(l, r)
101 | TLExpr::Eq(l, r)
102 | TLExpr::Lt(l, r)
103 | TLExpr::Gt(l, r)
104 | TLExpr::Lte(l, r)
105 | TLExpr::Gte(l, r) => {
106 self.collect_domains_from_expr(l)?;
107 self.collect_domains_from_expr(r)?;
108 }
109 TLExpr::Not(e)
110 | TLExpr::Score(e)
111 | TLExpr::Abs(e)
112 | TLExpr::Floor(e)
113 | TLExpr::Ceil(e)
114 | TLExpr::Round(e)
115 | TLExpr::Sqrt(e)
116 | TLExpr::Exp(e)
117 | TLExpr::Log(e)
118 | TLExpr::Sin(e)
119 | TLExpr::Cos(e)
120 | TLExpr::Tan(e) => {
121 self.collect_domains_from_expr(e)?;
122 }
123 TLExpr::IfThenElse {
124 condition,
125 then_branch,
126 else_branch,
127 } => {
128 self.collect_domains_from_expr(condition)?;
129 self.collect_domains_from_expr(then_branch)?;
130 self.collect_domains_from_expr(else_branch)?;
131 }
132 TLExpr::Aggregate {
133 domain, var, body, ..
134 } => {
135 if !self.domains.contains_key(domain) {
136 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
137 }
138 self.bind_variable(var, domain)?;
139 self.collect_domains_from_expr(body)?;
140 }
141 TLExpr::Let {
142 var: _,
143 value,
144 body,
145 } => {
146 self.collect_domains_from_expr(value)?;
147 self.collect_domains_from_expr(body)?;
149 }
150 TLExpr::Box(inner)
152 | TLExpr::Diamond(inner)
153 | TLExpr::Next(inner)
154 | TLExpr::Eventually(inner)
155 | TLExpr::Always(inner) => {
156 self.collect_domains_from_expr(inner)?;
157 }
158 TLExpr::Until { before, after }
159 | TLExpr::Release {
160 released: before,
161 releaser: after,
162 }
163 | TLExpr::WeakUntil { before, after }
164 | TLExpr::StrongRelease {
165 released: before,
166 releaser: after,
167 } => {
168 self.collect_domains_from_expr(before)?;
169 self.collect_domains_from_expr(after)?;
170 }
171 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
172 self.collect_domains_from_expr(left)?;
173 self.collect_domains_from_expr(right)?;
174 }
175 TLExpr::FuzzyNot { expr, .. } => {
176 self.collect_domains_from_expr(expr)?;
177 }
178 TLExpr::FuzzyImplication {
179 premise,
180 conclusion,
181 ..
182 } => {
183 self.collect_domains_from_expr(premise)?;
184 self.collect_domains_from_expr(conclusion)?;
185 }
186 TLExpr::SoftExists {
187 domain, var, body, ..
188 }
189 | TLExpr::SoftForAll {
190 domain, var, body, ..
191 } => {
192 if !self.domains.contains_key(domain) {
193 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
194 }
195 self.bind_variable(var, domain)?;
196 self.collect_domains_from_expr(body)?;
197 }
198 TLExpr::WeightedRule { rule, .. } => {
199 self.collect_domains_from_expr(rule)?;
200 }
201 TLExpr::ProbabilisticChoice { alternatives } => {
202 for (_prob, expr) in alternatives {
203 self.collect_domains_from_expr(expr)?;
204 }
205 }
206 TLExpr::CountingExists {
208 domain, var, body, ..
209 }
210 | TLExpr::CountingForAll {
211 domain, var, body, ..
212 }
213 | TLExpr::ExactCount {
214 domain, var, body, ..
215 }
216 | TLExpr::Majority { domain, var, body } => {
217 if !self.domains.contains_key(domain) {
218 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
219 }
220 self.bind_variable(var, domain)?;
221 self.collect_domains_from_expr(body)?;
222 }
223 TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
224 _ => {
226 }
229 }
230 Ok(())
231 }
232
233 fn collect_predicates_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
234 match expr {
235 TLExpr::Pred { name, args } => {
236 if !self.predicates.contains_key(name) {
237 let arg_domains: Vec<String> =
238 args.iter().map(|_| "Unknown".to_string()).collect();
239 self.predicates
240 .insert(name.clone(), PredicateInfo::new(name.clone(), arg_domains));
241 }
242 }
243 TLExpr::And(l, r)
244 | TLExpr::Or(l, r)
245 | TLExpr::Imply(l, r)
246 | TLExpr::Add(l, r)
247 | TLExpr::Sub(l, r)
248 | TLExpr::Mul(l, r)
249 | TLExpr::Div(l, r)
250 | TLExpr::Pow(l, r)
251 | TLExpr::Mod(l, r)
252 | TLExpr::Min(l, r)
253 | TLExpr::Max(l, r)
254 | TLExpr::Eq(l, r)
255 | TLExpr::Lt(l, r)
256 | TLExpr::Gt(l, r)
257 | TLExpr::Lte(l, r)
258 | TLExpr::Gte(l, r) => {
259 self.collect_predicates_from_expr(l)?;
260 self.collect_predicates_from_expr(r)?;
261 }
262 TLExpr::Not(e)
263 | TLExpr::Score(e)
264 | TLExpr::Abs(e)
265 | TLExpr::Floor(e)
266 | TLExpr::Ceil(e)
267 | TLExpr::Round(e)
268 | TLExpr::Sqrt(e)
269 | TLExpr::Exp(e)
270 | TLExpr::Log(e)
271 | TLExpr::Sin(e)
272 | TLExpr::Cos(e)
273 | TLExpr::Tan(e) => {
274 self.collect_predicates_from_expr(e)?;
275 }
276 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
277 self.collect_predicates_from_expr(body)?;
278 }
279 TLExpr::IfThenElse {
280 condition,
281 then_branch,
282 else_branch,
283 } => {
284 self.collect_predicates_from_expr(condition)?;
285 self.collect_predicates_from_expr(then_branch)?;
286 self.collect_predicates_from_expr(else_branch)?;
287 }
288 TLExpr::Aggregate { body, .. } => {
289 self.collect_predicates_from_expr(body)?;
290 }
291 TLExpr::Let { value, body, .. } => {
292 self.collect_predicates_from_expr(value)?;
293 self.collect_predicates_from_expr(body)?;
294 }
295 TLExpr::Box(inner)
297 | TLExpr::Diamond(inner)
298 | TLExpr::Next(inner)
299 | TLExpr::Eventually(inner)
300 | TLExpr::Always(inner) => {
301 self.collect_predicates_from_expr(inner)?;
302 }
303 TLExpr::Until { before, after }
304 | TLExpr::Release {
305 released: before,
306 releaser: after,
307 }
308 | TLExpr::WeakUntil { before, after }
309 | TLExpr::StrongRelease {
310 released: before,
311 releaser: after,
312 } => {
313 self.collect_predicates_from_expr(before)?;
314 self.collect_predicates_from_expr(after)?;
315 }
316 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
317 self.collect_predicates_from_expr(left)?;
318 self.collect_predicates_from_expr(right)?;
319 }
320 TLExpr::FuzzyNot { expr, .. } => {
321 self.collect_predicates_from_expr(expr)?;
322 }
323 TLExpr::FuzzyImplication {
324 premise,
325 conclusion,
326 ..
327 } => {
328 self.collect_predicates_from_expr(premise)?;
329 self.collect_predicates_from_expr(conclusion)?;
330 }
331 TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
332 self.collect_predicates_from_expr(body)?;
333 }
334 TLExpr::WeightedRule { rule, .. } => {
335 self.collect_predicates_from_expr(rule)?;
336 }
337 TLExpr::ProbabilisticChoice { alternatives } => {
338 for (_prob, expr) in alternatives {
339 self.collect_predicates_from_expr(expr)?;
340 }
341 }
342 TLExpr::CountingExists { body, .. }
344 | TLExpr::CountingForAll { body, .. }
345 | TLExpr::ExactCount { body, .. }
346 | TLExpr::Majority { body, .. } => {
347 self.collect_predicates_from_expr(body)?;
348 }
349 TLExpr::Constant(_) => {}
350 _ => {
352 }
355 }
356 Ok(())
357 }
358
359 pub fn to_json(&self) -> Result<String> {
360 Ok(serde_json::to_string_pretty(self)?)
361 }
362
363 pub fn from_json(json: &str) -> Result<Self> {
364 Ok(serde_json::from_str(json)?)
365 }
366
367 pub fn to_yaml(&self) -> Result<String> {
368 Ok(serde_yaml::to_string(self)?)
369 }
370
371 pub fn from_yaml(yaml: &str) -> Result<Self> {
372 Ok(serde_yaml::from_str(yaml)?)
373 }
374}
375
376impl Default for SymbolTable {
377 fn default() -> Self {
378 Self::new()
379 }
380}