1use base64::Engine;
16use base64::engine::general_purpose::URL_SAFE_NO_PAD;
17use chrono::{DateTime, Utc};
18use rand::Rng;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use std::collections::HashMap;
22use std::future::IntoFuture;
23use std::net::SocketAddr;
24use tokio::sync::oneshot;
25use tracing::{debug, info};
26
27use crate::credentials::{CredentialError, CredentialStore};
28use crate::error::LlmError;
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
34#[serde(rename_all = "snake_case")]
35pub enum AuthMethod {
36 #[default]
38 ApiKey,
39 #[serde(rename = "oauth")]
41 OAuth,
42}
43
44impl std::fmt::Display for AuthMethod {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 AuthMethod::ApiKey => write!(f, "api_key"),
48 AuthMethod::OAuth => write!(f, "oauth"),
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct OAuthProviderConfig {
56 pub provider_name: String,
58 pub client_id: String,
60 pub client_secret: Option<String>,
63 pub authorization_url: String,
65 pub token_url: String,
67 pub scopes: Vec<String>,
69 pub audience: Option<String>,
71 pub supports_device_code: bool,
73 pub device_code_url: Option<String>,
75 pub extra_auth_params: Vec<(String, String)>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct OAuthToken {
82 pub access_token: String,
84 pub refresh_token: Option<String>,
86 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub id_token: Option<String>,
89 pub expires_at: Option<DateTime<Utc>>,
91 pub token_type: String,
93 pub scopes: Vec<String>,
95}
96
97struct PkcePair {
99 verifier: String,
100 challenge: String,
101}
102
103struct CallbackData {
105 code: String,
106 state: String,
107}
108
109fn generate_pkce_pair() -> PkcePair {
116 let mut rng = rand::thread_rng();
117 let verifier: String = (0..43)
118 .map(|_| {
119 const CHARSET: &[u8] =
120 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
121 let idx = rng.gen_range(0..CHARSET.len());
122 CHARSET[idx] as char
123 })
124 .collect();
125
126 let mut hasher = Sha256::new();
127 hasher.update(verifier.as_bytes());
128 let digest = hasher.finalize();
129 let challenge = URL_SAFE_NO_PAD.encode(digest);
130
131 PkcePair {
132 verifier,
133 challenge,
134 }
135}
136
137fn generate_state() -> String {
139 let mut rng = rand::thread_rng();
140 let bytes: [u8; 32] = rng.r#gen();
141 URL_SAFE_NO_PAD.encode(bytes)
142}
143
144pub const OAUTH_CALLBACK_PORT: u16 = 8844;
152
153fn build_callback_router(
155 tx: std::sync::Arc<tokio::sync::Mutex<Option<oneshot::Sender<CallbackData>>>>,
156) -> axum::Router {
157 axum::Router::new().route(
158 "/auth/callback",
159 axum::routing::get({
160 let tx = tx.clone();
161 move |query: axum::extract::Query<HashMap<String, String>>| {
162 let tx = tx.clone();
163 async move {
164 let code = query.get("code").cloned().unwrap_or_default();
165 let state = query.get("state").cloned().unwrap_or_default();
166
167 if let Some(sender) = tx.lock().await.take() {
168 let _ = sender.send(CallbackData { code, state });
169 }
170
171 axum::response::Html(
172 r#"<!DOCTYPE html>
173<html>
174<head><title>Rustant</title></head>
175<body style="font-family: system-ui; text-align: center; padding-top: 80px;">
176<h2>Authentication successful!</h2>
177<p>You can close this tab and return to the terminal.</p>
178</body>
179</html>"#,
180 )
181 }
182 }
183 }),
184 )
185}
186
187async fn load_tls_config() -> Result<axum_server::tls_rustls::RustlsConfig, LlmError> {
203 if let Some(home) = directories::BaseDirs::new() {
205 let cert_dir = home.home_dir().join(".rustant").join("certs");
206 let cert_path = cert_dir.join("localhost.pem");
207 let key_path = cert_dir.join("localhost-key.pem");
208
209 if cert_path.exists() && key_path.exists() {
210 info!("Using mkcert certificates from {}", cert_dir.display());
211 return axum_server::tls_rustls::RustlsConfig::from_pem_file(cert_path, key_path)
212 .await
213 .map_err(|e| LlmError::OAuthFailed {
214 message: format!("Failed to load mkcert certificates: {}", e),
215 });
216 }
217 }
218
219 info!(
221 "No mkcert certs found in ~/.rustant/certs/. Generating self-signed certificate.\n\
222 Your browser may show a security warning. To avoid this, run:\n \
223 mkcert -install && mkdir -p ~/.rustant/certs && \
224 mkcert -cert-file ~/.rustant/certs/localhost.pem \
225 -key-file ~/.rustant/certs/localhost-key.pem localhost 127.0.0.1"
226 );
227
228 use rcgen::CertifiedKey;
229 let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
230 let CertifiedKey { cert, key_pair } = rcgen::generate_simple_self_signed(subject_alt_names)
231 .map_err(|e| LlmError::OAuthFailed {
232 message: format!("Failed to generate self-signed certificate: {}", e),
233 })?;
234
235 let cert_pem = cert.pem();
236 let key_pem = key_pair.serialize_pem();
237
238 axum_server::tls_rustls::RustlsConfig::from_pem(cert_pem.into_bytes(), key_pem.into_bytes())
239 .await
240 .map_err(|e| LlmError::OAuthFailed {
241 message: format!("Failed to build TLS config: {}", e),
242 })
243}
244
245async fn start_callback_server(
254 use_tls: bool,
255) -> Result<(u16, oneshot::Receiver<CallbackData>), LlmError> {
256 let (tx, rx) = oneshot::channel::<CallbackData>();
257 let tx = std::sync::Arc::new(tokio::sync::Mutex::new(Some(tx)));
258 let app = build_callback_router(tx);
259
260 let bind_addr = format!("127.0.0.1:{}", OAUTH_CALLBACK_PORT);
261
262 if use_tls {
263 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
265
266 let tls_config = load_tls_config().await?;
267
268 let addr: SocketAddr = bind_addr.parse().map_err(|e| LlmError::OAuthFailed {
269 message: format!("Invalid bind address: {}", e),
270 })?;
271
272 debug!(
273 port = OAUTH_CALLBACK_PORT,
274 "OAuth HTTPS callback server starting"
275 );
276
277 tokio::spawn(async move {
278 let server = axum_server::bind_rustls(addr, tls_config).serve(app.into_make_service());
279 let _ = tokio::time::timeout(std::time::Duration::from_secs(120), server).await;
280 });
281
282 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
284 } else {
285 let listener = tokio::net::TcpListener::bind(&bind_addr)
286 .await
287 .map_err(|e| LlmError::OAuthFailed {
288 message: format!(
289 "Failed to bind callback server on port {}: {}. \
290 Make sure no other process is using this port.",
291 OAUTH_CALLBACK_PORT, e
292 ),
293 })?;
294
295 debug!(
296 port = OAUTH_CALLBACK_PORT,
297 "OAuth HTTP callback server starting"
298 );
299
300 tokio::spawn(async move {
301 let server = axum::serve(listener, app);
302 let _ = tokio::time::timeout(std::time::Duration::from_secs(120), server.into_future())
303 .await;
304 });
305 }
306
307 Ok((OAUTH_CALLBACK_PORT, rx))
308}
309
310pub async fn authorize_browser_flow(
331 config: &OAuthProviderConfig,
332 redirect_uri_override: Option<&str>,
333) -> Result<OAuthToken, LlmError> {
334 let pkce = generate_pkce_pair();
335 let state = generate_state();
336
337 let is_channel_provider = matches!(
341 config.provider_name.as_str(),
342 "slack" | "discord" | "teams" | "whatsapp" | "gmail"
343 );
344
345 let use_tls = match redirect_uri_override {
346 Some(uri) => uri.starts_with("https://"),
347 None => is_channel_provider,
348 };
349
350 let (port, rx) = start_callback_server(use_tls).await?;
352
353 let redirect_uri = match redirect_uri_override {
354 Some(uri) => uri.to_string(),
355 None => {
356 let scheme = if use_tls { "https" } else { "http" };
357 format!("{}://localhost:{}/auth/callback", scheme, port)
358 }
359 };
360
361 let mut auth_url =
363 url::Url::parse(&config.authorization_url).map_err(|e| LlmError::OAuthFailed {
364 message: format!("Invalid authorization URL: {}", e),
365 })?;
366
367 {
368 let mut params = auth_url.query_pairs_mut();
369 params.append_pair("response_type", "code");
370 params.append_pair("client_id", &config.client_id);
371 params.append_pair("redirect_uri", &redirect_uri);
372 params.append_pair("code_challenge", &pkce.challenge);
373 params.append_pair("code_challenge_method", "S256");
374 params.append_pair("state", &state);
375
376 if !config.scopes.is_empty() {
377 params.append_pair("scope", &config.scopes.join(" "));
378 }
379 if let Some(ref audience) = config.audience {
380 params.append_pair("audience", audience);
381 }
382 for (key, value) in &config.extra_auth_params {
383 params.append_pair(key, value);
384 }
385 }
386
387 info!("Opening browser for OAuth authorization...");
388 debug!(url = %auth_url, "Authorization URL");
389 open::that(auth_url.as_str()).map_err(|e| LlmError::OAuthFailed {
390 message: format!("Failed to open browser: {}", e),
391 })?;
392
393 let callback = tokio::time::timeout(std::time::Duration::from_secs(120), rx)
395 .await
396 .map_err(|_| LlmError::OAuthFailed {
397 message: "OAuth callback timed out after 120 seconds".to_string(),
398 })?
399 .map_err(|_| LlmError::OAuthFailed {
400 message: "OAuth callback channel closed unexpectedly".to_string(),
401 })?;
402
403 if callback.state != state {
405 return Err(LlmError::OAuthFailed {
406 message: "OAuth state parameter mismatch (possible CSRF attack)".to_string(),
407 });
408 }
409
410 if callback.code.is_empty() {
411 return Err(LlmError::OAuthFailed {
412 message: "OAuth callback did not contain an authorization code".to_string(),
413 });
414 }
415
416 let mut token =
418 exchange_code_for_token(config, &callback.code, &pkce.verifier, &redirect_uri).await?;
419
420 if config.provider_name == "openai"
425 && let Some(ref id_tok) = token.id_token
426 {
427 if let Some(payload) = id_tok.split('.').nth(1)
428 && let Ok(bytes) = URL_SAFE_NO_PAD.decode(payload)
429 && let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&bytes)
430 {
431 debug!(claims = %claims, "ID token claims");
432 }
433 debug!("Exchanging ID token for OpenAI API key...");
434 match obtain_openai_api_key(config, id_tok).await {
435 Ok(api_key) => {
436 info!("Obtained OpenAI Platform API key via token exchange");
437 token.access_token = api_key;
438 }
439 Err(e) => {
440 return Err(LlmError::OAuthFailed {
445 message: format!(
446 "Failed to exchange OAuth token for an OpenAI API key: {}\n\n\
447 This usually means your OpenAI account does not have \
448 Platform API access set up.\n\n\
449 To fix this:\n\
450 1. Visit https://platform.openai.com to create an API organization\n\
451 2. Ensure you have a billing method or active subscription\n\
452 3. Run 'rustant auth login openai' again\n\n\
453 Alternatively, use a standard API key:\n\
454 1. Get your key from https://platform.openai.com/api-keys\n\
455 2. Set the OPENAI_API_KEY environment variable\n\
456 3. Set auth_method to empty in .rustant/config.toml",
457 e
458 ),
459 });
460 }
461 }
462 }
463
464 Ok(token)
465}
466
467async fn exchange_code_for_token(
469 config: &OAuthProviderConfig,
470 code: &str,
471 code_verifier: &str,
472 redirect_uri: &str,
473) -> Result<OAuthToken, LlmError> {
474 let client = reqwest::Client::new();
475
476 let mut body = format!(
479 "grant_type={}&code={}&redirect_uri={}&client_id={}&code_verifier={}",
480 urlencoding::encode("authorization_code"),
481 urlencoding::encode(code),
482 urlencoding::encode(redirect_uri),
483 urlencoding::encode(&config.client_id),
484 urlencoding::encode(code_verifier),
485 );
486
487 if let Some(ref secret) = config.client_secret {
489 body.push_str(&format!("&client_secret={}", urlencoding::encode(secret)));
490 }
491
492 debug!(provider = %config.provider_name, "Exchanging authorization code for token");
493
494 let response = client
495 .post(&config.token_url)
496 .header("Content-Type", "application/x-www-form-urlencoded")
497 .body(body)
498 .send()
499 .await
500 .map_err(|e| LlmError::OAuthFailed {
501 message: format!("Token exchange request failed: {}", e),
502 })?;
503
504 let status = response.status();
505 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
506 message: format!("Failed to read token response: {}", e),
507 })?;
508
509 if !status.is_success() {
510 return Err(LlmError::OAuthFailed {
511 message: format!("Token exchange failed (HTTP {}): {}", status, body_text),
512 });
513 }
514
515 parse_token_response(&body_text)
516}
517
518async fn obtain_openai_api_key(
527 config: &OAuthProviderConfig,
528 id_token: &str,
529) -> Result<String, LlmError> {
530 let client = reqwest::Client::new();
531
532 let body = format!(
536 "grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}",
537 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
538 urlencoding::encode(&config.client_id),
539 urlencoding::encode("openai-api-key"),
540 urlencoding::encode(id_token),
541 urlencoding::encode("urn:ietf:params:oauth:token-type:id_token"),
542 );
543
544 debug!(body_len = body.len(), "Token exchange request body");
545
546 let response = client
547 .post(&config.token_url)
548 .header("Content-Type", "application/x-www-form-urlencoded")
549 .body(body)
550 .send()
551 .await
552 .map_err(|e| LlmError::OAuthFailed {
553 message: format!("API key exchange request failed: {}", e),
554 })?;
555
556 let status = response.status();
557 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
558 message: format!("Failed to read API key exchange response: {}", e),
559 })?;
560
561 if !status.is_success() {
562 return Err(LlmError::OAuthFailed {
563 message: format!("API key exchange failed (HTTP {}): {}", status, body_text),
564 });
565 }
566
567 let json: serde_json::Value =
568 serde_json::from_str(&body_text).map_err(|e| LlmError::OAuthFailed {
569 message: format!("Invalid JSON in API key exchange response: {}", e),
570 })?;
571
572 json["access_token"]
573 .as_str()
574 .map(|s| s.to_string())
575 .ok_or_else(|| LlmError::OAuthFailed {
576 message: "API key exchange response missing 'access_token'".to_string(),
577 })
578}
579
580fn parse_token_response(body: &str) -> Result<OAuthToken, LlmError> {
582 let json: serde_json::Value =
583 serde_json::from_str(body).map_err(|e| LlmError::OAuthFailed {
584 message: format!("Invalid JSON in token response: {}", e),
585 })?;
586
587 let access_token = json["access_token"]
588 .as_str()
589 .ok_or_else(|| LlmError::OAuthFailed {
590 message: "Token response missing 'access_token'".to_string(),
591 })?
592 .to_string();
593
594 let refresh_token = json["refresh_token"].as_str().map(|s| s.to_string());
595 let id_token = json["id_token"].as_str().map(|s| s.to_string());
596 let token_type = json["token_type"].as_str().unwrap_or("Bearer").to_string();
597
598 let expires_at = json["expires_in"]
599 .as_u64()
600 .map(|secs| Utc::now() + chrono::Duration::seconds(secs as i64));
601
602 let scopes = json["scope"]
603 .as_str()
604 .map(|s| s.split_whitespace().map(|s| s.to_string()).collect())
605 .unwrap_or_default();
606
607 Ok(OAuthToken {
608 access_token,
609 refresh_token,
610 id_token,
611 expires_at,
612 token_type,
613 scopes,
614 })
615}
616
617pub async fn authorize_device_code_flow(
627 config: &OAuthProviderConfig,
628) -> Result<OAuthToken, LlmError> {
629 let device_code_url =
630 config
631 .device_code_url
632 .as_deref()
633 .ok_or_else(|| LlmError::OAuthFailed {
634 message: format!(
635 "Provider '{}' does not support device code flow",
636 config.provider_name
637 ),
638 })?;
639
640 let client = reqwest::Client::new();
641
642 let mut params = HashMap::new();
644 params.insert("client_id", config.client_id.as_str());
645 if !config.scopes.is_empty() {
646 let scope_str = config.scopes.join(" ");
647 params.insert("scope", Box::leak(scope_str.into_boxed_str()));
648 }
649 if let Some(ref audience) = config.audience {
650 params.insert("audience", audience.as_str());
651 }
652
653 let response = client
654 .post(device_code_url)
655 .form(¶ms)
656 .send()
657 .await
658 .map_err(|e| LlmError::OAuthFailed {
659 message: format!("Device code request failed: {}", e),
660 })?;
661
662 let status = response.status();
663 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
664 message: format!("Failed to read device code response: {}", e),
665 })?;
666
667 if !status.is_success() {
668 return Err(LlmError::OAuthFailed {
669 message: format!(
670 "Device code request failed (HTTP {}): {}",
671 status, body_text
672 ),
673 });
674 }
675
676 let json: serde_json::Value =
677 serde_json::from_str(&body_text).map_err(|e| LlmError::OAuthFailed {
678 message: format!("Invalid JSON in device code response: {}", e),
679 })?;
680
681 let device_code = json["device_code"]
682 .as_str()
683 .ok_or_else(|| LlmError::OAuthFailed {
684 message: "Device code response missing 'device_code'".to_string(),
685 })?;
686 let user_code = json["user_code"]
687 .as_str()
688 .ok_or_else(|| LlmError::OAuthFailed {
689 message: "Device code response missing 'user_code'".to_string(),
690 })?;
691 let verification_uri = json["verification_uri"]
692 .as_str()
693 .or_else(|| json["verification_url"].as_str())
694 .ok_or_else(|| LlmError::OAuthFailed {
695 message: "Device code response missing 'verification_uri'".to_string(),
696 })?;
697 let interval = json["interval"].as_u64().unwrap_or(5);
698 let expires_in = json["expires_in"].as_u64().unwrap_or(600);
699
700 println!();
702 println!(" To authenticate, visit: {}", verification_uri);
703 println!(" Enter this code: {}", user_code);
704 println!();
705 println!(" Waiting for authorization...");
706
707 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(expires_in);
709 let poll_interval = std::time::Duration::from_secs(interval);
710
711 loop {
712 tokio::time::sleep(poll_interval).await;
713
714 if tokio::time::Instant::now() >= deadline {
715 return Err(LlmError::OAuthFailed {
716 message: "Device code flow timed out waiting for authorization".to_string(),
717 });
718 }
719
720 let mut poll_params = HashMap::new();
721 poll_params.insert("grant_type", "urn:ietf:params:oauth:grant-type:device_code");
722 poll_params.insert("device_code", device_code);
723 poll_params.insert("client_id", &config.client_id);
724
725 let poll_response = client
726 .post(&config.token_url)
727 .form(&poll_params)
728 .send()
729 .await
730 .map_err(|e| LlmError::OAuthFailed {
731 message: format!("Token poll request failed: {}", e),
732 })?;
733
734 let poll_status = poll_response.status();
735 let poll_body = poll_response
736 .text()
737 .await
738 .map_err(|e| LlmError::OAuthFailed {
739 message: format!("Failed to read token poll response: {}", e),
740 })?;
741
742 if poll_status.is_success() {
743 return parse_token_response(&poll_body);
744 }
745
746 if let Ok(err_json) = serde_json::from_str::<serde_json::Value>(&poll_body) {
748 let error = err_json["error"].as_str().unwrap_or("");
749 match error {
750 "authorization_pending" => {
751 debug!("Device code flow: authorization pending, polling again...");
752 continue;
753 }
754 "slow_down" => {
755 debug!("Device code flow: slow down requested");
756 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
757 continue;
758 }
759 "expired_token" => {
760 return Err(LlmError::OAuthFailed {
761 message: "Device code expired. Please try again.".to_string(),
762 });
763 }
764 "access_denied" => {
765 return Err(LlmError::OAuthFailed {
766 message: "Authorization was denied by the user.".to_string(),
767 });
768 }
769 _ => {
770 return Err(LlmError::OAuthFailed {
771 message: format!("Token poll error: {}", poll_body),
772 });
773 }
774 }
775 }
776
777 return Err(LlmError::OAuthFailed {
779 message: format!("Token poll failed (HTTP {}): {}", poll_status, poll_body),
780 });
781 }
782}
783
784pub async fn refresh_token(
788 config: &OAuthProviderConfig,
789 refresh_token_str: &str,
790) -> Result<OAuthToken, LlmError> {
791 let client = reqwest::Client::new();
792
793 let mut params = HashMap::new();
794 params.insert("grant_type", "refresh_token");
795 params.insert("refresh_token", refresh_token_str);
796 params.insert("client_id", &config.client_id);
797
798 debug!(provider = %config.provider_name, "Refreshing OAuth token");
799
800 let response = client
801 .post(&config.token_url)
802 .form(¶ms)
803 .send()
804 .await
805 .map_err(|e| LlmError::OAuthFailed {
806 message: format!("Token refresh request failed: {}", e),
807 })?;
808
809 let status = response.status();
810 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
811 message: format!("Failed to read token refresh response: {}", e),
812 })?;
813
814 if !status.is_success() {
815 return Err(LlmError::OAuthFailed {
816 message: format!("Token refresh failed (HTTP {}): {}", status, body_text),
817 });
818 }
819
820 let mut token = parse_token_response(&body_text)?;
821
822 if token.refresh_token.is_none() {
824 token.refresh_token = Some(refresh_token_str.to_string());
825 }
826
827 Ok(token)
828}
829
830pub fn is_token_expired(token: &OAuthToken) -> bool {
834 match token.expires_at {
835 Some(expires_at) => {
836 let buffer = chrono::Duration::minutes(5);
837 Utc::now() >= (expires_at - buffer)
838 }
839 None => false,
841 }
842}
843
844pub fn store_oauth_token(
850 store: &dyn CredentialStore,
851 provider: &str,
852 token: &OAuthToken,
853) -> Result<(), LlmError> {
854 let key = format!("oauth:{}", provider);
855 let json = serde_json::to_string(token).map_err(|e| LlmError::OAuthFailed {
856 message: format!("Failed to serialize OAuth token: {}", e),
857 })?;
858 store
859 .store_key(&key, &json)
860 .map_err(|e| LlmError::OAuthFailed {
861 message: format!("Failed to store OAuth token: {}", e),
862 })
863}
864
865pub fn load_oauth_token(
867 store: &dyn CredentialStore,
868 provider: &str,
869) -> Result<OAuthToken, LlmError> {
870 let key = format!("oauth:{}", provider);
871 let json = store.get_key(&key).map_err(|e| match e {
872 CredentialError::NotFound { .. } => LlmError::OAuthFailed {
873 message: format!("No OAuth token found for provider '{}'", provider),
874 },
875 other => LlmError::OAuthFailed {
876 message: format!("Failed to load OAuth token: {}", other),
877 },
878 })?;
879 serde_json::from_str(&json).map_err(|e| LlmError::OAuthFailed {
880 message: format!("Failed to deserialize OAuth token: {}", e),
881 })
882}
883
884pub fn delete_oauth_token(store: &dyn CredentialStore, provider: &str) -> Result<(), LlmError> {
886 let key = format!("oauth:{}", provider);
887 store.delete_key(&key).map_err(|e| LlmError::OAuthFailed {
888 message: format!("Failed to delete OAuth token: {}", e),
889 })
890}
891
892pub fn has_oauth_token(store: &dyn CredentialStore, provider: &str) -> bool {
894 let key = format!("oauth:{}", provider);
895 store.has_key(&key)
896}
897
898pub fn openai_oauth_config() -> OAuthProviderConfig {
904 OAuthProviderConfig {
905 provider_name: "openai".to_string(),
906 client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(),
907 client_secret: None, authorization_url: "https://auth.openai.com/oauth/authorize".to_string(),
909 token_url: "https://auth.openai.com/oauth/token".to_string(),
910 scopes: vec![
911 "openid".to_string(),
912 "profile".to_string(),
913 "email".to_string(),
914 "offline_access".to_string(),
915 ],
916 audience: None,
917 supports_device_code: true,
918 device_code_url: Some("https://auth.openai.com/oauth/device/code".to_string()),
919 extra_auth_params: vec![
920 ("id_token_add_organizations".to_string(), "true".to_string()),
921 ("codex_cli_simplified_flow".to_string(), "true".to_string()),
922 ("originator".to_string(), "codex_cli_rs".to_string()),
923 ],
924 }
925}
926
927pub fn google_oauth_config() -> Option<OAuthProviderConfig> {
933 let client_id = std::env::var("GOOGLE_OAUTH_CLIENT_ID").ok()?;
934 let client_secret = std::env::var("GOOGLE_OAUTH_CLIENT_SECRET").ok();
935 Some(OAuthProviderConfig {
936 provider_name: "google".to_string(),
937 client_id,
938 client_secret,
939 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
940 token_url: "https://oauth2.googleapis.com/token".to_string(),
941 scopes: vec!["https://www.googleapis.com/auth/generative-language".to_string()],
942 audience: None,
943 supports_device_code: false,
944 device_code_url: None,
945 extra_auth_params: vec![],
946 })
947}
948
949pub fn anthropic_oauth_config() -> Option<OAuthProviderConfig> {
955 None
958}
959
960pub fn slack_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
967 OAuthProviderConfig {
968 provider_name: "slack".to_string(),
969 client_id: client_id.to_string(),
970 client_secret,
971 authorization_url: "https://slack.com/oauth/v2/authorize".to_string(),
972 token_url: "https://slack.com/api/oauth.v2.access".to_string(),
973 scopes: vec![
974 "chat:write".to_string(),
975 "channels:history".to_string(),
976 "channels:read".to_string(),
977 "users:read".to_string(),
978 ],
979 audience: None,
980 supports_device_code: false,
981 device_code_url: None,
982 extra_auth_params: vec![],
983 }
984}
985
986pub fn discord_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
991 OAuthProviderConfig {
992 provider_name: "discord".to_string(),
993 client_id: client_id.to_string(),
994 client_secret,
995 authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
996 token_url: "https://discord.com/api/oauth2/token".to_string(),
997 scopes: vec!["bot".to_string(), "messages.read".to_string()],
998 audience: None,
999 supports_device_code: false,
1000 device_code_url: None,
1001 extra_auth_params: vec![],
1002 }
1003}
1004
1005pub fn teams_oauth_config(
1013 client_id: &str,
1014 tenant_id: &str,
1015 client_secret: Option<String>,
1016) -> OAuthProviderConfig {
1017 OAuthProviderConfig {
1018 provider_name: "teams".to_string(),
1019 client_id: client_id.to_string(),
1020 client_secret,
1021 authorization_url: format!(
1022 "https://login.microsoftonline.com/{}/oauth2/v2.0/authorize",
1023 tenant_id
1024 ),
1025 token_url: format!(
1026 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
1027 tenant_id
1028 ),
1029 scopes: vec!["https://graph.microsoft.com/.default".to_string()],
1030 audience: None,
1031 supports_device_code: true,
1032 device_code_url: Some(format!(
1033 "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
1034 tenant_id
1035 )),
1036 extra_auth_params: vec![],
1037 }
1038}
1039
1040pub fn whatsapp_oauth_config(app_id: &str, app_secret: Option<String>) -> OAuthProviderConfig {
1045 OAuthProviderConfig {
1046 provider_name: "whatsapp".to_string(),
1047 client_id: app_id.to_string(),
1048 client_secret: app_secret,
1049 authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
1050 token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
1051 scopes: vec![
1052 "whatsapp_business_messaging".to_string(),
1053 "whatsapp_business_management".to_string(),
1054 ],
1055 audience: None,
1056 supports_device_code: false,
1057 device_code_url: None,
1058 extra_auth_params: vec![],
1059 }
1060}
1061
1062pub fn gmail_oauth_config(client_id: &str, client_secret: Option<String>) -> OAuthProviderConfig {
1068 OAuthProviderConfig {
1069 provider_name: "gmail".to_string(),
1070 client_id: client_id.to_string(),
1071 client_secret,
1072 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1073 token_url: "https://oauth2.googleapis.com/token".to_string(),
1074 scopes: vec!["https://mail.google.com/".to_string()],
1075 audience: None,
1076 supports_device_code: false,
1077 device_code_url: None,
1078 extra_auth_params: vec![
1079 ("access_type".to_string(), "offline".to_string()),
1080 ("prompt".to_string(), "consent".to_string()),
1081 ],
1082 }
1083}
1084
1085pub async fn authorize_client_credentials_flow(
1095 config: &OAuthProviderConfig,
1096 client_secret: &str,
1097) -> Result<OAuthToken, LlmError> {
1098 let client = reqwest::Client::new();
1099
1100 let secret = if client_secret.is_empty() {
1102 config.client_secret.as_deref().unwrap_or("")
1103 } else {
1104 client_secret
1105 };
1106
1107 let body = format!(
1108 "grant_type={}&client_id={}&client_secret={}&scope={}",
1109 urlencoding::encode("client_credentials"),
1110 urlencoding::encode(&config.client_id),
1111 urlencoding::encode(secret),
1112 urlencoding::encode(&config.scopes.join(" ")),
1113 );
1114
1115 debug!(provider = %config.provider_name, "Requesting client credentials token");
1116
1117 let response = client
1118 .post(&config.token_url)
1119 .header("Content-Type", "application/x-www-form-urlencoded")
1120 .body(body)
1121 .send()
1122 .await
1123 .map_err(|e| LlmError::OAuthFailed {
1124 message: format!("Client credentials request failed: {}", e),
1125 })?;
1126
1127 let status = response.status();
1128 let body_text = response.text().await.map_err(|e| LlmError::OAuthFailed {
1129 message: format!("Failed to read client credentials response: {}", e),
1130 })?;
1131
1132 if !status.is_success() {
1133 return Err(LlmError::OAuthFailed {
1134 message: format!(
1135 "Client credentials token request failed (HTTP {}): {}",
1136 status, body_text
1137 ),
1138 });
1139 }
1140
1141 parse_token_response(&body_text)
1142}
1143
1144pub fn build_xoauth2_token(email: &str, access_token: &str) -> String {
1149 format!("user={}\x01auth=Bearer {}\x01\x01", email, access_token)
1150}
1151
1152pub fn build_xoauth2_token_base64(email: &str, access_token: &str) -> String {
1154 use base64::Engine;
1155 let raw = build_xoauth2_token(email, access_token);
1156 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1157}
1158
1159pub fn oauth_config_for_provider(provider: &str) -> Option<OAuthProviderConfig> {
1167 match provider {
1168 "openai" => Some(openai_oauth_config()),
1169 "gemini" | "google" => google_oauth_config(),
1170 "anthropic" => anthropic_oauth_config(),
1171 "slack" => {
1172 let client_id = std::env::var("SLACK_CLIENT_ID").ok()?;
1173 let client_secret = std::env::var("SLACK_CLIENT_SECRET").ok();
1174 Some(slack_oauth_config(&client_id, client_secret))
1175 }
1176 "discord" => {
1177 let client_id = std::env::var("DISCORD_CLIENT_ID").ok()?;
1178 let client_secret = std::env::var("DISCORD_CLIENT_SECRET").ok();
1179 Some(discord_oauth_config(&client_id, client_secret))
1180 }
1181 "teams" => {
1182 let client_id = std::env::var("TEAMS_CLIENT_ID").ok()?;
1183 let tenant_id =
1184 std::env::var("TEAMS_TENANT_ID").unwrap_or_else(|_| "common".to_string());
1185 let client_secret = std::env::var("TEAMS_CLIENT_SECRET").ok();
1186 Some(teams_oauth_config(&client_id, &tenant_id, client_secret))
1187 }
1188 "whatsapp" => {
1189 let app_id = std::env::var("WHATSAPP_APP_ID").ok()?;
1190 let app_secret = std::env::var("WHATSAPP_APP_SECRET").ok();
1191 Some(whatsapp_oauth_config(&app_id, app_secret))
1192 }
1193 "gmail" => {
1194 let client_id = std::env::var("GMAIL_OAUTH_CLIENT_ID")
1195 .or_else(|_| std::env::var("GOOGLE_OAUTH_CLIENT_ID"))
1196 .ok()?;
1197 let client_secret = std::env::var("GMAIL_OAUTH_CLIENT_SECRET")
1198 .or_else(|_| std::env::var("GOOGLE_OAUTH_CLIENT_SECRET"))
1199 .ok();
1200 Some(gmail_oauth_config(&client_id, client_secret))
1201 }
1202 _ => None,
1203 }
1204}
1205
1206pub fn oauth_config_with_credentials(
1213 provider: &str,
1214 client_id: &str,
1215 client_secret: Option<&str>,
1216) -> Option<OAuthProviderConfig> {
1217 let secret = client_secret.map(String::from);
1218 match provider {
1219 "slack" => Some(slack_oauth_config(client_id, secret)),
1220 "discord" => Some(discord_oauth_config(client_id, secret)),
1221 "gmail" => Some(gmail_oauth_config(client_id, secret)),
1222 _ => None,
1223 }
1224}
1225
1226pub fn provider_supports_oauth(provider: &str) -> bool {
1228 match provider {
1229 "openai" => true,
1230 "gemini" | "google" => std::env::var("GOOGLE_OAUTH_CLIENT_ID").is_ok(),
1231 "slack" => std::env::var("SLACK_CLIENT_ID").is_ok(),
1232 "discord" => std::env::var("DISCORD_CLIENT_ID").is_ok(),
1233 "teams" => std::env::var("TEAMS_CLIENT_ID").is_ok(),
1234 "whatsapp" => std::env::var("WHATSAPP_APP_ID").is_ok(),
1235 "gmail" => {
1236 std::env::var("GMAIL_OAUTH_CLIENT_ID").is_ok()
1237 || std::env::var("GOOGLE_OAUTH_CLIENT_ID").is_ok()
1238 }
1239 _ => false,
1240 }
1241}
1242
1243#[cfg(test)]
1246mod tests {
1247 use super::*;
1248 use crate::credentials::InMemoryCredentialStore;
1249
1250 #[test]
1251 fn test_generate_pkce_pair() {
1252 let pair = generate_pkce_pair();
1253 assert_eq!(pair.verifier.len(), 43);
1254 assert!(!pair.challenge.is_empty());
1255
1256 let decoded = URL_SAFE_NO_PAD.decode(&pair.challenge).unwrap();
1258 assert_eq!(decoded.len(), 32); let mut hasher = Sha256::new();
1262 hasher.update(pair.verifier.as_bytes());
1263 let expected = hasher.finalize();
1264 assert_eq!(decoded, expected.as_slice());
1265 }
1266
1267 #[test]
1268 fn test_generate_pkce_pair_uniqueness() {
1269 let pair1 = generate_pkce_pair();
1270 let pair2 = generate_pkce_pair();
1271 assert_ne!(pair1.verifier, pair2.verifier);
1272 assert_ne!(pair1.challenge, pair2.challenge);
1273 }
1274
1275 #[test]
1276 fn test_generate_state() {
1277 let state = generate_state();
1278 assert!(!state.is_empty());
1279 assert_eq!(state.len(), 43);
1281 }
1282
1283 #[test]
1284 fn test_generate_state_uniqueness() {
1285 let s1 = generate_state();
1286 let s2 = generate_state();
1287 assert_ne!(s1, s2);
1288 }
1289
1290 #[test]
1291 fn test_parse_token_response_full() {
1292 let body = serde_json::json!({
1293 "access_token": "at-12345",
1294 "refresh_token": "rt-67890",
1295 "token_type": "Bearer",
1296 "expires_in": 3600,
1297 "scope": "openai.public"
1298 })
1299 .to_string();
1300
1301 let token = parse_token_response(&body).unwrap();
1302 assert_eq!(token.access_token, "at-12345");
1303 assert_eq!(token.refresh_token, Some("rt-67890".to_string()));
1304 assert_eq!(token.token_type, "Bearer");
1305 assert!(token.expires_at.is_some());
1306 assert_eq!(token.scopes, vec!["openai.public"]);
1307 }
1308
1309 #[test]
1310 fn test_parse_token_response_minimal() {
1311 let body = serde_json::json!({
1312 "access_token": "at-minimal"
1313 })
1314 .to_string();
1315
1316 let token = parse_token_response(&body).unwrap();
1317 assert_eq!(token.access_token, "at-minimal");
1318 assert!(token.refresh_token.is_none());
1319 assert_eq!(token.token_type, "Bearer");
1320 assert!(token.expires_at.is_none());
1321 assert!(token.scopes.is_empty());
1322 }
1323
1324 #[test]
1325 fn test_parse_token_response_missing_access_token() {
1326 let body = serde_json::json!({
1327 "token_type": "Bearer"
1328 })
1329 .to_string();
1330
1331 let result = parse_token_response(&body);
1332 assert!(result.is_err());
1333 match result.unwrap_err() {
1334 LlmError::OAuthFailed { message } => {
1335 assert!(message.contains("access_token"));
1336 }
1337 other => panic!("Expected OAuthFailed, got {:?}", other),
1338 }
1339 }
1340
1341 #[test]
1342 fn test_parse_token_response_invalid_json() {
1343 let result = parse_token_response("not json");
1344 assert!(result.is_err());
1345 }
1346
1347 #[test]
1348 fn test_is_token_expired_future() {
1349 let token = OAuthToken {
1350 access_token: "test".to_string(),
1351 refresh_token: None,
1352 id_token: None,
1353 expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
1354 token_type: "Bearer".to_string(),
1355 scopes: vec![],
1356 };
1357 assert!(!is_token_expired(&token));
1358 }
1359
1360 #[test]
1361 fn test_is_token_expired_past() {
1362 let token = OAuthToken {
1363 access_token: "test".to_string(),
1364 refresh_token: None,
1365 id_token: None,
1366 expires_at: Some(Utc::now() - chrono::Duration::hours(1)),
1367 token_type: "Bearer".to_string(),
1368 scopes: vec![],
1369 };
1370 assert!(is_token_expired(&token));
1371 }
1372
1373 #[test]
1374 fn test_is_token_expired_within_buffer() {
1375 let token = OAuthToken {
1377 access_token: "test".to_string(),
1378 refresh_token: None,
1379 id_token: None,
1380 expires_at: Some(Utc::now() + chrono::Duration::minutes(3)),
1381 token_type: "Bearer".to_string(),
1382 scopes: vec![],
1383 };
1384 assert!(is_token_expired(&token));
1385 }
1386
1387 #[test]
1388 fn test_is_token_expired_no_expiry() {
1389 let token = OAuthToken {
1390 access_token: "test".to_string(),
1391 refresh_token: None,
1392 id_token: None,
1393 expires_at: None,
1394 token_type: "Bearer".to_string(),
1395 scopes: vec![],
1396 };
1397 assert!(!is_token_expired(&token));
1398 }
1399
1400 #[test]
1401 fn test_store_and_load_oauth_token() {
1402 let store = InMemoryCredentialStore::new();
1403 let token = OAuthToken {
1404 access_token: "at-test-store".to_string(),
1405 refresh_token: Some("rt-test-store".to_string()),
1406 id_token: None,
1407 expires_at: None,
1408 token_type: "Bearer".to_string(),
1409 scopes: vec!["openai.public".to_string()],
1410 };
1411
1412 store_oauth_token(&store, "openai", &token).unwrap();
1413 let loaded = load_oauth_token(&store, "openai").unwrap();
1414 assert_eq!(loaded.access_token, "at-test-store");
1415 assert_eq!(loaded.refresh_token, Some("rt-test-store".to_string()));
1416 assert_eq!(loaded.scopes, vec!["openai.public"]);
1417 }
1418
1419 #[test]
1420 fn test_load_oauth_token_not_found() {
1421 let store = InMemoryCredentialStore::new();
1422 let result = load_oauth_token(&store, "nonexistent");
1423 assert!(result.is_err());
1424 }
1425
1426 #[test]
1427 fn test_delete_oauth_token() {
1428 let store = InMemoryCredentialStore::new();
1429 let token = OAuthToken {
1430 access_token: "at-delete".to_string(),
1431 refresh_token: None,
1432 id_token: None,
1433 expires_at: None,
1434 token_type: "Bearer".to_string(),
1435 scopes: vec![],
1436 };
1437
1438 store_oauth_token(&store, "openai", &token).unwrap();
1439 assert!(has_oauth_token(&store, "openai"));
1440
1441 delete_oauth_token(&store, "openai").unwrap();
1442 assert!(!has_oauth_token(&store, "openai"));
1443 }
1444
1445 #[test]
1446 fn test_has_oauth_token() {
1447 let store = InMemoryCredentialStore::new();
1448 assert!(!has_oauth_token(&store, "openai"));
1449
1450 let token = OAuthToken {
1451 access_token: "at-has".to_string(),
1452 refresh_token: None,
1453 id_token: None,
1454 expires_at: None,
1455 token_type: "Bearer".to_string(),
1456 scopes: vec![],
1457 };
1458 store_oauth_token(&store, "openai", &token).unwrap();
1459 assert!(has_oauth_token(&store, "openai"));
1460 }
1461
1462 #[test]
1463 fn test_openai_oauth_config() {
1464 let config = openai_oauth_config();
1465 assert_eq!(config.provider_name, "openai");
1466 assert_eq!(config.client_id, "app_EMoamEEZ73f0CkXaXp7hrann");
1467 assert!(config.authorization_url.contains("auth.openai.com"));
1468 assert!(config.token_url.contains("auth.openai.com"));
1469 assert!(config.supports_device_code);
1470 assert!(config.device_code_url.is_some());
1471 assert_eq!(
1472 config.scopes,
1473 vec!["openid", "profile", "email", "offline_access"]
1474 );
1475 assert_eq!(config.audience, None);
1476 assert_eq!(config.extra_auth_params.len(), 3);
1477 }
1478
1479 #[test]
1480 fn test_anthropic_oauth_config_returns_none() {
1481 assert!(anthropic_oauth_config().is_none());
1482 }
1483
1484 #[test]
1485 fn test_oauth_config_for_provider() {
1486 assert!(oauth_config_for_provider("openai").is_some());
1487 assert!(oauth_config_for_provider("anthropic").is_none());
1488 assert!(oauth_config_for_provider("unknown").is_none());
1489 }
1490
1491 #[test]
1492 fn test_provider_supports_oauth() {
1493 assert!(provider_supports_oauth("openai"));
1494 assert!(!provider_supports_oauth("anthropic"));
1495 assert!(!provider_supports_oauth("unknown"));
1496 }
1497
1498 #[test]
1499 fn test_auth_method_serde() {
1500 let json = serde_json::to_string(&AuthMethod::OAuth).unwrap();
1501 assert_eq!(json, "\"oauth\"");
1502 let method: AuthMethod = serde_json::from_str("\"api_key\"").unwrap();
1503 assert_eq!(method, AuthMethod::ApiKey);
1504 }
1505
1506 #[test]
1507 fn test_auth_method_default() {
1508 assert_eq!(AuthMethod::default(), AuthMethod::ApiKey);
1509 }
1510
1511 #[test]
1512 fn test_auth_method_display() {
1513 assert_eq!(AuthMethod::ApiKey.to_string(), "api_key");
1514 assert_eq!(AuthMethod::OAuth.to_string(), "oauth");
1515 }
1516
1517 #[test]
1518 fn test_oauth_token_serde_roundtrip() {
1519 let token = OAuthToken {
1520 access_token: "at-roundtrip".to_string(),
1521 refresh_token: Some("rt-roundtrip".to_string()),
1522 id_token: None,
1523 expires_at: Some(Utc::now()),
1524 token_type: "Bearer".to_string(),
1525 scopes: vec!["scope1".to_string(), "scope2".to_string()],
1526 };
1527 let json = serde_json::to_string(&token).unwrap();
1528 let parsed: OAuthToken = serde_json::from_str(&json).unwrap();
1529 assert_eq!(parsed.access_token, token.access_token);
1530 assert_eq!(parsed.refresh_token, token.refresh_token);
1531 assert_eq!(parsed.scopes.len(), 2);
1532 }
1533
1534 #[tokio::test]
1535 async fn test_callback_server_http_receives_code() {
1536 let (port, rx) = start_callback_server(false).await.unwrap();
1537 assert_eq!(port, OAUTH_CALLBACK_PORT);
1538
1539 let client = reqwest::Client::new();
1541 let url = format!(
1542 "http://127.0.0.1:{}/auth/callback?code=test-http&state=test-state-http",
1543 port
1544 );
1545 let response = client.get(&url).send().await.unwrap();
1546 assert!(response.status().is_success());
1547
1548 let callback = rx.await.unwrap();
1549 assert_eq!(callback.code, "test-http");
1550 assert_eq!(callback.state, "test-state-http");
1551 }
1552
1553 #[tokio::test]
1554 async fn test_tls_config_loading() {
1555 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1557 let config = load_tls_config().await;
1558 assert!(config.is_ok(), "TLS config loading should succeed");
1559 }
1560
1561 #[test]
1564 fn test_slack_oauth_config() {
1565 let config = slack_oauth_config("slack-client-123", Some("slack-secret".into()));
1566 assert_eq!(config.provider_name, "slack");
1567 assert_eq!(config.client_id, "slack-client-123");
1568 assert!(
1569 config
1570 .authorization_url
1571 .contains("slack.com/oauth/v2/authorize")
1572 );
1573 assert!(config.token_url.contains("slack.com/api/oauth.v2.access"));
1574 assert!(config.scopes.contains(&"chat:write".to_string()));
1575 assert!(config.scopes.contains(&"channels:history".to_string()));
1576 assert!(config.scopes.contains(&"channels:read".to_string()));
1577 assert!(config.scopes.contains(&"users:read".to_string()));
1578 assert!(!config.supports_device_code);
1579 }
1580
1581 #[test]
1582 fn test_discord_oauth_config() {
1583 let config = discord_oauth_config("discord-client-456", Some("discord-secret".into()));
1584 assert_eq!(config.provider_name, "discord");
1585 assert_eq!(config.client_id, "discord-client-456");
1586 assert!(
1587 config
1588 .authorization_url
1589 .contains("discord.com/api/oauth2/authorize")
1590 );
1591 assert!(config.token_url.contains("discord.com/api/oauth2/token"));
1592 assert!(config.scopes.contains(&"bot".to_string()));
1593 assert!(config.scopes.contains(&"messages.read".to_string()));
1594 assert!(!config.supports_device_code);
1595 }
1596
1597 #[test]
1598 fn test_teams_oauth_config() {
1599 let config = teams_oauth_config(
1600 "teams-client-789",
1601 "my-tenant-id",
1602 Some("teams-secret".into()),
1603 );
1604 assert_eq!(config.provider_name, "teams");
1605 assert_eq!(config.client_id, "teams-client-789");
1606 assert!(
1607 config
1608 .authorization_url
1609 .contains("login.microsoftonline.com/my-tenant-id")
1610 );
1611 assert!(
1612 config
1613 .token_url
1614 .contains("login.microsoftonline.com/my-tenant-id")
1615 );
1616 assert!(
1617 config
1618 .scopes
1619 .contains(&"https://graph.microsoft.com/.default".to_string())
1620 );
1621 assert!(config.supports_device_code);
1622 assert!(
1623 config
1624 .device_code_url
1625 .as_ref()
1626 .unwrap()
1627 .contains("my-tenant-id")
1628 );
1629 }
1630
1631 #[test]
1632 fn test_teams_oauth_config_common_tenant() {
1633 let config = teams_oauth_config("teams-client", "common", None);
1634 assert!(
1635 config
1636 .authorization_url
1637 .contains("common/oauth2/v2.0/authorize")
1638 );
1639 assert!(config.token_url.contains("common/oauth2/v2.0/token"));
1640 }
1641
1642 #[test]
1643 fn test_whatsapp_oauth_config() {
1644 let config = whatsapp_oauth_config("meta-app-123", Some("meta-secret".into()));
1645 assert_eq!(config.provider_name, "whatsapp");
1646 assert_eq!(config.client_id, "meta-app-123");
1647 assert!(
1648 config
1649 .authorization_url
1650 .contains("facebook.com/v18.0/dialog/oauth")
1651 );
1652 assert!(
1653 config
1654 .token_url
1655 .contains("graph.facebook.com/v18.0/oauth/access_token")
1656 );
1657 assert!(
1658 config
1659 .scopes
1660 .contains(&"whatsapp_business_messaging".to_string())
1661 );
1662 assert!(
1663 config
1664 .scopes
1665 .contains(&"whatsapp_business_management".to_string())
1666 );
1667 assert!(!config.supports_device_code);
1668 }
1669
1670 #[test]
1671 fn test_gmail_oauth_config() {
1672 let config = gmail_oauth_config("gmail-client-id", Some("gmail-secret".into()));
1673 assert_eq!(config.provider_name, "gmail");
1674 assert_eq!(config.client_id, "gmail-client-id");
1675 assert!(config.authorization_url.contains("accounts.google.com"));
1676 assert!(config.token_url.contains("oauth2.googleapis.com"));
1677 assert!(
1678 config
1679 .scopes
1680 .contains(&"https://mail.google.com/".to_string())
1681 );
1682 assert!(
1684 config
1685 .extra_auth_params
1686 .iter()
1687 .any(|(k, v)| k == "access_type" && v == "offline")
1688 );
1689 }
1690
1691 #[test]
1692 fn test_xoauth2_token_format() {
1693 let token = build_xoauth2_token("user@gmail.com", "ya29.access-token");
1694 assert_eq!(
1695 token,
1696 "user=user@gmail.com\x01auth=Bearer ya29.access-token\x01\x01"
1697 );
1698 }
1699
1700 #[test]
1701 fn test_xoauth2_token_base64() {
1702 let b64 = build_xoauth2_token_base64("user@gmail.com", "token123");
1703 let decoded = base64::engine::general_purpose::STANDARD
1705 .decode(&b64)
1706 .unwrap();
1707 let decoded_str = String::from_utf8(decoded).unwrap();
1708 assert!(decoded_str.starts_with("user=user@gmail.com\x01"));
1709 assert!(decoded_str.contains("auth=Bearer token123"));
1710 }
1711
1712 #[test]
1713 fn test_oauth_config_for_channel_providers_without_env() {
1714 let _ = oauth_config_for_provider("slack");
1717 let _ = oauth_config_for_provider("discord");
1718 let _ = oauth_config_for_provider("teams");
1719 let _ = oauth_config_for_provider("whatsapp");
1720 let _ = oauth_config_for_provider("gmail");
1721 }
1723
1724 #[test]
1725 fn test_store_and_load_channel_oauth_token() {
1726 let store = InMemoryCredentialStore::new();
1727 let token = OAuthToken {
1728 access_token: "xoxb-slack-token".to_string(),
1729 refresh_token: Some("xoxr-refresh".to_string()),
1730 id_token: None,
1731 expires_at: None,
1732 token_type: "Bearer".to_string(),
1733 scopes: vec!["chat:write".to_string(), "channels:history".to_string()],
1734 };
1735
1736 store_oauth_token(&store, "slack", &token).unwrap();
1737 let loaded = load_oauth_token(&store, "slack").unwrap();
1738 assert_eq!(loaded.access_token, "xoxb-slack-token");
1739 assert_eq!(loaded.scopes.len(), 2);
1740
1741 let teams_token = OAuthToken {
1743 access_token: "eyJ-teams-token".to_string(),
1744 refresh_token: None,
1745 id_token: None,
1746 expires_at: None,
1747 token_type: "Bearer".to_string(),
1748 scopes: vec!["https://graph.microsoft.com/.default".to_string()],
1749 };
1750 store_oauth_token(&store, "teams", &teams_token).unwrap();
1751 let loaded_teams = load_oauth_token(&store, "teams").unwrap();
1752 assert_eq!(loaded_teams.access_token, "eyJ-teams-token");
1753
1754 let loaded_slack = load_oauth_token(&store, "slack").unwrap();
1756 assert_eq!(loaded_slack.access_token, "xoxb-slack-token");
1757 }
1758}