scribe_core/
tokenization.rs

1//! # Tokenization Module
2//!
3//! This module provides accurate token counting using OpenAI's tiktoken tokenizer,
4//! replacing the simple character-based estimation used previously.
5//!
6//! ## Features
7//!
8//! - **Accurate Token Counting**: Uses tiktoken cl100k_base encoding (GPT-4 compatible)
9//! - **Multiple Encoding Support**: Supports different OpenAI encodings
10//! - **Content-Aware Estimation**: Handles code content more accurately than character counting
11//! - **Budget Management**: Token budget allocation and tracking
12//!
13//! ## Usage
14//!
15//! ```rust
16//! use scribe_core::tokenization::{TokenCounter, TokenizerConfig};
17//!
18//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
19//! let config = TokenizerConfig::default();
20//! let counter = TokenCounter::new(config)?;
21//!
22//! let content = "fn main() { println!(\"Hello, world!\"); }";
23//! let token_count = counter.count_tokens(content)?;
24//! println!("Token count: {}", token_count);
25//! # Ok(())
26//! # }
27//! ```
28
29use crate::Result;
30use once_cell::sync::Lazy;
31use serde::{Deserialize, Serialize};
32use std::sync::Arc;
33use tiktoken_rs::{get_bpe_from_model, CoreBPE};
34
35/// Configuration for the tokenizer
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TokenizerConfig {
38    /// The encoding model to use (e.g., "gpt-4", "gpt-3.5-turbo")
39    pub encoding_model: String,
40    /// Whether to cache tokenizer instances
41    pub enable_caching: bool,
42    /// Token budget for content selection
43    pub token_budget: Option<usize>,
44}
45
46impl Default for TokenizerConfig {
47    fn default() -> Self {
48        Self {
49            encoding_model: "gpt-4".to_string(),
50            enable_caching: true,
51            token_budget: Some(128000), // Default to GPT-4 context window
52        }
53    }
54}
55
56/// Shared global instance of the default TokenCounter (GPT-4)
57/// This avoids expensive re-initialization on every token counting call
58static GLOBAL_TOKEN_COUNTER: Lazy<TokenCounter> = Lazy::new(|| {
59    TokenCounter::new(TokenizerConfig::default())
60        .expect("Failed to initialize global token counter")
61});
62
63/// Main tokenizer interface for accurate token counting
64pub struct TokenCounter {
65    config: TokenizerConfig,
66    bpe: Arc<CoreBPE>,
67}
68
69impl std::fmt::Debug for TokenCounter {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("TokenCounter")
72            .field("config", &self.config)
73            .field("bpe", &"<CoreBPE>")
74            .finish()
75    }
76}
77
78impl TokenCounter {
79    /// Create a new token counter with the specified configuration
80    pub fn new(config: TokenizerConfig) -> Result<Self> {
81        let bpe = get_bpe_from_model(&config.encoding_model).map_err(|e| {
82            crate::ScribeError::tokenization(format!(
83                "Failed to load tokenizer for model '{}': {}",
84                config.encoding_model, e
85            ))
86        })?;
87
88        Ok(Self {
89            config,
90            bpe: Arc::new(bpe),
91        })
92    }
93
94    /// Create a new token counter with default configuration (GPT-4)
95    pub fn default() -> Result<Self> {
96        Self::new(TokenizerConfig::default())
97    }
98
99    /// Get a reference to the shared global token counter instance
100    /// This is highly optimized and avoids re-initialization costs
101    pub fn global() -> &'static TokenCounter {
102        &GLOBAL_TOKEN_COUNTER
103    }
104
105    /// Count tokens in the given text content
106    pub fn count_tokens(&self, content: &str) -> Result<usize> {
107        let tokens = self.bpe.encode_with_special_tokens(content);
108        Ok(tokens.len())
109    }
110
111    /// Count tokens in multiple content strings and return the total
112    pub fn count_tokens_batch(&self, contents: &[&str]) -> Result<usize> {
113        let mut total = 0;
114        for content in contents {
115            total += self.count_tokens(content)?;
116        }
117        Ok(total)
118    }
119
120    /// Estimate tokens for a file based on its content and metadata
121    pub fn estimate_file_tokens(
122        &self,
123        content: &str,
124        file_path: &std::path::Path,
125    ) -> Result<usize> {
126        // Get base token count
127        let base_tokens = self.count_tokens(content)?;
128
129        // Apply language-specific multipliers based on file extension
130        let multiplier = self.get_language_multiplier(file_path);
131
132        Ok((base_tokens as f64 * multiplier).ceil() as usize)
133    }
134
135    /// Get language-specific token multiplier
136    fn get_language_multiplier(&self, file_path: &std::path::Path) -> f64 {
137        let extension = file_path
138            .extension()
139            .and_then(|ext| ext.to_str())
140            .unwrap_or("");
141
142        match extension {
143            // Languages with lots of boilerplate tend to have lower token density
144            "java" | "csharp" | "cs" => 1.2,
145
146            // Languages with compact syntax
147            "py" | "python" => 0.9,
148            "js" | "javascript" | "ts" | "typescript" => 0.95,
149            "rs" | "rust" => 1.0,
150            "go" => 0.95,
151
152            // Configuration and data files
153            "json" | "yaml" | "yml" | "toml" => 0.8,
154            "xml" | "html" | "htm" => 1.1,
155
156            // Documentation
157            "md" | "markdown" | "txt" => 0.7,
158
159            // Default for unknown types
160            _ => 1.0,
161        }
162    }
163
164    /// Check if content fits within the token budget
165    pub fn fits_budget(&self, content: &str) -> Result<bool> {
166        if let Some(budget) = self.config.token_budget {
167            let token_count = self.count_tokens(content)?;
168            Ok(token_count <= budget)
169        } else {
170            Ok(true) // No budget limit
171        }
172    }
173
174    /// Calculate remaining budget after accounting for content
175    pub fn remaining_budget(&self, used_tokens: usize) -> Option<usize> {
176        self.config
177            .token_budget
178            .map(|budget| budget.saturating_sub(used_tokens))
179    }
180
181    /// Split content into chunks that fit within a token limit
182    pub fn chunk_content(&self, content: &str, chunk_size: usize) -> Result<Vec<String>> {
183        let tokens = self.bpe.encode_with_special_tokens(content);
184        let mut chunks = Vec::new();
185
186        for chunk_tokens in tokens.chunks(chunk_size) {
187            let chunk_text = self.bpe.decode(chunk_tokens.to_vec()).map_err(|e| {
188                crate::ScribeError::tokenization(format!("Failed to decode token chunk: {}", e))
189            })?;
190            chunks.push(chunk_text);
191        }
192
193        Ok(chunks)
194    }
195
196    /// Get the current tokenizer configuration
197    pub fn config(&self) -> &TokenizerConfig {
198        &self.config
199    }
200
201    /// Update the token budget
202    pub fn set_token_budget(&mut self, budget: Option<usize>) {
203        self.config.token_budget = budget;
204    }
205}
206
207/// Token budget tracker for selection algorithms
208#[derive(Debug, Clone)]
209pub struct TokenBudget {
210    total_budget: usize,
211    used_tokens: usize,
212    reserved_tokens: usize,
213}
214
215impl TokenBudget {
216    /// Create a new token budget tracker
217    pub fn new(total_budget: usize) -> Self {
218        Self {
219            total_budget,
220            used_tokens: 0,
221            reserved_tokens: 0,
222        }
223    }
224
225    /// Get the total budget
226    pub fn total(&self) -> usize {
227        self.total_budget
228    }
229
230    /// Get the number of tokens used
231    pub fn used(&self) -> usize {
232        self.used_tokens
233    }
234
235    /// Get the number of tokens reserved but not yet used
236    pub fn reserved(&self) -> usize {
237        self.reserved_tokens
238    }
239
240    /// Get the number of available tokens
241    pub fn available(&self) -> usize {
242        self.total_budget
243            .saturating_sub(self.used_tokens + self.reserved_tokens)
244    }
245
246    /// Check if the budget can accommodate the specified number of tokens
247    pub fn can_allocate(&self, tokens: usize) -> bool {
248        self.available() >= tokens
249    }
250
251    /// Allocate tokens from the budget
252    pub fn allocate(&mut self, tokens: usize) -> bool {
253        if self.can_allocate(tokens) {
254            self.used_tokens += tokens;
255            true
256        } else {
257            false
258        }
259    }
260
261    /// Reserve tokens without using them yet
262    pub fn reserve(&mut self, tokens: usize) -> bool {
263        if self.available() >= tokens {
264            self.reserved_tokens += tokens;
265            true
266        } else {
267            false
268        }
269    }
270
271    /// Confirm reserved tokens as used
272    pub fn confirm_reservation(&mut self, tokens: usize) {
273        let to_confirm = tokens.min(self.reserved_tokens);
274        self.reserved_tokens -= to_confirm;
275        self.used_tokens += to_confirm;
276    }
277
278    /// Release reserved tokens back to available pool
279    pub fn release_reservation(&mut self, tokens: usize) {
280        self.reserved_tokens = self.reserved_tokens.saturating_sub(tokens);
281    }
282
283    /// Get utilization as a percentage
284    pub fn utilization(&self) -> f64 {
285        (self.used_tokens as f64 / self.total_budget as f64) * 100.0
286    }
287
288    /// Reset the budget tracker
289    pub fn reset(&mut self) {
290        self.used_tokens = 0;
291        self.reserved_tokens = 0;
292    }
293}
294
295/// Utilities for working with tokens and content
296pub mod utils {
297    use super::*;
298
299    /// Estimate tokens using the legacy character-based method (for comparison)
300    pub fn estimate_tokens_legacy(content: &str) -> usize {
301        // Original method: ~4 characters per token for English text
302        (content.chars().count() as f64 / 4.0).ceil() as usize
303    }
304
305    /// Compare tiktoken accuracy against legacy estimation
306    pub fn compare_tokenization_accuracy(
307        content: &str,
308        counter: &TokenCounter,
309    ) -> Result<TokenizationComparison> {
310        let tiktoken_count = counter.count_tokens(content)?;
311        let legacy_count = estimate_tokens_legacy(content);
312
313        let accuracy_ratio = if legacy_count > 0 {
314            tiktoken_count as f64 / legacy_count as f64
315        } else {
316            1.0
317        };
318
319        Ok(TokenizationComparison {
320            tiktoken_count,
321            legacy_count,
322            accuracy_ratio,
323            improvement: if accuracy_ratio < 1.0 {
324                Some((1.0 - accuracy_ratio) * 100.0)
325            } else {
326                None
327            },
328        })
329    }
330
331    /// Get recommended token budget based on model and content type
332    pub fn recommend_token_budget(model: &str, content_type: ContentType) -> usize {
333        let base_budget = match model {
334            "gpt-4" | "gpt-4-turbo" => 128000,
335            "gpt-4-32k" => 32000,
336            "gpt-3.5-turbo" => 16000,
337            "gpt-3.5-turbo-16k" => 16000,
338            _ => 8000, // Conservative default
339        };
340
341        // Adjust based on content type
342        match content_type {
343            ContentType::Code => (base_budget as f64 * 0.8) as usize, // Leave room for analysis
344            ContentType::Documentation => base_budget,
345            ContentType::Mixed => (base_budget as f64 * 0.9) as usize,
346        }
347    }
348}
349
350/// Content type for budget recommendations
351#[derive(Debug, Clone, Copy)]
352pub enum ContentType {
353    Code,
354    Documentation,
355    Mixed,
356}
357
358/// Comparison between tiktoken and legacy tokenization
359#[derive(Debug, Clone)]
360pub struct TokenizationComparison {
361    pub tiktoken_count: usize,
362    pub legacy_count: usize,
363    pub accuracy_ratio: f64,
364    pub improvement: Option<f64>, // Percentage improvement if tiktoken is more accurate
365}
366
367impl TokenizationComparison {
368    /// Format the comparison as a human-readable string
369    pub fn format(&self) -> String {
370        match self.improvement {
371            Some(improvement) => format!(
372                "Tiktoken: {} tokens, Legacy: {} tokens, {:.1}% more accurate",
373                self.tiktoken_count, self.legacy_count, improvement
374            ),
375            None => format!(
376                "Tiktoken: {} tokens, Legacy: {} tokens, {:.2}x ratio",
377                self.tiktoken_count, self.legacy_count, self.accuracy_ratio
378            ),
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use std::path::Path;
387
388    #[test]
389    fn test_token_counter_creation() {
390        let config = TokenizerConfig::default();
391        let counter = TokenCounter::new(config);
392        assert!(counter.is_ok());
393    }
394
395    #[test]
396    fn test_basic_token_counting() {
397        let counter = TokenCounter::default().unwrap();
398
399        let simple_text = "Hello, world!";
400        let count = counter.count_tokens(simple_text).unwrap();
401        assert!(count > 0);
402        assert!(count < 10); // Should be a small number for this simple text
403    }
404
405    #[test]
406    fn test_code_token_counting() {
407        let counter = TokenCounter::default().unwrap();
408
409        let rust_code = r#"
410fn main() {
411    println!("Hello, world!");
412    let x = 42;
413    if x > 0 {
414        println!("Positive number: {}", x);
415    }
416}
417"#;
418
419        let count = counter.count_tokens(rust_code).unwrap();
420        assert!(count > 20); // Should be more tokens for this code
421        assert!(count < 100); // But not excessive
422    }
423
424    #[test]
425    fn test_language_multipliers() {
426        let counter = TokenCounter::default().unwrap();
427
428        let content = "function test() { return 42; }";
429
430        let js_tokens = counter
431            .estimate_file_tokens(content, Path::new("test.js"))
432            .unwrap();
433        let java_tokens = counter
434            .estimate_file_tokens(content, Path::new("test.java"))
435            .unwrap();
436        let py_tokens = counter
437            .estimate_file_tokens(content, Path::new("test.py"))
438            .unwrap();
439
440        // Java should have more tokens due to boilerplate multiplier
441        assert!(java_tokens >= js_tokens);
442        // Python should have fewer tokens due to compact syntax
443        assert!(py_tokens <= js_tokens);
444    }
445
446    #[test]
447    fn test_token_budget() {
448        let mut budget = TokenBudget::new(1000);
449
450        assert_eq!(budget.total(), 1000);
451        assert_eq!(budget.used(), 0);
452        assert_eq!(budget.available(), 1000);
453
454        assert!(budget.allocate(300));
455        assert_eq!(budget.used(), 300);
456        assert_eq!(budget.available(), 700);
457
458        assert!(budget.reserve(200));
459        assert_eq!(budget.reserved(), 200);
460        assert_eq!(budget.available(), 500);
461
462        budget.confirm_reservation(150);
463        assert_eq!(budget.used(), 450);
464        assert_eq!(budget.reserved(), 50);
465        assert_eq!(budget.available(), 500);
466    }
467
468    #[test]
469    fn test_content_chunking() {
470        let counter = TokenCounter::default().unwrap();
471
472        let long_content = "word ".repeat(1000); // 1000 words
473        let chunks = counter.chunk_content(&long_content, 100).unwrap();
474
475        assert!(chunks.len() > 1); // Should be split into multiple chunks
476
477        // Verify each chunk is roughly the right size
478        for chunk in &chunks {
479            let chunk_tokens = counter.count_tokens(chunk).unwrap();
480            assert!(chunk_tokens <= 120); // Allow some margin due to token boundaries
481        }
482    }
483
484    #[test]
485    fn test_tokenization_comparison() {
486        let counter = TokenCounter::default().unwrap();
487
488        let code_content = r#"
489use std::collections::HashMap;
490
491fn process_data(input: &str) -> Result<HashMap<String, i32>, Box<dyn std::error::Error>> {
492    let mut result = HashMap::new();
493    for line in input.lines() {
494        let parts: Vec<&str> = line.split(':').collect();
495        if parts.len() == 2 {
496            result.insert(parts[0].to_string(), parts[1].parse()?);
497        }
498    }
499    Ok(result)
500}
501"#;
502
503        let comparison = utils::compare_tokenization_accuracy(code_content, &counter).unwrap();
504
505        assert!(comparison.tiktoken_count > 0);
506        assert!(comparison.legacy_count > 0);
507        assert!(comparison.accuracy_ratio > 0.0);
508
509        let formatted = comparison.format();
510        assert!(formatted.contains("Tiktoken"));
511        assert!(formatted.contains("Legacy"));
512    }
513
514    #[test]
515    fn test_budget_recommendations() {
516        let code_budget = utils::recommend_token_budget("gpt-4", ContentType::Code);
517        let doc_budget = utils::recommend_token_budget("gpt-4", ContentType::Documentation);
518        let mixed_budget = utils::recommend_token_budget("gpt-4", ContentType::Mixed);
519
520        assert!(code_budget < doc_budget); // Code should have smaller budget to leave room for analysis
521        assert!(mixed_budget > code_budget);
522        assert!(mixed_budget < doc_budget);
523    }
524}