tool_parser/
factory.rs

1// Factory and pool for creating model-specific tool parsers with pooling support.
2
3use std::{
4    collections::HashMap,
5    sync::{Arc, RwLock},
6};
7
8use tokio::sync::Mutex;
9
10use crate::{
11    parsers::{
12        DeepSeekParser, Glm4MoeParser, JsonParser, KimiK2Parser, LlamaParser, MinimaxM2Parser,
13        MistralParser, PassthroughParser, PythonicParser, QwenCoderParser, QwenParser, Step3Parser,
14    },
15    traits::ToolParser,
16};
17
18/// Type alias for pooled parser instances.
19pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
20
21/// Type alias for parser creator functions.
22type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
23
24/// Registry for model-specific tool parsers with pooling support.
25#[derive(Clone)]
26pub struct ParserRegistry {
27    /// Creator functions for parsers (used when pool is empty)
28    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
29    /// Pooled parser instances for reuse
30    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
31    /// Model pattern to parser name mappings
32    model_mapping: Arc<RwLock<HashMap<String, String>>>,
33    /// Default parser name
34    default_parser: Arc<RwLock<String>>,
35}
36
37impl ParserRegistry {
38    /// Create a new empty registry.
39    pub fn new() -> Self {
40        Self {
41            creators: Arc::new(RwLock::new(HashMap::new())),
42            pool: Arc::new(RwLock::new(HashMap::new())),
43            model_mapping: Arc::new(RwLock::new(HashMap::new())),
44            default_parser: Arc::new(RwLock::new("passthrough".to_string())),
45        }
46    }
47
48    /// Register a parser creator for a given parser type.
49    pub fn register_parser<F>(&self, name: &str, creator: F)
50    where
51        F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
52    {
53        let mut creators = self.creators.write().unwrap();
54        creators.insert(name.to_string(), Arc::new(creator));
55    }
56
57    /// Map a model name/pattern to a parser
58    pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
59        let mut mapping = self.model_mapping.write().unwrap();
60        mapping.insert(model.into(), parser.into());
61    }
62
63    /// Get a pooled parser by exact name.
64    /// Returns a shared parser instance from the pool, creating one if needed.
65    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
66        // First check if we have a pooled instance
67        {
68            let pool = self.pool.read().unwrap();
69            if let Some(parser) = pool.get(name) {
70                return Some(Arc::clone(parser));
71            }
72        }
73
74        // If not in pool, create one and add to pool
75        let creators = self.creators.read().unwrap();
76        if let Some(creator) = creators.get(name) {
77            let parser = Arc::new(Mutex::new(creator()));
78
79            // Add to pool for future use
80            let mut pool = self.pool.write().unwrap();
81            pool.insert(name.to_string(), Arc::clone(&parser));
82
83            Some(parser)
84        } else {
85            None
86        }
87    }
88
89    /// Check if a parser with the given name is registered.
90    pub fn has_parser(&self, name: &str) -> bool {
91        let creators = self.creators.read().unwrap();
92        creators.contains_key(name)
93    }
94
95    /// Create a fresh (non-pooled) parser instance by exact name.
96    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
97    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ToolParser>> {
98        let creators = self.creators.read().unwrap();
99        creators.get(name).map(|creator| creator())
100    }
101
102    /// Check if a parser can be created for a specific model without actually creating it.
103    /// Returns true if a parser is available (registered) for this model.
104    pub fn has_parser_for_model(&self, model: &str) -> bool {
105        // Try exact match first
106        {
107            let mapping = self.model_mapping.read().unwrap();
108            if let Some(parser_name) = mapping.get(model) {
109                let creators = self.creators.read().unwrap();
110                if creators.contains_key(parser_name) {
111                    return true;
112                }
113            }
114        }
115
116        // Try prefix matching
117        let model_mapping = self.model_mapping.read().unwrap();
118        let best_match = model_mapping
119            .iter()
120            .filter(|(pattern, _)| {
121                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
122            })
123            .max_by_key(|(pattern, _)| pattern.len());
124
125        if let Some((_, parser_name)) = best_match {
126            let creators = self.creators.read().unwrap();
127            if creators.contains_key(parser_name) {
128                return true;
129            }
130        }
131
132        // Return false if no specific parser found for this model
133        // (get_pooled will still fall back to default parser)
134        false
135    }
136
137    /// Create a fresh (non-pooled) parser instance for a specific model.
138    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
139    pub fn create_for_model(&self, model: &str) -> Option<Box<dyn ToolParser>> {
140        // Try exact match first
141        {
142            let mapping = self.model_mapping.read().unwrap();
143            if let Some(parser_name) = mapping.get(model) {
144                if let Some(parser) = self.create_parser(parser_name) {
145                    return Some(parser);
146                }
147            }
148        }
149
150        // Try prefix matching with more specific patterns first
151        let model_mapping = self.model_mapping.read().unwrap();
152        let best_match = model_mapping
153            .iter()
154            .filter(|(pattern, _)| {
155                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
156            })
157            .max_by_key(|(pattern, _)| pattern.len());
158
159        // Return the best matching parser
160        if let Some((_, parser_name)) = best_match {
161            if let Some(parser) = self.create_parser(parser_name) {
162                return Some(parser);
163            }
164        }
165
166        // Fall back to default parser
167        let default = self.default_parser.read().unwrap().clone();
168        self.create_parser(&default)
169    }
170
171    /// Get parser for a specific model
172    pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledParser> {
173        // Try exact match first
174        {
175            let mapping = self.model_mapping.read().unwrap();
176            if let Some(parser_name) = mapping.get(model) {
177                if let Some(parser) = self.get_pooled_parser(parser_name) {
178                    return Some(parser);
179                }
180            }
181        }
182
183        // Try prefix matching with more specific patterns first
184        let model_mapping = self.model_mapping.read().unwrap();
185        let best_match = model_mapping
186            .iter()
187            .filter(|(pattern, _)| {
188                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
189            })
190            .max_by_key(|(pattern, _)| pattern.len());
191
192        // Return the best matching parser
193        if let Some((_, parser_name)) = best_match {
194            if let Some(parser) = self.get_pooled_parser(parser_name) {
195                return Some(parser);
196            }
197        }
198
199        // Fall back to default parser
200        let default = self.default_parser.read().unwrap().clone();
201        self.get_pooled_parser(&default)
202    }
203
204    /// Clear the parser pool, forcing new instances to be created.
205    pub fn clear_pool(&self) {
206        let mut pool = self.pool.write().unwrap();
207        pool.clear();
208    }
209
210    /// Set the default parser
211    pub fn set_default_parser(&self, name: impl Into<String>) {
212        let mut default = self.default_parser.write().unwrap();
213        *default = name.into();
214    }
215}
216
217impl Default for ParserRegistry {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223/// Factory for creating tool parsers based on model type.
224#[derive(Clone)]
225pub struct ParserFactory {
226    registry: ParserRegistry,
227}
228
229impl ParserFactory {
230    /// Create a new factory with default parsers registered.
231    pub fn new() -> Self {
232        let registry = ParserRegistry::new();
233
234        // Register default parsers
235        registry.register_parser("passthrough", || Box::new(PassthroughParser::new()));
236        registry.register_parser("json", || Box::new(JsonParser::new()));
237        registry.register_parser("mistral", || Box::new(MistralParser::new()));
238        registry.register_parser("qwen", || Box::new(QwenParser::new()));
239        registry.register_parser("qwen_coder", || Box::new(QwenCoderParser::new()));
240        registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
241        registry.register_parser("llama", || Box::new(LlamaParser::new()));
242        registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
243        registry.register_parser("glm45_moe", || Box::new(Glm4MoeParser::glm45()));
244        registry.register_parser("glm47_moe", || Box::new(Glm4MoeParser::glm47()));
245        registry.register_parser("step3", || Box::new(Step3Parser::new()));
246        registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
247        registry.register_parser("minimax_m2", || Box::new(MinimaxM2Parser::new()));
248
249        // Register default model mappings
250        Self::register_default_mappings(&registry);
251
252        Self { registry }
253    }
254
255    fn register_default_mappings(registry: &ParserRegistry) {
256        // OpenAI models
257        registry.map_model("gpt-4*", "json");
258        registry.map_model("gpt-3.5*", "json");
259        registry.map_model("gpt-4o*", "json");
260
261        // Anthropic models
262        registry.map_model("claude-*", "json");
263
264        // Mistral models
265        registry.map_model("mistral-*", "mistral");
266        registry.map_model("mixtral-*", "mistral");
267
268        // Qwen models (more specific patterns first - longer patterns take precedence)
269        // Qwen Coder models use XML format: <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
270        registry.map_model("Qwen/Qwen3-Coder*", "qwen_coder");
271        registry.map_model("Qwen3-Coder*", "qwen_coder");
272        registry.map_model("qwen3-coder*", "qwen_coder");
273        registry.map_model("Qwen/Qwen2.5-Coder*", "qwen_coder");
274        registry.map_model("Qwen2.5-Coder*", "qwen_coder");
275        registry.map_model("qwen2.5-coder*", "qwen_coder");
276        // Generic Qwen models use JSON format
277        registry.map_model("qwen*", "qwen");
278        registry.map_model("Qwen*", "qwen");
279
280        // Llama models
281        registry.map_model("llama-4*", "pythonic");
282        registry.map_model("meta-llama-4*", "pythonic");
283        registry.map_model("llama-3.2*", "llama");
284        registry.map_model("meta-llama-3.2*", "llama");
285        registry.map_model("llama-*", "json");
286        registry.map_model("meta-llama-*", "json");
287
288        // DeepSeek models
289        registry.map_model("deepseek-v3*", "deepseek");
290        registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
291        registry.map_model("deepseek-*", "pythonic");
292
293        // GLM models
294        registry.map_model("glm-4.5*", "glm45_moe");
295        registry.map_model("glm-4.6*", "glm45_moe");
296        registry.map_model("glm-4.7*", "glm47_moe");
297        registry.map_model("glm-*", "json");
298
299        // Step3 models
300        registry.map_model("step3*", "step3");
301        registry.map_model("Step-3*", "step3");
302
303        // Kimi models
304        registry.map_model("kimi-k2*", "kimik2");
305        registry.map_model("Kimi-K2*", "kimik2");
306        registry.map_model("moonshot*/Kimi-K2*", "kimik2");
307
308        // MiniMax models
309        registry.map_model("minimax*", "minimax_m2");
310        registry.map_model("MiniMax*", "minimax_m2");
311
312        // Other models
313        registry.map_model("gemini-*", "json");
314        registry.map_model("palm-*", "json");
315        registry.map_model("gemma-*", "json");
316    }
317
318    /// Get a pooled parser for the given model ID.
319    /// Returns a shared instance that can be used concurrently.
320    /// Falls back to passthrough parser if model is not recognized.
321    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
322        self.registry
323            .get_pooled_for_model(model_id)
324            .unwrap_or_else(|| {
325                // Fallback to passthrough parser (no-op, returns text unchanged)
326                self.registry
327                    .get_pooled_parser("passthrough")
328                    .expect("Passthrough parser should always be registered")
329            })
330    }
331
332    /// Get the internal registry for custom registration.
333    pub fn registry(&self) -> &ParserRegistry {
334        &self.registry
335    }
336
337    /// Clear the parser pool.
338    pub fn clear_pool(&self) {
339        self.registry.clear_pool();
340    }
341
342    /// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
343    /// This is useful for benchmarks and testing where you want independent parser instances.
344    pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
345        // Determine which parser type to use
346        let parser_type = {
347            let mapping = self.registry.model_mapping.read().unwrap();
348
349            // Try exact match first
350            if let Some(parser_name) = mapping.get(model_id) {
351                parser_name.clone()
352            } else {
353                // Try prefix matching
354                let best_match = mapping
355                    .iter()
356                    .filter(|(pattern, _)| {
357                        pattern.ends_with('*')
358                            && model_id.starts_with(&pattern[..pattern.len() - 1])
359                    })
360                    .max_by_key(|(pattern, _)| pattern.len());
361
362                if let Some((_, parser_name)) = best_match {
363                    parser_name.clone()
364                } else {
365                    // Fall back to default
366                    self.registry.default_parser.read().unwrap().clone()
367                }
368            }
369        };
370
371        let creators = self.registry.creators.read().unwrap();
372        creators.get(&parser_type).map(|creator| {
373            // Call the creator to get a Box<dyn ToolParser>, then convert to Arc
374            let boxed_parser = creator();
375            Arc::from(boxed_parser)
376        })
377    }
378
379    /// List all registered parsers (for compatibility with old API).
380    pub fn list_parsers(&self) -> Vec<String> {
381        self.registry
382            .creators
383            .read()
384            .unwrap()
385            .keys()
386            .cloned()
387            .collect()
388    }
389}
390
391impl Default for ParserFactory {
392    fn default() -> Self {
393        Self::new()
394    }
395}