Skip to main content

rusty_commit/auth/
oauth.rs

1use anyhow::{Context, Result};
2use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8use tokio::sync::Mutex;
9use tokio::time::sleep;
10
11// Claude OAuth endpoints (similar to Claude Code)
12pub const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
13pub const TOKEN_URL: &str = "https://claude.ai/oauth/token";
14pub const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"; // Public client ID for CLI apps
15pub const REDIRECT_URI: &str = "http://localhost:8989/callback";
16
17#[derive(Debug, Serialize)]
18#[allow(dead_code)]
19struct DeviceCodeRequest {
20    client_id: String,
21    scope: String,
22}
23
24#[derive(Debug, Deserialize)]
25#[allow(dead_code)]
26pub struct DeviceCodeResponse {
27    pub device_code: String,
28    pub user_code: String,
29    pub verification_uri: String,
30    pub verification_uri_complete: Option<String>,
31    pub expires_in: u64,
32    pub interval: u64,
33}
34
35#[derive(Debug, Serialize)]
36#[allow(dead_code)]
37struct TokenRequest {
38    grant_type: String,
39    device_code: String,
40    client_id: String,
41}
42
43#[derive(Debug, Serialize)]
44#[allow(dead_code)]
45struct RefreshTokenRequest {
46    grant_type: String,
47    refresh_token: String,
48    client_id: String,
49}
50
51#[derive(Debug, Deserialize)]
52#[allow(dead_code)]
53pub struct TokenResponse {
54    pub access_token: String,
55    pub token_type: String,
56    pub expires_in: Option<u64>,
57    pub refresh_token: Option<String>,
58    pub scope: Option<String>,
59}
60
61#[derive(Debug, Deserialize)]
62struct ErrorResponse {
63    error: String,
64    error_description: Option<String>,
65}
66
67/// OAuth client for Claude authentication
68pub struct OAuthClient {
69    client: Client,
70    client_id: String,
71    redirect_uri: String,
72}
73
74impl Default for OAuthClient {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl OAuthClient {
81    pub fn new() -> Self {
82        Self {
83            client: Client::new(),
84            client_id: CLIENT_ID.to_string(),
85            redirect_uri: REDIRECT_URI.to_string(),
86        }
87    }
88
89    /// Generate PKCE challenge and verifier
90    fn generate_pkce() -> Result<(String, String)> {
91        // Generate random verifier
92        let mut bytes = [0u8; 32];
93        generate_random_bytes(&mut bytes)?;
94        let verifier = URL_SAFE_NO_PAD.encode(bytes);
95
96        // Generate challenge from verifier
97        let mut hasher = Sha256::new();
98        hasher.update(verifier.as_bytes());
99        let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
100
101        Ok((verifier, challenge))
102    }
103
104    /// Build authorization URL with PKCE
105    pub fn get_authorization_url(&self) -> Result<(String, String)> {
106        let (verifier, challenge) = Self::generate_pkce()?;
107
108        let state = URL_SAFE_NO_PAD.encode(uuid::Uuid::new_v4().as_bytes());
109
110        let params = [
111            ("client_id", &self.client_id),
112            ("redirect_uri", &self.redirect_uri),
113            ("response_type", &"code".to_string()),
114            ("scope", &"openid profile email".to_string()),
115            ("state", &state),
116            ("code_challenge", &challenge),
117            ("code_challenge_method", &"S256".to_string()),
118        ];
119
120        let query = serde_urlencoded::to_string(params).context("Failed to encode OAuth params")?;
121        let auth_url = format!("{AUTHORIZE_URL}?{query}");
122
123        Ok((auth_url, verifier))
124    }
125
126    /// Start local server to receive OAuth callback
127    pub async fn start_callback_server(&self, verifier: String) -> Result<TokenResponse> {
128        use warp::Filter;
129
130        let code = Arc::new(Mutex::new(None));
131        let code_clone = code.clone();
132
133        // Create callback route
134        let callback = warp::path("callback")
135            .and(warp::query::<std::collections::HashMap<String, String>>())
136            .map(move |params: std::collections::HashMap<String, String>| {
137                if let Some(auth_code) = params.get("code") {
138                    let mut code_guard = code_clone.blocking_lock();
139                    *code_guard = Some(auth_code.clone());
140                }
141
142                warp::reply::html(
143                    r#"
144                    <!DOCTYPE html>
145                    <html>
146                    <head>
147                        <title>Authentication Successful</title>
148                        <style>
149                            body {
150                                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
151                                display: flex;
152                                justify-content: center;
153                                align-items: center;
154                                height: 100vh;
155                                margin: 0;
156                                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
157                            }
158                            .container {
159                                background: white;
160                                padding: 3rem;
161                                border-radius: 12px;
162                                box-shadow: 0 20px 60px rgba(0,0,0,0.3);
163                                text-align: center;
164                                max-width: 400px;
165                            }
166                            h1 { color: #2d3748; margin-bottom: 1rem; }
167                            p { color: #718096; line-height: 1.6; }
168                            .check {
169                                width: 60px;
170                                height: 60px;
171                                margin: 0 auto 1.5rem;
172                                background: #48bb78;
173                                border-radius: 50%;
174                                display: flex;
175                                align-items: center;
176                                justify-content: center;
177                            }
178                            .check::after {
179                                content: '✓';
180                                color: white;
181                                font-size: 30px;
182                                font-weight: bold;
183                            }
184                        </style>
185                    </head>
186                    <body>
187                        <div class="container">
188                            <div class="check"></div>
189                            <h1>Authentication Successful!</h1>
190                            <p>You can now close this window and return to your terminal.</p>
191                        </div>
192                    </body>
193                    </html>
194                    "#
195                )
196            });
197
198        // Start server in background
199        let server = warp::serve(callback).bind(([127, 0, 0, 1], 8989));
200        let server_handle = tokio::spawn(server);
201
202        // Wait for code (with timeout)
203        let start = SystemTime::now();
204        let timeout = Duration::from_secs(300); // 5 minutes
205
206        loop {
207            if let Some(auth_code) = &*code.lock().await {
208                // Exchange code for token
209                let token = self.exchange_code_for_token(auth_code, &verifier).await?;
210                server_handle.abort();
211                return Ok(token);
212            }
213
214            if SystemTime::now().duration_since(start)? > timeout {
215                server_handle.abort();
216                anyhow::bail!("Authentication timeout - no response received");
217            }
218
219            sleep(Duration::from_millis(100)).await;
220        }
221    }
222
223    /// Exchange authorization code for access token
224    async fn exchange_code_for_token(&self, code: &str, verifier: &str) -> Result<TokenResponse> {
225        let params = [
226            ("grant_type", "authorization_code"),
227            ("code", code),
228            ("redirect_uri", &self.redirect_uri),
229            ("client_id", &self.client_id),
230            ("code_verifier", verifier),
231        ];
232
233        let response = self
234            .client
235            .post(TOKEN_URL)
236            .form(&params)
237            .send()
238            .await
239            .context("Failed to exchange code for token")?;
240
241        if response.status().is_success() {
242            response
243                .json::<TokenResponse>()
244                .await
245                .context("Failed to parse token response")
246        } else {
247            let error: ErrorResponse = response.json().await?;
248            anyhow::bail!(
249                "Token exchange failed: {} - {}",
250                error.error,
251                error.error_description.unwrap_or_default()
252            )
253        }
254    }
255
256    /// Refresh an access token
257    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
258        let request = RefreshTokenRequest {
259            grant_type: "refresh_token".to_string(),
260            refresh_token: refresh_token.to_string(),
261            client_id: self.client_id.clone(),
262        };
263
264        let response = self
265            .client
266            .post(TOKEN_URL)
267            .json(&request)
268            .send()
269            .await
270            .context("Failed to refresh token")?;
271
272        if response.status().is_success() {
273            response
274                .json::<TokenResponse>()
275                .await
276                .context("Failed to parse refresh token response")
277        } else {
278            let error: ErrorResponse = response.json().await?;
279            anyhow::bail!(
280                "Token refresh failed: {} - {}",
281                error.error,
282                error.error_description.unwrap_or_default()
283            )
284        }
285    }
286
287    /// Check if a token is expired
288    #[allow(dead_code)]
289    pub fn is_token_expired(expires_at: u64) -> bool {
290        let now = SystemTime::now()
291            .duration_since(UNIX_EPOCH)
292            .expect("System time before Unix epoch")
293            .as_secs();
294        now >= expires_at
295    }
296}
297
298// Generate random bytes for PKCE
299fn generate_random_bytes(dest: &mut [u8]) -> Result<()> {
300    use rand::RngCore;
301    let mut rng = rand::rng();
302    rng.fill_bytes(dest);
303    Ok(())
304}