1use base64::Engine;
7use sha2::{Digest, Sha256};
8use std::time::{Duration, Instant};
9use thiserror::Error;
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::net::TcpListener;
12use tokio::sync::{Mutex, RwLock};
13use tokio::task::JoinHandle;
14use tokio::time::timeout;
15
16#[derive(Debug, Error)]
18pub enum OAuthError {
19 #[error("OIDC config missing required field: {0}")]
20 MissingConfig(&'static str),
21
22 #[error("HTTP request failed: {0}")]
23 HttpError(#[from] reqwest::Error),
24
25 #[error("Token exchange failed: {0}")]
26 TokenExchange(String),
27
28 #[error("Callback timeout (120s)")]
29 CallbackTimeout,
30
31 #[error("Invalid state parameter")]
32 InvalidState,
33
34 #[error("No authorization code in callback")]
35 NoCode,
36
37 #[error("No refresh token available")]
38 NoRefreshToken,
39
40 #[error("Failed to open browser: {0}")]
41 BrowserOpen(String),
42}
43
44struct TokenSet {
45 access_token: String,
46 refresh_token: Option<String>,
47 expires_at: Instant,
48}
49
50#[derive(serde::Deserialize)]
51struct TokenResponse {
52 access_token: String,
53 #[serde(default)]
54 refresh_token: Option<String>,
55 #[serde(default)]
56 expires_in: u64,
57}
58
59pub struct OAuthManager {
61 oidc_config: strike48_proto::proto::OidcConfig,
62 tokens: RwLock<Option<TokenSet>>,
63 #[allow(dead_code)] refresh_handle: Mutex<Option<JoinHandle<()>>>,
65}
66
67impl OAuthManager {
68 pub fn new(oidc_config: strike48_proto::proto::OidcConfig) -> Self {
70 Self {
71 oidc_config,
72 tokens: RwLock::new(None),
73 refresh_handle: Mutex::new(None),
74 }
75 }
76
77 pub async fn login_interactive(&self) -> Result<String, OAuthError> {
79 let auth_endpoint = Some(self.oidc_config.authorization_endpoint.as_str())
80 .filter(|s| !s.is_empty())
81 .ok_or(OAuthError::MissingConfig("authorization_endpoint"))?;
82 let token_endpoint = Some(self.oidc_config.token_endpoint.as_str())
83 .filter(|s| !s.is_empty())
84 .ok_or(OAuthError::MissingConfig("token_endpoint"))?;
85 let client_id = Some(self.oidc_config.client_id.as_str())
86 .filter(|s| !s.is_empty())
87 .ok_or(OAuthError::MissingConfig("client_id"))?;
88
89 let code_verifier = Self::generate_code_verifier();
91 let code_challenge = Self::compute_code_challenge(&code_verifier);
92 let state: String = (0..32)
93 .map(|_| rand::random::<u8>())
94 .map(|b| format!("{b:02x}"))
95 .collect();
96
97 let listener = TcpListener::bind("127.0.0.1:0")
99 .await
100 .map_err(|e| OAuthError::TokenExchange(format!("Failed to bind callback: {e}")))?;
101 let port = listener
102 .local_addr()
103 .map_err(|e| OAuthError::TokenExchange(format!("Failed to get local addr: {e}")))?
104 .port();
105 let redirect_uri = format!("http://127.0.0.1:{port}/callback");
106
107 let mut auth_url = format!(
109 "{}?response_type=code&client_id={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&state={}",
110 auth_endpoint.trim_end_matches('?'),
111 urlencoding::encode(client_id),
112 urlencoding::encode(&redirect_uri),
113 urlencoding::encode(&code_challenge),
114 urlencoding::encode(&state),
115 );
116 let scope_str: String = self
117 .oidc_config
118 .scopes
119 .iter()
120 .filter(|s| !s.is_empty())
121 .cloned()
122 .collect::<Vec<_>>()
123 .join(" ");
124 if !scope_str.is_empty() {
125 auth_url.push_str("&scope=");
126 auth_url.push_str(&urlencoding::encode(&scope_str));
127 }
128
129 open::that(&auth_url).map_err(|e| OAuthError::BrowserOpen(e.to_string()))?;
131
132 let (code, callback_state) = Self::wait_for_callback(&listener).await?;
134
135 if callback_state != state {
136 return Err(OAuthError::InvalidState);
137 }
138
139 let token_set = self
140 .exchange_code(
141 &code,
142 &redirect_uri,
143 &code_verifier,
144 token_endpoint,
145 client_id,
146 )
147 .await?;
148
149 let access_token = token_set.access_token.clone();
150 *self.tokens.write().await = Some(token_set);
151 Ok(access_token)
152 }
153
154 pub async fn get_token(&self) -> Result<String, OAuthError> {
156 let tokens = self.tokens.write().await;
157 if let Some(ref ts) = *tokens {
158 if ts
160 .expires_at
161 .saturating_duration_since(Instant::now())
162 .as_secs()
163 > 30
164 {
165 return Ok(ts.access_token.clone());
166 }
167 if ts.refresh_token.is_some() {
168 drop(tokens);
169 return self.refresh().await;
170 }
171 }
172 Err(OAuthError::NoRefreshToken)
173 }
174
175 async fn exchange_code(
176 &self,
177 code: &str,
178 redirect_uri: &str,
179 code_verifier: &str,
180 token_endpoint: &str,
181 client_id: &str,
182 ) -> Result<TokenSet, OAuthError> {
183 let client = reqwest::Client::new();
184 let params = [
185 ("grant_type", "authorization_code"),
186 ("code", code),
187 ("redirect_uri", redirect_uri),
188 ("client_id", client_id),
189 ("code_verifier", code_verifier),
190 ];
191
192 let resp = client.post(token_endpoint).form(¶ms).send().await?;
193
194 let status = resp.status();
195 let body = resp.text().await?;
196
197 if !status.is_success() {
198 return Err(OAuthError::TokenExchange(format!(
199 "Token exchange failed ({}): {}",
200 status, body
201 )));
202 }
203
204 let token_resp: TokenResponse = serde_json::from_str(&body)
205 .map_err(|e| OAuthError::TokenExchange(format!("Invalid token response: {e}")))?;
206
207 let expires_at =
208 Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(30).max(60));
209
210 Ok(TokenSet {
211 access_token: token_resp.access_token,
212 refresh_token: token_resp.refresh_token,
213 expires_at,
214 })
215 }
216
217 pub async fn refresh(&self) -> Result<String, OAuthError> {
219 let token_endpoint = Some(self.oidc_config.token_endpoint.as_str())
220 .filter(|s| !s.is_empty())
221 .ok_or(OAuthError::MissingConfig("token_endpoint"))?;
222 let client_id = Some(self.oidc_config.client_id.as_str())
223 .filter(|s| !s.is_empty())
224 .ok_or(OAuthError::MissingConfig("client_id"))?;
225
226 let refresh_token = {
227 let tokens = self.tokens.read().await;
228 tokens
229 .as_ref()
230 .and_then(|t| t.refresh_token.clone())
231 .ok_or(OAuthError::NoRefreshToken)?
232 };
233
234 let client = reqwest::Client::new();
235 let params = [
236 ("grant_type", "refresh_token"),
237 ("refresh_token", refresh_token.as_str()),
238 ("client_id", client_id),
239 ];
240
241 let resp = client.post(token_endpoint).form(¶ms).send().await?;
242
243 let status = resp.status();
244 let body = resp.text().await?;
245
246 if !status.is_success() {
247 return Err(OAuthError::TokenExchange(format!(
248 "Refresh failed ({}): {}",
249 status, body
250 )));
251 }
252
253 let token_resp: TokenResponse = serde_json::from_str(&body)
254 .map_err(|e| OAuthError::TokenExchange(format!("Invalid refresh response: {e}")))?;
255
256 let expires_at =
257 Instant::now() + Duration::from_secs(token_resp.expires_in.saturating_sub(30).max(60));
258
259 let new_tokens = TokenSet {
260 access_token: token_resp.access_token.clone(),
261 refresh_token: token_resp.refresh_token.or(Some(refresh_token)),
262 expires_at,
263 };
264
265 *self.tokens.write().await = Some(new_tokens);
266 Ok(token_resp.access_token)
267 }
268
269 fn generate_code_verifier() -> String {
270 let bytes: Vec<u8> = (0..64).map(|_| rand::random()).collect();
271 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
272 }
273
274 fn compute_code_challenge(verifier: &str) -> String {
275 let hash = Sha256::digest(verifier.as_bytes());
276 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
277 }
278
279 async fn wait_for_callback(listener: &TcpListener) -> Result<(String, String), OAuthError> {
280 let (stream, _) = timeout(Duration::from_secs(120), listener.accept())
281 .await
282 .map_err(|_| OAuthError::CallbackTimeout)?
283 .map_err(|e| OAuthError::TokenExchange(format!("Accept failed: {e}")))?;
284
285 let mut reader = BufReader::new(stream);
286 let mut request_line = String::new();
287 reader
288 .read_line(&mut request_line)
289 .await
290 .map_err(|e| OAuthError::TokenExchange(format!("Read failed: {e}")))?;
291
292 let mut code = None;
293 let mut state = None;
294 if let Some(path_query) = request_line.split_whitespace().nth(1) {
295 let (path, query) = path_query.split_once('?').unwrap_or((path_query, ""));
296 if path == "/callback" || path.starts_with("/callback") {
297 for pair in query.split('&') {
298 if let Some((k, v)) = pair.split_once('=') {
299 let v = urlencoding::decode(v).unwrap_or_default();
300 match k {
301 "code" => code = Some(v.into_owned()),
302 "state" => state = Some(v.into_owned()),
303 _ => {}
304 }
305 }
306 }
307 }
308 }
309
310 let code = code.ok_or(OAuthError::NoCode)?;
311 let state = state.unwrap_or_default();
312
313 let success = !request_line.contains("error=");
314 let (status, body) = if success {
315 (
316 "200 OK",
317 r#"<!DOCTYPE html><html><head><title>Success</title></head><body><h1>Login successful</h1><p>You can close this window.</p></body></html>"#,
318 )
319 } else {
320 (
321 "400 Bad Request",
322 r#"<!DOCTYPE html><html><head><title>Error</title></head><body><h1>Login failed</h1><p>Please try again.</p></body></html>"#,
323 )
324 };
325
326 let response = format!(
327 "HTTP/1.1 {status}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
328 body.len()
329 );
330
331 let mut stream = reader.into_inner();
332 stream
333 .write_all(response.as_bytes())
334 .await
335 .map_err(|e| OAuthError::TokenExchange(format!("Write response failed: {e}")))?;
336 stream
337 .flush()
338 .await
339 .map_err(|e| OAuthError::TokenExchange(format!("Flush failed: {e}")))?;
340
341 Ok((code, state))
342 }
343}