torg_mask/
mapping.rs

1//! Token mapping between TØR-G tokens and LLM vocabulary IDs.
2
3use torg_core::Token;
4
5/// Maps TØR-G tokens to LLM vocabulary token IDs.
6///
7/// TØR-G has 9 fixed tokens plus a range of Id tokens. This struct
8/// provides bidirectional mapping between TØR-G tokens and the
9/// corresponding token IDs in an LLM's vocabulary.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct TokenMapping {
12    // Fixed token mappings
13    or_id: u32,
14    nor_id: u32,
15    xor_id: u32,
16    node_start_id: u32,
17    node_end_id: u32,
18    input_decl_id: u32,
19    output_decl_id: u32,
20    true_id: u32,
21    false_id: u32,
22
23    /// Base LLM token ID for Id(0). Id(n) maps to id_base + n.
24    id_base: u32,
25
26    /// Number of Id tokens mapped (0..id_count).
27    /// Default 256, matching typical Limits. This is bounded by
28    /// practical usage (Limits::max_inputs + max_nodes), not the u16 type range.
29    id_count: u16,
30}
31
32impl TokenMapping {
33    /// Create a builder for custom token mappings.
34    pub fn builder() -> TokenMappingBuilder {
35        TokenMappingBuilder::new()
36    }
37
38    /// Example mapping using sequential IDs starting from 0.
39    ///
40    /// **Warning**: This is for testing only. In production, you must
41    /// map to actual unused token IDs in your LLM's vocabulary.
42    ///
43    /// Layout:
44    /// - 0: Or
45    /// - 1: Nor
46    /// - 2: Xor
47    /// - 3: NodeStart
48    /// - 4: NodeEnd
49    /// - 5: InputDecl
50    /// - 6: OutputDecl
51    /// - 7: True
52    /// - 8: False
53    /// - 9..265: Id(0)..Id(255)
54    pub fn sequential(id_count: u16) -> Self {
55        Self {
56            or_id: 0,
57            nor_id: 1,
58            xor_id: 2,
59            node_start_id: 3,
60            node_end_id: 4,
61            input_decl_id: 5,
62            output_decl_id: 6,
63            true_id: 7,
64            false_id: 8,
65            id_base: 9,
66            id_count,
67        }
68    }
69
70    /// Mapping for Mistral/Ministral models using reserved `<SPECIAL_N>` tokens.
71    ///
72    /// Ministral tokenizers reserve token IDs 36-565 as `<SPECIAL_N>` placeholders.
73    /// This mapping uses IDs 36-300 for TØR-G tokens:
74    ///
75    /// | TØR-G Token | Ministral ID |
76    /// |-------------|--------------|
77    /// | `Or`        | 36           |
78    /// | `Nor`       | 37           |
79    /// | `Xor`       | 38           |
80    /// | `NodeStart` | 39           |
81    /// | `NodeEnd`   | 40           |
82    /// | `InputDecl` | 41           |
83    /// | `OutputDecl`| 42           |
84    /// | `True`      | 43           |
85    /// | `False`     | 44           |
86    /// | `Id(0)`     | 45           |
87    /// | `Id(255)`   | 300          |
88    ///
89    /// Compatible with: Ministral-3B, Ministral-8B, Mistral-7B v0.3+
90    pub fn ministral() -> Self {
91        Self {
92            or_id: 36,
93            nor_id: 37,
94            xor_id: 38,
95            node_start_id: 39,
96            node_end_id: 40,
97            input_decl_id: 41,
98            output_decl_id: 42,
99            true_id: 43,
100            false_id: 44,
101            id_base: 45,
102            id_count: 256,
103        }
104    }
105
106    /// Map a TØR-G token to its LLM vocabulary ID.
107    ///
108    /// Returns `None` if the token cannot be mapped (e.g., Id out of range).
109    pub fn get(&self, token: Token) -> Option<u32> {
110        match token {
111            Token::Or => Some(self.or_id),
112            Token::Nor => Some(self.nor_id),
113            Token::Xor => Some(self.xor_id),
114            Token::NodeStart => Some(self.node_start_id),
115            Token::NodeEnd => Some(self.node_end_id),
116            Token::InputDecl => Some(self.input_decl_id),
117            Token::OutputDecl => Some(self.output_decl_id),
118            Token::True => Some(self.true_id),
119            Token::False => Some(self.false_id),
120            Token::Id(n) => {
121                if n < self.id_count {
122                    Some(self.id_base + n as u32)
123                } else {
124                    None
125                }
126            }
127        }
128    }
129
130    /// Map an LLM vocabulary ID back to a TØR-G token.
131    ///
132    /// Returns `None` if the ID doesn't correspond to any mapped token.
133    pub fn reverse(&self, id: u32) -> Option<Token> {
134        if id == self.or_id {
135            Some(Token::Or)
136        } else if id == self.nor_id {
137            Some(Token::Nor)
138        } else if id == self.xor_id {
139            Some(Token::Xor)
140        } else if id == self.node_start_id {
141            Some(Token::NodeStart)
142        } else if id == self.node_end_id {
143            Some(Token::NodeEnd)
144        } else if id == self.input_decl_id {
145            Some(Token::InputDecl)
146        } else if id == self.output_decl_id {
147            Some(Token::OutputDecl)
148        } else if id == self.true_id {
149            Some(Token::True)
150        } else if id == self.false_id {
151            Some(Token::False)
152        } else if id >= self.id_base && id < self.id_base + self.id_count as u32 {
153            Some(Token::Id((id - self.id_base) as u16))
154        } else {
155            None
156        }
157    }
158
159    /// Get the number of Id tokens mapped.
160    pub fn id_count(&self) -> u16 {
161        self.id_count
162    }
163
164    /// Get the total number of mapped tokens (9 fixed + id_count).
165    pub fn total_tokens(&self) -> usize {
166        9 + self.id_count as usize
167    }
168}
169
170impl Default for TokenMapping {
171    fn default() -> Self {
172        Self::sequential(256)
173    }
174}
175
176/// Builder for creating custom token mappings.
177#[derive(Debug, Clone)]
178pub struct TokenMappingBuilder {
179    or_id: Option<u32>,
180    nor_id: Option<u32>,
181    xor_id: Option<u32>,
182    node_start_id: Option<u32>,
183    node_end_id: Option<u32>,
184    input_decl_id: Option<u32>,
185    output_decl_id: Option<u32>,
186    true_id: Option<u32>,
187    false_id: Option<u32>,
188    id_base: Option<u32>,
189    id_count: u16,
190}
191
192impl TokenMappingBuilder {
193    /// Create a new builder with no mappings set.
194    pub fn new() -> Self {
195        Self {
196            or_id: None,
197            nor_id: None,
198            xor_id: None,
199            node_start_id: None,
200            node_end_id: None,
201            input_decl_id: None,
202            output_decl_id: None,
203            true_id: None,
204            false_id: None,
205            id_base: None,
206            id_count: 256,
207        }
208    }
209
210    /// Set the LLM token ID for Or.
211    pub fn or(mut self, id: u32) -> Self {
212        self.or_id = Some(id);
213        self
214    }
215
216    /// Set the LLM token ID for Nor.
217    pub fn nor(mut self, id: u32) -> Self {
218        self.nor_id = Some(id);
219        self
220    }
221
222    /// Set the LLM token ID for Xor.
223    pub fn xor(mut self, id: u32) -> Self {
224        self.xor_id = Some(id);
225        self
226    }
227
228    /// Set the LLM token ID for NodeStart.
229    pub fn node_start(mut self, id: u32) -> Self {
230        self.node_start_id = Some(id);
231        self
232    }
233
234    /// Set the LLM token ID for NodeEnd.
235    pub fn node_end(mut self, id: u32) -> Self {
236        self.node_end_id = Some(id);
237        self
238    }
239
240    /// Set the LLM token ID for InputDecl.
241    pub fn input_decl(mut self, id: u32) -> Self {
242        self.input_decl_id = Some(id);
243        self
244    }
245
246    /// Set the LLM token ID for OutputDecl.
247    pub fn output_decl(mut self, id: u32) -> Self {
248        self.output_decl_id = Some(id);
249        self
250    }
251
252    /// Set the LLM token ID for True.
253    pub fn true_token(mut self, id: u32) -> Self {
254        self.true_id = Some(id);
255        self
256    }
257
258    /// Set the LLM token ID for False.
259    pub fn false_token(mut self, id: u32) -> Self {
260        self.false_id = Some(id);
261        self
262    }
263
264    /// Set the base LLM token ID for Id tokens.
265    /// Id(n) will map to id_base + n.
266    pub fn id_base(mut self, base: u32) -> Self {
267        self.id_base = Some(base);
268        self
269    }
270
271    /// Set how many Id tokens to map (default 256).
272    pub fn id_count(mut self, count: u16) -> Self {
273        self.id_count = count;
274        self
275    }
276
277    /// Build the token mapping.
278    ///
279    /// # Panics
280    ///
281    /// Panics if any required mapping is not set.
282    pub fn build(self) -> TokenMapping {
283        TokenMapping {
284            or_id: self.or_id.expect("or_id not set"),
285            nor_id: self.nor_id.expect("nor_id not set"),
286            xor_id: self.xor_id.expect("xor_id not set"),
287            node_start_id: self.node_start_id.expect("node_start_id not set"),
288            node_end_id: self.node_end_id.expect("node_end_id not set"),
289            input_decl_id: self.input_decl_id.expect("input_decl_id not set"),
290            output_decl_id: self.output_decl_id.expect("output_decl_id not set"),
291            true_id: self.true_id.expect("true_id not set"),
292            false_id: self.false_id.expect("false_id not set"),
293            id_base: self.id_base.expect("id_base not set"),
294            id_count: self.id_count,
295        }
296    }
297}
298
299impl Default for TokenMappingBuilder {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_sequential_mapping() {
311        let mapping = TokenMapping::sequential(256);
312
313        assert_eq!(mapping.get(Token::Or), Some(0));
314        assert_eq!(mapping.get(Token::Nor), Some(1));
315        assert_eq!(mapping.get(Token::Xor), Some(2));
316        assert_eq!(mapping.get(Token::NodeStart), Some(3));
317        assert_eq!(mapping.get(Token::NodeEnd), Some(4));
318        assert_eq!(mapping.get(Token::InputDecl), Some(5));
319        assert_eq!(mapping.get(Token::OutputDecl), Some(6));
320        assert_eq!(mapping.get(Token::True), Some(7));
321        assert_eq!(mapping.get(Token::False), Some(8));
322        assert_eq!(mapping.get(Token::Id(0)), Some(9));
323        assert_eq!(mapping.get(Token::Id(255)), Some(264));
324        assert_eq!(mapping.get(Token::Id(256)), None);
325    }
326
327    #[test]
328    fn test_reverse_mapping() {
329        let mapping = TokenMapping::sequential(256);
330
331        assert_eq!(mapping.reverse(0), Some(Token::Or));
332        assert_eq!(mapping.reverse(1), Some(Token::Nor));
333        assert_eq!(mapping.reverse(9), Some(Token::Id(0)));
334        assert_eq!(mapping.reverse(264), Some(Token::Id(255)));
335        assert_eq!(mapping.reverse(265), None);
336        assert_eq!(mapping.reverse(1000), None);
337    }
338
339    #[test]
340    fn test_round_trip() {
341        let mapping = TokenMapping::sequential(256);
342
343        let tokens = [
344            Token::Or,
345            Token::Nor,
346            Token::Xor,
347            Token::NodeStart,
348            Token::NodeEnd,
349            Token::InputDecl,
350            Token::OutputDecl,
351            Token::True,
352            Token::False,
353            Token::Id(0),
354            Token::Id(42),
355            Token::Id(255),
356        ];
357
358        for token in tokens {
359            let id = mapping.get(token).unwrap();
360            let back = mapping.reverse(id).unwrap();
361            assert_eq!(token, back);
362        }
363    }
364
365    #[test]
366    fn test_builder() {
367        let mapping = TokenMapping::builder()
368            .or(100)
369            .nor(101)
370            .xor(102)
371            .node_start(103)
372            .node_end(104)
373            .input_decl(105)
374            .output_decl(106)
375            .true_token(107)
376            .false_token(108)
377            .id_base(1000)
378            .id_count(128)
379            .build();
380
381        assert_eq!(mapping.get(Token::Or), Some(100));
382        assert_eq!(mapping.get(Token::Id(0)), Some(1000));
383        assert_eq!(mapping.get(Token::Id(127)), Some(1127));
384        assert_eq!(mapping.get(Token::Id(128)), None);
385    }
386
387    #[test]
388    fn test_ministral_mapping() {
389        let mapping = TokenMapping::ministral();
390
391        // Fixed tokens use reserved <SPECIAL_N> IDs 36-44
392        assert_eq!(mapping.get(Token::Or), Some(36));
393        assert_eq!(mapping.get(Token::Nor), Some(37));
394        assert_eq!(mapping.get(Token::Xor), Some(38));
395        assert_eq!(mapping.get(Token::NodeStart), Some(39));
396        assert_eq!(mapping.get(Token::NodeEnd), Some(40));
397        assert_eq!(mapping.get(Token::InputDecl), Some(41));
398        assert_eq!(mapping.get(Token::OutputDecl), Some(42));
399        assert_eq!(mapping.get(Token::True), Some(43));
400        assert_eq!(mapping.get(Token::False), Some(44));
401
402        // Id tokens start at 45
403        assert_eq!(mapping.get(Token::Id(0)), Some(45));
404        assert_eq!(mapping.get(Token::Id(255)), Some(300));
405        assert_eq!(mapping.get(Token::Id(256)), None);
406
407        // Total: 9 fixed + 256 Id = 265 tokens
408        assert_eq!(mapping.total_tokens(), 265);
409    }
410
411    #[test]
412    fn test_ministral_round_trip() {
413        let mapping = TokenMapping::ministral();
414
415        // Test all fixed tokens
416        for token in [
417            Token::Or,
418            Token::Nor,
419            Token::Xor,
420            Token::NodeStart,
421            Token::NodeEnd,
422            Token::InputDecl,
423            Token::OutputDecl,
424            Token::True,
425            Token::False,
426        ] {
427            let id = mapping.get(token).unwrap();
428            let back = mapping.reverse(id).unwrap();
429            assert_eq!(token, back);
430        }
431
432        // Test Id tokens at boundaries
433        for n in [0, 1, 127, 128, 254, 255] {
434            let token = Token::Id(n);
435            let id = mapping.get(token).unwrap();
436            let back = mapping.reverse(id).unwrap();
437            assert_eq!(token, back);
438        }
439    }
440}