Skip to main content

zag_agent/
auto_selector.rs

1//! Auto-selection of provider and/or model based on task analysis.
2//!
3//! Runs a lightweight LLM call to analyze the user's prompt and select
4//! the most suitable provider/model combination.
5
6use crate::config::Config;
7use crate::factory::AgentFactory;
8use anyhow::{Result, bail};
9use log::debug;
10use serde::Deserialize;
11
12const PROMPT_TEMPLATE: &str = include_str!("../prompts/auto-selector/3_1.md");
13
14/// Result of auto-selection.
15#[derive(Debug)]
16pub struct AutoResult {
17    /// The selected provider (e.g., "claude", "codex", "gemini").
18    pub provider: Option<String>,
19    /// The selected model (e.g., "opus", "haiku", "sonnet").
20    pub model: Option<String>,
21}
22
23/// JSON response structure from the auto-selector LLM.
24#[derive(Debug, Deserialize)]
25struct AutoSelectorResponse {
26    provider: Option<String>,
27    model: Option<String>,
28    reason: Option<String>,
29    /// If true, the selector declined to route the task.
30    declined: Option<bool>,
31}
32
33/// Resolve provider and/or model automatically by analyzing the task prompt.
34///
35/// - `prompt`: The user's task prompt to analyze.
36/// - `auto_provider`: Whether the provider should be auto-selected.
37/// - `auto_model`: Whether the model should be auto-selected.
38/// - `current_provider`: The non-auto provider (used when only model is auto).
39/// - `config`: The loaded configuration.
40/// - `root`: Optional root directory for agent creation.
41pub async fn resolve(
42    prompt: &str,
43    auto_provider: bool,
44    auto_model: bool,
45    current_provider: Option<&str>,
46    config: &Config,
47    root: Option<&str>,
48) -> Result<AutoResult> {
49    // Build the mode description and response format
50    let (mode, response_format) =
51        build_mode_and_format(auto_provider, auto_model, current_provider);
52
53    // Build the selector prompt
54    let selector_prompt = PROMPT_TEMPLATE
55        .replace("{MODE}", &mode)
56        .replace("{RESPONSE_FORMAT}", &response_format)
57        .replace("{TASK}", prompt);
58
59    debug!("Auto-selector prompt:\n{selector_prompt}");
60
61    // Determine which provider/model to use for auto-selection
62    let selector_provider = config.auto_provider().unwrap_or("claude").to_string();
63    let selector_model = config.auto_model().unwrap_or("sonnet").to_string();
64
65    debug!("Auto-selector using {selector_provider} with model {selector_model}");
66
67    // Create and run the selector agent
68    debug!("Selecting provider/model for task...");
69
70    let mut agent = AgentFactory::create(
71        &selector_provider,
72        Some("Respond with ONLY the JSON object, nothing else. No explanations.".to_string()),
73        Some(selector_model),
74        root.map(String::from),
75        true, // auto-approve (selector doesn't need tools)
76        vec![],
77    )?;
78
79    // Capture stdout so we can parse the response programmatically
80    agent.set_capture_output(true);
81
82    let output = agent.run(Some(&selector_prompt)).await?;
83
84    // Parse the response
85    let response = extract_response(output)?;
86    debug!("Auto-selector response: '{response}'");
87
88    parse_response(&response, auto_provider, auto_model, current_provider)
89}
90
91/// Build the mode description and response format for the prompt template.
92///
93/// Returns (mode, response_format).
94fn build_mode_and_format(
95    auto_provider: bool,
96    auto_model: bool,
97    current_provider: Option<&str>,
98) -> (String, String) {
99    let declined_format =
100        r#"If you decline the task, respond with: {"declined": true, "reason": "..."}"#;
101
102    if auto_provider && auto_model {
103        let mode = "provider and model".to_string();
104        let response_format = format!(
105            "Respond with ONLY a JSON object on a single line, nothing else:\n\
106             {{\"provider\": \"<provider>\", \"model\": \"<size>\", \"reason\": \"...\"}}\n\n\
107             {declined_format}"
108        );
109        (mode, response_format)
110    } else if auto_provider {
111        let mode = "provider".to_string();
112        let response_format = format!(
113            "Respond with ONLY a JSON object on a single line, nothing else:\n\
114             {{\"provider\": \"<provider>\", \"reason\": \"...\"}}\n\n\
115             {declined_format}"
116        );
117        (mode, response_format)
118    } else {
119        // auto_model only
120        let provider = current_provider.unwrap_or("claude");
121        let mode = format!("model for {provider}");
122        let response_format = format!(
123            "Respond with ONLY a JSON object on a single line, nothing else:\n\
124             {{\"model\": \"<model>\", \"reason\": \"...\"}}\n\n\
125             {declined_format}"
126        );
127        (mode, response_format)
128    }
129}
130
131/// Extract the text response from the agent output.
132fn extract_response(output: Option<crate::output::AgentOutput>) -> Result<String> {
133    if let Some(agent_output) = output {
134        if let Some(result) = agent_output.final_result() {
135            return Ok(result.trim().to_string());
136        }
137        bail!("Auto-selector returned no result");
138    }
139
140    bail!(
141        "Auto-selector produced no parseable output. Ensure the selector agent is configured correctly."
142    )
143}
144
145/// Check if a response looks like an LLM refusal rather than a valid selection.
146fn is_refusal(response: &str) -> bool {
147    let lower = response.to_lowercase();
148    let refusal_patterns = [
149        "i'm sorry",
150        "i'm not able",
151        "i cannot",
152        "i can't",
153        "i'm unable",
154        "i apologize",
155        "i must decline",
156        "not appropriate",
157        "i'm not going to",
158        "i don't think i should",
159        "i won't",
160        "as an ai",
161        "as a language model",
162        "content policy",
163        "against my guidelines",
164    ];
165    refusal_patterns.iter().any(|p| lower.contains(p))
166}
167
168/// Parse the response into an AutoResult.
169///
170/// Tries JSON parsing first, then falls back to text-based parsing for robustness.
171fn parse_response(
172    response: &str,
173    auto_provider: bool,
174    auto_model: bool,
175    current_provider: Option<&str>,
176) -> Result<AutoResult> {
177    // Check for LLM refusal before attempting to parse
178    if is_refusal(response) {
179        bail!(
180            "Auto-selector declined to process the prompt. The task may have been \
181             filtered by the model's content policy. Try running with an explicit \
182             provider and model instead of auto."
183        );
184    }
185
186    // Try JSON parsing first
187    let cleaned = crate::json_validation::strip_markdown_fences(response);
188    if let Ok(parsed) = serde_json::from_str::<AutoSelectorResponse>(cleaned) {
189        debug!("Auto-selector parsed JSON response successfully");
190        if let Some(ref reason) = parsed.reason {
191            debug!("Auto-selector reason: {reason}");
192        }
193
194        // Check for structured decline
195        if parsed.declined == Some(true) {
196            let reason = parsed.reason.as_deref().unwrap_or("no reason given");
197            bail!(
198                "Auto-selector declined the task: {reason}. \
199                 Try running with an explicit provider and model instead of auto."
200            );
201        }
202
203        return build_result_from_json(parsed, auto_provider, auto_model, current_provider);
204    }
205
206    // Fall back to text-based parsing
207    debug!("Auto-selector falling back to text parsing");
208    parse_response_text(response, auto_provider, auto_model, current_provider)
209}
210
211/// Build an AutoResult from a parsed JSON response.
212fn build_result_from_json(
213    parsed: AutoSelectorResponse,
214    auto_provider: bool,
215    auto_model: bool,
216    current_provider: Option<&str>,
217) -> Result<AutoResult> {
218    if auto_provider && auto_model {
219        let provider = parsed
220            .provider
221            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'provider' field"))?;
222        let provider = validate_provider(&provider)?;
223        Ok(AutoResult {
224            provider: Some(provider),
225            model: parsed.model,
226        })
227    } else if auto_provider {
228        let provider = parsed
229            .provider
230            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'provider' field"))?;
231        let provider = validate_provider(&provider)?;
232        Ok(AutoResult {
233            provider: Some(provider),
234            model: None,
235        })
236    } else {
237        // auto_model only
238        let model = parsed
239            .model
240            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'model' field"))?;
241        Ok(AutoResult {
242            provider: current_provider.map(String::from),
243            model: Some(model.to_lowercase()),
244        })
245    }
246}
247
248/// Parse a text-based response (fallback when JSON parsing fails).
249fn parse_response_text(
250    response: &str,
251    auto_provider: bool,
252    auto_model: bool,
253    current_provider: Option<&str>,
254) -> Result<AutoResult> {
255    // Clean up the response - take only the first line, trim whitespace and backticks
256    let cleaned = response
257        .lines()
258        .next()
259        .unwrap_or("")
260        .trim()
261        .trim_matches('`')
262        .trim()
263        .to_lowercase();
264
265    if cleaned.is_empty() {
266        bail!("Auto-selector returned an empty response");
267    }
268
269    let parts: Vec<&str> = cleaned.split_whitespace().collect();
270
271    if auto_provider && auto_model {
272        // Expect "<provider> <model>"
273        if parts.len() >= 2 {
274            let provider = validate_provider(parts[0])?;
275            let model = parts[1].to_string();
276            Ok(AutoResult {
277                provider: Some(provider),
278                model: Some(model),
279            })
280        } else if parts.len() == 1 {
281            // Just a provider, use default model
282            let provider = validate_provider(parts[0])?;
283            Ok(AutoResult {
284                provider: Some(provider),
285                model: None,
286            })
287        } else {
288            bail!("Auto-selector returned unparseable response: '{response}'");
289        }
290    } else if auto_provider {
291        // Expect "<provider>"
292        let provider = validate_provider(parts[0])?;
293        Ok(AutoResult {
294            provider: Some(provider),
295            model: None,
296        })
297    } else {
298        // auto_model only - expect "<model>"
299        Ok(AutoResult {
300            provider: current_provider.map(String::from),
301            model: Some(parts[0].to_string()),
302        })
303    }
304}
305
306/// Validate that a provider name is known.
307fn validate_provider(name: &str) -> Result<String> {
308    let normalized = name.to_lowercase();
309    if Config::VALID_PROVIDERS.contains(&normalized.as_str()) {
310        Ok(normalized)
311    } else {
312        bail!(
313            "Auto-selector chose unknown provider '{}'. Available: {}",
314            name,
315            Config::VALID_PROVIDERS.join(", ")
316        );
317    }
318}
319
320#[cfg(test)]
321#[path = "auto_selector_tests.rs"]
322mod tests;