Skip to main content

systemprompt_cloud/oauth/
client.rs

1use anyhow::{Context, Result, anyhow};
2use axum::Router;
3use axum::extract::Query;
4use axum::response::Html;
5use axum::routing::get;
6use reqwest::Client;
7use std::sync::Arc;
8use systemprompt_logging::CliService;
9use tokio::sync::{Mutex, oneshot};
10
11use crate::OAuthProvider;
12use crate::constants::oauth::{CALLBACK_PORT, CALLBACK_TIMEOUT_SECS};
13
14#[derive(serde::Deserialize)]
15struct CallbackParams {
16    access_token: Option<String>,
17    error: Option<String>,
18    error_description: Option<String>,
19}
20
21#[derive(serde::Deserialize)]
22struct AuthorizeResponse {
23    authorize_url: String,
24}
25
26#[derive(Debug, Clone, Copy)]
27pub struct OAuthTemplates {
28    pub success_html: &'static str,
29    pub error_html: &'static str,
30}
31
32pub async fn run_oauth_flow(
33    api_url: &str,
34    provider: OAuthProvider,
35    templates: OAuthTemplates,
36) -> Result<String> {
37    let (tx, rx) = oneshot::channel::<Result<String>>();
38    let tx = Arc::new(Mutex::new(Some(tx)));
39
40    let success_html = templates.success_html.to_string();
41    let error_html = templates.error_html.to_string();
42
43    let callback_handler = {
44        let tx = Arc::clone(&tx);
45        let success_html = success_html.clone();
46        let error_html = error_html.clone();
47        move |Query(params): Query<CallbackParams>| {
48            let tx = Arc::clone(&tx);
49            let success_html = success_html.clone();
50            let error_html = error_html.clone();
51            async move {
52                let result = if let Some(error) = params.error {
53                    let desc = params
54                        .error_description
55                        .unwrap_or_else(|| "(no description provided)".into());
56                    Err(anyhow!("OAuth error: {} - {}", error, desc))
57                } else if let Some(token) = params.access_token {
58                    Ok(token)
59                } else {
60                    Err(anyhow!("No token received in callback"))
61                };
62
63                let sender = tx.lock().await.take();
64                if let Some(sender) = sender {
65                    let is_success = result.is_ok();
66                    if sender.send(result).is_err() {
67                        tracing::warn!("OAuth result receiver dropped before result could be sent");
68                    }
69
70                    if is_success {
71                        Html(success_html)
72                    } else {
73                        Html(error_html)
74                    }
75                } else {
76                    Html(error_html)
77                }
78            }
79        }
80    };
81
82    let app = Router::new().route("/callback", get(callback_handler));
83    let addr = format!("127.0.0.1:{CALLBACK_PORT}");
84    let listener = tokio::net::TcpListener::bind(&addr).await?;
85
86    CliService::info(&format!("Starting authentication server on http://{addr}"));
87
88    let redirect_uri = format!("http://127.0.0.1:{CALLBACK_PORT}/callback");
89
90    CliService::info("Fetching authorization URL...");
91
92    let client = Client::new();
93    let oauth_endpoint = format!(
94        "{}/api/v1/auth/oauth/{}?redirect_uri={}",
95        api_url,
96        provider.as_str(),
97        urlencoding::encode(&redirect_uri)
98    );
99
100    let response = client
101        .get(&oauth_endpoint)
102        .send()
103        .await
104        .context("Failed to connect to API")?;
105
106    if !response.status().is_success() {
107        let status = response.status();
108        let body = response.text().await.unwrap_or_else(|e| {
109            tracing::warn!(error = %e, "Failed to read OAuth error response body");
110            format!("(body unreadable: {})", e)
111        });
112        return Err(anyhow!(
113            "Failed to get authorization URL ({}): {}",
114            status,
115            body
116        ));
117    }
118
119    let auth_response: AuthorizeResponse = response
120        .json()
121        .await
122        .context("Failed to parse authorization response")?;
123
124    let auth_url = auth_response.authorize_url;
125
126    CliService::info(&format!(
127        "Opening browser for {} authentication...",
128        provider.display_name()
129    ));
130    CliService::info(&format!("URL: {auth_url}"));
131
132    if let Err(e) = open::that(&auth_url) {
133        CliService::warning(&format!("Could not open browser automatically: {e}"));
134        CliService::info("Please open this URL manually:");
135        CliService::key_value("URL", &auth_url);
136    }
137
138    CliService::info("Waiting for authentication...");
139    CliService::info(&format!("(timeout in {CALLBACK_TIMEOUT_SECS} seconds)"));
140
141    let server = axum::serve(listener, app);
142
143    tokio::select! {
144        result = rx => {
145            result.map_err(|_| anyhow!("Authentication cancelled"))?
146        }
147        _ = server => {
148            Err(anyhow!("Server stopped unexpectedly"))
149        }
150        () = tokio::time::sleep(std::time::Duration::from_secs(CALLBACK_TIMEOUT_SECS)) => {
151            Err(anyhow!("Authentication timed out after {CALLBACK_TIMEOUT_SECS} seconds"))
152        }
153    }
154}