1use anyhow::Result;
2use tokenizers::Tokenizer;
3
4#[cfg(feature = "hf-http")]
5use tokenizers::FromPretrainedParameters;
6
7#[derive(Clone, Copy, Debug)]
8pub enum RandomMode {
9 English,
10 Chinese,
11}
12
13#[derive(Clone, Copy, Debug)]
14pub struct CandidateFilter {
15 pub allow_digits: bool,
16 pub allow_punct: bool,
17 pub allow_space: bool,
18 pub allow_non_ascii: bool,
19 pub require_letter: bool,
20 pub require_cjk: bool,
21}
22
23impl CandidateFilter {
24 pub fn english_default() -> Self {
25 Self {
26 allow_digits: false,
27 allow_punct: false,
28 allow_space: false,
29 allow_non_ascii: false,
30 require_letter: true,
31 require_cjk: false,
32 }
33 }
34
35 pub fn chinese_default() -> Self {
36 Self {
37 allow_digits: false,
38 allow_punct: false,
39 allow_space: false,
40 allow_non_ascii: true,
41 require_letter: false,
42 require_cjk: true,
43 }
44 }
45}
46
47#[derive(Clone, Copy, Debug)]
48pub struct FilterConfig {
49 pub english: CandidateFilter,
50 pub chinese: CandidateFilter,
51}
52
53impl FilterConfig {
54 pub fn filter(&self, mode: RandomMode) -> CandidateFilter {
55 match mode {
56 RandomMode::English => self.english,
57 RandomMode::Chinese => self.chinese,
58 }
59 }
60
61 pub fn filter_mut(&mut self, mode: RandomMode) -> &mut CandidateFilter {
62 match mode {
63 RandomMode::English => &mut self.english,
64 RandomMode::Chinese => &mut self.chinese,
65 }
66 }
67}
68
69impl Default for FilterConfig {
70 fn default() -> Self {
71 Self {
72 english: CandidateFilter::english_default(),
73 chinese: CandidateFilter::chinese_default(),
74 }
75 }
76}
77
78#[derive(Clone, Copy, Debug)]
79pub struct EngineOptions {
80 pub skip_special_tokens: bool,
81 pub filters: FilterConfig,
82}
83
84impl Default for EngineOptions {
85 fn default() -> Self {
86 Self {
87 skip_special_tokens: true,
88 filters: FilterConfig::default(),
89 }
90 }
91}
92
93#[derive(Clone, Copy, Debug)]
94pub enum TextRepeat {
95 Once,
96 Finite(usize),
97 Infinite,
98}
99
100#[derive(Clone, Debug)]
101pub struct TokenPiece {
102 pub id: u32,
103 pub text: String,
104}
105
106#[derive(Debug)]
107pub struct TokenizerEngine {
108 tokenizer: Tokenizer,
109 english_ids: Vec<u32>,
110 chinese_ids: Vec<u32>,
111 options: EngineOptions,
112}
113
114impl TokenizerEngine {
115 pub fn new(tokenizer: Tokenizer) -> Result<Self> {
116 Self::new_with_options(tokenizer, EngineOptions::default())
117 }
118
119 pub fn new_with_options(tokenizer: Tokenizer, options: EngineOptions) -> Result<Self> {
120 let (english_ids, chinese_ids) = build_candidate_ids(&tokenizer, &options)?;
121 Ok(Self {
122 tokenizer,
123 english_ids,
124 chinese_ids,
125 options,
126 })
127 }
128
129 pub fn tokenizer(&self) -> &Tokenizer {
130 &self.tokenizer
131 }
132
133 pub fn set_options(&mut self, options: EngineOptions) -> Result<()> {
134 let (english_ids, chinese_ids) = build_candidate_ids(&self.tokenizer, &options)?;
135 self.english_ids = english_ids;
136 self.chinese_ids = chinese_ids;
137 self.options = options;
138 Ok(())
139 }
140
141 pub fn decode_id(&self, id: u32) -> Result<String> {
142 self.tokenizer
143 .decode(&[id], self.options.skip_special_tokens)
144 .map_err(|err| anyhow::anyhow!("decode token id failed: {err}"))
145 }
146
147 pub fn encode_text(&self, text: &str) -> Result<Vec<u32>> {
148 let encoding = self
149 .tokenizer
150 .encode(text, false)
151 .map_err(|err| anyhow::anyhow!("encode text failed: {err}"))?;
152 Ok(encoding.get_ids().to_vec())
153 }
154
155 pub fn random_token(&self, rng: &mut SimpleRng, mode: RandomMode) -> Option<TokenPiece> {
156 let ids = match mode {
157 RandomMode::English => &self.english_ids,
158 RandomMode::Chinese => &self.chinese_ids,
159 };
160 if ids.is_empty() {
161 return None;
162 }
163
164 for _ in 0..8 {
165 let index = rng.gen_usize(ids.len());
166 let id = ids[index];
167 if let Ok(text) = self.decode_id(id)
168 && !text.is_empty()
169 {
170 return Some(TokenPiece { id, text });
171 }
172 }
173 None
174 }
175
176 pub fn text_stream(&self, text: &str, repeat: TextRepeat) -> Result<TextStream> {
177 let ids = self.encode_text(text)?;
178 Ok(TextStream::new(ids, repeat))
179 }
180}
181
182pub fn tokenizer_from_json_bytes(bytes: &[u8]) -> Result<Tokenizer> {
183 Tokenizer::from_bytes(bytes)
184 .map_err(|err| anyhow::anyhow!("load tokenizer from bytes failed: {err}"))
185}
186
187#[cfg(feature = "hf-http")]
188pub fn tokenizer_from_hub(
189 model: &str,
190 revision: Option<&str>,
191 token: Option<&str>,
192) -> Result<Tokenizer> {
193 let mut params = FromPretrainedParameters::default();
194 if let Some(revision) = revision {
195 params.revision = revision.to_string();
196 }
197 if let Some(token) = token {
198 params.token = Some(token.to_string());
199 }
200 Tokenizer::from_pretrained(model, Some(params))
201 .map_err(|err| anyhow::anyhow!("load tokenizer from hub failed: {err}"))
202}
203
204#[derive(Clone, Debug)]
205pub struct TextStream {
206 ids: Vec<u32>,
207 index: usize,
208 remaining_loops: Option<usize>,
209}
210
211impl TextStream {
212 fn new(ids: Vec<u32>, repeat: TextRepeat) -> Self {
213 let remaining_loops = match repeat {
214 TextRepeat::Once => Some(1),
215 TextRepeat::Finite(times) => Some(times),
216 TextRepeat::Infinite => None,
217 };
218 Self {
219 ids,
220 index: 0,
221 remaining_loops,
222 }
223 }
224
225 pub fn next_id(&mut self) -> Option<u32> {
226 if self.ids.is_empty() {
227 self.remaining_loops = Some(0);
228 return None;
229 }
230
231 if self.index >= self.ids.len() {
232 match self.remaining_loops {
233 Some(0) => return None,
234 Some(1) => {
235 self.remaining_loops = Some(0);
236 return None;
237 }
238 Some(left) => {
239 self.remaining_loops = Some(left.saturating_sub(1));
240 self.index = 0;
241 }
242 None => {
243 self.index = 0;
244 }
245 }
246 }
247
248 if self.index >= self.ids.len() {
249 return None;
250 }
251 let id = self.ids[self.index];
252 self.index += 1;
253 Some(id)
254 }
255
256 pub fn is_exhausted(&self) -> bool {
257 match self.remaining_loops {
258 Some(0) => true,
259 Some(1) => self.index >= self.ids.len(),
260 Some(_) => false,
261 None => false,
262 }
263 }
264}
265
266#[derive(Clone, Debug)]
267pub struct SimpleRng {
268 state: u64,
269}
270
271impl SimpleRng {
272 pub fn new(seed: u64) -> Self {
273 let seed = if seed == 0 { 0x9e3779b97f4a7c15 } else { seed };
274 Self { state: seed }
275 }
276
277 pub fn next_u64(&mut self) -> u64 {
278 let mut x = self.state;
279 x ^= x >> 12;
280 x ^= x << 25;
281 x ^= x >> 27;
282 self.state = x;
283 x.wrapping_mul(0x2545f4914f6cdd1d)
284 }
285
286 pub fn gen_usize(&mut self, upper: usize) -> usize {
287 if upper == 0 {
288 return 0;
289 }
290 (self.next_u64() % upper as u64) as usize
291 }
292
293 pub fn gen_f64(&mut self) -> f64 {
294 let value = self.next_u64() >> 11;
295 (value as f64) / ((1u64 << 53) as f64)
296 }
297
298 pub fn gen_f64_range(&mut self, min: f64, max: f64) -> f64 {
299 if min >= max {
300 return min;
301 }
302 min + (max - min) * self.gen_f64()
303 }
304}
305
306fn build_candidate_ids(
307 tokenizer: &Tokenizer,
308 options: &EngineOptions,
309) -> Result<(Vec<u32>, Vec<u32>)> {
310 let vocab = tokenizer.get_vocab(true);
311 let mut english_ids = Vec::new();
312 let mut chinese_ids = Vec::new();
313
314 for id in vocab.values() {
315 let Ok(text) = tokenizer.decode(&[*id], options.skip_special_tokens) else {
316 continue;
317 };
318 if !is_printable_candidate(&text) {
319 continue;
320 }
321 if matches_filter(&text, options.filters.english) {
322 english_ids.push(*id);
323 }
324 if matches_filter(&text, options.filters.chinese) {
325 chinese_ids.push(*id);
326 }
327 }
328
329 Ok((english_ids, chinese_ids))
330}
331
332fn is_printable_candidate(text: &str) -> bool {
333 if text.is_empty() || text.contains('\u{FFFD}') {
334 return false;
335 }
336 !text.chars().any(|c| c.is_control())
337}
338
339fn matches_filter(text: &str, filter: CandidateFilter) -> bool {
340 let mut stats = TextStats::default();
341 for c in text.chars() {
342 if c.is_control() {
343 return false;
344 }
345 if c.is_whitespace() {
346 stats.has_space = true;
347 continue;
348 }
349 if c.is_ascii() {
350 if c.is_ascii_alphabetic() {
351 stats.has_letter = true;
352 } else if c.is_ascii_digit() {
353 stats.has_digit = true;
354 } else {
355 stats.has_punct = true;
356 }
357 continue;
358 }
359
360 stats.has_non_ascii = true;
361 if is_cjk_char(c) {
362 stats.has_cjk = true;
363 }
364 }
365
366 if !filter.allow_non_ascii && stats.has_non_ascii {
367 return false;
368 }
369 if !filter.allow_digits && stats.has_digit {
370 return false;
371 }
372 if !filter.allow_punct && stats.has_punct {
373 return false;
374 }
375 if !filter.allow_space && stats.has_space {
376 return false;
377 }
378 if filter.require_letter && !stats.has_letter {
379 return false;
380 }
381 if filter.require_cjk && !stats.has_cjk {
382 return false;
383 }
384
385 true
386}
387
388#[derive(Default)]
389struct TextStats {
390 has_letter: bool,
391 has_cjk: bool,
392 has_digit: bool,
393 has_punct: bool,
394 has_space: bool,
395 has_non_ascii: bool,
396}
397
398fn is_cjk_char(c: char) -> bool {
399 matches!(
400 c as u32,
401 0x3400..=0x4DBF
402 | 0x4E00..=0x9FFF
403 | 0xF900..=0xFAFF
404 | 0x20000..=0x2A6DF
405 | 0x2A700..=0x2B73F
406 | 0x2B740..=0x2B81F
407 | 0x2B820..=0x2CEAF
408 )
409}
410
411#[cfg(test)]
412mod tests {
413 use super::{SimpleRng, TextRepeat, TextStream};
414
415 #[test]
416 fn text_stream_once() {
417 let mut stream = TextStream::new(vec![1, 2], TextRepeat::Once);
418 assert_eq!(stream.next_id(), Some(1));
419 assert_eq!(stream.next_id(), Some(2));
420 assert_eq!(stream.next_id(), None);
421 assert!(stream.is_exhausted());
422 }
423
424 #[test]
425 fn text_stream_finite() {
426 let mut stream = TextStream::new(vec![7, 8], TextRepeat::Finite(2));
427 assert_eq!(stream.next_id(), Some(7));
428 assert_eq!(stream.next_id(), Some(8));
429 assert_eq!(stream.next_id(), Some(7));
430 assert_eq!(stream.next_id(), Some(8));
431 assert_eq!(stream.next_id(), None);
432 assert!(stream.is_exhausted());
433 }
434
435 #[test]
436 fn simple_rng_is_deterministic() {
437 let mut rng = SimpleRng::new(42);
438 let first = rng.next_u64();
439 let second = rng.next_u64();
440 let mut rng_again = SimpleRng::new(42);
441 assert_eq!(first, rng_again.next_u64());
442 assert_eq!(second, rng_again.next_u64());
443 }
444}