Skip to main content

vtcode_auth/
oauth_server.rs

1use anyhow::{Context, Result};
2use axum::{
3    Router,
4    extract::{Query, State},
5    response::Html,
6    routing::get,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::net::SocketAddr;
11use std::str::FromStr;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{mpsc, oneshot};
15
16const DEFAULT_CALLBACK_TIMEOUT_SECS: u64 = 300;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
20#[serde(rename_all = "snake_case")]
21pub enum OAuthProvider {
22    OpenAi,
23    OpenRouter,
24}
25
26impl OAuthProvider {
27    #[must_use]
28    pub fn slug(self) -> &'static str {
29        match self {
30            Self::OpenAi => "openai",
31            Self::OpenRouter => "openrouter",
32        }
33    }
34
35    #[must_use]
36    pub fn display_name(self) -> &'static str {
37        match self {
38            Self::OpenAi => "OpenAI",
39            Self::OpenRouter => "OpenRouter",
40        }
41    }
42
43    #[must_use]
44    pub fn subtitle(self) -> &'static str {
45        match self {
46            Self::OpenAi => "Your ChatGPT subscription is now connected.",
47            Self::OpenRouter => "Your OpenRouter account is now connected.",
48        }
49    }
50
51    #[must_use]
52    pub fn failure_subtitle(self) -> &'static str {
53        match self {
54            Self::OpenAi => "Unable to connect your ChatGPT subscription.",
55            Self::OpenRouter => "Unable to connect your OpenRouter account.",
56        }
57    }
58
59    #[must_use]
60    pub fn retry_hint(self) -> String {
61        format!("You can try again anytime using /login {}", self.slug())
62    }
63
64    #[must_use]
65    pub fn supports_manual_refresh(self) -> bool {
66        matches!(self, Self::OpenAi)
67    }
68}
69
70impl fmt::Display for OAuthProvider {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        f.write_str(self.slug())
73    }
74}
75
76impl FromStr for OAuthProvider {
77    type Err = ();
78
79    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
80        match value.trim().to_ascii_lowercase().as_str() {
81            "openai" => Ok(Self::OpenAi),
82            "openrouter" => Ok(Self::OpenRouter),
83            _ => Err(()),
84        }
85    }
86}
87
88#[derive(Debug, Clone, Copy)]
89pub struct OAuthCallbackPage {
90    provider_slug: &'static str,
91    success_subtitle: &'static str,
92    failure_subtitle: &'static str,
93    retry_hint: &'static str,
94}
95
96impl OAuthCallbackPage {
97    #[must_use]
98    pub fn new(provider: OAuthProvider) -> Self {
99        match provider {
100            OAuthProvider::OpenAi => Self {
101                provider_slug: "openai",
102                success_subtitle: "Your ChatGPT subscription is now connected.",
103                failure_subtitle: "Unable to connect your ChatGPT subscription.",
104                retry_hint: "You can try again anytime using /login openai",
105            },
106            OAuthProvider::OpenRouter => Self {
107                provider_slug: "openrouter",
108                success_subtitle: "Your OpenRouter account is now connected.",
109                failure_subtitle: "Unable to connect your OpenRouter account.",
110                retry_hint: "You can try again anytime using /login openrouter",
111            },
112        }
113    }
114
115    #[must_use]
116    pub fn custom(
117        provider_slug: &'static str,
118        success_subtitle: &'static str,
119        failure_subtitle: &'static str,
120        retry_hint: &'static str,
121    ) -> Self {
122        Self {
123            provider_slug,
124            success_subtitle,
125            failure_subtitle,
126            retry_hint,
127        }
128    }
129
130    #[must_use]
131    pub fn provider_slug(&self) -> &'static str {
132        self.provider_slug
133    }
134
135    #[must_use]
136    pub fn success_subtitle(&self) -> &'static str {
137        self.success_subtitle
138    }
139
140    #[must_use]
141    pub fn failure_subtitle(&self) -> &'static str {
142        self.failure_subtitle
143    }
144
145    #[must_use]
146    pub fn retry_hint(&self) -> &'static str {
147        self.retry_hint
148    }
149}
150
151#[derive(Debug, Clone)]
152pub enum AuthCallbackOutcome {
153    Code(String),
154    Cancelled,
155    Error(String),
156}
157
158pub struct AuthCodeCallbackServer {
159    timeout: Duration,
160    result_rx: mpsc::Receiver<AuthCallbackOutcome>,
161    shutdown_tx: Option<oneshot::Sender<()>>,
162    server_handle: Option<tokio::task::JoinHandle<()>>,
163}
164
165impl AuthCodeCallbackServer {
166    pub async fn start(
167        port: u16,
168        timeout_secs: u64,
169        page: OAuthCallbackPage,
170        expected_state: Option<String>,
171    ) -> Result<Self> {
172        let (result_tx, result_rx) = mpsc::channel::<AuthCallbackOutcome>(1);
173        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
174        let state = Arc::new(AuthCallbackState {
175            page,
176            expected_state,
177            result_tx,
178        });
179
180        let app = Router::new()
181            .route("/callback", get(handle_callback))
182            .route("/auth/callback", get(handle_callback))
183            .route("/cancel", get(handle_cancel))
184            .route("/health", get(|| async { "OK" }))
185            .with_state(state);
186
187        let addr = SocketAddr::from(([127, 0, 0, 1], port));
188        let listener = tokio::net::TcpListener::bind(addr)
189            .await
190            .with_context(|| format!("failed to bind localhost callback server on port {port}"))?;
191
192        let server = axum::serve(listener, app).with_graceful_shutdown(async move {
193            let _ = shutdown_rx.await;
194        });
195        let server_handle = tokio::spawn(async move {
196            if let Err(err) = server.await {
197                tracing::error!("OAuth callback server error: {}", err);
198            }
199        });
200
201        Ok(Self {
202            timeout: callback_timeout(timeout_secs),
203            result_rx,
204            shutdown_tx: Some(shutdown_tx),
205            server_handle: Some(server_handle),
206        })
207    }
208
209    pub async fn wait(mut self) -> Result<AuthCallbackOutcome> {
210        let result = tokio::select! {
211            Some(result) = self.result_rx.recv() => result,
212            _ = tokio::time::sleep(self.timeout) => {
213                AuthCallbackOutcome::Error(format!(
214                    "OAuth flow timed out after {} seconds",
215                    self.timeout.as_secs()
216                ))
217            }
218        };
219
220        self.shutdown().await;
221        Ok(result)
222    }
223
224    async fn shutdown(&mut self) {
225        if let Some(shutdown_tx) = self.shutdown_tx.take() {
226            let _ = shutdown_tx.send(());
227        }
228        if let Some(server_handle) = self.server_handle.take() {
229            let _ = server_handle.await;
230        }
231    }
232}
233
234impl Drop for AuthCodeCallbackServer {
235    fn drop(&mut self) {
236        if let Some(shutdown_tx) = self.shutdown_tx.take() {
237            let _ = shutdown_tx.send(());
238        }
239        if let Some(server_handle) = self.server_handle.take() {
240            server_handle.abort();
241        }
242    }
243}
244
245#[derive(Debug, Deserialize)]
246struct AuthCallbackParams {
247    code: Option<String>,
248    error: Option<String>,
249    error_description: Option<String>,
250    state: Option<String>,
251}
252
253struct AuthCallbackState {
254    page: OAuthCallbackPage,
255    expected_state: Option<String>,
256    result_tx: mpsc::Sender<AuthCallbackOutcome>,
257}
258
259pub async fn start_auth_code_callback_server(
260    port: u16,
261    timeout_secs: u64,
262    page: OAuthCallbackPage,
263    expected_state: Option<String>,
264) -> Result<AuthCodeCallbackServer> {
265    AuthCodeCallbackServer::start(port, timeout_secs, page, expected_state).await
266}
267
268pub async fn run_auth_code_callback_server(
269    port: u16,
270    timeout_secs: u64,
271    page: OAuthCallbackPage,
272    expected_state: Option<String>,
273) -> Result<AuthCallbackOutcome> {
274    start_auth_code_callback_server(port, timeout_secs, page, expected_state)
275        .await?
276        .wait()
277        .await
278}
279
280async fn handle_callback(
281    State(state): State<Arc<AuthCallbackState>>,
282    Query(params): Query<AuthCallbackParams>,
283) -> Html<String> {
284    tracing::info!(
285        provider = state.page.provider_slug(),
286        has_code = params.code.is_some(),
287        has_error = params.error.is_some(),
288        "received oauth callback"
289    );
290    if let Some(expected_state) = state.expected_state.as_deref() {
291        match params.state.as_deref() {
292            Some(actual_state) if actual_state == expected_state => {}
293            _ => {
294                let message = "OAuth error: state mismatch".to_string();
295                let _ = state
296                    .result_tx
297                    .send(AuthCallbackOutcome::Error(message.clone()))
298                    .await;
299                return Html(error_html(state.page, &message));
300            }
301        }
302    }
303
304    if let Some(error) = params.error {
305        let message = match params.error_description {
306            Some(description) if !description.trim().is_empty() => {
307                format!("OAuth error: {error} - {description}")
308            }
309            _ => format!("OAuth error: {error}"),
310        };
311        let _ = state
312            .result_tx
313            .send(AuthCallbackOutcome::Error(message.clone()))
314            .await;
315        return Html(error_html(state.page, &message));
316    }
317
318    let Some(code) = params.code else {
319        let message = "Missing authorization code".to_string();
320        let _ = state
321            .result_tx
322            .send(AuthCallbackOutcome::Error(message.clone()))
323            .await;
324        return Html(error_html(state.page, &message));
325    };
326
327    let _ = state.result_tx.send(AuthCallbackOutcome::Code(code)).await;
328    Html(success_html(state.page))
329}
330
331async fn handle_cancel(State(state): State<Arc<AuthCallbackState>>) -> Html<String> {
332    let _ = state.result_tx.send(AuthCallbackOutcome::Cancelled).await;
333    Html(cancelled_html(state.page))
334}
335
336fn success_html(page: OAuthCallbackPage) -> String {
337    base_html(
338        "Authentication Successful",
339        page.success_subtitle(),
340        Some("You may now close this window and return to VT Code."),
341        "✓",
342        "#22c55e",
343        None,
344    )
345}
346
347fn error_html(page: OAuthCallbackPage, error: &str) -> String {
348    base_html(
349        "Authentication Failed",
350        page.failure_subtitle(),
351        None,
352        "✕",
353        "#ef4444",
354        Some(error),
355    )
356}
357
358fn cancelled_html(page: OAuthCallbackPage) -> String {
359    base_html(
360        "Authentication Cancelled",
361        page.retry_hint(),
362        None,
363        "—",
364        "#71717a",
365        None,
366    )
367}
368
369fn base_html(
370    title: &str,
371    subtitle: &str,
372    close_note: Option<&str>,
373    icon: &str,
374    accent: &str,
375    error: Option<&str>,
376) -> String {
377    let close_note_html = close_note
378        .map(|value| format!(r#"<p class="close-note">{}</p>"#, html_escape(value)))
379        .unwrap_or_default();
380    let error_html = error
381        .map(|value| format!(r#"<div class="error">{}</div>"#, html_escape(value)))
382        .unwrap_or_default();
383    let auto_close = if close_note.is_some() {
384        r#"<script>setTimeout(() => window.close(), 3000);</script>"#
385    } else {
386        ""
387    };
388
389    format!(
390        r##"<!DOCTYPE html>
391<html>
392<head>
393    <title>VT Code - {title}</title>
394    <style>
395        @font-face {{
396            font-family: 'SF Pro Display';
397            src: local('SF Pro Display'), local('.SF NS Display'), local('Helvetica Neue');
398        }}
399        @font-face {{
400            font-family: 'SF Mono';
401            src: local('SF Mono'), local('Menlo'), local('Monaco');
402        }}
403        :root {{
404            color-scheme: dark;
405            --bg: #0a0a0a;
406            --panel: #111111;
407            --panel-border: #262626;
408            --text: #fafafa;
409            --muted: #a1a1aa;
410            --subtle: #52525b;
411            --code-bg: #18181b;
412            --code-border: #27272a;
413            --accent: {accent};
414        }}
415        * {{ box-sizing: border-box; }}
416        body {{
417            font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, system-ui, sans-serif;
418            display: flex;
419            justify-content: center;
420            align-items: center;
421            min-height: 100vh;
422            margin: 0;
423            background:
424                radial-gradient(circle at top, rgba(255,255,255,0.04), transparent 32%),
425                linear-gradient(180deg, var(--bg), #050505);
426            color: var(--text);
427            padding: 24px;
428        }}
429        .container {{
430            text-align: center;
431            padding: 2.75rem 3rem;
432            border: 1px solid var(--panel-border);
433            border-radius: 14px;
434            background: rgba(17, 17, 17, 0.92);
435            max-width: 460px;
436            width: 100%;
437            box-shadow: 0 30px 90px rgba(0, 0, 0, 0.35);
438        }}
439        .logo {{
440            margin-bottom: 1.5rem;
441        }}
442        .logo-mark {{
443            display: inline-flex;
444            align-items: center;
445            justify-content: center;
446            font-size: 0.95rem;
447            letter-spacing: 0.24em;
448            text-transform: uppercase;
449            color: var(--muted);
450        }}
451        .status-icon {{
452            width: 52px;
453            height: 52px;
454            margin: 0 auto 1.25rem;
455            border: 2px solid var(--accent);
456            border-radius: 50%;
457            display: flex;
458            align-items: center;
459            justify-content: center;
460            font-size: 1.25rem;
461            color: var(--accent);
462        }}
463        h1 {{
464            margin: 0 0 0.75rem 0;
465            font-size: 1.25rem;
466            font-weight: 600;
467            letter-spacing: -0.02em;
468        }}
469        p {{
470            color: var(--muted);
471            margin: 0;
472            font-size: 0.92rem;
473            line-height: 1.55;
474        }}
475        .close-note {{
476            margin-top: 1.25rem;
477            font-size: 0.78rem;
478            color: var(--subtle);
479        }}
480        .error {{
481            margin-top: 1.35rem;
482            padding: 0.95rem 1rem;
483            background: var(--code-bg);
484            border: 1px solid var(--code-border);
485            border-radius: 10px;
486            font-family: 'SF Mono', Menlo, Monaco, monospace;
487            font-size: 0.75rem;
488            color: #d4d4d8;
489            word-break: break-word;
490            text-align: left;
491        }}
492    </style>
493</head>
494<body>
495    <div class="container">
496        <div class="logo">
497            <div class="logo-mark">&gt; VT Code</div>
498        </div>
499        <div class="status-icon">{icon}</div>
500        <h1>{title}</h1>
501        <p>{subtitle}</p>
502        {close_note_html}
503        {error_html}
504    </div>
505    {auto_close}
506</body>
507</html>"##,
508        title = html_escape(title),
509        subtitle = html_escape(subtitle),
510        icon = icon,
511        accent = accent,
512        close_note_html = close_note_html,
513        error_html = error_html,
514        auto_close = auto_close,
515    )
516}
517
518fn html_escape(value: &str) -> String {
519    value
520        .replace('&', "&amp;")
521        .replace('<', "&lt;")
522        .replace('>', "&gt;")
523}
524
525fn callback_timeout(timeout_secs: u64) -> Duration {
526    Duration::from_secs(if timeout_secs == 0 {
527        DEFAULT_CALLBACK_TIMEOUT_SECS
528    } else {
529        timeout_secs
530    })
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use axum::extract::{Query, State};
537    use reqwest::Client;
538
539    #[test]
540    fn oauth_provider_parses_known_providers() {
541        assert_eq!("openai".parse::<OAuthProvider>(), Ok(OAuthProvider::OpenAi));
542        assert_eq!(
543            "openrouter".parse::<OAuthProvider>(),
544            Ok(OAuthProvider::OpenRouter)
545        );
546        assert!("other".parse::<OAuthProvider>().is_err());
547    }
548
549    #[test]
550    fn success_html_mentions_vtcode_and_autoclose() {
551        let html = success_html(OAuthCallbackPage::new(OAuthProvider::OpenAi));
552        assert!(html.contains("VT Code"));
553        assert!(html.contains("Authentication Successful"));
554        assert!(html.contains("window.close"));
555    }
556
557    #[tokio::test]
558    async fn callback_rejects_state_mismatch() {
559        let (result_tx, mut result_rx) = mpsc::channel(1);
560        let state = Arc::new(AuthCallbackState {
561            page: OAuthCallbackPage::new(OAuthProvider::OpenAi),
562            expected_state: Some("expected-state".to_string()),
563            result_tx,
564        });
565
566        let html = handle_callback(
567            State(state),
568            Query(AuthCallbackParams {
569                code: Some("auth-code".to_string()),
570                error: None,
571                error_description: None,
572                state: Some("wrong-state".to_string()),
573            }),
574        )
575        .await;
576
577        let outcome = result_rx.recv().await.expect("callback outcome");
578        match outcome {
579            AuthCallbackOutcome::Error(message) => {
580                assert!(message.contains("state mismatch"));
581            }
582            _ => panic!("expected error outcome"),
583        }
584        assert!(html.0.contains("Authentication Failed"));
585    }
586
587    #[tokio::test]
588    async fn callback_server_starts_listening_before_wait() {
589        let listener = std::net::TcpListener::bind(("127.0.0.1", 0)).expect("bind temp port");
590        let port = listener.local_addr().expect("local addr").port();
591        drop(listener);
592
593        let server = start_auth_code_callback_server(
594            port,
595            5,
596            OAuthCallbackPage::new(OAuthProvider::OpenAi),
597            None,
598        )
599        .await
600        .expect("start callback server");
601        let client = Client::builder()
602            .no_proxy()
603            .build()
604            .expect("build http client");
605
606        let health = client
607            .get(format!("http://127.0.0.1:{port}/health"))
608            .send()
609            .await
610            .expect("health request should succeed");
611        assert!(health.status().is_success());
612
613        let cancel = client
614            .get(format!("http://127.0.0.1:{port}/cancel"))
615            .send()
616            .await
617            .expect("cancel request should succeed");
618        assert!(cancel.status().is_success());
619
620        assert!(matches!(
621            server.wait().await.expect("wait for callback outcome"),
622            AuthCallbackOutcome::Cancelled
623        ));
624    }
625}