1use scirs2_core::random::*; use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use trustformers_core::errors::Result;
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SubwordRegularizationConfig {
10 pub alpha: f32,
12 pub num_samples: usize,
14 pub seed: Option<u64>,
16 pub debug: bool,
18}
19
20impl Default for SubwordRegularizationConfig {
21 fn default() -> Self {
22 Self {
23 alpha: 0.1,
24 num_samples: 1,
25 seed: None,
26 debug: false,
27 }
28 }
29}
30
31pub struct SubwordRegularizer<T: Tokenizer> {
33 tokenizer: T,
34 config: SubwordRegularizationConfig,
35 rng: StdRng,
36}
37
38impl<T: Tokenizer> SubwordRegularizer<T> {
39 pub fn new(tokenizer: T, config: SubwordRegularizationConfig) -> Self {
40 let rng = if let Some(seed) = config.seed {
41 StdRng::seed_from_u64(seed)
42 } else {
43 let seed = thread_rng().random();
45 StdRng::seed_from_u64(seed)
46 };
47
48 Self {
49 tokenizer,
50 config,
51 rng,
52 }
53 }
54
55 pub fn with_alpha(mut self, alpha: f32) -> Self {
56 self.config.alpha = alpha;
57 self
58 }
59
60 pub fn with_num_samples(mut self, num_samples: usize) -> Self {
61 self.config.num_samples = num_samples;
62 self
63 }
64
65 pub fn with_seed(mut self, seed: u64) -> Self {
66 self.config.seed = Some(seed);
67 self.rng = StdRng::seed_from_u64(seed);
68 self
69 }
70
71 pub fn encode_with_regularization(&mut self, text: &str) -> Result<Vec<TokenizedInput>> {
73 let mut results = Vec::new();
74
75 for _ in 0..self.config.num_samples {
76 let regularized_text = self.apply_regularization(text);
77 let tokenized = self.tokenizer.encode(®ularized_text)?;
78 results.push(tokenized);
79 }
80
81 Ok(results)
82 }
83
84 fn apply_regularization(&mut self, text: &str) -> String {
86 if self.config.alpha <= 0.0 {
87 return text.to_string();
88 }
89
90 let mut result = String::new();
91 let chars: Vec<char> = text.chars().collect();
92 let mut i = 0;
93
94 while i < chars.len() {
95 let char = chars[i];
96
97 if self.rng.random::<f32>() < self.config.alpha {
99 if self.rng.random::<f32>() < 0.1 {
101 i += 1;
102 continue;
103 }
104
105 if self.rng.random::<f32>() < 0.05 {
107 result.push(char);
108 result.push(char);
109 i += 1;
110 continue;
111 }
112 }
113
114 result.push(char);
115 i += 1;
116 }
117
118 result
119 }
120
121 pub fn inner(&self) -> &T {
123 &self.tokenizer
124 }
125
126 pub fn config(&self) -> &SubwordRegularizationConfig {
128 &self.config
129 }
130}
131
132impl<T: Tokenizer> Tokenizer for SubwordRegularizer<T> {
133 fn encode(&self, text: &str) -> Result<TokenizedInput> {
134 self.tokenizer.encode(text)
136 }
137
138 fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
139 self.tokenizer.encode_pair(text, text2)
140 }
141
142 fn decode(&self, ids: &[u32]) -> Result<String> {
143 self.tokenizer.decode(ids)
144 }
145
146 fn vocab_size(&self) -> usize {
147 self.tokenizer.vocab_size()
148 }
149
150 fn get_vocab(&self) -> HashMap<String, u32> {
151 self.tokenizer.get_vocab()
152 }
153
154 fn token_to_id(&self, token: &str) -> Option<u32> {
155 self.tokenizer.token_to_id(token)
156 }
157
158 fn id_to_token(&self, id: u32) -> Option<String> {
159 self.tokenizer.id_to_token(id)
160 }
161}
162
163pub struct UnigramSubwordRegularizer {
165 vocab: HashMap<String, f32>,
166 config: SubwordRegularizationConfig,
167 rng: StdRng,
168}
169
170impl UnigramSubwordRegularizer {
171 pub fn new(vocab: HashMap<String, f32>, config: SubwordRegularizationConfig) -> Self {
172 let rng = if let Some(seed) = config.seed {
173 StdRng::seed_from_u64(seed)
174 } else {
175 let seed = thread_rng().random();
177 StdRng::seed_from_u64(seed)
178 };
179
180 Self { vocab, config, rng }
181 }
182
183 pub fn sample_segmentation(&mut self, text: &str) -> Result<Vec<String>> {
185 let chars: Vec<char> = text.chars().collect();
186 let n = chars.len();
187
188 if n == 0 {
189 return Ok(vec![]);
190 }
191
192 let mut dp = vec![vec![0.0; n + 1]; n + 1];
194 let mut best_seg = vec![vec![None; n + 1]; n + 1];
195
196 for (i, dp_row) in dp.iter_mut().enumerate().take(n + 1) {
198 dp_row[i] = 0.0;
199 }
200
201 for length in 1..=n {
203 for start in 0..=n - length {
204 let end = start + length;
205 let substring: String = chars[start..end].iter().collect();
206
207 if let Some(&score) = self.vocab.get(&substring) {
208 let regularized_score = if self.config.alpha > 0.0 {
210 let noise = self.rng.random::<f32>() * self.config.alpha;
211 score + noise - self.config.alpha / 2.0
212 } else {
213 score
214 };
215
216 if dp[start][end] < regularized_score {
217 dp[start][end] = regularized_score;
218 best_seg[start][end] = Some(substring);
219 }
220 }
221
222 for mid in start + 1..end {
224 let combined_score = dp[start][mid] + dp[mid][end];
225 if dp[start][end] < combined_score {
226 dp[start][end] = combined_score;
227 best_seg[start][end] = None; }
229 }
230 }
231 }
232
233 self.backtrack_segmentation(&best_seg, 0, n, &chars)
235 }
236
237 #[allow(clippy::only_used_in_recursion)]
238 fn backtrack_segmentation(
239 &self,
240 best_seg: &[Vec<Option<String>>],
241 start: usize,
242 end: usize,
243 chars: &[char],
244 ) -> Result<Vec<String>> {
245 if start == end {
246 return Ok(vec![]);
247 }
248
249 if let Some(ref segment) = best_seg[start][end] {
250 return Ok(vec![segment.clone()]);
251 }
252
253 let mut best_split = start + 1;
255 let mut best_score = f32::NEG_INFINITY;
256
257 for (mid, _) in best_seg.iter().enumerate().take(end).skip(start + 1) {
258 let score = best_seg[start][mid].as_ref().map(|_| 1.0).unwrap_or(0.0)
259 + best_seg[mid][end].as_ref().map(|_| 1.0).unwrap_or(0.0);
260 if score > best_score {
261 best_score = score;
262 best_split = mid;
263 }
264 }
265
266 let mut result = self.backtrack_segmentation(best_seg, start, best_split, chars)?;
267 let mut right_part = self.backtrack_segmentation(best_seg, best_split, end, chars)?;
268 result.append(&mut right_part);
269
270 Ok(result)
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::char::CharTokenizer;
278
279 #[test]
280 fn test_subword_regularization_config() {
281 let config = SubwordRegularizationConfig::default();
282 assert_eq!(config.alpha, 0.1);
283 assert_eq!(config.num_samples, 1);
284 assert_eq!(config.seed, None);
285 assert!(!config.debug);
286 }
287
288 #[test]
289 fn test_subword_regularizer_creation() {
290 let tokenizer = CharTokenizer::from_text("hello world", 1000);
291 let config = SubwordRegularizationConfig::default();
292 let regularizer = SubwordRegularizer::new(tokenizer, config);
293
294 assert_eq!(regularizer.config().alpha, 0.1);
295 assert_eq!(regularizer.config().num_samples, 1);
296 }
297
298 #[test]
299 fn test_subword_regularizer_encode() {
300 let tokenizer = CharTokenizer::from_text("hello world", 1000);
301 let config = SubwordRegularizationConfig::default();
302 let regularizer = SubwordRegularizer::new(tokenizer, config);
303
304 let result = regularizer.encode("hello");
305 assert!(result.is_ok());
306
307 let tokenized = result.expect("Operation failed in test");
308 assert!(!tokenized.input_ids.is_empty());
309 }
310
311 #[test]
312 fn test_subword_regularizer_with_seed() {
313 let tokenizer = CharTokenizer::from_text("hello world", 1000);
314 let config = SubwordRegularizationConfig::default();
315 let mut regularizer = SubwordRegularizer::new(tokenizer, config).with_seed(42);
316
317 let result1 = regularizer.encode_with_regularization("hello world");
318 assert!(result1.is_ok());
319
320 let tokenizer2 = CharTokenizer::from_text("hello world", 1000);
322 let config2 = SubwordRegularizationConfig::default();
323 let mut regularizer2 = SubwordRegularizer::new(tokenizer2, config2).with_seed(42);
324
325 let result2 = regularizer2.encode_with_regularization("hello world");
326 assert!(result2.is_ok());
327 }
328
329 #[test]
330 fn test_subword_regularizer_multiple_samples() {
331 let tokenizer = CharTokenizer::from_text("hello world", 1000);
332 let config = SubwordRegularizationConfig::default();
333 let mut regularizer =
334 SubwordRegularizer::new(tokenizer, config).with_num_samples(3).with_alpha(0.2);
335
336 let results = regularizer.encode_with_regularization("hello world");
337 assert!(results.is_ok());
338
339 let tokenized_results = results.expect("Operation failed in test");
340 assert_eq!(tokenized_results.len(), 3);
341
342 for result in tokenized_results {
343 assert!(!result.input_ids.is_empty());
344 }
345 }
346
347 #[test]
348 fn test_unigram_subword_regularizer() {
349 let mut vocab = HashMap::new();
350 vocab.insert("hello".to_string(), 1.0);
351 vocab.insert("world".to_string(), 1.0);
352 vocab.insert("h".to_string(), 0.5);
353 vocab.insert("e".to_string(), 0.5);
354 vocab.insert("l".to_string(), 0.5);
355 vocab.insert("o".to_string(), 0.5);
356
357 let config = SubwordRegularizationConfig::default();
358 let mut regularizer = UnigramSubwordRegularizer::new(vocab, config);
359
360 let result = regularizer.sample_segmentation("hello");
361 assert!(result.is_ok());
362
363 let segmentation = result.expect("Operation failed in test");
364 assert!(!segmentation.is_empty());
365 }
366
367 #[test]
368 fn test_unigram_regularizer_with_alpha() {
369 let mut vocab = HashMap::new();
370 vocab.insert("test".to_string(), 1.0);
371 vocab.insert("t".to_string(), 0.3);
372 vocab.insert("e".to_string(), 0.3);
373 vocab.insert("s".to_string(), 0.3);
374
375 let config = SubwordRegularizationConfig {
376 alpha: 0.5,
377 num_samples: 1,
378 seed: Some(123),
379 debug: false,
380 };
381
382 let mut regularizer = UnigramSubwordRegularizer::new(vocab, config);
383
384 let result1 = regularizer.sample_segmentation("test");
385 assert!(result1.is_ok());
386
387 let result2 = regularizer.sample_segmentation("test");
389 assert!(result2.is_ok());
390 }
391
392 #[test]
393 fn test_regularization_config_serialization() {
394 let config = SubwordRegularizationConfig {
395 alpha: 0.3,
396 num_samples: 5,
397 seed: Some(42),
398 debug: true,
399 };
400
401 let serialized = serde_json::to_string(&config).expect("Serialization failed");
402 let deserialized: SubwordRegularizationConfig =
403 serde_json::from_str(&serialized).expect("Deserialization failed");
404
405 assert_eq!(config.alpha, deserialized.alpha);
406 assert_eq!(config.num_samples, deserialized.num_samples);
407 assert_eq!(config.seed, deserialized.seed);
408 assert_eq!(config.debug, deserialized.debug);
409 }
410
411 #[test]
412 fn test_apply_regularization() {
413 let tokenizer = CharTokenizer::from_text("hello world", 1000);
414 let config = SubwordRegularizationConfig {
415 alpha: 0.0, num_samples: 1,
417 seed: Some(42),
418 debug: false,
419 };
420
421 let mut regularizer = SubwordRegularizer::new(tokenizer, config);
422 let result = regularizer.apply_regularization("hello");
423 assert_eq!(result, "hello");
424
425 regularizer.config.alpha = 0.5;
427 let result_with_reg = regularizer.apply_regularization("hello");
428 assert!(!result_with_reg.is_empty());
430 }
431}