1use anyhow::{Context, Result};
2use axum::{
3 Router,
4 extract::{Query, State},
5 response::Html,
6 routing::get,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::net::SocketAddr;
11use std::str::FromStr;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16const DEFAULT_CALLBACK_TIMEOUT_SECS: u64 = 300;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
20#[serde(rename_all = "snake_case")]
21pub enum OAuthProvider {
22 OpenAi,
23 OpenRouter,
24}
25
26impl OAuthProvider {
27 #[must_use]
28 pub fn slug(self) -> &'static str {
29 match self {
30 Self::OpenAi => "openai",
31 Self::OpenRouter => "openrouter",
32 }
33 }
34
35 #[must_use]
36 pub fn display_name(self) -> &'static str {
37 match self {
38 Self::OpenAi => "OpenAI",
39 Self::OpenRouter => "OpenRouter",
40 }
41 }
42
43 #[must_use]
44 pub fn subtitle(self) -> &'static str {
45 match self {
46 Self::OpenAi => "Your ChatGPT subscription is now connected.",
47 Self::OpenRouter => "Your OpenRouter account is now connected.",
48 }
49 }
50
51 #[must_use]
52 pub fn failure_subtitle(self) -> &'static str {
53 match self {
54 Self::OpenAi => "Unable to connect your ChatGPT subscription.",
55 Self::OpenRouter => "Unable to connect your OpenRouter account.",
56 }
57 }
58
59 #[must_use]
60 pub fn retry_hint(self) -> String {
61 format!("You can try again anytime using /login {}", self.slug())
62 }
63
64 #[must_use]
65 pub fn supports_manual_refresh(self) -> bool {
66 matches!(self, Self::OpenAi)
67 }
68}
69
70impl fmt::Display for OAuthProvider {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 f.write_str(self.slug())
73 }
74}
75
76impl FromStr for OAuthProvider {
77 type Err = ();
78
79 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
80 match value.trim().to_ascii_lowercase().as_str() {
81 "openai" => Ok(Self::OpenAi),
82 "openrouter" => Ok(Self::OpenRouter),
83 _ => Err(()),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy)]
89pub struct OAuthCallbackPage {
90 provider_slug: &'static str,
91 success_subtitle: &'static str,
92 failure_subtitle: &'static str,
93 retry_hint: &'static str,
94}
95
96impl OAuthCallbackPage {
97 #[must_use]
98 pub fn new(provider: OAuthProvider) -> Self {
99 match provider {
100 OAuthProvider::OpenAi => Self {
101 provider_slug: "openai",
102 success_subtitle: "Your ChatGPT subscription is now connected.",
103 failure_subtitle: "Unable to connect your ChatGPT subscription.",
104 retry_hint: "You can try again anytime using /login openai",
105 },
106 OAuthProvider::OpenRouter => Self {
107 provider_slug: "openrouter",
108 success_subtitle: "Your OpenRouter account is now connected.",
109 failure_subtitle: "Unable to connect your OpenRouter account.",
110 retry_hint: "You can try again anytime using /login openrouter",
111 },
112 }
113 }
114
115 #[must_use]
116 pub fn custom(
117 provider_slug: &'static str,
118 success_subtitle: &'static str,
119 failure_subtitle: &'static str,
120 retry_hint: &'static str,
121 ) -> Self {
122 Self {
123 provider_slug,
124 success_subtitle,
125 failure_subtitle,
126 retry_hint,
127 }
128 }
129
130 #[must_use]
131 pub fn provider_slug(&self) -> &'static str {
132 self.provider_slug
133 }
134
135 #[must_use]
136 pub fn success_subtitle(&self) -> &'static str {
137 self.success_subtitle
138 }
139
140 #[must_use]
141 pub fn failure_subtitle(&self) -> &'static str {
142 self.failure_subtitle
143 }
144
145 #[must_use]
146 pub fn retry_hint(&self) -> &'static str {
147 self.retry_hint
148 }
149}
150
151#[derive(Debug, Clone)]
152pub enum AuthCallbackOutcome {
153 Code(String),
154 Cancelled,
155 Error(String),
156}
157
158pub struct AuthCodeCallbackServer {
159 timeout: Duration,
160 result_rx: mpsc::Receiver<AuthCallbackOutcome>,
161 shutdown_tx: Option<oneshot::Sender<()>>,
162 server_handle: Option<tokio::task::JoinHandle<()>>,
163}
164
165impl AuthCodeCallbackServer {
166 pub async fn start(
167 port: u16,
168 timeout_secs: u64,
169 page: OAuthCallbackPage,
170 expected_state: Option<String>,
171 ) -> Result<Self> {
172 let (result_tx, result_rx) = mpsc::channel::<AuthCallbackOutcome>(1);
173 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
174 let state = Arc::new(AuthCallbackState {
175 page,
176 expected_state,
177 result_tx,
178 });
179
180 let app = Router::new()
181 .route("/callback", get(handle_callback))
182 .route("/auth/callback", get(handle_callback))
183 .route("/cancel", get(handle_cancel))
184 .route("/health", get(|| async { "OK" }))
185 .with_state(state);
186
187 let addr = SocketAddr::from(([127, 0, 0, 1], port));
188 let listener = tokio::net::TcpListener::bind(addr)
189 .await
190 .with_context(|| format!("failed to bind localhost callback server on port {port}"))?;
191
192 let server = axum::serve(listener, app).with_graceful_shutdown(async move {
193 let _ = shutdown_rx.await;
194 });
195 let server_handle = tokio::spawn(async move {
196 if let Err(err) = server.await {
197 tracing::error!("OAuth callback server error: {}", err);
198 }
199 });
200
201 Ok(Self {
202 timeout: callback_timeout(timeout_secs),
203 result_rx,
204 shutdown_tx: Some(shutdown_tx),
205 server_handle: Some(server_handle),
206 })
207 }
208
209 pub async fn wait(mut self) -> Result<AuthCallbackOutcome> {
210 let result = tokio::select! {
211 Some(result) = self.result_rx.recv() => result,
212 _ = tokio::time::sleep(self.timeout) => {
213 AuthCallbackOutcome::Error(format!(
214 "OAuth flow timed out after {} seconds",
215 self.timeout.as_secs()
216 ))
217 }
218 };
219
220 self.shutdown().await;
221 Ok(result)
222 }
223
224 async fn shutdown(&mut self) {
225 if let Some(shutdown_tx) = self.shutdown_tx.take() {
226 let _ = shutdown_tx.send(());
227 }
228 if let Some(server_handle) = self.server_handle.take() {
229 let _ = server_handle.await;
230 }
231 }
232}
233
234impl Drop for AuthCodeCallbackServer {
235 fn drop(&mut self) {
236 if let Some(shutdown_tx) = self.shutdown_tx.take() {
237 let _ = shutdown_tx.send(());
238 }
239 if let Some(server_handle) = self.server_handle.take() {
240 server_handle.abort();
241 }
242 }
243}
244
245#[derive(Debug, Deserialize)]
246struct AuthCallbackParams {
247 code: Option<String>,
248 error: Option<String>,
249 error_description: Option<String>,
250 state: Option<String>,
251}
252
253struct AuthCallbackState {
254 page: OAuthCallbackPage,
255 expected_state: Option<String>,
256 result_tx: mpsc::Sender<AuthCallbackOutcome>,
257}
258
259pub async fn start_auth_code_callback_server(
260 port: u16,
261 timeout_secs: u64,
262 page: OAuthCallbackPage,
263 expected_state: Option<String>,
264) -> Result<AuthCodeCallbackServer> {
265 AuthCodeCallbackServer::start(port, timeout_secs, page, expected_state).await
266}
267
268pub async fn run_auth_code_callback_server(
269 port: u16,
270 timeout_secs: u64,
271 page: OAuthCallbackPage,
272 expected_state: Option<String>,
273) -> Result<AuthCallbackOutcome> {
274 start_auth_code_callback_server(port, timeout_secs, page, expected_state)
275 .await?
276 .wait()
277 .await
278}
279
280async fn handle_callback(
281 State(state): State<Arc<AuthCallbackState>>,
282 Query(params): Query<AuthCallbackParams>,
283) -> Html<String> {
284 tracing::info!(
285 provider = state.page.provider_slug(),
286 has_code = params.code.is_some(),
287 has_error = params.error.is_some(),
288 "received oauth callback"
289 );
290 if let Some(expected_state) = state.expected_state.as_deref() {
291 match params.state.as_deref() {
292 Some(actual_state) if actual_state == expected_state => {}
293 _ => {
294 let message = "OAuth error: state mismatch".to_string();
295 let _ = state
296 .result_tx
297 .send(AuthCallbackOutcome::Error(message.clone()))
298 .await;
299 return Html(error_html(state.page, &message));
300 }
301 }
302 }
303
304 if let Some(error) = params.error {
305 let message = match params.error_description {
306 Some(description) if !description.trim().is_empty() => {
307 format!("OAuth error: {error} - {description}")
308 }
309 _ => format!("OAuth error: {error}"),
310 };
311 let _ = state
312 .result_tx
313 .send(AuthCallbackOutcome::Error(message.clone()))
314 .await;
315 return Html(error_html(state.page, &message));
316 }
317
318 let Some(code) = params.code else {
319 let message = "Missing authorization code".to_string();
320 let _ = state
321 .result_tx
322 .send(AuthCallbackOutcome::Error(message.clone()))
323 .await;
324 return Html(error_html(state.page, &message));
325 };
326
327 let _ = state.result_tx.send(AuthCallbackOutcome::Code(code)).await;
328 Html(success_html(state.page))
329}
330
331async fn handle_cancel(State(state): State<Arc<AuthCallbackState>>) -> Html<String> {
332 let _ = state.result_tx.send(AuthCallbackOutcome::Cancelled).await;
333 Html(cancelled_html(state.page))
334}
335
336fn success_html(page: OAuthCallbackPage) -> String {
337 base_html(
338 "Authentication Successful",
339 page.success_subtitle(),
340 Some("You may now close this window and return to VT Code."),
341 "✓",
342 "#22c55e",
343 None,
344 )
345}
346
347fn error_html(page: OAuthCallbackPage, error: &str) -> String {
348 base_html(
349 "Authentication Failed",
350 page.failure_subtitle(),
351 None,
352 "✕",
353 "#ef4444",
354 Some(error),
355 )
356}
357
358fn cancelled_html(page: OAuthCallbackPage) -> String {
359 base_html(
360 "Authentication Cancelled",
361 page.retry_hint(),
362 None,
363 "—",
364 "#71717a",
365 None,
366 )
367}
368
369fn base_html(
370 title: &str,
371 subtitle: &str,
372 close_note: Option<&str>,
373 icon: &str,
374 accent: &str,
375 error: Option<&str>,
376) -> String {
377 let close_note_html = close_note
378 .map(|value| format!(r#"<p class="close-note">{}</p>"#, html_escape(value)))
379 .unwrap_or_default();
380 let error_html = error
381 .map(|value| format!(r#"<div class="error">{}</div>"#, html_escape(value)))
382 .unwrap_or_default();
383 let auto_close = if close_note.is_some() {
384 r#"<script>setTimeout(() => window.close(), 3000);</script>"#
385 } else {
386 ""
387 };
388
389 format!(
390 r##"<!DOCTYPE html>
391<html>
392<head>
393 <title>VT Code - {title}</title>
394 <style>
395 @font-face {{
396 font-family: 'SF Pro Display';
397 src: local('SF Pro Display'), local('.SF NS Display'), local('Helvetica Neue');
398 }}
399 @font-face {{
400 font-family: 'SF Mono';
401 src: local('SF Mono'), local('Menlo'), local('Monaco');
402 }}
403 :root {{
404 color-scheme: dark;
405 --bg: #0a0a0a;
406 --panel: #111111;
407 --panel-border: #262626;
408 --text: #fafafa;
409 --muted: #a1a1aa;
410 --subtle: #52525b;
411 --code-bg: #18181b;
412 --code-border: #27272a;
413 --accent: {accent};
414 }}
415 * {{ box-sizing: border-box; }}
416 body {{
417 font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, system-ui, sans-serif;
418 display: flex;
419 justify-content: center;
420 align-items: center;
421 min-height: 100vh;
422 margin: 0;
423 background:
424 radial-gradient(circle at top, rgba(255,255,255,0.04), transparent 32%),
425 linear-gradient(180deg, var(--bg), #050505);
426 color: var(--text);
427 padding: 24px;
428 }}
429 .container {{
430 text-align: center;
431 padding: 2.75rem 3rem;
432 border: 1px solid var(--panel-border);
433 border-radius: 14px;
434 background: rgba(17, 17, 17, 0.92);
435 max-width: 460px;
436 width: 100%;
437 box-shadow: 0 30px 90px rgba(0, 0, 0, 0.35);
438 }}
439 .logo {{
440 margin-bottom: 1.5rem;
441 }}
442 .logo-mark {{
443 display: inline-flex;
444 align-items: center;
445 justify-content: center;
446 font-size: 0.95rem;
447 letter-spacing: 0.24em;
448 text-transform: uppercase;
449 color: var(--muted);
450 }}
451 .status-icon {{
452 width: 52px;
453 height: 52px;
454 margin: 0 auto 1.25rem;
455 border: 2px solid var(--accent);
456 border-radius: 50%;
457 display: flex;
458 align-items: center;
459 justify-content: center;
460 font-size: 1.25rem;
461 color: var(--accent);
462 }}
463 h1 {{
464 margin: 0 0 0.75rem 0;
465 font-size: 1.25rem;
466 font-weight: 600;
467 letter-spacing: -0.02em;
468 }}
469 p {{
470 color: var(--muted);
471 margin: 0;
472 font-size: 0.92rem;
473 line-height: 1.55;
474 }}
475 .close-note {{
476 margin-top: 1.25rem;
477 font-size: 0.78rem;
478 color: var(--subtle);
479 }}
480 .error {{
481 margin-top: 1.35rem;
482 padding: 0.95rem 1rem;
483 background: var(--code-bg);
484 border: 1px solid var(--code-border);
485 border-radius: 10px;
486 font-family: 'SF Mono', Menlo, Monaco, monospace;
487 font-size: 0.75rem;
488 color: #d4d4d8;
489 word-break: break-word;
490 text-align: left;
491 }}
492 </style>
493</head>
494<body>
495 <div class="container">
496 <div class="logo">
497 <div class="logo-mark">> VT Code</div>
498 </div>
499 <div class="status-icon">{icon}</div>
500 <h1>{title}</h1>
501 <p>{subtitle}</p>
502 {close_note_html}
503 {error_html}
504 </div>
505 {auto_close}
506</body>
507</html>"##,
508 title = html_escape(title),
509 subtitle = html_escape(subtitle),
510 icon = icon,
511 accent = accent,
512 close_note_html = close_note_html,
513 error_html = error_html,
514 auto_close = auto_close,
515 )
516}
517
518fn html_escape(value: &str) -> String {
519 value
520 .replace('&', "&")
521 .replace('<', "<")
522 .replace('>', ">")
523}
524
525fn callback_timeout(timeout_secs: u64) -> Duration {
526 Duration::from_secs(if timeout_secs == 0 {
527 DEFAULT_CALLBACK_TIMEOUT_SECS
528 } else {
529 timeout_secs
530 })
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use axum::extract::{Query, State};
537 use reqwest::Client;
538
539 #[test]
540 fn oauth_provider_parses_known_providers() {
541 assert_eq!("openai".parse::<OAuthProvider>(), Ok(OAuthProvider::OpenAi));
542 assert_eq!(
543 "openrouter".parse::<OAuthProvider>(),
544 Ok(OAuthProvider::OpenRouter)
545 );
546 assert!("other".parse::<OAuthProvider>().is_err());
547 }
548
549 #[test]
550 fn success_html_mentions_vtcode_and_autoclose() {
551 let html = success_html(OAuthCallbackPage::new(OAuthProvider::OpenAi));
552 assert!(html.contains("VT Code"));
553 assert!(html.contains("Authentication Successful"));
554 assert!(html.contains("window.close"));
555 }
556
557 #[tokio::test]
558 async fn callback_rejects_state_mismatch() {
559 let (result_tx, mut result_rx) = mpsc::channel(1);
560 let state = Arc::new(AuthCallbackState {
561 page: OAuthCallbackPage::new(OAuthProvider::OpenAi),
562 expected_state: Some("expected-state".to_string()),
563 result_tx,
564 });
565
566 let html = handle_callback(
567 State(state),
568 Query(AuthCallbackParams {
569 code: Some("auth-code".to_string()),
570 error: None,
571 error_description: None,
572 state: Some("wrong-state".to_string()),
573 }),
574 )
575 .await;
576
577 let outcome = result_rx.recv().await.expect("callback outcome");
578 match outcome {
579 AuthCallbackOutcome::Error(message) => {
580 assert!(message.contains("state mismatch"));
581 }
582 _ => panic!("expected error outcome"),
583 }
584 assert!(html.0.contains("Authentication Failed"));
585 }
586
587 #[tokio::test]
588 async fn callback_server_starts_listening_before_wait() {
589 let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).expect("bind temp port");
590 let port = listener.local_addr().expect("local addr").port();
591 drop(listener);
592
593 let server = start_auth_code_callback_server(
594 port,
595 5,
596 OAuthCallbackPage::new(OAuthProvider::OpenAi),
597 None,
598 )
599 .await
600 .expect("start callback server");
601 let client = Client::builder()
602 .no_proxy()
603 .build()
604 .expect("build http client");
605
606 let health = client
607 .get(format!("http://127.0.0.1:{port}/health"))
608 .send()
609 .await
610 .expect("health request should succeed");
611 assert!(health.status().is_success());
612
613 let cancel = client
614 .get(format!("http://127.0.0.1:{port}/cancel"))
615 .send()
616 .await
617 .expect("cancel request should succeed");
618 assert!(cancel.status().is_success());
619
620 assert!(matches!(
621 server.wait().await.expect("wait for callback outcome"),
622 AuthCallbackOutcome::Cancelled
623 ));
624 }
625}