Skip to main content

st/
dynamic_tokenizer.rs

1//! Dynamic tokenizer - "Learning each project's language!" - Omni
2//! Automatically discovers and tokenizes common patterns in any codebase
3
4use crate::scanner::FileNode;
5use std::collections::HashMap;
6
7/// Dynamic tokenizer that learns project-specific patterns
8pub struct DynamicTokenizer {
9    /// Path component frequencies
10    path_components: HashMap<String, usize>,
11    /// File name frequencies
12    file_names: HashMap<String, usize>,
13    /// Extension frequencies
14    extensions: HashMap<String, usize>,
15    /// Common prefixes/suffixes
16    prefixes: HashMap<String, usize>,
17    suffixes: HashMap<String, usize>,
18    /// Generated token mappings
19    tokens: HashMap<String, String>,
20}
21
22impl Default for DynamicTokenizer {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl DynamicTokenizer {
29    pub fn new() -> Self {
30        Self {
31            path_components: HashMap::new(),
32            file_names: HashMap::new(),
33            extensions: HashMap::new(),
34            prefixes: HashMap::new(),
35            suffixes: HashMap::new(),
36            tokens: HashMap::new(),
37        }
38    }
39
40    /// Analyze nodes to learn patterns
41    pub fn analyze(&mut self, nodes: &[FileNode]) {
42        for node in nodes {
43            // Analyze path components
44            let path_str = node.path.to_string_lossy();
45
46            // Split path into components
47            for component in path_str.split('/').filter(|c| !c.is_empty()) {
48                *self
49                    .path_components
50                    .entry(component.to_string())
51                    .or_insert(0) += 1;
52            }
53
54            // Analyze file name
55            if let Some(file_name) = node.path.file_name() {
56                let name = file_name.to_string_lossy().to_string();
57                *self.file_names.entry(name.clone()).or_insert(0) += 1;
58
59                // Extract common patterns
60                self.analyze_name_patterns(&name);
61            }
62
63            // Analyze extension
64            if let Some(ext) = node.path.extension() {
65                let ext_str = ext.to_string_lossy().to_string();
66                *self.extensions.entry(ext_str).or_insert(0) += 1;
67            }
68        }
69
70        // Generate optimal tokens
71        self.generate_tokens();
72    }
73
74    /// Analyze file name for common patterns
75    fn analyze_name_patterns(&mut self, name: &str) {
76        // Common prefixes
77        let prefix_patterns = ["test_", "Test", "_", "mock_", "stub_", "fake_"];
78        for prefix in &prefix_patterns {
79            if name.starts_with(prefix) {
80                *self.prefixes.entry(prefix.to_string()).or_insert(0) += 1;
81            }
82        }
83
84        // Common suffixes
85        let suffix_patterns = [
86            "_test",
87            "Test",
88            "Spec",
89            "_spec",
90            ".test",
91            ".spec",
92            "Controller",
93            "Service",
94            "Repository",
95            "Model",
96            "View",
97            "Component",
98            "Module",
99            "Config",
100        ];
101        for suffix in &suffix_patterns {
102            if name.contains(suffix) {
103                *self.suffixes.entry(suffix.to_string()).or_insert(0) += 1;
104            }
105        }
106
107        // Camel/Snake case components
108        if name.contains('_') {
109            // Snake case - split and analyze
110            for part in name.split('_') {
111                if part.len() > 2 {
112                    *self.path_components.entry(part.to_string()).or_insert(0) += 1;
113                }
114            }
115        } else if name.chars().any(|c| c.is_uppercase()) && name.chars().any(|c| c.is_lowercase()) {
116            // CamelCase - split and analyze
117            let parts = split_camel_case(name);
118            for part in parts {
119                if part.len() > 2 {
120                    *self.path_components.entry(part).or_insert(0) += 1;
121                }
122            }
123        }
124    }
125
126    /// Generate optimal token assignments
127    fn generate_tokens(&mut self) {
128        let mut token_id = 0x80; // Start from 128
129
130        // Sort all patterns by frequency
131        let mut all_patterns: Vec<(String, usize)> = Vec::new();
132
133        // Collect all patterns with their frequencies
134        for (pattern, count) in &self.path_components {
135            if *count > 2 {
136                // Only tokenize if it appears more than twice
137                all_patterns.push((pattern.clone(), *count));
138            }
139        }
140
141        for (pattern, count) in &self.file_names {
142            if *count > 2 {
143                all_patterns.push((pattern.clone(), *count));
144            }
145        }
146
147        for (pattern, count) in &self.extensions {
148            if *count > 5 {
149                // Extensions need higher frequency
150                all_patterns.push((format!(".{}", pattern), *count));
151            }
152        }
153
154        // Sort by frequency (descending) and pattern length (descending for same frequency)
155        all_patterns.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| b.0.len().cmp(&a.0.len())));
156
157        // Assign tokens to most frequent patterns
158        for (pattern, _count) in all_patterns.iter().take(127) {
159            // Max 127 tokens
160            self.tokens
161                .insert(pattern.clone(), format!("{:02X}", token_id));
162            token_id += 1;
163            if token_id > 0xFE {
164                // Reserve 0xFF
165                break;
166            }
167        }
168    }
169
170    /// Compress a path using learned tokens
171    pub fn compress_path(&self, path: &str) -> String {
172        let mut compressed = path.to_string();
173
174        // Apply tokens from longest to shortest to avoid substring issues
175        let mut tokens_by_length: Vec<(&String, &String)> = self.tokens.iter().collect();
176        tokens_by_length.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
177
178        for (pattern, token) in tokens_by_length {
179            compressed = compressed.replace(pattern, &format!("{{{}}}", token));
180        }
181
182        compressed
183    }
184
185    /// Get the token dictionary header
186    pub fn get_token_header(&self) -> String {
187        let mut header = String::from("TOKENS:\n");
188
189        // Sort tokens by ID for consistent output
190        let mut sorted_tokens: Vec<(&String, &String)> = self.tokens.iter().collect();
191        sorted_tokens.sort_by(|a, b| a.1.cmp(b.1));
192
193        for (pattern, token) in sorted_tokens {
194            header.push_str(&format!("  {}={}\n", token, pattern));
195        }
196
197        header
198    }
199
200    /// Get compression statistics
201    pub fn get_stats(&self) -> TokenizerStats {
202        let total_pattern_bytes: usize = self.tokens.keys().map(|k| k.len()).sum();
203        let total_token_bytes = self.tokens.len() * 3; // {XX} format
204
205        TokenizerStats {
206            patterns_found: self.path_components.len()
207                + self.file_names.len()
208                + self.extensions.len(),
209            tokens_generated: self.tokens.len(),
210            estimated_savings: total_pattern_bytes.saturating_sub(total_token_bytes),
211        }
212    }
213}
214
215#[derive(Debug)]
216pub struct TokenizerStats {
217    pub patterns_found: usize,
218    pub tokens_generated: usize,
219    pub estimated_savings: usize,
220}
221
222/// Split CamelCase into components
223fn split_camel_case(s: &str) -> Vec<String> {
224    let mut result = Vec::new();
225    let mut current = String::new();
226
227    for (i, ch) in s.chars().enumerate() {
228        if i > 0 && ch.is_uppercase() && !current.is_empty() {
229            result.push(current.clone());
230            current.clear();
231        }
232        current.push(ch.to_lowercase().to_string().chars().next().unwrap());
233    }
234
235    if !current.is_empty() {
236        result.push(current);
237    }
238
239    result
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_camel_case_split() {
248        assert_eq!(
249            split_camel_case("UserController"),
250            vec!["user", "controller"]
251        );
252        assert_eq!(
253            split_camel_case("HTTPSConnection"),
254            vec!["h", "t", "t", "p", "s", "connection"]
255        );
256    }
257
258    #[test]
259    fn test_pattern_detection() {
260        let mut tokenizer = DynamicTokenizer::new();
261
262        // Simulate a typical web project
263        let patterns = vec![
264            "src/components/UserList.tsx",
265            "src/components/UserDetail.tsx",
266            "src/components/UserForm.tsx",
267            "src/services/UserService.ts",
268            "src/services/AuthService.ts",
269            "src/services/ApiService.ts", // Added to make services appear 3 times
270            "src/controllers/UserController.ts",
271            "src/controllers/AuthController.ts",
272            "src/controllers/ApiController.ts", // Added to make controllers appear 3 times
273            "tests/unit/UserService.test.ts",
274            "tests/unit/AuthService.test.ts",
275            "tests/integration/ApiService.test.ts", // Added to make tests appear 3 times
276        ];
277
278        for pattern in patterns {
279            tokenizer.analyze_name_patterns(pattern);
280            for component in pattern.split('/') {
281                *tokenizer
282                    .path_components
283                    .entry(component.to_string())
284                    .or_insert(0) += 1;
285            }
286        }
287
288        tokenizer.generate_tokens();
289
290        // Should tokenize frequent patterns (appear > 2 times)
291        assert!(tokenizer.tokens.contains_key("src"));
292        assert!(tokenizer.tokens.contains_key("components"));
293        assert!(tokenizer.tokens.contains_key("services"));
294        assert!(tokenizer.tokens.contains_key("tests"));
295        assert!(tokenizer.tokens.contains_key("controllers"));
296    }
297
298    #[test]
299    fn test_compression() {
300        let mut tokenizer = DynamicTokenizer::new();
301
302        // Add patterns
303        for _ in 0..10 {
304            *tokenizer
305                .path_components
306                .entry("src".to_string())
307                .or_insert(0) += 1;
308            *tokenizer
309                .path_components
310                .entry("components".to_string())
311                .or_insert(0) += 1;
312        }
313
314        tokenizer.generate_tokens();
315
316        // Test compression
317        let original = "src/components/Button.tsx";
318        let compressed = tokenizer.compress_path(original);
319
320        // Should be shorter
321        assert!(compressed.len() < original.len());
322        // Should contain token markers
323        assert!(compressed.contains("{"));
324        assert!(compressed.contains("}"));
325    }
326
327    #[test]
328    fn test_token_assignment_order() {
329        let mut tokenizer = DynamicTokenizer::new();
330
331        // Add patterns with different frequencies
332        *tokenizer
333            .path_components
334            .entry("very_frequent".to_string())
335            .or_insert(0) = 100;
336        *tokenizer
337            .path_components
338            .entry("less_frequent".to_string())
339            .or_insert(0) = 50;
340        *tokenizer
341            .path_components
342            .entry("rare".to_string())
343            .or_insert(0) = 3;
344        *tokenizer
345            .path_components
346            .entry("too_rare".to_string())
347            .or_insert(0) = 1; // Won't be tokenized
348
349        tokenizer.generate_tokens();
350
351        // Most frequent should get lower token IDs
352        let very_frequent_token = tokenizer.tokens.get("very_frequent").unwrap();
353        let less_frequent_token = tokenizer.tokens.get("less_frequent").unwrap();
354
355        assert!(very_frequent_token < less_frequent_token);
356
357        // Too rare shouldn't be tokenized
358        assert!(!tokenizer.tokens.contains_key("too_rare"));
359    }
360}