Skip to main content

tensorlogic_trustformers/rule_guided_decoder/
engine.rs

1//! Rule-guided beam-search engine.
2//!
3//! [`RuleGuidedBeamSearch`] is a thin composition layer on top of
4//! [`tensorlogic_infer::beam_search::BeamSearchDecoder`]:
5//!
6//! 1. The caller provides a raw scoring closure (same signature as the one
7//!    accepted by [`BeamSearchDecoder::decode`]).
8//! 2. We wrap that closure in a new one that, after each raw-logits call,
9//!    runs the configured [`LogitMasker`] against the active beams' prefixes.
10//! 3. The wrapped closure is forwarded to the inner decoder, which takes care
11//!    of length penalties, temperature, top-k and EOS bookkeeping.
12//!
13//! Because the masking happens *before* softmax inside `decode<F>`, setting a
14//! logit to `NEG_INFINITY` is semantically equivalent to assigning zero
15//! probability — the beam search will never explore that branch.
16
17use std::sync::Arc;
18
19use tensorlogic_infer::beam_search::{BeamSearchConfig, BeamSearchDecoder, BeamSearchResult};
20
21use crate::rule_guided_decoder::constraint::RuleConstraint;
22use crate::rule_guided_decoder::error::{RuleGuidedError, RuleGuidedResult};
23use crate::rule_guided_decoder::mask::LogitMasker;
24
25/// High-level composition of a beam-search decoder, a logical constraint and
26/// a masking strategy.
27pub struct RuleGuidedBeamSearch {
28    inner: BeamSearchDecoder,
29    constraint: Arc<RuleConstraint>,
30    masker: Arc<dyn LogitMasker>,
31}
32
33impl RuleGuidedBeamSearch {
34    /// Construct a decoder with an explicit beam-search configuration.
35    pub fn new(
36        config: BeamSearchConfig,
37        constraint: RuleConstraint,
38        masker: Arc<dyn LogitMasker>,
39    ) -> Self {
40        Self {
41            inner: BeamSearchDecoder::new(config),
42            constraint: Arc::new(constraint),
43            masker,
44        }
45    }
46
47    /// Read-only access to the underlying beam-search configuration.
48    pub fn config(&self) -> &BeamSearchConfig {
49        &self.inner.config
50    }
51
52    /// Read-only access to the compiled constraint.
53    pub fn constraint(&self) -> &RuleConstraint {
54        &self.constraint
55    }
56
57    /// Return the masker's name (`"HardMask"` / `"SoftPenaltyMask"` / ...).
58    pub fn masker_name(&self) -> &'static str {
59        self.masker.name()
60    }
61
62    /// Run the decoder.
63    ///
64    /// `score_fn` is the same "raw logits" closure accepted by
65    /// [`BeamSearchDecoder::decode`].  The engine wraps it so that every
66    /// row of the returned `[num_beams][vocab_size]` logit matrix is filtered
67    /// through the configured [`LogitMasker`] before being passed on.
68    pub fn decode<F>(&self, bos_token_id: usize, score_fn: F) -> RuleGuidedResult<BeamSearchResult>
69    where
70        F: Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String>,
71    {
72        let constraint = Arc::clone(&self.constraint);
73        let masker = Arc::clone(&self.masker);
74        let expected_vocab = self.inner.config.vocab_size;
75
76        let wrapped = move |beams: &[&[usize]]| -> Result<Vec<Vec<f64>>, String> {
77            let mut raw_logits = score_fn(beams)?;
78            for (beam_idx, logits_row) in raw_logits.iter_mut().enumerate() {
79                if logits_row.len() != expected_vocab {
80                    return Err(format!(
81                        "logits row {beam_idx} has width {}, expected {expected_vocab}",
82                        logits_row.len()
83                    ));
84                }
85                let prefix = beams.get(beam_idx).copied().unwrap_or(&[]);
86                masker
87                    .apply(&constraint, prefix, logits_row)
88                    .map_err(|e| format!("rule-guided mask error: {e}"))?;
89            }
90            Ok(raw_logits)
91        };
92
93        self.inner
94            .decode(bos_token_id, wrapped)
95            .map_err(RuleGuidedError::from)
96    }
97}
98
99impl std::fmt::Debug for RuleGuidedBeamSearch {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("RuleGuidedBeamSearch")
102            .field("config", &self.inner.config)
103            .field("constraint", &self.constraint)
104            .field("masker", &self.masker.name())
105            .finish()
106    }
107}
108
109// ---------------------------------------------------------------------------
110// Unit tests
111// ---------------------------------------------------------------------------
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::rule_guided_decoder::constraint::TokenId;
117    use crate::rule_guided_decoder::mask::{HardMask, SoftPenaltyMask};
118    use tensorlogic_ir::{TLExpr, Term};
119
120    fn mk_constraint_alice_bob() -> RuleConstraint {
121        // entity(Alice) OR entity(Bob) — allow symbol set = {entity, Alice, Bob}.
122        let a = TLExpr::Pred {
123            name: "entity".into(),
124            args: vec![Term::Const("Alice".into())],
125        };
126        let b = TLExpr::Pred {
127            name: "entity".into(),
128            args: vec![Term::Const("Bob".into())],
129        };
130        let expr = TLExpr::Or(Box::new(a), Box::new(b));
131        let mapper = |tid: TokenId| match tid {
132            0 => Some("entity".into()),
133            1 => Some("Alice".into()),
134            2 => Some("Bob".into()),
135            3 => Some("Eve".into()),
136            _ => None,
137        };
138        RuleConstraint::compile(expr, mapper).expect("compile")
139    }
140
141    fn flat_config() -> BeamSearchConfig {
142        BeamSearchConfig {
143            beam_width: 2,
144            max_length: 4,
145            eos_token_id: None,
146            length_penalty: 0.0,
147            min_length: 1,
148            vocab_size: 4,
149            temperature: 1.0,
150            top_k_filter: None,
151        }
152    }
153
154    fn flat_scores() -> impl Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String> {
155        // Return uniform logits for every active beam; the masker decides
156        // which tokens live or die.
157        |beams: &[&[usize]]| Ok(beams.iter().map(|_| vec![1.0_f64, 1.0, 1.0, 1.0]).collect())
158    }
159
160    #[test]
161    fn hard_mask_excludes_forbidden_token() {
162        let decoder = RuleGuidedBeamSearch::new(
163            flat_config(),
164            mk_constraint_alice_bob(),
165            Arc::new(HardMask::new()),
166        );
167
168        let result = decoder
169            .decode(0, flat_scores())
170            .expect("decode should succeed");
171        // Eve (token id 3) maps to a symbol outside the allow set, so every
172        // beam must avoid it.  "Unknown" token ids (mapper returns None)
173        // don't exist in this vocabulary, so only forbidden symbols matter.
174        for hyp in &result.hypotheses {
175            assert!(
176                !hyp.tokens.contains(&3),
177                "hard-masked decoder emitted forbidden token: {:?}",
178                hyp.tokens
179            );
180        }
181        assert_eq!(decoder.masker_name(), "HardMask");
182    }
183
184    #[test]
185    fn soft_mask_allows_forbidden_when_lambda_is_zero() {
186        // Note: Forbidden tokens are still banned regardless of lambda.  We
187        // verify soft-mode does not *additionally* block allowed tokens and
188        // reports its name correctly.
189        let decoder = RuleGuidedBeamSearch::new(
190            flat_config(),
191            mk_constraint_alice_bob(),
192            Arc::new(SoftPenaltyMask::new(0.0).expect("lambda")),
193        );
194        let result = decoder
195            .decode(0, flat_scores())
196            .expect("decode should succeed");
197        assert!(!result.hypotheses.is_empty());
198        assert_eq!(decoder.masker_name(), "SoftPenaltyMask");
199    }
200
201    #[test]
202    fn error_from_score_fn_is_propagated() {
203        let decoder = RuleGuidedBeamSearch::new(
204            flat_config(),
205            mk_constraint_alice_bob(),
206            Arc::new(HardMask::new()),
207        );
208        let result = decoder.decode(0, |_beams: &[&[usize]]| {
209            Err::<Vec<Vec<f64>>, String>("synthetic scoring error".into())
210        });
211        assert!(result.is_err());
212        let msg = format!("{}", result.expect_err("should have returned an error"));
213        assert!(msg.contains("synthetic"));
214    }
215}