traitclaw_core/
token_counting.rs1use crate::types::message::Message;
7
8pub struct CharApproxCounter {
27 chars_per_token: usize,
29}
30
31impl CharApproxCounter {
32 #[must_use]
36 pub fn new(chars_per_token: usize) -> Self {
37 Self {
38 chars_per_token: chars_per_token.max(1),
39 }
40 }
41
42 #[must_use]
44 pub fn count(&self, messages: &[Message]) -> usize {
45 messages
46 .iter()
47 .map(|m| m.content.len() / self.chars_per_token + 1)
48 .sum()
49 }
50
51 #[must_use]
53 pub fn count_str(&self, text: &str) -> usize {
54 text.len() / self.chars_per_token + 1
55 }
56}
57
58impl Default for CharApproxCounter {
59 fn default() -> Self {
60 Self::new(4)
61 }
62}
63
64pub trait TokenCounter: Send + Sync {
69 fn count_messages(&self, messages: &[Message]) -> usize;
71
72 fn count_str(&self, text: &str) -> usize;
74}
75
76impl TokenCounter for CharApproxCounter {
77 fn count_messages(&self, messages: &[Message]) -> usize {
78 self.count(messages)
79 }
80
81 fn count_str(&self, text: &str) -> usize {
82 CharApproxCounter::count_str(self, text)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::types::message::MessageRole;
90
91 fn msg(content: &str) -> Message {
92 Message {
93 role: MessageRole::User,
94 content: content.to_string(),
95 tool_call_id: None,
96 }
97 }
98
99 #[test]
100 fn test_char_approx_default() {
101 let counter = CharApproxCounter::default();
102 assert_eq!(counter.count_str("Hello world!"), 4);
104 }
105
106 #[test]
107 fn test_char_approx_custom_ratio() {
108 let counter = CharApproxCounter::new(2);
109 assert_eq!(counter.count_str("Hello world!"), 7);
111 }
112
113 #[test]
114 fn test_char_approx_messages() {
115 let counter = CharApproxCounter::default();
116 let messages = vec![msg("aaaa"), msg("bbbbbbbb")]; assert_eq!(counter.count(&messages), 5);
118 }
119
120 #[test]
121 fn test_char_approx_empty() {
122 let counter = CharApproxCounter::default();
123 assert_eq!(counter.count(&[]), 0);
124 assert_eq!(counter.count_str(""), 1);
126 }
127
128 #[test]
129 fn test_char_approx_zero_ratio_clamped() {
130 let counter = CharApproxCounter::new(0);
131 assert_eq!(counter.count_str("abcd"), 5); }
134
135 #[test]
136 fn test_token_counter_trait() {
137 let counter = CharApproxCounter::default();
138 let tc: &dyn TokenCounter = &counter;
139 assert_eq!(tc.count_str("abcdefgh"), 3); }
141}