Skip to main content

rusty_commit/auth/
gitlab_oauth.rs

1use anyhow::{Context, Result};
2use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::sync::Arc;
7use std::time::{Duration, SystemTime};
8use tokio::sync::Mutex;
9use tokio::time::sleep;
10
11// GitLab OAuth endpoints
12#[allow(dead_code)]
13pub const GITLAB_AUTHORIZE_URL: &str = "https://gitlab.com/oauth/authorize";
14pub const GITLAB_TOKEN_URL: &str = "https://gitlab.com/oauth/token";
15#[allow(dead_code)]
16pub const GITLAB_API_URL: &str = "https://gitlab.com/api/v4";
17
18// GitLab AI Gateway for Claude models
19#[allow(dead_code)]
20pub const GITLAB_AI_GATEWAY_URL: &str = "https://gitlab.ai/api/v1";
21
22#[derive(Debug, Serialize)]
23#[allow(dead_code)]
24struct GitLabTokenRequest {
25    grant_type: String,
26    code: String,
27    redirect_uri: String,
28    client_id: String,
29    code_verifier: String,
30}
31
32#[derive(Debug, Serialize)]
33#[allow(dead_code)]
34struct GitLabRefreshTokenRequest {
35    grant_type: String,
36    refresh_token: String,
37    client_id: String,
38}
39
40#[derive(Debug, Deserialize)]
41#[allow(dead_code)]
42pub struct GitLabTokenResponse {
43    pub access_token: String,
44    pub refresh_token: Option<String>,
45    pub token_type: String,
46    pub expires_in: Option<u64>,
47    pub scope: Option<String>,
48}
49
50#[derive(Debug, Deserialize)]
51struct GitLabErrorResponse {
52    error: String,
53    error_description: Option<String>,
54}
55
56/// OAuth client for GitLab AI (Claude models on GitLab)
57#[allow(dead_code)]
58pub struct GitLabOAuthClient {
59    client: Client,
60    client_id: String,
61    #[allow(dead_code)]
62    redirect_uri: String,
63}
64
65impl Default for GitLabOAuthClient {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71#[allow(dead_code)]
72impl GitLabOAuthClient {
73    pub fn new() -> Self {
74        // GitLab CLI OAuth app for rusty-commit
75        Self {
76            client: Client::new(),
77            client_id: "cde3d0a736a6f9d9e9b9e9b9e9b9e9b9e9b9e9b9e".to_string(), // Placeholder
78            redirect_uri: "http://localhost:8989/auth/callback".to_string(),
79        }
80    }
81
82    /// Generate PKCE challenge and verifier
83    fn generate_pkce() -> Result<(String, String)> {
84        let mut bytes = [0u8; 32];
85        generate_random_bytes(&mut bytes)?;
86        let verifier = URL_SAFE_NO_PAD.encode(bytes);
87
88        let mut hasher = Sha256::new();
89        hasher.update(verifier.as_bytes());
90        let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
91
92        Ok((verifier, challenge))
93    }
94
95    /// Generate random state for CSRF protection
96    fn generate_state() -> String {
97        let mut bytes = [0u8; 32];
98        // fill_bytes never fails, so we don't need unwrap_or_default
99        let _ = generate_random_bytes(&mut bytes);
100        URL_SAFE_NO_PAD.encode(bytes)
101    }
102
103    /// Build authorization URL with PKCE
104    pub fn get_authorization_url(&self) -> Result<(String, String)> {
105        let (verifier, challenge) = Self::generate_pkce()?;
106        let state = Self::generate_state();
107
108        let params = [
109            ("client_id", self.client_id.as_str()),
110            ("redirect_uri", self.redirect_uri.as_str()),
111            ("response_type", "code"),
112            ("scope", "read_user api openid"),
113            ("state", state.as_str()),
114            ("code_challenge", challenge.as_str()),
115            ("code_challenge_method", "S256"),
116        ];
117
118        let query = serde_urlencoded::to_string(params).context("Failed to encode OAuth params")?;
119        let auth_url = format!("{}?{}", GITLAB_AUTHORIZE_URL, query);
120
121        Ok((auth_url, verifier))
122    }
123
124    /// Start local server to receive OAuth callback
125    pub async fn start_callback_server(&self, verifier: String) -> Result<GitLabTokenResponse> {
126        use warp::Filter;
127
128        let code = Arc::new(Mutex::new(None));
129        let code_clone = code.clone();
130
131        let callback = warp::path("auth")
132            .and(warp::path("callback"))
133            .and(warp::query::<std::collections::HashMap<String, String>>())
134            .map(move |params: std::collections::HashMap<String, String>| {
135                if let Some(auth_code) = params.get("code") {
136                    let mut code_lock = code_clone.blocking_lock();
137                    *code_lock = Some(auth_code.clone());
138                }
139                warp::reply::html(r#"<!DOCTYPE html><html><head><title>Authenticated!</title></head><body style="font-family: system-ui; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #fc6d26;"><div style="background: white; padding: 2rem; border-radius: 8px; text-align: center;"><h1 style="color: #fc6d26;">Authentication Successful!</h1><p>You can close this window.</p></div></body></html>"#)
140            });
141
142        let server = warp::serve(callback).bind(([127, 0, 0, 1], 8989));
143        let server_handle = tokio::spawn(server);
144
145        let start = std::time::SystemTime::now();
146        let timeout = Duration::from_secs(300);
147
148        loop {
149            if let Some(auth_code) = &*code.lock().await {
150                let token = self.exchange_code_for_token(auth_code, &verifier).await?;
151                server_handle.abort();
152                return Ok(token);
153            }
154
155            if SystemTime::now().duration_since(start)? > timeout {
156                server_handle.abort();
157                anyhow::bail!("Authentication timeout");
158            }
159
160            sleep(Duration::from_millis(100)).await;
161        }
162    }
163
164    /// Exchange authorization code for access token
165    async fn exchange_code_for_token(
166        &self,
167        code: &str,
168        verifier: &str,
169    ) -> Result<GitLabTokenResponse> {
170        let params = [
171            ("grant_type", "authorization_code"),
172            ("code", code),
173            ("redirect_uri", self.redirect_uri.as_str()),
174            ("client_id", self.client_id.as_str()),
175            ("code_verifier", verifier),
176        ];
177
178        let response = self
179            .client
180            .post(GITLAB_TOKEN_URL)
181            .form(&params)
182            .send()
183            .await
184            .context("Failed to exchange code for token")?;
185
186        if response.status().is_success() {
187            response
188                .json::<GitLabTokenResponse>()
189                .await
190                .context("Failed to parse token response")
191        } else {
192            let error: GitLabErrorResponse = response.json().await?;
193            anyhow::bail!(
194                "Token exchange failed: {} - {}",
195                error.error,
196                error.error_description.unwrap_or_default()
197            )
198        }
199    }
200
201    /// Refresh an access token
202    #[allow(dead_code)]
203    pub async fn refresh_token(&self, refresh_token: &str) -> Result<GitLabTokenResponse> {
204        let params = [
205            ("grant_type", "refresh_token"),
206            ("refresh_token", refresh_token),
207            ("client_id", self.client_id.as_str()),
208        ];
209
210        let response = self
211            .client
212            .post(GITLAB_TOKEN_URL)
213            .form(&params)
214            .send()
215            .await
216            .context("Failed to refresh token")?;
217
218        if response.status().is_success() {
219            response
220                .json::<GitLabTokenResponse>()
221                .await
222                .context("Failed to parse refresh token response")
223        } else {
224            let error: GitLabErrorResponse = response.json().await?;
225            anyhow::bail!(
226                "Token refresh failed: {} - {}",
227                error.error,
228                error.error_description.unwrap_or_default()
229            )
230        }
231    }
232}
233
234/// Generate random bytes
235#[allow(dead_code)]
236fn generate_random_bytes(dest: &mut [u8]) -> Result<()> {
237    use rand::RngCore;
238    let mut rng = rand::rng();
239    rng.fill_bytes(dest);
240    Ok(())
241}