1use anyhow::{bail, Result};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6use tensorlogic_ir::TLExpr;
7
8use crate::error::AdapterError;
9use crate::{DomainInfo, PredicateInfo};
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
13pub struct SymbolTable {
14 pub domains: IndexMap<String, DomainInfo>,
15 pub predicates: IndexMap<String, PredicateInfo>,
16 pub variables: IndexMap<String, String>,
17}
18
19impl SymbolTable {
20 pub fn new() -> Self {
21 SymbolTable {
22 domains: IndexMap::new(),
23 predicates: IndexMap::new(),
24 variables: IndexMap::new(),
25 }
26 }
27
28 pub fn add_domain(&mut self, domain: DomainInfo) -> Result<()> {
29 self.domains.insert(domain.name.clone(), domain);
30 Ok(())
31 }
32
33 pub fn add_predicate(&mut self, predicate: PredicateInfo) -> Result<()> {
34 for domain in &predicate.arg_domains {
35 if !self.domains.contains_key(domain) {
36 bail!(
37 "Domain '{}' referenced by predicate '{}' does not exist",
38 domain,
39 predicate.name
40 );
41 }
42 }
43 self.predicates.insert(predicate.name.clone(), predicate);
44 Ok(())
45 }
46
47 pub fn bind_variable(
48 &mut self,
49 var: impl Into<String>,
50 domain: impl Into<String>,
51 ) -> Result<()> {
52 let var = var.into();
53 let domain = domain.into();
54
55 if !self.domains.contains_key(&domain) {
56 return Err(AdapterError::DomainNotFound(domain).into());
57 }
58
59 self.variables.insert(var, domain);
60 Ok(())
61 }
62
63 pub fn get_domain(&self, name: &str) -> Option<&DomainInfo> {
64 self.domains.get(name)
65 }
66
67 pub fn get_predicate(&self, name: &str) -> Option<&PredicateInfo> {
68 self.predicates.get(name)
69 }
70
71 pub fn get_variable_domain(&self, var: &str) -> Option<&str> {
72 self.variables.get(var).map(|s| s.as_str())
73 }
74
75 pub fn infer_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
76 self.collect_domains_from_expr(expr)?;
77 self.collect_predicates_from_expr(expr)?;
78 Ok(())
79 }
80
81 fn collect_domains_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
82 match expr {
83 TLExpr::Exists { domain, var, body } | TLExpr::ForAll { domain, var, body } => {
84 if !self.domains.contains_key(domain) {
85 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
86 }
87 self.bind_variable(var, domain)?;
88 self.collect_domains_from_expr(body)?;
89 }
90 TLExpr::And(l, r)
91 | TLExpr::Or(l, r)
92 | TLExpr::Imply(l, r)
93 | TLExpr::Add(l, r)
94 | TLExpr::Sub(l, r)
95 | TLExpr::Mul(l, r)
96 | TLExpr::Div(l, r)
97 | TLExpr::Pow(l, r)
98 | TLExpr::Mod(l, r)
99 | TLExpr::Min(l, r)
100 | TLExpr::Max(l, r)
101 | TLExpr::Eq(l, r)
102 | TLExpr::Lt(l, r)
103 | TLExpr::Gt(l, r)
104 | TLExpr::Lte(l, r)
105 | TLExpr::Gte(l, r) => {
106 self.collect_domains_from_expr(l)?;
107 self.collect_domains_from_expr(r)?;
108 }
109 TLExpr::Not(e)
110 | TLExpr::Score(e)
111 | TLExpr::Abs(e)
112 | TLExpr::Floor(e)
113 | TLExpr::Ceil(e)
114 | TLExpr::Round(e)
115 | TLExpr::Sqrt(e)
116 | TLExpr::Exp(e)
117 | TLExpr::Log(e)
118 | TLExpr::Sin(e)
119 | TLExpr::Cos(e)
120 | TLExpr::Tan(e) => {
121 self.collect_domains_from_expr(e)?;
122 }
123 TLExpr::IfThenElse {
124 condition,
125 then_branch,
126 else_branch,
127 } => {
128 self.collect_domains_from_expr(condition)?;
129 self.collect_domains_from_expr(then_branch)?;
130 self.collect_domains_from_expr(else_branch)?;
131 }
132 TLExpr::Aggregate {
133 domain, var, body, ..
134 } => {
135 if !self.domains.contains_key(domain) {
136 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
137 }
138 self.bind_variable(var, domain)?;
139 self.collect_domains_from_expr(body)?;
140 }
141 TLExpr::Let {
142 var: _,
143 value,
144 body,
145 } => {
146 self.collect_domains_from_expr(value)?;
147 self.collect_domains_from_expr(body)?;
149 }
150 TLExpr::Box(inner)
152 | TLExpr::Diamond(inner)
153 | TLExpr::Next(inner)
154 | TLExpr::Eventually(inner)
155 | TLExpr::Always(inner) => {
156 self.collect_domains_from_expr(inner)?;
157 }
158 TLExpr::Until { before, after }
159 | TLExpr::Release {
160 released: before,
161 releaser: after,
162 }
163 | TLExpr::WeakUntil { before, after }
164 | TLExpr::StrongRelease {
165 released: before,
166 releaser: after,
167 } => {
168 self.collect_domains_from_expr(before)?;
169 self.collect_domains_from_expr(after)?;
170 }
171 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
172 self.collect_domains_from_expr(left)?;
173 self.collect_domains_from_expr(right)?;
174 }
175 TLExpr::FuzzyNot { expr, .. } => {
176 self.collect_domains_from_expr(expr)?;
177 }
178 TLExpr::FuzzyImplication {
179 premise,
180 conclusion,
181 ..
182 } => {
183 self.collect_domains_from_expr(premise)?;
184 self.collect_domains_from_expr(conclusion)?;
185 }
186 TLExpr::SoftExists {
187 domain, var, body, ..
188 }
189 | TLExpr::SoftForAll {
190 domain, var, body, ..
191 } => {
192 if !self.domains.contains_key(domain) {
193 self.add_domain(DomainInfo::new(domain.clone(), 0))?;
194 }
195 self.bind_variable(var, domain)?;
196 self.collect_domains_from_expr(body)?;
197 }
198 TLExpr::WeightedRule { rule, .. } => {
199 self.collect_domains_from_expr(rule)?;
200 }
201 TLExpr::ProbabilisticChoice { alternatives } => {
202 for (_prob, expr) in alternatives {
203 self.collect_domains_from_expr(expr)?;
204 }
205 }
206 TLExpr::Pred { .. } | TLExpr::Constant(_) => {}
207 }
208 Ok(())
209 }
210
211 fn collect_predicates_from_expr(&mut self, expr: &TLExpr) -> Result<()> {
212 match expr {
213 TLExpr::Pred { name, args } => {
214 if !self.predicates.contains_key(name) {
215 let arg_domains: Vec<String> =
216 args.iter().map(|_| "Unknown".to_string()).collect();
217 self.predicates
218 .insert(name.clone(), PredicateInfo::new(name.clone(), arg_domains));
219 }
220 }
221 TLExpr::And(l, r)
222 | TLExpr::Or(l, r)
223 | TLExpr::Imply(l, r)
224 | TLExpr::Add(l, r)
225 | TLExpr::Sub(l, r)
226 | TLExpr::Mul(l, r)
227 | TLExpr::Div(l, r)
228 | TLExpr::Pow(l, r)
229 | TLExpr::Mod(l, r)
230 | TLExpr::Min(l, r)
231 | TLExpr::Max(l, r)
232 | TLExpr::Eq(l, r)
233 | TLExpr::Lt(l, r)
234 | TLExpr::Gt(l, r)
235 | TLExpr::Lte(l, r)
236 | TLExpr::Gte(l, r) => {
237 self.collect_predicates_from_expr(l)?;
238 self.collect_predicates_from_expr(r)?;
239 }
240 TLExpr::Not(e)
241 | TLExpr::Score(e)
242 | TLExpr::Abs(e)
243 | TLExpr::Floor(e)
244 | TLExpr::Ceil(e)
245 | TLExpr::Round(e)
246 | TLExpr::Sqrt(e)
247 | TLExpr::Exp(e)
248 | TLExpr::Log(e)
249 | TLExpr::Sin(e)
250 | TLExpr::Cos(e)
251 | TLExpr::Tan(e) => {
252 self.collect_predicates_from_expr(e)?;
253 }
254 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
255 self.collect_predicates_from_expr(body)?;
256 }
257 TLExpr::IfThenElse {
258 condition,
259 then_branch,
260 else_branch,
261 } => {
262 self.collect_predicates_from_expr(condition)?;
263 self.collect_predicates_from_expr(then_branch)?;
264 self.collect_predicates_from_expr(else_branch)?;
265 }
266 TLExpr::Aggregate { body, .. } => {
267 self.collect_predicates_from_expr(body)?;
268 }
269 TLExpr::Let { value, body, .. } => {
270 self.collect_predicates_from_expr(value)?;
271 self.collect_predicates_from_expr(body)?;
272 }
273 TLExpr::Box(inner)
275 | TLExpr::Diamond(inner)
276 | TLExpr::Next(inner)
277 | TLExpr::Eventually(inner)
278 | TLExpr::Always(inner) => {
279 self.collect_predicates_from_expr(inner)?;
280 }
281 TLExpr::Until { before, after }
282 | TLExpr::Release {
283 released: before,
284 releaser: after,
285 }
286 | TLExpr::WeakUntil { before, after }
287 | TLExpr::StrongRelease {
288 released: before,
289 releaser: after,
290 } => {
291 self.collect_predicates_from_expr(before)?;
292 self.collect_predicates_from_expr(after)?;
293 }
294 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
295 self.collect_predicates_from_expr(left)?;
296 self.collect_predicates_from_expr(right)?;
297 }
298 TLExpr::FuzzyNot { expr, .. } => {
299 self.collect_predicates_from_expr(expr)?;
300 }
301 TLExpr::FuzzyImplication {
302 premise,
303 conclusion,
304 ..
305 } => {
306 self.collect_predicates_from_expr(premise)?;
307 self.collect_predicates_from_expr(conclusion)?;
308 }
309 TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
310 self.collect_predicates_from_expr(body)?;
311 }
312 TLExpr::WeightedRule { rule, .. } => {
313 self.collect_predicates_from_expr(rule)?;
314 }
315 TLExpr::ProbabilisticChoice { alternatives } => {
316 for (_prob, expr) in alternatives {
317 self.collect_predicates_from_expr(expr)?;
318 }
319 }
320 TLExpr::Constant(_) => {}
321 }
322 Ok(())
323 }
324
325 pub fn to_json(&self) -> Result<String> {
326 Ok(serde_json::to_string_pretty(self)?)
327 }
328
329 pub fn from_json(json: &str) -> Result<Self> {
330 Ok(serde_json::from_str(json)?)
331 }
332
333 pub fn to_yaml(&self) -> Result<String> {
334 Ok(serde_yaml::to_string(self)?)
335 }
336
337 pub fn from_yaml(yaml: &str) -> Result<Self> {
338 Ok(serde_yaml::from_str(yaml)?)
339 }
340}
341
342impl Default for SymbolTable {
343 fn default() -> Self {
344 Self::new()
345 }
346}