tokstream_core/
lib.rs

1use anyhow::Result;
2use tokenizers::Tokenizer;
3
4#[cfg(feature = "hf-http")]
5use tokenizers::FromPretrainedParameters;
6
7#[derive(Clone, Copy, Debug)]
8pub enum RandomMode {
9    English,
10    Chinese,
11}
12
13#[derive(Clone, Copy, Debug)]
14pub struct CandidateFilter {
15    pub allow_digits: bool,
16    pub allow_punct: bool,
17    pub allow_space: bool,
18    pub allow_non_ascii: bool,
19    pub require_letter: bool,
20    pub require_cjk: bool,
21}
22
23impl CandidateFilter {
24    pub fn english_default() -> Self {
25        Self {
26            allow_digits: false,
27            allow_punct: false,
28            allow_space: false,
29            allow_non_ascii: false,
30            require_letter: true,
31            require_cjk: false,
32        }
33    }
34
35    pub fn chinese_default() -> Self {
36        Self {
37            allow_digits: false,
38            allow_punct: false,
39            allow_space: false,
40            allow_non_ascii: true,
41            require_letter: false,
42            require_cjk: true,
43        }
44    }
45}
46
47#[derive(Clone, Copy, Debug)]
48pub struct FilterConfig {
49    pub english: CandidateFilter,
50    pub chinese: CandidateFilter,
51}
52
53impl FilterConfig {
54    pub fn filter(&self, mode: RandomMode) -> CandidateFilter {
55        match mode {
56            RandomMode::English => self.english,
57            RandomMode::Chinese => self.chinese,
58        }
59    }
60
61    pub fn filter_mut(&mut self, mode: RandomMode) -> &mut CandidateFilter {
62        match mode {
63            RandomMode::English => &mut self.english,
64            RandomMode::Chinese => &mut self.chinese,
65        }
66    }
67}
68
69impl Default for FilterConfig {
70    fn default() -> Self {
71        Self {
72            english: CandidateFilter::english_default(),
73            chinese: CandidateFilter::chinese_default(),
74        }
75    }
76}
77
78#[derive(Clone, Copy, Debug)]
79pub struct EngineOptions {
80    pub skip_special_tokens: bool,
81    pub filters: FilterConfig,
82}
83
84impl Default for EngineOptions {
85    fn default() -> Self {
86        Self {
87            skip_special_tokens: true,
88            filters: FilterConfig::default(),
89        }
90    }
91}
92
93#[derive(Clone, Copy, Debug)]
94pub enum TextRepeat {
95    Once,
96    Finite(usize),
97    Infinite,
98}
99
100#[derive(Clone, Debug)]
101pub struct TokenPiece {
102    pub id: u32,
103    pub text: String,
104}
105
106#[derive(Debug)]
107pub struct TokenizerEngine {
108    tokenizer: Tokenizer,
109    english_ids: Vec<u32>,
110    chinese_ids: Vec<u32>,
111    options: EngineOptions,
112}
113
114impl TokenizerEngine {
115    pub fn new(tokenizer: Tokenizer) -> Result<Self> {
116        Self::new_with_options(tokenizer, EngineOptions::default())
117    }
118
119    pub fn new_with_options(tokenizer: Tokenizer, options: EngineOptions) -> Result<Self> {
120        let (english_ids, chinese_ids) = build_candidate_ids(&tokenizer, &options)?;
121        Ok(Self {
122            tokenizer,
123            english_ids,
124            chinese_ids,
125            options,
126        })
127    }
128
129    pub fn tokenizer(&self) -> &Tokenizer {
130        &self.tokenizer
131    }
132
133    pub fn set_options(&mut self, options: EngineOptions) -> Result<()> {
134        let (english_ids, chinese_ids) = build_candidate_ids(&self.tokenizer, &options)?;
135        self.english_ids = english_ids;
136        self.chinese_ids = chinese_ids;
137        self.options = options;
138        Ok(())
139    }
140
141    pub fn decode_id(&self, id: u32) -> Result<String> {
142        self.tokenizer
143            .decode(&[id], self.options.skip_special_tokens)
144            .map_err(|err| anyhow::anyhow!("decode token id failed: {err}"))
145    }
146
147    pub fn encode_text(&self, text: &str) -> Result<Vec<u32>> {
148        let encoding = self
149            .tokenizer
150            .encode(text, false)
151            .map_err(|err| anyhow::anyhow!("encode text failed: {err}"))?;
152        Ok(encoding.get_ids().to_vec())
153    }
154
155    pub fn random_token(&self, rng: &mut SimpleRng, mode: RandomMode) -> Option<TokenPiece> {
156        let ids = match mode {
157            RandomMode::English => &self.english_ids,
158            RandomMode::Chinese => &self.chinese_ids,
159        };
160        if ids.is_empty() {
161            return None;
162        }
163
164        for _ in 0..8 {
165            let index = rng.gen_usize(ids.len());
166            let id = ids[index];
167            if let Ok(text) = self.decode_id(id)
168                && !text.is_empty()
169            {
170                return Some(TokenPiece { id, text });
171            }
172        }
173        None
174    }
175
176    pub fn text_stream(&self, text: &str, repeat: TextRepeat) -> Result<TextStream> {
177        let ids = self.encode_text(text)?;
178        Ok(TextStream::new(ids, repeat))
179    }
180}
181
182pub fn tokenizer_from_json_bytes(bytes: &[u8]) -> Result<Tokenizer> {
183    Tokenizer::from_bytes(bytes)
184        .map_err(|err| anyhow::anyhow!("load tokenizer from bytes failed: {err}"))
185}
186
187#[cfg(feature = "hf-http")]
188pub fn tokenizer_from_hub(
189    model: &str,
190    revision: Option<&str>,
191    token: Option<&str>,
192) -> Result<Tokenizer> {
193    let mut params = FromPretrainedParameters::default();
194    if let Some(revision) = revision {
195        params.revision = revision.to_string();
196    }
197    if let Some(token) = token {
198        params.token = Some(token.to_string());
199    }
200    Tokenizer::from_pretrained(model, Some(params))
201        .map_err(|err| anyhow::anyhow!("load tokenizer from hub failed: {err}"))
202}
203
204#[derive(Clone, Debug)]
205pub struct TextStream {
206    ids: Vec<u32>,
207    index: usize,
208    remaining_loops: Option<usize>,
209}
210
211impl TextStream {
212    fn new(ids: Vec<u32>, repeat: TextRepeat) -> Self {
213        let remaining_loops = match repeat {
214            TextRepeat::Once => Some(1),
215            TextRepeat::Finite(times) => Some(times),
216            TextRepeat::Infinite => None,
217        };
218        Self {
219            ids,
220            index: 0,
221            remaining_loops,
222        }
223    }
224
225    pub fn next_id(&mut self) -> Option<u32> {
226        if self.ids.is_empty() {
227            self.remaining_loops = Some(0);
228            return None;
229        }
230
231        if self.index >= self.ids.len() {
232            match self.remaining_loops {
233                Some(0) => return None,
234                Some(1) => {
235                    self.remaining_loops = Some(0);
236                    return None;
237                }
238                Some(left) => {
239                    self.remaining_loops = Some(left.saturating_sub(1));
240                    self.index = 0;
241                }
242                None => {
243                    self.index = 0;
244                }
245            }
246        }
247
248        if self.index >= self.ids.len() {
249            return None;
250        }
251        let id = self.ids[self.index];
252        self.index += 1;
253        Some(id)
254    }
255
256    pub fn is_exhausted(&self) -> bool {
257        match self.remaining_loops {
258            Some(0) => true,
259            Some(1) => self.index >= self.ids.len(),
260            Some(_) => false,
261            None => false,
262        }
263    }
264}
265
266#[derive(Clone, Debug)]
267pub struct SimpleRng {
268    state: u64,
269}
270
271impl SimpleRng {
272    pub fn new(seed: u64) -> Self {
273        let seed = if seed == 0 { 0x9e3779b97f4a7c15 } else { seed };
274        Self { state: seed }
275    }
276
277    pub fn next_u64(&mut self) -> u64 {
278        let mut x = self.state;
279        x ^= x >> 12;
280        x ^= x << 25;
281        x ^= x >> 27;
282        self.state = x;
283        x.wrapping_mul(0x2545f4914f6cdd1d)
284    }
285
286    pub fn gen_usize(&mut self, upper: usize) -> usize {
287        if upper == 0 {
288            return 0;
289        }
290        (self.next_u64() % upper as u64) as usize
291    }
292
293    pub fn gen_f64(&mut self) -> f64 {
294        let value = self.next_u64() >> 11;
295        (value as f64) / ((1u64 << 53) as f64)
296    }
297
298    pub fn gen_f64_range(&mut self, min: f64, max: f64) -> f64 {
299        if min >= max {
300            return min;
301        }
302        min + (max - min) * self.gen_f64()
303    }
304}
305
306fn build_candidate_ids(
307    tokenizer: &Tokenizer,
308    options: &EngineOptions,
309) -> Result<(Vec<u32>, Vec<u32>)> {
310    let vocab = tokenizer.get_vocab(true);
311    let mut english_ids = Vec::new();
312    let mut chinese_ids = Vec::new();
313
314    for id in vocab.values() {
315        let Ok(text) = tokenizer.decode(&[*id], options.skip_special_tokens) else {
316            continue;
317        };
318        if !is_printable_candidate(&text) {
319            continue;
320        }
321        if matches_filter(&text, options.filters.english) {
322            english_ids.push(*id);
323        }
324        if matches_filter(&text, options.filters.chinese) {
325            chinese_ids.push(*id);
326        }
327    }
328
329    Ok((english_ids, chinese_ids))
330}
331
332fn is_printable_candidate(text: &str) -> bool {
333    if text.is_empty() || text.contains('\u{FFFD}') {
334        return false;
335    }
336    !text.chars().any(|c| c.is_control())
337}
338
339fn matches_filter(text: &str, filter: CandidateFilter) -> bool {
340    let mut stats = TextStats::default();
341    for c in text.chars() {
342        if c.is_control() {
343            return false;
344        }
345        if c.is_whitespace() {
346            stats.has_space = true;
347            continue;
348        }
349        if c.is_ascii() {
350            if c.is_ascii_alphabetic() {
351                stats.has_letter = true;
352            } else if c.is_ascii_digit() {
353                stats.has_digit = true;
354            } else {
355                stats.has_punct = true;
356            }
357            continue;
358        }
359
360        stats.has_non_ascii = true;
361        if is_cjk_char(c) {
362            stats.has_cjk = true;
363        }
364    }
365
366    if !filter.allow_non_ascii && stats.has_non_ascii {
367        return false;
368    }
369    if !filter.allow_digits && stats.has_digit {
370        return false;
371    }
372    if !filter.allow_punct && stats.has_punct {
373        return false;
374    }
375    if !filter.allow_space && stats.has_space {
376        return false;
377    }
378    if filter.require_letter && !stats.has_letter {
379        return false;
380    }
381    if filter.require_cjk && !stats.has_cjk {
382        return false;
383    }
384
385    true
386}
387
388#[derive(Default)]
389struct TextStats {
390    has_letter: bool,
391    has_cjk: bool,
392    has_digit: bool,
393    has_punct: bool,
394    has_space: bool,
395    has_non_ascii: bool,
396}
397
398fn is_cjk_char(c: char) -> bool {
399    matches!(
400        c as u32,
401        0x3400..=0x4DBF
402            | 0x4E00..=0x9FFF
403            | 0xF900..=0xFAFF
404            | 0x20000..=0x2A6DF
405            | 0x2A700..=0x2B73F
406            | 0x2B740..=0x2B81F
407            | 0x2B820..=0x2CEAF
408    )
409}
410
411#[cfg(test)]
412mod tests {
413    use super::{SimpleRng, TextRepeat, TextStream};
414
415    #[test]
416    fn text_stream_once() {
417        let mut stream = TextStream::new(vec![1, 2], TextRepeat::Once);
418        assert_eq!(stream.next_id(), Some(1));
419        assert_eq!(stream.next_id(), Some(2));
420        assert_eq!(stream.next_id(), None);
421        assert!(stream.is_exhausted());
422    }
423
424    #[test]
425    fn text_stream_finite() {
426        let mut stream = TextStream::new(vec![7, 8], TextRepeat::Finite(2));
427        assert_eq!(stream.next_id(), Some(7));
428        assert_eq!(stream.next_id(), Some(8));
429        assert_eq!(stream.next_id(), Some(7));
430        assert_eq!(stream.next_id(), Some(8));
431        assert_eq!(stream.next_id(), None);
432        assert!(stream.is_exhausted());
433    }
434
435    #[test]
436    fn simple_rng_is_deterministic() {
437        let mut rng = SimpleRng::new(42);
438        let first = rng.next_u64();
439        let second = rng.next_u64();
440        let mut rng_again = SimpleRng::new(42);
441        assert_eq!(first, rng_again.next_u64());
442        assert_eq!(second, rng_again.next_u64());
443    }
444}