Skip to main content

tokenx_rs/
estimator.rs

1//! Core token estimation logic.
2//!
3//! Single-pass scanner that classifies characters inline using simple
4//! char-level checks. Zero allocations, zero regex, zero dependencies.
5
6use crate::config::EstimationOptions;
7
8// ── char classification helpers ─────────────────────────────────────────
9
10#[inline(always)]
11fn is_cjk(c: char) -> bool {
12    matches!(c,
13        '\u{4E00}'..='\u{9FFF}'   // CJK Unified Ideographs
14        | '\u{3400}'..='\u{4DBF}' // CJK Unified Ideographs Extension A
15        | '\u{3000}'..='\u{303F}' // CJK Symbols and Punctuation
16        | '\u{FF00}'..='\u{FFEF}' // Halfwidth and Fullwidth Forms
17        | '\u{30A0}'..='\u{30FF}' // Katakana
18        | '\u{2E80}'..='\u{2EFF}' // CJK Radicals Supplement
19        | '\u{31C0}'..='\u{31EF}' // CJK Strokes
20        | '\u{3200}'..='\u{32FF}' // Enclosed CJK Letters and Months
21        | '\u{3300}'..='\u{33FF}' // CJK Compatibility
22        | '\u{AC00}'..='\u{D7AF}' // Hangul Syllables
23        | '\u{1100}'..='\u{11FF}' // Hangul Jamo
24        | '\u{3130}'..='\u{318F}' // Hangul Compatibility Jamo
25        | '\u{A960}'..='\u{A97F}' // Hangul Jamo Extended-A
26        | '\u{D7B0}'..='\u{D7FF}' // Hangul Jamo Extended-B
27    )
28}
29
30#[inline(always)]
31fn is_punctuation(c: char) -> bool {
32    matches!(
33        c,
34        '.' | ','
35            | '!'
36            | '?'
37            | ';'
38            | '\''
39            | '"'
40            | '\u{201E}' // „
41            | '\u{201C}' // "
42            | '\u{201D}' // "
43            | '\u{2018}' // '
44            | '\u{2019}' // '
45            | '-'
46            | '('
47            | ')'
48            | '{'
49            | '}'
50            | '['
51            | ']'
52            | '<'
53            | '>'
54            | ':'
55            | '/'
56            | '\\'
57            | '|'
58            | '@'
59            | '#'
60            | '$'
61            | '%'
62            | '^'
63            | '&'
64            | '*'
65            | '+'
66            | '='
67            | '`'
68            | '~'
69    )
70}
71
72#[inline(always)]
73fn is_alphanumeric_latin(c: char) -> bool {
74    c.is_ascii_alphanumeric()
75        || matches!(c, '\u{00C0}'..='\u{00D6}' | '\u{00D8}'..='\u{00F6}' | '\u{00F8}'..='\u{00FF}')
76}
77
78// ── split boundary classification ───────────────────────────────────────
79
80#[derive(Clone, Copy, PartialEq, Eq)]
81enum SplitKind {
82    Whitespace,
83    Punctuation,
84    Word,
85}
86
87#[inline(always)]
88fn split_classify(c: char) -> SplitKind {
89    if c.is_whitespace() {
90        SplitKind::Whitespace
91    } else if c.is_ascii() {
92        if is_punctuation(c) {
93            SplitKind::Punctuation
94        } else {
95            SplitKind::Word
96        }
97    } else if is_punctuation(c) {
98        SplitKind::Punctuation
99    } else {
100        SplitKind::Word
101    }
102}
103
104// ── segment scoring ─────────────────────────────────────────────────────
105
106#[inline(always)]
107fn score_word(
108    byte_len: usize,
109    char_count: usize,
110    has_cjk: bool,
111    all_alphanum: bool,
112    all_digits: bool,
113    lang_cpt: Option<f64>,
114    default_cpt: f64,
115) -> usize {
116    if has_cjk {
117        return char_count;
118    }
119    if all_digits {
120        return 1;
121    }
122    if byte_len <= 3 {
123        return 1;
124    }
125    if all_alphanum || lang_cpt.is_some() {
126        let cpt = lang_cpt.unwrap_or(default_cpt);
127        return (byte_len as f64 / cpt).ceil() as usize;
128    }
129    char_count
130}
131
132#[inline(always)]
133fn score_punctuation(byte_len: usize) -> usize {
134    if byte_len <= 3 {
135        1
136    } else {
137        (byte_len + 1) / 2
138    }
139}
140
141// ── public API ──────────────────────────────────────────────────────────
142
143/// Estimates the number of tokens in `text` using default options.
144///
145/// # Examples
146///
147/// ```
148/// use tokenx_rs::estimate_token_count;
149///
150/// assert_eq!(estimate_token_count(""), 0);
151/// assert!(estimate_token_count("Hello, world!") > 0);
152/// ```
153pub fn estimate_token_count(text: &str) -> usize {
154    estimate_token_count_with_options(text, &EstimationOptions::default())
155}
156
157/// Estimates the number of tokens in `text` using custom options.
158///
159/// # Examples
160///
161/// ```
162/// use tokenx_rs::{estimate_token_count_with_options, EstimationOptions};
163///
164/// let opts = EstimationOptions::default();
165/// let tokens = estimate_token_count_with_options("Hello, world!", &opts);
166/// assert!(tokens > 0);
167/// ```
168pub fn estimate_token_count_with_options(text: &str, options: &EstimationOptions) -> usize {
169    if text.is_empty() {
170        return 0;
171    }
172
173    let mut total_tokens: usize = 0;
174
175    let mut seg_split_kind = SplitKind::Word;
176    let mut seg_byte_len: usize = 0;
177    let mut seg_char_count: usize = 0;
178    let mut seg_has_cjk = false;
179    let mut seg_all_alphanum = true;
180    let mut seg_all_digits = true;
181    let mut seg_lang_cpt: Option<f64> = None;
182    let mut in_segment = false;
183
184    let default_cpt = options.default_chars_per_token;
185
186    macro_rules! flush {
187        () => {
188            total_tokens += match seg_split_kind {
189                SplitKind::Whitespace => 0,
190                SplitKind::Punctuation => score_punctuation(seg_byte_len),
191                SplitKind::Word => score_word(
192                    seg_byte_len,
193                    seg_char_count,
194                    seg_has_cjk,
195                    seg_all_alphanum,
196                    seg_all_digits,
197                    seg_lang_cpt,
198                    default_cpt,
199                ),
200            };
201        };
202    }
203
204    for c in text.chars() {
205        let kind = split_classify(c);
206
207        if in_segment && kind == seg_split_kind {
208            seg_byte_len += c.len_utf8();
209            seg_char_count += 1;
210            if kind == SplitKind::Word {
211                if is_cjk(c) {
212                    seg_has_cjk = true;
213                }
214                if !is_alphanumeric_latin(c) {
215                    seg_all_alphanum = false;
216                    seg_all_digits = false;
217                } else if !c.is_ascii_digit() {
218                    seg_all_digits = false;
219                }
220                if seg_lang_cpt.is_none() {
221                    seg_lang_cpt = detect_language_cpt(c, options);
222                }
223            }
224        } else {
225            if in_segment {
226                flush!();
227            }
228            seg_split_kind = kind;
229            seg_byte_len = c.len_utf8();
230            seg_char_count = 1;
231            seg_has_cjk = kind == SplitKind::Word && is_cjk(c);
232            seg_all_alphanum = kind != SplitKind::Word || is_alphanumeric_latin(c);
233            seg_all_digits = kind == SplitKind::Word && c.is_ascii_digit();
234            seg_lang_cpt = if kind == SplitKind::Word {
235                detect_language_cpt(c, options)
236            } else {
237                None
238            };
239            in_segment = true;
240        }
241    }
242
243    if in_segment {
244        flush!();
245    }
246
247    total_tokens
248}
249
250#[inline(always)]
251fn detect_language_cpt(c: char, options: &EstimationOptions) -> Option<f64> {
252    for lc in &options.language_configs {
253        if (lc.matcher)(c) {
254            return Some(lc.chars_per_token);
255        }
256    }
257    None
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::config::DEFAULT_CHARS_PER_TOKEN;
264
265    #[test]
266    fn empty_string() {
267        assert_eq!(estimate_token_count(""), 0);
268    }
269
270    #[test]
271    fn pure_whitespace() {
272        assert_eq!(estimate_token_count("   "), 0);
273        assert_eq!(estimate_token_count("\t\n"), 0);
274    }
275
276    #[test]
277    fn pure_cjk() {
278        assert_eq!(estimate_token_count("你好世界"), 4);
279    }
280
281    #[test]
282    fn pure_punctuation() {
283        assert_eq!(estimate_token_count("..."), 1);
284        assert_eq!(estimate_token_count(","), 1);
285    }
286
287    #[test]
288    fn numeric_string() {
289        assert_eq!(estimate_token_count("12345"), 1);
290        assert_eq!(estimate_token_count("3.14"), 3);
291    }
292
293    #[test]
294    fn short_words() {
295        assert_eq!(estimate_token_count("Hi Bob"), 2);
296    }
297
298    #[test]
299    fn mixed_content() {
300        let count = estimate_token_count("Hello, world!");
301        assert!(count >= 2, "Expected at least 2 tokens, got {count}");
302    }
303
304    #[test]
305    fn german_text() {
306        let count = estimate_token_count("Ärgerlich");
307        assert!(count > 0);
308    }
309
310    #[test]
311    fn french_text() {
312        let count = estimate_token_count("résumé");
313        assert!(count > 0);
314    }
315
316    #[test]
317    fn english_sentence() {
318        let count = estimate_token_count("The quick brown fox jumps over the lazy dog");
319        assert!(count >= 9, "Expected at least 9 tokens, got {count}");
320    }
321
322    #[test]
323    fn default_chars_per_token_constant() {
324        assert_eq!(DEFAULT_CHARS_PER_TOKEN, 6.0);
325    }
326
327    #[test]
328    fn underscore_identifiers() {
329        let count = estimate_token_count("process_items");
330        assert_eq!(count, 13); // 1 per char (not alphanumeric due to _)
331    }
332
333    #[test]
334    fn custom_options() {
335        let opts = EstimationOptions {
336            default_chars_per_token: 4.0,
337            language_configs: vec![],
338        };
339        let count = estimate_token_count_with_options("abcdefgh", &opts);
340        // ceil(8 / 4.0) = 2
341        assert_eq!(count, 2);
342    }
343}