1use torg_core::{BuildError, Builder, Graph, Limits, Phase};
4
5use crate::generator::MaskGenerator;
6use crate::mask::LogitMask;
7
8#[derive(Debug)]
31pub struct ConstrainedDecoder {
32 builder: Builder,
33 generator: MaskGenerator,
34}
35
36impl ConstrainedDecoder {
37 pub fn new(generator: MaskGenerator) -> Self {
39 Self {
40 builder: Builder::new(),
41 generator,
42 }
43 }
44
45 pub fn with_limits(generator: MaskGenerator, limits: Limits) -> Self {
47 Self {
48 builder: Builder::with_limits(limits),
49 generator,
50 }
51 }
52
53 pub fn phase(&self) -> Phase {
55 self.builder.phase()
56 }
57
58 pub fn is_complete(&self) -> bool {
63 self.builder.phase() == Phase::Done || self.builder.is_completable()
64 }
65
66 pub fn next_mask(&self) -> LogitMask {
70 self.generator.generate_from_builder(&self.builder)
71 }
72
73 pub fn feed_token(&mut self, llm_token_id: u32) -> Result<(), DecodeError> {
84 let token = self
85 .generator
86 .mapping()
87 .reverse(llm_token_id)
88 .ok_or(DecodeError::UnmappedToken(llm_token_id))?;
89
90 self.builder.push(token).map_err(DecodeError::BuildError)?;
91
92 Ok(())
93 }
94
95 pub fn finish(self) -> Result<Graph, BuildError> {
101 self.builder.finish()
102 }
103
104 pub fn builder(&self) -> &Builder {
106 &self.builder
107 }
108
109 pub fn generator(&self) -> &MaskGenerator {
111 &self.generator
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum DecodeError {
118 UnmappedToken(u32),
120 BuildError(BuildError),
122}
123
124impl std::fmt::Display for DecodeError {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 DecodeError::UnmappedToken(id) => write!(f, "unmapped LLM token ID: {}", id),
128 DecodeError::BuildError(e) => write!(f, "build error: {}", e),
129 }
130 }
131}
132
133impl std::error::Error for DecodeError {
134 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
135 match self {
136 DecodeError::BuildError(e) => Some(e),
137 _ => None,
138 }
139 }
140}
141
142impl From<BuildError> for DecodeError {
143 fn from(e: BuildError) -> Self {
144 DecodeError::BuildError(e)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::mapping::TokenMapping;
152 use torg_core::Token;
153
154 fn make_decoder() -> ConstrainedDecoder {
155 let mapping = TokenMapping::sequential(256);
156 let generator = MaskGenerator::new(mapping, 1000);
157 ConstrainedDecoder::new(generator)
158 }
159
160 #[test]
161 fn test_simple_decode() {
162 let mut decoder = make_decoder();
163
164 let token_ids = [
166 5, 9, 3, 10, 0, 9, 7, 4, 6, 10, ];
177
178 for &id in &token_ids {
179 assert!(!decoder.is_complete());
180 let mask = decoder.next_mask();
181 assert!(mask.is_allowed(id), "token {} should be allowed", id);
182 decoder.feed_token(id).unwrap();
183 }
184
185 let graph = decoder.finish().unwrap();
186 assert_eq!(graph.inputs, vec![0]);
187 assert_eq!(graph.nodes.len(), 1);
188 assert_eq!(graph.outputs, vec![1]);
189 }
190
191 #[test]
192 fn test_unmapped_token_error() {
193 let mut decoder = make_decoder();
194 let result = decoder.feed_token(9999);
195 assert!(matches!(result, Err(DecodeError::UnmappedToken(9999))));
196 }
197
198 #[test]
199 fn test_invalid_token_error() {
200 let mut decoder = make_decoder();
201 decoder.feed_token(5).unwrap();
203 let result = decoder.feed_token(4);
205 assert!(matches!(result, Err(DecodeError::BuildError(_))));
206 }
207
208 #[test]
209 fn test_mask_constraints() {
210 let mut decoder = make_decoder();
211
212 let mask = decoder.next_mask();
214 assert!(mask.is_allowed(5)); assert!(mask.is_allowed(3)); assert!(mask.is_allowed(6)); decoder.feed_token(5).unwrap(); decoder.feed_token(9).unwrap(); decoder.feed_token(3).unwrap(); decoder.feed_token(10).unwrap(); let mask = decoder.next_mask();
227 assert_eq!(mask.allowed_count(), 3);
228 assert!(mask.is_allowed(0)); assert!(mask.is_allowed(1)); assert!(mask.is_allowed(2)); assert!(!mask.is_allowed(3)); assert!(!mask.is_allowed(4)); }
234
235 #[test]
236 fn test_full_decode_loop() {
237 let mapping = TokenMapping::sequential(256);
238 let generator = MaskGenerator::new(mapping.clone(), 1000);
239 let mut decoder = ConstrainedDecoder::new(generator);
240
241 let tokens = [
247 Token::InputDecl,
248 Token::Id(0),
249 Token::InputDecl,
250 Token::Id(1),
251 Token::InputDecl,
252 Token::Id(2),
253 Token::NodeStart,
254 Token::Id(3),
255 Token::Xor,
256 Token::Id(1),
257 Token::Id(2),
258 Token::NodeEnd,
259 Token::NodeStart,
260 Token::Id(4),
261 Token::Or,
262 Token::Id(0),
263 Token::Id(3),
264 Token::NodeEnd,
265 Token::OutputDecl,
266 Token::Id(4),
267 ];
268
269 for token in tokens {
270 let mask = decoder.next_mask();
271 let id = mapping.get(token).unwrap();
272 assert!(
273 mask.is_allowed(id),
274 "token {:?} (id {}) should be allowed",
275 token,
276 id
277 );
278 decoder.feed_token(id).unwrap();
279 }
280
281 let graph = decoder.finish().unwrap();
282 assert_eq!(graph.inputs, vec![0, 1, 2]);
283 assert_eq!(graph.nodes.len(), 2);
284 assert_eq!(graph.outputs, vec![4]);
285 }
286}