Skip to main content

token_count/tokenizers/claude/
estimation.rs

1//! Adaptive token estimation for Claude models
2//!
3//! This module provides content-aware estimation that adapts based on whether
4//! the input is code, prose, or mixed content.
5
6/// Classification of text content for adaptive token estimation
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum ContentType {
9    /// Code-heavy content (>15% code indicators)
10    /// Examples: source files, JSON, config files
11    /// Token ratio: ~3.0 chars/token
12    Code,
13
14    /// Prose-heavy content (<5% code indicators)
15    /// Examples: documentation, natural language, articles
16    /// Token ratio: ~4.5 chars/token
17    Prose,
18
19    /// Mixed content (5-15% code indicators)
20    /// Examples: markdown with code blocks, technical docs
21    /// Token ratio: ~3.75 chars/token
22    Mixed,
23}
24
25impl ContentType {
26    /// Get the estimated characters per token for this content type
27    pub fn chars_per_token(&self) -> f64 {
28        match self {
29            Self::Code => 3.0,
30            Self::Prose => 4.5,
31            Self::Mixed => 3.75,
32        }
33    }
34}
35
36/// Detect the content type of the given text
37pub fn detect_content_type(text: &str) -> ContentType {
38    let total_chars = text.chars().count();
39    if total_chars == 0 {
40        return ContentType::Prose;
41    }
42
43    let code_indicators = count_code_indicators(text);
44    let ratio = code_indicators as f64 / total_chars as f64;
45
46    if ratio > 0.15 {
47        ContentType::Code
48    } else if ratio > 0.05 {
49        ContentType::Mixed
50    } else {
51        ContentType::Prose
52    }
53}
54
55/// Count code indicators in the text
56fn count_code_indicators(text: &str) -> usize {
57    let mut count = 0;
58
59    // Count structural code characters
60    for ch in text.chars() {
61        if matches!(ch, '{' | '}' | '[' | ']' | '(' | ')' | ';' | ':' | ',' | '<' | '>') {
62            count += 1;
63        }
64    }
65
66    // Count code keywords (simple substring matching)
67    let keywords = [
68        "fn ",
69        "def ",
70        "function ",
71        "const ",
72        "let ",
73        "var ",
74        "import ",
75        "export ",
76        "class ",
77        "struct ",
78        "enum ",
79        "impl ",
80        "trait ",
81        "interface ",
82        "if (",
83        "for (",
84        "while (",
85        "return ",
86        "async ",
87        "await ",
88        "//",
89        "/*",
90        "*/",
91        "#include",
92        "#define",
93    ];
94
95    for keyword in &keywords {
96        count += text.matches(keyword).count() * keyword.len();
97    }
98
99    count
100}
101
102/// Estimate token count using adaptive algorithm
103pub fn estimate_tokens(text: &str) -> usize {
104    if text.is_empty() {
105        return 0;
106    }
107
108    let content_type = detect_content_type(text);
109    let char_count = text.chars().count();
110    let ratio = content_type.chars_per_token();
111
112    (char_count as f64 / ratio).ceil() as usize
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_estimate_empty() {
121        assert_eq!(estimate_tokens(""), 0);
122    }
123
124    #[test]
125    fn test_estimate_single_char() {
126        assert_eq!(estimate_tokens("a"), 1);
127    }
128
129    #[test]
130    fn test_detect_code() {
131        let code = "fn main() { println!(\"test\"); }";
132        assert_eq!(detect_content_type(code), ContentType::Code);
133    }
134
135    #[test]
136    fn test_detect_prose() {
137        let prose = "The quick brown fox jumps over the lazy dog.";
138        assert_eq!(detect_content_type(prose), ContentType::Prose);
139    }
140
141    #[test]
142    fn test_detect_mixed() {
143        let mixed = "## Title\n\nSome text\n\n```rust\nfn test() {}\n```\n\nMore text.";
144        assert_eq!(detect_content_type(mixed), ContentType::Mixed);
145    }
146
147    #[test]
148    fn test_estimate_code() {
149        let code = "fn main() {}"; // 12 chars
150        let tokens = estimate_tokens(code);
151        // 12 / 3.0 = 4 tokens
152        assert_eq!(tokens, 4);
153    }
154
155    #[test]
156    fn test_estimate_prose() {
157        let prose = "Hello world!"; // 12 chars
158        let tokens = estimate_tokens(prose);
159        // 12 / 4.5 = 2.67 -> ceil = 3 tokens
160        assert_eq!(tokens, 3);
161    }
162
163    #[test]
164    fn test_detect_json() {
165        let json = r#"{"key": "value", "count": 42}"#;
166        assert_eq!(detect_content_type(json), ContentType::Code);
167    }
168
169    #[test]
170    fn test_detect_python() {
171        let python =
172            "def fibonacci(n):\n    return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)";
173        assert_eq!(detect_content_type(python), ContentType::Code);
174    }
175
176    #[test]
177    fn test_chars_per_token() {
178        assert_eq!(ContentType::Code.chars_per_token(), 3.0);
179        assert_eq!(ContentType::Prose.chars_per_token(), 4.5);
180        assert_eq!(ContentType::Mixed.chars_per_token(), 3.75);
181    }
182}