systemprompt_cloud/oauth/
client.rs1use anyhow::{anyhow, Context, Result};
2use axum::extract::Query;
3use axum::response::Html;
4use axum::routing::get;
5use axum::Router;
6use reqwest::Client;
7use std::sync::Arc;
8use systemprompt_logging::CliService;
9use tokio::sync::{oneshot, Mutex};
10
11use crate::constants::oauth::{CALLBACK_PORT, CALLBACK_TIMEOUT_SECS};
12use crate::OAuthProvider;
13
14#[derive(serde::Deserialize)]
15struct CallbackParams {
16 access_token: Option<String>,
17 error: Option<String>,
18 error_description: Option<String>,
19}
20
21#[derive(serde::Deserialize)]
22struct AuthorizeResponse {
23 authorize_url: String,
24}
25
26#[derive(Debug, Clone, Copy)]
27pub struct OAuthTemplates {
28 pub success_html: &'static str,
29 pub error_html: &'static str,
30}
31
32pub async fn run_oauth_flow(
33 api_url: &str,
34 provider: OAuthProvider,
35 templates: OAuthTemplates,
36) -> Result<String> {
37 let (tx, rx) = oneshot::channel::<Result<String>>();
38 let tx = Arc::new(Mutex::new(Some(tx)));
39
40 let success_html = templates.success_html.to_string();
41 let error_html = templates.error_html.to_string();
42
43 let callback_handler = {
44 let tx = tx.clone();
45 let success_html = success_html.clone();
46 let error_html = error_html.clone();
47 move |Query(params): Query<CallbackParams>| {
48 let tx = tx.clone();
49 let success_html = success_html.clone();
50 let error_html = error_html.clone();
51 async move {
52 let result = if let Some(error) = params.error {
53 let desc = params
54 .error_description
55 .unwrap_or_else(|| "(no description provided)".into());
56 Err(anyhow!("OAuth error: {} - {}", error, desc))
57 } else if let Some(token) = params.access_token {
58 Ok(token)
59 } else {
60 Err(anyhow!("No token received in callback"))
61 };
62
63 let sender = tx.lock().await.take();
64 if let Some(sender) = sender {
65 let is_success = result.is_ok();
66 let _ = sender.send(result);
67
68 if is_success {
69 Html(success_html)
70 } else {
71 Html(error_html)
72 }
73 } else {
74 Html(error_html)
75 }
76 }
77 }
78 };
79
80 let app = Router::new().route("/callback", get(callback_handler));
81 let addr = format!("127.0.0.1:{CALLBACK_PORT}");
82 let listener = tokio::net::TcpListener::bind(&addr).await?;
83
84 CliService::info(&format!("Starting authentication server on http://{addr}"));
85
86 let redirect_uri = format!("http://127.0.0.1:{CALLBACK_PORT}/callback");
87
88 CliService::info("Fetching authorization URL...");
89
90 let client = Client::new();
91 let oauth_endpoint = format!(
92 "{}/api/v1/auth/oauth/{}?redirect_uri={}",
93 api_url,
94 provider.as_str(),
95 urlencoding::encode(&redirect_uri)
96 );
97
98 let response = client
99 .get(&oauth_endpoint)
100 .send()
101 .await
102 .context("Failed to connect to API")?;
103
104 if !response.status().is_success() {
105 let status = response.status();
106 let body = response.text().await.unwrap_or_else(|e| {
107 tracing::warn!(error = %e, "Failed to read OAuth error response body");
108 format!("(body unreadable: {})", e)
109 });
110 return Err(anyhow!(
111 "Failed to get authorization URL ({}): {}",
112 status,
113 body
114 ));
115 }
116
117 let auth_response: AuthorizeResponse = response
118 .json()
119 .await
120 .context("Failed to parse authorization response")?;
121
122 let auth_url = auth_response.authorize_url;
123
124 CliService::info(&format!(
125 "Opening browser for {} authentication...",
126 provider.display_name()
127 ));
128 CliService::info(&format!("URL: {auth_url}"));
129
130 if let Err(e) = open::that(&auth_url) {
131 CliService::warning(&format!("Could not open browser automatically: {e}"));
132 CliService::info("Please open this URL manually:");
133 CliService::key_value("URL", &auth_url);
134 }
135
136 CliService::info("Waiting for authentication...");
137 CliService::info(&format!("(timeout in {CALLBACK_TIMEOUT_SECS} seconds)"));
138
139 let server = axum::serve(listener, app);
140
141 tokio::select! {
142 result = rx => {
143 result.map_err(|_| anyhow!("Authentication cancelled"))?
144 }
145 _ = server => {
146 Err(anyhow!("Server stopped unexpectedly"))
147 }
148 () = tokio::time::sleep(std::time::Duration::from_secs(CALLBACK_TIMEOUT_SECS)) => {
149 Err(anyhow!("Authentication timed out after {CALLBACK_TIMEOUT_SECS} seconds"))
150 }
151 }
152}