1use super::config::{AuthorizationRequestMode, OAuthConfig, TokenRequestMode};
4use super::error::{OAuthError, OAuthResult};
5use super::pkce::PkceChallenge;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TokenResponse {
11 pub access_token: String,
13 pub refresh_token: String,
15 pub expires_in: i64,
17 pub token_type: String,
19}
20
21enum TokenRequest {
22 Json(serde_json::Value),
23 Form(Vec<(String, String)>),
24}
25
26pub struct OAuthFlow {
28 config: OAuthConfig,
29 pkce: Option<PkceChallenge>,
30 state: Option<String>,
31}
32
33impl OAuthFlow {
34 pub fn new(config: OAuthConfig) -> Self {
36 Self {
37 config,
38 pkce: None,
39 state: None,
40 }
41 }
42
43 pub fn generate_auth_url(&mut self) -> String {
48 let pkce = PkceChallenge::generate();
49 let state = uuid::Uuid::new_v4().simple().to_string();
50
51 let mut query = vec![
52 format!("client_id={}", urlencoding::encode(&self.config.client_id)),
53 "response_type=code".to_string(),
54 format!(
55 "redirect_uri={}",
56 urlencoding::encode(&self.config.redirect_url)
57 ),
58 format!(
59 "scope={}",
60 urlencoding::encode(&self.config.scopes_string())
61 ),
62 format!("code_challenge={}", urlencoding::encode(&pkce.challenge)),
63 format!(
64 "code_challenge_method={}",
65 PkceChallenge::challenge_method()
66 ),
67 format!("state={}", urlencoding::encode(&state)),
68 ];
69
70 if self.config.authorization_request_mode == AuthorizationRequestMode::LegacyCode {
71 query.insert(0, "code=true".to_string());
72 }
73
74 query.extend(self.config.authorization_params.iter().map(|(key, value)| {
75 format!(
76 "{}={}",
77 urlencoding::encode(key),
78 urlencoding::encode(value)
79 )
80 }));
81
82 let url = format!("{}?{}", self.config.auth_url, query.join("&"));
83
84 self.pkce = Some(pkce);
85 self.state = Some(state);
86 url
87 }
88
89 fn build_token_exchange_request(
90 &self,
91 auth_code: String,
92 state: String,
93 ) -> OAuthResult<TokenRequest> {
94 let pkce = self.pkce.as_ref().ok_or(OAuthError::PkceNotInitialized)?;
95
96 Ok(match self.config.token_request_mode {
97 TokenRequestMode::Json => TokenRequest::Json(serde_json::json!({
98 "grant_type": "authorization_code",
99 "code": auth_code,
100 "state": state,
101 "client_id": self.config.client_id,
102 "redirect_uri": self.config.redirect_url,
103 "code_verifier": pkce.verifier,
104 })),
105 TokenRequestMode::FormUrlEncoded => TokenRequest::Form(vec![
106 ("grant_type".to_string(), "authorization_code".to_string()),
110 ("code".to_string(), auth_code),
111 ("client_id".to_string(), self.config.client_id.clone()),
112 ("redirect_uri".to_string(), self.config.redirect_url.clone()),
113 ("code_verifier".to_string(), pkce.verifier.clone()),
114 ]),
115 })
116 }
117
118 fn build_token_refresh_request(&self, refresh_token: String) -> TokenRequest {
119 match self.config.token_request_mode {
120 TokenRequestMode::Json => TokenRequest::Json(serde_json::json!({
121 "grant_type": "refresh_token",
122 "refresh_token": refresh_token,
123 "client_id": self.config.client_id,
124 })),
125 TokenRequestMode::FormUrlEncoded => TokenRequest::Form(vec![
126 ("grant_type".to_string(), "refresh_token".to_string()),
127 ("refresh_token".to_string(), refresh_token),
128 ("client_id".to_string(), self.config.client_id.clone()),
129 ]),
130 }
131 }
132
133 pub async fn exchange_code(&self, code: &str) -> OAuthResult<TokenResponse> {
139 let (auth_code, state) = parse_auth_code(code)?;
140 self.exchange_code_with_state(&auth_code, &state).await
141 }
142
143 pub async fn exchange_code_with_state(
145 &self,
146 auth_code: &str,
147 state: &str,
148 ) -> OAuthResult<TokenResponse> {
149 let _pkce = self.pkce.as_ref().ok_or(OAuthError::PkceNotInitialized)?;
150
151 let expected_state = self
152 .state
153 .as_deref()
154 .ok_or(OAuthError::PkceNotInitialized)?;
155
156 if state != expected_state {
159 return Err(OAuthError::invalid_code_format(
160 "State mismatch - possible CSRF attack",
161 ));
162 }
163
164 let token_request =
165 self.build_token_exchange_request(auth_code.to_string(), state.to_string())?;
166
167 let client =
168 crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
169 .expect("Failed to create TLS client for OAuth token exchange");
170 let response = match token_request {
171 TokenRequest::Json(body) => client.post(&self.config.token_url).json(&body),
172 TokenRequest::Form(body) => client.post(&self.config.token_url).form(&body),
173 }
174 .send()
175 .await?;
176
177 if !response.status().is_success() {
178 let status = response.status();
179 let error_text = response.text().await.unwrap_or_default();
180 return Err(OAuthError::token_exchange_failed(format!(
181 "HTTP {}: {}",
182 status, error_text
183 )));
184 }
185
186 response.json::<TokenResponse>().await.map_err(|e| {
187 OAuthError::token_exchange_failed(format!("Failed to parse token response: {}", e))
188 })
189 }
190
191 pub async fn refresh_token(&self, refresh_token: &str) -> OAuthResult<TokenResponse> {
193 let token_request = self.build_token_refresh_request(refresh_token.to_string());
194 let client =
195 crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
196 .expect("Failed to create TLS client for OAuth token refresh");
197 let response = match token_request {
198 TokenRequest::Json(body) => client.post(&self.config.token_url).json(&body),
199 TokenRequest::Form(body) => client.post(&self.config.token_url).form(&body),
200 }
201 .send()
202 .await?;
203
204 if !response.status().is_success() {
205 let status = response.status();
206 let error_text = response.text().await.unwrap_or_default();
207 return Err(OAuthError::token_refresh_failed(format!(
208 "HTTP {}: {}",
209 status, error_text
210 )));
211 }
212
213 response.json::<TokenResponse>().await.map_err(|e| {
214 OAuthError::token_refresh_failed(format!("Failed to parse token response: {}", e))
215 })
216 }
217
218 pub fn pkce_verifier(&self) -> Option<&str> {
220 self.pkce.as_ref().map(|p| p.verifier.as_str())
221 }
222}
223
224#[allow(clippy::string_slice)] fn parse_auth_code(code: &str) -> OAuthResult<(String, String)> {
229 let code = code.replace("%23", "#");
231
232 if let Some(pos) = code.find('#') {
233 let auth_code = code[..pos].to_string();
234 let state = code[pos + 1..].to_string();
235
236 if auth_code.is_empty() || state.is_empty() {
237 return Err(OAuthError::invalid_code_format(
238 "Authorization code or state is empty",
239 ));
240 }
241
242 Ok((auth_code, state))
243 } else {
244 Err(OAuthError::invalid_code_format(
245 "Expected format: authorization_code#state",
246 ))
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::oauth::config::AuthorizationRequestMode;
254
255 fn test_config() -> OAuthConfig {
256 OAuthConfig::new(
257 "test-client-id",
258 "https://example.com/auth",
259 "https://example.com/token",
260 "https://example.com/callback",
261 vec!["scope1".to_string(), "scope2".to_string()],
262 )
263 }
264
265 #[test]
266 fn test_generate_auth_url_standard_pkce() {
267 let mut flow = OAuthFlow::new(test_config());
268 let url = flow.generate_auth_url();
269
270 assert!(url.starts_with("https://example.com/auth?"));
271 assert!(url.contains("client_id=test-client-id"));
272 assert!(url.contains("response_type=code"));
273 assert!(url.contains("redirect_uri="));
274 assert!(url.contains("scope=scope1%20scope2"));
275 assert!(url.contains("code_challenge="));
276 assert!(url.contains("code_challenge_method=S256"));
277 assert!(url.contains("state="));
278 assert!(!url.contains("code=true"));
279
280 assert!(flow.pkce.is_some());
282 }
283
284 #[test]
285 fn test_generate_auth_url_legacy_mode_includes_code_param() {
286 let mut flow = OAuthFlow::new(
287 test_config().with_authorization_request_mode(AuthorizationRequestMode::LegacyCode),
288 );
289 let url = flow.generate_auth_url();
290
291 assert!(url.contains("code=true"));
292 assert!(url.contains("response_type=code"));
293 }
294
295 #[test]
296 fn test_generate_auth_url_includes_provider_specific_params() {
297 let mut flow = OAuthFlow::new(test_config().with_authorization_params(vec![
298 ("id_token_add_organizations", "true"),
299 ("codex_cli_simplified_flow", "true"),
300 ("originator", "stakpak"),
301 ]));
302 let url = flow.generate_auth_url();
303
304 assert!(url.contains("id_token_add_organizations=true"));
305 assert!(url.contains("codex_cli_simplified_flow=true"));
306 assert!(url.contains("originator=stakpak"));
307 }
308
309 #[test]
310 fn test_generate_auth_url_uses_separate_state_from_pkce_verifier() {
311 let mut flow = OAuthFlow::new(test_config());
312 let url = flow.generate_auth_url();
313 let parsed = reqwest::Url::parse(&url).expect("parse auth url");
314 let state = parsed
315 .query_pairs()
316 .find(|(key, _)| key == "state")
317 .map(|(_, value)| value.to_string())
318 .expect("state param");
319
320 assert_ne!(Some(state.as_str()), flow.pkce_verifier());
321 }
322
323 #[test]
324 fn test_openai_token_exchange_request_uses_form_encoding_without_state() {
325 let mut flow = OAuthFlow::new(
326 test_config()
327 .with_token_request_mode(crate::oauth::config::TokenRequestMode::FormUrlEncoded),
328 );
329 let _ = flow.generate_auth_url();
330 let request = flow
331 .build_token_exchange_request("auth-code".to_string(), "callback-state".to_string())
332 .expect("token exchange request");
333
334 match request {
335 TokenRequest::Form(params) => {
336 assert!(
337 params.contains(&("grant_type".to_string(), "authorization_code".to_string()))
338 );
339 assert!(params.contains(&("code".to_string(), "auth-code".to_string())));
340 assert!(params.contains(&("client_id".to_string(), "test-client-id".to_string())));
341 assert!(params.iter().all(|(key, _)| key != "state"));
342 }
343 TokenRequest::Json(_) => panic!("expected form request"),
344 }
345 }
346
347 #[test]
348 fn test_openai_token_refresh_request_uses_form_encoding() {
349 let flow = OAuthFlow::new(
350 test_config()
351 .with_token_request_mode(crate::oauth::config::TokenRequestMode::FormUrlEncoded),
352 );
353 let request = flow.build_token_refresh_request("refresh-token".to_string());
354
355 match request {
356 TokenRequest::Form(params) => {
357 assert!(params.contains(&("grant_type".to_string(), "refresh_token".to_string())));
358 assert!(
359 params.contains(&("refresh_token".to_string(), "refresh-token".to_string()))
360 );
361 assert!(params.contains(&("client_id".to_string(), "test-client-id".to_string())));
362 }
363 TokenRequest::Json(_) => panic!("expected form request"),
364 }
365 }
366
367 #[test]
368 fn test_parse_auth_code_valid() {
369 let result = parse_auth_code("abc123#verifier456");
370 assert!(result.is_ok());
371 let (code, state) = result.unwrap();
372 assert_eq!(code, "abc123");
373 assert_eq!(state, "verifier456");
374 }
375
376 #[test]
377 fn test_parse_auth_code_url_encoded() {
378 let result = parse_auth_code("abc123%23verifier456");
379 assert!(result.is_ok());
380 let (code, state) = result.unwrap();
381 assert_eq!(code, "abc123");
382 assert_eq!(state, "verifier456");
383 }
384
385 #[test]
386 fn test_parse_auth_code_missing_separator() {
387 let result = parse_auth_code("abc123verifier456");
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_parse_auth_code_empty_parts() {
393 assert!(parse_auth_code("#state").is_err());
394 assert!(parse_auth_code("code#").is_err());
395 assert!(parse_auth_code("#").is_err());
396 }
397
398 #[test]
399 fn test_exchange_code_without_pkce() {
400 let flow = OAuthFlow::new(test_config());
401 let result = tokio_test::block_on(flow.exchange_code("code#state"));
402 assert!(matches!(result, Err(OAuthError::PkceNotInitialized)));
403 }
404
405 #[test]
406 fn test_exchange_code_with_state_without_pkce() {
407 let flow = OAuthFlow::new(test_config());
408 let result = tokio_test::block_on(flow.exchange_code_with_state("code", "state"));
409 assert!(matches!(result, Err(OAuthError::PkceNotInitialized)));
410 }
411
412 #[test]
413 fn test_token_response_serde() {
414 let json = r#"{
415 "access_token": "access123",
416 "refresh_token": "refresh456",
417 "expires_in": 3600,
418 "token_type": "Bearer"
419 }"#;
420
421 let response: TokenResponse = serde_json::from_str(json).unwrap();
422 assert_eq!(response.access_token, "access123");
423 assert_eq!(response.refresh_token, "refresh456");
424 assert_eq!(response.expires_in, 3600);
425 assert_eq!(response.token_type, "Bearer");
426 }
427}