Skip to main content

token_count/tokenizers/
registry.rs

1//! Model registry for managing supported models
2
3use crate::error::TokenError;
4use crate::tokenizers::{openai::OpenAITokenizer, ModelInfo, Tokenizer};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::OnceLock;
8
9/// Configuration for a specific model
10#[derive(Debug, Clone)]
11pub struct ModelConfig {
12    pub name: String,
13    pub encoding: String,
14    pub context_window: usize,
15    pub description: String,
16    pub aliases: Vec<String>,
17}
18
19/// Registry of all supported models
20pub struct ModelRegistry {
21    models: HashMap<String, ModelConfig>,
22    aliases: HashMap<String, String>, // alias → canonical name
23}
24
25impl ModelRegistry {
26    /// Create a new model registry with all supported models
27    pub fn new() -> Self {
28        let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
29
30        // GPT-3.5-turbo (default)
31        registry.add_model(ModelConfig {
32            name: "gpt-3.5-turbo".to_string(),
33            encoding: "cl100k_base".to_string(),
34            context_window: 16385,
35            description: "GPT-3.5 Turbo (16K context)".to_string(),
36            aliases: vec![
37                "gpt-3.5".to_string(),
38                "gpt35".to_string(),
39                "gpt-35-turbo".to_string(),
40                "openai/gpt-3.5-turbo".to_string(),
41            ],
42        });
43
44        // GPT-4
45        registry.add_model(ModelConfig {
46            name: "gpt-4".to_string(),
47            encoding: "cl100k_base".to_string(),
48            context_window: 128000,
49            description: "GPT-4 (128K context)".to_string(),
50            aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
51        });
52
53        // GPT-4-turbo
54        registry.add_model(ModelConfig {
55            name: "gpt-4-turbo".to_string(),
56            encoding: "cl100k_base".to_string(),
57            context_window: 128000,
58            description: "GPT-4 Turbo (128K context)".to_string(),
59            aliases: vec![
60                "gpt4-turbo".to_string(),
61                "gpt-4turbo".to_string(),
62                "openai/gpt-4-turbo".to_string(),
63            ],
64        });
65
66        // GPT-4o
67        registry.add_model(ModelConfig {
68            name: "gpt-4o".to_string(),
69            encoding: "o200k_base".to_string(),
70            context_window: 128000,
71            description: "GPT-4o (128K context)".to_string(),
72            aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
73        });
74
75        registry
76    }
77
78    /// Get the global registry instance
79    pub fn global() -> &'static Self {
80        static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
81        REGISTRY.get_or_init(Self::new)
82    }
83
84    /// Add a model to the registry
85    fn add_model(&mut self, config: ModelConfig) {
86        let canonical_name = config.name.clone();
87
88        // Add aliases
89        for alias in &config.aliases {
90            self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
91        }
92
93        // Add model itself as an alias (case-insensitive)
94        self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
95
96        self.models.insert(canonical_name, config);
97    }
98
99    /// Resolve a model name (canonical or alias) to its canonical name
100    pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
101        let normalized = name.trim().to_lowercase();
102
103        if let Some(canonical) = self.aliases.get(&normalized) {
104            return Ok(canonical.clone());
105        }
106
107        // Model not found - generate suggestions
108        let suggestion = self.generate_suggestions(&normalized);
109        Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
110    }
111
112    /// Get a model configuration by name (canonical or alias)
113    pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
114        let canonical = self.resolve_model_name(name)?;
115        Ok(self.models.get(&canonical).expect("Canonical name must exist"))
116    }
117
118    /// Create a tokenizer for the given model
119    pub fn get_tokenizer(&self, name: &str) -> Result<Box<dyn Tokenizer>, TokenError> {
120        let config = self.get_model(name)?;
121
122        let model_info = ModelInfo {
123            name: config.name.clone(),
124            encoding: config.encoding.clone(),
125            context_window: config.context_window,
126            description: config.description.clone(),
127        };
128
129        let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
130            .map_err(|e| TokenError::Tokenization(e.to_string()))?;
131
132        Ok(Box::new(tokenizer))
133    }
134
135    /// List all supported models
136    pub fn list_models(&self) -> Vec<&ModelConfig> {
137        let mut models: Vec<&ModelConfig> = self.models.values().collect();
138        models.sort_by(|a, b| a.name.cmp(&b.name));
139        models
140    }
141
142    /// Generate fuzzy suggestions for unknown model names
143    fn generate_suggestions(&self, name: &str) -> String {
144        let mut suggestions: Vec<(&str, usize)> = self
145            .models
146            .keys()
147            .map(|model_name| {
148                let distance = strsim::levenshtein(name, &model_name.to_lowercase());
149                (model_name.as_str(), distance)
150            })
151            .collect();
152
153        suggestions.sort_by_key(|&(_, dist)| dist);
154
155        let close_matches: Vec<&str> = suggestions
156            .iter()
157            .take(3)
158            .filter(|&&(_, dist)| dist <= 3)
159            .map(|&(name, _)| name)
160            .collect();
161
162        if close_matches.is_empty() {
163            "Use --list-models to see all supported models".to_string()
164        } else {
165            format!("Did you mean: {}?", close_matches.join(", "))
166        }
167    }
168}
169
170impl Default for ModelRegistry {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_resolve_canonical_name() {
182        let registry = ModelRegistry::new();
183        assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
184        assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
185    }
186
187    #[test]
188    fn test_resolve_alias() {
189        let registry = ModelRegistry::new();
190        assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
191        assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
192    }
193
194    #[test]
195    fn test_unknown_model() {
196        let registry = ModelRegistry::new();
197        let result = registry.resolve_model_name("gpt-5");
198        assert!(result.is_err());
199        assert!(result.unwrap_err().to_string().contains("gpt"));
200    }
201
202    #[test]
203    fn test_list_models() {
204        let registry = ModelRegistry::new();
205        let models = registry.list_models();
206        assert_eq!(models.len(), 4);
207        assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
208        assert!(models.iter().any(|m| m.name == "gpt-4"));
209        assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
210        assert!(models.iter().any(|m| m.name == "gpt-4o"));
211    }
212
213    #[test]
214    fn test_get_tokenizer() {
215        let registry = ModelRegistry::new();
216        let tokenizer = registry.get_tokenizer("gpt-4").unwrap();
217        let count = tokenizer.count_tokens("Hello world").unwrap();
218        assert_eq!(count, 2);
219    }
220
221    #[test]
222    fn test_fuzzy_suggestions() {
223        let registry = ModelRegistry::new();
224        let result = registry.resolve_model_name("gpt4-tubro");
225        assert!(result.is_err());
226        let err = result.unwrap_err();
227        assert!(err.to_string().contains("Did you mean"));
228        assert!(err.to_string().contains("gpt-4-turbo"));
229    }
230}