Skip to main content

rustant_core/
oauth.rs

1//! OAuth 2.0 + PKCE authentication for LLM providers.
2//!
3//! Provides browser-based login flows for OpenAI, Google Gemini, and (when available)
4//! Anthropic. Supports both the standard authorization code flow with PKCE and a
5//! device code flow for headless/SSH environments.
6//!
7//! # Supported providers
8//!
9//! | Provider | Status | Flow |
10//! |----------|--------|------|
11//! | OpenAI | Fully supported | OAuth 2.0 + PKCE |
12//! | Google Gemini | Supported | Google OAuth 2.0 |
13//! | Anthropic | Blocked for 3rd-party | API key only |
14
15use base64::Engine;
16use base64::engine::general_purpose::URL_SAFE_NO_PAD;
17use chrono::{DateTime, Utc};
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use std::collections::HashMap;
22use std::future::IntoFuture;
23use std::net::SocketAddr;
24use tokio::sync::oneshot;
25use tracing::{debug, info};
26
27use crate::credentials::{CredentialError, CredentialStore};
28use crate::error::LlmError;
29
30// ── Types ───────────────────────────────────────────────────────────────────
31
32/// Authentication method for a provider.
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
34#[serde(rename_all = "snake_case")]
35pub enum AuthMethod {
36    /// Traditional API key authentication.
37    #[default]
38    ApiKey,
39    /// OAuth 2.0 browser-based login.
40    #[serde(rename = "oauth")]
41    OAuth,
42}
43
44impl std::fmt::Display for AuthMethod {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            AuthMethod::ApiKey => write!(f, "api_key"),
48            AuthMethod::OAuth => write!(f, "oauth"),
49        }
50    }
51}
52
53/// OAuth 2.0 provider configuration.
54#[derive(Debug, Clone)]
55pub struct OAuthProviderConfig {
56    /// Internal provider name (e.g., "openai", "google").
57    pub provider_name: String,
58    /// OAuth client ID.
59    pub client_id: String,
60    /// OAuth client secret (required by confidential clients like Slack, Discord, Teams).
61    /// Public clients (e.g., OpenAI PKCE-only) leave this as `None`.
62    pub client_secret: Option<String>,
63    /// Authorization endpoint URL.
64    pub authorization_url: String,
65    /// Token exchange endpoint URL.
66    pub token_url: String,
67    /// Requested scopes.
68    pub scopes: Vec<String>,
69    /// Optional audience parameter (used by OpenAI).
70    pub audience: Option<String>,
71    /// Whether the provider supports device code flow (for headless environments).
72    pub supports_device_code: bool,
73    /// Device code endpoint URL (if supported).
74    pub device_code_url: Option<String>,
75    /// Extra query parameters to include in the authorization URL.
76    pub extra_auth_params: Vec<(String, String)>,
77}
78
79/// Stored OAuth token data.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct OAuthToken {
82    /// The access token used for API requests.
83    pub access_token: String,
84    /// Optional refresh token for obtaining new access tokens.
85    pub refresh_token: Option<String>,
86    /// Optional ID token (OpenID Connect).
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub id_token: Option<String>,
89    /// When the access token expires (if known).
90    pub expires_at: Option<DateTime<Utc>>,
91    /// Token type (usually "Bearer").
92    pub token_type: String,
93    /// Scopes granted by the authorization server.
94    pub scopes: Vec<String>,
95}
96
97/// PKCE code verifier and challenge pair.
98struct PkcePair {
99    verifier: String,
100    challenge: String,
101}
102
103/// Callback data received from the authorization server.
104struct CallbackData {
105    code: String,
106    state: String,
107}
108
109// ── PKCE ────────────────────────────────────────────────────────────────────
110
111/// Generate a PKCE code verifier and S256 challenge.
112///
113/// The verifier is a random 43-character string using unreserved URI characters.
114/// The challenge is the base64url-encoded SHA-256 hash of the verifier.
115fn generate_pkce_pair() -> PkcePair {
116    let mut rng = rand::thread_rng();
117    let verifier: String = (0..43)
118        .map(|_| {
119            const CHARSET: &[u8] =
120                b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
121            let idx = rng.gen_range(0..CHARSET.len());
122            CHARSET[idx] as char
123        })
124        .collect();
125
126    let mut hasher = Sha256::new();
127    hasher.update(verifier.as_bytes());
128    let digest = hasher.finalize();
129    let challenge = URL_SAFE_NO_PAD.encode(digest);
130
131    PkcePair {
132        verifier,
133        challenge,
134    }
135}
136
137/// Generate a random state parameter for CSRF protection.
138fn generate_state() -> String {
139    let mut rng = rand::thread_rng();
140    let bytes: [u8; 32] = rng.r#gen();
141    URL_SAFE_NO_PAD.encode(bytes)
142}
143
144// ── Callback Server ─────────────────────────────────────────────────────────
145
146/// Default port for the OAuth callback server.
147///
148/// Providers like Slack require the redirect URI to exactly match one registered
149/// in the app settings. Using a fixed port ensures `https://localhost:8844/auth/callback`
150/// is predictable and can be pre-configured.
151pub const OAUTH_CALLBACK_PORT: u16 = 8844;
152
153/// Build the axum router used by the OAuth callback server.
154fn build_callback_router(
155    tx: std::sync::Arc<tokio::sync::Mutex<Option<oneshot::Sender<CallbackData>>>>,
156) -> axum::Router {
157    axum::Router::new().route(
158        "/auth/callback",
159        axum::routing::get({
160            let tx = tx.clone();
161            move |query: axum::extract::Query<HashMap<String, String>>| {
162                let tx = tx.clone();
163                async move {
164                    let code = query.get("code").cloned().unwrap_or_default();
165                    let state = query.get("state").cloned().unwrap_or_default();
166
167                    if let Some(sender) = tx.lock().await.take() {
168                        let _ = sender.send(CallbackData { code, state });
169                    }
170
171                    axum::response::Html(
172                        r#"<!DOCTYPE html>
173<html>
174<head><title>Rustant</title></head>
175<body style="font-family: system-ui; text-align: center; padding-top: 80px;">
176<h2>Authentication successful!</h2>
177<p>You can close this tab and return to the terminal.</p>
178</body>
179</html>"#,
180                    )
181                }
182            }
183        }),
184    )
185}
186
187/// Load TLS config for the OAuth callback server.
188///
189/// Tries the following in order:
190/// 1. `mkcert`-generated certs in `~/.rustant/certs/` (browser-trusted)
191/// 2. Falls back to a self-signed cert generated at runtime via `rcgen`
192///    (the browser will show a warning on first redirect)
193///
194/// To generate trusted certs, run:
195/// ```sh
196/// mkcert -install            # installs the root CA (needs sudo)
197/// mkdir -p ~/.rustant/certs
198/// mkcert -cert-file ~/.rustant/certs/localhost.pem \
199///        -key-file ~/.rustant/certs/localhost-key.pem \
200///        localhost 127.0.0.1
201/// ```
202async fn load_tls_config() -> Result<axum_server::tls_rustls::RustlsConfig, LlmError> {
203    // Check for mkcert certs first.
204    if let Some(home) = directories::BaseDirs::new() {
205        let cert_dir = home.home_dir().join(".rustant").join("certs");
206        let cert_path = cert_dir.join("localhost.pem");
207        let key_path = cert_dir.join("localhost-key.pem");
208
209        if cert_path.exists() && key_path.exists() {
210            info!("Using mkcert certificates from {}", cert_dir.display());
211            return axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_path, key_path)
212                .await
213                .map_err(|e| LlmError::OAuthFailed {
214                    message: format!("Failed to load mkcert certificates: {}", e),
215                });
216        }
217    }
218
219    // Fall back to self-signed cert.
220    info!(
221        "No mkcert certs found in ~/.rustant/certs/. Generating self-signed certificate.\n\
222         Your browser may show a security warning. To avoid this, run:\n  \
223         mkcert -install && mkdir -p ~/.rustant/certs && \
224         mkcert -cert-file ~/.rustant/certs/localhost.pem \
225         -key-file ~/.rustant/certs/localhost-key.pem localhost 127.0.0.1"
226    );
227
228    use rcgen::CertifiedKey;
229    let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
230    let CertifiedKey { cert, key_pair } = rcgen::generate_simple_self_signed(subject_alt_names)
231        .map_err(|e| LlmError::OAuthFailed {
232            message: format!("Failed to generate self-signed certificate: {}", e),
233        })?;
234
235    let cert_pem = cert.pem();
236    let key_pem = key_pair.serialize_pem();
237
238    axum_server::tls_rustls::RustlsConfig::from_pem(cert_pem.into_bytes(), key_pem.into_bytes())
239        .await
240        .map_err(|e| LlmError::OAuthFailed {
241            message: format!("Failed to build TLS config: {}", e),
242        })
243}
244
245/// Start a local callback server on the fixed port.
246///
247/// When `use_tls` is true, the server runs HTTPS with a self-signed certificate
248/// (for providers like Slack that require HTTPS redirect URIs). When false, it
249/// runs plain HTTP (suitable for OpenAI and other providers that accept HTTP on
250/// localhost).
251///
252/// Returns the server's port and a receiver that will yield the callback data.
253async fn start_callback_server(
254    use_tls: bool,
255) -> Result<(u16, oneshot::Receiver<CallbackData>), LlmError> {
256    let (tx, rx) = oneshot::channel::<CallbackData>();
257    let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx)));
258    let app = build_callback_router(tx);
259
260    let bind_addr = format!("127.0.0.1:{}", OAUTH_CALLBACK_PORT);
261
262    if use_tls {
263        // Ensure the rustls CryptoProvider is installed (idempotent).
264        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
265
266        let tls_config = load_tls_config().await?;
267
268        let addr: SocketAddr = bind_addr.parse().map_err(|e| LlmError::OAuthFailed {
269            message: format!("Invalid bind address: {}", e),
270        })?;
271
272        debug!(
273            port = OAUTH_CALLBACK_PORT,
274            "OAuth HTTPS callback server starting"
275        );
276
277        tokio::spawn(async move {
278            let server = axum_server::bind_rustls(addr, tls_config).serve(app.into_make_service());
279            let _ = tokio::time::timeout(std::time::Duration::from_secs(120), server).await;
280        });
281
282        // Give the TLS server a moment to bind.
283        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
284    } else {
285        let listener = tokio::net::TcpListener::bind(&bind_addr)
286            .await
287            .map_err(|e| LlmError::OAuthFailed {
288                message: format!(
289                    "Failed to bind callback server on port {}: {}. \
290                     Make sure no other process is using this port.",
291                    OAUTH_CALLBACK_PORT, e
292                ),
293            })?;
294
295        debug!(
296            port = OAUTH_CALLBACK_PORT,
297            "OAuth HTTP callback server starting"
298        );
299
300        tokio::spawn(async move {
301            let server = axum::serve(listener, app);
302            let _ = tokio::time::timeout(std::time::Duration::from_secs(120), server.into_future())
303                .await;
304        });
305    }
306
307    Ok((OAUTH_CALLBACK_PORT, rx))
308}
309
310// ── Browser Flow ────────────────────────────────────────────────────────────
311
312/// Run the full OAuth 2.0 Authorization Code flow with PKCE.
313///
314/// 1. Generate PKCE pair and state
315/// 2. Start local callback server
316/// 3. Build authorization URL and open the user's browser
317/// 4. Wait for the callback with the authorization code
318/// 5. Exchange the code for tokens
319///
320/// Returns the obtained `OAuthToken` on success.
321///
322/// If `redirect_uri_override` is `Some`, that URI is sent to the OAuth provider
323/// instead of the default. When the redirect URI starts with `https://`, the
324/// local callback server will use TLS with a self-signed certificate; otherwise
325/// it runs plain HTTP.
326///
327/// Channel providers that require HTTPS (e.g. Slack) will automatically get a
328/// TLS-enabled callback server via the `https://localhost:8844/auth/callback`
329/// default.
330pub async fn authorize_browser_flow(
331    config: &OAuthProviderConfig,
332    redirect_uri_override: Option<&str>,
333) -> Result<OAuthToken, LlmError> {
334    let pkce = generate_pkce_pair();
335    let state = generate_state();
336
337    // Determine the redirect URI and whether we need TLS.
338    // Channel providers (Slack, Discord, etc.) require HTTPS; LLM providers
339    // (OpenAI, Google) typically accept HTTP on localhost.
340    let is_channel_provider = matches!(
341        config.provider_name.as_str(),
342        "slack" | "discord" | "teams" | "whatsapp" | "gmail"
343    );
344
345    let use_tls = match redirect_uri_override {
346        Some(uri) => uri.starts_with("https://"),
347        None => is_channel_provider,
348    };
349
350    // Start callback server (HTTP or HTTPS depending on use_tls).
351    let (port, rx) = start_callback_server(use_tls).await?;
352
353    let redirect_uri = match redirect_uri_override {
354        Some(uri) => uri.to_string(),
355        None => {
356            let scheme = if use_tls { "https" } else { "http" };
357            format!("{}://localhost:{}/auth/callback", scheme, port)
358        }
359    };
360
361    // Build authorization URL.
362    let mut auth_url =
363        url::Url::parse(&config.authorization_url).map_err(|e| LlmError::OAuthFailed {
364            message: format!("Invalid authorization URL: {}", e),
365        })?;
366
367    {
368        let mut params = auth_url.query_pairs_mut();
369        params.append_pair("response_type", "code");
370        params.append_pair("client_id", &config.client_id);
371        params.append_pair("redirect_uri", &redirect_uri);
372        params.append_pair("code_challenge", &pkce.challenge);
373        params.append_pair("code_challenge_method", "S256");
374        params.append_pair("state", &state);
375
376        if !config.scopes.is_empty() {
377            params.append_pair("scope", &config.scopes.join(" "));
378        }
379        if let Some(ref audience) = config.audience {
380            params.append_pair("audience", audience);
381        }
382        for (key, value) in &config.extra_auth_params {
383            params.append_pair(key, value);
384        }
385    }
386
387    info!("Opening browser for OAuth authorization...");
388    debug!(url = %auth_url, "Authorization URL");
389    open::that(auth_url.as_str()).map_err(|e| LlmError::OAuthFailed {
390        message: format!("Failed to open browser: {}", e),
391    })?;
392
393    // Wait for the callback.
394    let callback = tokio::time::timeout(std::time::Duration::from_secs(120), rx)
395        .await
396        .map_err(|_| LlmError::OAuthFailed {
397            message: "OAuth callback timed out after 120 seconds".to_string(),
398        })?
399        .map_err(|_| LlmError::OAuthFailed {
400            message: "OAuth callback channel closed unexpectedly".to_string(),
401        })?;
402
403    // Verify state parameter.
404    if callback.state != state {
405        return Err(LlmError::OAuthFailed {
406            message: "OAuth state parameter mismatch (possible CSRF attack)".to_string(),
407        });
408    }
409
410    if callback.code.is_empty() {
411        return Err(LlmError::OAuthFailed {
412            message: "OAuth callback did not contain an authorization code".to_string(),
413        });
414    }
415
416    // Exchange authorization code for tokens.
417    let mut token =
418        exchange_code_for_token(config, &callback.code, &pkce.verifier, &redirect_uri).await?;
419
420    // For OpenAI: try to exchange the ID token for a Platform API key.
421    // This succeeds for accounts with Platform org/project setup. For Personal/
422    // ChatGPT-only accounts it may fail — in that case we fall back to using the
423    // OAuth access token directly as a Bearer token (same as Codex CLI).
424    if config.provider_name == "openai"
425        && let Some(ref id_tok) = token.id_token
426    {
427        if let Some(payload) = id_tok.split('.').nth(1)
428            && let Ok(bytes) = URL_SAFE_NO_PAD.decode(payload)
429            && let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&bytes)
430        {
431            debug!(claims = %claims, "ID token claims");
432        }
433        debug!("Exchanging ID token for OpenAI API key...");
434        match obtain_openai_api_key(config, id_tok).await {
435            Ok(api_key) => {
436                info!("Obtained OpenAI Platform API key via token exchange");
437                token.access_token = api_key;
438            }
439            Err(e) => {
440                // The token exchange typically fails for accounts without a
441                // Platform API organization. The standard Chat Completions
442                // endpoint requires a Platform API key — the raw OAuth
443                // access token won't work.
444                return Err(LlmError::OAuthFailed {
445                    message: format!(
446                        "Failed to exchange OAuth token for an OpenAI API key: {}\n\n\
447                             This usually means your OpenAI account does not have \
448                             Platform API access set up.\n\n\
449                             To fix this:\n\
450                             1. Visit https://platform.openai.com to create an API organization\n\
451                             2. Ensure you have a billing method or active subscription\n\
452                             3. Run 'rustant auth login openai' again\n\n\
453                             Alternatively, use a standard API key:\n\
454                             1. Get your key from https://platform.openai.com/api-keys\n\
455                             2. Set the OPENAI_API_KEY environment variable\n\
456                             3. Set auth_method to empty in .rustant/config.toml",
457                        e
458                    ),
459                });
460            }
461        }
462    }
463
464    Ok(token)
465}
466
467/// Exchange an authorization code for an access token.
468async fn exchange_code_for_token(
469    config: &OAuthProviderConfig,
470    code: &str,
471    code_verifier: &str,
472    redirect_uri: &str,
473) -> Result<OAuthToken, LlmError> {
474    let client = reqwest::Client::new();
475
476    // Build the body exactly like the Codex CLI: using urlencoding::encode()
477    // with format!() and .body() for consistent percent-encoding.
478    let mut body = format!(
479        "grant_type={}&code={}&redirect_uri={}&client_id={}&code_verifier={}",
480        urlencoding::encode("authorization_code"),
481        urlencoding::encode(code),
482        urlencoding::encode(redirect_uri),
483        urlencoding::encode(&config.client_id),
484        urlencoding::encode(code_verifier),
485    );
486
487    // Confidential clients (Slack, Discord, Teams, etc.) require a client_secret.
488    if let Some(ref secret) = config.client_secret {
489        body.push_str(&format!("&client_secret={}", urlencoding::encode(secret)));
490    }
491
492    debug!(provider = %config.provider_name, "Exchanging authorization code for token");
493
494    let response = client
495        .post(&config.token_url)
496        .header("Content-Type", "application/x-www-form-urlencoded")
497        .body(body)
498        .send()
499        .await
500        .map_err(|e| LlmError::OAuthFailed {
501            message: format!("Token exchange request failed: {}", e),
502        })?;
503
504    let status = response.status();
505    let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
506        message: format!("Failed to read token response: {}", e),
507    })?;
508
509    if !status.is_success() {
510        return Err(LlmError::OAuthFailed {
511            message: format!("Token exchange failed (HTTP {}): {}", status, body_text),
512        });
513    }
514
515    parse_token_response(&body_text)
516}
517
518/// Exchange an OpenAI ID token for an actual OpenAI API key via the
519/// RFC 8693 token-exchange grant type.
520///
521/// This is the second step of the OpenAI Codex OAuth flow: after the standard
522/// PKCE code exchange, the ID token is exchanged for a usable API key.
523///
524/// Uses manual URL-encoded body construction (matching Codex CLI) instead of
525/// `reqwest .form()` to avoid potential double-encoding of the JWT ID token.
526async fn obtain_openai_api_key(
527    config: &OAuthProviderConfig,
528    id_token: &str,
529) -> Result<String, LlmError> {
530    let client = reqwest::Client::new();
531
532    // Build the body exactly like the Codex CLI: using urlencoding::encode()
533    // with format!() and .body(). This ensures identical percent-encoding
534    // behavior (RFC 3986 unreserved chars, %20 for spaces).
535    let body = format!(
536        "grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}",
537        urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
538        urlencoding::encode(&config.client_id),
539        urlencoding::encode("openai-api-key"),
540        urlencoding::encode(id_token),
541        urlencoding::encode("urn:ietf:params:oauth:token-type:id_token"),
542    );
543
544    debug!(body_len = body.len(), "Token exchange request body");
545
546    let response = client
547        .post(&config.token_url)
548        .header("Content-Type", "application/x-www-form-urlencoded")
549        .body(body)
550        .send()
551        .await
552        .map_err(|e| LlmError::OAuthFailed {
553            message: format!("API key exchange request failed: {}", e),
554        })?;
555
556    let status = response.status();
557    let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
558        message: format!("Failed to read API key exchange response: {}", e),
559    })?;
560
561    if !status.is_success() {
562        return Err(LlmError::OAuthFailed {
563            message: format!("API key exchange failed (HTTP {}): {}", status, body_text),
564        });
565    }
566
567    let json: serde_json::Value =
568        serde_json::from_str(&body_text).map_err(|e| LlmError::OAuthFailed {
569            message: format!("Invalid JSON in API key exchange response: {}", e),
570        })?;
571
572    json["access_token"]
573        .as_str()
574        .map(|s| s.to_string())
575        .ok_or_else(|| LlmError::OAuthFailed {
576            message: "API key exchange response missing 'access_token'".to_string(),
577        })
578}
579
580/// Parse a token endpoint response into an `OAuthToken`.
581fn parse_token_response(body: &str) -> Result<OAuthToken, LlmError> {
582    let json: serde_json::Value =
583        serde_json::from_str(body).map_err(|e| LlmError::OAuthFailed {
584            message: format!("Invalid JSON in token response: {}", e),
585        })?;
586
587    let access_token = json["access_token"]
588        .as_str()
589        .ok_or_else(|| LlmError::OAuthFailed {
590            message: "Token response missing 'access_token'".to_string(),
591        })?
592        .to_string();
593
594    let refresh_token = json["refresh_token"].as_str().map(|s| s.to_string());
595    let id_token = json["id_token"].as_str().map(|s| s.to_string());
596    let token_type = json["token_type"].as_str().unwrap_or("Bearer").to_string();
597
598    let expires_at = json["expires_in"]
599        .as_u64()
600        .map(|secs| Utc::now() + chrono::Duration::seconds(secs as i64));
601
602    let scopes = json["scope"]
603        .as_str()
604        .map(|s| s.split_whitespace().map(|s| s.to_string()).collect())
605        .unwrap_or_default();
606
607    Ok(OAuthToken {
608        access_token,
609        refresh_token,
610        id_token,
611        expires_at,
612        token_type,
613        scopes,
614    })
615}
616
617// ── Device Code Flow ────────────────────────────────────────────────────────
618
619/// Run the OAuth 2.0 Device Code flow for headless environments.
620///
621/// 1. Request a device code from the provider
622/// 2. Display the user code and verification URI
623/// 3. Poll the token endpoint until the user completes authorization
624///
625/// Returns the obtained `OAuthToken` on success.
626pub async fn authorize_device_code_flow(
627    config: &OAuthProviderConfig,
628) -> Result<OAuthToken, LlmError> {
629    let device_code_url =
630        config
631            .device_code_url
632            .as_deref()
633            .ok_or_else(|| LlmError::OAuthFailed {
634                message: format!(
635                    "Provider '{}' does not support device code flow",
636                    config.provider_name
637                ),
638            })?;
639
640    let client = reqwest::Client::new();
641
642    // Step 1: Request device code.
643    let mut params = HashMap::new();
644    params.insert("client_id", config.client_id.as_str());
645    if !config.scopes.is_empty() {
646        let scope_str = config.scopes.join(" ");
647        params.insert("scope", Box::leak(scope_str.into_boxed_str()));
648    }
649    if let Some(ref audience) = config.audience {
650        params.insert("audience", audience.as_str());
651    }
652
653    let response = client
654        .post(device_code_url)
655        .form(&params)
656        .send()
657        .await
658        .map_err(|e| LlmError::OAuthFailed {
659            message: format!("Device code request failed: {}", e),
660        })?;
661
662    let status = response.status();
663    let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
664        message: format!("Failed to read device code response: {}", e),
665    })?;
666
667    if !status.is_success() {
668        return Err(LlmError::OAuthFailed {
669            message: format!(
670                "Device code request failed (HTTP {}): {}",
671                status, body_text
672            ),
673        });
674    }
675
676    let json: serde_json::Value =
677        serde_json::from_str(&body_text).map_err(|e| LlmError::OAuthFailed {
678            message: format!("Invalid JSON in device code response: {}", e),
679        })?;
680
681    let device_code = json["device_code"]
682        .as_str()
683        .ok_or_else(|| LlmError::OAuthFailed {
684            message: "Device code response missing 'device_code'".to_string(),
685        })?;
686    let user_code = json["user_code"]
687        .as_str()
688        .ok_or_else(|| LlmError::OAuthFailed {
689            message: "Device code response missing 'user_code'".to_string(),
690        })?;
691    let verification_uri = json["verification_uri"]
692        .as_str()
693        .or_else(|| json["verification_url"].as_str())
694        .ok_or_else(|| LlmError::OAuthFailed {
695            message: "Device code response missing 'verification_uri'".to_string(),
696        })?;
697    let interval = json["interval"].as_u64().unwrap_or(5);
698    let expires_in = json["expires_in"].as_u64().unwrap_or(600);
699
700    // Step 2: Display instructions.
701    println!();
702    println!("  To authenticate, visit: {}", verification_uri);
703    println!("  Enter this code: {}", user_code);
704    println!();
705    println!("  Waiting for authorization...");
706
707    // Step 3: Poll for token.
708    let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(expires_in);
709    let poll_interval = std::time::Duration::from_secs(interval);
710
711    loop {
712        tokio::time::sleep(poll_interval).await;
713
714        if tokio::time::Instant::now() >= deadline {
715            return Err(LlmError::OAuthFailed {
716                message: "Device code flow timed out waiting for authorization".to_string(),
717            });
718        }
719
720        let mut poll_params = HashMap::new();
721        poll_params.insert("grant_type", "urn:ietf:params:oauth:grant-type:device_code");
722        poll_params.insert("device_code", device_code);
723        poll_params.insert("client_id", &config.client_id);
724
725        let poll_response = client
726            .post(&config.token_url)
727            .form(&poll_params)
728            .send()
729            .await
730            .map_err(|e| LlmError::OAuthFailed {
731                message: format!("Token poll request failed: {}", e),
732            })?;
733
734        let poll_status = poll_response.status();
735        let poll_body = poll_response
736            .text()
737            .await
738            .map_err(|e| LlmError::OAuthFailed {
739                message: format!("Failed to read token poll response: {}", e),
740            })?;
741
742        if poll_status.is_success() {
743            return parse_token_response(&poll_body);
744        }
745
746        // Check for "authorization_pending" or "slow_down" errors.
747        if let Ok(err_json) = serde_json::from_str::<serde_json::Value>(&poll_body) {
748            let error = err_json["error"].as_str().unwrap_or("");
749            match error {
750                "authorization_pending" => {
751                    debug!("Device code flow: authorization pending, polling again...");
752                    continue;
753                }
754                "slow_down" => {
755                    debug!("Device code flow: slow down requested");
756                    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
757                    continue;
758                }
759                "expired_token" => {
760                    return Err(LlmError::OAuthFailed {
761                        message: "Device code expired. Please try again.".to_string(),
762                    });
763                }
764                "access_denied" => {
765                    return Err(LlmError::OAuthFailed {
766                        message: "Authorization was denied by the user.".to_string(),
767                    });
768                }
769                _ => {
770                    return Err(LlmError::OAuthFailed {
771                        message: format!("Token poll error: {}", poll_body),
772                    });
773                }
774            }
775        }
776
777        // Non-JSON error response.
778        return Err(LlmError::OAuthFailed {
779            message: format!("Token poll failed (HTTP {}): {}", poll_status, poll_body),
780        });
781    }
782}
783
784// ── Token Refresh ───────────────────────────────────────────────────────────
785
786/// Refresh an OAuth access token using a refresh token.
787pub async fn refresh_token(
788    config: &OAuthProviderConfig,
789    refresh_token_str: &str,
790) -> Result<OAuthToken, LlmError> {
791    let client = reqwest::Client::new();
792
793    let mut params = HashMap::new();
794    params.insert("grant_type", "refresh_token");
795    params.insert("refresh_token", refresh_token_str);
796    params.insert("client_id", &config.client_id);
797
798    debug!(provider = %config.provider_name, "Refreshing OAuth token");
799
800    let response = client
801        .post(&config.token_url)
802        .form(&params)
803        .send()
804        .await
805        .map_err(|e| LlmError::OAuthFailed {
806            message: format!("Token refresh request failed: {}", e),
807        })?;
808
809    let status = response.status();
810    let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
811        message: format!("Failed to read token refresh response: {}", e),
812    })?;
813
814    if !status.is_success() {
815        return Err(LlmError::OAuthFailed {
816            message: format!("Token refresh failed (HTTP {}): {}", status, body_text),
817        });
818    }
819
820    let mut token = parse_token_response(&body_text)?;
821
822    // Some providers don't return a new refresh_token; preserve the old one.
823    if token.refresh_token.is_none() {
824        token.refresh_token = Some(refresh_token_str.to_string());
825    }
826
827    Ok(token)
828}
829
830// ── Token Expiration ────────────────────────────────────────────────────────
831
832/// Check whether an OAuth token has expired (with a 5-minute safety buffer).
833pub fn is_token_expired(token: &OAuthToken) -> bool {
834    match token.expires_at {
835        Some(expires_at) => {
836            let buffer = chrono::Duration::minutes(5);
837            Utc::now() >= (expires_at - buffer)
838        }
839        // No expiration info — assume it's still valid.
840        None => false,
841    }
842}
843
844// ── Token Storage ───────────────────────────────────────────────────────────
845
846/// Store an OAuth token in the credential store.
847///
848/// The token is serialized as JSON and stored under the key `oauth:{provider}`.
849pub fn store_oauth_token(
850    store: &dyn CredentialStore,
851    provider: &str,
852    token: &OAuthToken,
853) -> Result<(), LlmError> {
854    let key = format!("oauth:{}", provider);
855    let json = serde_json::to_string(token).map_err(|e| LlmError::OAuthFailed {
856        message: format!("Failed to serialize OAuth token: {}", e),
857    })?;
858    store
859        .store_key(&key, &json)
860        .map_err(|e| LlmError::OAuthFailed {
861            message: format!("Failed to store OAuth token: {}", e),
862        })
863}
864
865/// Load an OAuth token from the credential store.
866pub fn load_oauth_token(
867    store: &dyn CredentialStore,
868    provider: &str,
869) -> Result<OAuthToken, LlmError> {
870    let key = format!("oauth:{}", provider);
871    let json = store.get_key(&key).map_err(|e| match e {
872        CredentialError::NotFound { .. } => LlmError::OAuthFailed {
873            message: format!("No OAuth token found for provider '{}'", provider),
874        },
875        other => LlmError::OAuthFailed {
876            message: format!("Failed to load OAuth token: {}", other),
877        },
878    })?;
879    serde_json::from_str(&json).map_err(|e| LlmError::OAuthFailed {
880        message: format!("Failed to deserialize OAuth token: {}", e),
881    })
882}
883
884/// Delete an OAuth token from the credential store.
885pub fn delete_oauth_token(store: &dyn CredentialStore, provider: &str) -> Result<(), LlmError> {
886    let key = format!("oauth:{}", provider);
887    store.delete_key(&key).map_err(|e| LlmError::OAuthFailed {
888        message: format!("Failed to delete OAuth token: {}", e),
889    })
890}
891
892/// Check whether an OAuth token exists in the credential store.
893pub fn has_oauth_token(store: &dyn CredentialStore, provider: &str) -> bool {
894    let key = format!("oauth:{}", provider);
895    store.has_key(&key)
896}
897
898// ── Provider Configs ────────────────────────────────────────────────────────
899
900/// OAuth configuration for OpenAI.
901///
902/// Uses the Codex public client ID. Supports both browser and device code flows.
903pub fn openai_oauth_config() -> OAuthProviderConfig {
904    OAuthProviderConfig {
905        provider_name: "openai".to_string(),
906        client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(),
907        client_secret: None, // public PKCE client
908        authorization_url: "https://auth.openai.com/oauth/authorize".to_string(),
909        token_url: "https://auth.openai.com/oauth/token".to_string(),
910        scopes: vec![
911            "openid".to_string(),
912            "profile".to_string(),
913            "email".to_string(),
914            "offline_access".to_string(),
915        ],
916        audience: None,
917        supports_device_code: true,
918        device_code_url: Some("https://auth.openai.com/oauth/device/code".to_string()),
919        extra_auth_params: vec![
920            ("id_token_add_organizations".to_string(), "true".to_string()),
921            ("codex_cli_simplified_flow".to_string(), "true".to_string()),
922            ("originator".to_string(), "codex_cli_rs".to_string()),
923        ],
924    }
925}
926
927/// OAuth configuration for Google (Gemini).
928///
929/// Requires a GCP OAuth client ID configured via the `GOOGLE_OAUTH_CLIENT_ID`
930/// environment variable. Users must create an OAuth 2.0 client in the GCP Console
931/// (application type: Desktop) with the Generative Language API scope enabled.
932pub fn google_oauth_config() -> Option<OAuthProviderConfig> {
933    let client_id = std::env::var("GOOGLE_OAUTH_CLIENT_ID").ok()?;
934    let client_secret = std::env::var("GOOGLE_OAUTH_CLIENT_SECRET").ok();
935    Some(OAuthProviderConfig {
936        provider_name: "google".to_string(),
937        client_id,
938        client_secret,
939        authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
940        token_url: "https://oauth2.googleapis.com/token".to_string(),
941        scopes: vec!["https://www.googleapis.com/auth/generative-language".to_string()],
942        audience: None,
943        supports_device_code: false,
944        device_code_url: None,
945        extra_auth_params: vec![],
946    })
947}
948
949/// OAuth configuration for Anthropic.
950///
951/// Currently returns `None` because Anthropic has blocked third-party tools from
952/// using their OAuth endpoints as of January 2026. When/if Anthropic opens a
953/// third-party OAuth program, this function will be updated to return a config.
954pub fn anthropic_oauth_config() -> Option<OAuthProviderConfig> {
955    // Anthropic OAuth is not available for third-party tools.
956    // The infrastructure is ready; only a valid client_id is needed.
957    None
958}
959
960// ── Channel OAuth Configs ──────────────────────────────────────────────────
961
962/// OAuth configuration for Slack.
963///
964/// Uses Slack's OAuth 2.0 V2 flow with bot scopes for channel messaging,
965/// history reading, and user info. Requires a Slack App client ID.
966pub fn slack_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
967    OAuthProviderConfig {
968        provider_name: "slack".to_string(),
969        client_id: client_id.to_string(),
970        client_secret,
971        authorization_url: "https://slack.com/oauth/v2/authorize".to_string(),
972        token_url: "https://slack.com/api/oauth.v2.access".to_string(),
973        scopes: vec![
974            "chat:write".to_string(),
975            "channels:history".to_string(),
976            "channels:read".to_string(),
977            "users:read".to_string(),
978        ],
979        audience: None,
980        supports_device_code: false,
981        device_code_url: None,
982        extra_auth_params: vec![],
983    }
984}
985
986/// OAuth configuration for Discord.
987///
988/// Uses Discord's OAuth 2.0 flow with bot scope for messaging and reading.
989/// Requires a Discord Application client ID.
990pub fn discord_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
991    OAuthProviderConfig {
992        provider_name: "discord".to_string(),
993        client_id: client_id.to_string(),
994        client_secret,
995        authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
996        token_url: "https://discord.com/api/oauth2/token".to_string(),
997        scopes: vec!["bot".to_string(), "messages.read".to_string()],
998        audience: None,
999        supports_device_code: false,
1000        device_code_url: None,
1001        extra_auth_params: vec![],
1002    }
1003}
1004
1005/// OAuth configuration for Microsoft Teams via Azure AD.
1006///
1007/// Uses Azure AD's OAuth 2.0 flow with Microsoft Graph scopes.
1008/// The `tenant_id` can be "common" for multi-tenant apps or a specific
1009/// Azure AD tenant ID. Teams bots typically use the client credentials
1010/// grant (server-to-server), but this config also supports the authorization
1011/// code flow for user-delegated access.
1012pub fn teams_oauth_config(
1013    client_id: &str,
1014    tenant_id: &str,
1015    client_secret: Option<String>,
1016) -> OAuthProviderConfig {
1017    OAuthProviderConfig {
1018        provider_name: "teams".to_string(),
1019        client_id: client_id.to_string(),
1020        client_secret,
1021        authorization_url: format!(
1022            "https://login.microsoftonline.com/{}/oauth2/v2.0/authorize",
1023            tenant_id
1024        ),
1025        token_url: format!(
1026            "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
1027            tenant_id
1028        ),
1029        scopes: vec!["https://graph.microsoft.com/.default".to_string()],
1030        audience: None,
1031        supports_device_code: true,
1032        device_code_url: Some(format!(
1033            "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
1034            tenant_id
1035        )),
1036        extra_auth_params: vec![],
1037    }
1038}
1039
1040/// OAuth configuration for WhatsApp via Meta Business Platform.
1041///
1042/// Uses Meta's OAuth 2.0 flow for WhatsApp Business API access.
1043/// Requires a Meta App ID as the client ID.
1044pub fn whatsapp_oauth_config(app_id: &str, app_secret: Option<String>) -> OAuthProviderConfig {
1045    OAuthProviderConfig {
1046        provider_name: "whatsapp".to_string(),
1047        client_id: app_id.to_string(),
1048        client_secret: app_secret,
1049        authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
1050        token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
1051        scopes: vec![
1052            "whatsapp_business_messaging".to_string(),
1053            "whatsapp_business_management".to_string(),
1054        ],
1055        audience: None,
1056        supports_device_code: false,
1057        device_code_url: None,
1058        extra_auth_params: vec![],
1059    }
1060}
1061
1062/// OAuth configuration for Gmail (IMAP/SMTP with XOAUTH2).
1063///
1064/// Reuses Google's OAuth 2.0 endpoints with the Gmail-specific scope for
1065/// full mailbox access via IMAP and SMTP XOAUTH2 SASL authentication.
1066/// Requires a GCP OAuth client ID.
1067pub fn gmail_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
1068    OAuthProviderConfig {
1069        provider_name: "gmail".to_string(),
1070        client_id: client_id.to_string(),
1071        client_secret,
1072        authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1073        token_url: "https://oauth2.googleapis.com/token".to_string(),
1074        scopes: vec!["https://mail.google.com/".to_string()],
1075        audience: None,
1076        supports_device_code: false,
1077        device_code_url: None,
1078        extra_auth_params: vec![
1079            ("access_type".to_string(), "offline".to_string()),
1080            ("prompt".to_string(), "consent".to_string()),
1081        ],
1082    }
1083}
1084
1085// ── Client Credentials Flow ────────────────────────────────────────────────
1086
1087/// Run the OAuth 2.0 Client Credentials flow (server-to-server).
1088///
1089/// This flow is used by services like Microsoft Teams bots that authenticate
1090/// as the application itself rather than a user. It requires both the client ID
1091/// (in the config) and a client secret.
1092///
1093/// Returns an `OAuthToken` with an access token and expiration.
1094pub async fn authorize_client_credentials_flow(
1095    config: &OAuthProviderConfig,
1096    client_secret: &str,
1097) -> Result<OAuthToken, LlmError> {
1098    let client = reqwest::Client::new();
1099
1100    // Use the explicit parameter, falling back to config.client_secret if empty.
1101    let secret = if client_secret.is_empty() {
1102        config.client_secret.as_deref().unwrap_or("")
1103    } else {
1104        client_secret
1105    };
1106
1107    let body = format!(
1108        "grant_type={}&client_id={}&client_secret={}&scope={}",
1109        urlencoding::encode("client_credentials"),
1110        urlencoding::encode(&config.client_id),
1111        urlencoding::encode(secret),
1112        urlencoding::encode(&config.scopes.join(" ")),
1113    );
1114
1115    debug!(provider = %config.provider_name, "Requesting client credentials token");
1116
1117    let response = client
1118        .post(&config.token_url)
1119        .header("Content-Type", "application/x-www-form-urlencoded")
1120        .body(body)
1121        .send()
1122        .await
1123        .map_err(|e| LlmError::OAuthFailed {
1124            message: format!("Client credentials request failed: {}", e),
1125        })?;
1126
1127    let status = response.status();
1128    let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
1129        message: format!("Failed to read client credentials response: {}", e),
1130    })?;
1131
1132    if !status.is_success() {
1133        return Err(LlmError::OAuthFailed {
1134            message: format!(
1135                "Client credentials token request failed (HTTP {}): {}",
1136                status, body_text
1137            ),
1138        });
1139    }
1140
1141    parse_token_response(&body_text)
1142}
1143
1144/// Build an XOAUTH2 SASL token string for IMAP/SMTP authentication.
1145///
1146/// Format: `user=<email>\x01auth=Bearer <token>\x01\x01`
1147/// This is used by Gmail and other providers that support XOAUTH2.
1148pub fn build_xoauth2_token(email: &str, access_token: &str) -> String {
1149    format!("user={}\x01auth=Bearer {}\x01\x01", email, access_token)
1150}
1151
1152/// Base64-encode an XOAUTH2 token for SASL AUTH.
1153pub fn build_xoauth2_token_base64(email: &str, access_token: &str) -> String {
1154    use base64::Engine;
1155    let raw = build_xoauth2_token(email, access_token);
1156    base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1157}
1158
1159/// Look up the OAuth configuration for a provider by name.
1160///
1161/// Returns `None` if the provider does not support OAuth or if required
1162/// environment variables (e.g., `GOOGLE_OAUTH_CLIENT_ID`) are not set.
1163///
1164/// For channel providers (slack, discord, teams, whatsapp, gmail), the relevant
1165/// client ID / app ID environment variables must be set.
1166pub fn oauth_config_for_provider(provider: &str) -> Option<OAuthProviderConfig> {
1167    match provider {
1168        "openai" => Some(openai_oauth_config()),
1169        "gemini" | "google" => google_oauth_config(),
1170        "anthropic" => anthropic_oauth_config(),
1171        "slack" => {
1172            let client_id = std::env::var("SLACK_CLIENT_ID").ok()?;
1173            let client_secret = std::env::var("SLACK_CLIENT_SECRET").ok();
1174            Some(slack_oauth_config(&client_id, client_secret))
1175        }
1176        "discord" => {
1177            let client_id = std::env::var("DISCORD_CLIENT_ID").ok()?;
1178            let client_secret = std::env::var("DISCORD_CLIENT_SECRET").ok();
1179            Some(discord_oauth_config(&client_id, client_secret))
1180        }
1181        "teams" => {
1182            let client_id = std::env::var("TEAMS_CLIENT_ID").ok()?;
1183            let tenant_id =
1184                std::env::var("TEAMS_TENANT_ID").unwrap_or_else(|_| "common".to_string());
1185            let client_secret = std::env::var("TEAMS_CLIENT_SECRET").ok();
1186            Some(teams_oauth_config(&client_id, &tenant_id, client_secret))
1187        }
1188        "whatsapp" => {
1189            let app_id = std::env::var("WHATSAPP_APP_ID").ok()?;
1190            let app_secret = std::env::var("WHATSAPP_APP_SECRET").ok();
1191            Some(whatsapp_oauth_config(&app_id, app_secret))
1192        }
1193        "gmail" => {
1194            let client_id = std::env::var("GMAIL_OAUTH_CLIENT_ID")
1195                .or_else(|_| std::env::var("GOOGLE_OAUTH_CLIENT_ID"))
1196                .ok()?;
1197            let client_secret = std::env::var("GMAIL_OAUTH_CLIENT_SECRET")
1198                .or_else(|_| std::env::var("GOOGLE_OAUTH_CLIENT_SECRET"))
1199                .ok();
1200            Some(gmail_oauth_config(&client_id, client_secret))
1201        }
1202        _ => None,
1203    }
1204}
1205
1206/// Build an OAuth configuration using directly-provided credentials.
1207///
1208/// Unlike [`oauth_config_for_provider`] which reads client credentials from
1209/// environment variables, this function accepts them as parameters. This is
1210/// used by the interactive `channel setup` wizard where the user enters
1211/// credentials at a prompt rather than setting env vars.
1212pub fn oauth_config_with_credentials(
1213    provider: &str,
1214    client_id: &str,
1215    client_secret: Option<&str>,
1216) -> Option<OAuthProviderConfig> {
1217    let secret = client_secret.map(String::from);
1218    match provider {
1219        "slack" => Some(slack_oauth_config(client_id, secret)),
1220        "discord" => Some(discord_oauth_config(client_id, secret)),
1221        "gmail" => Some(gmail_oauth_config(client_id, secret)),
1222        _ => None,
1223    }
1224}
1225
1226/// Check whether a provider supports OAuth login.
1227pub fn provider_supports_oauth(provider: &str) -> bool {
1228    match provider {
1229        "openai" => true,
1230        "gemini" | "google" => std::env::var("GOOGLE_OAUTH_CLIENT_ID").is_ok(),
1231        "slack" => std::env::var("SLACK_CLIENT_ID").is_ok(),
1232        "discord" => std::env::var("DISCORD_CLIENT_ID").is_ok(),
1233        "teams" => std::env::var("TEAMS_CLIENT_ID").is_ok(),
1234        "whatsapp" => std::env::var("WHATSAPP_APP_ID").is_ok(),
1235        "gmail" => {
1236            std::env::var("GMAIL_OAUTH_CLIENT_ID").is_ok()
1237                || std::env::var("GOOGLE_OAUTH_CLIENT_ID").is_ok()
1238        }
1239        _ => false,
1240    }
1241}
1242
1243// ── Tests ───────────────────────────────────────────────────────────────────
1244
1245#[cfg(test)]
1246mod tests {
1247    use super::*;
1248    use crate::credentials::InMemoryCredentialStore;
1249
1250    #[test]
1251    fn test_generate_pkce_pair() {
1252        let pair = generate_pkce_pair();
1253        assert_eq!(pair.verifier.len(), 43);
1254        assert!(!pair.challenge.is_empty());
1255
1256        // Verify the challenge is a valid base64url-encoded SHA-256 hash.
1257        let decoded = URL_SAFE_NO_PAD.decode(&pair.challenge).unwrap();
1258        assert_eq!(decoded.len(), 32); // SHA-256 produces 32 bytes
1259
1260        // Verify the challenge matches the verifier.
1261        let mut hasher = Sha256::new();
1262        hasher.update(pair.verifier.as_bytes());
1263        let expected = hasher.finalize();
1264        assert_eq!(decoded, expected.as_slice());
1265    }
1266
1267    #[test]
1268    fn test_generate_pkce_pair_uniqueness() {
1269        let pair1 = generate_pkce_pair();
1270        let pair2 = generate_pkce_pair();
1271        assert_ne!(pair1.verifier, pair2.verifier);
1272        assert_ne!(pair1.challenge, pair2.challenge);
1273    }
1274
1275    #[test]
1276    fn test_generate_state() {
1277        let state = generate_state();
1278        assert!(!state.is_empty());
1279        // base64url of 32 bytes = 43 characters
1280        assert_eq!(state.len(), 43);
1281    }
1282
1283    #[test]
1284    fn test_generate_state_uniqueness() {
1285        let s1 = generate_state();
1286        let s2 = generate_state();
1287        assert_ne!(s1, s2);
1288    }
1289
1290    #[test]
1291    fn test_parse_token_response_full() {
1292        let body = serde_json::json!({
1293            "access_token": "at-12345",
1294            "refresh_token": "rt-67890",
1295            "token_type": "Bearer",
1296            "expires_in": 3600,
1297            "scope": "openai.public"
1298        })
1299        .to_string();
1300
1301        let token = parse_token_response(&body).unwrap();
1302        assert_eq!(token.access_token, "at-12345");
1303        assert_eq!(token.refresh_token, Some("rt-67890".to_string()));
1304        assert_eq!(token.token_type, "Bearer");
1305        assert!(token.expires_at.is_some());
1306        assert_eq!(token.scopes, vec!["openai.public"]);
1307    }
1308
1309    #[test]
1310    fn test_parse_token_response_minimal() {
1311        let body = serde_json::json!({
1312            "access_token": "at-minimal"
1313        })
1314        .to_string();
1315
1316        let token = parse_token_response(&body).unwrap();
1317        assert_eq!(token.access_token, "at-minimal");
1318        assert!(token.refresh_token.is_none());
1319        assert_eq!(token.token_type, "Bearer");
1320        assert!(token.expires_at.is_none());
1321        assert!(token.scopes.is_empty());
1322    }
1323
1324    #[test]
1325    fn test_parse_token_response_missing_access_token() {
1326        let body = serde_json::json!({
1327            "token_type": "Bearer"
1328        })
1329        .to_string();
1330
1331        let result = parse_token_response(&body);
1332        assert!(result.is_err());
1333        match result.unwrap_err() {
1334            LlmError::OAuthFailed { message } => {
1335                assert!(message.contains("access_token"));
1336            }
1337            other => panic!("Expected OAuthFailed, got {:?}", other),
1338        }
1339    }
1340
1341    #[test]
1342    fn test_parse_token_response_invalid_json() {
1343        let result = parse_token_response("not json");
1344        assert!(result.is_err());
1345    }
1346
1347    #[test]
1348    fn test_is_token_expired_future() {
1349        let token = OAuthToken {
1350            access_token: "test".to_string(),
1351            refresh_token: None,
1352            id_token: None,
1353            expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
1354            token_type: "Bearer".to_string(),
1355            scopes: vec![],
1356        };
1357        assert!(!is_token_expired(&token));
1358    }
1359
1360    #[test]
1361    fn test_is_token_expired_past() {
1362        let token = OAuthToken {
1363            access_token: "test".to_string(),
1364            refresh_token: None,
1365            id_token: None,
1366            expires_at: Some(Utc::now() - chrono::Duration::hours(1)),
1367            token_type: "Bearer".to_string(),
1368            scopes: vec![],
1369        };
1370        assert!(is_token_expired(&token));
1371    }
1372
1373    #[test]
1374    fn test_is_token_expired_within_buffer() {
1375        // Token expires in 3 minutes — within the 5-minute buffer.
1376        let token = OAuthToken {
1377            access_token: "test".to_string(),
1378            refresh_token: None,
1379            id_token: None,
1380            expires_at: Some(Utc::now() + chrono::Duration::minutes(3)),
1381            token_type: "Bearer".to_string(),
1382            scopes: vec![],
1383        };
1384        assert!(is_token_expired(&token));
1385    }
1386
1387    #[test]
1388    fn test_is_token_expired_no_expiry() {
1389        let token = OAuthToken {
1390            access_token: "test".to_string(),
1391            refresh_token: None,
1392            id_token: None,
1393            expires_at: None,
1394            token_type: "Bearer".to_string(),
1395            scopes: vec![],
1396        };
1397        assert!(!is_token_expired(&token));
1398    }
1399
1400    #[test]
1401    fn test_store_and_load_oauth_token() {
1402        let store = InMemoryCredentialStore::new();
1403        let token = OAuthToken {
1404            access_token: "at-test-store".to_string(),
1405            refresh_token: Some("rt-test-store".to_string()),
1406            id_token: None,
1407            expires_at: None,
1408            token_type: "Bearer".to_string(),
1409            scopes: vec!["openai.public".to_string()],
1410        };
1411
1412        store_oauth_token(&store, "openai", &token).unwrap();
1413        let loaded = load_oauth_token(&store, "openai").unwrap();
1414        assert_eq!(loaded.access_token, "at-test-store");
1415        assert_eq!(loaded.refresh_token, Some("rt-test-store".to_string()));
1416        assert_eq!(loaded.scopes, vec!["openai.public"]);
1417    }
1418
1419    #[test]
1420    fn test_load_oauth_token_not_found() {
1421        let store = InMemoryCredentialStore::new();
1422        let result = load_oauth_token(&store, "nonexistent");
1423        assert!(result.is_err());
1424    }
1425
1426    #[test]
1427    fn test_delete_oauth_token() {
1428        let store = InMemoryCredentialStore::new();
1429        let token = OAuthToken {
1430            access_token: "at-delete".to_string(),
1431            refresh_token: None,
1432            id_token: None,
1433            expires_at: None,
1434            token_type: "Bearer".to_string(),
1435            scopes: vec![],
1436        };
1437
1438        store_oauth_token(&store, "openai", &token).unwrap();
1439        assert!(has_oauth_token(&store, "openai"));
1440
1441        delete_oauth_token(&store, "openai").unwrap();
1442        assert!(!has_oauth_token(&store, "openai"));
1443    }
1444
1445    #[test]
1446    fn test_has_oauth_token() {
1447        let store = InMemoryCredentialStore::new();
1448        assert!(!has_oauth_token(&store, "openai"));
1449
1450        let token = OAuthToken {
1451            access_token: "at-has".to_string(),
1452            refresh_token: None,
1453            id_token: None,
1454            expires_at: None,
1455            token_type: "Bearer".to_string(),
1456            scopes: vec![],
1457        };
1458        store_oauth_token(&store, "openai", &token).unwrap();
1459        assert!(has_oauth_token(&store, "openai"));
1460    }
1461
1462    #[test]
1463    fn test_openai_oauth_config() {
1464        let config = openai_oauth_config();
1465        assert_eq!(config.provider_name, "openai");
1466        assert_eq!(config.client_id, "app_EMoamEEZ73f0CkXaXp7hrann");
1467        assert!(config.authorization_url.contains("auth.openai.com"));
1468        assert!(config.token_url.contains("auth.openai.com"));
1469        assert!(config.supports_device_code);
1470        assert!(config.device_code_url.is_some());
1471        assert_eq!(
1472            config.scopes,
1473            vec!["openid", "profile", "email", "offline_access"]
1474        );
1475        assert_eq!(config.audience, None);
1476        assert_eq!(config.extra_auth_params.len(), 3);
1477    }
1478
1479    #[test]
1480    fn test_anthropic_oauth_config_returns_none() {
1481        assert!(anthropic_oauth_config().is_none());
1482    }
1483
1484    #[test]
1485    fn test_oauth_config_for_provider() {
1486        assert!(oauth_config_for_provider("openai").is_some());
1487        assert!(oauth_config_for_provider("anthropic").is_none());
1488        assert!(oauth_config_for_provider("unknown").is_none());
1489    }
1490
1491    #[test]
1492    fn test_provider_supports_oauth() {
1493        assert!(provider_supports_oauth("openai"));
1494        assert!(!provider_supports_oauth("anthropic"));
1495        assert!(!provider_supports_oauth("unknown"));
1496    }
1497
1498    #[test]
1499    fn test_auth_method_serde() {
1500        let json = serde_json::to_string(&AuthMethod::OAuth).unwrap();
1501        assert_eq!(json, "\"oauth\"");
1502        let method: AuthMethod = serde_json::from_str("\"api_key\"").unwrap();
1503        assert_eq!(method, AuthMethod::ApiKey);
1504    }
1505
1506    #[test]
1507    fn test_auth_method_default() {
1508        assert_eq!(AuthMethod::default(), AuthMethod::ApiKey);
1509    }
1510
1511    #[test]
1512    fn test_auth_method_display() {
1513        assert_eq!(AuthMethod::ApiKey.to_string(), "api_key");
1514        assert_eq!(AuthMethod::OAuth.to_string(), "oauth");
1515    }
1516
1517    #[test]
1518    fn test_oauth_token_serde_roundtrip() {
1519        let token = OAuthToken {
1520            access_token: "at-roundtrip".to_string(),
1521            refresh_token: Some("rt-roundtrip".to_string()),
1522            id_token: None,
1523            expires_at: Some(Utc::now()),
1524            token_type: "Bearer".to_string(),
1525            scopes: vec!["scope1".to_string(), "scope2".to_string()],
1526        };
1527        let json = serde_json::to_string(&token).unwrap();
1528        let parsed: OAuthToken = serde_json::from_str(&json).unwrap();
1529        assert_eq!(parsed.access_token, token.access_token);
1530        assert_eq!(parsed.refresh_token, token.refresh_token);
1531        assert_eq!(parsed.scopes.len(), 2);
1532    }
1533
1534    #[tokio::test]
1535    async fn test_callback_server_http_receives_code() {
1536        let (port, rx) = start_callback_server(false).await.unwrap();
1537        assert_eq!(port, OAUTH_CALLBACK_PORT);
1538
1539        // Simulate the OAuth callback using plain HTTP.
1540        let client = reqwest::Client::new();
1541        let url = format!(
1542            "http://127.0.0.1:{}/auth/callback?code=test-http&state=test-state-http",
1543            port
1544        );
1545        let response = client.get(&url).send().await.unwrap();
1546        assert!(response.status().is_success());
1547
1548        let callback = rx.await.unwrap();
1549        assert_eq!(callback.code, "test-http");
1550        assert_eq!(callback.state, "test-state-http");
1551    }
1552
1553    #[tokio::test]
1554    async fn test_tls_config_loading() {
1555        // Verify we can load/generate a TLS config without errors.
1556        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1557        let config = load_tls_config().await;
1558        assert!(config.is_ok(), "TLS config loading should succeed");
1559    }
1560
1561    // ── Channel OAuth Config Tests ─────────────────────────────────────────
1562
1563    #[test]
1564    fn test_slack_oauth_config() {
1565        let config = slack_oauth_config("slack-client-123", Some("slack-secret".into()));
1566        assert_eq!(config.provider_name, "slack");
1567        assert_eq!(config.client_id, "slack-client-123");
1568        assert!(
1569            config
1570                .authorization_url
1571                .contains("slack.com/oauth/v2/authorize")
1572        );
1573        assert!(config.token_url.contains("slack.com/api/oauth.v2.access"));
1574        assert!(config.scopes.contains(&"chat:write".to_string()));
1575        assert!(config.scopes.contains(&"channels:history".to_string()));
1576        assert!(config.scopes.contains(&"channels:read".to_string()));
1577        assert!(config.scopes.contains(&"users:read".to_string()));
1578        assert!(!config.supports_device_code);
1579    }
1580
1581    #[test]
1582    fn test_discord_oauth_config() {
1583        let config = discord_oauth_config("discord-client-456", Some("discord-secret".into()));
1584        assert_eq!(config.provider_name, "discord");
1585        assert_eq!(config.client_id, "discord-client-456");
1586        assert!(
1587            config
1588                .authorization_url
1589                .contains("discord.com/api/oauth2/authorize")
1590        );
1591        assert!(config.token_url.contains("discord.com/api/oauth2/token"));
1592        assert!(config.scopes.contains(&"bot".to_string()));
1593        assert!(config.scopes.contains(&"messages.read".to_string()));
1594        assert!(!config.supports_device_code);
1595    }
1596
1597    #[test]
1598    fn test_teams_oauth_config() {
1599        let config = teams_oauth_config(
1600            "teams-client-789",
1601            "my-tenant-id",
1602            Some("teams-secret".into()),
1603        );
1604        assert_eq!(config.provider_name, "teams");
1605        assert_eq!(config.client_id, "teams-client-789");
1606        assert!(
1607            config
1608                .authorization_url
1609                .contains("login.microsoftonline.com/my-tenant-id")
1610        );
1611        assert!(
1612            config
1613                .token_url
1614                .contains("login.microsoftonline.com/my-tenant-id")
1615        );
1616        assert!(
1617            config
1618                .scopes
1619                .contains(&"https://graph.microsoft.com/.default".to_string())
1620        );
1621        assert!(config.supports_device_code);
1622        assert!(
1623            config
1624                .device_code_url
1625                .as_ref()
1626                .unwrap()
1627                .contains("my-tenant-id")
1628        );
1629    }
1630
1631    #[test]
1632    fn test_teams_oauth_config_common_tenant() {
1633        let config = teams_oauth_config("teams-client", "common", None);
1634        assert!(
1635            config
1636                .authorization_url
1637                .contains("common/oauth2/v2.0/authorize")
1638        );
1639        assert!(config.token_url.contains("common/oauth2/v2.0/token"));
1640    }
1641
1642    #[test]
1643    fn test_whatsapp_oauth_config() {
1644        let config = whatsapp_oauth_config("meta-app-123", Some("meta-secret".into()));
1645        assert_eq!(config.provider_name, "whatsapp");
1646        assert_eq!(config.client_id, "meta-app-123");
1647        assert!(
1648            config
1649                .authorization_url
1650                .contains("facebook.com/v18.0/dialog/oauth")
1651        );
1652        assert!(
1653            config
1654                .token_url
1655                .contains("graph.facebook.com/v18.0/oauth/access_token")
1656        );
1657        assert!(
1658            config
1659                .scopes
1660                .contains(&"whatsapp_business_messaging".to_string())
1661        );
1662        assert!(
1663            config
1664                .scopes
1665                .contains(&"whatsapp_business_management".to_string())
1666        );
1667        assert!(!config.supports_device_code);
1668    }
1669
1670    #[test]
1671    fn test_gmail_oauth_config() {
1672        let config = gmail_oauth_config("gmail-client-id", Some("gmail-secret".into()));
1673        assert_eq!(config.provider_name, "gmail");
1674        assert_eq!(config.client_id, "gmail-client-id");
1675        assert!(config.authorization_url.contains("accounts.google.com"));
1676        assert!(config.token_url.contains("oauth2.googleapis.com"));
1677        assert!(
1678            config
1679                .scopes
1680                .contains(&"https://mail.google.com/".to_string())
1681        );
1682        // Gmail config should request offline access
1683        assert!(
1684            config
1685                .extra_auth_params
1686                .iter()
1687                .any(|(k, v)| k == "access_type" && v == "offline")
1688        );
1689    }
1690
1691    #[test]
1692    fn test_xoauth2_token_format() {
1693        let token = build_xoauth2_token("user@gmail.com", "ya29.access-token");
1694        assert_eq!(
1695            token,
1696            "user=user@gmail.com\x01auth=Bearer ya29.access-token\x01\x01"
1697        );
1698    }
1699
1700    #[test]
1701    fn test_xoauth2_token_base64() {
1702        let b64 = build_xoauth2_token_base64("user@gmail.com", "token123");
1703        // Should be valid base64
1704        let decoded = base64::engine::general_purpose::STANDARD
1705            .decode(&b64)
1706            .unwrap();
1707        let decoded_str = String::from_utf8(decoded).unwrap();
1708        assert!(decoded_str.starts_with("user=user@gmail.com\x01"));
1709        assert!(decoded_str.contains("auth=Bearer token123"));
1710    }
1711
1712    #[test]
1713    fn test_oauth_config_for_channel_providers_without_env() {
1714        // Without env vars set, channel providers should return None
1715        // (unless env vars happen to be set in CI)
1716        let _ = oauth_config_for_provider("slack");
1717        let _ = oauth_config_for_provider("discord");
1718        let _ = oauth_config_for_provider("teams");
1719        let _ = oauth_config_for_provider("whatsapp");
1720        let _ = oauth_config_for_provider("gmail");
1721        // Just verifying they don't panic
1722    }
1723
1724    #[test]
1725    fn test_store_and_load_channel_oauth_token() {
1726        let store = InMemoryCredentialStore::new();
1727        let token = OAuthToken {
1728            access_token: "xoxb-slack-token".to_string(),
1729            refresh_token: Some("xoxr-refresh".to_string()),
1730            id_token: None,
1731            expires_at: None,
1732            token_type: "Bearer".to_string(),
1733            scopes: vec!["chat:write".to_string(), "channels:history".to_string()],
1734        };
1735
1736        store_oauth_token(&store, "slack", &token).unwrap();
1737        let loaded = load_oauth_token(&store, "slack").unwrap();
1738        assert_eq!(loaded.access_token, "xoxb-slack-token");
1739        assert_eq!(loaded.scopes.len(), 2);
1740
1741        // Store a second provider token
1742        let teams_token = OAuthToken {
1743            access_token: "eyJ-teams-token".to_string(),
1744            refresh_token: None,
1745            id_token: None,
1746            expires_at: None,
1747            token_type: "Bearer".to_string(),
1748            scopes: vec!["https://graph.microsoft.com/.default".to_string()],
1749        };
1750        store_oauth_token(&store, "teams", &teams_token).unwrap();
1751        let loaded_teams = load_oauth_token(&store, "teams").unwrap();
1752        assert_eq!(loaded_teams.access_token, "eyJ-teams-token");
1753
1754        // Original slack token should still be there
1755        let loaded_slack = load_oauth_token(&store, "slack").unwrap();
1756        assert_eq!(loaded_slack.access_token, "xoxb-slack-token");
1757    }
1758}