Skip to main content

vtcode_auth/
oauth_server.rs

1use anyhow::{Context, Result};
2use axum::{
3    Router,
4    extract::{Query, State},
5    response::Html,
6    routing::get,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt;
10use std::net::SocketAddr;
11use std::str::FromStr;
12use std::sync::Arc;
13use tokio::sync::{mpsc, oneshot};
14
15const DEFAULT_CALLBACK_TIMEOUT_SECS: u64 = 300;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
19#[serde(rename_all = "snake_case")]
20pub enum OAuthProvider {
21    OpenAi,
22    OpenRouter,
23}
24
25impl OAuthProvider {
26    #[must_use]
27    pub fn slug(self) -> &'static str {
28        match self {
29            Self::OpenAi => "openai",
30            Self::OpenRouter => "openrouter",
31        }
32    }
33
34    #[must_use]
35    pub fn display_name(self) -> &'static str {
36        match self {
37            Self::OpenAi => "OpenAI",
38            Self::OpenRouter => "OpenRouter",
39        }
40    }
41
42    #[must_use]
43    pub fn subtitle(self) -> &'static str {
44        match self {
45            Self::OpenAi => "Your ChatGPT subscription is now connected.",
46            Self::OpenRouter => "Your OpenRouter account is now connected.",
47        }
48    }
49
50    #[must_use]
51    pub fn failure_subtitle(self) -> &'static str {
52        match self {
53            Self::OpenAi => "Unable to connect your ChatGPT subscription.",
54            Self::OpenRouter => "Unable to connect your OpenRouter account.",
55        }
56    }
57
58    #[must_use]
59    pub fn retry_hint(self) -> String {
60        format!("You can try again anytime using /login {}", self.slug())
61    }
62
63    #[must_use]
64    pub fn supports_manual_refresh(self) -> bool {
65        matches!(self, Self::OpenAi)
66    }
67}
68
69impl fmt::Display for OAuthProvider {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.write_str(self.slug())
72    }
73}
74
75impl FromStr for OAuthProvider {
76    type Err = ();
77
78    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
79        match value.trim().to_ascii_lowercase().as_str() {
80            "openai" => Ok(Self::OpenAi),
81            "openrouter" => Ok(Self::OpenRouter),
82            _ => Err(()),
83        }
84    }
85}
86
87#[derive(Debug, Clone, Copy)]
88pub struct OAuthCallbackPage {
89    provider: OAuthProvider,
90}
91
92impl OAuthCallbackPage {
93    #[must_use]
94    pub fn new(provider: OAuthProvider) -> Self {
95        Self { provider }
96    }
97}
98
99#[derive(Debug, Clone)]
100pub enum AuthCallbackOutcome {
101    Code(String),
102    Cancelled,
103    Error(String),
104}
105
106#[derive(Debug, Deserialize)]
107struct AuthCallbackParams {
108    code: Option<String>,
109    error: Option<String>,
110    error_description: Option<String>,
111    state: Option<String>,
112}
113
114struct AuthCallbackState {
115    page: OAuthCallbackPage,
116    expected_state: Option<String>,
117    result_tx: mpsc::Sender<AuthCallbackOutcome>,
118}
119
120pub async fn run_auth_code_callback_server(
121    port: u16,
122    timeout_secs: u64,
123    page: OAuthCallbackPage,
124    expected_state: Option<String>,
125) -> Result<AuthCallbackOutcome> {
126    let timeout = if timeout_secs == 0 {
127        DEFAULT_CALLBACK_TIMEOUT_SECS
128    } else {
129        timeout_secs
130    };
131    let (result_tx, mut result_rx) = mpsc::channel::<AuthCallbackOutcome>(1);
132    let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
133    let state = Arc::new(AuthCallbackState {
134        page,
135        expected_state,
136        result_tx,
137    });
138
139    let app = Router::new()
140        .route("/callback", get(handle_callback))
141        .route("/auth/callback", get(handle_callback))
142        .route("/cancel", get(handle_cancel))
143        .route("/health", get(|| async { "OK" }))
144        .with_state(state);
145
146    let addr = SocketAddr::from(([127, 0, 0, 1], port));
147    let listener = tokio::net::TcpListener::bind(addr)
148        .await
149        .with_context(|| format!("failed to bind localhost callback server on port {port}"))?;
150
151    let server = axum::serve(listener, app).with_graceful_shutdown(async move {
152        let _ = shutdown_rx.await;
153    });
154    let server_handle = tokio::spawn(async move {
155        if let Err(err) = server.await {
156            tracing::error!("OAuth callback server error: {}", err);
157        }
158    });
159
160    let result = tokio::select! {
161        Some(result) = result_rx.recv() => result,
162        _ = tokio::time::sleep(std::time::Duration::from_secs(timeout)) => {
163            AuthCallbackOutcome::Error(format!("OAuth flow timed out after {timeout} seconds"))
164        }
165    };
166
167    let _ = shutdown_tx.send(());
168    let _ = server_handle.await;
169    Ok(result)
170}
171
172async fn handle_callback(
173    State(state): State<Arc<AuthCallbackState>>,
174    Query(params): Query<AuthCallbackParams>,
175) -> Html<String> {
176    tracing::info!(
177        provider = %state.page.provider,
178        has_code = params.code.is_some(),
179        has_error = params.error.is_some(),
180        "received oauth callback"
181    );
182    if let Some(expected_state) = state.expected_state.as_deref() {
183        match params.state.as_deref() {
184            Some(actual_state) if actual_state == expected_state => {}
185            _ => {
186                let message = "OAuth error: state mismatch".to_string();
187                let _ = state
188                    .result_tx
189                    .send(AuthCallbackOutcome::Error(message.clone()))
190                    .await;
191                return Html(error_html(state.page.provider, &message));
192            }
193        }
194    }
195
196    if let Some(error) = params.error {
197        let message = match params.error_description {
198            Some(description) if !description.trim().is_empty() => {
199                format!("OAuth error: {error} - {description}")
200            }
201            _ => format!("OAuth error: {error}"),
202        };
203        let _ = state
204            .result_tx
205            .send(AuthCallbackOutcome::Error(message.clone()))
206            .await;
207        return Html(error_html(state.page.provider, &message));
208    }
209
210    let Some(code) = params.code else {
211        let message = "Missing authorization code".to_string();
212        let _ = state
213            .result_tx
214            .send(AuthCallbackOutcome::Error(message.clone()))
215            .await;
216        return Html(error_html(state.page.provider, &message));
217    };
218
219    let _ = state.result_tx.send(AuthCallbackOutcome::Code(code)).await;
220    Html(success_html(state.page.provider))
221}
222
223async fn handle_cancel(State(state): State<Arc<AuthCallbackState>>) -> Html<String> {
224    let _ = state.result_tx.send(AuthCallbackOutcome::Cancelled).await;
225    Html(cancelled_html(state.page.provider))
226}
227
228fn success_html(provider: OAuthProvider) -> String {
229    base_html(
230        "Authentication Successful",
231        provider.subtitle(),
232        Some("You may now close this window and return to VT Code."),
233        "✓",
234        "#22c55e",
235        None,
236    )
237}
238
239fn error_html(provider: OAuthProvider, error: &str) -> String {
240    base_html(
241        "Authentication Failed",
242        provider.failure_subtitle(),
243        None,
244        "✕",
245        "#ef4444",
246        Some(error),
247    )
248}
249
250fn cancelled_html(provider: OAuthProvider) -> String {
251    base_html(
252        "Authentication Cancelled",
253        &provider.retry_hint(),
254        None,
255        "—",
256        "#71717a",
257        None,
258    )
259}
260
261fn base_html(
262    title: &str,
263    subtitle: &str,
264    close_note: Option<&str>,
265    icon: &str,
266    accent: &str,
267    error: Option<&str>,
268) -> String {
269    let close_note_html = close_note
270        .map(|value| format!(r#"<p class="close-note">{}</p>"#, html_escape(value)))
271        .unwrap_or_default();
272    let error_html = error
273        .map(|value| format!(r#"<div class="error">{}</div>"#, html_escape(value)))
274        .unwrap_or_default();
275    let auto_close = if close_note.is_some() {
276        r#"<script>setTimeout(() => window.close(), 3000);</script>"#
277    } else {
278        ""
279    };
280
281    format!(
282        r##"<!DOCTYPE html>
283<html>
284<head>
285    <title>VT Code - {title}</title>
286    <style>
287        @font-face {{
288            font-family: 'SF Pro Display';
289            src: local('SF Pro Display'), local('.SF NS Display'), local('Helvetica Neue');
290        }}
291        @font-face {{
292            font-family: 'SF Mono';
293            src: local('SF Mono'), local('Menlo'), local('Monaco');
294        }}
295        :root {{
296            color-scheme: dark;
297            --bg: #0a0a0a;
298            --panel: #111111;
299            --panel-border: #262626;
300            --text: #fafafa;
301            --muted: #a1a1aa;
302            --subtle: #52525b;
303            --code-bg: #18181b;
304            --code-border: #27272a;
305            --accent: {accent};
306        }}
307        * {{ box-sizing: border-box; }}
308        body {{
309            font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, system-ui, sans-serif;
310            display: flex;
311            justify-content: center;
312            align-items: center;
313            min-height: 100vh;
314            margin: 0;
315            background:
316                radial-gradient(circle at top, rgba(255,255,255,0.04), transparent 32%),
317                linear-gradient(180deg, var(--bg), #050505);
318            color: var(--text);
319            padding: 24px;
320        }}
321        .container {{
322            text-align: center;
323            padding: 2.75rem 3rem;
324            border: 1px solid var(--panel-border);
325            border-radius: 14px;
326            background: rgba(17, 17, 17, 0.92);
327            max-width: 460px;
328            width: 100%;
329            box-shadow: 0 30px 90px rgba(0, 0, 0, 0.35);
330        }}
331        .logo {{
332            margin-bottom: 1.5rem;
333        }}
334        .logo-mark {{
335            display: inline-flex;
336            align-items: center;
337            justify-content: center;
338            font-size: 0.95rem;
339            letter-spacing: 0.24em;
340            text-transform: uppercase;
341            color: var(--muted);
342        }}
343        .status-icon {{
344            width: 52px;
345            height: 52px;
346            margin: 0 auto 1.25rem;
347            border: 2px solid var(--accent);
348            border-radius: 50%;
349            display: flex;
350            align-items: center;
351            justify-content: center;
352            font-size: 1.25rem;
353            color: var(--accent);
354        }}
355        h1 {{
356            margin: 0 0 0.75rem 0;
357            font-size: 1.25rem;
358            font-weight: 600;
359            letter-spacing: -0.02em;
360        }}
361        p {{
362            color: var(--muted);
363            margin: 0;
364            font-size: 0.92rem;
365            line-height: 1.55;
366        }}
367        .close-note {{
368            margin-top: 1.25rem;
369            font-size: 0.78rem;
370            color: var(--subtle);
371        }}
372        .error {{
373            margin-top: 1.35rem;
374            padding: 0.95rem 1rem;
375            background: var(--code-bg);
376            border: 1px solid var(--code-border);
377            border-radius: 10px;
378            font-family: 'SF Mono', Menlo, Monaco, monospace;
379            font-size: 0.75rem;
380            color: #d4d4d8;
381            word-break: break-word;
382            text-align: left;
383        }}
384    </style>
385</head>
386<body>
387    <div class="container">
388        <div class="logo">
389            <div class="logo-mark">&gt; VT Code</div>
390        </div>
391        <div class="status-icon">{icon}</div>
392        <h1>{title}</h1>
393        <p>{subtitle}</p>
394        {close_note_html}
395        {error_html}
396    </div>
397    {auto_close}
398</body>
399</html>"##,
400        title = html_escape(title),
401        subtitle = html_escape(subtitle),
402        icon = icon,
403        accent = accent,
404        close_note_html = close_note_html,
405        error_html = error_html,
406        auto_close = auto_close,
407    )
408}
409
410fn html_escape(value: &str) -> String {
411    value
412        .replace('&', "&amp;")
413        .replace('<', "&lt;")
414        .replace('>', "&gt;")
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use axum::extract::{Query, State};
421
422    #[test]
423    fn oauth_provider_parses_known_providers() {
424        assert_eq!("openai".parse::<OAuthProvider>(), Ok(OAuthProvider::OpenAi));
425        assert_eq!(
426            "openrouter".parse::<OAuthProvider>(),
427            Ok(OAuthProvider::OpenRouter)
428        );
429        assert!("other".parse::<OAuthProvider>().is_err());
430    }
431
432    #[test]
433    fn success_html_mentions_vtcode_and_autoclose() {
434        let html = success_html(OAuthProvider::OpenAi);
435        assert!(html.contains("VT Code"));
436        assert!(html.contains("Authentication Successful"));
437        assert!(html.contains("window.close"));
438    }
439
440    #[tokio::test]
441    async fn callback_rejects_state_mismatch() {
442        let (result_tx, mut result_rx) = mpsc::channel(1);
443        let state = Arc::new(AuthCallbackState {
444            page: OAuthCallbackPage::new(OAuthProvider::OpenAi),
445            expected_state: Some("expected-state".to_string()),
446            result_tx,
447        });
448
449        let html = handle_callback(
450            State(state),
451            Query(AuthCallbackParams {
452                code: Some("auth-code".to_string()),
453                error: None,
454                error_description: None,
455                state: Some("wrong-state".to_string()),
456            }),
457        )
458        .await;
459
460        let outcome = result_rx.recv().await.expect("callback outcome");
461        match outcome {
462            AuthCallbackOutcome::Error(message) => {
463                assert!(message.contains("state mismatch"));
464            }
465            _ => panic!("expected error outcome"),
466        }
467        assert!(html.0.contains("Authentication Failed"));
468    }
469}