Skip to main content

strike48_connector/auth/
oauth.rs

1//! OAuth PKCE flow for desktop connectors.
2//!
3//! Provides interactive browser-based login using the authorization code flow
4//! with PKCE (Proof Key for Code Exchange).
5
6use base64::Engine;
7use sha2::{Digest, Sha256};
8use std::time::{Duration, Instant};
9use thiserror::Error;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::net::TcpListener;
12use tokio::sync::{Mutex, RwLock};
13use tokio::task::JoinHandle;
14use tokio::time::timeout;
15
16/// OAuth-related errors.
17#[derive(Debug, Error)]
18pub enum OAuthError {
19    #[error("OIDC config missing required field: {0}")]
20    MissingConfig(&'static str),
21
22    #[error("HTTP request failed: {0}")]
23    HttpError(#[from] reqwest::Error),
24
25    #[error("Token exchange failed: {0}")]
26    TokenExchange(String),
27
28    #[error("Callback timeout (120s)")]
29    CallbackTimeout,
30
31    #[error("Invalid state parameter")]
32    InvalidState,
33
34    #[error("No authorization code in callback")]
35    NoCode,
36
37    #[error("No refresh token available")]
38    NoRefreshToken,
39
40    #[error("Failed to open browser: {0}")]
41    BrowserOpen(String),
42}
43
44struct TokenSet {
45    access_token: String,
46    refresh_token: Option<String>,
47    expires_at: Instant,
48}
49
50#[derive(serde::Deserialize)]
51struct TokenResponse {
52    access_token: String,
53    #[serde(default)]
54    refresh_token: Option<String>,
55    #[serde(default)]
56    expires_in: u64,
57}
58
59/// OAuth manager for PKCE flow in desktop connectors.
60pub struct OAuthManager {
61    oidc_config: strike48_proto::proto::OidcConfig,
62    tokens: RwLock<Option<TokenSet>>,
63    #[allow(dead_code)] // Reserved for background token refresh
64    refresh_handle: Mutex<Option<JoinHandle<()>>>,
65}
66
67impl OAuthManager {
68    /// Create a new OAuth manager with the given OIDC configuration.
69    pub fn new(oidc_config: strike48_proto::proto::OidcConfig) -> Self {
70        Self {
71            oidc_config,
72            tokens: RwLock::new(None),
73            refresh_handle: Mutex::new(None),
74        }
75    }
76
77    /// Perform interactive login: open browser, bind localhost callback, exchange code for tokens.
78    pub async fn login_interactive(&self) -> Result<String, OAuthError> {
79        let auth_endpoint = Some(self.oidc_config.authorization_endpoint.as_str())
80            .filter(|s| !s.is_empty())
81            .ok_or(OAuthError::MissingConfig("authorization_endpoint"))?;
82        let token_endpoint = Some(self.oidc_config.token_endpoint.as_str())
83            .filter(|s| !s.is_empty())
84            .ok_or(OAuthError::MissingConfig("token_endpoint"))?;
85        let client_id = Some(self.oidc_config.client_id.as_str())
86            .filter(|s| !s.is_empty())
87            .ok_or(OAuthError::MissingConfig("client_id"))?;
88
89        // Generate PKCE code_verifier (43-128 random bytes, base64url-encoded)
90        let code_verifier = Self::generate_code_verifier();
91        let code_challenge = Self::compute_code_challenge(&code_verifier);
92        let state: String = (0..32)
93            .map(|_| rand::random::<u8>())
94            .map(|b| format!("{b:02x}"))
95            .collect();
96
97        // Bind callback server on random port
98        let listener = TcpListener::bind("127.0.0.1:0")
99            .await
100            .map_err(|e| OAuthError::TokenExchange(format!("Failed to bind callback: {e}")))?;
101        let port = listener
102            .local_addr()
103            .map_err(|e| OAuthError::TokenExchange(format!("Failed to get local addr: {e}")))?
104            .port();
105        let redirect_uri = format!("http://127.0.0.1:{port}/callback");
106
107        // Build authorization URL
108        let mut auth_url = format!(
109            "{}?response_type=code&client_id={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
110            auth_endpoint.trim_end_matches('?'),
111            urlencoding::encode(client_id),
112            urlencoding::encode(&redirect_uri),
113            urlencoding::encode(&code_challenge),
114            urlencoding::encode(&state),
115        );
116        let scope_str: String = self
117            .oidc_config
118            .scopes
119            .iter()
120            .filter(|s| !s.is_empty())
121            .cloned()
122            .collect::<Vec<_>>()
123            .join(" ");
124        if !scope_str.is_empty() {
125            auth_url.push_str("&scope=");
126            auth_url.push_str(&urlencoding::encode(&scope_str));
127        }
128
129        // Open browser
130        open::that(&auth_url).map_err(|e| OAuthError::BrowserOpen(e.to_string()))?;
131
132        // Wait for callback with 120s timeout
133        let (code, callback_state) = Self::wait_for_callback(&listener).await?;
134
135        if callback_state != state {
136            return Err(OAuthError::InvalidState);
137        }
138
139        let token_set = self
140            .exchange_code(
141                &code,
142                &redirect_uri,
143                &code_verifier,
144                token_endpoint,
145                client_id,
146            )
147            .await?;
148
149        let access_token = token_set.access_token.clone();
150        *self.tokens.write().await = Some(token_set);
151        Ok(access_token)
152    }
153
154    /// Get current valid access token. Refreshes if expired and refresh_token is available.
155    pub async fn get_token(&self) -> Result<String, OAuthError> {
156        let tokens = self.tokens.write().await;
157        if let Some(ref ts) = *tokens {
158            // Consider expired 30s before actual expiry
159            if ts
160                .expires_at
161                .saturating_duration_since(Instant::now())
162                .as_secs()
163                > 30
164            {
165                return Ok(ts.access_token.clone());
166            }
167            if ts.refresh_token.is_some() {
168                drop(tokens);
169                return self.refresh().await;
170            }
171        }
172        Err(OAuthError::NoRefreshToken)
173    }
174
175    async fn exchange_code(
176        &self,
177        code: &str,
178        redirect_uri: &str,
179        code_verifier: &str,
180        token_endpoint: &str,
181        client_id: &str,
182    ) -> Result<TokenSet, OAuthError> {
183        let client = reqwest::Client::new();
184        let params = [
185            ("grant_type", "authorization_code"),
186            ("code", code),
187            ("redirect_uri", redirect_uri),
188            ("client_id", client_id),
189            ("code_verifier", code_verifier),
190        ];
191
192        let resp = client.post(token_endpoint).form(&params).send().await?;
193
194        let status = resp.status();
195        let body = resp.text().await?;
196
197        if !status.is_success() {
198            return Err(OAuthError::TokenExchange(format!(
199                "Token exchange failed ({}): {}",
200                status, body
201            )));
202        }
203
204        let token_resp: TokenResponse = serde_json::from_str(&body)
205            .map_err(|e| OAuthError::TokenExchange(format!("Invalid token response: {e}")))?;
206
207        let expires_at =
208            Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(30).max(60));
209
210        Ok(TokenSet {
211            access_token: token_resp.access_token,
212            refresh_token: token_resp.refresh_token,
213            expires_at,
214        })
215    }
216
217    /// Refresh the access token using the refresh_token grant.
218    pub async fn refresh(&self) -> Result<String, OAuthError> {
219        let token_endpoint = Some(self.oidc_config.token_endpoint.as_str())
220            .filter(|s| !s.is_empty())
221            .ok_or(OAuthError::MissingConfig("token_endpoint"))?;
222        let client_id = Some(self.oidc_config.client_id.as_str())
223            .filter(|s| !s.is_empty())
224            .ok_or(OAuthError::MissingConfig("client_id"))?;
225
226        let refresh_token = {
227            let tokens = self.tokens.read().await;
228            tokens
229                .as_ref()
230                .and_then(|t| t.refresh_token.clone())
231                .ok_or(OAuthError::NoRefreshToken)?
232        };
233
234        let client = reqwest::Client::new();
235        let params = [
236            ("grant_type", "refresh_token"),
237            ("refresh_token", refresh_token.as_str()),
238            ("client_id", client_id),
239        ];
240
241        let resp = client.post(token_endpoint).form(&params).send().await?;
242
243        let status = resp.status();
244        let body = resp.text().await?;
245
246        if !status.is_success() {
247            return Err(OAuthError::TokenExchange(format!(
248                "Refresh failed ({}): {}",
249                status, body
250            )));
251        }
252
253        let token_resp: TokenResponse = serde_json::from_str(&body)
254            .map_err(|e| OAuthError::TokenExchange(format!("Invalid refresh response: {e}")))?;
255
256        let expires_at =
257            Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(30).max(60));
258
259        let new_tokens = TokenSet {
260            access_token: token_resp.access_token.clone(),
261            refresh_token: token_resp.refresh_token.or(Some(refresh_token)),
262            expires_at,
263        };
264
265        *self.tokens.write().await = Some(new_tokens);
266        Ok(token_resp.access_token)
267    }
268
269    fn generate_code_verifier() -> String {
270        let bytes: Vec<u8> = (0..64).map(|_| rand::random()).collect();
271        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
272    }
273
274    fn compute_code_challenge(verifier: &str) -> String {
275        let hash = Sha256::digest(verifier.as_bytes());
276        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
277    }
278
279    async fn wait_for_callback(listener: &TcpListener) -> Result<(String, String), OAuthError> {
280        let (stream, _) = timeout(Duration::from_secs(120), listener.accept())
281            .await
282            .map_err(|_| OAuthError::CallbackTimeout)?
283            .map_err(|e| OAuthError::TokenExchange(format!("Accept failed: {e}")))?;
284
285        let mut reader = BufReader::new(stream);
286        let mut request_line = String::new();
287        reader
288            .read_line(&mut request_line)
289            .await
290            .map_err(|e| OAuthError::TokenExchange(format!("Read failed: {e}")))?;
291
292        let mut code = None;
293        let mut state = None;
294        if let Some(path_query) = request_line.split_whitespace().nth(1) {
295            let (path, query) = path_query.split_once('?').unwrap_or((path_query, ""));
296            if path == "/callback" || path.starts_with("/callback") {
297                for pair in query.split('&') {
298                    if let Some((k, v)) = pair.split_once('=') {
299                        let v = urlencoding::decode(v).unwrap_or_default();
300                        match k {
301                            "code" => code = Some(v.into_owned()),
302                            "state" => state = Some(v.into_owned()),
303                            _ => {}
304                        }
305                    }
306                }
307            }
308        }
309
310        let code = code.ok_or(OAuthError::NoCode)?;
311        let state = state.unwrap_or_default();
312
313        let success = !request_line.contains("error=");
314        let (status, body) = if success {
315            (
316                "200 OK",
317                r#"<!DOCTYPE html><html><head><title>Success</title></head><body><h1>Login successful</h1><p>You can close this window.</p></body></html>"#,
318            )
319        } else {
320            (
321                "400 Bad Request",
322                r#"<!DOCTYPE html><html><head><title>Error</title></head><body><h1>Login failed</h1><p>Please try again.</p></body></html>"#,
323            )
324        };
325
326        let response = format!(
327            "HTTP/1.1 {status}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
328            body.len()
329        );
330
331        let mut stream = reader.into_inner();
332        stream
333            .write_all(response.as_bytes())
334            .await
335            .map_err(|e| OAuthError::TokenExchange(format!("Write response failed: {e}")))?;
336        stream
337            .flush()
338            .await
339            .map_err(|e| OAuthError::TokenExchange(format!("Flush failed: {e}")))?;
340
341        Ok((code, state))
342    }
343}