Skip to main content

tool_parser/
factory.rs

1// Factory and pool for creating model-specific tool parsers with pooling support.
2
3use std::{collections::HashMap, sync::Arc};
4
5use parking_lot::RwLock;
6use tokio::sync::Mutex;
7
8use crate::{
9    parsers::{
10        CohereParser, DeepSeekParser, Glm4MoeParser, JsonParser, KimiK2Parser, LlamaParser,
11        MinimaxM2Parser, MistralParser, PassthroughParser, PythonicParser, QwenCoderParser,
12        QwenParser, Step3Parser,
13    },
14    traits::ToolParser,
15};
16
17/// Type alias for pooled parser instances.
18pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
19
20/// Type alias for parser creator functions.
21type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
22
23/// Registry for model-specific tool parsers with pooling support.
24#[derive(Clone)]
25pub struct ParserRegistry {
26    /// Creator functions for parsers (used when pool is empty)
27    creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
28    /// Pooled parser instances for reuse
29    pool: Arc<RwLock<HashMap<String, PooledParser>>>,
30    /// Model pattern to parser name mappings
31    model_mapping: Arc<RwLock<HashMap<String, String>>>,
32    /// Default parser name
33    default_parser: Arc<RwLock<String>>,
34}
35
36impl ParserRegistry {
37    /// Create a new empty registry.
38    pub fn new() -> Self {
39        Self {
40            creators: Arc::new(RwLock::new(HashMap::new())),
41            pool: Arc::new(RwLock::new(HashMap::new())),
42            model_mapping: Arc::new(RwLock::new(HashMap::new())),
43            default_parser: Arc::new(RwLock::new("passthrough".to_string())),
44        }
45    }
46
47    /// Register a parser creator for a given parser type.
48    pub fn register_parser<F>(&self, name: &str, creator: F)
49    where
50        F: Fn() -> Box<dyn ToolParser> + Send + Sync + 'static,
51    {
52        let mut creators = self.creators.write();
53        creators.insert(name.to_string(), Arc::new(creator));
54    }
55
56    /// Map a model name/pattern to a parser
57    pub fn map_model(&self, model: impl Into<String>, parser: impl Into<String>) {
58        let mut mapping = self.model_mapping.write();
59        mapping.insert(model.into(), parser.into());
60    }
61
62    /// Get a pooled parser by exact name.
63    /// Returns a shared parser instance from the pool, creating one if needed.
64    pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
65        // First check if we have a pooled instance
66        {
67            let pool = self.pool.read();
68            if let Some(parser) = pool.get(name) {
69                return Some(Arc::clone(parser));
70            }
71        }
72
73        // If not in pool, create one and add to pool
74        let creators = self.creators.read();
75        if let Some(creator) = creators.get(name) {
76            let parser = Arc::new(Mutex::new(creator()));
77
78            // Add to pool for future use
79            let mut pool = self.pool.write();
80            pool.insert(name.to_string(), Arc::clone(&parser));
81
82            Some(parser)
83        } else {
84            None
85        }
86    }
87
88    /// Check if a parser with the given name is registered.
89    pub fn has_parser(&self, name: &str) -> bool {
90        let creators = self.creators.read();
91        creators.contains_key(name)
92    }
93
94    /// Create a fresh (non-pooled) parser instance by exact name.
95    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
96    pub fn create_parser(&self, name: &str) -> Option<Box<dyn ToolParser>> {
97        let creators = self.creators.read();
98        creators.get(name).map(|creator| creator())
99    }
100
101    /// Check if a parser can be created for a specific model without actually creating it.
102    /// Returns true if a parser is available (registered) for this model.
103    pub fn has_parser_for_model(&self, model: &str) -> bool {
104        // Try exact match first
105        {
106            let mapping = self.model_mapping.read();
107            if let Some(parser_name) = mapping.get(model) {
108                let creators = self.creators.read();
109                if creators.contains_key(parser_name) {
110                    return true;
111                }
112            }
113        }
114
115        // Try prefix matching
116        let model_mapping = self.model_mapping.read();
117        let best_match = model_mapping
118            .iter()
119            .filter(|(pattern, _)| {
120                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
121            })
122            .max_by_key(|(pattern, _)| pattern.len());
123
124        if let Some((_, parser_name)) = best_match {
125            let creators = self.creators.read();
126            if creators.contains_key(parser_name) {
127                return true;
128            }
129        }
130
131        // Return false if no specific parser found for this model
132        // (get_pooled will still fall back to default parser)
133        false
134    }
135
136    /// Create a fresh (non-pooled) parser instance for a specific model.
137    /// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
138    pub fn create_for_model(&self, model: &str) -> Option<Box<dyn ToolParser>> {
139        // Try exact match first
140        {
141            let mapping = self.model_mapping.read();
142            if let Some(parser_name) = mapping.get(model) {
143                if let Some(parser) = self.create_parser(parser_name) {
144                    return Some(parser);
145                }
146            }
147        }
148
149        // Try prefix matching with more specific patterns first
150        let model_mapping = self.model_mapping.read();
151        let best_match = model_mapping
152            .iter()
153            .filter(|(pattern, _)| {
154                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
155            })
156            .max_by_key(|(pattern, _)| pattern.len());
157
158        // Return the best matching parser
159        if let Some((_, parser_name)) = best_match {
160            if let Some(parser) = self.create_parser(parser_name) {
161                return Some(parser);
162            }
163        }
164
165        // Fall back to default parser
166        let default = self.default_parser.read().clone();
167        self.create_parser(&default)
168    }
169
170    /// Get parser for a specific model
171    pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledParser> {
172        // Try exact match first
173        {
174            let mapping = self.model_mapping.read();
175            if let Some(parser_name) = mapping.get(model) {
176                if let Some(parser) = self.get_pooled_parser(parser_name) {
177                    return Some(parser);
178                }
179            }
180        }
181
182        // Try prefix matching with more specific patterns first
183        let model_mapping = self.model_mapping.read();
184        let best_match = model_mapping
185            .iter()
186            .filter(|(pattern, _)| {
187                pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
188            })
189            .max_by_key(|(pattern, _)| pattern.len());
190
191        // Return the best matching parser
192        if let Some((_, parser_name)) = best_match {
193            if let Some(parser) = self.get_pooled_parser(parser_name) {
194                return Some(parser);
195            }
196        }
197
198        // Fall back to default parser
199        let default = self.default_parser.read().clone();
200        self.get_pooled_parser(&default)
201    }
202
203    /// Clear the parser pool, forcing new instances to be created.
204    pub fn clear_pool(&self) {
205        let mut pool = self.pool.write();
206        pool.clear();
207    }
208
209    /// Set the default parser
210    pub fn set_default_parser(&self, name: impl Into<String>) {
211        let mut default = self.default_parser.write();
212        *default = name.into();
213    }
214}
215
216impl Default for ParserRegistry {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222/// Factory for creating tool parsers based on model type.
223#[derive(Clone)]
224pub struct ParserFactory {
225    registry: ParserRegistry,
226}
227
228impl ParserFactory {
229    /// Create a new factory with default parsers registered.
230    pub fn new() -> Self {
231        let registry = ParserRegistry::new();
232
233        // Register default parsers
234        registry.register_parser("passthrough", || Box::new(PassthroughParser::new()));
235        registry.register_parser("json", || Box::new(JsonParser::new()));
236        registry.register_parser("mistral", || Box::new(MistralParser::new()));
237        registry.register_parser("qwen", || Box::new(QwenParser::new()));
238        registry.register_parser("qwen_coder", || Box::new(QwenCoderParser::new()));
239        registry.register_parser("pythonic", || Box::new(PythonicParser::new()));
240        registry.register_parser("llama", || Box::new(LlamaParser::new()));
241        registry.register_parser("deepseek", || Box::new(DeepSeekParser::new()));
242        registry.register_parser("glm45_moe", || Box::new(Glm4MoeParser::glm45()));
243        registry.register_parser("glm47_moe", || Box::new(Glm4MoeParser::glm47()));
244        registry.register_parser("step3", || Box::new(Step3Parser::new()));
245        registry.register_parser("kimik2", || Box::new(KimiK2Parser::new()));
246        registry.register_parser("minimax_m2", || Box::new(MinimaxM2Parser::new()));
247        registry.register_parser("cohere", || Box::new(CohereParser::new()));
248
249        // Register default model mappings
250        Self::register_default_mappings(&registry);
251
252        Self { registry }
253    }
254
255    fn register_default_mappings(registry: &ParserRegistry) {
256        // OpenAI models
257        registry.map_model("gpt-4*", "json");
258        registry.map_model("gpt-3.5*", "json");
259        registry.map_model("gpt-4o*", "json");
260
261        // Anthropic models
262        registry.map_model("claude-*", "json");
263
264        // Mistral models
265        registry.map_model("mistral-*", "mistral");
266        registry.map_model("mixtral-*", "mistral");
267
268        // Qwen models (more specific patterns first - longer patterns take precedence)
269        // Qwen Coder models use XML format: <tool_call><function=name><parameter=key>value</parameter></function></tool_call>
270        registry.map_model("Qwen/Qwen3-Coder*", "qwen_coder");
271        registry.map_model("Qwen3-Coder*", "qwen_coder");
272        registry.map_model("qwen3-coder*", "qwen_coder");
273        registry.map_model("Qwen/Qwen2.5-Coder*", "qwen_coder");
274        registry.map_model("Qwen2.5-Coder*", "qwen_coder");
275        registry.map_model("qwen2.5-coder*", "qwen_coder");
276        // Generic Qwen models use JSON format
277        registry.map_model("qwen*", "qwen");
278        registry.map_model("Qwen*", "qwen");
279
280        // Llama models
281        registry.map_model("llama-4*", "pythonic");
282        registry.map_model("meta-llama-4*", "pythonic");
283        registry.map_model("llama-3.2*", "llama");
284        registry.map_model("meta-llama-3.2*", "llama");
285        registry.map_model("llama-*", "json");
286        registry.map_model("meta-llama-*", "json");
287
288        // DeepSeek models
289        registry.map_model("deepseek-v3*", "deepseek");
290        registry.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
291        registry.map_model("deepseek-*", "pythonic");
292
293        // GLM models
294        registry.map_model("glm-4.5*", "glm45_moe");
295        registry.map_model("glm-4.6*", "glm45_moe");
296        registry.map_model("glm-4.7*", "glm47_moe");
297        registry.map_model("glm-*", "json");
298
299        // Step3 models
300        registry.map_model("step3*", "step3");
301        registry.map_model("Step-3*", "step3");
302
303        // Kimi models
304        registry.map_model("kimi-k2*", "kimik2");
305        registry.map_model("Kimi-K2*", "kimik2");
306        registry.map_model("moonshot*/Kimi-K2*", "kimik2");
307
308        // MiniMax models
309        registry.map_model("minimax*", "minimax_m2");
310        registry.map_model("MiniMax*", "minimax_m2");
311
312        // Cohere models
313        registry.map_model("command-r*", "cohere");
314        registry.map_model("command-r-plus*", "cohere");
315        registry.map_model("command-a*", "cohere");
316        registry.map_model("c4ai-command*", "cohere");
317        registry.map_model("cohere*", "cohere");
318        registry.map_model("CohereForAI*", "cohere");
319
320        // Other models
321        registry.map_model("gemini-*", "json");
322        registry.map_model("palm-*", "json");
323        registry.map_model("gemma-*", "json");
324    }
325
326    /// Get a pooled parser for the given model ID.
327    /// Returns a shared instance that can be used concurrently.
328    /// Falls back to passthrough parser if model is not recognized.
329    #[expect(
330        clippy::expect_used,
331        reason = "passthrough parser is registered in new(); None indicates a bug in registration logic"
332    )]
333    pub fn get_pooled(&self, model_id: &str) -> PooledParser {
334        self.registry
335            .get_pooled_for_model(model_id)
336            .unwrap_or_else(|| {
337                // Fallback to passthrough parser (no-op, returns text unchanged)
338                self.registry
339                    .get_pooled_parser("passthrough")
340                    .expect("Passthrough parser should always be registered")
341            })
342    }
343
344    /// Get the internal registry for custom registration.
345    pub fn registry(&self) -> &ParserRegistry {
346        &self.registry
347    }
348
349    /// Clear the parser pool.
350    pub fn clear_pool(&self) {
351        self.registry.clear_pool();
352    }
353
354    /// Get a non-pooled parser for the given model ID (creates a fresh instance each time).
355    /// This is useful for benchmarks and testing where you want independent parser instances.
356    pub fn get_parser(&self, model_id: &str) -> Option<Arc<dyn ToolParser>> {
357        // Determine which parser type to use
358        let parser_type = {
359            let mapping = self.registry.model_mapping.read();
360
361            // Try exact match first
362            if let Some(parser_name) = mapping.get(model_id) {
363                parser_name.clone()
364            } else {
365                // Try prefix matching
366                let best_match = mapping
367                    .iter()
368                    .filter(|(pattern, _)| {
369                        pattern.ends_with('*')
370                            && model_id.starts_with(&pattern[..pattern.len() - 1])
371                    })
372                    .max_by_key(|(pattern, _)| pattern.len());
373
374                if let Some((_, parser_name)) = best_match {
375                    parser_name.clone()
376                } else {
377                    // Fall back to default
378                    self.registry.default_parser.read().clone()
379                }
380            }
381        };
382
383        let creators = self.registry.creators.read();
384        creators.get(&parser_type).map(|creator| {
385            // Call the creator to get a Box<dyn ToolParser>, then convert to Arc
386            let boxed_parser = creator();
387            Arc::from(boxed_parser)
388        })
389    }
390
391    /// List all registered parsers (for compatibility with old API).
392    pub fn list_parsers(&self) -> Vec<String> {
393        self.registry.creators.read().keys().cloned().collect()
394    }
395}
396
397impl Default for ParserFactory {
398    fn default() -> Self {
399        Self::new()
400    }
401}