1use std::collections::HashMap;
4
5use tensorlogic_ir::TLExpr;
6
7#[derive(Debug, Clone)]
9pub struct CseResult {
10 pub optimized_expr: TLExpr,
11 pub eliminated_count: usize,
12}
13
14pub fn eliminate_common_subexpressions(expr: &TLExpr) -> CseResult {
16 let mut cache: HashMap<String, TLExpr> = HashMap::new();
17 let mut eliminated_count = 0;
18
19 let optimized = cse_recursive(expr, &mut cache, &mut eliminated_count);
20
21 CseResult {
22 optimized_expr: optimized,
23 eliminated_count,
24 }
25}
26
27fn cse_recursive(
28 expr: &TLExpr,
29 cache: &mut HashMap<String, TLExpr>,
30 eliminated_count: &mut usize,
31) -> TLExpr {
32 let key = expr_to_key(expr);
34
35 if let Some(cached) = cache.get(&key) {
37 *eliminated_count += 1;
38 return cached.clone();
39 }
40
41 let result = match expr {
43 TLExpr::Pred { .. } => {
44 expr.clone()
46 }
47 TLExpr::And(left, right) => {
48 let left_opt = cse_recursive(left, cache, eliminated_count);
49 let right_opt = cse_recursive(right, cache, eliminated_count);
50 TLExpr::and(left_opt, right_opt)
51 }
52 TLExpr::Or(left, right) => {
53 let left_opt = cse_recursive(left, cache, eliminated_count);
54 let right_opt = cse_recursive(right, cache, eliminated_count);
55 TLExpr::or(left_opt, right_opt)
56 }
57 TLExpr::Imply(premise, conclusion) => {
58 let premise_opt = cse_recursive(premise, cache, eliminated_count);
59 let conclusion_opt = cse_recursive(conclusion, cache, eliminated_count);
60 TLExpr::imply(premise_opt, conclusion_opt)
61 }
62 TLExpr::Not(inner) => {
63 let inner_opt = cse_recursive(inner, cache, eliminated_count);
64 TLExpr::negate(inner_opt)
65 }
66 TLExpr::Exists { var, domain, body } => {
67 let body_opt = cse_recursive(body, cache, eliminated_count);
68 TLExpr::exists(var, domain, body_opt)
69 }
70 TLExpr::ForAll { var, domain, body } => {
71 let body_opt = cse_recursive(body, cache, eliminated_count);
72 TLExpr::forall(var, domain, body_opt)
73 }
74 TLExpr::Aggregate {
75 op,
76 var,
77 domain,
78 body,
79 group_by,
80 } => {
81 let body_opt = cse_recursive(body, cache, eliminated_count);
82 TLExpr::aggregate_with_group_by(
83 op.clone(),
84 var,
85 domain,
86 body_opt,
87 group_by.clone().unwrap_or_default(),
88 )
89 }
90 TLExpr::Score(inner) => {
91 let inner_opt = cse_recursive(inner, cache, eliminated_count);
92 TLExpr::score(inner_opt)
93 }
94 TLExpr::Add(left, right) => {
96 let left_opt = cse_recursive(left, cache, eliminated_count);
97 let right_opt = cse_recursive(right, cache, eliminated_count);
98 TLExpr::add(left_opt, right_opt)
99 }
100 TLExpr::Sub(left, right) => {
101 let left_opt = cse_recursive(left, cache, eliminated_count);
102 let right_opt = cse_recursive(right, cache, eliminated_count);
103 TLExpr::sub(left_opt, right_opt)
104 }
105 TLExpr::Mul(left, right) => {
106 let left_opt = cse_recursive(left, cache, eliminated_count);
107 let right_opt = cse_recursive(right, cache, eliminated_count);
108 TLExpr::mul(left_opt, right_opt)
109 }
110 TLExpr::Div(left, right) => {
111 let left_opt = cse_recursive(left, cache, eliminated_count);
112 let right_opt = cse_recursive(right, cache, eliminated_count);
113 TLExpr::div(left_opt, right_opt)
114 }
115 TLExpr::Eq(left, right) => {
117 let left_opt = cse_recursive(left, cache, eliminated_count);
118 let right_opt = cse_recursive(right, cache, eliminated_count);
119 TLExpr::eq(left_opt, right_opt)
120 }
121 TLExpr::Lt(left, right) => {
122 let left_opt = cse_recursive(left, cache, eliminated_count);
123 let right_opt = cse_recursive(right, cache, eliminated_count);
124 TLExpr::lt(left_opt, right_opt)
125 }
126 TLExpr::Gt(left, right) => {
127 let left_opt = cse_recursive(left, cache, eliminated_count);
128 let right_opt = cse_recursive(right, cache, eliminated_count);
129 TLExpr::gt(left_opt, right_opt)
130 }
131 TLExpr::Lte(left, right) => {
132 let left_opt = cse_recursive(left, cache, eliminated_count);
133 let right_opt = cse_recursive(right, cache, eliminated_count);
134 TLExpr::lte(left_opt, right_opt)
135 }
136 TLExpr::Gte(left, right) => {
137 let left_opt = cse_recursive(left, cache, eliminated_count);
138 let right_opt = cse_recursive(right, cache, eliminated_count);
139 TLExpr::gte(left_opt, right_opt)
140 }
141 TLExpr::Pow(left, right) => {
142 let left_opt = cse_recursive(left, cache, eliminated_count);
143 let right_opt = cse_recursive(right, cache, eliminated_count);
144 TLExpr::pow(left_opt, right_opt)
145 }
146 TLExpr::Mod(left, right) => {
147 let left_opt = cse_recursive(left, cache, eliminated_count);
148 let right_opt = cse_recursive(right, cache, eliminated_count);
149 TLExpr::modulo(left_opt, right_opt)
150 }
151 TLExpr::Min(left, right) => {
152 let left_opt = cse_recursive(left, cache, eliminated_count);
153 let right_opt = cse_recursive(right, cache, eliminated_count);
154 TLExpr::min(left_opt, right_opt)
155 }
156 TLExpr::Max(left, right) => {
157 let left_opt = cse_recursive(left, cache, eliminated_count);
158 let right_opt = cse_recursive(right, cache, eliminated_count);
159 TLExpr::max(left_opt, right_opt)
160 }
161 TLExpr::Abs(inner) => {
163 let inner_opt = cse_recursive(inner, cache, eliminated_count);
164 TLExpr::abs(inner_opt)
165 }
166 TLExpr::Floor(inner) => {
167 let inner_opt = cse_recursive(inner, cache, eliminated_count);
168 TLExpr::floor(inner_opt)
169 }
170 TLExpr::Ceil(inner) => {
171 let inner_opt = cse_recursive(inner, cache, eliminated_count);
172 TLExpr::ceil(inner_opt)
173 }
174 TLExpr::Round(inner) => {
175 let inner_opt = cse_recursive(inner, cache, eliminated_count);
176 TLExpr::round(inner_opt)
177 }
178 TLExpr::Sqrt(inner) => {
179 let inner_opt = cse_recursive(inner, cache, eliminated_count);
180 TLExpr::sqrt(inner_opt)
181 }
182 TLExpr::Exp(inner) => {
183 let inner_opt = cse_recursive(inner, cache, eliminated_count);
184 TLExpr::exp(inner_opt)
185 }
186 TLExpr::Log(inner) => {
187 let inner_opt = cse_recursive(inner, cache, eliminated_count);
188 TLExpr::log(inner_opt)
189 }
190 TLExpr::Sin(inner) => {
191 let inner_opt = cse_recursive(inner, cache, eliminated_count);
192 TLExpr::sin(inner_opt)
193 }
194 TLExpr::Cos(inner) => {
195 let inner_opt = cse_recursive(inner, cache, eliminated_count);
196 TLExpr::cos(inner_opt)
197 }
198 TLExpr::Tan(inner) => {
199 let inner_opt = cse_recursive(inner, cache, eliminated_count);
200 TLExpr::tan(inner_opt)
201 }
202 TLExpr::Let { var, value, body } => {
204 let value_opt = cse_recursive(value, cache, eliminated_count);
205 let body_opt = cse_recursive(body, cache, eliminated_count);
206 TLExpr::let_binding(var, value_opt, body_opt)
207 }
208 TLExpr::IfThenElse {
210 condition,
211 then_branch,
212 else_branch,
213 } => {
214 let cond_opt = cse_recursive(condition, cache, eliminated_count);
215 let then_opt = cse_recursive(then_branch, cache, eliminated_count);
216 let else_opt = cse_recursive(else_branch, cache, eliminated_count);
217 TLExpr::if_then_else(cond_opt, then_opt, else_opt)
218 }
219 TLExpr::Constant(_) => {
221 expr.clone()
223 }
224
225 TLExpr::Box(inner) => {
227 let inner_opt = cse_recursive(inner, cache, eliminated_count);
228 TLExpr::Box(Box::new(inner_opt))
229 }
230 TLExpr::Diamond(inner) => {
231 let inner_opt = cse_recursive(inner, cache, eliminated_count);
232 TLExpr::Diamond(Box::new(inner_opt))
233 }
234 TLExpr::Next(inner) => {
235 let inner_opt = cse_recursive(inner, cache, eliminated_count);
236 TLExpr::Next(Box::new(inner_opt))
237 }
238 TLExpr::Eventually(inner) => {
239 let inner_opt = cse_recursive(inner, cache, eliminated_count);
240 TLExpr::Eventually(Box::new(inner_opt))
241 }
242 TLExpr::Always(inner) => {
243 let inner_opt = cse_recursive(inner, cache, eliminated_count);
244 TLExpr::Always(Box::new(inner_opt))
245 }
246 TLExpr::Until { before, after } => {
247 let before_opt = cse_recursive(before, cache, eliminated_count);
248 let after_opt = cse_recursive(after, cache, eliminated_count);
249 TLExpr::Until {
250 before: Box::new(before_opt),
251 after: Box::new(after_opt),
252 }
253 }
254 TLExpr::TNorm { kind, left, right } => {
256 let left_opt = cse_recursive(left, cache, eliminated_count);
257 let right_opt = cse_recursive(right, cache, eliminated_count);
258 TLExpr::TNorm {
259 kind: *kind,
260 left: Box::new(left_opt),
261 right: Box::new(right_opt),
262 }
263 }
264 TLExpr::TCoNorm { kind, left, right } => {
265 let left_opt = cse_recursive(left, cache, eliminated_count);
266 let right_opt = cse_recursive(right, cache, eliminated_count);
267 TLExpr::TCoNorm {
268 kind: *kind,
269 left: Box::new(left_opt),
270 right: Box::new(right_opt),
271 }
272 }
273 TLExpr::FuzzyNot { kind, expr: inner } => {
274 let inner_opt = cse_recursive(inner, cache, eliminated_count);
275 TLExpr::FuzzyNot {
276 kind: *kind,
277 expr: Box::new(inner_opt),
278 }
279 }
280 TLExpr::FuzzyImplication {
281 kind,
282 premise,
283 conclusion,
284 } => {
285 let premise_opt = cse_recursive(premise, cache, eliminated_count);
286 let conclusion_opt = cse_recursive(conclusion, cache, eliminated_count);
287 TLExpr::FuzzyImplication {
288 kind: *kind,
289 premise: Box::new(premise_opt),
290 conclusion: Box::new(conclusion_opt),
291 }
292 }
293 TLExpr::SoftExists {
295 var,
296 domain,
297 body,
298 temperature,
299 } => {
300 let body_opt = cse_recursive(body, cache, eliminated_count);
301 TLExpr::SoftExists {
302 var: var.clone(),
303 domain: domain.clone(),
304 body: Box::new(body_opt),
305 temperature: *temperature,
306 }
307 }
308 TLExpr::SoftForAll {
309 var,
310 domain,
311 body,
312 temperature,
313 } => {
314 let body_opt = cse_recursive(body, cache, eliminated_count);
315 TLExpr::SoftForAll {
316 var: var.clone(),
317 domain: domain.clone(),
318 body: Box::new(body_opt),
319 temperature: *temperature,
320 }
321 }
322 TLExpr::WeightedRule { weight, rule } => {
324 let rule_opt = cse_recursive(rule, cache, eliminated_count);
325 TLExpr::WeightedRule {
326 weight: *weight,
327 rule: Box::new(rule_opt),
328 }
329 }
330 TLExpr::ProbabilisticChoice { alternatives } => {
331 let alts_opt: Vec<(f64, TLExpr)> = alternatives
332 .iter()
333 .map(|(prob, expr)| (*prob, cse_recursive(expr, cache, eliminated_count)))
334 .collect();
335 TLExpr::ProbabilisticChoice {
336 alternatives: alts_opt,
337 }
338 }
339 TLExpr::Release { released, releaser } => {
341 let released_opt = cse_recursive(released, cache, eliminated_count);
342 let releaser_opt = cse_recursive(releaser, cache, eliminated_count);
343 TLExpr::Release {
344 released: Box::new(released_opt),
345 releaser: Box::new(releaser_opt),
346 }
347 }
348 TLExpr::WeakUntil { before, after } => {
349 let before_opt = cse_recursive(before, cache, eliminated_count);
350 let after_opt = cse_recursive(after, cache, eliminated_count);
351 TLExpr::WeakUntil {
352 before: Box::new(before_opt),
353 after: Box::new(after_opt),
354 }
355 }
356 TLExpr::StrongRelease { released, releaser } => {
357 let released_opt = cse_recursive(released, cache, eliminated_count);
358 let releaser_opt = cse_recursive(releaser, cache, eliminated_count);
359 TLExpr::StrongRelease {
360 released: Box::new(released_opt),
361 releaser: Box::new(releaser_opt),
362 }
363 }
364 _ => expr.clone(),
366 };
367
368 cache.insert(key, result.clone());
370 result
371}
372
373fn expr_to_key(expr: &TLExpr) -> String {
375 format!("{:?}", expr)
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use tensorlogic_ir::Term;
384
385 #[test]
386 fn test_cse_no_duplicates() {
387 let expr = TLExpr::and(
388 TLExpr::pred("p", vec![Term::var("x")]),
389 TLExpr::pred("q", vec![Term::var("y")]),
390 );
391
392 let result = eliminate_common_subexpressions(&expr);
393 assert_eq!(result.eliminated_count, 0);
394 }
395
396 #[test]
397 fn test_cse_duplicate_predicates() {
398 let p_x = TLExpr::pred("p", vec![Term::var("x")]);
400 let expr = TLExpr::and(p_x.clone(), p_x);
401
402 let result = eliminate_common_subexpressions(&expr);
403 assert!(result.eliminated_count > 0);
405 }
406
407 #[test]
408 fn test_cse_nested_duplicates() {
409 let p_x = TLExpr::pred("p", vec![Term::var("x")]);
411 let q_y = TLExpr::pred("q", vec![Term::var("y")]);
412 let sub = TLExpr::and(p_x, q_y);
413 let expr = TLExpr::and(sub.clone(), sub);
414
415 let result = eliminate_common_subexpressions(&expr);
416 assert!(result.eliminated_count > 0);
417 }
418
419 #[test]
420 fn test_cse_with_quantifiers() {
421 let p_x = TLExpr::pred("p", vec![Term::var("x")]);
423 let exists = TLExpr::exists("x", "Domain", p_x);
424 let expr = TLExpr::and(exists.clone(), exists);
425
426 let result = eliminate_common_subexpressions(&expr);
427 assert!(result.eliminated_count > 0);
428 }
429
430 #[test]
431 fn test_cse_preserves_semantics() {
432 let p_x = TLExpr::pred("p", vec![Term::var("x")]);
434 let q_y = TLExpr::pred("q", vec![Term::var("y")]);
435 let expr = TLExpr::and(p_x.clone(), q_y.clone());
436
437 let result = eliminate_common_subexpressions(&expr);
438
439 match result.optimized_expr {
441 TLExpr::And(left, right) => {
442 assert!(matches!(*left, TLExpr::Pred { .. }));
443 assert!(matches!(*right, TLExpr::Pred { .. }));
444 }
445 _ => panic!("Expected And expression"),
446 }
447 }
448
449 #[test]
450 fn test_cse_complex_expression() {
451 let p_x = TLExpr::pred("p", vec![Term::var("x")]);
453 let q_y = TLExpr::pred("q", vec![Term::var("y")]);
454 let or_expr = TLExpr::or(q_y, p_x.clone());
455 let expr = TLExpr::and(p_x, or_expr);
456
457 let result = eliminate_common_subexpressions(&expr);
458 assert!(result.eliminated_count > 0);
459 }
460}