1use santui_core::auth::{AuthHandle, User};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Write};
5use std::net::TcpListener;
6use std::path::PathBuf;
7use std::sync::Mutex;
8use std::time::Duration;
9use url::Url;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12struct StoredToken {
13 id: String,
14 email: String,
15 name: String,
16 avatar_url: Option<String>,
17 provider: String,
18 access_token: String,
19 refresh_token: Option<String>,
20}
21
22#[derive(Debug, Clone)]
23pub struct AuthConfig {
24 pub client_id: String,
25 pub client_secret: Option<String>,
26 pub auth_uri: String,
27 pub token_uri: String,
28 pub scopes: Vec<String>,
29 pub redirect_port: u16,
30}
31
32impl AuthConfig {
33 pub fn google(client_id: String, client_secret: Option<String>) -> Self {
34 AuthConfig {
35 client_id,
36 client_secret,
37 auth_uri: "https://accounts.google.com/o/oauth2/v2/auth".into(),
38 token_uri: "https://oauth2.googleapis.com/token".into(),
39 scopes: vec!["openid".into(), "email".into(), "profile".into()],
40 redirect_port: 9842,
41 }
42 }
43
44 pub fn github(client_id: String) -> Self {
45 AuthConfig {
46 client_id,
47 client_secret: None,
48 auth_uri: String::new(),
49 token_uri: "https://github.com/login/oauth/access_token".into(),
50 scopes: vec!["read:user".into(), "user:email".into()],
51 redirect_port: 0,
52 }
53 }
54}
55
56#[cfg(target_os = "windows")]
57fn open_browser(url: &str) {
58 let _ = std::process::Command::new("cmd")
59 .args(["/c", "start", &url.replace('&', "^&")])
60 .spawn();
61}
62
63#[cfg(target_os = "linux")]
64fn open_browser(url: &str) {
65 let _ = std::process::Command::new("xdg-open").arg(url).spawn();
66}
67
68#[cfg(target_os = "macos")]
69fn open_browser(url: &str) {
70 let _ = std::process::Command::new("open").arg(url).spawn();
71}
72
73#[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
74fn open_browser(url: &str) {
75 let _ = std::process::Command::new("xdg-open").arg(url).spawn();
76}
77
78fn handle_redirect(
79 listener: TcpListener,
80) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
81 let (stream, _) = listener.accept()?;
82 stream.set_read_timeout(Some(Duration::from_secs(120)))?;
83 let mut reader = BufReader::new(&stream);
84 let mut request_line = String::new();
85 reader.read_line(&mut request_line)?;
86
87 let params = request_line
88 .split_whitespace()
89 .nth(1)
90 .and_then(|path| {
91 let full_url = format!("http://localhost{path}");
92 Url::parse(&full_url).ok().map(|u| {
93 u.query_pairs()
94 .map(|(k, v)| (k.into_owned(), v.into_owned()))
95 .collect::<HashMap<String, String>>()
96 })
97 })
98 .ok_or_else(|| "No query parameters in redirect".to_string())?;
99
100 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<!DOCTYPE html><html lang=\"en\"><head><meta charset=\"UTF-8\"><script src=\"https://cdn.tailwindcss.com\"></script><title>Santui — Signed In</title></head><body class=\"bg-gradient-to-br from-gray-900 via-slate-800 to-gray-900 min-h-screen flex items-center justify-center font-sans\"><div class=\"bg-white/10 backdrop-blur-lg rounded-lg shadow-2xl border border-white/20 p-8 max-w-md w-full mx-4 text-center\"><div class=\"text-emerald-400 mb-4\"><svg class=\"w-16 h-16 mx-auto mb-4\" fill=\"none\" stroke=\"currentColor\" viewBox=\"0 0 24 24\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"1.5\" d=\"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"/></svg><h1 class=\"text-2xl font-bold mb-1\">Signed In!</h1><p class=\"text-gray-400 text-sm\">You can close this window.</p></div></div></body></html>";
101 let mut stream = stream;
102 let _ = stream.write_all(response.as_bytes());
103
104 if let Some(err) = params.get("error") {
105 return Err(format!("OAuth error from server: {err}").into());
106 }
107
108 Ok(params)
109}
110
111#[derive(Deserialize)]
112struct DeviceCodeResponse {
113 device_code: String,
114 user_code: String,
115 #[allow(dead_code)]
116 verification_uri: String,
117 interval: Option<u64>,
118}
119
120#[derive(Deserialize)]
121struct DeviceTokenResponse {
122 access_token: Option<String>,
123 error: Option<String>,
124}
125
126fn request_device_code(
127 config: &AuthConfig,
128) -> Result<DeviceCodeResponse, Box<dyn std::error::Error>> {
129 let scope = config.scopes.join(" ");
130 let mut resp = ureq::post("https://github.com/login/device/code")
131 .header("Accept", "application/json")
132 .send_form([
133 ("client_id", config.client_id.as_str()),
134 ("scope", scope.as_str()),
135 ])?;
136 let text = resp.body_mut().read_to_string()?;
137 Ok(serde_json::from_str(&text)?)
138}
139
140fn poll_device_token(
141 config: &AuthConfig,
142 device_code: &str,
143 interval: u64,
144) -> Result<String, Box<dyn std::error::Error>> {
145 loop {
146 std::thread::sleep(std::time::Duration::from_secs(interval));
147 let mut resp = ureq::post(&config.token_uri)
148 .header("Accept", "application/json")
149 .send_form([
150 ("client_id", config.client_id.as_str()),
151 ("device_code", device_code),
152 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
153 ])?;
154 let text = resp.body_mut().read_to_string()?;
155 let body: DeviceTokenResponse = serde_json::from_str(&text)?;
156 if let Some(token) = body.access_token {
157 return Ok(token);
158 }
159 match body.error.as_deref() {
160 Some("authorization_pending") => continue,
161 Some("slow_down") => continue,
162 Some(err) => return Err(format!("device flow error: {err}").into()),
163 None => return Err("unexpected device flow response".into()),
164 }
165 }
166}
167
168fn user_from_token(provider: &str, access_token: &str) -> Result<User, Box<dyn std::error::Error>> {
169 match provider {
170 "github" => {
171 let mut resp = ureq::get("https://api.github.com/user")
172 .header("Authorization", &format!("Bearer {access_token}"))
173 .header("Accept", "application/vnd.github.v3+json")
174 .call()?;
175 let body: serde_json::Value = serde_json::from_str(&resp.body_mut().read_to_string()?)?;
176 Ok(User {
177 id: body["id"].to_string(),
178 email: body["email"].as_str().unwrap_or("").into(),
179 name: body["login"].as_str().unwrap_or("").into(),
180 avatar_url: body["avatar_url"].as_str().map(|s| s.into()),
181 provider: provider.into(),
182 })
183 }
184 _ => Err("unsupported provider".into()),
185 }
186}
187
188pub struct AuthClient {
189 providers: HashMap<String, AuthConfig>,
190 user: Mutex<Option<User>>,
191 token_path: PathBuf,
192 vercel_url: String,
193}
194
195impl AuthClient {
196 pub fn new(providers: Vec<(String, AuthConfig)>) -> Self {
197 let token_path = dirs::data_dir()
198 .unwrap_or_else(|| PathBuf::from("."))
199 .join("santui")
200 .join("auth-tokens.json");
201 let user = Self::load_tokens(&token_path);
202 AuthClient {
203 providers: providers.into_iter().collect(),
204 user: Mutex::new(user),
205 token_path,
206 vercel_url: String::new(),
207 }
208 }
209
210 pub fn with_vercel(mut self, url: String) -> Self {
211 self.vercel_url = url;
212 self
213 }
214
215 fn load_tokens(path: &PathBuf) -> Option<User> {
216 let data = std::fs::read_to_string(path).ok()?;
217 let stored: StoredToken = serde_json::from_str(&data).ok()?;
218 Some(User {
219 id: stored.id,
220 email: stored.email,
221 name: stored.name,
222 avatar_url: stored.avatar_url,
223 provider: stored.provider,
224 })
225 }
226
227 fn save_tokens(&self, stored: &StoredToken) {
228 if let Some(parent) = self.token_path.parent() {
229 let _ = std::fs::create_dir_all(parent);
230 }
231 if let Ok(data) = serde_json::to_string_pretty(stored) {
232 let _ = std::fs::write(&self.token_path, data);
233 }
234 }
235
236 fn clear_tokens(&self) {
237 let _ = std::fs::remove_file(&self.token_path);
238 }
239
240 fn sign_in_google(&self) -> Result<User, Box<dyn std::error::Error>> {
241 let port = 9842;
242 let listener = TcpListener::bind(("127.0.0.1", port))?;
243
244 let vercel = if self.vercel_url.is_empty() {
245 "https://santuiapp.vercel.app".to_string()
246 } else {
247 self.vercel_url.clone()
248 };
249 let auth_url = format!("{vercel}/api/auth/google?port={port}");
250 open_browser(&auth_url);
251
252 let params = handle_redirect(listener)?;
253
254 let access_token = params
255 .get("access_token")
256 .ok_or_else(|| "No access_token in redirect".to_string())?;
257 let user = User {
258 id: params.get("id").cloned().unwrap_or_default(),
259 email: params.get("email").cloned().unwrap_or_default(),
260 name: params.get("name").cloned().unwrap_or_default(),
261 avatar_url: params.get("avatar_url").cloned(),
262 provider: "google".into(),
263 };
264
265 self.save_tokens(&StoredToken {
266 id: user.id.clone(),
267 email: user.email.clone(),
268 name: user.name.clone(),
269 avatar_url: user.avatar_url.clone(),
270 provider: user.provider.clone(),
271 access_token: access_token.clone(),
272 refresh_token: None,
273 });
274 *self.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
275
276 Ok(user)
277 }
278
279 fn sign_in_github(&self) -> Result<User, Box<dyn std::error::Error>> {
280 let config = self
281 .providers
282 .get("github")
283 .ok_or_else(|| "GitHub auth not configured".to_string())?;
284
285 let device = request_device_code(config)?;
286 let user_code = &device.user_code;
287 let interval = device.interval.unwrap_or(5);
288
289 let activation_url = format!("https://github.com/login/device?user_code={user_code}");
290 open_browser(&activation_url);
291
292 let access_token = poll_device_token(config, &device.device_code, interval)?;
293 let user = user_from_token("github", &access_token)?;
294
295 self.save_tokens(&StoredToken {
296 id: user.id.clone(),
297 email: user.email.clone(),
298 name: user.name.clone(),
299 avatar_url: user.avatar_url.clone(),
300 provider: user.provider.clone(),
301 access_token,
302 refresh_token: None,
303 });
304 *self.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
305
306 Ok(user)
307 }
308}
309
310impl AuthHandle for AuthClient {
311 fn current_user(&self) -> Option<User> {
312 self.user.lock().unwrap_or_else(|e| e.into_inner()).clone()
313 }
314
315 fn bearer_token(&self) -> Option<String> {
316 let data = std::fs::read_to_string(&self.token_path).ok()?;
317 let stored: StoredToken = serde_json::from_str(&data).ok()?;
318 Some(stored.access_token)
319 }
320
321 fn sign_in(&self, provider: &str) -> Result<User, Box<dyn std::error::Error>> {
322 match provider {
323 "google" => self.sign_in_google(),
324 "github" => self.sign_in_github(),
325 _ => Err("unsupported provider".into()),
326 }
327 }
328
329 fn sign_out(&self) {
330 self.clear_tokens();
331 *self.user.lock().unwrap_or_else(|e| e.into_inner()) = None;
332 }
333}