Skip to main content

tool_parser/
factory.rs

1// Factory and pool for creating model-specific tool parsers with pooling support.
2
3use std::{
4    collections::HashMap,
5    sync::{Arc, RwLock},
6};
7
8use tokio::sync::Mutex;
9
10use crate::{
11    parsers::{
12        CohereParser, DeepSeekParser, Glm4MoeParser, JsonParser, KimiK2Parser, LlamaParser,
13        MinimaxM2Parser, MistralParser, PassthroughParser, PythonicParser, QwenCoderParser,
14        QwenParser, Step3Parser,
15    },
16    traits::ToolParser,
17};
18
19/// Type alias for pooled parser instances.
20pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
21
22/// Type alias for parser creator functions.
23type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
24
25/// Registry for model-specific tool parsers with pooling support.
26#[derive(Clone)]
27pub struct ParserRegistry {
28    /// Creator functions for parsers (used when pool is empty)
29    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
30    /// Pooled parser instances for reuse
31    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
32    /// Model pattern to parser name mappings
33    model_mapping: Arc<RwLock<HashMap<String, String>>>,
34    /// Default parser name
35    default_parser: Arc<RwLock<String>>,
36}
37
38impl ParserRegistry {
39    /// Create a new empty registry.
40    pub fn new() -> Self {
41        Self {
42            creators: Arc::new(RwLock::new(HashMap::new())),
43            pool: Arc::new(RwLock::new(HashMap::new())),
44            model_mapping: Arc::new(RwLock::new(HashMap::new())),
45            default_parser: Arc::new(RwLock::new("passthrough".to_string())),
46        }
47    }
48
49    /// Register a parser creator for a given parser type.
50    pub fn register_parser<F>(&self, name: &str, creator: F)
51    where
52        F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
53    {
54        let mut creators = self.creators.write().unwrap();
55        creators.insert(name.to_string(), Arc::new(creator));
56    }
57
58    /// Map a model name/pattern to a parser
59    pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
60        let mut mapping = self.model_mapping.write().unwrap();
61        mapping.insert(model.into(), parser.into());
62    }
63
64    /// Get a pooled parser by exact name.
65    /// Returns a shared parser instance from the pool, creating one if needed.
66    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
67        // First check if we have a pooled instance
68        {
69            let pool = self.pool.read().unwrap();
70            if let Some(parser) = pool.get(name) {
71                return Some(Arc::clone(parser));
72            }
73        }
74
75        // If not in pool, create one and add to pool
76        let creators = self.creators.read().unwrap();
77        if let Some(creator) = creators.get(name) {
78            let parser = Arc::new(Mutex::new(creator()));
79
80            // Add to pool for future use
81            let mut pool = self.pool.write().unwrap();
82            pool.insert(name.to_string(), Arc::clone(&parser));
83
84            Some(parser)
85        } else {
86            None
87        }
88    }
89
90    /// Check if a parser with the given name is registered.
91    pub fn has_parser(&self, name: &str) -> bool {
92        let creators = self.creators.read().unwrap();
93        creators.contains_key(name)
94    }
95
96    /// Create a fresh (non-pooled) parser instance by exact name.
97    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
98    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ToolParser>> {
99        let creators = self.creators.read().unwrap();
100        creators.get(name).map(|creator| creator())
101    }
102
103    /// Check if a parser can be created for a specific model without actually creating it.
104    /// Returns true if a parser is available (registered) for this model.
105    pub fn has_parser_for_model(&self, model: &str) -> bool {
106        // Try exact match first
107        {
108            let mapping = self.model_mapping.read().unwrap();
109            if let Some(parser_name) = mapping.get(model) {
110                let creators = self.creators.read().unwrap();
111                if creators.contains_key(parser_name) {
112                    return true;
113                }
114            }
115        }
116
117        // Try prefix matching
118        let model_mapping = self.model_mapping.read().unwrap();
119        let best_match = model_mapping
120            .iter()
121            .filter(|(pattern, _)| {
122                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
123            })
124            .max_by_key(|(pattern, _)| pattern.len());
125
126        if let Some((_, parser_name)) = best_match {
127            let creators = self.creators.read().unwrap();
128            if creators.contains_key(parser_name) {
129                return true;
130            }
131        }
132
133        // Return false if no specific parser found for this model
134        // (get_pooled will still fall back to default parser)
135        false
136    }
137
138    /// Create a fresh (non-pooled) parser instance for a specific model.
139    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
140    pub fn create_for_model(&self, model: &str) -> Option<Box<dyn ToolParser>> {
141        // Try exact match first
142        {
143            let mapping = self.model_mapping.read().unwrap();
144            if let Some(parser_name) = mapping.get(model) {
145                if let Some(parser) = self.create_parser(parser_name) {
146                    return Some(parser);
147                }
148            }
149        }
150
151        // Try prefix matching with more specific patterns first
152        let model_mapping = self.model_mapping.read().unwrap();
153        let best_match = model_mapping
154            .iter()
155            .filter(|(pattern, _)| {
156                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
157            })
158            .max_by_key(|(pattern, _)| pattern.len());
159
160        // Return the best matching parser
161        if let Some((_, parser_name)) = best_match {
162            if let Some(parser) = self.create_parser(parser_name) {
163                return Some(parser);
164            }
165        }
166
167        // Fall back to default parser
168        let default = self.default_parser.read().unwrap().clone();
169        self.create_parser(&default)
170    }
171
172    /// Get parser for a specific model
173    pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledParser> {
174        // Try exact match first
175        {
176            let mapping = self.model_mapping.read().unwrap();
177            if let Some(parser_name) = mapping.get(model) {
178                if let Some(parser) = self.get_pooled_parser(parser_name) {
179                    return Some(parser);
180                }
181            }
182        }
183
184        // Try prefix matching with more specific patterns first
185        let model_mapping = self.model_mapping.read().unwrap();
186        let best_match = model_mapping
187            .iter()
188            .filter(|(pattern, _)| {
189                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
190            })
191            .max_by_key(|(pattern, _)| pattern.len());
192
193        // Return the best matching parser
194        if let Some((_, parser_name)) = best_match {
195            if let Some(parser) = self.get_pooled_parser(parser_name) {
196                return Some(parser);
197            }
198        }
199
200        // Fall back to default parser
201        let default = self.default_parser.read().unwrap().clone();
202        self.get_pooled_parser(&default)
203    }
204
205    /// Clear the parser pool, forcing new instances to be created.
206    pub fn clear_pool(&self) {
207        let mut pool = self.pool.write().unwrap();
208        pool.clear();
209    }
210
211    /// Set the default parser
212    pub fn set_default_parser(&self, name: impl Into<String>) {
213        let mut default = self.default_parser.write().unwrap();
214        *default = name.into();
215    }
216}
217
218impl Default for ParserRegistry {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224/// Factory for creating tool parsers based on model type.
225#[derive(Clone)]
226pub struct ParserFactory {
227    registry: ParserRegistry,
228}
229
230impl ParserFactory {
231    /// Create a new factory with default parsers registered.
232    pub fn new() -> Self {
233        let registry = ParserRegistry::new();
234
235        // Register default parsers
236        registry.register_parser("passthrough", || Box::new(PassthroughParser::new()));
237        registry.register_parser("json", || Box::new(JsonParser::new()));
238        registry.register_parser("mistral", || Box::new(MistralParser::new()));
239        registry.register_parser("qwen", || Box::new(QwenParser::new()));
240        registry.register_parser("qwen_coder", || Box::new(QwenCoderParser::new()));
241        registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
242        registry.register_parser("llama", || Box::new(LlamaParser::new()));
243        registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
244        registry.register_parser("glm45_moe", || Box::new(Glm4MoeParser::glm45()));
245        registry.register_parser("glm47_moe", || Box::new(Glm4MoeParser::glm47()));
246        registry.register_parser("step3", || Box::new(Step3Parser::new()));
247        registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
248        registry.register_parser("minimax_m2", || Box::new(MinimaxM2Parser::new()));
249        registry.register_parser("cohere", || Box::new(CohereParser::new()));
250
251        // Register default model mappings
252        Self::register_default_mappings(&registry);
253
254        Self { registry }
255    }
256
257    fn register_default_mappings(registry: &ParserRegistry) {
258        // OpenAI models
259        registry.map_model("gpt-4*", "json");
260        registry.map_model("gpt-3.5*", "json");
261        registry.map_model("gpt-4o*", "json");
262
263        // Anthropic models
264        registry.map_model("claude-*", "json");
265
266        // Mistral models
267        registry.map_model("mistral-*", "mistral");
268        registry.map_model("mixtral-*", "mistral");
269
270        // Qwen models (more specific patterns first - longer patterns take precedence)
271        // Qwen Coder models use XML format: <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
272        registry.map_model("Qwen/Qwen3-Coder*", "qwen_coder");
273        registry.map_model("Qwen3-Coder*", "qwen_coder");
274        registry.map_model("qwen3-coder*", "qwen_coder");
275        registry.map_model("Qwen/Qwen2.5-Coder*", "qwen_coder");
276        registry.map_model("Qwen2.5-Coder*", "qwen_coder");
277        registry.map_model("qwen2.5-coder*", "qwen_coder");
278        // Generic Qwen models use JSON format
279        registry.map_model("qwen*", "qwen");
280        registry.map_model("Qwen*", "qwen");
281
282        // Llama models
283        registry.map_model("llama-4*", "pythonic");
284        registry.map_model("meta-llama-4*", "pythonic");
285        registry.map_model("llama-3.2*", "llama");
286        registry.map_model("meta-llama-3.2*", "llama");
287        registry.map_model("llama-*", "json");
288        registry.map_model("meta-llama-*", "json");
289
290        // DeepSeek models
291        registry.map_model("deepseek-v3*", "deepseek");
292        registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
293        registry.map_model("deepseek-*", "pythonic");
294
295        // GLM models
296        registry.map_model("glm-4.5*", "glm45_moe");
297        registry.map_model("glm-4.6*", "glm45_moe");
298        registry.map_model("glm-4.7*", "glm47_moe");
299        registry.map_model("glm-*", "json");
300
301        // Step3 models
302        registry.map_model("step3*", "step3");
303        registry.map_model("Step-3*", "step3");
304
305        // Kimi models
306        registry.map_model("kimi-k2*", "kimik2");
307        registry.map_model("Kimi-K2*", "kimik2");
308        registry.map_model("moonshot*/Kimi-K2*", "kimik2");
309
310        // MiniMax models
311        registry.map_model("minimax*", "minimax_m2");
312        registry.map_model("MiniMax*", "minimax_m2");
313
314        // Cohere models
315        registry.map_model("command-r*", "cohere");
316        registry.map_model("command-r-plus*", "cohere");
317        registry.map_model("command-a*", "cohere");
318        registry.map_model("c4ai-command*", "cohere");
319        registry.map_model("cohere*", "cohere");
320        registry.map_model("CohereForAI*", "cohere");
321
322        // Other models
323        registry.map_model("gemini-*", "json");
324        registry.map_model("palm-*", "json");
325        registry.map_model("gemma-*", "json");
326    }
327
328    /// Get a pooled parser for the given model ID.
329    /// Returns a shared instance that can be used concurrently.
330    /// Falls back to passthrough parser if model is not recognized.
331    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
332        self.registry
333            .get_pooled_for_model(model_id)
334            .unwrap_or_else(|| {
335                // Fallback to passthrough parser (no-op, returns text unchanged)
336                self.registry
337                    .get_pooled_parser("passthrough")
338                    .expect("Passthrough parser should always be registered")
339            })
340    }
341
342    /// Get the internal registry for custom registration.
343    pub fn registry(&self) -> &ParserRegistry {
344        &self.registry
345    }
346
347    /// Clear the parser pool.
348    pub fn clear_pool(&self) {
349        self.registry.clear_pool();
350    }
351
352    /// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
353    /// This is useful for benchmarks and testing where you want independent parser instances.
354    pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
355        // Determine which parser type to use
356        let parser_type = {
357            let mapping = self.registry.model_mapping.read().unwrap();
358
359            // Try exact match first
360            if let Some(parser_name) = mapping.get(model_id) {
361                parser_name.clone()
362            } else {
363                // Try prefix matching
364                let best_match = mapping
365                    .iter()
366                    .filter(|(pattern, _)| {
367                        pattern.ends_with('*')
368                            && model_id.starts_with(&pattern[..pattern.len() - 1])
369                    })
370                    .max_by_key(|(pattern, _)| pattern.len());
371
372                if let Some((_, parser_name)) = best_match {
373                    parser_name.clone()
374                } else {
375                    // Fall back to default
376                    self.registry.default_parser.read().unwrap().clone()
377                }
378            }
379        };
380
381        let creators = self.registry.creators.read().unwrap();
382        creators.get(&parser_type).map(|creator| {
383            // Call the creator to get a Box<dyn ToolParser>, then convert to Arc
384            let boxed_parser = creator();
385            Arc::from(boxed_parser)
386        })
387    }
388
389    /// List all registered parsers (for compatibility with old API).
390    pub fn list_parsers(&self) -> Vec<String> {
391        self.registry
392            .creators
393            .read()
394            .unwrap()
395            .keys()
396            .cloned()
397            .collect()
398    }
399}
400
401impl Default for ParserFactory {
402    fn default() -> Self {
403        Self::new()
404    }
405}