Skip to main content

winx_code_agent/utils/
encoder.rs

1//! Claude-compatible token counting.
2//!
3//! WCGW counts tokens with the `Xenova/claude-tokenizer` (Hugging Face `tokenizers`).
4//! We embed that same tokenizer definition in the binary and load it lazily, so token
5//! budgets and truncation match the model that actually runs the agent. If the
6//! tokenizer fails to load we fall back to a cheap character/word estimate.
7
8use std::sync::OnceLock;
9use tokenizers::Tokenizer;
10
11/// Embedded `Xenova/claude-tokenizer` definition (Hugging Face `tokenizer.json`).
12static CLAUDE_TOKENIZER_JSON: &[u8] =
13    include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/claude-tokenizer.json"));
14
15fn tokenizer() -> Option<&'static Tokenizer> {
16    static TOKENIZER: OnceLock<Option<Tokenizer>> = OnceLock::new();
17    TOKENIZER
18        .get_or_init(|| match Tokenizer::from_bytes(CLAUDE_TOKENIZER_JSON) {
19            Ok(tokenizer) => Some(tokenizer),
20            Err(error) => {
21                tracing::warn!("Failed to load embedded Claude tokenizer, using estimate: {error}");
22                None
23            }
24        })
25        .as_ref()
26}
27
28/// Count tokens the way Claude does. Falls back to [`estimate_tokens`] on failure.
29pub fn count_tokens(text: &str) -> usize {
30    match encode_ids(text) {
31        Some(ids) => ids.len(),
32        None => estimate_tokens(text),
33    }
34}
35
36/// Encode `text` into Claude token ids. Returns `None` if the tokenizer is
37/// unavailable so callers can pick a byte-based fallback.
38pub fn encode_ids(text: &str) -> Option<Vec<u32>> {
39    let tokenizer = tokenizer()?;
40    tokenizer.encode(text, false).ok().map(|encoding| encoding.get_ids().to_vec())
41}
42
43/// Decode Claude token ids back into text. Returns `None` on failure.
44pub fn decode_ids(ids: &[u32]) -> Option<String> {
45    let tokenizer = tokenizer()?;
46    tokenizer.decode(ids, false).ok()
47}
48
49/// Cheap fallback estimate used only when the tokenizer is unavailable.
50pub fn estimate_tokens(text: &str) -> usize {
51    text.chars().count().div_ceil(4).max(text.split_whitespace().count())
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[test]
59    fn counts_tokens_for_simple_text() {
60        // Whatever the backend, a non-empty string must produce at least one token.
61        assert!(count_tokens("hello world") >= 1);
62        assert_eq!(count_tokens(""), 0);
63    }
64
65    #[test]
66    fn estimate_is_nonzero_for_words() {
67        assert!(estimate_tokens("a b c d") >= 4);
68    }
69}