1use anyhow::{Context, Result};
2use axum::{
3 Router,
4 extract::{Query, State},
5 response::Html,
6 routing::get,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::net::SocketAddr;
11use std::str::FromStr;
12use std::sync::Arc;
13use tokio::sync::{mpsc, oneshot};
14
15const DEFAULT_CALLBACK_TIMEOUT_SECS: u64 = 300;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
19#[serde(rename_all = "snake_case")]
20pub enum OAuthProvider {
21 OpenAi,
22 OpenRouter,
23}
24
25impl OAuthProvider {
26 #[must_use]
27 pub fn slug(self) -> &'static str {
28 match self {
29 Self::OpenAi => "openai",
30 Self::OpenRouter => "openrouter",
31 }
32 }
33
34 #[must_use]
35 pub fn display_name(self) -> &'static str {
36 match self {
37 Self::OpenAi => "OpenAI",
38 Self::OpenRouter => "OpenRouter",
39 }
40 }
41
42 #[must_use]
43 pub fn subtitle(self) -> &'static str {
44 match self {
45 Self::OpenAi => "Your ChatGPT subscription is now connected.",
46 Self::OpenRouter => "Your OpenRouter account is now connected.",
47 }
48 }
49
50 #[must_use]
51 pub fn failure_subtitle(self) -> &'static str {
52 match self {
53 Self::OpenAi => "Unable to connect your ChatGPT subscription.",
54 Self::OpenRouter => "Unable to connect your OpenRouter account.",
55 }
56 }
57
58 #[must_use]
59 pub fn retry_hint(self) -> String {
60 format!("You can try again anytime using /login {}", self.slug())
61 }
62
63 #[must_use]
64 pub fn supports_manual_refresh(self) -> bool {
65 matches!(self, Self::OpenAi)
66 }
67}
68
69impl fmt::Display for OAuthProvider {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.write_str(self.slug())
72 }
73}
74
75impl FromStr for OAuthProvider {
76 type Err = ();
77
78 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
79 match value.trim().to_ascii_lowercase().as_str() {
80 "openai" => Ok(Self::OpenAi),
81 "openrouter" => Ok(Self::OpenRouter),
82 _ => Err(()),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy)]
88pub struct OAuthCallbackPage {
89 provider: OAuthProvider,
90}
91
92impl OAuthCallbackPage {
93 #[must_use]
94 pub fn new(provider: OAuthProvider) -> Self {
95 Self { provider }
96 }
97}
98
99#[derive(Debug, Clone)]
100pub enum AuthCallbackOutcome {
101 Code(String),
102 Cancelled,
103 Error(String),
104}
105
106#[derive(Debug, Deserialize)]
107struct AuthCallbackParams {
108 code: Option<String>,
109 error: Option<String>,
110 error_description: Option<String>,
111 state: Option<String>,
112}
113
114struct AuthCallbackState {
115 page: OAuthCallbackPage,
116 expected_state: Option<String>,
117 result_tx: mpsc::Sender<AuthCallbackOutcome>,
118}
119
120pub async fn run_auth_code_callback_server(
121 port: u16,
122 timeout_secs: u64,
123 page: OAuthCallbackPage,
124 expected_state: Option<String>,
125) -> Result<AuthCallbackOutcome> {
126 let timeout = if timeout_secs == 0 {
127 DEFAULT_CALLBACK_TIMEOUT_SECS
128 } else {
129 timeout_secs
130 };
131 let (result_tx, mut result_rx) = mpsc::channel::<AuthCallbackOutcome>(1);
132 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
133 let state = Arc::new(AuthCallbackState {
134 page,
135 expected_state,
136 result_tx,
137 });
138
139 let app = Router::new()
140 .route("/callback", get(handle_callback))
141 .route("/auth/callback", get(handle_callback))
142 .route("/cancel", get(handle_cancel))
143 .route("/health", get(|| async { "OK" }))
144 .with_state(state);
145
146 let addr = SocketAddr::from(([127, 0, 0, 1], port));
147 let listener = tokio::net::TcpListener::bind(addr)
148 .await
149 .with_context(|| format!("failed to bind localhost callback server on port {port}"))?;
150
151 let server = axum::serve(listener, app).with_graceful_shutdown(async move {
152 let _ = shutdown_rx.await;
153 });
154 let server_handle = tokio::spawn(async move {
155 if let Err(err) = server.await {
156 tracing::error!("OAuth callback server error: {}", err);
157 }
158 });
159
160 let result = tokio::select! {
161 Some(result) = result_rx.recv() => result,
162 _ = tokio::time::sleep(std::time::Duration::from_secs(timeout)) => {
163 AuthCallbackOutcome::Error(format!("OAuth flow timed out after {timeout} seconds"))
164 }
165 };
166
167 let _ = shutdown_tx.send(());
168 let _ = server_handle.await;
169 Ok(result)
170}
171
172async fn handle_callback(
173 State(state): State<Arc<AuthCallbackState>>,
174 Query(params): Query<AuthCallbackParams>,
175) -> Html<String> {
176 tracing::info!(
177 provider = %state.page.provider,
178 has_code = params.code.is_some(),
179 has_error = params.error.is_some(),
180 "received oauth callback"
181 );
182 if let Some(expected_state) = state.expected_state.as_deref() {
183 match params.state.as_deref() {
184 Some(actual_state) if actual_state == expected_state => {}
185 _ => {
186 let message = "OAuth error: state mismatch".to_string();
187 let _ = state
188 .result_tx
189 .send(AuthCallbackOutcome::Error(message.clone()))
190 .await;
191 return Html(error_html(state.page.provider, &message));
192 }
193 }
194 }
195
196 if let Some(error) = params.error {
197 let message = match params.error_description {
198 Some(description) if !description.trim().is_empty() => {
199 format!("OAuth error: {error} - {description}")
200 }
201 _ => format!("OAuth error: {error}"),
202 };
203 let _ = state
204 .result_tx
205 .send(AuthCallbackOutcome::Error(message.clone()))
206 .await;
207 return Html(error_html(state.page.provider, &message));
208 }
209
210 let Some(code) = params.code else {
211 let message = "Missing authorization code".to_string();
212 let _ = state
213 .result_tx
214 .send(AuthCallbackOutcome::Error(message.clone()))
215 .await;
216 return Html(error_html(state.page.provider, &message));
217 };
218
219 let _ = state.result_tx.send(AuthCallbackOutcome::Code(code)).await;
220 Html(success_html(state.page.provider))
221}
222
223async fn handle_cancel(State(state): State<Arc<AuthCallbackState>>) -> Html<String> {
224 let _ = state.result_tx.send(AuthCallbackOutcome::Cancelled).await;
225 Html(cancelled_html(state.page.provider))
226}
227
228fn success_html(provider: OAuthProvider) -> String {
229 base_html(
230 "Authentication Successful",
231 provider.subtitle(),
232 Some("You may now close this window and return to VT Code."),
233 "✓",
234 "#22c55e",
235 None,
236 )
237}
238
239fn error_html(provider: OAuthProvider, error: &str) -> String {
240 base_html(
241 "Authentication Failed",
242 provider.failure_subtitle(),
243 None,
244 "✕",
245 "#ef4444",
246 Some(error),
247 )
248}
249
250fn cancelled_html(provider: OAuthProvider) -> String {
251 base_html(
252 "Authentication Cancelled",
253 &provider.retry_hint(),
254 None,
255 "—",
256 "#71717a",
257 None,
258 )
259}
260
261fn base_html(
262 title: &str,
263 subtitle: &str,
264 close_note: Option<&str>,
265 icon: &str,
266 accent: &str,
267 error: Option<&str>,
268) -> String {
269 let close_note_html = close_note
270 .map(|value| format!(r#"<p class="close-note">{}</p>"#, html_escape(value)))
271 .unwrap_or_default();
272 let error_html = error
273 .map(|value| format!(r#"<div class="error">{}</div>"#, html_escape(value)))
274 .unwrap_or_default();
275 let auto_close = if close_note.is_some() {
276 r#"<script>setTimeout(() => window.close(), 3000);</script>"#
277 } else {
278 ""
279 };
280
281 format!(
282 r##"<!DOCTYPE html>
283<html>
284<head>
285 <title>VT Code - {title}</title>
286 <style>
287 @font-face {{
288 font-family: 'SF Pro Display';
289 src: local('SF Pro Display'), local('.SF NS Display'), local('Helvetica Neue');
290 }}
291 @font-face {{
292 font-family: 'SF Mono';
293 src: local('SF Mono'), local('Menlo'), local('Monaco');
294 }}
295 :root {{
296 color-scheme: dark;
297 --bg: #0a0a0a;
298 --panel: #111111;
299 --panel-border: #262626;
300 --text: #fafafa;
301 --muted: #a1a1aa;
302 --subtle: #52525b;
303 --code-bg: #18181b;
304 --code-border: #27272a;
305 --accent: {accent};
306 }}
307 * {{ box-sizing: border-box; }}
308 body {{
309 font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, system-ui, sans-serif;
310 display: flex;
311 justify-content: center;
312 align-items: center;
313 min-height: 100vh;
314 margin: 0;
315 background:
316 radial-gradient(circle at top, rgba(255,255,255,0.04), transparent 32%),
317 linear-gradient(180deg, var(--bg), #050505);
318 color: var(--text);
319 padding: 24px;
320 }}
321 .container {{
322 text-align: center;
323 padding: 2.75rem 3rem;
324 border: 1px solid var(--panel-border);
325 border-radius: 14px;
326 background: rgba(17, 17, 17, 0.92);
327 max-width: 460px;
328 width: 100%;
329 box-shadow: 0 30px 90px rgba(0, 0, 0, 0.35);
330 }}
331 .logo {{
332 margin-bottom: 1.5rem;
333 }}
334 .logo-mark {{
335 display: inline-flex;
336 align-items: center;
337 justify-content: center;
338 font-size: 0.95rem;
339 letter-spacing: 0.24em;
340 text-transform: uppercase;
341 color: var(--muted);
342 }}
343 .status-icon {{
344 width: 52px;
345 height: 52px;
346 margin: 0 auto 1.25rem;
347 border: 2px solid var(--accent);
348 border-radius: 50%;
349 display: flex;
350 align-items: center;
351 justify-content: center;
352 font-size: 1.25rem;
353 color: var(--accent);
354 }}
355 h1 {{
356 margin: 0 0 0.75rem 0;
357 font-size: 1.25rem;
358 font-weight: 600;
359 letter-spacing: -0.02em;
360 }}
361 p {{
362 color: var(--muted);
363 margin: 0;
364 font-size: 0.92rem;
365 line-height: 1.55;
366 }}
367 .close-note {{
368 margin-top: 1.25rem;
369 font-size: 0.78rem;
370 color: var(--subtle);
371 }}
372 .error {{
373 margin-top: 1.35rem;
374 padding: 0.95rem 1rem;
375 background: var(--code-bg);
376 border: 1px solid var(--code-border);
377 border-radius: 10px;
378 font-family: 'SF Mono', Menlo, Monaco, monospace;
379 font-size: 0.75rem;
380 color: #d4d4d8;
381 word-break: break-word;
382 text-align: left;
383 }}
384 </style>
385</head>
386<body>
387 <div class="container">
388 <div class="logo">
389 <div class="logo-mark">> VT Code</div>
390 </div>
391 <div class="status-icon">{icon}</div>
392 <h1>{title}</h1>
393 <p>{subtitle}</p>
394 {close_note_html}
395 {error_html}
396 </div>
397 {auto_close}
398</body>
399</html>"##,
400 title = html_escape(title),
401 subtitle = html_escape(subtitle),
402 icon = icon,
403 accent = accent,
404 close_note_html = close_note_html,
405 error_html = error_html,
406 auto_close = auto_close,
407 )
408}
409
410fn html_escape(value: &str) -> String {
411 value
412 .replace('&', "&")
413 .replace('<', "<")
414 .replace('>', ">")
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use axum::extract::{Query, State};
421
422 #[test]
423 fn oauth_provider_parses_known_providers() {
424 assert_eq!("openai".parse::<OAuthProvider>(), Ok(OAuthProvider::OpenAi));
425 assert_eq!(
426 "openrouter".parse::<OAuthProvider>(),
427 Ok(OAuthProvider::OpenRouter)
428 );
429 assert!("other".parse::<OAuthProvider>().is_err());
430 }
431
432 #[test]
433 fn success_html_mentions_vtcode_and_autoclose() {
434 let html = success_html(OAuthProvider::OpenAi);
435 assert!(html.contains("VT Code"));
436 assert!(html.contains("Authentication Successful"));
437 assert!(html.contains("window.close"));
438 }
439
440 #[tokio::test]
441 async fn callback_rejects_state_mismatch() {
442 let (result_tx, mut result_rx) = mpsc::channel(1);
443 let state = Arc::new(AuthCallbackState {
444 page: OAuthCallbackPage::new(OAuthProvider::OpenAi),
445 expected_state: Some("expected-state".to_string()),
446 result_tx,
447 });
448
449 let html = handle_callback(
450 State(state),
451 Query(AuthCallbackParams {
452 code: Some("auth-code".to_string()),
453 error: None,
454 error_description: None,
455 state: Some("wrong-state".to_string()),
456 }),
457 )
458 .await;
459
460 let outcome = result_rx.recv().await.expect("callback outcome");
461 match outcome {
462 AuthCallbackOutcome::Error(message) => {
463 assert!(message.contains("state mismatch"));
464 }
465 _ => panic!("expected error outcome"),
466 }
467 assert!(html.0.contains("Authentication Failed"));
468 }
469}