1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
16use base64::Engine;
17use chrono::{DateTime, Utc};
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use std::collections::HashMap;
22use std::future::IntoFuture;
23use std::net::SocketAddr;
24use tokio::sync::oneshot;
25use tracing::{debug, info};
26
27use crate::credentials::{CredentialError, CredentialStore};
28use crate::error::LlmError;
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
34#[serde(rename_all = "snake_case")]
35pub enum AuthMethod {
36 #[default]
38 ApiKey,
39 #[serde(rename = "oauth")]
41 OAuth,
42}
43
44impl std::fmt::Display for AuthMethod {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 AuthMethod::ApiKey => write!(f, "api_key"),
48 AuthMethod::OAuth => write!(f, "oauth"),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct OAuthProviderConfig {
56 pub provider_name: String,
58 pub client_id: String,
60 pub client_secret: Option<String>,
63 pub authorization_url: String,
65 pub token_url: String,
67 pub scopes: Vec<String>,
69 pub audience: Option<String>,
71 pub supports_device_code: bool,
73 pub device_code_url: Option<String>,
75 pub extra_auth_params: Vec<(String, String)>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct OAuthToken {
82 pub access_token: String,
84 pub refresh_token: Option<String>,
86 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub id_token: Option<String>,
89 pub expires_at: Option<DateTime<Utc>>,
91 pub token_type: String,
93 pub scopes: Vec<String>,
95}
96
97struct PkcePair {
99 verifier: String,
100 challenge: String,
101}
102
103struct CallbackData {
105 code: String,
106 state: String,
107}
108
109fn generate_pkce_pair() -> PkcePair {
116 let mut rng = rand::thread_rng();
117 let verifier: String = (0..43)
118 .map(|_| {
119 const CHARSET: &[u8] =
120 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
121 let idx = rng.gen_range(0..CHARSET.len());
122 CHARSET[idx] as char
123 })
124 .collect();
125
126 let mut hasher = Sha256::new();
127 hasher.update(verifier.as_bytes());
128 let digest = hasher.finalize();
129 let challenge = URL_SAFE_NO_PAD.encode(digest);
130
131 PkcePair {
132 verifier,
133 challenge,
134 }
135}
136
137fn generate_state() -> String {
139 let mut rng = rand::thread_rng();
140 let bytes: [u8; 32] = rng.gen();
141 URL_SAFE_NO_PAD.encode(bytes)
142}
143
144pub const OAUTH_CALLBACK_PORT: u16 = 8844;
152
153fn build_callback_router(
155 tx: std::sync::Arc<tokio::sync::Mutex<Option<oneshot::Sender<CallbackData>>>>,
156) -> axum::Router {
157 axum::Router::new().route(
158 "/auth/callback",
159 axum::routing::get({
160 let tx = tx.clone();
161 move |query: axum::extract::Query<HashMap<String, String>>| {
162 let tx = tx.clone();
163 async move {
164 let code = query.get("code").cloned().unwrap_or_default();
165 let state = query.get("state").cloned().unwrap_or_default();
166
167 if let Some(sender) = tx.lock().await.take() {
168 let _ = sender.send(CallbackData { code, state });
169 }
170
171 axum::response::Html(
172 r#"<!DOCTYPE html>
173<html>
174<head><title>Rustant</title></head>
175<body style="font-family: system-ui; text-align: center; padding-top: 80px;">
176<h2>Authentication successful!</h2>
177<p>You can close this tab and return to the terminal.</p>
178</body>
179</html>"#,
180 )
181 }
182 }
183 }),
184 )
185}
186
187async fn load_tls_config() -> Result<axum_server::tls_rustls::RustlsConfig, LlmError> {
203 if let Some(home) = directories::BaseDirs::new() {
205 let cert_dir = home.home_dir().join(".rustant").join("certs");
206 let cert_path = cert_dir.join("localhost.pem");
207 let key_path = cert_dir.join("localhost-key.pem");
208
209 if cert_path.exists() && key_path.exists() {
210 info!("Using mkcert certificates from {}", cert_dir.display());
211 return axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_path, key_path)
212 .await
213 .map_err(|e| LlmError::OAuthFailed {
214 message: format!("Failed to load mkcert certificates: {}", e),
215 });
216 }
217 }
218
219 info!(
221 "No mkcert certs found in ~/.rustant/certs/. Generating self-signed certificate.\n\
222 Your browser may show a security warning. To avoid this, run:\n \
223 mkcert -install && mkdir -p ~/.rustant/certs && \
224 mkcert -cert-file ~/.rustant/certs/localhost.pem \
225 -key-file ~/.rustant/certs/localhost-key.pem localhost 127.0.0.1"
226 );
227
228 use rcgen::CertifiedKey;
229 let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
230 let CertifiedKey { cert, key_pair } = rcgen::generate_simple_self_signed(subject_alt_names)
231 .map_err(|e| LlmError::OAuthFailed {
232 message: format!("Failed to generate self-signed certificate: {}", e),
233 })?;
234
235 let cert_pem = cert.pem();
236 let key_pem = key_pair.serialize_pem();
237
238 axum_server::tls_rustls::RustlsConfig::from_pem(cert_pem.into_bytes(), key_pem.into_bytes())
239 .await
240 .map_err(|e| LlmError::OAuthFailed {
241 message: format!("Failed to build TLS config: {}", e),
242 })
243}
244
245async fn start_callback_server(
254 use_tls: bool,
255) -> Result<(u16, oneshot::Receiver<CallbackData>), LlmError> {
256 let (tx, rx) = oneshot::channel::<CallbackData>();
257 let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx)));
258 let app = build_callback_router(tx);
259
260 let bind_addr = format!("127.0.0.1:{}", OAUTH_CALLBACK_PORT);
261
262 if use_tls {
263 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
265
266 let tls_config = load_tls_config().await?;
267
268 let addr: SocketAddr = bind_addr.parse().map_err(|e| LlmError::OAuthFailed {
269 message: format!("Invalid bind address: {}", e),
270 })?;
271
272 debug!(
273 port = OAUTH_CALLBACK_PORT,
274 "OAuth HTTPS callback server starting"
275 );
276
277 tokio::spawn(async move {
278 let server = axum_server::bind_rustls(addr, tls_config).serve(app.into_make_service());
279 let _ = tokio::time::timeout(std::time::Duration::from_secs(120), server).await;
280 });
281
282 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
284 } else {
285 let listener = tokio::net::TcpListener::bind(&bind_addr)
286 .await
287 .map_err(|e| LlmError::OAuthFailed {
288 message: format!(
289 "Failed to bind callback server on port {}: {}. \
290 Make sure no other process is using this port.",
291 OAUTH_CALLBACK_PORT, e
292 ),
293 })?;
294
295 debug!(
296 port = OAUTH_CALLBACK_PORT,
297 "OAuth HTTP callback server starting"
298 );
299
300 tokio::spawn(async move {
301 let server = axum::serve(listener, app);
302 let _ = tokio::time::timeout(std::time::Duration::from_secs(120), server.into_future())
303 .await;
304 });
305 }
306
307 Ok((OAUTH_CALLBACK_PORT, rx))
308}
309
310pub async fn authorize_browser_flow(
331 config: &OAuthProviderConfig,
332 redirect_uri_override: Option<&str>,
333) -> Result<OAuthToken, LlmError> {
334 let pkce = generate_pkce_pair();
335 let state = generate_state();
336
337 let is_channel_provider = matches!(
341 config.provider_name.as_str(),
342 "slack" | "discord" | "teams" | "whatsapp" | "gmail"
343 );
344
345 let use_tls = match redirect_uri_override {
346 Some(uri) => uri.starts_with("https://"),
347 None => is_channel_provider,
348 };
349
350 let (port, rx) = start_callback_server(use_tls).await?;
352
353 let redirect_uri = match redirect_uri_override {
354 Some(uri) => uri.to_string(),
355 None => {
356 let scheme = if use_tls { "https" } else { "http" };
357 format!("{}://localhost:{}/auth/callback", scheme, port)
358 }
359 };
360
361 let mut auth_url =
363 url::Url::parse(&config.authorization_url).map_err(|e| LlmError::OAuthFailed {
364 message: format!("Invalid authorization URL: {}", e),
365 })?;
366
367 {
368 let mut params = auth_url.query_pairs_mut();
369 params.append_pair("response_type", "code");
370 params.append_pair("client_id", &config.client_id);
371 params.append_pair("redirect_uri", &redirect_uri);
372 params.append_pair("code_challenge", &pkce.challenge);
373 params.append_pair("code_challenge_method", "S256");
374 params.append_pair("state", &state);
375
376 if !config.scopes.is_empty() {
377 params.append_pair("scope", &config.scopes.join(" "));
378 }
379 if let Some(ref audience) = config.audience {
380 params.append_pair("audience", audience);
381 }
382 for (key, value) in &config.extra_auth_params {
383 params.append_pair(key, value);
384 }
385 }
386
387 info!("Opening browser for OAuth authorization...");
388 debug!(url = %auth_url, "Authorization URL");
389 open::that(auth_url.as_str()).map_err(|e| LlmError::OAuthFailed {
390 message: format!("Failed to open browser: {}", e),
391 })?;
392
393 let callback = tokio::time::timeout(std::time::Duration::from_secs(120), rx)
395 .await
396 .map_err(|_| LlmError::OAuthFailed {
397 message: "OAuth callback timed out after 120 seconds".to_string(),
398 })?
399 .map_err(|_| LlmError::OAuthFailed {
400 message: "OAuth callback channel closed unexpectedly".to_string(),
401 })?;
402
403 if callback.state != state {
405 return Err(LlmError::OAuthFailed {
406 message: "OAuth state parameter mismatch (possible CSRF attack)".to_string(),
407 });
408 }
409
410 if callback.code.is_empty() {
411 return Err(LlmError::OAuthFailed {
412 message: "OAuth callback did not contain an authorization code".to_string(),
413 });
414 }
415
416 let mut token =
418 exchange_code_for_token(config, &callback.code, &pkce.verifier, &redirect_uri).await?;
419
420 if config.provider_name == "openai" {
425 if let Some(ref id_tok) = token.id_token {
426 if let Some(payload) = id_tok.split('.').nth(1) {
427 if let Ok(bytes) = URL_SAFE_NO_PAD.decode(payload) {
428 if let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&bytes) {
429 debug!(claims = %claims, "ID token claims");
430 }
431 }
432 }
433 debug!("Exchanging ID token for OpenAI API key...");
434 match obtain_openai_api_key(config, id_tok).await {
435 Ok(api_key) => {
436 info!("Obtained OpenAI Platform API key via token exchange");
437 token.access_token = api_key;
438 }
439 Err(e) => {
440 return Err(LlmError::OAuthFailed {
445 message: format!(
446 "Failed to exchange OAuth token for an OpenAI API key: {}\n\n\
447 This usually means your OpenAI account does not have \
448 Platform API access set up.\n\n\
449 To fix this:\n\
450 1. Visit https://platform.openai.com to create an API organization\n\
451 2. Ensure you have a billing method or active subscription\n\
452 3. Run 'rustant auth login openai' again\n\n\
453 Alternatively, use a standard API key:\n\
454 1. Get your key from https://platform.openai.com/api-keys\n\
455 2. Set the OPENAI_API_KEY environment variable\n\
456 3. Set auth_method to empty in .rustant/config.toml",
457 e
458 ),
459 });
460 }
461 }
462 }
463 }
464
465 Ok(token)
466}
467
468async fn exchange_code_for_token(
470 config: &OAuthProviderConfig,
471 code: &str,
472 code_verifier: &str,
473 redirect_uri: &str,
474) -> Result<OAuthToken, LlmError> {
475 let client = reqwest::Client::new();
476
477 let mut body = format!(
480 "grant_type={}&code={}&redirect_uri={}&client_id={}&code_verifier={}",
481 urlencoding::encode("authorization_code"),
482 urlencoding::encode(code),
483 urlencoding::encode(redirect_uri),
484 urlencoding::encode(&config.client_id),
485 urlencoding::encode(code_verifier),
486 );
487
488 if let Some(ref secret) = config.client_secret {
490 body.push_str(&format!("&client_secret={}", urlencoding::encode(secret)));
491 }
492
493 debug!(provider = %config.provider_name, "Exchanging authorization code for token");
494
495 let response = client
496 .post(&config.token_url)
497 .header("Content-Type", "application/x-www-form-urlencoded")
498 .body(body)
499 .send()
500 .await
501 .map_err(|e| LlmError::OAuthFailed {
502 message: format!("Token exchange request failed: {}", e),
503 })?;
504
505 let status = response.status();
506 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
507 message: format!("Failed to read token response: {}", e),
508 })?;
509
510 if !status.is_success() {
511 return Err(LlmError::OAuthFailed {
512 message: format!("Token exchange failed (HTTP {}): {}", status, body_text),
513 });
514 }
515
516 parse_token_response(&body_text)
517}
518
519async fn obtain_openai_api_key(
528 config: &OAuthProviderConfig,
529 id_token: &str,
530) -> Result<String, LlmError> {
531 let client = reqwest::Client::new();
532
533 let body = format!(
537 "grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}",
538 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
539 urlencoding::encode(&config.client_id),
540 urlencoding::encode("openai-api-key"),
541 urlencoding::encode(id_token),
542 urlencoding::encode("urn:ietf:params:oauth:token-type:id_token"),
543 );
544
545 debug!(body_len = body.len(), "Token exchange request body");
546
547 let response = client
548 .post(&config.token_url)
549 .header("Content-Type", "application/x-www-form-urlencoded")
550 .body(body)
551 .send()
552 .await
553 .map_err(|e| LlmError::OAuthFailed {
554 message: format!("API key exchange request failed: {}", e),
555 })?;
556
557 let status = response.status();
558 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
559 message: format!("Failed to read API key exchange response: {}", e),
560 })?;
561
562 if !status.is_success() {
563 return Err(LlmError::OAuthFailed {
564 message: format!("API key exchange failed (HTTP {}): {}", status, body_text),
565 });
566 }
567
568 let json: serde_json::Value =
569 serde_json::from_str(&body_text).map_err(|e| LlmError::OAuthFailed {
570 message: format!("Invalid JSON in API key exchange response: {}", e),
571 })?;
572
573 json["access_token"]
574 .as_str()
575 .map(|s| s.to_string())
576 .ok_or_else(|| LlmError::OAuthFailed {
577 message: "API key exchange response missing 'access_token'".to_string(),
578 })
579}
580
581fn parse_token_response(body: &str) -> Result<OAuthToken, LlmError> {
583 let json: serde_json::Value =
584 serde_json::from_str(body).map_err(|e| LlmError::OAuthFailed {
585 message: format!("Invalid JSON in token response: {}", e),
586 })?;
587
588 let access_token = json["access_token"]
589 .as_str()
590 .ok_or_else(|| LlmError::OAuthFailed {
591 message: "Token response missing 'access_token'".to_string(),
592 })?
593 .to_string();
594
595 let refresh_token = json["refresh_token"].as_str().map(|s| s.to_string());
596 let id_token = json["id_token"].as_str().map(|s| s.to_string());
597 let token_type = json["token_type"].as_str().unwrap_or("Bearer").to_string();
598
599 let expires_at = json["expires_in"]
600 .as_u64()
601 .map(|secs| Utc::now() + chrono::Duration::seconds(secs as i64));
602
603 let scopes = json["scope"]
604 .as_str()
605 .map(|s| s.split_whitespace().map(|s| s.to_string()).collect())
606 .unwrap_or_default();
607
608 Ok(OAuthToken {
609 access_token,
610 refresh_token,
611 id_token,
612 expires_at,
613 token_type,
614 scopes,
615 })
616}
617
618pub async fn authorize_device_code_flow(
628 config: &OAuthProviderConfig,
629) -> Result<OAuthToken, LlmError> {
630 let device_code_url =
631 config
632 .device_code_url
633 .as_deref()
634 .ok_or_else(|| LlmError::OAuthFailed {
635 message: format!(
636 "Provider '{}' does not support device code flow",
637 config.provider_name
638 ),
639 })?;
640
641 let client = reqwest::Client::new();
642
643 let mut params = HashMap::new();
645 params.insert("client_id", config.client_id.as_str());
646 if !config.scopes.is_empty() {
647 let scope_str = config.scopes.join(" ");
648 params.insert("scope", Box::leak(scope_str.into_boxed_str()));
649 }
650 if let Some(ref audience) = config.audience {
651 params.insert("audience", audience.as_str());
652 }
653
654 let response = client
655 .post(device_code_url)
656 .form(¶ms)
657 .send()
658 .await
659 .map_err(|e| LlmError::OAuthFailed {
660 message: format!("Device code request failed: {}", e),
661 })?;
662
663 let status = response.status();
664 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
665 message: format!("Failed to read device code response: {}", e),
666 })?;
667
668 if !status.is_success() {
669 return Err(LlmError::OAuthFailed {
670 message: format!(
671 "Device code request failed (HTTP {}): {}",
672 status, body_text
673 ),
674 });
675 }
676
677 let json: serde_json::Value =
678 serde_json::from_str(&body_text).map_err(|e| LlmError::OAuthFailed {
679 message: format!("Invalid JSON in device code response: {}", e),
680 })?;
681
682 let device_code = json["device_code"]
683 .as_str()
684 .ok_or_else(|| LlmError::OAuthFailed {
685 message: "Device code response missing 'device_code'".to_string(),
686 })?;
687 let user_code = json["user_code"]
688 .as_str()
689 .ok_or_else(|| LlmError::OAuthFailed {
690 message: "Device code response missing 'user_code'".to_string(),
691 })?;
692 let verification_uri = json["verification_uri"]
693 .as_str()
694 .or_else(|| json["verification_url"].as_str())
695 .ok_or_else(|| LlmError::OAuthFailed {
696 message: "Device code response missing 'verification_uri'".to_string(),
697 })?;
698 let interval = json["interval"].as_u64().unwrap_or(5);
699 let expires_in = json["expires_in"].as_u64().unwrap_or(600);
700
701 println!();
703 println!(" To authenticate, visit: {}", verification_uri);
704 println!(" Enter this code: {}", user_code);
705 println!();
706 println!(" Waiting for authorization...");
707
708 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(expires_in);
710 let poll_interval = std::time::Duration::from_secs(interval);
711
712 loop {
713 tokio::time::sleep(poll_interval).await;
714
715 if tokio::time::Instant::now() >= deadline {
716 return Err(LlmError::OAuthFailed {
717 message: "Device code flow timed out waiting for authorization".to_string(),
718 });
719 }
720
721 let mut poll_params = HashMap::new();
722 poll_params.insert("grant_type", "urn:ietf:params:oauth:grant-type:device_code");
723 poll_params.insert("device_code", device_code);
724 poll_params.insert("client_id", &config.client_id);
725
726 let poll_response = client
727 .post(&config.token_url)
728 .form(&poll_params)
729 .send()
730 .await
731 .map_err(|e| LlmError::OAuthFailed {
732 message: format!("Token poll request failed: {}", e),
733 })?;
734
735 let poll_status = poll_response.status();
736 let poll_body = poll_response
737 .text()
738 .await
739 .map_err(|e| LlmError::OAuthFailed {
740 message: format!("Failed to read token poll response: {}", e),
741 })?;
742
743 if poll_status.is_success() {
744 return parse_token_response(&poll_body);
745 }
746
747 if let Ok(err_json) = serde_json::from_str::<serde_json::Value>(&poll_body) {
749 let error = err_json["error"].as_str().unwrap_or("");
750 match error {
751 "authorization_pending" => {
752 debug!("Device code flow: authorization pending, polling again...");
753 continue;
754 }
755 "slow_down" => {
756 debug!("Device code flow: slow down requested");
757 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
758 continue;
759 }
760 "expired_token" => {
761 return Err(LlmError::OAuthFailed {
762 message: "Device code expired. Please try again.".to_string(),
763 });
764 }
765 "access_denied" => {
766 return Err(LlmError::OAuthFailed {
767 message: "Authorization was denied by the user.".to_string(),
768 });
769 }
770 _ => {
771 return Err(LlmError::OAuthFailed {
772 message: format!("Token poll error: {}", poll_body),
773 });
774 }
775 }
776 }
777
778 return Err(LlmError::OAuthFailed {
780 message: format!("Token poll failed (HTTP {}): {}", poll_status, poll_body),
781 });
782 }
783}
784
785pub async fn refresh_token(
789 config: &OAuthProviderConfig,
790 refresh_token_str: &str,
791) -> Result<OAuthToken, LlmError> {
792 let client = reqwest::Client::new();
793
794 let mut params = HashMap::new();
795 params.insert("grant_type", "refresh_token");
796 params.insert("refresh_token", refresh_token_str);
797 params.insert("client_id", &config.client_id);
798
799 debug!(provider = %config.provider_name, "Refreshing OAuth token");
800
801 let response = client
802 .post(&config.token_url)
803 .form(¶ms)
804 .send()
805 .await
806 .map_err(|e| LlmError::OAuthFailed {
807 message: format!("Token refresh request failed: {}", e),
808 })?;
809
810 let status = response.status();
811 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
812 message: format!("Failed to read token refresh response: {}", e),
813 })?;
814
815 if !status.is_success() {
816 return Err(LlmError::OAuthFailed {
817 message: format!("Token refresh failed (HTTP {}): {}", status, body_text),
818 });
819 }
820
821 let mut token = parse_token_response(&body_text)?;
822
823 if token.refresh_token.is_none() {
825 token.refresh_token = Some(refresh_token_str.to_string());
826 }
827
828 Ok(token)
829}
830
831pub fn is_token_expired(token: &OAuthToken) -> bool {
835 match token.expires_at {
836 Some(expires_at) => {
837 let buffer = chrono::Duration::minutes(5);
838 Utc::now() >= (expires_at - buffer)
839 }
840 None => false,
842 }
843}
844
845pub fn store_oauth_token(
851 store: &dyn CredentialStore,
852 provider: &str,
853 token: &OAuthToken,
854) -> Result<(), LlmError> {
855 let key = format!("oauth:{}", provider);
856 let json = serde_json::to_string(token).map_err(|e| LlmError::OAuthFailed {
857 message: format!("Failed to serialize OAuth token: {}", e),
858 })?;
859 store
860 .store_key(&key, &json)
861 .map_err(|e| LlmError::OAuthFailed {
862 message: format!("Failed to store OAuth token: {}", e),
863 })
864}
865
866pub fn load_oauth_token(
868 store: &dyn CredentialStore,
869 provider: &str,
870) -> Result<OAuthToken, LlmError> {
871 let key = format!("oauth:{}", provider);
872 let json = store.get_key(&key).map_err(|e| match e {
873 CredentialError::NotFound { .. } => LlmError::OAuthFailed {
874 message: format!("No OAuth token found for provider '{}'", provider),
875 },
876 other => LlmError::OAuthFailed {
877 message: format!("Failed to load OAuth token: {}", other),
878 },
879 })?;
880 serde_json::from_str(&json).map_err(|e| LlmError::OAuthFailed {
881 message: format!("Failed to deserialize OAuth token: {}", e),
882 })
883}
884
885pub fn delete_oauth_token(store: &dyn CredentialStore, provider: &str) -> Result<(), LlmError> {
887 let key = format!("oauth:{}", provider);
888 store.delete_key(&key).map_err(|e| LlmError::OAuthFailed {
889 message: format!("Failed to delete OAuth token: {}", e),
890 })
891}
892
893pub fn has_oauth_token(store: &dyn CredentialStore, provider: &str) -> bool {
895 let key = format!("oauth:{}", provider);
896 store.has_key(&key)
897}
898
899pub fn openai_oauth_config() -> OAuthProviderConfig {
905 OAuthProviderConfig {
906 provider_name: "openai".to_string(),
907 client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(),
908 client_secret: None, authorization_url: "https://auth.openai.com/oauth/authorize".to_string(),
910 token_url: "https://auth.openai.com/oauth/token".to_string(),
911 scopes: vec![
912 "openid".to_string(),
913 "profile".to_string(),
914 "email".to_string(),
915 "offline_access".to_string(),
916 ],
917 audience: None,
918 supports_device_code: true,
919 device_code_url: Some("https://auth.openai.com/oauth/device/code".to_string()),
920 extra_auth_params: vec![
921 ("id_token_add_organizations".to_string(), "true".to_string()),
922 ("codex_cli_simplified_flow".to_string(), "true".to_string()),
923 ("originator".to_string(), "codex_cli_rs".to_string()),
924 ],
925 }
926}
927
928pub fn google_oauth_config() -> Option<OAuthProviderConfig> {
934 let client_id = std::env::var("GOOGLE_OAUTH_CLIENT_ID").ok()?;
935 let client_secret = std::env::var("GOOGLE_OAUTH_CLIENT_SECRET").ok();
936 Some(OAuthProviderConfig {
937 provider_name: "google".to_string(),
938 client_id,
939 client_secret,
940 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
941 token_url: "https://oauth2.googleapis.com/token".to_string(),
942 scopes: vec!["https://www.googleapis.com/auth/generative-language".to_string()],
943 audience: None,
944 supports_device_code: false,
945 device_code_url: None,
946 extra_auth_params: vec![],
947 })
948}
949
950pub fn anthropic_oauth_config() -> Option<OAuthProviderConfig> {
956 None
959}
960
961pub fn slack_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
968 OAuthProviderConfig {
969 provider_name: "slack".to_string(),
970 client_id: client_id.to_string(),
971 client_secret,
972 authorization_url: "https://slack.com/oauth/v2/authorize".to_string(),
973 token_url: "https://slack.com/api/oauth.v2.access".to_string(),
974 scopes: vec![
975 "chat:write".to_string(),
976 "channels:history".to_string(),
977 "channels:read".to_string(),
978 "users:read".to_string(),
979 ],
980 audience: None,
981 supports_device_code: false,
982 device_code_url: None,
983 extra_auth_params: vec![],
984 }
985}
986
987pub fn discord_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
992 OAuthProviderConfig {
993 provider_name: "discord".to_string(),
994 client_id: client_id.to_string(),
995 client_secret,
996 authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
997 token_url: "https://discord.com/api/oauth2/token".to_string(),
998 scopes: vec!["bot".to_string(), "messages.read".to_string()],
999 audience: None,
1000 supports_device_code: false,
1001 device_code_url: None,
1002 extra_auth_params: vec![],
1003 }
1004}
1005
1006pub fn teams_oauth_config(
1014 client_id: &str,
1015 tenant_id: &str,
1016 client_secret: Option<String>,
1017) -> OAuthProviderConfig {
1018 OAuthProviderConfig {
1019 provider_name: "teams".to_string(),
1020 client_id: client_id.to_string(),
1021 client_secret,
1022 authorization_url: format!(
1023 "https://login.microsoftonline.com/{}/oauth2/v2.0/authorize",
1024 tenant_id
1025 ),
1026 token_url: format!(
1027 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
1028 tenant_id
1029 ),
1030 scopes: vec!["https://graph.microsoft.com/.default".to_string()],
1031 audience: None,
1032 supports_device_code: true,
1033 device_code_url: Some(format!(
1034 "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
1035 tenant_id
1036 )),
1037 extra_auth_params: vec![],
1038 }
1039}
1040
1041pub fn whatsapp_oauth_config(app_id: &str, app_secret: Option<String>) -> OAuthProviderConfig {
1046 OAuthProviderConfig {
1047 provider_name: "whatsapp".to_string(),
1048 client_id: app_id.to_string(),
1049 client_secret: app_secret,
1050 authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
1051 token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
1052 scopes: vec![
1053 "whatsapp_business_messaging".to_string(),
1054 "whatsapp_business_management".to_string(),
1055 ],
1056 audience: None,
1057 supports_device_code: false,
1058 device_code_url: None,
1059 extra_auth_params: vec![],
1060 }
1061}
1062
1063pub fn gmail_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
1069 OAuthProviderConfig {
1070 provider_name: "gmail".to_string(),
1071 client_id: client_id.to_string(),
1072 client_secret,
1073 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1074 token_url: "https://oauth2.googleapis.com/token".to_string(),
1075 scopes: vec!["https://mail.google.com/".to_string()],
1076 audience: None,
1077 supports_device_code: false,
1078 device_code_url: None,
1079 extra_auth_params: vec![
1080 ("access_type".to_string(), "offline".to_string()),
1081 ("prompt".to_string(), "consent".to_string()),
1082 ],
1083 }
1084}
1085
1086pub async fn authorize_client_credentials_flow(
1096 config: &OAuthProviderConfig,
1097 client_secret: &str,
1098) -> Result<OAuthToken, LlmError> {
1099 let client = reqwest::Client::new();
1100
1101 let secret = if client_secret.is_empty() {
1103 config.client_secret.as_deref().unwrap_or("")
1104 } else {
1105 client_secret
1106 };
1107
1108 let body = format!(
1109 "grant_type={}&client_id={}&client_secret={}&scope={}",
1110 urlencoding::encode("client_credentials"),
1111 urlencoding::encode(&config.client_id),
1112 urlencoding::encode(secret),
1113 urlencoding::encode(&config.scopes.join(" ")),
1114 );
1115
1116 debug!(provider = %config.provider_name, "Requesting client credentials token");
1117
1118 let response = client
1119 .post(&config.token_url)
1120 .header("Content-Type", "application/x-www-form-urlencoded")
1121 .body(body)
1122 .send()
1123 .await
1124 .map_err(|e| LlmError::OAuthFailed {
1125 message: format!("Client credentials request failed: {}", e),
1126 })?;
1127
1128 let status = response.status();
1129 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
1130 message: format!("Failed to read client credentials response: {}", e),
1131 })?;
1132
1133 if !status.is_success() {
1134 return Err(LlmError::OAuthFailed {
1135 message: format!(
1136 "Client credentials token request failed (HTTP {}): {}",
1137 status, body_text
1138 ),
1139 });
1140 }
1141
1142 parse_token_response(&body_text)
1143}
1144
1145pub fn build_xoauth2_token(email: &str, access_token: &str) -> String {
1150 format!("user={}\x01auth=Bearer {}\x01\x01", email, access_token)
1151}
1152
1153pub fn build_xoauth2_token_base64(email: &str, access_token: &str) -> String {
1155 use base64::Engine;
1156 let raw = build_xoauth2_token(email, access_token);
1157 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1158}
1159
1160pub fn oauth_config_for_provider(provider: &str) -> Option<OAuthProviderConfig> {
1168 match provider {
1169 "openai" => Some(openai_oauth_config()),
1170 "gemini" | "google" => google_oauth_config(),
1171 "anthropic" => anthropic_oauth_config(),
1172 "slack" => {
1173 let client_id = std::env::var("SLACK_CLIENT_ID").ok()?;
1174 let client_secret = std::env::var("SLACK_CLIENT_SECRET").ok();
1175 Some(slack_oauth_config(&client_id, client_secret))
1176 }
1177 "discord" => {
1178 let client_id = std::env::var("DISCORD_CLIENT_ID").ok()?;
1179 let client_secret = std::env::var("DISCORD_CLIENT_SECRET").ok();
1180 Some(discord_oauth_config(&client_id, client_secret))
1181 }
1182 "teams" => {
1183 let client_id = std::env::var("TEAMS_CLIENT_ID").ok()?;
1184 let tenant_id =
1185 std::env::var("TEAMS_TENANT_ID").unwrap_or_else(|_| "common".to_string());
1186 let client_secret = std::env::var("TEAMS_CLIENT_SECRET").ok();
1187 Some(teams_oauth_config(&client_id, &tenant_id, client_secret))
1188 }
1189 "whatsapp" => {
1190 let app_id = std::env::var("WHATSAPP_APP_ID").ok()?;
1191 let app_secret = std::env::var("WHATSAPP_APP_SECRET").ok();
1192 Some(whatsapp_oauth_config(&app_id, app_secret))
1193 }
1194 "gmail" => {
1195 let client_id = std::env::var("GMAIL_OAUTH_CLIENT_ID")
1196 .or_else(|_| std::env::var("GOOGLE_OAUTH_CLIENT_ID"))
1197 .ok()?;
1198 let client_secret = std::env::var("GMAIL_OAUTH_CLIENT_SECRET")
1199 .or_else(|_| std::env::var("GOOGLE_OAUTH_CLIENT_SECRET"))
1200 .ok();
1201 Some(gmail_oauth_config(&client_id, client_secret))
1202 }
1203 _ => None,
1204 }
1205}
1206
1207pub fn oauth_config_with_credentials(
1214 provider: &str,
1215 client_id: &str,
1216 client_secret: Option<&str>,
1217) -> Option<OAuthProviderConfig> {
1218 let secret = client_secret.map(String::from);
1219 match provider {
1220 "slack" => Some(slack_oauth_config(client_id, secret)),
1221 "discord" => Some(discord_oauth_config(client_id, secret)),
1222 "gmail" => Some(gmail_oauth_config(client_id, secret)),
1223 _ => None,
1224 }
1225}
1226
1227pub fn provider_supports_oauth(provider: &str) -> bool {
1229 match provider {
1230 "openai" => true,
1231 "gemini" | "google" => std::env::var("GOOGLE_OAUTH_CLIENT_ID").is_ok(),
1232 "slack" => std::env::var("SLACK_CLIENT_ID").is_ok(),
1233 "discord" => std::env::var("DISCORD_CLIENT_ID").is_ok(),
1234 "teams" => std::env::var("TEAMS_CLIENT_ID").is_ok(),
1235 "whatsapp" => std::env::var("WHATSAPP_APP_ID").is_ok(),
1236 "gmail" => {
1237 std::env::var("GMAIL_OAUTH_CLIENT_ID").is_ok()
1238 || std::env::var("GOOGLE_OAUTH_CLIENT_ID").is_ok()
1239 }
1240 _ => false,
1241 }
1242}
1243
1244#[cfg(test)]
1247mod tests {
1248 use super::*;
1249 use crate::credentials::InMemoryCredentialStore;
1250
1251 #[test]
1252 fn test_generate_pkce_pair() {
1253 let pair = generate_pkce_pair();
1254 assert_eq!(pair.verifier.len(), 43);
1255 assert!(!pair.challenge.is_empty());
1256
1257 let decoded = URL_SAFE_NO_PAD.decode(&pair.challenge).unwrap();
1259 assert_eq!(decoded.len(), 32); let mut hasher = Sha256::new();
1263 hasher.update(pair.verifier.as_bytes());
1264 let expected = hasher.finalize();
1265 assert_eq!(decoded, expected.as_slice());
1266 }
1267
1268 #[test]
1269 fn test_generate_pkce_pair_uniqueness() {
1270 let pair1 = generate_pkce_pair();
1271 let pair2 = generate_pkce_pair();
1272 assert_ne!(pair1.verifier, pair2.verifier);
1273 assert_ne!(pair1.challenge, pair2.challenge);
1274 }
1275
1276 #[test]
1277 fn test_generate_state() {
1278 let state = generate_state();
1279 assert!(!state.is_empty());
1280 assert_eq!(state.len(), 43);
1282 }
1283
1284 #[test]
1285 fn test_generate_state_uniqueness() {
1286 let s1 = generate_state();
1287 let s2 = generate_state();
1288 assert_ne!(s1, s2);
1289 }
1290
1291 #[test]
1292 fn test_parse_token_response_full() {
1293 let body = serde_json::json!({
1294 "access_token": "at-12345",
1295 "refresh_token": "rt-67890",
1296 "token_type": "Bearer",
1297 "expires_in": 3600,
1298 "scope": "openai.public"
1299 })
1300 .to_string();
1301
1302 let token = parse_token_response(&body).unwrap();
1303 assert_eq!(token.access_token, "at-12345");
1304 assert_eq!(token.refresh_token, Some("rt-67890".to_string()));
1305 assert_eq!(token.token_type, "Bearer");
1306 assert!(token.expires_at.is_some());
1307 assert_eq!(token.scopes, vec!["openai.public"]);
1308 }
1309
1310 #[test]
1311 fn test_parse_token_response_minimal() {
1312 let body = serde_json::json!({
1313 "access_token": "at-minimal"
1314 })
1315 .to_string();
1316
1317 let token = parse_token_response(&body).unwrap();
1318 assert_eq!(token.access_token, "at-minimal");
1319 assert!(token.refresh_token.is_none());
1320 assert_eq!(token.token_type, "Bearer");
1321 assert!(token.expires_at.is_none());
1322 assert!(token.scopes.is_empty());
1323 }
1324
1325 #[test]
1326 fn test_parse_token_response_missing_access_token() {
1327 let body = serde_json::json!({
1328 "token_type": "Bearer"
1329 })
1330 .to_string();
1331
1332 let result = parse_token_response(&body);
1333 assert!(result.is_err());
1334 match result.unwrap_err() {
1335 LlmError::OAuthFailed { message } => {
1336 assert!(message.contains("access_token"));
1337 }
1338 other => panic!("Expected OAuthFailed, got {:?}", other),
1339 }
1340 }
1341
1342 #[test]
1343 fn test_parse_token_response_invalid_json() {
1344 let result = parse_token_response("not json");
1345 assert!(result.is_err());
1346 }
1347
1348 #[test]
1349 fn test_is_token_expired_future() {
1350 let token = OAuthToken {
1351 access_token: "test".to_string(),
1352 refresh_token: None,
1353 id_token: None,
1354 expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
1355 token_type: "Bearer".to_string(),
1356 scopes: vec![],
1357 };
1358 assert!(!is_token_expired(&token));
1359 }
1360
1361 #[test]
1362 fn test_is_token_expired_past() {
1363 let token = OAuthToken {
1364 access_token: "test".to_string(),
1365 refresh_token: None,
1366 id_token: None,
1367 expires_at: Some(Utc::now() - chrono::Duration::hours(1)),
1368 token_type: "Bearer".to_string(),
1369 scopes: vec![],
1370 };
1371 assert!(is_token_expired(&token));
1372 }
1373
1374 #[test]
1375 fn test_is_token_expired_within_buffer() {
1376 let token = OAuthToken {
1378 access_token: "test".to_string(),
1379 refresh_token: None,
1380 id_token: None,
1381 expires_at: Some(Utc::now() + chrono::Duration::minutes(3)),
1382 token_type: "Bearer".to_string(),
1383 scopes: vec![],
1384 };
1385 assert!(is_token_expired(&token));
1386 }
1387
1388 #[test]
1389 fn test_is_token_expired_no_expiry() {
1390 let token = OAuthToken {
1391 access_token: "test".to_string(),
1392 refresh_token: None,
1393 id_token: None,
1394 expires_at: None,
1395 token_type: "Bearer".to_string(),
1396 scopes: vec![],
1397 };
1398 assert!(!is_token_expired(&token));
1399 }
1400
1401 #[test]
1402 fn test_store_and_load_oauth_token() {
1403 let store = InMemoryCredentialStore::new();
1404 let token = OAuthToken {
1405 access_token: "at-test-store".to_string(),
1406 refresh_token: Some("rt-test-store".to_string()),
1407 id_token: None,
1408 expires_at: None,
1409 token_type: "Bearer".to_string(),
1410 scopes: vec!["openai.public".to_string()],
1411 };
1412
1413 store_oauth_token(&store, "openai", &token).unwrap();
1414 let loaded = load_oauth_token(&store, "openai").unwrap();
1415 assert_eq!(loaded.access_token, "at-test-store");
1416 assert_eq!(loaded.refresh_token, Some("rt-test-store".to_string()));
1417 assert_eq!(loaded.scopes, vec!["openai.public"]);
1418 }
1419
1420 #[test]
1421 fn test_load_oauth_token_not_found() {
1422 let store = InMemoryCredentialStore::new();
1423 let result = load_oauth_token(&store, "nonexistent");
1424 assert!(result.is_err());
1425 }
1426
1427 #[test]
1428 fn test_delete_oauth_token() {
1429 let store = InMemoryCredentialStore::new();
1430 let token = OAuthToken {
1431 access_token: "at-delete".to_string(),
1432 refresh_token: None,
1433 id_token: None,
1434 expires_at: None,
1435 token_type: "Bearer".to_string(),
1436 scopes: vec![],
1437 };
1438
1439 store_oauth_token(&store, "openai", &token).unwrap();
1440 assert!(has_oauth_token(&store, "openai"));
1441
1442 delete_oauth_token(&store, "openai").unwrap();
1443 assert!(!has_oauth_token(&store, "openai"));
1444 }
1445
1446 #[test]
1447 fn test_has_oauth_token() {
1448 let store = InMemoryCredentialStore::new();
1449 assert!(!has_oauth_token(&store, "openai"));
1450
1451 let token = OAuthToken {
1452 access_token: "at-has".to_string(),
1453 refresh_token: None,
1454 id_token: None,
1455 expires_at: None,
1456 token_type: "Bearer".to_string(),
1457 scopes: vec![],
1458 };
1459 store_oauth_token(&store, "openai", &token).unwrap();
1460 assert!(has_oauth_token(&store, "openai"));
1461 }
1462
1463 #[test]
1464 fn test_openai_oauth_config() {
1465 let config = openai_oauth_config();
1466 assert_eq!(config.provider_name, "openai");
1467 assert_eq!(config.client_id, "app_EMoamEEZ73f0CkXaXp7hrann");
1468 assert!(config.authorization_url.contains("auth.openai.com"));
1469 assert!(config.token_url.contains("auth.openai.com"));
1470 assert!(config.supports_device_code);
1471 assert!(config.device_code_url.is_some());
1472 assert_eq!(
1473 config.scopes,
1474 vec!["openid", "profile", "email", "offline_access"]
1475 );
1476 assert_eq!(config.audience, None);
1477 assert_eq!(config.extra_auth_params.len(), 3);
1478 }
1479
1480 #[test]
1481 fn test_anthropic_oauth_config_returns_none() {
1482 assert!(anthropic_oauth_config().is_none());
1483 }
1484
1485 #[test]
1486 fn test_oauth_config_for_provider() {
1487 assert!(oauth_config_for_provider("openai").is_some());
1488 assert!(oauth_config_for_provider("anthropic").is_none());
1489 assert!(oauth_config_for_provider("unknown").is_none());
1490 }
1491
1492 #[test]
1493 fn test_provider_supports_oauth() {
1494 assert!(provider_supports_oauth("openai"));
1495 assert!(!provider_supports_oauth("anthropic"));
1496 assert!(!provider_supports_oauth("unknown"));
1497 }
1498
1499 #[test]
1500 fn test_auth_method_serde() {
1501 let json = serde_json::to_string(&AuthMethod::OAuth).unwrap();
1502 assert_eq!(json, "\"oauth\"");
1503 let method: AuthMethod = serde_json::from_str("\"api_key\"").unwrap();
1504 assert_eq!(method, AuthMethod::ApiKey);
1505 }
1506
1507 #[test]
1508 fn test_auth_method_default() {
1509 assert_eq!(AuthMethod::default(), AuthMethod::ApiKey);
1510 }
1511
1512 #[test]
1513 fn test_auth_method_display() {
1514 assert_eq!(AuthMethod::ApiKey.to_string(), "api_key");
1515 assert_eq!(AuthMethod::OAuth.to_string(), "oauth");
1516 }
1517
1518 #[test]
1519 fn test_oauth_token_serde_roundtrip() {
1520 let token = OAuthToken {
1521 access_token: "at-roundtrip".to_string(),
1522 refresh_token: Some("rt-roundtrip".to_string()),
1523 id_token: None,
1524 expires_at: Some(Utc::now()),
1525 token_type: "Bearer".to_string(),
1526 scopes: vec!["scope1".to_string(), "scope2".to_string()],
1527 };
1528 let json = serde_json::to_string(&token).unwrap();
1529 let parsed: OAuthToken = serde_json::from_str(&json).unwrap();
1530 assert_eq!(parsed.access_token, token.access_token);
1531 assert_eq!(parsed.refresh_token, token.refresh_token);
1532 assert_eq!(parsed.scopes.len(), 2);
1533 }
1534
1535 #[tokio::test]
1536 async fn test_callback_server_http_receives_code() {
1537 let (port, rx) = start_callback_server(false).await.unwrap();
1538 assert_eq!(port, OAUTH_CALLBACK_PORT);
1539
1540 let client = reqwest::Client::new();
1542 let url = format!(
1543 "http://127.0.0.1:{}/auth/callback?code=test-http&state=test-state-http",
1544 port
1545 );
1546 let response = client.get(&url).send().await.unwrap();
1547 assert!(response.status().is_success());
1548
1549 let callback = rx.await.unwrap();
1550 assert_eq!(callback.code, "test-http");
1551 assert_eq!(callback.state, "test-state-http");
1552 }
1553
1554 #[tokio::test]
1555 async fn test_tls_config_loading() {
1556 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1558 let config = load_tls_config().await;
1559 assert!(config.is_ok(), "TLS config loading should succeed");
1560 }
1561
1562 #[test]
1565 fn test_slack_oauth_config() {
1566 let config = slack_oauth_config("slack-client-123", Some("slack-secret".into()));
1567 assert_eq!(config.provider_name, "slack");
1568 assert_eq!(config.client_id, "slack-client-123");
1569 assert!(config
1570 .authorization_url
1571 .contains("slack.com/oauth/v2/authorize"));
1572 assert!(config.token_url.contains("slack.com/api/oauth.v2.access"));
1573 assert!(config.scopes.contains(&"chat:write".to_string()));
1574 assert!(config.scopes.contains(&"channels:history".to_string()));
1575 assert!(config.scopes.contains(&"channels:read".to_string()));
1576 assert!(config.scopes.contains(&"users:read".to_string()));
1577 assert!(!config.supports_device_code);
1578 }
1579
1580 #[test]
1581 fn test_discord_oauth_config() {
1582 let config = discord_oauth_config("discord-client-456", Some("discord-secret".into()));
1583 assert_eq!(config.provider_name, "discord");
1584 assert_eq!(config.client_id, "discord-client-456");
1585 assert!(config
1586 .authorization_url
1587 .contains("discord.com/api/oauth2/authorize"));
1588 assert!(config.token_url.contains("discord.com/api/oauth2/token"));
1589 assert!(config.scopes.contains(&"bot".to_string()));
1590 assert!(config.scopes.contains(&"messages.read".to_string()));
1591 assert!(!config.supports_device_code);
1592 }
1593
1594 #[test]
1595 fn test_teams_oauth_config() {
1596 let config = teams_oauth_config(
1597 "teams-client-789",
1598 "my-tenant-id",
1599 Some("teams-secret".into()),
1600 );
1601 assert_eq!(config.provider_name, "teams");
1602 assert_eq!(config.client_id, "teams-client-789");
1603 assert!(config
1604 .authorization_url
1605 .contains("login.microsoftonline.com/my-tenant-id"));
1606 assert!(config
1607 .token_url
1608 .contains("login.microsoftonline.com/my-tenant-id"));
1609 assert!(config
1610 .scopes
1611 .contains(&"https://graph.microsoft.com/.default".to_string()));
1612 assert!(config.supports_device_code);
1613 assert!(config
1614 .device_code_url
1615 .as_ref()
1616 .unwrap()
1617 .contains("my-tenant-id"));
1618 }
1619
1620 #[test]
1621 fn test_teams_oauth_config_common_tenant() {
1622 let config = teams_oauth_config("teams-client", "common", None);
1623 assert!(config
1624 .authorization_url
1625 .contains("common/oauth2/v2.0/authorize"));
1626 assert!(config.token_url.contains("common/oauth2/v2.0/token"));
1627 }
1628
1629 #[test]
1630 fn test_whatsapp_oauth_config() {
1631 let config = whatsapp_oauth_config("meta-app-123", Some("meta-secret".into()));
1632 assert_eq!(config.provider_name, "whatsapp");
1633 assert_eq!(config.client_id, "meta-app-123");
1634 assert!(config
1635 .authorization_url
1636 .contains("facebook.com/v18.0/dialog/oauth"));
1637 assert!(config
1638 .token_url
1639 .contains("graph.facebook.com/v18.0/oauth/access_token"));
1640 assert!(config
1641 .scopes
1642 .contains(&"whatsapp_business_messaging".to_string()));
1643 assert!(config
1644 .scopes
1645 .contains(&"whatsapp_business_management".to_string()));
1646 assert!(!config.supports_device_code);
1647 }
1648
1649 #[test]
1650 fn test_gmail_oauth_config() {
1651 let config = gmail_oauth_config("gmail-client-id", Some("gmail-secret".into()));
1652 assert_eq!(config.provider_name, "gmail");
1653 assert_eq!(config.client_id, "gmail-client-id");
1654 assert!(config.authorization_url.contains("accounts.google.com"));
1655 assert!(config.token_url.contains("oauth2.googleapis.com"));
1656 assert!(config
1657 .scopes
1658 .contains(&"https://mail.google.com/".to_string()));
1659 assert!(config
1661 .extra_auth_params
1662 .iter()
1663 .any(|(k, v)| k == "access_type" && v == "offline"));
1664 }
1665
1666 #[test]
1667 fn test_xoauth2_token_format() {
1668 let token = build_xoauth2_token("user@gmail.com", "ya29.access-token");
1669 assert_eq!(
1670 token,
1671 "user=user@gmail.com\x01auth=Bearer ya29.access-token\x01\x01"
1672 );
1673 }
1674
1675 #[test]
1676 fn test_xoauth2_token_base64() {
1677 let b64 = build_xoauth2_token_base64("user@gmail.com", "token123");
1678 let decoded = base64::engine::general_purpose::STANDARD
1680 .decode(&b64)
1681 .unwrap();
1682 let decoded_str = String::from_utf8(decoded).unwrap();
1683 assert!(decoded_str.starts_with("user=user@gmail.com\x01"));
1684 assert!(decoded_str.contains("auth=Bearer token123"));
1685 }
1686
1687 #[test]
1688 fn test_oauth_config_for_channel_providers_without_env() {
1689 let _ = oauth_config_for_provider("slack");
1692 let _ = oauth_config_for_provider("discord");
1693 let _ = oauth_config_for_provider("teams");
1694 let _ = oauth_config_for_provider("whatsapp");
1695 let _ = oauth_config_for_provider("gmail");
1696 }
1698
1699 #[test]
1700 fn test_store_and_load_channel_oauth_token() {
1701 let store = InMemoryCredentialStore::new();
1702 let token = OAuthToken {
1703 access_token: "xoxb-slack-token".to_string(),
1704 refresh_token: Some("xoxr-refresh".to_string()),
1705 id_token: None,
1706 expires_at: None,
1707 token_type: "Bearer".to_string(),
1708 scopes: vec!["chat:write".to_string(), "channels:history".to_string()],
1709 };
1710
1711 store_oauth_token(&store, "slack", &token).unwrap();
1712 let loaded = load_oauth_token(&store, "slack").unwrap();
1713 assert_eq!(loaded.access_token, "xoxb-slack-token");
1714 assert_eq!(loaded.scopes.len(), 2);
1715
1716 let teams_token = OAuthToken {
1718 access_token: "eyJ-teams-token".to_string(),
1719 refresh_token: None,
1720 id_token: None,
1721 expires_at: None,
1722 token_type: "Bearer".to_string(),
1723 scopes: vec!["https://graph.microsoft.com/.default".to_string()],
1724 };
1725 store_oauth_token(&store, "teams", &teams_token).unwrap();
1726 let loaded_teams = load_oauth_token(&store, "teams").unwrap();
1727 assert_eq!(loaded_teams.access_token, "eyJ-teams-token");
1728
1729 let loaded_slack = load_oauth_token(&store, "slack").unwrap();
1731 assert_eq!(loaded_slack.access_token, "xoxb-slack-token");
1732 }
1733}