Skip to main content

st/proxy/
oauth.rs

1//! OAuth 2.0 + PKCE framework for proxy providers.
2//!
3//! No high-level crate — just reqwest + sha2 + base64 + rand. The flow:
4//!
5//!   1. `begin(provider)` generates code_verifier/challenge, binds a loopback
6//!      listener on 127.0.0.1:<random>, returns the authorization URL.
7//!   2. User opens the URL, completes login, provider redirects back to the
8//!      loopback URL with `?code=...&state=...`.
9//!   3. The loopback task exchanges the code for tokens and stores them via
10//!      `token_store::save`.
11//!
12//! Designed to be driven from the proxy admin API and/or the CLI.
13
14use crate::proxy::token_store::{self, StoredToken};
15use anyhow::{anyhow, bail, Context, Result};
16use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
17use rand::{distributions::Alphanumeric, Rng};
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use std::net::SocketAddr;
22use std::time::Duration;
23use tokio::net::TcpListener;
24use tokio::sync::oneshot;
25
26/// Static configuration for a single OAuth provider.
27#[derive(Debug, Clone)]
28pub struct ProviderConfig {
29    pub name: &'static str,
30    pub auth_url: &'static str,
31    pub token_url: &'static str,
32    /// Only filled from env / user config. We do not ship client_ids.
33    pub client_id_env: &'static str,
34    /// Some providers (Google desktop apps, GitHub) require a client_secret;
35    /// others (public native apps w/ PKCE) don't. None => PKCE-only.
36    pub client_secret_env: Option<&'static str>,
37    pub scopes: &'static [&'static str],
38    /// Extra query params appended to the auth URL (e.g. `access_type=offline`).
39    pub extra_auth_params: &'static [(&'static str, &'static str)],
40}
41
42impl ProviderConfig {
43    fn client_id(&self) -> Result<String> {
44        std::env::var(self.client_id_env)
45            .map_err(|_| anyhow!("{} not set — register an OAuth app and export the client id", self.client_id_env))
46    }
47
48    fn client_secret(&self) -> Option<String> {
49        self.client_secret_env.and_then(|k| std::env::var(k).ok())
50    }
51}
52
53pub mod providers {
54    use super::ProviderConfig;
55
56    /// Google AI Studio / Antigravity. Uses standard Google OAuth2.
57    /// Register a "Desktop app" OAuth client in GCP and export client id/secret.
58    pub const ANTIGRAVITY: ProviderConfig = ProviderConfig {
59        name: "antigravity",
60        auth_url: "https://accounts.google.com/o/oauth2/v2/auth",
61        token_url: "https://oauth2.googleapis.com/token",
62        client_id_env: "ANTIGRAVITY_CLIENT_ID",
63        client_secret_env: Some("ANTIGRAVITY_CLIENT_SECRET"),
64        scopes: &[
65            "https://www.googleapis.com/auth/generative-language",
66            "openid",
67            "email",
68        ],
69        extra_auth_params: &[("access_type", "offline"), ("prompt", "consent")],
70    };
71
72    /// OpenAI Codex CLI — same public OAuth endpoints as ChatGPT desktop.
73    /// Placeholder scopes; refine once we verify against Codex CLI source.
74    pub const CODEX: ProviderConfig = ProviderConfig {
75        name: "codex",
76        auth_url: "https://auth.openai.com/authorize",
77        token_url: "https://auth.openai.com/oauth/token",
78        client_id_env: "CODEX_CLIENT_ID",
79        client_secret_env: None,
80        scopes: &["openid", "email", "profile", "offline_access"],
81        extra_auth_params: &[],
82    };
83
84    pub fn by_name(name: &str) -> Option<ProviderConfig> {
85        match name.to_lowercase().as_str() {
86            "antigravity" | "google" => Some(ANTIGRAVITY),
87            "codex" | "openai" => Some(CODEX),
88            _ => None,
89        }
90    }
91}
92
93/// A flow-in-progress: the caller opens `auth_url`, we wait for the callback,
94/// exchange the code, and store the resulting token.
95pub struct StartedFlow {
96    pub auth_url: String,
97    pub state: String,
98    pub redirect_uri: String,
99    done: oneshot::Receiver<Result<StoredToken>>,
100}
101
102impl StartedFlow {
103    /// Wait up to `timeout` for the user to complete the flow.
104    pub async fn wait(self, timeout: Duration) -> Result<StoredToken> {
105        match tokio::time::timeout(timeout, self.done).await {
106            Ok(Ok(res)) => res,
107            Ok(Err(_)) => bail!("oauth callback channel closed"),
108            Err(_) => bail!("oauth flow timed out"),
109        }
110    }
111}
112
113/// Kick off an OAuth flow. Returns a StartedFlow with the auth URL to open.
114pub async fn begin(provider: ProviderConfig, account: String) -> Result<StartedFlow> {
115    let client_id = provider.client_id()?;
116    let client_secret = provider.client_secret();
117
118    let verifier: String = rand::thread_rng()
119        .sample_iter(&Alphanumeric)
120        .take(64)
121        .map(char::from)
122        .collect();
123    let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
124
125    let state: String = rand::thread_rng()
126        .sample_iter(&Alphanumeric)
127        .take(24)
128        .map(char::from)
129        .collect();
130
131    let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?;
132    let port = listener.local_addr()?.port();
133    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
134
135    let scope = provider.scopes.join(" ");
136    let mut auth_url = format!(
137        "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
138        provider.auth_url,
139        urlencoding::encode(&client_id),
140        urlencoding::encode(&redirect_uri),
141        urlencoding::encode(&scope),
142        urlencoding::encode(&state),
143        urlencoding::encode(&challenge),
144    );
145    for (k, v) in provider.extra_auth_params {
146        auth_url.push_str(&format!("&{}={}", k, urlencoding::encode(v)));
147    }
148
149    let (tx, rx) = oneshot::channel::<Result<StoredToken>>();
150    let state_expected = state.clone();
151    let redirect_uri_cloned = redirect_uri.clone();
152    let provider_cloned = provider.clone();
153    let account_cloned = account.clone();
154
155    tokio::spawn(async move {
156        let res = run_callback(
157            listener,
158            &state_expected,
159            &redirect_uri_cloned,
160            &verifier,
161            &client_id,
162            client_secret.as_deref(),
163            &provider_cloned,
164            &account_cloned,
165        )
166        .await;
167        let _ = tx.send(res);
168    });
169
170    Ok(StartedFlow {
171        auth_url,
172        state,
173        redirect_uri,
174        done: rx,
175    })
176}
177
178async fn run_callback(
179    listener: TcpListener,
180    state_expected: &str,
181    redirect_uri: &str,
182    verifier: &str,
183    client_id: &str,
184    client_secret: Option<&str>,
185    provider: &ProviderConfig,
186    account: &str,
187) -> Result<StoredToken> {
188    let (mut stream, _) = listener.accept().await?;
189    let (code, state_got) = read_callback_query(&mut stream).await?;
190    if state_got != state_expected {
191        write_plain(&mut stream, "oauth state mismatch — possible CSRF, aborting").await;
192        bail!("oauth state mismatch");
193    }
194
195    let token = exchange_code(
196        provider,
197        client_id,
198        client_secret,
199        &code,
200        redirect_uri,
201        verifier,
202    )
203    .await;
204
205    match &token {
206        Ok(_) => {
207            write_plain(
208                &mut stream,
209                "Smart Tree proxy: sign-in complete. You can close this tab.",
210            )
211            .await
212        }
213        Err(e) => {
214            write_plain(&mut stream, &format!("sign-in failed: {}", e)).await;
215        }
216    }
217
218    let token = token?;
219    token_store::save(provider.name, account, &token)?;
220    Ok(token)
221}
222
223async fn read_callback_query(
224    stream: &mut tokio::net::TcpStream,
225) -> Result<(String, String)> {
226    use tokio::io::AsyncReadExt;
227    let mut buf = vec![0u8; 8192];
228    let n = stream.read(&mut buf).await?;
229    let req = String::from_utf8_lossy(&buf[..n]);
230    let first_line = req.lines().next().context("empty HTTP request")?;
231    // "GET /callback?code=...&state=... HTTP/1.1"
232    let path = first_line
233        .split_whitespace()
234        .nth(1)
235        .context("malformed request line")?;
236    let query = path.split_once('?').map(|(_, q)| q).unwrap_or("");
237    let mut code = None;
238    let mut state = None;
239    let mut error = None;
240    for pair in query.split('&') {
241        if let Some((k, v)) = pair.split_once('=') {
242            let v = urlencoding::decode(v).unwrap_or_default().into_owned();
243            match k {
244                "code" => code = Some(v),
245                "state" => state = Some(v),
246                "error" => error = Some(v),
247                _ => {}
248            }
249        }
250    }
251    if let Some(e) = error {
252        bail!("oauth provider returned error: {}", e);
253    }
254    Ok((
255        code.context("missing code in callback")?,
256        state.context("missing state in callback")?,
257    ))
258}
259
260async fn write_plain(stream: &mut tokio::net::TcpStream, body: &str) {
261    use tokio::io::AsyncWriteExt;
262    let resp = format!(
263        "HTTP/1.1 200 OK\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
264        body.len(),
265        body
266    );
267    let _ = stream.write_all(resp.as_bytes()).await;
268    let _ = stream.shutdown().await;
269}
270
271#[derive(Debug, Deserialize)]
272struct TokenResponse {
273    access_token: String,
274    #[serde(default)]
275    refresh_token: Option<String>,
276    #[serde(default)]
277    expires_in: Option<i64>,
278    #[serde(default)]
279    scope: Option<String>,
280    #[serde(default)]
281    token_type: Option<String>,
282}
283
284#[derive(Serialize)]
285struct TokenExchange<'a> {
286    grant_type: &'a str,
287    code: &'a str,
288    redirect_uri: &'a str,
289    client_id: &'a str,
290    code_verifier: &'a str,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    client_secret: Option<&'a str>,
293}
294
295async fn exchange_code(
296    provider: &ProviderConfig,
297    client_id: &str,
298    client_secret: Option<&str>,
299    code: &str,
300    redirect_uri: &str,
301    verifier: &str,
302) -> Result<StoredToken> {
303    let body = TokenExchange {
304        grant_type: "authorization_code",
305        code,
306        redirect_uri,
307        client_id,
308        code_verifier: verifier,
309        client_secret,
310    };
311
312    let res = Client::new()
313        .post(provider.token_url)
314        .form(&body)
315        .send()
316        .await?;
317
318    if !res.status().is_success() {
319        let text = res.text().await.unwrap_or_default();
320        bail!("token endpoint returned error: {}", text);
321    }
322
323    let t: TokenResponse = res.json().await?;
324    let expires_at = t
325        .expires_in
326        .map(|s| chrono::Utc::now() + chrono::Duration::seconds(s));
327    Ok(StoredToken {
328        access_token: t.access_token,
329        refresh_token: t.refresh_token,
330        expires_at,
331        scope: t.scope,
332        token_type: t.token_type,
333    })
334}
335
336/// Refresh an existing stored token in place. Returns the refreshed token.
337pub async fn refresh(provider: ProviderConfig, account: &str) -> Result<StoredToken> {
338    let current = token_store::load(provider.name, account)?
339        .ok_or_else(|| anyhow!("no stored token for {}:{}", provider.name, account))?;
340    let refresh_token = current
341        .refresh_token
342        .as_deref()
343        .ok_or_else(|| anyhow!("stored token has no refresh_token"))?;
344
345    let client_id = provider.client_id()?;
346    let client_secret = provider.client_secret();
347
348    #[derive(Serialize)]
349    struct RefreshBody<'a> {
350        grant_type: &'a str,
351        refresh_token: &'a str,
352        client_id: &'a str,
353        #[serde(skip_serializing_if = "Option::is_none")]
354        client_secret: Option<&'a str>,
355    }
356
357    let res = Client::new()
358        .post(provider.token_url)
359        .form(&RefreshBody {
360            grant_type: "refresh_token",
361            refresh_token,
362            client_id: &client_id,
363            client_secret: client_secret.as_deref(),
364        })
365        .send()
366        .await?;
367
368    if !res.status().is_success() {
369        let text = res.text().await.unwrap_or_default();
370        bail!("refresh failed: {}", text);
371    }
372
373    let t: TokenResponse = res.json().await?;
374    let expires_at = t
375        .expires_in
376        .map(|s| chrono::Utc::now() + chrono::Duration::seconds(s));
377    let refreshed = StoredToken {
378        access_token: t.access_token,
379        // Google omits refresh_token on refresh responses; keep the original.
380        refresh_token: t.refresh_token.or(current.refresh_token),
381        expires_at,
382        scope: t.scope.or(current.scope),
383        token_type: t.token_type.or(current.token_type),
384    };
385    token_store::save(provider.name, account, &refreshed)?;
386    Ok(refreshed)
387}
388
389/// Load a token, transparently refreshing if it's expired.
390pub async fn load_fresh(provider: ProviderConfig, account: &str) -> Result<StoredToken> {
391    match token_store::load(provider.name, account)? {
392        Some(t) if !t.is_expired() => Ok(t),
393        Some(_) => refresh(provider, account).await,
394        None => bail!("no stored token for {}:{}", provider.name, account),
395    }
396}
397
398// Small dependency-free urlencoding shim so we don't pull another crate.
399mod urlencoding {
400    pub fn encode(s: &str) -> String {
401        let mut out = String::with_capacity(s.len());
402        for b in s.bytes() {
403            match b {
404                b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
405                    out.push(b as char)
406                }
407                _ => out.push_str(&format!("%{:02X}", b)),
408            }
409        }
410        out
411    }
412
413    pub fn decode(s: &str) -> Option<std::borrow::Cow<'_, str>> {
414        let mut out = Vec::with_capacity(s.len());
415        let bytes = s.as_bytes();
416        let mut i = 0;
417        while i < bytes.len() {
418            match bytes[i] {
419                b'+' => {
420                    out.push(b' ');
421                    i += 1;
422                }
423                b'%' if i + 2 < bytes.len() => {
424                    let hi = (bytes[i + 1] as char).to_digit(16)?;
425                    let lo = (bytes[i + 2] as char).to_digit(16)?;
426                    out.push((hi * 16 + lo) as u8);
427                    i += 3;
428                }
429                c => {
430                    out.push(c);
431                    i += 1;
432                }
433            }
434        }
435        Some(String::from_utf8(out).ok()?.into())
436    }
437}