Skip to main content

systemprompt_cloud/oauth/
client.rs

1use anyhow::{anyhow, Context, Result};
2use axum::extract::Query;
3use axum::response::Html;
4use axum::routing::get;
5use axum::Router;
6use reqwest::Client;
7use std::sync::Arc;
8use systemprompt_logging::CliService;
9use tokio::sync::{oneshot, Mutex};
10
11use crate::constants::oauth::{CALLBACK_PORT, CALLBACK_TIMEOUT_SECS};
12use crate::OAuthProvider;
13
14#[derive(serde::Deserialize)]
15struct CallbackParams {
16    access_token: Option<String>,
17    error: Option<String>,
18    error_description: Option<String>,
19}
20
21#[derive(serde::Deserialize)]
22struct AuthorizeResponse {
23    authorize_url: String,
24}
25
26#[derive(Debug, Clone, Copy)]
27pub struct OAuthTemplates {
28    pub success_html: &'static str,
29    pub error_html: &'static str,
30}
31
32pub async fn run_oauth_flow(
33    api_url: &str,
34    provider: OAuthProvider,
35    templates: OAuthTemplates,
36) -> Result<String> {
37    let (tx, rx) = oneshot::channel::<Result<String>>();
38    let tx = Arc::new(Mutex::new(Some(tx)));
39
40    let success_html = templates.success_html.to_string();
41    let error_html = templates.error_html.to_string();
42
43    let callback_handler = {
44        let tx = tx.clone();
45        let success_html = success_html.clone();
46        let error_html = error_html.clone();
47        move |Query(params): Query<CallbackParams>| {
48            let tx = tx.clone();
49            let success_html = success_html.clone();
50            let error_html = error_html.clone();
51            async move {
52                let result = if let Some(error) = params.error {
53                    let desc = params
54                        .error_description
55                        .unwrap_or_else(|| "(no description provided)".into());
56                    Err(anyhow!("OAuth error: {} - {}", error, desc))
57                } else if let Some(token) = params.access_token {
58                    Ok(token)
59                } else {
60                    Err(anyhow!("No token received in callback"))
61                };
62
63                let sender = tx.lock().await.take();
64                if let Some(sender) = sender {
65                    let is_success = result.is_ok();
66                    let _ = sender.send(result);
67
68                    if is_success {
69                        Html(success_html)
70                    } else {
71                        Html(error_html)
72                    }
73                } else {
74                    Html(error_html)
75                }
76            }
77        }
78    };
79
80    let app = Router::new().route("/callback", get(callback_handler));
81    let addr = format!("127.0.0.1:{CALLBACK_PORT}");
82    let listener = tokio::net::TcpListener::bind(&addr).await?;
83
84    CliService::info(&format!("Starting authentication server on http://{addr}"));
85
86    let redirect_uri = format!("http://127.0.0.1:{CALLBACK_PORT}/callback");
87
88    CliService::info("Fetching authorization URL...");
89
90    let client = Client::new();
91    let oauth_endpoint = format!(
92        "{}/api/v1/auth/oauth/{}?redirect_uri={}",
93        api_url,
94        provider.as_str(),
95        urlencoding::encode(&redirect_uri)
96    );
97
98    let response = client
99        .get(&oauth_endpoint)
100        .send()
101        .await
102        .context("Failed to connect to API")?;
103
104    if !response.status().is_success() {
105        let status = response.status();
106        let body = response.text().await.unwrap_or_else(|e| {
107            tracing::warn!(error = %e, "Failed to read OAuth error response body");
108            format!("(body unreadable: {})", e)
109        });
110        return Err(anyhow!(
111            "Failed to get authorization URL ({}): {}",
112            status,
113            body
114        ));
115    }
116
117    let auth_response: AuthorizeResponse = response
118        .json()
119        .await
120        .context("Failed to parse authorization response")?;
121
122    let auth_url = auth_response.authorize_url;
123
124    CliService::info(&format!(
125        "Opening browser for {} authentication...",
126        provider.display_name()
127    ));
128    CliService::info(&format!("URL: {auth_url}"));
129
130    if let Err(e) = open::that(&auth_url) {
131        CliService::warning(&format!("Could not open browser automatically: {e}"));
132        CliService::info("Please open this URL manually:");
133        CliService::key_value("URL", &auth_url);
134    }
135
136    CliService::info("Waiting for authentication...");
137    CliService::info(&format!("(timeout in {CALLBACK_TIMEOUT_SECS} seconds)"));
138
139    let server = axum::serve(listener, app);
140
141    tokio::select! {
142        result = rx => {
143            result.map_err(|_| anyhow!("Authentication cancelled"))?
144        }
145        _ = server => {
146            Err(anyhow!("Server stopped unexpectedly"))
147        }
148        () = tokio::time::sleep(std::time::Duration::from_secs(CALLBACK_TIMEOUT_SECS)) => {
149            Err(anyhow!("Authentication timed out after {CALLBACK_TIMEOUT_SECS} seconds"))
150        }
151    }
152}