Skip to main content

santui_auth/
lib.rs

1use santui_core::auth::{AuthHandle, User};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Write};
5use std::net::TcpListener;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::thread;
10use std::time::Duration;
11use url::Url;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct StoredToken {
15    id: String,
16    email: String,
17    name: String,
18    avatar_url: Option<String>,
19    provider: String,
20    access_token: String,
21    refresh_token: Option<String>,
22}
23
24#[derive(Debug, Clone)]
25pub struct AuthConfig {
26    pub client_id: String,
27    pub client_secret: Option<String>,
28    pub auth_uri: String,
29    pub token_uri: String,
30    pub scopes: Vec<String>,
31    pub redirect_port: u16,
32}
33
34impl AuthConfig {
35    pub fn google(client_id: String, client_secret: Option<String>) -> Self {
36        AuthConfig {
37            client_id,
38            client_secret,
39            auth_uri: "https://accounts.google.com/o/oauth2/v2/auth".into(),
40            token_uri: "https://oauth2.googleapis.com/token".into(),
41            scopes: vec!["openid".into(), "email".into(), "profile".into()],
42            redirect_port: 9842,
43        }
44    }
45
46    pub fn github(client_id: String) -> Self {
47        AuthConfig {
48            client_id,
49            client_secret: None,
50            auth_uri: String::new(),
51            token_uri: "https://github.com/login/oauth/access_token".into(),
52            scopes: vec!["read:user".into(), "user:email".into()],
53            redirect_port: 0,
54        }
55    }
56}
57
58#[cfg(target_os = "windows")]
59fn open_browser(url: &str) {
60    let _ = std::process::Command::new("cmd")
61        .args(["/c", "start", &url.replace('&', "^&")])
62        .spawn();
63}
64
65#[cfg(target_os = "linux")]
66fn open_browser(url: &str) {
67    let _ = std::process::Command::new("xdg-open").arg(url).spawn();
68}
69
70#[cfg(target_os = "macos")]
71fn open_browser(url: &str) {
72    let _ = std::process::Command::new("open").arg(url).spawn();
73}
74
75#[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
76fn open_browser(url: &str) {
77    let _ = std::process::Command::new("xdg-open").arg(url).spawn();
78}
79
80fn bind_with_fallback() -> Result<(TcpListener, u16), Box<dyn std::error::Error>> {
81    for port in 9842..9850 {
82        if let Ok(listener) = TcpListener::bind(("127.0.0.1", port)) {
83            return Ok((listener, port));
84        }
85    }
86    let listener = TcpListener::bind(("127.0.0.1", 0))?;
87    let port = listener.local_addr()?.port();
88    Ok((listener, port))
89}
90
91fn handle_redirect(
92    listener: TcpListener,
93) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
94    let (stream, _) = listener.accept()?;
95    stream.set_read_timeout(Some(Duration::from_secs(120)))?;
96    let mut reader = BufReader::new(&stream);
97    let mut request_line = String::new();
98    reader.read_line(&mut request_line)?;
99
100    let params = request_line
101        .split_whitespace()
102        .nth(1)
103        .and_then(|path| {
104            let full_url = format!("http://localhost{path}");
105            Url::parse(&full_url).ok().map(|u| {
106                u.query_pairs()
107                    .map(|(k, v)| (k.into_owned(), v.into_owned()))
108                    .collect::<HashMap<String, String>>()
109            })
110        })
111        .ok_or_else(|| "No query parameters in redirect".to_string())?;
112
113    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<!DOCTYPE html><html lang=\"en\"><head><meta charset=\"UTF-8\"><script src=\"https://cdn.tailwindcss.com\"></script><title>Santui — Signed In</title></head><body class=\"bg-gradient-to-br from-gray-900 via-slate-800 to-gray-900 min-h-screen flex items-center justify-center font-sans\"><div class=\"bg-white/10 backdrop-blur-lg rounded-lg shadow-2xl border border-white/20 p-8 max-w-md w-full mx-4 text-center\"><div class=\"text-emerald-400 mb-4\"><svg class=\"w-16 h-16 mx-auto mb-4\" fill=\"none\" stroke=\"currentColor\" viewBox=\"0 0 24 24\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"1.5\" d=\"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"/></svg><h1 class=\"text-2xl font-bold mb-1\">Signed In!</h1><p class=\"text-gray-400 text-sm\">You can close this window.</p></div></div></body></html>";
114    let mut stream = stream;
115    let _ = stream.write_all(response.as_bytes());
116
117    if let Some(err) = params.get("error") {
118        return Err(format!("OAuth error from server: {err}").into());
119    }
120
121    Ok(params)
122}
123
124#[derive(Deserialize)]
125struct DeviceCodeResponse {
126    device_code: String,
127    user_code: String,
128    #[allow(dead_code)]
129    verification_uri: String,
130    interval: Option<u64>,
131}
132
133#[derive(Deserialize)]
134struct DeviceTokenResponse {
135    access_token: Option<String>,
136    error: Option<String>,
137}
138
139fn request_device_code(
140    config: &AuthConfig,
141) -> Result<DeviceCodeResponse, Box<dyn std::error::Error>> {
142    let scope = config.scopes.join(" ");
143    let mut resp = ureq::post("https://github.com/login/device/code")
144        .header("Accept", "application/json")
145        .send_form([
146            ("client_id", config.client_id.as_str()),
147            ("scope", scope.as_str()),
148        ])?;
149    let text = resp.body_mut().read_to_string()?;
150    Ok(serde_json::from_str(&text)?)
151}
152
153fn poll_device_token(
154    config: &AuthConfig,
155    device_code: &str,
156    interval: u64,
157) -> Result<String, Box<dyn std::error::Error>> {
158    loop {
159        std::thread::sleep(std::time::Duration::from_secs(interval));
160        let mut resp = ureq::post(&config.token_uri)
161            .header("Accept", "application/json")
162            .send_form([
163                ("client_id", config.client_id.as_str()),
164                ("device_code", device_code),
165                ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
166            ])?;
167        let text = resp.body_mut().read_to_string()?;
168        let body: DeviceTokenResponse = serde_json::from_str(&text)?;
169        if let Some(token) = body.access_token {
170            return Ok(token);
171        }
172        match body.error.as_deref() {
173            Some("authorization_pending") => continue,
174            Some("slow_down") => continue,
175            Some(err) => return Err(format!("device flow error: {err}").into()),
176            None => return Err("unexpected device flow response".into()),
177        }
178    }
179}
180
181fn user_from_token(provider: &str, access_token: &str) -> Result<User, Box<dyn std::error::Error>> {
182    match provider {
183        "github" => {
184            let mut resp = ureq::get("https://api.github.com/user")
185                .header("Authorization", &format!("Bearer {access_token}"))
186                .header("Accept", "application/vnd.github.v3+json")
187                .call()?;
188            let body: serde_json::Value = serde_json::from_str(&resp.body_mut().read_to_string()?)?;
189            Ok(User {
190                id: body["id"].to_string(),
191                email: body["email"].as_str().unwrap_or("").into(),
192                name: body["login"].as_str().unwrap_or("").into(),
193                avatar_url: body["avatar_url"].as_str().map(|s| s.into()),
194                provider: provider.into(),
195            })
196        }
197        _ => Err("unsupported provider".into()),
198    }
199}
200
201pub struct AuthClient {
202    providers: HashMap<String, AuthConfig>,
203    user: Arc<Mutex<Option<User>>>,
204    pending_sign_in: Arc<Mutex<Option<Result<User, String>>>>,
205    auth_msg: Arc<Mutex<Option<String>>>,
206    token_path: PathBuf,
207    vercel_url: String,
208}
209
210impl AuthClient {
211    pub fn new(providers: Vec<(String, AuthConfig)>) -> Self {
212        let token_path = dirs::data_dir()
213            .unwrap_or_else(|| PathBuf::from("."))
214            .join("santui")
215            .join("auth-tokens.json");
216        let user = Self::load_tokens(&token_path);
217        AuthClient {
218            providers: providers.into_iter().collect(),
219            user: Arc::new(Mutex::new(user)),
220            pending_sign_in: Arc::new(Mutex::new(None)),
221            auth_msg: Arc::new(Mutex::new(None)),
222            token_path,
223            vercel_url: String::new(),
224        }
225    }
226
227    pub fn with_vercel(mut self, url: String) -> Self {
228        self.vercel_url = url;
229        self
230    }
231
232    fn load_tokens(path: &PathBuf) -> Option<User> {
233        let data = std::fs::read_to_string(path).ok()?;
234        let stored: StoredToken = serde_json::from_str(&data).ok()?;
235        Some(User {
236            id: stored.id,
237            email: stored.email,
238            name: stored.name,
239            avatar_url: stored.avatar_url,
240            provider: stored.provider,
241        })
242    }
243
244    fn clear_tokens(&self) {
245        let _ = std::fs::remove_file(&self.token_path);
246    }
247
248    fn run_google_redirect_flow(
249        vercel_url: &str,
250        token_path: &PathBuf,
251        user_lock: &Arc<Mutex<Option<User>>>,
252        pending: &Arc<Mutex<Option<Result<User, String>>>>,
253        auth_msg: &Arc<Mutex<Option<String>>>,
254    ) {
255        let vercel = if vercel_url.is_empty() {
256            "https://santuiapp.vercel.app".to_string()
257        } else {
258            vercel_url.to_string()
259        };
260
261        let (listener, port) = match bind_with_fallback() {
262            Ok(v) => v,
263            Err(e) => {
264                *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
265                *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
266                return;
267            }
268        };
269        let auth_url = format!("{vercel}/api/auth/google?port={port}");
270        *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) =
271            Some("Google: waiting for browser…".into());
272        open_browser(&auth_url);
273
274        let params = match handle_redirect(listener) {
275            Ok(p) => p,
276            Err(e) => {
277                *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
278                *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
279                return;
280            }
281        };
282
283        let access_token = match params.get("access_token") {
284            Some(t) => t.clone(),
285            None => {
286                *pending.lock().unwrap_or_else(|e| e.into_inner()) =
287                    Some(Err("No access_token in redirect".into()));
288                *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
289                return;
290            }
291        };
292
293        let user = User {
294            id: params.get("id").cloned().unwrap_or_default(),
295            email: params.get("email").cloned().unwrap_or_default(),
296            name: params.get("name").cloned().unwrap_or_default(),
297            avatar_url: params.get("avatar_url").cloned(),
298            provider: "google".into(),
299        };
300
301        let stored = StoredToken {
302            id: user.id.clone(),
303            email: user.email.clone(),
304            name: user.name.clone(),
305            avatar_url: user.avatar_url.clone(),
306            provider: user.provider.clone(),
307            access_token,
308            refresh_token: None,
309        };
310        save_tokens_to_path(token_path, &stored);
311        *user_lock.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
312        *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
313        *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Ok(user));
314    }
315
316    fn sign_in_google(&self) -> Result<User, Box<dyn std::error::Error>> {
317        let vercel_url = self.vercel_url.clone();
318        Self::run_google_redirect_flow(
319            &vercel_url,
320            &self.token_path,
321            &self.user,
322            &self.pending_sign_in,
323            &self.auth_msg,
324        );
325
326        // Block until the flow completes
327        loop {
328            if let Some(result) = self
329                .pending_sign_in
330                .lock()
331                .unwrap_or_else(|e| e.into_inner())
332                .take()
333            {
334                *self.auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
335                return result.map_err(|e| e.into());
336            }
337            thread::sleep(Duration::from_millis(100));
338        }
339    }
340
341    fn start_sign_in_google(&self) -> Result<(), Box<dyn std::error::Error>> {
342        let vercel_url = self.vercel_url.clone();
343        let token_path = self.token_path.clone();
344        let user_lock = Arc::clone(&self.user);
345        let pending = Arc::clone(&self.pending_sign_in);
346        let auth_msg = Arc::clone(&self.auth_msg);
347
348        thread::spawn(move || {
349            Self::run_google_redirect_flow(
350                &vercel_url,
351                &token_path,
352                &user_lock,
353                &pending,
354                &auth_msg,
355            );
356        });
357
358        Ok(())
359    }
360
361    fn run_github_device_flow(
362        config: &AuthConfig,
363        token_path: &PathBuf,
364        user_lock: &Arc<Mutex<Option<User>>>,
365        pending: &Arc<Mutex<Option<Result<User, String>>>>,
366        auth_msg: &Arc<Mutex<Option<String>>>,
367    ) {
368        let device = match request_device_code(config) {
369            Ok(d) => d,
370            Err(e) => {
371                *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
372                *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
373                return;
374            }
375        };
376        let user_code = device.user_code.clone();
377        let interval = device.interval.unwrap_or(5);
378        let activation_url = format!("https://github.com/login/device?user_code={user_code}");
379        *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = Some(format!(
380            "GitHub: enter code {user_code} at github.com/login/device"
381        ));
382        open_browser(&activation_url);
383
384        let access_token = match poll_device_token(config, &device.device_code, interval) {
385            Ok(t) => t,
386            Err(e) => {
387                *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
388                *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
389                return;
390            }
391        };
392
393        let user = match user_from_token("github", &access_token) {
394            Ok(u) => u,
395            Err(e) => {
396                *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
397                *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
398                return;
399            }
400        };
401
402        let stored = StoredToken {
403            id: user.id.clone(),
404            email: user.email.clone(),
405            name: user.name.clone(),
406            avatar_url: user.avatar_url.clone(),
407            provider: user.provider.clone(),
408            access_token,
409            refresh_token: None,
410        };
411        save_tokens_to_path(token_path, &stored);
412        *user_lock.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
413        *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
414        *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Ok(user));
415    }
416
417    fn sign_in_github(&self) -> Result<User, Box<dyn std::error::Error>> {
418        let config = self
419            .providers
420            .get("github")
421            .ok_or_else(|| "GitHub auth not configured".to_string())?;
422
423        let clone = config.clone();
424        Self::run_github_device_flow(
425            &clone,
426            &self.token_path,
427            &self.user,
428            &self.pending_sign_in,
429            &self.auth_msg,
430        );
431
432        // Block until the flow completes (read from pending)
433        loop {
434            if let Some(result) = self
435                .pending_sign_in
436                .lock()
437                .unwrap_or_else(|e| e.into_inner())
438                .take()
439            {
440                return result.map_err(|e| e.into());
441            }
442            thread::sleep(Duration::from_millis(100));
443        }
444    }
445
446    fn start_sign_in_github(&self) -> Result<(), Box<dyn std::error::Error>> {
447        let config = self
448            .providers
449            .get("github")
450            .ok_or_else(|| "GitHub auth not configured".to_string())?
451            .clone();
452        let token_path = self.token_path.clone();
453        let user_lock = Arc::clone(&self.user);
454        let pending = Arc::clone(&self.pending_sign_in);
455        let msg = Arc::clone(&self.auth_msg);
456
457        thread::spawn(move || {
458            Self::run_github_device_flow(&config, &token_path, &user_lock, &pending, &msg);
459        });
460
461        Ok(())
462    }
463}
464
465fn save_tokens_to_path(token_path: &PathBuf, stored: &StoredToken) {
466    if let Some(parent) = token_path.parent() {
467        let _ = std::fs::create_dir_all(parent);
468    }
469    if let Ok(data) = serde_json::to_string_pretty(stored) {
470        let _ = std::fs::write(token_path, data);
471    }
472}
473
474impl AuthHandle for AuthClient {
475    fn current_user(&self) -> Option<User> {
476        self.user.lock().unwrap_or_else(|e| e.into_inner()).clone()
477    }
478
479    fn bearer_token(&self) -> Option<String> {
480        let data = std::fs::read_to_string(&self.token_path).ok()?;
481        let stored: StoredToken = serde_json::from_str(&data).ok()?;
482        Some(stored.access_token)
483    }
484
485    fn sign_in(&self, provider: &str) -> Result<User, Box<dyn std::error::Error>> {
486        match provider {
487            "google" => self.sign_in_google(),
488            "github" => self.sign_in_github(),
489            _ => Err("unsupported provider".into()),
490        }
491    }
492
493    fn start_sign_in(&self, provider: &str) -> Result<(), Box<dyn std::error::Error>> {
494        match provider {
495            "github" => self.start_sign_in_github(),
496            "google" => self.start_sign_in_google(),
497            _ => Err("unsupported provider".into()),
498        }
499    }
500
501    fn drain_pending_sign_in(&self) -> Option<Result<User, Box<dyn std::error::Error>>> {
502        let mut guard = self
503            .pending_sign_in
504            .lock()
505            .unwrap_or_else(|e| e.into_inner());
506        guard.take().map(|r| r.map_err(|e| e.into()))
507    }
508
509    fn auth_message(&self) -> Option<String> {
510        self.auth_msg
511            .lock()
512            .unwrap_or_else(|e| e.into_inner())
513            .clone()
514    }
515
516    fn sign_out(&self) {
517        self.clear_tokens();
518        *self.auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
519        *self.user.lock().unwrap_or_else(|e| e.into_inner()) = None;
520    }
521}