systemprompt_cloud/oauth/
client.rs1use anyhow::{Context, Result, anyhow};
2use axum::Router;
3use axum::extract::Query;
4use axum::response::Html;
5use axum::routing::get;
6use reqwest::Client;
7use std::sync::Arc;
8use systemprompt_logging::CliService;
9use tokio::sync::{Mutex, oneshot};
10
11use crate::OAuthProvider;
12use crate::constants::oauth::{CALLBACK_PORT, CALLBACK_TIMEOUT_SECS};
13
14#[derive(serde::Deserialize)]
15struct CallbackParams {
16 access_token: Option<String>,
17 error: Option<String>,
18 error_description: Option<String>,
19}
20
21#[derive(serde::Deserialize)]
22struct AuthorizeResponse {
23 authorize_url: String,
24}
25
26#[derive(Debug, Clone, Copy)]
27pub struct OAuthTemplates {
28 pub success_html: &'static str,
29 pub error_html: &'static str,
30}
31
32pub async fn run_oauth_flow(
33 api_url: &str,
34 provider: OAuthProvider,
35 templates: OAuthTemplates,
36) -> Result<String> {
37 let (tx, rx) = oneshot::channel::<Result<String>>();
38 let tx = Arc::new(Mutex::new(Some(tx)));
39
40 let success_html = templates.success_html.to_string();
41 let error_html = templates.error_html.to_string();
42
43 let callback_handler = {
44 let tx = Arc::clone(&tx);
45 let success_html = success_html.clone();
46 let error_html = error_html.clone();
47 move |Query(params): Query<CallbackParams>| {
48 let tx = Arc::clone(&tx);
49 let success_html = success_html.clone();
50 let error_html = error_html.clone();
51 async move {
52 let result = if let Some(error) = params.error {
53 let desc = params
54 .error_description
55 .unwrap_or_else(|| "(no description provided)".into());
56 Err(anyhow!("OAuth error: {} - {}", error, desc))
57 } else if let Some(token) = params.access_token {
58 Ok(token)
59 } else {
60 Err(anyhow!("No token received in callback"))
61 };
62
63 let sender = tx.lock().await.take();
64 if let Some(sender) = sender {
65 let is_success = result.is_ok();
66 if sender.send(result).is_err() {
67 tracing::warn!("OAuth result receiver dropped before result could be sent");
68 }
69
70 if is_success {
71 Html(success_html)
72 } else {
73 Html(error_html)
74 }
75 } else {
76 Html(error_html)
77 }
78 }
79 }
80 };
81
82 let app = Router::new().route("/callback", get(callback_handler));
83 let addr = format!("127.0.0.1:{CALLBACK_PORT}");
84 let listener = tokio::net::TcpListener::bind(&addr).await?;
85
86 CliService::info(&format!("Starting authentication server on http://{addr}"));
87
88 let redirect_uri = format!("http://127.0.0.1:{CALLBACK_PORT}/callback");
89
90 CliService::info("Fetching authorization URL...");
91
92 let client = Client::new();
93 let oauth_endpoint = format!(
94 "{}/api/v1/auth/oauth/{}?redirect_uri={}",
95 api_url,
96 provider.as_str(),
97 urlencoding::encode(&redirect_uri)
98 );
99
100 let response = client
101 .get(&oauth_endpoint)
102 .send()
103 .await
104 .context("Failed to connect to API")?;
105
106 if !response.status().is_success() {
107 let status = response.status();
108 let body = response.text().await.unwrap_or_else(|e| {
109 tracing::warn!(error = %e, "Failed to read OAuth error response body");
110 format!("(body unreadable: {})", e)
111 });
112 return Err(anyhow!(
113 "Failed to get authorization URL ({}): {}",
114 status,
115 body
116 ));
117 }
118
119 let auth_response: AuthorizeResponse = response
120 .json()
121 .await
122 .context("Failed to parse authorization response")?;
123
124 let auth_url = auth_response.authorize_url;
125
126 CliService::info(&format!(
127 "Opening browser for {} authentication...",
128 provider.display_name()
129 ));
130 CliService::info(&format!("URL: {auth_url}"));
131
132 if let Err(e) = open::that(&auth_url) {
133 CliService::warning(&format!("Could not open browser automatically: {e}"));
134 CliService::info("Please open this URL manually:");
135 CliService::key_value("URL", &auth_url);
136 }
137
138 CliService::info("Waiting for authentication...");
139 CliService::info(&format!("(timeout in {CALLBACK_TIMEOUT_SECS} seconds)"));
140
141 let server = axum::serve(listener, app);
142
143 tokio::select! {
144 result = rx => {
145 result.map_err(|_| anyhow!("Authentication cancelled"))?
146 }
147 _ = server => {
148 Err(anyhow!("Server stopped unexpectedly"))
149 }
150 () = tokio::time::sleep(std::time::Duration::from_secs(CALLBACK_TIMEOUT_SECS)) => {
151 Err(anyhow!("Authentication timed out after {CALLBACK_TIMEOUT_SECS} seconds"))
152 }
153 }
154}