tensorlogic_compiler/partial_eval/api.rs
1//! Public entry points for the partial evaluator:
2//! [`partially_evaluate`], [`specialize`], and [`specialize_batch`].
3
4use std::collections::HashSet;
5
6use tensorlogic_ir::TLExpr;
7
8use super::helpers::collect_free_pred_vars;
9use super::pe_core::pe_rec;
10use super::types::{PEConfig, PEEnv, PEResult, PEStats};
11
12/// Partially evaluate `expr` under `env` according to `config`.
13///
14/// Variables are zero-arity predicates. Any such predicate whose name appears
15/// in `env` is substituted with the bound value. Arithmetic and boolean
16/// identities are applied when both operands are known. Dead branches in
17/// logical operators are pruned when one operand resolves to a concrete boolean.
18/// `Let` bindings whose value reduces to a concrete constant are inlined into
19/// the body.
20///
21/// The returned [`PEResult`] contains:
22/// - The residual expression (partially reduced).
23/// - Accumulated statistics.
24/// - The names of variables still free in the output.
25pub fn partially_evaluate(expr: &TLExpr, env: &PEEnv, config: &PEConfig) -> PEResult {
26 let mut stats = PEStats::default();
27 let result_expr = pe_rec(expr.clone(), env, config, 0, &mut stats);
28
29 // Compute residual free variables in the output expression
30 let mut free_set = HashSet::new();
31 collect_free_pred_vars(&result_expr, &HashSet::new(), &mut free_set);
32
33 let mut residual_vars: Vec<String> = free_set.into_iter().collect();
34 residual_vars.sort();
35
36 PEResult {
37 expr: result_expr,
38 stats,
39 residual_vars,
40 }
41}
42
43// ── Specialization helpers ────────────────────────────────────────────────────
44
45/// Specialize `expr` by binding all provided `(name, f64)` pairs and returning
46/// the residual expression.
47///
48/// This is a convenience wrapper around [`partially_evaluate`] that builds a
49/// [`PEEnv`] from the supplied bindings.
50pub fn specialize(expr: &TLExpr, bindings: &[(String, f64)], config: &PEConfig) -> PEResult {
51 let mut env = PEEnv::new();
52 for (name, val) in bindings {
53 env.bind_f64(name.clone(), *val);
54 }
55 partially_evaluate(expr, &env, config)
56}
57
58/// Multi-point specialization: evaluate `expr` at multiple binding sets and
59/// return one [`PEResult`] per binding set, in the same order as `binding_sets`.
60///
61/// This can be used, for example, to compile a parameterised expression for a
62/// batch of concrete parameter values.
63pub fn specialize_batch(
64 expr: &TLExpr,
65 binding_sets: &[Vec<(String, f64)>],
66 config: &PEConfig,
67) -> Vec<PEResult> {
68 binding_sets
69 .iter()
70 .map(|bindings| specialize(expr, bindings, config))
71 .collect()
72}