Skip to main content

zag_agent/
auto_selector.rs

1//! Auto-selection of provider and/or model based on task analysis.
2//!
3//! Runs a lightweight LLM call to analyze the user's prompt and select
4//! the most suitable provider/model combination.
5
6use crate::config::Config;
7use crate::factory::AgentFactory;
8use anyhow::{Result, bail};
9use log::debug;
10use serde::Deserialize;
11
12const PROMPT_TEMPLATE: &str = include_str!("../prompts/auto-selector/3_1.md");
13
14/// Result of auto-selection.
15#[derive(Debug)]
16pub struct AutoResult {
17    /// The selected provider (e.g., "claude", "codex", "gemini").
18    pub provider: Option<String>,
19    /// The selected model (e.g., "opus", "haiku", "sonnet").
20    pub model: Option<String>,
21}
22
23/// JSON response structure from the auto-selector LLM.
24#[derive(Debug, Deserialize)]
25struct AutoSelectorResponse {
26    provider: Option<String>,
27    model: Option<String>,
28    reason: Option<String>,
29    /// If true, the selector declined to route the task.
30    declined: Option<bool>,
31}
32
33/// Resolve provider and/or model automatically by analyzing the task prompt.
34///
35/// - `prompt`: The user's task prompt to analyze.
36/// - `auto_provider`: Whether the provider should be auto-selected.
37/// - `auto_model`: Whether the model should be auto-selected.
38/// - `current_provider`: The non-auto provider (used when only model is auto).
39/// - `config`: The loaded configuration.
40/// - `root`: Optional root directory for agent creation.
41pub async fn resolve(
42    prompt: &str,
43    auto_provider: bool,
44    auto_model: bool,
45    current_provider: Option<&str>,
46    config: &Config,
47    root: Option<&str>,
48) -> Result<AutoResult> {
49    // Build the mode description and response format
50    let (mode, response_format) =
51        build_mode_and_format(auto_provider, auto_model, current_provider);
52
53    // Build the selector prompt
54    let selector_prompt = PROMPT_TEMPLATE
55        .replace("{MODE}", &mode)
56        .replace("{RESPONSE_FORMAT}", &response_format)
57        .replace("{TASK}", prompt);
58
59    debug!("Auto-selector prompt:\n{}", selector_prompt);
60
61    // Determine which provider/model to use for auto-selection
62    let selector_provider = config.auto_provider().unwrap_or("claude").to_string();
63    let selector_model = config.auto_model().unwrap_or("sonnet").to_string();
64
65    debug!(
66        "Auto-selector using {} with model {}",
67        selector_provider, selector_model
68    );
69
70    // Create and run the selector agent
71    debug!("Selecting provider/model for task...");
72
73    let mut agent = AgentFactory::create(
74        &selector_provider,
75        Some("Respond with ONLY the JSON object, nothing else. No explanations.".to_string()),
76        Some(selector_model),
77        root.map(String::from),
78        true, // auto-approve (selector doesn't need tools)
79        vec![],
80    )?;
81
82    // Capture stdout so we can parse the response programmatically
83    agent.set_capture_output(true);
84
85    let output = agent.run(Some(&selector_prompt)).await?;
86
87    // Parse the response
88    let response = extract_response(output)?;
89    debug!("Auto-selector response: '{}'", response);
90
91    parse_response(&response, auto_provider, auto_model, current_provider)
92}
93
94/// Build the mode description and response format for the prompt template.
95///
96/// Returns (mode, response_format).
97fn build_mode_and_format(
98    auto_provider: bool,
99    auto_model: bool,
100    current_provider: Option<&str>,
101) -> (String, String) {
102    let declined_format =
103        r#"If you decline the task, respond with: {"declined": true, "reason": "..."}"#;
104
105    if auto_provider && auto_model {
106        let mode = "provider and model".to_string();
107        let response_format = format!(
108            "Respond with ONLY a JSON object on a single line, nothing else:\n\
109             {{\"provider\": \"<provider>\", \"model\": \"<size>\", \"reason\": \"...\"}}\n\n\
110             {declined_format}"
111        );
112        (mode, response_format)
113    } else if auto_provider {
114        let mode = "provider".to_string();
115        let response_format = format!(
116            "Respond with ONLY a JSON object on a single line, nothing else:\n\
117             {{\"provider\": \"<provider>\", \"reason\": \"...\"}}\n\n\
118             {declined_format}"
119        );
120        (mode, response_format)
121    } else {
122        // auto_model only
123        let provider = current_provider.unwrap_or("claude");
124        let mode = format!("model for {}", provider);
125        let response_format = format!(
126            "Respond with ONLY a JSON object on a single line, nothing else:\n\
127             {{\"model\": \"<model>\", \"reason\": \"...\"}}\n\n\
128             {declined_format}"
129        );
130        (mode, response_format)
131    }
132}
133
134/// Extract the text response from the agent output.
135fn extract_response(output: Option<crate::output::AgentOutput>) -> Result<String> {
136    if let Some(agent_output) = output {
137        if let Some(result) = agent_output.final_result() {
138            return Ok(result.trim().to_string());
139        }
140        bail!("Auto-selector returned no result");
141    }
142
143    bail!(
144        "Auto-selector produced no parseable output. Ensure the selector agent is configured correctly."
145    )
146}
147
148/// Check if a response looks like an LLM refusal rather than a valid selection.
149fn is_refusal(response: &str) -> bool {
150    let lower = response.to_lowercase();
151    let refusal_patterns = [
152        "i'm sorry",
153        "i'm not able",
154        "i cannot",
155        "i can't",
156        "i'm unable",
157        "i apologize",
158        "i must decline",
159        "not appropriate",
160        "i'm not going to",
161        "i don't think i should",
162        "i won't",
163        "as an ai",
164        "as a language model",
165        "content policy",
166        "against my guidelines",
167    ];
168    refusal_patterns.iter().any(|p| lower.contains(p))
169}
170
171/// Parse the response into an AutoResult.
172///
173/// Tries JSON parsing first, then falls back to text-based parsing for robustness.
174fn parse_response(
175    response: &str,
176    auto_provider: bool,
177    auto_model: bool,
178    current_provider: Option<&str>,
179) -> Result<AutoResult> {
180    // Check for LLM refusal before attempting to parse
181    if is_refusal(response) {
182        bail!(
183            "Auto-selector declined to process the prompt. The task may have been \
184             filtered by the model's content policy. Try running with an explicit \
185             provider and model instead of auto."
186        );
187    }
188
189    // Try JSON parsing first
190    let cleaned = crate::json_validation::strip_markdown_fences(response);
191    if let Ok(parsed) = serde_json::from_str::<AutoSelectorResponse>(cleaned) {
192        debug!("Auto-selector parsed JSON response successfully");
193        if let Some(ref reason) = parsed.reason {
194            debug!("Auto-selector reason: {}", reason);
195        }
196
197        // Check for structured decline
198        if parsed.declined == Some(true) {
199            let reason = parsed.reason.as_deref().unwrap_or("no reason given");
200            bail!(
201                "Auto-selector declined the task: {}. \
202                 Try running with an explicit provider and model instead of auto.",
203                reason
204            );
205        }
206
207        return build_result_from_json(parsed, auto_provider, auto_model, current_provider);
208    }
209
210    // Fall back to text-based parsing
211    debug!("Auto-selector falling back to text parsing");
212    parse_response_text(response, auto_provider, auto_model, current_provider)
213}
214
215/// Build an AutoResult from a parsed JSON response.
216fn build_result_from_json(
217    parsed: AutoSelectorResponse,
218    auto_provider: bool,
219    auto_model: bool,
220    current_provider: Option<&str>,
221) -> Result<AutoResult> {
222    if auto_provider && auto_model {
223        let provider = parsed
224            .provider
225            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'provider' field"))?;
226        let provider = validate_provider(&provider)?;
227        Ok(AutoResult {
228            provider: Some(provider),
229            model: parsed.model,
230        })
231    } else if auto_provider {
232        let provider = parsed
233            .provider
234            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'provider' field"))?;
235        let provider = validate_provider(&provider)?;
236        Ok(AutoResult {
237            provider: Some(provider),
238            model: None,
239        })
240    } else {
241        // auto_model only
242        let model = parsed
243            .model
244            .ok_or_else(|| anyhow::anyhow!("Auto-selector JSON missing 'model' field"))?;
245        Ok(AutoResult {
246            provider: current_provider.map(String::from),
247            model: Some(model.to_lowercase()),
248        })
249    }
250}
251
252/// Parse a text-based response (fallback when JSON parsing fails).
253fn parse_response_text(
254    response: &str,
255    auto_provider: bool,
256    auto_model: bool,
257    current_provider: Option<&str>,
258) -> Result<AutoResult> {
259    // Clean up the response - take only the first line, trim whitespace and backticks
260    let cleaned = response
261        .lines()
262        .next()
263        .unwrap_or("")
264        .trim()
265        .trim_matches('`')
266        .trim()
267        .to_lowercase();
268
269    if cleaned.is_empty() {
270        bail!("Auto-selector returned an empty response");
271    }
272
273    let parts: Vec<&str> = cleaned.split_whitespace().collect();
274
275    if auto_provider && auto_model {
276        // Expect "<provider> <model>"
277        if parts.len() >= 2 {
278            let provider = validate_provider(parts[0])?;
279            let model = parts[1].to_string();
280            Ok(AutoResult {
281                provider: Some(provider),
282                model: Some(model),
283            })
284        } else if parts.len() == 1 {
285            // Just a provider, use default model
286            let provider = validate_provider(parts[0])?;
287            Ok(AutoResult {
288                provider: Some(provider),
289                model: None,
290            })
291        } else {
292            bail!(
293                "Auto-selector returned unparseable response: '{}'",
294                response
295            );
296        }
297    } else if auto_provider {
298        // Expect "<provider>"
299        let provider = validate_provider(parts[0])?;
300        Ok(AutoResult {
301            provider: Some(provider),
302            model: None,
303        })
304    } else {
305        // auto_model only - expect "<model>"
306        Ok(AutoResult {
307            provider: current_provider.map(String::from),
308            model: Some(parts[0].to_string()),
309        })
310    }
311}
312
313/// Validate that a provider name is known.
314fn validate_provider(name: &str) -> Result<String> {
315    let normalized = name.to_lowercase();
316    if Config::VALID_PROVIDERS.contains(&normalized.as_str()) {
317        Ok(normalized)
318    } else {
319        bail!(
320            "Auto-selector chose unknown provider '{}'. Available: {}",
321            name,
322            Config::VALID_PROVIDERS.join(", ")
323        );
324    }
325}
326
327#[cfg(test)]
328#[path = "auto_selector_tests.rs"]
329mod tests;