tensorlogic_trustformers/rule_guided_decoder/
engine.rs1use std::sync::Arc;
18
19use tensorlogic_infer::beam_search::{BeamSearchConfig, BeamSearchDecoder, BeamSearchResult};
20
21use crate::rule_guided_decoder::constraint::RuleConstraint;
22use crate::rule_guided_decoder::error::{RuleGuidedError, RuleGuidedResult};
23use crate::rule_guided_decoder::mask::LogitMasker;
24
25pub struct RuleGuidedBeamSearch {
28 inner: BeamSearchDecoder,
29 constraint: Arc<RuleConstraint>,
30 masker: Arc<dyn LogitMasker>,
31}
32
33impl RuleGuidedBeamSearch {
34 pub fn new(
36 config: BeamSearchConfig,
37 constraint: RuleConstraint,
38 masker: Arc<dyn LogitMasker>,
39 ) -> Self {
40 Self {
41 inner: BeamSearchDecoder::new(config),
42 constraint: Arc::new(constraint),
43 masker,
44 }
45 }
46
47 pub fn config(&self) -> &BeamSearchConfig {
49 &self.inner.config
50 }
51
52 pub fn constraint(&self) -> &RuleConstraint {
54 &self.constraint
55 }
56
57 pub fn masker_name(&self) -> &'static str {
59 self.masker.name()
60 }
61
62 pub fn decode<F>(&self, bos_token_id: usize, score_fn: F) -> RuleGuidedResult<BeamSearchResult>
69 where
70 F: Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String>,
71 {
72 let constraint = Arc::clone(&self.constraint);
73 let masker = Arc::clone(&self.masker);
74 let expected_vocab = self.inner.config.vocab_size;
75
76 let wrapped = move |beams: &[&[usize]]| -> Result<Vec<Vec<f64>>, String> {
77 let mut raw_logits = score_fn(beams)?;
78 for (beam_idx, logits_row) in raw_logits.iter_mut().enumerate() {
79 if logits_row.len() != expected_vocab {
80 return Err(format!(
81 "logits row {beam_idx} has width {}, expected {expected_vocab}",
82 logits_row.len()
83 ));
84 }
85 let prefix = beams.get(beam_idx).copied().unwrap_or(&[]);
86 masker
87 .apply(&constraint, prefix, logits_row)
88 .map_err(|e| format!("rule-guided mask error: {e}"))?;
89 }
90 Ok(raw_logits)
91 };
92
93 self.inner
94 .decode(bos_token_id, wrapped)
95 .map_err(RuleGuidedError::from)
96 }
97}
98
99impl std::fmt::Debug for RuleGuidedBeamSearch {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("RuleGuidedBeamSearch")
102 .field("config", &self.inner.config)
103 .field("constraint", &self.constraint)
104 .field("masker", &self.masker.name())
105 .finish()
106 }
107}
108
109#[cfg(test)]
114mod tests {
115 use super::*;
116 use crate::rule_guided_decoder::constraint::TokenId;
117 use crate::rule_guided_decoder::mask::{HardMask, SoftPenaltyMask};
118 use tensorlogic_ir::{TLExpr, Term};
119
120 fn mk_constraint_alice_bob() -> RuleConstraint {
121 let a = TLExpr::Pred {
123 name: "entity".into(),
124 args: vec![Term::Const("Alice".into())],
125 };
126 let b = TLExpr::Pred {
127 name: "entity".into(),
128 args: vec![Term::Const("Bob".into())],
129 };
130 let expr = TLExpr::Or(Box::new(a), Box::new(b));
131 let mapper = |tid: TokenId| match tid {
132 0 => Some("entity".into()),
133 1 => Some("Alice".into()),
134 2 => Some("Bob".into()),
135 3 => Some("Eve".into()),
136 _ => None,
137 };
138 RuleConstraint::compile(expr, mapper).expect("compile")
139 }
140
141 fn flat_config() -> BeamSearchConfig {
142 BeamSearchConfig {
143 beam_width: 2,
144 max_length: 4,
145 eos_token_id: None,
146 length_penalty: 0.0,
147 min_length: 1,
148 vocab_size: 4,
149 temperature: 1.0,
150 top_k_filter: None,
151 }
152 }
153
154 fn flat_scores() -> impl Fn(&[&[usize]]) -> Result<Vec<Vec<f64>>, String> {
155 |beams: &[&[usize]]| Ok(beams.iter().map(|_| vec![1.0_f64, 1.0, 1.0, 1.0]).collect())
158 }
159
160 #[test]
161 fn hard_mask_excludes_forbidden_token() {
162 let decoder = RuleGuidedBeamSearch::new(
163 flat_config(),
164 mk_constraint_alice_bob(),
165 Arc::new(HardMask::new()),
166 );
167
168 let result = decoder
169 .decode(0, flat_scores())
170 .expect("decode should succeed");
171 for hyp in &result.hypotheses {
175 assert!(
176 !hyp.tokens.contains(&3),
177 "hard-masked decoder emitted forbidden token: {:?}",
178 hyp.tokens
179 );
180 }
181 assert_eq!(decoder.masker_name(), "HardMask");
182 }
183
184 #[test]
185 fn soft_mask_allows_forbidden_when_lambda_is_zero() {
186 let decoder = RuleGuidedBeamSearch::new(
190 flat_config(),
191 mk_constraint_alice_bob(),
192 Arc::new(SoftPenaltyMask::new(0.0).expect("lambda")),
193 );
194 let result = decoder
195 .decode(0, flat_scores())
196 .expect("decode should succeed");
197 assert!(!result.hypotheses.is_empty());
198 assert_eq!(decoder.masker_name(), "SoftPenaltyMask");
199 }
200
201 #[test]
202 fn error_from_score_fn_is_propagated() {
203 let decoder = RuleGuidedBeamSearch::new(
204 flat_config(),
205 mk_constraint_alice_bob(),
206 Arc::new(HardMask::new()),
207 );
208 let result = decoder.decode(0, |_beams: &[&[usize]]| {
209 Err::<Vec<Vec<f64>>, String>("synthetic scoring error".into())
210 });
211 assert!(result.is_err());
212 let msg = format!("{}", result.expect_err("should have returned an error"));
213 assert!(msg.contains("synthetic"));
214 }
215}