1use torg_core::{Builder, Token};
4
5use crate::mapping::TokenMapping;
6use crate::mask::LogitMask;
7
8#[derive(Debug, Clone)]
13pub struct MaskGenerator {
14 mapping: TokenMapping,
15 vocab_size: usize,
16}
17
18impl MaskGenerator {
19 pub fn new(mapping: TokenMapping, vocab_size: usize) -> Self {
26 Self {
27 mapping,
28 vocab_size,
29 }
30 }
31
32 pub fn mapping(&self) -> &TokenMapping {
34 &self.mapping
35 }
36
37 pub fn vocab_size(&self) -> usize {
39 self.vocab_size
40 }
41
42 pub fn generate(&self, valid_tokens: &[Token]) -> LogitMask {
46 let allowed: Vec<u32> = valid_tokens
47 .iter()
48 .filter_map(|&token| self.mapping.get(token))
49 .collect();
50
51 LogitMask::new(self.vocab_size, allowed)
52 }
53
54 pub fn generate_from_builder(&self, builder: &Builder) -> LogitMask {
59 self.generate(&builder.valid_next_tokens())
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use torg_core::Token;
67
68 #[test]
69 fn test_generate_operators() {
70 let mapping = TokenMapping::sequential(256);
71 let generator = MaskGenerator::new(mapping, 1000);
72
73 let valid = vec![Token::Or, Token::Nor, Token::Xor];
74 let mask = generator.generate(&valid);
75
76 assert_eq!(mask.allowed_count(), 3);
77 assert!(mask.is_allowed(0)); assert!(mask.is_allowed(1)); assert!(mask.is_allowed(2)); assert!(!mask.is_allowed(3)); }
82
83 #[test]
84 fn test_generate_with_ids() {
85 let mapping = TokenMapping::sequential(256);
86 let generator = MaskGenerator::new(mapping, 1000);
87
88 let valid = vec![Token::Id(0), Token::Id(1), Token::True, Token::False];
89 let mask = generator.generate(&valid);
90
91 assert_eq!(mask.allowed_count(), 4);
92 assert!(mask.is_allowed(9)); assert!(mask.is_allowed(10)); assert!(mask.is_allowed(7)); assert!(mask.is_allowed(8)); }
97
98 #[test]
99 fn test_generate_from_builder() {
100 let mapping = TokenMapping::sequential(256);
101 let generator = MaskGenerator::new(mapping, 1000);
102
103 let mut builder = Builder::new();
104 builder.push(Token::InputDecl).unwrap();
105 builder.push(Token::Id(0)).unwrap();
106 builder.push(Token::NodeStart).unwrap();
107 builder.push(Token::Id(1)).unwrap();
108
109 let mask = generator.generate_from_builder(&builder);
111
112 assert_eq!(mask.allowed_count(), 3);
113 assert!(mask.is_allowed(0)); assert!(mask.is_allowed(1)); assert!(mask.is_allowed(2)); }
117
118 #[test]
119 fn test_unmapped_tokens_ignored() {
120 let mapping = TokenMapping::sequential(10); let generator = MaskGenerator::new(mapping, 1000);
122
123 let valid = vec![Token::Id(0), Token::Id(100)]; let mask = generator.generate(&valid);
125
126 assert_eq!(mask.allowed_count(), 1);
127 assert!(mask.is_allowed(9)); }
129}