Skip to main content

zag_agent/
auto_selector.rs

1//! Auto-selection of provider and/or model based on task analysis.
2//!
3//! Runs a lightweight LLM call to analyze the user's prompt and select
4//! the most suitable provider/model combination.
5
6use crate::config::Config;
7use crate::factory::AgentFactory;
8use anyhow::{Result, bail};
9use log::debug;
10use serde::Deserialize;
11
12const PROMPT_TEMPLATE_SOURCE: &str = include_str!("../prompts/auto-selector/3_1_0.md");
13
14fn prompt_template() -> &'static str {
15    crate::prompts::strip_front_matter(PROMPT_TEMPLATE_SOURCE)
16}
17
18/// Result of auto-selection.
19#[derive(Debug)]
20pub struct AutoResult {
21    /// The selected provider (e.g., "claude", "codex", "gemini").
22    pub provider: Option<String>,
23    /// The selected model (e.g., "opus", "haiku", "sonnet").
24    pub model: Option<String>,
25}
26
27/// JSON response structure from the auto-selector LLM.
28#[derive(Debug, Deserialize)]
29struct AutoSelectorResponse {
30    provider: Option<String>,
31    model: Option<String>,
32    reason: Option<String>,
33    /// If true, the selector declined to route the task.
34    declined: Option<bool>,
35}
36
37/// Resolve provider and/or model automatically by analyzing the task prompt.
38///
39/// - `prompt`: The user's task prompt to analyze.
40/// - `auto_provider`: Whether the provider should be auto-selected.
41/// - `auto_model`: Whether the model should be auto-selected.
42/// - `current_provider`: The non-auto provider (used when only model is auto).
43/// - `config`: The loaded configuration.
44/// - `root`: Optional root directory for agent creation.
45pub async fn resolve(
46    prompt: &str,
47    auto_provider: bool,
48    auto_model: bool,
49    current_provider: Option<&str>,
50    config: &Config,
51    root: Option<&str>,
52) -> Result<AutoResult> {
53    // Build the mode description and response format
54    let (mode, response_format) =
55        build_mode_and_format(auto_provider, auto_model, current_provider);
56
57    // Build the selector prompt
58    let selector_prompt = prompt_template()
59        .replace("{MODE}", &mode)
60        .replace("{RESPONSE_FORMAT}", &response_format)
61        .replace("{TASK}", prompt);
62
63    debug!("Auto-selector prompt:\n{selector_prompt}");
64
65    // Determine which provider/model to use for auto-selection
66    let selector_provider = config.auto_provider().unwrap_or("claude").to_string();
67    let selector_model = config.auto_model().unwrap_or("sonnet").to_string();
68
69    debug!("Auto-selector using {selector_provider} with model {selector_model}");
70
71    // Create and run the selector agent
72    debug!("Selecting provider/model for task...");
73
74    let mut agent = AgentFactory::create(
75        &selector_provider,
76        Some("Respond with ONLY the JSON object, nothing else. No explanations.".to_string()),
77        Some(selector_model),
78        root.map(String::from),
79        true, // auto-approve (selector doesn't need tools)
80        vec![],
81    )?;
82
83    // Capture stdout so we can parse the response programmatically
84    agent.set_capture_output(true);
85
86    let output = agent.run(Some(&selector_prompt)).await?;
87
88    // Parse the response
89    let response = extract_response(output)?;
90    debug!("Auto-selector response: '{response}'");
91
92    parse_response(&response, auto_provider, auto_model, current_provider)
93}
94
95/// Build the mode description and response format for the prompt template.
96///
97/// Returns (mode, response_format).
98fn build_mode_and_format(
99    auto_provider: bool,
100    auto_model: bool,
101    current_provider: Option<&str>,
102) -> (String, String) {
103    let declined_format =
104        r#"If you decline the task, respond with: {"declined": true, "reason": "..."}"#;
105
106    if auto_provider && auto_model {
107        let mode = "provider and model".to_string();
108        let response_format = format!(
109            "Respond with ONLY a JSON object on a single line, nothing else:\n\
110             {{\"provider\": \"<provider>\", \"model\": \"<size>\", \"reason\": \"...\"}}\n\n\
111             {declined_format}"
112        );
113        (mode, response_format)
114    } else if auto_provider {
115        let mode = "provider".to_string();
116        let response_format = format!(
117            "Respond with ONLY a JSON object on a single line, nothing else:\n\
118             {{\"provider\": \"<provider>\", \"reason\": \"...\"}}\n\n\
119             {declined_format}"
120        );
121        (mode, response_format)
122    } else {
123        // auto_model only
124        let provider = current_provider.unwrap_or("claude");
125        let mode = format!("model for {provider}");
126        let response_format = format!(
127            "Respond with ONLY a JSON object on a single line, nothing else:\n\
128             {{\"model\": \"<model>\", \"reason\": \"...\"}}\n\n\
129             {declined_format}"
130        );
131        (mode, response_format)
132    }
133}
134
135/// Extract the text response from the agent output.
136fn extract_response(output: Option<crate::output::AgentOutput>) -> Result<String> {
137    if let Some(agent_output) = output {
138        if let Some(result) = agent_output.final_result() {
139            return Ok(result.trim().to_string());
140        }
141        bail!("Auto-selector returned no result");
142    }
143
144    bail!(
145        "Auto-selector produced no parseable output. Ensure the selector agent is configured correctly."
146    )
147}
148
149/// Check if a response looks like an LLM refusal rather than a valid selection.
150fn is_refusal(response: &str) -> bool {
151    let lower = response.to_lowercase();
152    let refusal_patterns = [
153        "i'm sorry",
154        "i'm not able",
155        "i cannot",
156        "i can't",
157        "i'm unable",
158        "i apologize",
159        "i must decline",
160        "not appropriate",
161        "i'm not going to",
162        "i don't think i should",
163        "i won't",
164        "as an ai",
165        "as a language model",
166        "content policy",
167        "against my guidelines",
168    ];
169    refusal_patterns.iter().any(|p| lower.contains(p))
170}
171
172/// Parse the response into an AutoResult.
173///
174/// Tries JSON parsing first, then falls back to text-based parsing for robustness.
175fn parse_response(
176    response: &str,
177    auto_provider: bool,
178    auto_model: bool,
179    current_provider: Option<&str>,
180) -> Result<AutoResult> {
181    // Check for LLM refusal before attempting to parse
182    if is_refusal(response) {
183        bail!(
184            "Auto-selector declined to process the prompt. The task may have been \
185             filtered by the model's content policy. Try running with an explicit \
186             provider and model instead of auto."
187        );
188    }
189
190    // Try JSON parsing first
191    let cleaned = crate::json_validation::strip_markdown_fences(response);
192    if let Ok(parsed) = serde_json::from_str::<AutoSelectorResponse>(cleaned) {
193        debug!("Auto-selector parsed JSON response successfully");
194        if let Some(ref reason) = parsed.reason {
195            debug!("Auto-selector reason: {reason}");
196        }
197
198        // Check for structured decline
199        if parsed.declined == Some(true) {
200            let reason = parsed.reason.as_deref().unwrap_or("no reason given");
201            bail!(
202                "Auto-selector declined the task: {reason}. \
203                 Try running with an explicit provider and model instead of auto."
204            );
205        }
206
207        return build_result_from_json(parsed, auto_provider, auto_model, current_provider);
208    }
209
210    // Fall back to text-based parsing
211    debug!("Auto-selector falling back to text parsing");
212    parse_response_text(response, auto_provider, auto_model, current_provider)
213}
214
215/// Build an AutoResult from a parsed JSON response.
216fn build_result_from_json(
217    parsed: AutoSelectorResponse,
218    auto_provider: bool,
219    auto_model: bool,
220    current_provider: Option<&str>,
221) -> Result<AutoResult> {
222    if auto_provider && auto_model {
223        let provider = parsed
224            .provider
225            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'provider' field"))?;
226        let provider = validate_provider(&provider)?;
227        Ok(AutoResult {
228            provider: Some(provider),
229            model: parsed.model,
230        })
231    } else if auto_provider {
232        let provider = parsed
233            .provider
234            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'provider' field"))?;
235        let provider = validate_provider(&provider)?;
236        Ok(AutoResult {
237            provider: Some(provider),
238            model: None,
239        })
240    } else {
241        // auto_model only
242        let model = parsed
243            .model
244            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'model' field"))?;
245        Ok(AutoResult {
246            provider: current_provider.map(String::from),
247            model: Some(model.to_lowercase()),
248        })
249    }
250}
251
252/// Parse a text-based response (fallback when JSON parsing fails).
253fn parse_response_text(
254    response: &str,
255    auto_provider: bool,
256    auto_model: bool,
257    current_provider: Option<&str>,
258) -> Result<AutoResult> {
259    // Clean up the response - take only the first line, trim whitespace and backticks
260    let cleaned = response
261        .lines()
262        .next()
263        .unwrap_or("")
264        .trim()
265        .trim_matches('`')
266        .trim()
267        .to_lowercase();
268
269    if cleaned.is_empty() {
270        bail!("Auto-selector returned an empty response");
271    }
272
273    let parts: Vec<&str> = cleaned.split_whitespace().collect();
274
275    if auto_provider && auto_model {
276        // Expect "<provider> <model>"
277        if parts.len() >= 2 {
278            let provider = validate_provider(parts[0])?;
279            let model = parts[1].to_string();
280            Ok(AutoResult {
281                provider: Some(provider),
282                model: Some(model),
283            })
284        } else if parts.len() == 1 {
285            // Just a provider, use default model
286            let provider = validate_provider(parts[0])?;
287            Ok(AutoResult {
288                provider: Some(provider),
289                model: None,
290            })
291        } else {
292            bail!("Auto-selector returned unparseable response: '{response}'");
293        }
294    } else if auto_provider {
295        // Expect "<provider>"
296        let provider = validate_provider(parts[0])?;
297        Ok(AutoResult {
298            provider: Some(provider),
299            model: None,
300        })
301    } else {
302        // auto_model only - expect "<model>"
303        Ok(AutoResult {
304            provider: current_provider.map(String::from),
305            model: Some(parts[0].to_string()),
306        })
307    }
308}
309
310/// Validate that a provider name is known.
311fn validate_provider(name: &str) -> Result<String> {
312    let normalized = name.to_lowercase();
313    if Config::VALID_PROVIDERS.contains(&normalized.as_str()) {
314        Ok(normalized)
315    } else {
316        bail!(
317            "Auto-selector chose unknown provider '{}'. Available: {}",
318            name,
319            Config::VALID_PROVIDERS.join(", ")
320        );
321    }
322}
323
324#[cfg(test)]
325#[path = "auto_selector_tests.rs"]
326mod tests;