Skip to main content

trustformers_tokenizers/
subword_regularization.rs

1use scirs2_core::random::*; // SciRS2 Integration Policy - Replaces rand
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use trustformers_core::errors::Result;
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7/// Configuration for subword regularization
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SubwordRegularizationConfig {
10    /// Alpha parameter for controlling randomness (0.0 = no randomness, 1.0 = maximum randomness)
11    pub alpha: f32,
12    /// Number of alternative segmentations to sample
13    pub num_samples: usize,
14    /// Seed for reproducible randomness
15    pub seed: Option<u64>,
16    /// Enable debugging output
17    pub debug: bool,
18}
19
20impl Default for SubwordRegularizationConfig {
21    fn default() -> Self {
22        Self {
23            alpha: 0.1,
24            num_samples: 1,
25            seed: None,
26            debug: false,
27        }
28    }
29}
30
31/// Subword regularization wrapper that adds randomness to tokenization
32pub struct SubwordRegularizer<T: Tokenizer> {
33    tokenizer: T,
34    config: SubwordRegularizationConfig,
35    rng: StdRng,
36}
37
38impl<T: Tokenizer> SubwordRegularizer<T> {
39    pub fn new(tokenizer: T, config: SubwordRegularizationConfig) -> Self {
40        let rng = if let Some(seed) = config.seed {
41            StdRng::seed_from_u64(seed)
42        } else {
43            // Generate random seed from thread_rng
44            let seed = thread_rng().random();
45            StdRng::seed_from_u64(seed)
46        };
47
48        Self {
49            tokenizer,
50            config,
51            rng,
52        }
53    }
54
55    pub fn with_alpha(mut self, alpha: f32) -> Self {
56        self.config.alpha = alpha;
57        self
58    }
59
60    pub fn with_num_samples(mut self, num_samples: usize) -> Self {
61        self.config.num_samples = num_samples;
62        self
63    }
64
65    pub fn with_seed(mut self, seed: u64) -> Self {
66        self.config.seed = Some(seed);
67        self.rng = StdRng::seed_from_u64(seed);
68        self
69    }
70
71    /// Generate multiple tokenizations with regularization
72    pub fn encode_with_regularization(&mut self, text: &str) -> Result<Vec<TokenizedInput>> {
73        let mut results = Vec::new();
74
75        for _ in 0..self.config.num_samples {
76            let regularized_text = self.apply_regularization(text);
77            let tokenized = self.tokenizer.encode(&regularized_text)?;
78            results.push(tokenized);
79        }
80
81        Ok(results)
82    }
83
84    /// Apply regularization to text (simplified version)
85    fn apply_regularization(&mut self, text: &str) -> String {
86        if self.config.alpha <= 0.0 {
87            return text.to_string();
88        }
89
90        let mut result = String::new();
91        let chars: Vec<char> = text.chars().collect();
92        let mut i = 0;
93
94        while i < chars.len() {
95            let char = chars[i];
96
97            // Add some randomness to character processing
98            if self.rng.random::<f32>() < self.config.alpha {
99                // Skip character with some probability
100                if self.rng.random::<f32>() < 0.1 {
101                    i += 1;
102                    continue;
103                }
104
105                // Duplicate character with some probability
106                if self.rng.random::<f32>() < 0.05 {
107                    result.push(char);
108                    result.push(char);
109                    i += 1;
110                    continue;
111                }
112            }
113
114            result.push(char);
115            i += 1;
116        }
117
118        result
119    }
120
121    /// Get the underlying tokenizer
122    pub fn inner(&self) -> &T {
123        &self.tokenizer
124    }
125
126    /// Get the configuration
127    pub fn config(&self) -> &SubwordRegularizationConfig {
128        &self.config
129    }
130}
131
132impl<T: Tokenizer> Tokenizer for SubwordRegularizer<T> {
133    fn encode(&self, text: &str) -> Result<TokenizedInput> {
134        // For the basic interface, just use the underlying tokenizer
135        self.tokenizer.encode(text)
136    }
137
138    fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
139        self.tokenizer.encode_pair(text, text2)
140    }
141
142    fn decode(&self, ids: &[u32]) -> Result<String> {
143        self.tokenizer.decode(ids)
144    }
145
146    fn vocab_size(&self) -> usize {
147        self.tokenizer.vocab_size()
148    }
149
150    fn get_vocab(&self) -> HashMap<String, u32> {
151        self.tokenizer.get_vocab()
152    }
153
154    fn token_to_id(&self, token: &str) -> Option<u32> {
155        self.tokenizer.token_to_id(token)
156    }
157
158    fn id_to_token(&self, id: u32) -> Option<String> {
159        self.tokenizer.id_to_token(id)
160    }
161}
162
163/// Unigram-specific subword regularization implementation
164pub struct UnigramSubwordRegularizer {
165    vocab: HashMap<String, f32>,
166    config: SubwordRegularizationConfig,
167    rng: StdRng,
168}
169
170impl UnigramSubwordRegularizer {
171    pub fn new(vocab: HashMap<String, f32>, config: SubwordRegularizationConfig) -> Self {
172        let rng = if let Some(seed) = config.seed {
173            StdRng::seed_from_u64(seed)
174        } else {
175            // Generate random seed from thread_rng
176            let seed = thread_rng().random();
177            StdRng::seed_from_u64(seed)
178        };
179
180        Self { vocab, config, rng }
181    }
182
183    /// Sample a segmentation using the Unigram language model with regularization
184    pub fn sample_segmentation(&mut self, text: &str) -> Result<Vec<String>> {
185        let chars: Vec<char> = text.chars().collect();
186        let n = chars.len();
187
188        if n == 0 {
189            return Ok(vec![]);
190        }
191
192        // Dynamic programming with sampling
193        let mut dp = vec![vec![0.0; n + 1]; n + 1];
194        let mut best_seg = vec![vec![None; n + 1]; n + 1];
195
196        // Initialize
197        for (i, dp_row) in dp.iter_mut().enumerate().take(n + 1) {
198            dp_row[i] = 0.0;
199        }
200
201        // Fill DP table with regularization
202        for length in 1..=n {
203            for start in 0..=n - length {
204                let end = start + length;
205                let substring: String = chars[start..end].iter().collect();
206
207                if let Some(&score) = self.vocab.get(&substring) {
208                    // Apply regularization to the score
209                    let regularized_score = if self.config.alpha > 0.0 {
210                        let noise = self.rng.random::<f32>() * self.config.alpha;
211                        score + noise - self.config.alpha / 2.0
212                    } else {
213                        score
214                    };
215
216                    if dp[start][end] < regularized_score {
217                        dp[start][end] = regularized_score;
218                        best_seg[start][end] = Some(substring);
219                    }
220                }
221
222                // Try splitting at intermediate points
223                for mid in start + 1..end {
224                    let combined_score = dp[start][mid] + dp[mid][end];
225                    if dp[start][end] < combined_score {
226                        dp[start][end] = combined_score;
227                        best_seg[start][end] = None; // Mark as split
228                    }
229                }
230            }
231        }
232
233        // Backtrack to get the segmentation
234        self.backtrack_segmentation(&best_seg, 0, n, &chars)
235    }
236
237    #[allow(clippy::only_used_in_recursion)]
238    fn backtrack_segmentation(
239        &self,
240        best_seg: &[Vec<Option<String>>],
241        start: usize,
242        end: usize,
243        chars: &[char],
244    ) -> Result<Vec<String>> {
245        if start == end {
246            return Ok(vec![]);
247        }
248
249        if let Some(ref segment) = best_seg[start][end] {
250            return Ok(vec![segment.clone()]);
251        }
252
253        // Find the best split point
254        let mut best_split = start + 1;
255        let mut best_score = f32::NEG_INFINITY;
256
257        for (mid, _) in best_seg.iter().enumerate().take(end).skip(start + 1) {
258            let score = best_seg[start][mid].as_ref().map(|_| 1.0).unwrap_or(0.0)
259                + best_seg[mid][end].as_ref().map(|_| 1.0).unwrap_or(0.0);
260            if score > best_score {
261                best_score = score;
262                best_split = mid;
263            }
264        }
265
266        let mut result = self.backtrack_segmentation(best_seg, start, best_split, chars)?;
267        let mut right_part = self.backtrack_segmentation(best_seg, best_split, end, chars)?;
268        result.append(&mut right_part);
269
270        Ok(result)
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::char::CharTokenizer;
278
279    #[test]
280    fn test_subword_regularization_config() {
281        let config = SubwordRegularizationConfig::default();
282        assert_eq!(config.alpha, 0.1);
283        assert_eq!(config.num_samples, 1);
284        assert_eq!(config.seed, None);
285        assert!(!config.debug);
286    }
287
288    #[test]
289    fn test_subword_regularizer_creation() {
290        let tokenizer = CharTokenizer::from_text("hello world", 1000);
291        let config = SubwordRegularizationConfig::default();
292        let regularizer = SubwordRegularizer::new(tokenizer, config);
293
294        assert_eq!(regularizer.config().alpha, 0.1);
295        assert_eq!(regularizer.config().num_samples, 1);
296    }
297
298    #[test]
299    fn test_subword_regularizer_encode() {
300        let tokenizer = CharTokenizer::from_text("hello world", 1000);
301        let config = SubwordRegularizationConfig::default();
302        let regularizer = SubwordRegularizer::new(tokenizer, config);
303
304        let result = regularizer.encode("hello");
305        assert!(result.is_ok());
306
307        let tokenized = result.expect("Operation failed in test");
308        assert!(!tokenized.input_ids.is_empty());
309    }
310
311    #[test]
312    fn test_subword_regularizer_with_seed() {
313        let tokenizer = CharTokenizer::from_text("hello world", 1000);
314        let config = SubwordRegularizationConfig::default();
315        let mut regularizer = SubwordRegularizer::new(tokenizer, config).with_seed(42);
316
317        let result1 = regularizer.encode_with_regularization("hello world");
318        assert!(result1.is_ok());
319
320        // Reset with same seed
321        let tokenizer2 = CharTokenizer::from_text("hello world", 1000);
322        let config2 = SubwordRegularizationConfig::default();
323        let mut regularizer2 = SubwordRegularizer::new(tokenizer2, config2).with_seed(42);
324
325        let result2 = regularizer2.encode_with_regularization("hello world");
326        assert!(result2.is_ok());
327    }
328
329    #[test]
330    fn test_subword_regularizer_multiple_samples() {
331        let tokenizer = CharTokenizer::from_text("hello world", 1000);
332        let config = SubwordRegularizationConfig::default();
333        let mut regularizer =
334            SubwordRegularizer::new(tokenizer, config).with_num_samples(3).with_alpha(0.2);
335
336        let results = regularizer.encode_with_regularization("hello world");
337        assert!(results.is_ok());
338
339        let tokenized_results = results.expect("Operation failed in test");
340        assert_eq!(tokenized_results.len(), 3);
341
342        for result in tokenized_results {
343            assert!(!result.input_ids.is_empty());
344        }
345    }
346
347    #[test]
348    fn test_unigram_subword_regularizer() {
349        let mut vocab = HashMap::new();
350        vocab.insert("hello".to_string(), 1.0);
351        vocab.insert("world".to_string(), 1.0);
352        vocab.insert("h".to_string(), 0.5);
353        vocab.insert("e".to_string(), 0.5);
354        vocab.insert("l".to_string(), 0.5);
355        vocab.insert("o".to_string(), 0.5);
356
357        let config = SubwordRegularizationConfig::default();
358        let mut regularizer = UnigramSubwordRegularizer::new(vocab, config);
359
360        let result = regularizer.sample_segmentation("hello");
361        assert!(result.is_ok());
362
363        let segmentation = result.expect("Operation failed in test");
364        assert!(!segmentation.is_empty());
365    }
366
367    #[test]
368    fn test_unigram_regularizer_with_alpha() {
369        let mut vocab = HashMap::new();
370        vocab.insert("test".to_string(), 1.0);
371        vocab.insert("t".to_string(), 0.3);
372        vocab.insert("e".to_string(), 0.3);
373        vocab.insert("s".to_string(), 0.3);
374
375        let config = SubwordRegularizationConfig {
376            alpha: 0.5,
377            num_samples: 1,
378            seed: Some(123),
379            debug: false,
380        };
381
382        let mut regularizer = UnigramSubwordRegularizer::new(vocab, config);
383
384        let result1 = regularizer.sample_segmentation("test");
385        assert!(result1.is_ok());
386
387        // Results should be different due to regularization
388        let result2 = regularizer.sample_segmentation("test");
389        assert!(result2.is_ok());
390    }
391
392    #[test]
393    fn test_regularization_config_serialization() {
394        let config = SubwordRegularizationConfig {
395            alpha: 0.3,
396            num_samples: 5,
397            seed: Some(42),
398            debug: true,
399        };
400
401        let serialized = serde_json::to_string(&config).expect("Serialization failed");
402        let deserialized: SubwordRegularizationConfig =
403            serde_json::from_str(&serialized).expect("Deserialization failed");
404
405        assert_eq!(config.alpha, deserialized.alpha);
406        assert_eq!(config.num_samples, deserialized.num_samples);
407        assert_eq!(config.seed, deserialized.seed);
408        assert_eq!(config.debug, deserialized.debug);
409    }
410
411    #[test]
412    fn test_apply_regularization() {
413        let tokenizer = CharTokenizer::from_text("hello world", 1000);
414        let config = SubwordRegularizationConfig {
415            alpha: 0.0, // No regularization
416            num_samples: 1,
417            seed: Some(42),
418            debug: false,
419        };
420
421        let mut regularizer = SubwordRegularizer::new(tokenizer, config);
422        let result = regularizer.apply_regularization("hello");
423        assert_eq!(result, "hello");
424
425        // With regularization
426        regularizer.config.alpha = 0.5;
427        let result_with_reg = regularizer.apply_regularization("hello");
428        // Result might be different due to randomness
429        assert!(!result_with_reg.is_empty());
430    }
431}