Skip to main content

traitclaw_core/
token_counting.rs

1//! Token counting utilities for context management.
2//!
3//! Provides both approximate (character-based) and accurate (tiktoken-based)
4//! token estimation for message lists.
5
6use crate::types::message::Message;
7
8/// Approximate token counter using the 4-characters ≈ 1-token heuristic.
9///
10/// Fast and allocation-free. Suitable for most use cases where exact
11/// token counts are not required.
12///
13/// # Example
14///
15/// ```rust
16/// use traitclaw_core::token_counting::CharApproxCounter;
17/// use traitclaw_core::types::message::{Message, MessageRole};
18///
19/// let counter = CharApproxCounter::new(4);
20/// let messages = vec![
21///     Message { role: MessageRole::User, content: "Hello world!".to_string(), tool_call_id: None },
22/// ];
23/// let tokens = counter.count(&messages);
24/// assert_eq!(tokens, 4); // 12 chars / 4 + 1 = 4
25/// ```
26pub struct CharApproxCounter {
27    /// Characters per token ratio.
28    chars_per_token: usize,
29}
30
31impl CharApproxCounter {
32    /// Create a new counter with the given characters-per-token ratio.
33    ///
34    /// Common values: 4 (English), 3 (CJK-heavy), 2 (code-heavy).
35    #[must_use]
36    pub fn new(chars_per_token: usize) -> Self {
37        Self {
38            chars_per_token: chars_per_token.max(1),
39        }
40    }
41
42    /// Count tokens in a message list.
43    #[must_use]
44    pub fn count(&self, messages: &[Message]) -> usize {
45        messages
46            .iter()
47            .map(|m| m.content.len() / self.chars_per_token + 1)
48            .sum()
49    }
50
51    /// Count tokens in a single string.
52    #[must_use]
53    pub fn count_str(&self, text: &str) -> usize {
54        text.len() / self.chars_per_token + 1
55    }
56}
57
58impl Default for CharApproxCounter {
59    fn default() -> Self {
60        Self::new(4)
61    }
62}
63
64/// Trait for pluggable token counting backends.
65///
66/// Implement this trait to provide accurate token counting for specific
67/// models (e.g., tiktoken for OpenAI models).
68pub trait TokenCounter: Send + Sync {
69    /// Count tokens in a message list.
70    fn count_messages(&self, messages: &[Message]) -> usize;
71
72    /// Count tokens in a single string.
73    fn count_str(&self, text: &str) -> usize;
74}
75
76impl TokenCounter for CharApproxCounter {
77    fn count_messages(&self, messages: &[Message]) -> usize {
78        self.count(messages)
79    }
80
81    fn count_str(&self, text: &str) -> usize {
82        CharApproxCounter::count_str(self, text)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::types::message::MessageRole;
90
91    fn msg(content: &str) -> Message {
92        Message {
93            role: MessageRole::User,
94            content: content.to_string(),
95            tool_call_id: None,
96        }
97    }
98
99    #[test]
100    fn test_char_approx_default() {
101        let counter = CharApproxCounter::default();
102        // 12 chars / 4 + 1 = 4
103        assert_eq!(counter.count_str("Hello world!"), 4);
104    }
105
106    #[test]
107    fn test_char_approx_custom_ratio() {
108        let counter = CharApproxCounter::new(2);
109        // 12 chars / 2 + 1 = 7
110        assert_eq!(counter.count_str("Hello world!"), 7);
111    }
112
113    #[test]
114    fn test_char_approx_messages() {
115        let counter = CharApproxCounter::default();
116        let messages = vec![msg("aaaa"), msg("bbbbbbbb")]; // 4/4+1=2, 8/4+1=3 → 5
117        assert_eq!(counter.count(&messages), 5);
118    }
119
120    #[test]
121    fn test_char_approx_empty() {
122        let counter = CharApproxCounter::default();
123        assert_eq!(counter.count(&[]), 0);
124        // empty string: 0/4 + 1 = 1
125        assert_eq!(counter.count_str(""), 1);
126    }
127
128    #[test]
129    fn test_char_approx_zero_ratio_clamped() {
130        let counter = CharApproxCounter::new(0);
131        // Should clamp to 1
132        assert_eq!(counter.count_str("abcd"), 5); // 4/1 + 1 = 5
133    }
134
135    #[test]
136    fn test_token_counter_trait() {
137        let counter = CharApproxCounter::default();
138        let tc: &dyn TokenCounter = &counter;
139        assert_eq!(tc.count_str("abcdefgh"), 3); // 8/4 + 1 = 3
140    }
141}