Skip to main content

tensorlogic_trustformers/rule_guided_decoder/
constraint.rs

1//! Logical-constraint compilation for rule-guided decoding.
2//!
3//! This module wraps a [`tensorlogic_ir::TLExpr`] constraint and compiles it
4//! into a lightweight, per-token lookup representation that the decoder can
5//! consult at every sampling step.
6//!
7//! ## Scope
8//!
9//! Only a deliberate subset of `TLExpr` is honoured by the current implementation:
10//!
11//! * [`TLExpr::Pred`] — treated as an allow-list of symbol names that the
12//!   candidate token must match.
13//! * [`TLExpr::And`] — intersection of its operands' constraints.
14//! * [`TLExpr::Or`] — union of its operands' constraints.
15//! * [`TLExpr::Not`] — inverts the classification emitted by the inner
16//!   constraint (allow ↔ forbid).
17//!
18//! Any other variant collapses to [`ConstraintVerdict::SoftPenalty`]`(0.0)` (no-op)
19//! with a `// TODO` pointing at the extension point. See [`extend_tlexpr_support`]
20//! for the next step.
21//!
22//! ## Token-to-symbol mapping
23//!
24//! The compiled constraint needs to know which *symbol name* each token
25//! corresponds to — vocabulary encodings are deeply application-specific.
26//! Callers supply a mapper `Fn(TokenId) -> Option<SymbolName>`; an empty option
27//! means "this token is unknown / has no symbolic identity".
28
29use std::collections::HashSet;
30
31use tensorlogic_ir::TLExpr;
32
33#[cfg(test)]
34use crate::rule_guided_decoder::error::RuleGuidedError;
35use crate::rule_guided_decoder::error::RuleGuidedResult;
36
37/// Logical token identifier used by the beam-search backend (`usize`).
38pub type TokenId = usize;
39
40/// Verdict produced by a compiled constraint about a candidate token.
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum ConstraintVerdict {
43    /// Token is explicitly allowed — no logit adjustment needed.
44    Allowed,
45    /// Token is explicitly forbidden — hard masking should set the logit
46    /// to `-inf`; soft masking treats it as the maximum-penalty case.
47    Forbidden,
48    /// Soft penalty expressed as a non-negative "violation magnitude" that
49    /// the soft-penalty mask multiplies by `lambda`.
50    SoftPenalty(f64),
51}
52
53/// Token → symbol-name mapper. `None` means the token has no symbolic identity
54/// (constraint evaluation is conservative — see [`RuleConstraint::evaluate`]).
55pub type TokenSymbolMapper = dyn Fn(TokenId) -> Option<String> + Send + Sync;
56
57/// Compiled representation of a single `TLExpr` constraint.
58///
59/// Two forms coexist:
60///
61/// 1. **Eager table** (`allow_set`) — populated when the constraint compiles
62///    to a finite predicate list over a user-supplied vocabulary mapper.  This
63///    path is the fast path and is used by the hard/soft masks.
64/// 2. **Fallback pass-through** — used when the expression hit an unsupported
65///    variant.  In that case [`RuleConstraint::evaluate`] returns
66///    `ConstraintVerdict::SoftPenalty(0.0)` and the decoder behaves as if no
67///    constraint was present.
68pub struct RuleConstraint {
69    /// Original TLExpr (kept for diagnostics and lazy re-compilation).
70    source: TLExpr,
71    /// Union of symbol names accepted by the constraint, if computable.
72    ///
73    /// Conceptually `None` means "constraint is non-enumerable" (e.g., the
74    /// AST contained variables or unsupported connectives).  An empty set
75    /// means the constraint is unsatisfiable — no token passes.
76    allow_set: Option<HashSet<String>>,
77    /// Mapper from token ids to symbol names.  Stored so `evaluate` can be
78    /// called many times without re-compiling.
79    mapper: Box<TokenSymbolMapper>,
80    /// Set to true when compilation succeeded against a recognised subset of
81    /// `TLExpr`.  When false, the constraint silently no-ops.
82    supported: bool,
83}
84
85impl std::fmt::Debug for RuleConstraint {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("RuleConstraint")
88            .field("source", &self.source)
89            .field("allow_set", &self.allow_set)
90            .field("supported", &self.supported)
91            .finish_non_exhaustive()
92    }
93}
94
95impl RuleConstraint {
96    /// Compile a constraint from a `TLExpr` and a token → symbol-name mapper.
97    ///
98    /// * If the expression only uses supported variants, the returned
99    ///   constraint eagerly enumerates allowed symbol names into a `HashSet`.
100    /// * Otherwise, the constraint is still constructed but evaluates to a
101    ///   no-op (soft-penalty of zero).  This makes the decoder forward-
102    ///   compatible: new `TLExpr` variants don't break existing call sites.
103    pub fn compile<M>(expr: TLExpr, mapper: M) -> RuleGuidedResult<Self>
104    where
105        M: Fn(TokenId) -> Option<String> + Send + Sync + 'static,
106    {
107        let mut builder = AllowSetBuilder::default();
108        let supported = builder.visit(&expr)?;
109        let allow_set = if supported {
110            Some(builder.finalize())
111        } else {
112            None
113        };
114        Ok(Self {
115            source: expr,
116            allow_set,
117            mapper: Box::new(mapper),
118            supported,
119        })
120    }
121
122    /// Evaluate the constraint against `(prefix, candidate)`.
123    ///
124    /// `prefix` is the token sequence already committed to the beam; it is
125    /// not used by the current allow-list compiler but is part of the contract
126    /// so stateful constraints (e.g. "no token X after token Y") remain
127    /// implementable without an API break — see [`extend_tlexpr_support`].
128    pub fn evaluate(&self, prefix: &[TokenId], candidate: TokenId) -> ConstraintVerdict {
129        let _ = prefix; // Reserved for future stateful predicates.
130        if !self.supported {
131            return ConstraintVerdict::SoftPenalty(0.0);
132        }
133        let allow_set = match &self.allow_set {
134            Some(set) => set,
135            None => return ConstraintVerdict::SoftPenalty(0.0),
136        };
137
138        let symbol = (self.mapper)(candidate);
139        match symbol {
140            Some(name) if allow_set.contains(&name) => ConstraintVerdict::Allowed,
141            Some(_) => ConstraintVerdict::Forbidden,
142            None => {
143                // Unknown tokens (e.g. punctuation with no symbol) are treated
144                // conservatively as a soft violation so the decoder slightly
145                // prefers fully-symbolic completions without banning them.
146                ConstraintVerdict::SoftPenalty(1.0)
147            }
148        }
149    }
150
151    /// Read-only access to the compiled allow-list, if any.
152    pub fn allow_set(&self) -> Option<&HashSet<String>> {
153        self.allow_set.as_ref()
154    }
155
156    /// `true` when the constraint was compiled against a supported subset of
157    /// `TLExpr`.  `false` means the constraint is a no-op.
158    pub fn is_supported(&self) -> bool {
159        self.supported
160    }
161
162    /// Original expression.
163    pub fn source(&self) -> &TLExpr {
164        &self.source
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Allow-set compiler
170// ---------------------------------------------------------------------------
171
172#[derive(Default)]
173struct AllowSetBuilder {
174    /// Accumulated allow-list.  Interpretation depends on the surrounding
175    /// operator: set-intersection for AND, set-union for OR.  The top-level
176    /// operator semantics are applied by the caller walking the tree.
177    current: Option<HashSet<String>>,
178}
179
180impl AllowSetBuilder {
181    /// Recursively visit `expr`, updating `self.current`.
182    ///
183    /// Returns `true` when every sub-expression fell into the supported
184    /// subset; `false` signals the caller to drop the compiled table and
185    /// fall back to the no-op path.
186    fn visit(&mut self, expr: &TLExpr) -> RuleGuidedResult<bool> {
187        let set = match self.classify(expr)? {
188            Some(s) => s,
189            None => return Ok(false),
190        };
191        self.current = Some(set);
192        Ok(true)
193    }
194
195    fn finalize(self) -> HashSet<String> {
196        self.current.unwrap_or_default()
197    }
198
199    /// Attempt to fold `expr` into an allow-set.  Returns `Ok(None)` when the
200    /// expression uses an unsupported variant.
201    fn classify(&self, expr: &TLExpr) -> RuleGuidedResult<Option<HashSet<String>>> {
202        match expr {
203            TLExpr::Pred { name, args } => {
204                // Treat the predicate's atoms as allowed symbol names.
205                // The predicate name itself is allowed as a symbol too —
206                // this matches the usual convention where tokenizers emit a
207                // "type" token (e.g., `entity(Alice)`).
208                let mut set = HashSet::with_capacity(1 + args.len());
209                set.insert(name.clone());
210                for arg in args {
211                    match arg {
212                        tensorlogic_ir::Term::Const(s) => {
213                            set.insert(s.clone());
214                        }
215                        tensorlogic_ir::Term::Var(_) => {
216                            // Variables are unbound: the predicate doesn't
217                            // restrict the vocabulary symbolically.
218                        }
219                        tensorlogic_ir::Term::Typed { value, .. } => {
220                            if let tensorlogic_ir::Term::Const(s) = value.as_ref() {
221                                set.insert(s.clone());
222                            }
223                        }
224                    }
225                }
226                Ok(Some(set))
227            }
228            TLExpr::And(lhs, rhs) => {
229                let l = match self.classify(lhs)? {
230                    Some(s) => s,
231                    None => return Ok(None),
232                };
233                let r = match self.classify(rhs)? {
234                    Some(s) => s,
235                    None => return Ok(None),
236                };
237                Ok(Some(l.intersection(&r).cloned().collect()))
238            }
239            TLExpr::Or(lhs, rhs) => {
240                let l = match self.classify(lhs)? {
241                    Some(s) => s,
242                    None => return Ok(None),
243                };
244                let r = match self.classify(rhs)? {
245                    Some(s) => s,
246                    None => return Ok(None),
247                };
248                Ok(Some(l.union(&r).cloned().collect()))
249            }
250            TLExpr::Not(inner) => {
251                // Negation of an allow-list has no finite representation in
252                // the closed-vocabulary form we keep here.  Callers still get
253                // well-defined behaviour: negation flips "membership" to
254                // "non-membership", but we need the vocabulary-wide symbol
255                // universe for that — which we don't know at compile time.
256                // Instead, synthesize a sentinel allow-set signalling
257                // "complement mode" via an unused variant.  See the TODO at
258                // the end of this module.
259                //
260                // For now, fall back to the no-op path.
261                let _ = inner;
262                // TODO(extend_tlexpr_support): Thread Not through evaluate()
263                // with an explicit complement flag or per-token mapper look-up.
264                Ok(None)
265            }
266            // TODO(extend_tlexpr_support): Implement Exists/ForAll/Imply.
267            // Each requires either quantifier elimination against the mapper
268            // or a semantic predicate that consults the prefix, which we do
269            // not currently have at compile time.
270            _ => Ok(None),
271        }
272    }
273}
274
275/// Documentation marker: extension point for additional `TLExpr` variants.
276///
277/// Today we handle `Pred`, `And`, and `Or`.  To add support for, e.g.,
278/// `Imply`, extend `AllowSetBuilder::classify` with the appropriate
279/// set-algebraic translation.  Stateful connectives (those whose truth
280/// depends on the generated prefix) should introduce a new arm in
281/// [`RuleConstraint::evaluate`] that inspects `prefix`.
282pub const fn extend_tlexpr_support() {}
283
284// ---------------------------------------------------------------------------
285// Unit tests
286// ---------------------------------------------------------------------------
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use tensorlogic_ir::Term;
292
293    fn mk_pred(name: &str, consts: &[&str]) -> TLExpr {
294        TLExpr::Pred {
295            name: name.into(),
296            args: consts.iter().map(|c| Term::Const((*c).into())).collect(),
297        }
298    }
299
300    fn demo_mapper() -> impl Fn(TokenId) -> Option<String> + Send + Sync + 'static {
301        |tid: TokenId| match tid {
302            1 => Some("Alice".into()),
303            2 => Some("Bob".into()),
304            3 => Some("entity".into()),
305            _ => None,
306        }
307    }
308
309    #[test]
310    fn predicate_allow_list_accepts_named_consts() {
311        let expr = mk_pred("entity", &["Alice"]);
312        let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
313        assert!(rc.is_supported());
314        assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
315        assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
316        assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
317    }
318
319    #[test]
320    fn conjunction_intersects_allow_sets() {
321        // entity(Alice) AND entity(Bob) — only the shared "entity" symbol
322        // remains, so Alice/Bob tokens become Forbidden.
323        let a = mk_pred("entity", &["Alice"]);
324        let b = mk_pred("entity", &["Bob"]);
325        let expr = TLExpr::And(Box::new(a), Box::new(b));
326        let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
327        assert!(rc.is_supported());
328        assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
329        assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
330        assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
331    }
332
333    #[test]
334    fn disjunction_unions_allow_sets() {
335        let a = mk_pred("entity", &["Alice"]);
336        let b = mk_pred("entity", &["Bob"]);
337        let expr = TLExpr::Or(Box::new(a), Box::new(b));
338        let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
339        assert!(rc.is_supported());
340        assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
341        assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Allowed);
342        assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
343    }
344
345    #[test]
346    fn unsupported_variant_returns_soft_noop() {
347        let inner = mk_pred("entity", &["Alice"]);
348        let expr = TLExpr::Not(Box::new(inner));
349        let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
350        assert!(!rc.is_supported());
351        assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::SoftPenalty(0.0));
352    }
353
354    #[test]
355    fn unknown_token_yields_soft_penalty() {
356        let expr = mk_pred("entity", &["Alice"]);
357        let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
358        // Token id 99 has no mapping -> conservative SoftPenalty(1.0).
359        assert_eq!(rc.evaluate(&[], 99), ConstraintVerdict::SoftPenalty(1.0));
360    }
361
362    #[test]
363    fn empty_intersection_forbids_all_known_tokens() {
364        // entity(Alice) AND user(Charlie) — disjoint allow sets except that
365        // "entity" / "user" are the predicate names.  Tokens for Alice/Bob
366        // all become Forbidden.
367        let a = mk_pred("entity", &["Alice"]);
368        let b = mk_pred("user", &["Charlie"]);
369        let expr = TLExpr::And(Box::new(a), Box::new(b));
370        let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
371        assert!(rc.is_supported());
372        // "entity" is in allow_set(a) but not allow_set(b); intersection is
373        // empty, so every known symbol is forbidden.
374        assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
375        assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Forbidden);
376    }
377
378    #[test]
379    fn error_type_has_useful_display() {
380        // Sanity check that RuleGuidedError links correctly.
381        let err: RuleGuidedError =
382            RuleGuidedError::CompilationError("synthetic failure".to_string());
383        assert!(err.to_string().contains("synthetic"));
384    }
385}