torg_mask/
generator.rs

1//! Mask generator for creating logit masks from builder state.
2
3use torg_core::{Builder, Token};
4
5use crate::mapping::TokenMapping;
6use crate::mask::LogitMask;
7
8/// Generates logit masks from TØR-G builder state.
9///
10/// This struct combines a token mapping with a vocabulary size to
11/// produce `LogitMask` instances from `Builder::valid_next_tokens()`.
12#[derive(Debug, Clone)]
13pub struct MaskGenerator {
14    mapping: TokenMapping,
15    vocab_size: usize,
16}
17
18impl MaskGenerator {
19    /// Create a new mask generator.
20    ///
21    /// # Arguments
22    ///
23    /// * `mapping` - Token mapping from TØR-G tokens to LLM vocab IDs
24    /// * `vocab_size` - Size of the LLM vocabulary
25    pub fn new(mapping: TokenMapping, vocab_size: usize) -> Self {
26        Self {
27            mapping,
28            vocab_size,
29        }
30    }
31
32    /// Get the token mapping.
33    pub fn mapping(&self) -> &TokenMapping {
34        &self.mapping
35    }
36
37    /// Get the vocabulary size.
38    pub fn vocab_size(&self) -> usize {
39        self.vocab_size
40    }
41
42    /// Generate a logit mask from a list of valid tokens.
43    ///
44    /// Tokens that cannot be mapped are silently ignored.
45    pub fn generate(&self, valid_tokens: &[Token]) -> LogitMask {
46        let allowed: Vec<u32> = valid_tokens
47            .iter()
48            .filter_map(|&token| self.mapping.get(token))
49            .collect();
50
51        LogitMask::new(self.vocab_size, allowed)
52    }
53
54    /// Generate a logit mask from builder state.
55    ///
56    /// This is a convenience method that calls `builder.valid_next_tokens()`
57    /// and then generates the mask.
58    pub fn generate_from_builder(&self, builder: &Builder) -> LogitMask {
59        self.generate(&builder.valid_next_tokens())
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use torg_core::Token;
67
68    #[test]
69    fn test_generate_operators() {
70        let mapping = TokenMapping::sequential(256);
71        let generator = MaskGenerator::new(mapping, 1000);
72
73        let valid = vec![Token::Or, Token::Nor, Token::Xor];
74        let mask = generator.generate(&valid);
75
76        assert_eq!(mask.allowed_count(), 3);
77        assert!(mask.is_allowed(0)); // Or
78        assert!(mask.is_allowed(1)); // Nor
79        assert!(mask.is_allowed(2)); // Xor
80        assert!(!mask.is_allowed(3)); // NodeStart - not allowed
81    }
82
83    #[test]
84    fn test_generate_with_ids() {
85        let mapping = TokenMapping::sequential(256);
86        let generator = MaskGenerator::new(mapping, 1000);
87
88        let valid = vec![Token::Id(0), Token::Id(1), Token::True, Token::False];
89        let mask = generator.generate(&valid);
90
91        assert_eq!(mask.allowed_count(), 4);
92        assert!(mask.is_allowed(9)); // Id(0)
93        assert!(mask.is_allowed(10)); // Id(1)
94        assert!(mask.is_allowed(7)); // True
95        assert!(mask.is_allowed(8)); // False
96    }
97
98    #[test]
99    fn test_generate_from_builder() {
100        let mapping = TokenMapping::sequential(256);
101        let generator = MaskGenerator::new(mapping, 1000);
102
103        let mut builder = Builder::new();
104        builder.push(Token::InputDecl).unwrap();
105        builder.push(Token::Id(0)).unwrap();
106        builder.push(Token::NodeStart).unwrap();
107        builder.push(Token::Id(1)).unwrap();
108
109        // After node ID, only operators are valid
110        let mask = generator.generate_from_builder(&builder);
111
112        assert_eq!(mask.allowed_count(), 3);
113        assert!(mask.is_allowed(0)); // Or
114        assert!(mask.is_allowed(1)); // Nor
115        assert!(mask.is_allowed(2)); // Xor
116    }
117
118    #[test]
119    fn test_unmapped_tokens_ignored() {
120        let mapping = TokenMapping::sequential(10); // Only Id(0)..Id(9)
121        let generator = MaskGenerator::new(mapping, 1000);
122
123        let valid = vec![Token::Id(0), Token::Id(100)]; // Id(100) out of range
124        let mask = generator.generate(&valid);
125
126        assert_eq!(mask.allowed_count(), 1);
127        assert!(mask.is_allowed(9)); // Id(0)
128    }
129}