torg_mask/
decoder.rs

1//! High-level constrained decoding orchestrator.
2
3use torg_core::{BuildError, Builder, Graph, Limits, Phase};
4
5use crate::generator::MaskGenerator;
6use crate::mask::LogitMask;
7
8/// High-level orchestrator for constrained LLM decoding.
9///
10/// This struct manages the decode loop for generating TØR-G graphs
11/// from an LLM. It maintains builder state, generates masks, and
12/// converts LLM tokens back to TØR-G tokens.
13///
14/// # Example
15///
16/// ```ignore
17/// let mapping = TokenMapping::sequential(256);
18/// let generator = MaskGenerator::new(mapping, vocab_size);
19/// let mut decoder = ConstrainedDecoder::new(generator);
20///
21/// while !decoder.is_complete() {
22///     let mask = decoder.next_mask();
23///     mask.apply_to_logits(&mut logits);
24///     let token_id = sample(&logits);
25///     decoder.feed_token(token_id)?;
26/// }
27///
28/// let graph = decoder.finish()?;
29/// ```
30#[derive(Debug)]
31pub struct ConstrainedDecoder {
32    builder: Builder,
33    generator: MaskGenerator,
34}
35
36impl ConstrainedDecoder {
37    /// Create a new constrained decoder with default limits.
38    pub fn new(generator: MaskGenerator) -> Self {
39        Self {
40            builder: Builder::new(),
41            generator,
42        }
43    }
44
45    /// Create a new constrained decoder with custom limits.
46    pub fn with_limits(generator: MaskGenerator, limits: Limits) -> Self {
47        Self {
48            builder: Builder::with_limits(limits),
49            generator,
50        }
51    }
52
53    /// Get the current build phase.
54    pub fn phase(&self) -> Phase {
55        self.builder.phase()
56    }
57
58    /// Check if decoding is complete.
59    ///
60    /// Returns `true` when the graph is in a completable state (at least
61    /// one output declared) or when the builder is in the `Done` phase.
62    pub fn is_complete(&self) -> bool {
63        self.builder.phase() == Phase::Done || self.builder.is_completable()
64    }
65
66    /// Generate the logit mask for the next token.
67    ///
68    /// Apply this mask to your LLM's logits before sampling.
69    pub fn next_mask(&self) -> LogitMask {
70        self.generator.generate_from_builder(&self.builder)
71    }
72
73    /// Feed an LLM token ID into the decoder.
74    ///
75    /// The token ID is reverse-mapped to a TØR-G token and pushed
76    /// to the internal builder.
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if:
81    /// - The token ID doesn't map to a valid TØR-G token
82    /// - The token is not valid in the current builder state
83    pub fn feed_token(&mut self, llm_token_id: u32) -> Result<(), DecodeError> {
84        let token = self
85            .generator
86            .mapping()
87            .reverse(llm_token_id)
88            .ok_or(DecodeError::UnmappedToken(llm_token_id))?;
89
90        self.builder.push(token).map_err(DecodeError::BuildError)?;
91
92        Ok(())
93    }
94
95    /// Finish decoding and return the constructed graph.
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if the graph is incomplete (e.g., no outputs declared).
100    pub fn finish(self) -> Result<Graph, BuildError> {
101        self.builder.finish()
102    }
103
104    /// Get a reference to the internal builder.
105    pub fn builder(&self) -> &Builder {
106        &self.builder
107    }
108
109    /// Get the mask generator.
110    pub fn generator(&self) -> &MaskGenerator {
111        &self.generator
112    }
113}
114
115/// Errors that can occur during constrained decoding.
116#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum DecodeError {
118    /// The LLM token ID doesn't map to any TØR-G token.
119    UnmappedToken(u32),
120    /// The TØR-G token is not valid in the current state.
121    BuildError(BuildError),
122}
123
124impl std::fmt::Display for DecodeError {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            DecodeError::UnmappedToken(id) => write!(f, "unmapped LLM token ID: {}", id),
128            DecodeError::BuildError(e) => write!(f, "build error: {}", e),
129        }
130    }
131}
132
133impl std::error::Error for DecodeError {
134    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
135        match self {
136            DecodeError::BuildError(e) => Some(e),
137            _ => None,
138        }
139    }
140}
141
142impl From<BuildError> for DecodeError {
143    fn from(e: BuildError) -> Self {
144        DecodeError::BuildError(e)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::mapping::TokenMapping;
152    use torg_core::Token;
153
154    fn make_decoder() -> ConstrainedDecoder {
155        let mapping = TokenMapping::sequential(256);
156        let generator = MaskGenerator::new(mapping, 1000);
157        ConstrainedDecoder::new(generator)
158    }
159
160    #[test]
161    fn test_simple_decode() {
162        let mut decoder = make_decoder();
163
164        // Build: InputDecl Id(0) NodeStart Id(1) Or Id(0) True NodeEnd OutputDecl Id(1)
165        let token_ids = [
166            5,  // InputDecl
167            9,  // Id(0)
168            3,  // NodeStart
169            10, // Id(1)
170            0,  // Or
171            9,  // Id(0)
172            7,  // True
173            4,  // NodeEnd
174            6,  // OutputDecl
175            10, // Id(1)
176        ];
177
178        for &id in &token_ids {
179            assert!(!decoder.is_complete());
180            let mask = decoder.next_mask();
181            assert!(mask.is_allowed(id), "token {} should be allowed", id);
182            decoder.feed_token(id).unwrap();
183        }
184
185        let graph = decoder.finish().unwrap();
186        assert_eq!(graph.inputs, vec![0]);
187        assert_eq!(graph.nodes.len(), 1);
188        assert_eq!(graph.outputs, vec![1]);
189    }
190
191    #[test]
192    fn test_unmapped_token_error() {
193        let mut decoder = make_decoder();
194        let result = decoder.feed_token(9999);
195        assert!(matches!(result, Err(DecodeError::UnmappedToken(9999))));
196    }
197
198    #[test]
199    fn test_invalid_token_error() {
200        let mut decoder = make_decoder();
201        // Push InputDecl
202        decoder.feed_token(5).unwrap();
203        // Try to push NodeEnd (invalid after InputDecl)
204        let result = decoder.feed_token(4);
205        assert!(matches!(result, Err(DecodeError::BuildError(_))));
206    }
207
208    #[test]
209    fn test_mask_constraints() {
210        let mut decoder = make_decoder();
211
212        // Initially, InputDecl, NodeStart, OutputDecl, and Id tokens are valid
213        let mask = decoder.next_mask();
214        assert!(mask.is_allowed(5)); // InputDecl
215        assert!(mask.is_allowed(3)); // NodeStart
216        assert!(mask.is_allowed(6)); // OutputDecl
217
218        // After InputDecl, Id tokens should be valid
219        decoder.feed_token(5).unwrap(); // InputDecl
220        decoder.feed_token(9).unwrap(); // Id(0)
221
222        // After NodeStart Id, only operators are valid
223        decoder.feed_token(3).unwrap(); // NodeStart
224        decoder.feed_token(10).unwrap(); // Id(1)
225
226        let mask = decoder.next_mask();
227        assert_eq!(mask.allowed_count(), 3);
228        assert!(mask.is_allowed(0)); // Or
229        assert!(mask.is_allowed(1)); // Nor
230        assert!(mask.is_allowed(2)); // Xor
231        assert!(!mask.is_allowed(3)); // NodeStart - not valid
232        assert!(!mask.is_allowed(4)); // NodeEnd - not valid yet
233    }
234
235    #[test]
236    fn test_full_decode_loop() {
237        let mapping = TokenMapping::sequential(256);
238        let generator = MaskGenerator::new(mapping.clone(), 1000);
239        let mut decoder = ConstrainedDecoder::new(generator);
240
241        // Simulate: "Admin OR (Owner XOR Public)"
242        // Tokens: InputDecl Id(0) InputDecl Id(1) InputDecl Id(2)
243        //         NodeStart Id(3) Xor Id(1) Id(2) NodeEnd
244        //         NodeStart Id(4) Or Id(0) Id(3) NodeEnd
245        //         OutputDecl Id(4)
246        let tokens = [
247            Token::InputDecl,
248            Token::Id(0),
249            Token::InputDecl,
250            Token::Id(1),
251            Token::InputDecl,
252            Token::Id(2),
253            Token::NodeStart,
254            Token::Id(3),
255            Token::Xor,
256            Token::Id(1),
257            Token::Id(2),
258            Token::NodeEnd,
259            Token::NodeStart,
260            Token::Id(4),
261            Token::Or,
262            Token::Id(0),
263            Token::Id(3),
264            Token::NodeEnd,
265            Token::OutputDecl,
266            Token::Id(4),
267        ];
268
269        for token in tokens {
270            let mask = decoder.next_mask();
271            let id = mapping.get(token).unwrap();
272            assert!(
273                mask.is_allowed(id),
274                "token {:?} (id {}) should be allowed",
275                token,
276                id
277            );
278            decoder.feed_token(id).unwrap();
279        }
280
281        let graph = decoder.finish().unwrap();
282        assert_eq!(graph.inputs, vec![0, 1, 2]);
283        assert_eq!(graph.nodes.len(), 2);
284        assert_eq!(graph.outputs, vec![4]);
285    }
286}