scribe_core/
tokenization.rs

1//! # Tokenization Module
2//!
3//! This module provides accurate token counting using OpenAI's tiktoken tokenizer,
4//! replacing the simple character-based estimation used previously.
5//!
6//! ## Features
7//!
8//! - **Accurate Token Counting**: Uses tiktoken cl100k_base encoding (GPT-4 compatible)
9//! - **Multiple Encoding Support**: Supports different OpenAI encodings
10//! - **Content-Aware Estimation**: Handles code content more accurately than character counting
11//! - **Budget Management**: Token budget allocation and tracking
12//!
13//! ## Usage
14//!
15//! ```rust
16//! use scribe_core::tokenization::{TokenCounter, TokenizerConfig};
17//!
18//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
19//! let config = TokenizerConfig::default();
20//! let counter = TokenCounter::new(config)?;
21//!
22//! let content = "fn main() { println!(\"Hello, world!\"); }";
23//! let token_count = counter.count_tokens(content)?;
24//! println!("Token count: {}", token_count);
25//! # Ok(())
26//! # }
27//! ```
28
29use std::sync::Arc;
30use tiktoken_rs::{get_bpe_from_model, CoreBPE};
31use serde::{Deserialize, Serialize};
32use once_cell::sync::Lazy;
33use crate::Result;
34
35/// Configuration for the tokenizer
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TokenizerConfig {
38    /// The encoding model to use (e.g., "gpt-4", "gpt-3.5-turbo")
39    pub encoding_model: String,
40    /// Whether to cache tokenizer instances
41    pub enable_caching: bool,
42    /// Token budget for content selection
43    pub token_budget: Option<usize>,
44}
45
46impl Default for TokenizerConfig {
47    fn default() -> Self {
48        Self {
49            encoding_model: "gpt-4".to_string(),
50            enable_caching: true,
51            token_budget: Some(128000), // Default to GPT-4 context window
52        }
53    }
54}
55
56/// Shared global instance of the default TokenCounter (GPT-4)
57/// This avoids expensive re-initialization on every token counting call
58static GLOBAL_TOKEN_COUNTER: Lazy<TokenCounter> = Lazy::new(|| {
59    TokenCounter::new(TokenizerConfig::default())
60        .expect("Failed to initialize global token counter")
61});
62
63/// Main tokenizer interface for accurate token counting
64pub struct TokenCounter {
65    config: TokenizerConfig,
66    bpe: Arc<CoreBPE>,
67}
68
69impl std::fmt::Debug for TokenCounter {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("TokenCounter")
72            .field("config", &self.config)
73            .field("bpe", &"<CoreBPE>")
74            .finish()
75    }
76}
77
78impl TokenCounter {
79    /// Create a new token counter with the specified configuration
80    pub fn new(config: TokenizerConfig) -> Result<Self> {
81        let bpe = get_bpe_from_model(&config.encoding_model)
82            .map_err(|e| crate::ScribeError::tokenization(format!("Failed to load tokenizer for model '{}': {}", config.encoding_model, e)))?;
83        
84        Ok(Self {
85            config,
86            bpe: Arc::new(bpe),
87        })
88    }
89    
90    /// Create a new token counter with default configuration (GPT-4)
91    pub fn default() -> Result<Self> {
92        Self::new(TokenizerConfig::default())
93    }
94    
95    /// Get a reference to the shared global token counter instance
96    /// This is highly optimized and avoids re-initialization costs
97    pub fn global() -> &'static TokenCounter {
98        &GLOBAL_TOKEN_COUNTER
99    }
100    
101    /// Count tokens in the given text content
102    pub fn count_tokens(&self, content: &str) -> Result<usize> {
103        let tokens = self.bpe.encode_with_special_tokens(content);
104        Ok(tokens.len())
105    }
106    
107    /// Count tokens in multiple content strings and return the total
108    pub fn count_tokens_batch(&self, contents: &[&str]) -> Result<usize> {
109        let mut total = 0;
110        for content in contents {
111            total += self.count_tokens(content)?;
112        }
113        Ok(total)
114    }
115    
116    /// Estimate tokens for a file based on its content and metadata
117    pub fn estimate_file_tokens(&self, content: &str, file_path: &std::path::Path) -> Result<usize> {
118        // Get base token count
119        let base_tokens = self.count_tokens(content)?;
120        
121        // Apply language-specific multipliers based on file extension
122        let multiplier = self.get_language_multiplier(file_path);
123        
124        Ok((base_tokens as f64 * multiplier).ceil() as usize)
125    }
126    
127    /// Get language-specific token multiplier
128    fn get_language_multiplier(&self, file_path: &std::path::Path) -> f64 {
129        let extension = file_path.extension()
130            .and_then(|ext| ext.to_str())
131            .unwrap_or("");
132        
133        match extension {
134            // Languages with lots of boilerplate tend to have lower token density
135            "java" | "csharp" | "cs" => 1.2,
136            
137            // Languages with compact syntax
138            "py" | "python" => 0.9,
139            "js" | "javascript" | "ts" | "typescript" => 0.95,
140            "rs" | "rust" => 1.0,
141            "go" => 0.95,
142            
143            // Configuration and data files
144            "json" | "yaml" | "yml" | "toml" => 0.8,
145            "xml" | "html" | "htm" => 1.1,
146            
147            // Documentation
148            "md" | "markdown" | "txt" => 0.7,
149            
150            // Default for unknown types
151            _ => 1.0,
152        }
153    }
154    
155    /// Check if content fits within the token budget
156    pub fn fits_budget(&self, content: &str) -> Result<bool> {
157        if let Some(budget) = self.config.token_budget {
158            let token_count = self.count_tokens(content)?;
159            Ok(token_count <= budget)
160        } else {
161            Ok(true) // No budget limit
162        }
163    }
164    
165    /// Calculate remaining budget after accounting for content
166    pub fn remaining_budget(&self, used_tokens: usize) -> Option<usize> {
167        self.config.token_budget.map(|budget| budget.saturating_sub(used_tokens))
168    }
169    
170    /// Split content into chunks that fit within a token limit
171    pub fn chunk_content(&self, content: &str, chunk_size: usize) -> Result<Vec<String>> {
172        let tokens = self.bpe.encode_with_special_tokens(content);
173        let mut chunks = Vec::new();
174        
175        for chunk_tokens in tokens.chunks(chunk_size) {
176            let chunk_text = self.bpe.decode(chunk_tokens.to_vec())
177                .map_err(|e| crate::ScribeError::tokenization(format!("Failed to decode token chunk: {}", e)))?;
178            chunks.push(chunk_text);
179        }
180        
181        Ok(chunks)
182    }
183    
184    /// Get the current tokenizer configuration
185    pub fn config(&self) -> &TokenizerConfig {
186        &self.config
187    }
188    
189    /// Update the token budget
190    pub fn set_token_budget(&mut self, budget: Option<usize>) {
191        self.config.token_budget = budget;
192    }
193}
194
195/// Token budget tracker for selection algorithms
196#[derive(Debug, Clone)]
197pub struct TokenBudget {
198    total_budget: usize,
199    used_tokens: usize,
200    reserved_tokens: usize,
201}
202
203impl TokenBudget {
204    /// Create a new token budget tracker
205    pub fn new(total_budget: usize) -> Self {
206        Self {
207            total_budget,
208            used_tokens: 0,
209            reserved_tokens: 0,
210        }
211    }
212    
213    /// Get the total budget
214    pub fn total(&self) -> usize {
215        self.total_budget
216    }
217    
218    /// Get the number of tokens used
219    pub fn used(&self) -> usize {
220        self.used_tokens
221    }
222    
223    /// Get the number of tokens reserved but not yet used
224    pub fn reserved(&self) -> usize {
225        self.reserved_tokens
226    }
227    
228    /// Get the number of available tokens
229    pub fn available(&self) -> usize {
230        self.total_budget.saturating_sub(self.used_tokens + self.reserved_tokens)
231    }
232    
233    /// Check if the budget can accommodate the specified number of tokens
234    pub fn can_allocate(&self, tokens: usize) -> bool {
235        self.available() >= tokens
236    }
237    
238    /// Allocate tokens from the budget
239    pub fn allocate(&mut self, tokens: usize) -> bool {
240        if self.can_allocate(tokens) {
241            self.used_tokens += tokens;
242            true
243        } else {
244            false
245        }
246    }
247    
248    /// Reserve tokens without using them yet
249    pub fn reserve(&mut self, tokens: usize) -> bool {
250        if self.available() >= tokens {
251            self.reserved_tokens += tokens;
252            true
253        } else {
254            false
255        }
256    }
257    
258    /// Confirm reserved tokens as used
259    pub fn confirm_reservation(&mut self, tokens: usize) {
260        let to_confirm = tokens.min(self.reserved_tokens);
261        self.reserved_tokens -= to_confirm;
262        self.used_tokens += to_confirm;
263    }
264    
265    /// Release reserved tokens back to available pool
266    pub fn release_reservation(&mut self, tokens: usize) {
267        self.reserved_tokens = self.reserved_tokens.saturating_sub(tokens);
268    }
269    
270    /// Get utilization as a percentage
271    pub fn utilization(&self) -> f64 {
272        (self.used_tokens as f64 / self.total_budget as f64) * 100.0
273    }
274    
275    /// Reset the budget tracker
276    pub fn reset(&mut self) {
277        self.used_tokens = 0;
278        self.reserved_tokens = 0;
279    }
280}
281
282/// Utilities for working with tokens and content
283pub mod utils {
284    use super::*;
285    
286    /// Estimate tokens using the legacy character-based method (for comparison)
287    pub fn estimate_tokens_legacy(content: &str) -> usize {
288        // Original method: ~4 characters per token for English text
289        (content.chars().count() as f64 / 4.0).ceil() as usize
290    }
291    
292    /// Compare tiktoken accuracy against legacy estimation
293    pub fn compare_tokenization_accuracy(content: &str, counter: &TokenCounter) -> Result<TokenizationComparison> {
294        let tiktoken_count = counter.count_tokens(content)?;
295        let legacy_count = estimate_tokens_legacy(content);
296        
297        let accuracy_ratio = if legacy_count > 0 {
298            tiktoken_count as f64 / legacy_count as f64
299        } else {
300            1.0
301        };
302        
303        Ok(TokenizationComparison {
304            tiktoken_count,
305            legacy_count,
306            accuracy_ratio,
307            improvement: if accuracy_ratio < 1.0 { 
308                Some((1.0 - accuracy_ratio) * 100.0)
309            } else { 
310                None 
311            },
312        })
313    }
314    
315    /// Get recommended token budget based on model and content type
316    pub fn recommend_token_budget(model: &str, content_type: ContentType) -> usize {
317        let base_budget = match model {
318            "gpt-4" | "gpt-4-turbo" => 128000,
319            "gpt-4-32k" => 32000,
320            "gpt-3.5-turbo" => 16000,
321            "gpt-3.5-turbo-16k" => 16000,
322            _ => 8000, // Conservative default
323        };
324        
325        // Adjust based on content type
326        match content_type {
327            ContentType::Code => (base_budget as f64 * 0.8) as usize, // Leave room for analysis
328            ContentType::Documentation => base_budget,
329            ContentType::Mixed => (base_budget as f64 * 0.9) as usize,
330        }
331    }
332}
333
334/// Content type for budget recommendations
335#[derive(Debug, Clone, Copy)]
336pub enum ContentType {
337    Code,
338    Documentation,
339    Mixed,
340}
341
342/// Comparison between tiktoken and legacy tokenization
343#[derive(Debug, Clone)]
344pub struct TokenizationComparison {
345    pub tiktoken_count: usize,
346    pub legacy_count: usize,
347    pub accuracy_ratio: f64,
348    pub improvement: Option<f64>, // Percentage improvement if tiktoken is more accurate
349}
350
351impl TokenizationComparison {
352    /// Format the comparison as a human-readable string
353    pub fn format(&self) -> String {
354        match self.improvement {
355            Some(improvement) => format!(
356                "Tiktoken: {} tokens, Legacy: {} tokens, {:.1}% more accurate",
357                self.tiktoken_count, self.legacy_count, improvement
358            ),
359            None => format!(
360                "Tiktoken: {} tokens, Legacy: {} tokens, {:.2}x ratio",
361                self.tiktoken_count, self.legacy_count, self.accuracy_ratio
362            ),
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use std::path::Path;
371    
372    #[test]
373    fn test_token_counter_creation() {
374        let config = TokenizerConfig::default();
375        let counter = TokenCounter::new(config);
376        assert!(counter.is_ok());
377    }
378    
379    #[test]
380    fn test_basic_token_counting() {
381        let counter = TokenCounter::default().unwrap();
382        
383        let simple_text = "Hello, world!";
384        let count = counter.count_tokens(simple_text).unwrap();
385        assert!(count > 0);
386        assert!(count < 10); // Should be a small number for this simple text
387    }
388    
389    #[test]
390    fn test_code_token_counting() {
391        let counter = TokenCounter::default().unwrap();
392        
393        let rust_code = r#"
394fn main() {
395    println!("Hello, world!");
396    let x = 42;
397    if x > 0 {
398        println!("Positive number: {}", x);
399    }
400}
401"#;
402        
403        let count = counter.count_tokens(rust_code).unwrap();
404        assert!(count > 20); // Should be more tokens for this code
405        assert!(count < 100); // But not excessive
406    }
407    
408    #[test]
409    fn test_language_multipliers() {
410        let counter = TokenCounter::default().unwrap();
411        
412        let content = "function test() { return 42; }";
413        
414        let js_tokens = counter.estimate_file_tokens(content, Path::new("test.js")).unwrap();
415        let java_tokens = counter.estimate_file_tokens(content, Path::new("test.java")).unwrap();
416        let py_tokens = counter.estimate_file_tokens(content, Path::new("test.py")).unwrap();
417        
418        // Java should have more tokens due to boilerplate multiplier
419        assert!(java_tokens >= js_tokens);
420        // Python should have fewer tokens due to compact syntax
421        assert!(py_tokens <= js_tokens);
422    }
423    
424    #[test]
425    fn test_token_budget() {
426        let mut budget = TokenBudget::new(1000);
427        
428        assert_eq!(budget.total(), 1000);
429        assert_eq!(budget.used(), 0);
430        assert_eq!(budget.available(), 1000);
431        
432        assert!(budget.allocate(300));
433        assert_eq!(budget.used(), 300);
434        assert_eq!(budget.available(), 700);
435        
436        assert!(budget.reserve(200));
437        assert_eq!(budget.reserved(), 200);
438        assert_eq!(budget.available(), 500);
439        
440        budget.confirm_reservation(150);
441        assert_eq!(budget.used(), 450);
442        assert_eq!(budget.reserved(), 50);
443        assert_eq!(budget.available(), 500);
444    }
445    
446    #[test]
447    fn test_content_chunking() {
448        let counter = TokenCounter::default().unwrap();
449        
450        let long_content = "word ".repeat(1000); // 1000 words
451        let chunks = counter.chunk_content(&long_content, 100).unwrap();
452        
453        assert!(chunks.len() > 1); // Should be split into multiple chunks
454        
455        // Verify each chunk is roughly the right size
456        for chunk in &chunks {
457            let chunk_tokens = counter.count_tokens(chunk).unwrap();
458            assert!(chunk_tokens <= 120); // Allow some margin due to token boundaries
459        }
460    }
461    
462    #[test]
463    fn test_tokenization_comparison() {
464        let counter = TokenCounter::default().unwrap();
465        
466        let code_content = r#"
467use std::collections::HashMap;
468
469fn process_data(input: &str) -> Result<HashMap<String, i32>, Box<dyn std::error::Error>> {
470    let mut result = HashMap::new();
471    for line in input.lines() {
472        let parts: Vec<&str> = line.split(':').collect();
473        if parts.len() == 2 {
474            result.insert(parts[0].to_string(), parts[1].parse()?);
475        }
476    }
477    Ok(result)
478}
479"#;
480        
481        let comparison = utils::compare_tokenization_accuracy(code_content, &counter).unwrap();
482        
483        assert!(comparison.tiktoken_count > 0);
484        assert!(comparison.legacy_count > 0);
485        assert!(comparison.accuracy_ratio > 0.0);
486        
487        let formatted = comparison.format();
488        assert!(formatted.contains("Tiktoken"));
489        assert!(formatted.contains("Legacy"));
490    }
491    
492    #[test]
493    fn test_budget_recommendations() {
494        let code_budget = utils::recommend_token_budget("gpt-4", ContentType::Code);
495        let doc_budget = utils::recommend_token_budget("gpt-4", ContentType::Documentation);
496        let mixed_budget = utils::recommend_token_budget("gpt-4", ContentType::Mixed);
497        
498        assert!(code_budget < doc_budget); // Code should have smaller budget to leave room for analysis
499        assert!(mixed_budget > code_budget);
500        assert!(mixed_budget < doc_budget);
501    }
502}