tensorlogic_trustformers/rule_guided_decoder/constraint.rs
1//! Logical-constraint compilation for rule-guided decoding.
2//!
3//! This module wraps a [`tensorlogic_ir::TLExpr`] constraint and compiles it
4//! into a lightweight, per-token lookup representation that the decoder can
5//! consult at every sampling step.
6//!
7//! ## Scope
8//!
9//! Only a deliberate subset of `TLExpr` is honoured by the current implementation:
10//!
11//! * [`TLExpr::Pred`] — treated as an allow-list of symbol names that the
12//! candidate token must match.
13//! * [`TLExpr::And`] — intersection of its operands' constraints.
14//! * [`TLExpr::Or`] — union of its operands' constraints.
15//! * [`TLExpr::Not`] — inverts the classification emitted by the inner
16//! constraint (allow ↔ forbid).
17//!
18//! Any other variant collapses to [`ConstraintVerdict::SoftPenalty`]`(0.0)` (no-op)
19//! with a `// TODO` pointing at the extension point. See [`extend_tlexpr_support`]
20//! for the next step.
21//!
22//! ## Token-to-symbol mapping
23//!
24//! The compiled constraint needs to know which *symbol name* each token
25//! corresponds to — vocabulary encodings are deeply application-specific.
26//! Callers supply a mapper `Fn(TokenId) -> Option<SymbolName>`; an empty option
27//! means "this token is unknown / has no symbolic identity".
28
29use std::collections::HashSet;
30
31use tensorlogic_ir::TLExpr;
32
33#[cfg(test)]
34use crate::rule_guided_decoder::error::RuleGuidedError;
35use crate::rule_guided_decoder::error::RuleGuidedResult;
36
37/// Logical token identifier used by the beam-search backend (`usize`).
38pub type TokenId = usize;
39
40/// Verdict produced by a compiled constraint about a candidate token.
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum ConstraintVerdict {
43 /// Token is explicitly allowed — no logit adjustment needed.
44 Allowed,
45 /// Token is explicitly forbidden — hard masking should set the logit
46 /// to `-inf`; soft masking treats it as the maximum-penalty case.
47 Forbidden,
48 /// Soft penalty expressed as a non-negative "violation magnitude" that
49 /// the soft-penalty mask multiplies by `lambda`.
50 SoftPenalty(f64),
51}
52
53/// Token → symbol-name mapper. `None` means the token has no symbolic identity
54/// (constraint evaluation is conservative — see [`RuleConstraint::evaluate`]).
55pub type TokenSymbolMapper = dyn Fn(TokenId) -> Option<String> + Send + Sync;
56
57/// Compiled representation of a single `TLExpr` constraint.
58///
59/// Two forms coexist:
60///
61/// 1. **Eager table** (`allow_set`) — populated when the constraint compiles
62/// to a finite predicate list over a user-supplied vocabulary mapper. This
63/// path is the fast path and is used by the hard/soft masks.
64/// 2. **Fallback pass-through** — used when the expression hit an unsupported
65/// variant. In that case [`RuleConstraint::evaluate`] returns
66/// `ConstraintVerdict::SoftPenalty(0.0)` and the decoder behaves as if no
67/// constraint was present.
68pub struct RuleConstraint {
69 /// Original TLExpr (kept for diagnostics and lazy re-compilation).
70 source: TLExpr,
71 /// Union of symbol names accepted by the constraint, if computable.
72 ///
73 /// Conceptually `None` means "constraint is non-enumerable" (e.g., the
74 /// AST contained variables or unsupported connectives). An empty set
75 /// means the constraint is unsatisfiable — no token passes.
76 allow_set: Option<HashSet<String>>,
77 /// Mapper from token ids to symbol names. Stored so `evaluate` can be
78 /// called many times without re-compiling.
79 mapper: Box<TokenSymbolMapper>,
80 /// Set to true when compilation succeeded against a recognised subset of
81 /// `TLExpr`. When false, the constraint silently no-ops.
82 supported: bool,
83}
84
85impl std::fmt::Debug for RuleConstraint {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("RuleConstraint")
88 .field("source", &self.source)
89 .field("allow_set", &self.allow_set)
90 .field("supported", &self.supported)
91 .finish_non_exhaustive()
92 }
93}
94
95impl RuleConstraint {
96 /// Compile a constraint from a `TLExpr` and a token → symbol-name mapper.
97 ///
98 /// * If the expression only uses supported variants, the returned
99 /// constraint eagerly enumerates allowed symbol names into a `HashSet`.
100 /// * Otherwise, the constraint is still constructed but evaluates to a
101 /// no-op (soft-penalty of zero). This makes the decoder forward-
102 /// compatible: new `TLExpr` variants don't break existing call sites.
103 pub fn compile<M>(expr: TLExpr, mapper: M) -> RuleGuidedResult<Self>
104 where
105 M: Fn(TokenId) -> Option<String> + Send + Sync + 'static,
106 {
107 let mut builder = AllowSetBuilder::default();
108 let supported = builder.visit(&expr)?;
109 let allow_set = if supported {
110 Some(builder.finalize())
111 } else {
112 None
113 };
114 Ok(Self {
115 source: expr,
116 allow_set,
117 mapper: Box::new(mapper),
118 supported,
119 })
120 }
121
122 /// Evaluate the constraint against `(prefix, candidate)`.
123 ///
124 /// `prefix` is the token sequence already committed to the beam; it is
125 /// not used by the current allow-list compiler but is part of the contract
126 /// so stateful constraints (e.g. "no token X after token Y") remain
127 /// implementable without an API break — see [`extend_tlexpr_support`].
128 pub fn evaluate(&self, prefix: &[TokenId], candidate: TokenId) -> ConstraintVerdict {
129 let _ = prefix; // Reserved for future stateful predicates.
130 if !self.supported {
131 return ConstraintVerdict::SoftPenalty(0.0);
132 }
133 let allow_set = match &self.allow_set {
134 Some(set) => set,
135 None => return ConstraintVerdict::SoftPenalty(0.0),
136 };
137
138 let symbol = (self.mapper)(candidate);
139 match symbol {
140 Some(name) if allow_set.contains(&name) => ConstraintVerdict::Allowed,
141 Some(_) => ConstraintVerdict::Forbidden,
142 None => {
143 // Unknown tokens (e.g. punctuation with no symbol) are treated
144 // conservatively as a soft violation so the decoder slightly
145 // prefers fully-symbolic completions without banning them.
146 ConstraintVerdict::SoftPenalty(1.0)
147 }
148 }
149 }
150
151 /// Read-only access to the compiled allow-list, if any.
152 pub fn allow_set(&self) -> Option<&HashSet<String>> {
153 self.allow_set.as_ref()
154 }
155
156 /// `true` when the constraint was compiled against a supported subset of
157 /// `TLExpr`. `false` means the constraint is a no-op.
158 pub fn is_supported(&self) -> bool {
159 self.supported
160 }
161
162 /// Original expression.
163 pub fn source(&self) -> &TLExpr {
164 &self.source
165 }
166}
167
168// ---------------------------------------------------------------------------
169// Allow-set compiler
170// ---------------------------------------------------------------------------
171
172#[derive(Default)]
173struct AllowSetBuilder {
174 /// Accumulated allow-list. Interpretation depends on the surrounding
175 /// operator: set-intersection for AND, set-union for OR. The top-level
176 /// operator semantics are applied by the caller walking the tree.
177 current: Option<HashSet<String>>,
178}
179
180impl AllowSetBuilder {
181 /// Recursively visit `expr`, updating `self.current`.
182 ///
183 /// Returns `true` when every sub-expression fell into the supported
184 /// subset; `false` signals the caller to drop the compiled table and
185 /// fall back to the no-op path.
186 fn visit(&mut self, expr: &TLExpr) -> RuleGuidedResult<bool> {
187 let set = match self.classify(expr)? {
188 Some(s) => s,
189 None => return Ok(false),
190 };
191 self.current = Some(set);
192 Ok(true)
193 }
194
195 fn finalize(self) -> HashSet<String> {
196 self.current.unwrap_or_default()
197 }
198
199 /// Attempt to fold `expr` into an allow-set. Returns `Ok(None)` when the
200 /// expression uses an unsupported variant.
201 fn classify(&self, expr: &TLExpr) -> RuleGuidedResult<Option<HashSet<String>>> {
202 match expr {
203 TLExpr::Pred { name, args } => {
204 // Treat the predicate's atoms as allowed symbol names.
205 // The predicate name itself is allowed as a symbol too —
206 // this matches the usual convention where tokenizers emit a
207 // "type" token (e.g., `entity(Alice)`).
208 let mut set = HashSet::with_capacity(1 + args.len());
209 set.insert(name.clone());
210 for arg in args {
211 match arg {
212 tensorlogic_ir::Term::Const(s) => {
213 set.insert(s.clone());
214 }
215 tensorlogic_ir::Term::Var(_) => {
216 // Variables are unbound: the predicate doesn't
217 // restrict the vocabulary symbolically.
218 }
219 tensorlogic_ir::Term::Typed { value, .. } => {
220 if let tensorlogic_ir::Term::Const(s) = value.as_ref() {
221 set.insert(s.clone());
222 }
223 }
224 }
225 }
226 Ok(Some(set))
227 }
228 TLExpr::And(lhs, rhs) => {
229 let l = match self.classify(lhs)? {
230 Some(s) => s,
231 None => return Ok(None),
232 };
233 let r = match self.classify(rhs)? {
234 Some(s) => s,
235 None => return Ok(None),
236 };
237 Ok(Some(l.intersection(&r).cloned().collect()))
238 }
239 TLExpr::Or(lhs, rhs) => {
240 let l = match self.classify(lhs)? {
241 Some(s) => s,
242 None => return Ok(None),
243 };
244 let r = match self.classify(rhs)? {
245 Some(s) => s,
246 None => return Ok(None),
247 };
248 Ok(Some(l.union(&r).cloned().collect()))
249 }
250 TLExpr::Not(inner) => {
251 // Negation of an allow-list has no finite representation in
252 // the closed-vocabulary form we keep here. Callers still get
253 // well-defined behaviour: negation flips "membership" to
254 // "non-membership", but we need the vocabulary-wide symbol
255 // universe for that — which we don't know at compile time.
256 // Instead, synthesize a sentinel allow-set signalling
257 // "complement mode" via an unused variant. See the TODO at
258 // the end of this module.
259 //
260 // For now, fall back to the no-op path.
261 let _ = inner;
262 // TODO(extend_tlexpr_support): Thread Not through evaluate()
263 // with an explicit complement flag or per-token mapper look-up.
264 Ok(None)
265 }
266 // TODO(extend_tlexpr_support): Implement Exists/ForAll/Imply.
267 // Each requires either quantifier elimination against the mapper
268 // or a semantic predicate that consults the prefix, which we do
269 // not currently have at compile time.
270 _ => Ok(None),
271 }
272 }
273}
274
275/// Documentation marker: extension point for additional `TLExpr` variants.
276///
277/// Today we handle `Pred`, `And`, and `Or`. To add support for, e.g.,
278/// `Imply`, extend `AllowSetBuilder::classify` with the appropriate
279/// set-algebraic translation. Stateful connectives (those whose truth
280/// depends on the generated prefix) should introduce a new arm in
281/// [`RuleConstraint::evaluate`] that inspects `prefix`.
282pub const fn extend_tlexpr_support() {}
283
284// ---------------------------------------------------------------------------
285// Unit tests
286// ---------------------------------------------------------------------------
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use tensorlogic_ir::Term;
292
293 fn mk_pred(name: &str, consts: &[&str]) -> TLExpr {
294 TLExpr::Pred {
295 name: name.into(),
296 args: consts.iter().map(|c| Term::Const((*c).into())).collect(),
297 }
298 }
299
300 fn demo_mapper() -> impl Fn(TokenId) -> Option<String> + Send + Sync + 'static {
301 |tid: TokenId| match tid {
302 1 => Some("Alice".into()),
303 2 => Some("Bob".into()),
304 3 => Some("entity".into()),
305 _ => None,
306 }
307 }
308
309 #[test]
310 fn predicate_allow_list_accepts_named_consts() {
311 let expr = mk_pred("entity", &["Alice"]);
312 let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
313 assert!(rc.is_supported());
314 assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
315 assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
316 assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
317 }
318
319 #[test]
320 fn conjunction_intersects_allow_sets() {
321 // entity(Alice) AND entity(Bob) — only the shared "entity" symbol
322 // remains, so Alice/Bob tokens become Forbidden.
323 let a = mk_pred("entity", &["Alice"]);
324 let b = mk_pred("entity", &["Bob"]);
325 let expr = TLExpr::And(Box::new(a), Box::new(b));
326 let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
327 assert!(rc.is_supported());
328 assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
329 assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Forbidden);
330 assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
331 }
332
333 #[test]
334 fn disjunction_unions_allow_sets() {
335 let a = mk_pred("entity", &["Alice"]);
336 let b = mk_pred("entity", &["Bob"]);
337 let expr = TLExpr::Or(Box::new(a), Box::new(b));
338 let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
339 assert!(rc.is_supported());
340 assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Allowed);
341 assert_eq!(rc.evaluate(&[], 2), ConstraintVerdict::Allowed);
342 assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Allowed);
343 }
344
345 #[test]
346 fn unsupported_variant_returns_soft_noop() {
347 let inner = mk_pred("entity", &["Alice"]);
348 let expr = TLExpr::Not(Box::new(inner));
349 let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
350 assert!(!rc.is_supported());
351 assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::SoftPenalty(0.0));
352 }
353
354 #[test]
355 fn unknown_token_yields_soft_penalty() {
356 let expr = mk_pred("entity", &["Alice"]);
357 let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
358 // Token id 99 has no mapping -> conservative SoftPenalty(1.0).
359 assert_eq!(rc.evaluate(&[], 99), ConstraintVerdict::SoftPenalty(1.0));
360 }
361
362 #[test]
363 fn empty_intersection_forbids_all_known_tokens() {
364 // entity(Alice) AND user(Charlie) — disjoint allow sets except that
365 // "entity" / "user" are the predicate names. Tokens for Alice/Bob
366 // all become Forbidden.
367 let a = mk_pred("entity", &["Alice"]);
368 let b = mk_pred("user", &["Charlie"]);
369 let expr = TLExpr::And(Box::new(a), Box::new(b));
370 let rc = RuleConstraint::compile(expr, demo_mapper()).expect("compile");
371 assert!(rc.is_supported());
372 // "entity" is in allow_set(a) but not allow_set(b); intersection is
373 // empty, so every known symbol is forbidden.
374 assert_eq!(rc.evaluate(&[], 1), ConstraintVerdict::Forbidden);
375 assert_eq!(rc.evaluate(&[], 3), ConstraintVerdict::Forbidden);
376 }
377
378 #[test]
379 fn error_type_has_useful_display() {
380 // Sanity check that RuleGuidedError links correctly.
381 let err: RuleGuidedError =
382 RuleGuidedError::CompilationError("synthetic failure".to_string());
383 assert!(err.to_string().contains("synthetic"));
384 }
385}