1use crate::proxy::token_store::{self, StoredToken};
15use anyhow::{anyhow, bail, Context, Result};
16use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
17use rand::{distributions::Alphanumeric, Rng};
18use reqwest::Client;
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use std::net::SocketAddr;
22use std::time::Duration;
23use tokio::net::TcpListener;
24use tokio::sync::oneshot;
25
26#[derive(Debug, Clone)]
28pub struct ProviderConfig {
29 pub name: &'static str,
30 pub auth_url: &'static str,
31 pub token_url: &'static str,
32 pub client_id_env: &'static str,
34 pub client_secret_env: Option<&'static str>,
37 pub scopes: &'static [&'static str],
38 pub extra_auth_params: &'static [(&'static str, &'static str)],
40}
41
42impl ProviderConfig {
43 fn client_id(&self) -> Result<String> {
44 std::env::var(self.client_id_env)
45 .map_err(|_| anyhow!("{} not set — register an OAuth app and export the client id", self.client_id_env))
46 }
47
48 fn client_secret(&self) -> Option<String> {
49 self.client_secret_env.and_then(|k| std::env::var(k).ok())
50 }
51}
52
53pub mod providers {
54 use super::ProviderConfig;
55
56 pub const ANTIGRAVITY: ProviderConfig = ProviderConfig {
59 name: "antigravity",
60 auth_url: "https://accounts.google.com/o/oauth2/v2/auth",
61 token_url: "https://oauth2.googleapis.com/token",
62 client_id_env: "ANTIGRAVITY_CLIENT_ID",
63 client_secret_env: Some("ANTIGRAVITY_CLIENT_SECRET"),
64 scopes: &[
65 "https://www.googleapis.com/auth/generative-language",
66 "openid",
67 "email",
68 ],
69 extra_auth_params: &[("access_type", "offline"), ("prompt", "consent")],
70 };
71
72 pub const CODEX: ProviderConfig = ProviderConfig {
75 name: "codex",
76 auth_url: "https://auth.openai.com/authorize",
77 token_url: "https://auth.openai.com/oauth/token",
78 client_id_env: "CODEX_CLIENT_ID",
79 client_secret_env: None,
80 scopes: &["openid", "email", "profile", "offline_access"],
81 extra_auth_params: &[],
82 };
83
84 pub fn by_name(name: &str) -> Option<ProviderConfig> {
85 match name.to_lowercase().as_str() {
86 "antigravity" | "google" => Some(ANTIGRAVITY),
87 "codex" | "openai" => Some(CODEX),
88 _ => None,
89 }
90 }
91}
92
93pub struct StartedFlow {
96 pub auth_url: String,
97 pub state: String,
98 pub redirect_uri: String,
99 done: oneshot::Receiver<Result<StoredToken>>,
100}
101
102impl StartedFlow {
103 pub async fn wait(self, timeout: Duration) -> Result<StoredToken> {
105 match tokio::time::timeout(timeout, self.done).await {
106 Ok(Ok(res)) => res,
107 Ok(Err(_)) => bail!("oauth callback channel closed"),
108 Err(_) => bail!("oauth flow timed out"),
109 }
110 }
111}
112
113pub async fn begin(provider: ProviderConfig, account: String) -> Result<StartedFlow> {
115 let client_id = provider.client_id()?;
116 let client_secret = provider.client_secret();
117
118 let verifier: String = rand::thread_rng()
119 .sample_iter(&Alphanumeric)
120 .take(64)
121 .map(char::from)
122 .collect();
123 let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
124
125 let state: String = rand::thread_rng()
126 .sample_iter(&Alphanumeric)
127 .take(24)
128 .map(char::from)
129 .collect();
130
131 let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))).await?;
132 let port = listener.local_addr()?.port();
133 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
134
135 let scope = provider.scopes.join(" ");
136 let mut auth_url = format!(
137 "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
138 provider.auth_url,
139 urlencoding::encode(&client_id),
140 urlencoding::encode(&redirect_uri),
141 urlencoding::encode(&scope),
142 urlencoding::encode(&state),
143 urlencoding::encode(&challenge),
144 );
145 for (k, v) in provider.extra_auth_params {
146 auth_url.push_str(&format!("&{}={}", k, urlencoding::encode(v)));
147 }
148
149 let (tx, rx) = oneshot::channel::<Result<StoredToken>>();
150 let state_expected = state.clone();
151 let redirect_uri_cloned = redirect_uri.clone();
152 let provider_cloned = provider.clone();
153 let account_cloned = account.clone();
154
155 tokio::spawn(async move {
156 let res = run_callback(
157 listener,
158 &state_expected,
159 &redirect_uri_cloned,
160 &verifier,
161 &client_id,
162 client_secret.as_deref(),
163 &provider_cloned,
164 &account_cloned,
165 )
166 .await;
167 let _ = tx.send(res);
168 });
169
170 Ok(StartedFlow {
171 auth_url,
172 state,
173 redirect_uri,
174 done: rx,
175 })
176}
177
178async fn run_callback(
179 listener: TcpListener,
180 state_expected: &str,
181 redirect_uri: &str,
182 verifier: &str,
183 client_id: &str,
184 client_secret: Option<&str>,
185 provider: &ProviderConfig,
186 account: &str,
187) -> Result<StoredToken> {
188 let (mut stream, _) = listener.accept().await?;
189 let (code, state_got) = read_callback_query(&mut stream).await?;
190 if state_got != state_expected {
191 write_plain(&mut stream, "oauth state mismatch — possible CSRF, aborting").await;
192 bail!("oauth state mismatch");
193 }
194
195 let token = exchange_code(
196 provider,
197 client_id,
198 client_secret,
199 &code,
200 redirect_uri,
201 verifier,
202 )
203 .await;
204
205 match &token {
206 Ok(_) => {
207 write_plain(
208 &mut stream,
209 "Smart Tree proxy: sign-in complete. You can close this tab.",
210 )
211 .await
212 }
213 Err(e) => {
214 write_plain(&mut stream, &format!("sign-in failed: {}", e)).await;
215 }
216 }
217
218 let token = token?;
219 token_store::save(provider.name, account, &token)?;
220 Ok(token)
221}
222
223async fn read_callback_query(
224 stream: &mut tokio::net::TcpStream,
225) -> Result<(String, String)> {
226 use tokio::io::AsyncReadExt;
227 let mut buf = vec![0u8; 8192];
228 let n = stream.read(&mut buf).await?;
229 let req = String::from_utf8_lossy(&buf[..n]);
230 let first_line = req.lines().next().context("empty HTTP request")?;
231 let path = first_line
233 .split_whitespace()
234 .nth(1)
235 .context("malformed request line")?;
236 let query = path.split_once('?').map(|(_, q)| q).unwrap_or("");
237 let mut code = None;
238 let mut state = None;
239 let mut error = None;
240 for pair in query.split('&') {
241 if let Some((k, v)) = pair.split_once('=') {
242 let v = urlencoding::decode(v).unwrap_or_default().into_owned();
243 match k {
244 "code" => code = Some(v),
245 "state" => state = Some(v),
246 "error" => error = Some(v),
247 _ => {}
248 }
249 }
250 }
251 if let Some(e) = error {
252 bail!("oauth provider returned error: {}", e);
253 }
254 Ok((
255 code.context("missing code in callback")?,
256 state.context("missing state in callback")?,
257 ))
258}
259
260async fn write_plain(stream: &mut tokio::net::TcpStream, body: &str) {
261 use tokio::io::AsyncWriteExt;
262 let resp = format!(
263 "HTTP/1.1 200 OK\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
264 body.len(),
265 body
266 );
267 let _ = stream.write_all(resp.as_bytes()).await;
268 let _ = stream.shutdown().await;
269}
270
271#[derive(Debug, Deserialize)]
272struct TokenResponse {
273 access_token: String,
274 #[serde(default)]
275 refresh_token: Option<String>,
276 #[serde(default)]
277 expires_in: Option<i64>,
278 #[serde(default)]
279 scope: Option<String>,
280 #[serde(default)]
281 token_type: Option<String>,
282}
283
284#[derive(Serialize)]
285struct TokenExchange<'a> {
286 grant_type: &'a str,
287 code: &'a str,
288 redirect_uri: &'a str,
289 client_id: &'a str,
290 code_verifier: &'a str,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 client_secret: Option<&'a str>,
293}
294
295async fn exchange_code(
296 provider: &ProviderConfig,
297 client_id: &str,
298 client_secret: Option<&str>,
299 code: &str,
300 redirect_uri: &str,
301 verifier: &str,
302) -> Result<StoredToken> {
303 let body = TokenExchange {
304 grant_type: "authorization_code",
305 code,
306 redirect_uri,
307 client_id,
308 code_verifier: verifier,
309 client_secret,
310 };
311
312 let res = Client::new()
313 .post(provider.token_url)
314 .form(&body)
315 .send()
316 .await?;
317
318 if !res.status().is_success() {
319 let text = res.text().await.unwrap_or_default();
320 bail!("token endpoint returned error: {}", text);
321 }
322
323 let t: TokenResponse = res.json().await?;
324 let expires_at = t
325 .expires_in
326 .map(|s| chrono::Utc::now() + chrono::Duration::seconds(s));
327 Ok(StoredToken {
328 access_token: t.access_token,
329 refresh_token: t.refresh_token,
330 expires_at,
331 scope: t.scope,
332 token_type: t.token_type,
333 })
334}
335
336pub async fn refresh(provider: ProviderConfig, account: &str) -> Result<StoredToken> {
338 let current = token_store::load(provider.name, account)?
339 .ok_or_else(|| anyhow!("no stored token for {}:{}", provider.name, account))?;
340 let refresh_token = current
341 .refresh_token
342 .as_deref()
343 .ok_or_else(|| anyhow!("stored token has no refresh_token"))?;
344
345 let client_id = provider.client_id()?;
346 let client_secret = provider.client_secret();
347
348 #[derive(Serialize)]
349 struct RefreshBody<'a> {
350 grant_type: &'a str,
351 refresh_token: &'a str,
352 client_id: &'a str,
353 #[serde(skip_serializing_if = "Option::is_none")]
354 client_secret: Option<&'a str>,
355 }
356
357 let res = Client::new()
358 .post(provider.token_url)
359 .form(&RefreshBody {
360 grant_type: "refresh_token",
361 refresh_token,
362 client_id: &client_id,
363 client_secret: client_secret.as_deref(),
364 })
365 .send()
366 .await?;
367
368 if !res.status().is_success() {
369 let text = res.text().await.unwrap_or_default();
370 bail!("refresh failed: {}", text);
371 }
372
373 let t: TokenResponse = res.json().await?;
374 let expires_at = t
375 .expires_in
376 .map(|s| chrono::Utc::now() + chrono::Duration::seconds(s));
377 let refreshed = StoredToken {
378 access_token: t.access_token,
379 refresh_token: t.refresh_token.or(current.refresh_token),
381 expires_at,
382 scope: t.scope.or(current.scope),
383 token_type: t.token_type.or(current.token_type),
384 };
385 token_store::save(provider.name, account, &refreshed)?;
386 Ok(refreshed)
387}
388
389pub async fn load_fresh(provider: ProviderConfig, account: &str) -> Result<StoredToken> {
391 match token_store::load(provider.name, account)? {
392 Some(t) if !t.is_expired() => Ok(t),
393 Some(_) => refresh(provider, account).await,
394 None => bail!("no stored token for {}:{}", provider.name, account),
395 }
396}
397
398mod urlencoding {
400 pub fn encode(s: &str) -> String {
401 let mut out = String::with_capacity(s.len());
402 for b in s.bytes() {
403 match b {
404 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
405 out.push(b as char)
406 }
407 _ => out.push_str(&format!("%{:02X}", b)),
408 }
409 }
410 out
411 }
412
413 pub fn decode(s: &str) -> Option<std::borrow::Cow<'_, str>> {
414 let mut out = Vec::with_capacity(s.len());
415 let bytes = s.as_bytes();
416 let mut i = 0;
417 while i < bytes.len() {
418 match bytes[i] {
419 b'+' => {
420 out.push(b' ');
421 i += 1;
422 }
423 b'%' if i + 2 < bytes.len() => {
424 let hi = (bytes[i + 1] as char).to_digit(16)?;
425 let lo = (bytes[i + 2] as char).to_digit(16)?;
426 out.push((hi * 16 + lo) as u8);
427 i += 3;
428 }
429 c => {
430 out.push(c);
431 i += 1;
432 }
433 }
434 }
435 Some(String::from_utf8(out).ok()?.into())
436 }
437}