1mod abduction;
4mod arithmetic;
5mod comparison;
6mod conditional;
7mod constraints;
8mod counting_quantifiers;
9pub mod custom_ops;
10mod fixpoint;
11mod fuzzy;
12mod higher_order;
13mod hybrid;
14mod implication;
15mod let_binding;
16mod logic_ops;
17mod modal_temporal;
18mod pattern_match;
19mod predicate;
20mod probabilistic;
21mod quantifiers;
22mod set_operations;
23mod strategy_mapping;
24
25use anyhow::Result;
26use tensorlogic_ir::{EinsumGraph, TLExpr};
27
28use crate::context::{CompileState, CompilerContext};
29
30pub(crate) use abduction::{compile_abducible, compile_explain};
31pub(crate) use arithmetic::{
32 compile_abs, compile_add, compile_ceil, compile_cos, compile_div, compile_exp, compile_floor,
33 compile_log, compile_max_binary, compile_min_binary, compile_mod, compile_mul, compile_pow,
34 compile_round, compile_sin, compile_sqrt, compile_sub, compile_tan,
35};
36pub(crate) use comparison::{compile_eq, compile_gt, compile_gte, compile_lt, compile_lte};
37pub(crate) use conditional::{compile_constant, compile_if_then_else};
38pub(crate) use constraints::{compile_all_different, compile_global_cardinality};
39pub(crate) use counting_quantifiers::{
40 compile_counting_exists, compile_counting_forall, compile_exact_count, compile_majority,
41};
42pub use custom_ops::{
43 CustomOpData, CustomOpHandler, CustomOpMetadata, CustomOpRegistry, ExtendedCompilerContext,
44};
45pub(crate) use fixpoint::{compile_greatest_fixpoint, compile_least_fixpoint};
46pub(crate) use fuzzy::{
47 compile_fuzzy_implication, compile_fuzzy_not, compile_tconorm, compile_tnorm,
48};
49pub(crate) use higher_order::{compile_apply, compile_lambda};
50pub(crate) use hybrid::{compile_at, compile_everywhere, compile_nominal, compile_somewhere};
51pub(crate) use implication::compile_imply;
52pub(crate) use let_binding::compile_let;
53pub(crate) use logic_ops::{compile_and, compile_not, compile_or};
54pub(crate) use modal_temporal::{
55 compile_always, compile_box, compile_diamond, compile_eventually, compile_next,
56 compile_release, compile_strong_release, compile_until, compile_weak_until,
57};
58pub(crate) use pattern_match::{compile_match, compile_symbol_literal};
59pub(crate) use predicate::compile_predicate;
60pub(crate) use probabilistic::{compile_probabilistic_choice, compile_weighted_rule};
61pub(crate) use quantifiers::{
62 compile_aggregate, compile_exists, compile_forall, compile_soft_exists, compile_soft_forall,
63};
64pub(crate) use set_operations::{
65 compile_empty_set, compile_set_cardinality, compile_set_comprehension, compile_set_difference,
66 compile_set_intersection, compile_set_membership, compile_set_union,
67};
68
69pub(crate) fn infer_domain(expr: &TLExpr, _var: &str) -> Option<String> {
71 match expr {
72 TLExpr::Exists { domain, .. }
73 | TLExpr::ForAll { domain, .. }
74 | TLExpr::Aggregate { domain, .. }
75 | TLExpr::SoftExists { domain, .. }
76 | TLExpr::SoftForAll { domain, .. }
77 | TLExpr::SetComprehension { domain, .. }
78 | TLExpr::CountingExists { domain, .. }
79 | TLExpr::CountingForAll { domain, .. }
80 | TLExpr::ExactCount { domain, .. }
81 | TLExpr::Majority { domain, .. } => Some(domain.clone()),
82 TLExpr::Box(_)
84 | TLExpr::Diamond(_)
85 | TLExpr::Next(_)
86 | TLExpr::Eventually(_)
87 | TLExpr::Always(_)
88 | TLExpr::Until { .. }
89 | TLExpr::Release { .. }
90 | TLExpr::WeakUntil { .. }
91 | TLExpr::StrongRelease { .. } => None,
92 _ => None,
93 }
94}
95
96pub(crate) fn compile_expr(
98 expr: &TLExpr,
99 ctx: &mut CompilerContext,
100 graph: &mut EinsumGraph,
101) -> Result<CompileState> {
102 match expr {
103 TLExpr::Pred { name, args } => compile_predicate(name, args, ctx, graph),
104 TLExpr::And(left, right) => compile_and(left, right, ctx, graph),
105 TLExpr::Or(left, right) => compile_or(left, right, ctx, graph),
106 TLExpr::Not(inner) => compile_not(inner, ctx, graph),
107 TLExpr::Exists { var, domain, body } => compile_exists(var, domain, body, ctx, graph),
108 TLExpr::ForAll { var, domain, body } => compile_forall(var, domain, body, ctx, graph),
109 TLExpr::Aggregate {
110 op,
111 var,
112 domain,
113 body,
114 group_by,
115 } => compile_aggregate(op, var, domain, body, group_by, ctx, graph),
116 TLExpr::Imply(premise, conclusion) => compile_imply(premise, conclusion, ctx, graph),
117 TLExpr::Score(inner) => compile_expr(inner, ctx, graph),
118
119 TLExpr::Add(left, right) => compile_add(left, right, ctx, graph),
121 TLExpr::Sub(left, right) => compile_sub(left, right, ctx, graph),
122 TLExpr::Mul(left, right) => compile_mul(left, right, ctx, graph),
123 TLExpr::Div(left, right) => compile_div(left, right, ctx, graph),
124
125 TLExpr::Eq(left, right) => compile_eq(left, right, ctx, graph),
127 TLExpr::Lt(left, right) => compile_lt(left, right, ctx, graph),
128 TLExpr::Gt(left, right) => compile_gt(left, right, ctx, graph),
129 TLExpr::Lte(left, right) => compile_lte(left, right, ctx, graph),
130 TLExpr::Gte(left, right) => compile_gte(left, right, ctx, graph),
131
132 TLExpr::Pow(left, right) => compile_pow(left, right, ctx, graph),
134 TLExpr::Mod(left, right) => compile_mod(left, right, ctx, graph),
135 TLExpr::Min(left, right) => compile_min_binary(left, right, ctx, graph),
136 TLExpr::Max(left, right) => compile_max_binary(left, right, ctx, graph),
137
138 TLExpr::Abs(inner) => compile_abs(inner, ctx, graph),
140 TLExpr::Floor(inner) => compile_floor(inner, ctx, graph),
141 TLExpr::Ceil(inner) => compile_ceil(inner, ctx, graph),
142 TLExpr::Round(inner) => compile_round(inner, ctx, graph),
143 TLExpr::Sqrt(inner) => compile_sqrt(inner, ctx, graph),
144 TLExpr::Exp(inner) => compile_exp(inner, ctx, graph),
145 TLExpr::Log(inner) => compile_log(inner, ctx, graph),
146 TLExpr::Sin(inner) => compile_sin(inner, ctx, graph),
147 TLExpr::Cos(inner) => compile_cos(inner, ctx, graph),
148 TLExpr::Tan(inner) => compile_tan(inner, ctx, graph),
149
150 TLExpr::IfThenElse {
152 condition,
153 then_branch,
154 else_branch,
155 } => compile_if_then_else(condition, then_branch, else_branch, ctx, graph),
156
157 TLExpr::Constant(value) => compile_constant(*value, ctx, graph),
159
160 TLExpr::Let { var, value, body } => compile_let(var, value, body, ctx, graph),
162
163 TLExpr::TNorm { kind, left, right } => compile_tnorm(*kind, left, right, ctx, graph),
165 TLExpr::TCoNorm { kind, left, right } => compile_tconorm(*kind, left, right, ctx, graph),
166 TLExpr::FuzzyNot { kind, expr } => compile_fuzzy_not(*kind, expr, ctx, graph),
167 TLExpr::FuzzyImplication {
168 kind,
169 premise,
170 conclusion,
171 } => compile_fuzzy_implication(*kind, premise, conclusion, ctx, graph),
172 TLExpr::SoftExists {
173 var,
174 domain,
175 body,
176 temperature,
177 } => compile_soft_exists(var, domain, body, *temperature, ctx, graph),
178 TLExpr::SoftForAll {
179 var,
180 domain,
181 body,
182 temperature,
183 } => compile_soft_forall(var, domain, body, *temperature, ctx, graph),
184 TLExpr::WeightedRule { weight, rule } => compile_weighted_rule(*weight, rule, ctx, graph),
185 TLExpr::ProbabilisticChoice { alternatives } => {
186 compile_probabilistic_choice(alternatives, ctx, graph)
187 }
188
189 TLExpr::Box(inner) => compile_box(inner, ctx, graph),
191 TLExpr::Diamond(inner) => compile_diamond(inner, ctx, graph),
192
193 TLExpr::Next(inner) => compile_next(inner, ctx, graph),
195 TLExpr::Eventually(inner) => compile_eventually(inner, ctx, graph),
196 TLExpr::Always(inner) => compile_always(inner, ctx, graph),
197 TLExpr::Until { before, after } => compile_until(before, after, ctx, graph),
198 TLExpr::Release { released, releaser } => compile_release(releaser, released, ctx, graph),
199 TLExpr::WeakUntil { before, after } => compile_weak_until(before, after, ctx, graph),
200 TLExpr::StrongRelease { released, releaser } => {
201 compile_strong_release(releaser, released, ctx, graph)
202 }
203
204 TLExpr::CountingExists {
206 var,
207 domain,
208 body,
209 min_count,
210 } => compile_counting_exists(var, domain, body, *min_count, ctx, graph),
211 TLExpr::CountingForAll {
212 var,
213 domain,
214 body,
215 min_count,
216 } => compile_counting_forall(var, domain, body, *min_count, ctx, graph),
217 TLExpr::ExactCount {
218 var,
219 domain,
220 body,
221 count,
222 } => compile_exact_count(var, domain, body, *count, ctx, graph),
223 TLExpr::Majority { var, domain, body } => compile_majority(var, domain, body, ctx, graph),
224
225 TLExpr::Lambda {
227 var,
228 var_type,
229 body,
230 } => compile_lambda(var, var_type, body, ctx, graph),
231 TLExpr::Apply { function, argument } => compile_apply(function, argument, ctx, graph),
232
233 TLExpr::SetMembership { element, set } => compile_set_membership(element, set, ctx, graph),
235 TLExpr::SetUnion { left, right } => compile_set_union(left, right, ctx, graph),
236 TLExpr::SetIntersection { left, right } => {
237 compile_set_intersection(left, right, ctx, graph)
238 }
239 TLExpr::SetDifference { left, right } => compile_set_difference(left, right, ctx, graph),
240 TLExpr::SetCardinality { set } => compile_set_cardinality(set, ctx, graph),
241 TLExpr::EmptySet => compile_empty_set(ctx, graph),
242 TLExpr::SetComprehension {
243 var,
244 domain,
245 condition,
246 } => compile_set_comprehension(var, domain, condition, ctx, graph),
247
248 TLExpr::LeastFixpoint { var, body } => compile_least_fixpoint(var, body, ctx, graph),
250 TLExpr::GreatestFixpoint { var, body } => compile_greatest_fixpoint(var, body, ctx, graph),
251
252 TLExpr::Nominal { name } => compile_nominal(name, ctx, graph),
254 TLExpr::At { nominal, formula } => compile_at(nominal, formula, ctx, graph),
255 TLExpr::Somewhere { formula } => compile_somewhere(formula, ctx, graph),
256 TLExpr::Everywhere { formula } => compile_everywhere(formula, ctx, graph),
257
258 TLExpr::AllDifferent { variables } => compile_all_different(variables, ctx, graph),
260 TLExpr::GlobalCardinality {
261 variables,
262 values,
263 min_occurrences,
264 max_occurrences,
265 } => compile_global_cardinality(
266 variables,
267 values,
268 min_occurrences,
269 max_occurrences,
270 ctx,
271 graph,
272 ),
273
274 TLExpr::Abducible { name, cost } => compile_abducible(name, *cost, ctx, graph),
276 TLExpr::Explain { formula } => compile_explain(formula, ctx, graph),
277
278 TLExpr::SymbolLiteral(s) => compile_symbol_literal(s, ctx, graph),
280 TLExpr::Match { scrutinee, arms } => compile_match(scrutinee, arms, ctx, graph),
281 }
282}