Skip to main content

sparrow/provider/
detect.rs

1//! Provider auto-detection: scan environment for API keys, test connectivity
2//! with lightweight API calls, rank providers by cost tier (free > paid),
3//! and return a list of ready-to-use providers.
4//!
5//! Integrates with the first-run wizard in [`crate::onboarding::wizard`].
6
7use crate::config::providers::{find_provider, provider_registry, ProviderDef};
8use serde::{Deserialize, Serialize};
9
10// ─── Detection result types ──────────────────────────────────────────────────
11
12/// The result of scanning for one provider.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DetectedProvider {
15    /// Registry id (e.g. "anthropic", "nvidia")
16    pub id: String,
17    /// Human label (e.g. "Anthropic", "NVIDIA NIM")
18    pub label: String,
19    /// Whether we found an API key in the environment
20    pub key_found: bool,
21    /// The env var name if applicable
22    pub env_var: Option<String>,
23    /// Cost tier
24    pub tier: ProviderTier,
25    /// Whether we successfully validated the key with a lightweight API call
26    pub validated: Option<bool>,
27    /// Error message if validation failed
28    pub validation_error: Option<String>,
29    /// Signup URL for getting a key
30    pub signup_url: Option<String>,
31    /// Whether this provider is recommended for the user
32    pub recommended: bool,
33    /// Short description for the wizard UI
34    pub description: String,
35}
36
37/// Cost tier for ranking.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
39pub enum ProviderTier {
40    /// Completely free (NVIDIA NIM, Groq free tier, Gemini free tier)
41    Free = 0,
42    /// Has a generous free tier (some paid models but free tier exists)
43    FreeTier = 1,
44    /// Paid but cheap (DeepSeek, etc.)
45    Cheap = 2,
46    /// Paid, standard pricing
47    Paid = 3,
48    /// Requires signup / no key found
49    RequiresSignup = 4,
50}
51
52// ─── Known API key environment variables ─────────────────────────────────────
53
54/// List of all known API key env vars and their associated provider ids.
55const KNOWN_API_KEY_ENVS: &[(&str, &str)] = &[
56    ("OPENAI_API_KEY", "openai-codex"),
57    ("ANTHROPIC_API_KEY", "anthropic"),
58    ("GEMINI_API_KEY", "gemini"),
59    ("GROQ_API_KEY", "groq"),
60    ("NVIDIA_API_KEY", "nvidia"),
61    ("DEEPSEEK_API_KEY", "deepseek"),
62    ("OPENROUTER_API_KEY", "openrouter"),
63    ("XAI_API_KEY", "xai"),
64    ("HF_TOKEN", "huggingface"),
65    ("NOUS_API_KEY", "nous"),
66    ("NOVITA_API_KEY", "novita"),
67    ("DASHSCOPE_API_KEY", "alibaba"),
68    ("MOONSHOT_API_KEY", "kimi-coding"),
69    ("MISTRAL_API_KEY", "mistral"),
70    ("TOGETHER_API_KEY", "together"),
71    ("CEREBRAS_API_KEY", "cerebras"),
72    ("FIREWORKS_API_KEY", "fireworks"),
73    ("PERPLEXITY_API_KEY", "perplexity"),
74    ("COHERE_API_KEY", "cohere"),
75    ("AWS_ACCESS_KEY_ID", "bedrock"),
76    ("COPILOT_TOKEN", "copilot"),
77];
78
79// ─── Environment scanning ────────────────────────────────────────────────────
80
81/// Scan the environment for all known API keys.
82///
83/// Returns a map of provider_id → (env_var_name, key_value).
84pub fn scan_environment() -> Vec<(&'static str, &'static str, String)> {
85    KNOWN_API_KEY_ENVS
86        .iter()
87        .filter_map(|&(env_var, provider_id)| {
88            std::env::var(env_var)
89                .ok()
90                .filter(|v| !v.trim().is_empty())
91                .map(|key| (provider_id, env_var, key))
92        })
93        .collect()
94}
95
96/// Detect all providers with their status.
97///
98/// Returns a vector of [`DetectedProvider`] sorted by tier (free first), then
99/// by whether a key was found.
100pub fn detect_all_providers() -> Vec<DetectedProvider> {
101    let env_keys = scan_environment();
102    let env_keys_map: std::collections::HashMap<&str, (&str, String)> = env_keys
103        .iter()
104        .map(|(pid, env, key)| (*pid, (*env, key.clone())))
105        .collect();
106
107    let mut providers: Vec<DetectedProvider> = Vec::new();
108
109    for def in provider_registry() {
110        let key_info = env_keys_map.get(def.id.as_str());
111
112        let tier = classify_tier(&def);
113        let signup_url = signup_url_for(&def);
114
115        let description = match tier {
116            ProviderTier::Free => format!(
117                "Gratuit — {}. Modèle recommandé : {}",
118                def.notes.trim_end_matches('.'),
119                def.models
120                    .iter()
121                    .find(|m| m.recommended)
122                    .map(|m| m.name.as_str())
123                    .unwrap_or(def.models.first().map(|m| m.name.as_str()).unwrap_or("N/A")),
124            ),
125            _ => def.notes.clone(),
126        };
127
128        providers.push(DetectedProvider {
129            id: def.id.clone(),
130            label: def.label.clone(),
131            key_found: key_info.is_some(),
132            env_var: key_info.map(|(env, _)| env.to_string()),
133            tier,
134            validated: None,
135            validation_error: None,
136            signup_url,
137            recommended: def.models.iter().any(|m| m.recommended),
138            description,
139        });
140    }
141
142    // Sort: free first, then by key found
143    providers.sort_by(|a, b| {
144        a.tier
145            .cmp(&b.tier)
146            .then_with(|| b.key_found.cmp(&a.key_found)) // key found first within same tier
147            .then_with(|| a.label.cmp(&b.label))
148    });
149
150    providers
151}
152
153/// Classify a provider definition into a cost tier.
154fn classify_tier(def: &ProviderDef) -> ProviderTier {
155    // Check tags first
156    if def.tags.iter().any(|t| t == "free") {
157        return ProviderTier::Free;
158    }
159
160    // Check if any model has zero cost (free tier)
161    let all_free = !def.models.is_empty()
162        && def
163            .models
164            .iter()
165            .all(|m| m.cost_input_per_mtok == 0.0 && m.cost_output_per_mtok == 0.0);
166
167    if all_free {
168        return ProviderTier::Free;
169    }
170
171    let has_free_models = def
172        .models
173        .iter()
174        .any(|m| m.cost_input_per_mtok == 0.0 && m.cost_output_per_mtok == 0.0);
175
176    if has_free_models {
177        return ProviderTier::FreeTier;
178    }
179
180    // Cheap if cheapest model < $1/M input tokens
181    let cheapest_input = def
182        .models
183        .iter()
184        .map(|m| m.cost_input_per_mtok)
185        .fold(f64::MAX, f64::min);
186
187    if cheapest_input < 1.0 {
188        return ProviderTier::Cheap;
189    }
190
191    ProviderTier::Paid
192}
193
194/// Get the signup URL for a provider.
195fn signup_url_for(def: &ProviderDef) -> Option<String> {
196    match def.id.as_str() {
197        "anthropic" => Some("https://console.anthropic.com/settings/keys".into()),
198        "openai-codex" => Some("https://platform.openai.com/api-keys".into()),
199        "gemini" => Some("https://aistudio.google.com/app/apikey".into()),
200        "groq" => Some("https://console.groq.com/keys".into()),
201        "nvidia" => Some("https://build.nvidia.com/explore/discover".into()),
202        "deepseek" => Some("https://platform.deepseek.com/api_keys".into()),
203        "openrouter" => Some("https://openrouter.ai/keys".into()),
204        "xai" => Some("https://console.x.ai/".into()),
205        "huggingface" => Some("https://huggingface.co/settings/tokens".into()),
206        "nous" => Some("https://portal.nousresearch.com".into()),
207        "novita" => Some("https://novita.ai/dashboard/key".into()),
208        "alibaba" => Some("https://bailian.console.aliyun.com".into()),
209        "kimi-coding" => Some("https://platform.moonshot.cn/console".into()),
210        "mistral" => Some("https://console.mistral.ai/api-keys/".into()),
211        "together" => Some("https://api.together.xyz/settings/api-keys".into()),
212        "cerebras" => Some("https://cloud.cerebras.ai/".into()),
213        "fireworks" => Some("https://fireworks.ai/api-keys".into()),
214        "perplexity" => Some("https://www.perplexity.ai/settings/api".into()),
215        "cohere" => Some("https://dashboard.cohere.com/api-keys".into()),
216        _ => None,
217    }
218}
219
220/// Check if the `gh` CLI is installed (for GitHub Copilot integration).
221pub fn gh_cli_installed() -> bool {
222    std::process::Command::new("gh")
223        .arg("--version")
224        .stdout(std::process::Stdio::null())
225        .stderr(std::process::Stdio::null())
226        .status()
227        .map(|s| s.success())
228        .unwrap_or(false)
229}
230
231/// Test a provider's API key with a minimal request.
232///
233/// Returns `Ok(())` if the key is valid, `Err(msg)` with a French error
234/// message otherwise.
235pub async fn validate_api_key(provider_id: &str, api_key: &str) -> Result<(), String> {
236    let def = match find_provider(provider_id) {
237        Some(d) => d,
238        None => {
239            return Err(format!(
240                "Provider \"{provider_id}\" inconnu dans le registre Sparrow."
241            ));
242        }
243    };
244
245    // Build a minimal request based on the adapter type
246    match def.adapter.as_str() {
247        "anthropic-messages" => validate_anthropic_key(api_key).await,
248        "openai-compatible" => validate_openai_compatible_key(&def.base_url, api_key).await,
249        "ollama" => {
250            // Ollama doesn't need an API key — just check if it's reachable
251            validate_ollama_connection(&def.base_url).await
252        }
253        _ => validate_openai_compatible_key(&def.base_url, api_key).await,
254    }
255}
256
257/// Validate an Anthropic API key with a GET to /v1/models.
258async fn validate_anthropic_key(api_key: &str) -> Result<(), String> {
259    let client = reqwest::Client::builder()
260        .timeout(std::time::Duration::from_secs(10))
261        .build()
262        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
263
264    let resp = client
265        .get("https://api.anthropic.com/v1/models?limit=1")
266        .header("x-api-key", api_key)
267        .header("anthropic-version", "2023-06-01")
268        .send()
269        .await
270        .map_err(|e| {
271            if e.is_timeout() {
272                "Timeout — le serveur Anthropic ne répond pas. Check ta connexion.".into()
273            } else if e.is_connect() {
274                "Impossible de contacter api.anthropic.com. Vérifie ta connexion ou VPN.".into()
275            } else {
276                format!("Erreur réseau : {e}")
277            }
278        })?;
279
280    match resp.status().as_u16() {
281        200 => Ok(()),
282        401 | 403 => Err("Clé API Anthropic invalide. Vérifie ta clé sur https://console.anthropic.com/settings/keys".into()),
283        429 => Err("Rate limit Anthropic — trop de requêtes. Réessaie dans quelques secondes.".into()),
284        s => Err(format!("Erreur HTTP {s} du serveur Anthropic.")),
285    }
286}
287
288/// Validate an OpenAI-compatible API key with a GET to /v1/models?limit=1.
289async fn validate_openai_compatible_key(base_url: &str, api_key: &str) -> Result<(), String> {
290    let client = reqwest::Client::builder()
291        .timeout(std::time::Duration::from_secs(10))
292        .build()
293        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
294
295    let url = format!("{}/models?limit=1", base_url.trim_end_matches('/'));
296
297    let resp = client
298        .get(&url)
299        .bearer_auth(api_key)
300        .send()
301        .await
302        .map_err(|e| {
303            if e.is_timeout() {
304                format!("Timeout — le serveur à {url} ne répond pas. Check ta connexion.")
305            } else if e.is_connect() {
306                format!("Impossible de contacter {url}. Vérifie ta connexion ou VPN.")
307            } else {
308                format!("Erreur réseau : {e}")
309            }
310        })?;
311
312    match resp.status().as_u16() {
313        200 => Ok(()),
314        401 | 403 => Err("Clé API invalide. Vérifie ta clé.".into()),
315        404 => {
316            // Some providers don't have /v1/models — try a chat completions endpoint instead
317            validate_with_chat_request(base_url, api_key).await
318        }
319        429 => Err("Rate limit — trop de requêtes. Réessaie dans quelques secondes.".into()),
320        s => Err(format!("Erreur HTTP {s}.")),
321    }
322}
323
324/// Fallback validation: send a minimal chat completion request (1 token max).
325async fn validate_with_chat_request(base_url: &str, api_key: &str) -> Result<(), String> {
326    let client = reqwest::Client::builder()
327        .timeout(std::time::Duration::from_secs(10))
328        .build()
329        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
330
331    let url = format!(
332        "{}/chat/completions",
333        base_url.trim_end_matches('/')
334    );
335
336    let body = serde_json::json!({
337        "model": "gpt-3.5-turbo",  // widely supported model name for testing
338        "messages": [{"role": "user", "content": "hi"}],
339        "max_tokens": 1,
340        "temperature": 0.0,
341    });
342
343    let resp = client
344        .post(&url)
345        .bearer_auth(api_key)
346        .json(&body)
347        .send()
348        .await
349        .map_err(|e| {
350            if e.is_timeout() {
351                "Timeout — le serveur ne répond pas.".into()
352            } else if e.is_connect() {
353                format!("Impossible de contacter {url}.")
354            } else {
355                format!("Erreur réseau : {e}")
356            }
357        })?;
358
359    match resp.status().as_u16() {
360        200 => Ok(()),
361        401 | 403 => Err("Clé API invalide.".into()),
362        404 => Err("Endpoint chat/completions introuvable. L'URL de base est peut-être incorrecte.".into()),
363        429 => Err("Rate limit — trop de requêtes.".into()),
364        s => {
365            // Even 400/422 is "good" — it means the key was accepted, just the model
366            // name was wrong.
367            if s == 400 || s == 422 {
368                Ok(())
369            } else {
370                Err(format!("Erreur HTTP {s}."))
371            }
372        }
373    }
374}
375
376/// Validate an Ollama connection (no API key needed).
377async fn validate_ollama_connection(base_url: &str) -> Result<(), String> {
378    let client = reqwest::Client::builder()
379        .timeout(std::time::Duration::from_secs(5))
380        .build()
381        .map_err(|e| format!("Erreur client HTTP: {e}"))?;
382
383    let root = base_url
384        .trim_end_matches('/')
385        .trim_end_matches("/v1");
386    let url = format!("{root}/api/tags");
387
388    let resp = client.get(&url).send().await.map_err(|e| {
389        if e.is_connect() {
390            format!(
391                "Ollama ne tourne pas sur {root}.\n\
392                 → Lance `ollama serve` dans un autre terminal.\n\
393                 → Ou installe Ollama : https://ollama.com"
394            )
395        } else {
396            format!("Erreur réseau : {e}")
397        }
398    })?;
399
400    match resp.status().as_u16() {
401        200 => Ok(()),
402        s => Err(format!("Ollama a répondu HTTP {s}. Vérifie que le serveur tourne.")),
403    }
404}
405
406/// Run validation for all detected providers with keys.
407///
408/// Returns the list of providers with `validated` fields populated.
409pub async fn validate_detected_providers(
410    providers: &mut [DetectedProvider],
411) {
412    for p in providers.iter_mut() {
413        if !p.key_found {
414            p.validated = Some(false);
415            p.validation_error = Some("Aucune clé API trouvée dans l'environnement.".into());
416            continue;
417        }
418
419        let env_var = match &p.env_var {
420            Some(env) => env.clone(),
421            None => {
422                p.validated = Some(false);
423                p.validation_error = Some("Variable d'environnement inconnue.".into());
424                continue;
425            }
426        };
427
428        let api_key = match std::env::var(&env_var) {
429            Ok(k) if !k.trim().is_empty() => k,
430            _ => {
431                p.validated = Some(false);
432                p.validation_error = Some(format!("Variable {env_var} vide."));
433                continue;
434            }
435        };
436
437        match validate_api_key(&p.id, &api_key).await {
438            Ok(()) => {
439                p.validated = Some(true);
440                p.validation_error = None;
441            }
442            Err(e) => {
443                p.validated = Some(false);
444                p.validation_error = Some(e);
445            }
446        }
447    }
448}
449
450/// Get a summary list of ready-to-use providers (key found + validated).
451pub fn ready_providers(providers: &[DetectedProvider]) -> Vec<&DetectedProvider> {
452    providers
453        .iter()
454        .filter(|p| p.key_found && p.validated == Some(true))
455        .collect()
456}
457
458/// Get a list of free providers (regardless of key status), for suggestions.
459pub fn free_providers(providers: &[DetectedProvider]) -> Vec<&DetectedProvider> {
460    providers
461        .iter()
462        .filter(|p| matches!(p.tier, ProviderTier::Free))
463        .collect()
464}