stakpak_shared/oauth/
flow.rs1use super::config::OAuthConfig;
4use super::error::{OAuthError, OAuthResult};
5use super::pkce::PkceChallenge;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TokenResponse {
11 pub access_token: String,
13 pub refresh_token: String,
15 pub expires_in: i64,
17 pub token_type: String,
19}
20
21pub struct OAuthFlow {
23 config: OAuthConfig,
24 pkce: Option<PkceChallenge>,
25}
26
27impl OAuthFlow {
28 pub fn new(config: OAuthConfig) -> Self {
30 Self { config, pkce: None }
31 }
32
33 pub fn generate_auth_url(&mut self) -> String {
38 let pkce = PkceChallenge::generate();
39
40 let url = format!(
41 "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}",
42 self.config.auth_url,
43 urlencoding::encode(&self.config.client_id),
44 urlencoding::encode(&self.config.redirect_url),
45 urlencoding::encode(&self.config.scopes_string()),
46 urlencoding::encode(&pkce.challenge),
47 PkceChallenge::challenge_method(),
48 urlencoding::encode(&pkce.verifier), );
50
51 self.pkce = Some(pkce);
52 url
53 }
54
55 pub async fn exchange_code(&self, code: &str) -> OAuthResult<TokenResponse> {
59 let pkce = self.pkce.as_ref().ok_or(OAuthError::PkceNotInitialized)?;
60
61 let (auth_code, state) = parse_auth_code(code)?;
63
64 if state != pkce.verifier {
66 return Err(OAuthError::invalid_code_format(
67 "State mismatch - possible CSRF attack",
68 ));
69 }
70
71 let client =
72 crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
73 .expect("Failed to create TLS client for OAuth token exchange");
74 let response = client
75 .post(&self.config.token_url)
76 .json(&serde_json::json!({
77 "grant_type": "authorization_code",
78 "code": auth_code,
79 "state": state,
80 "client_id": self.config.client_id,
81 "redirect_uri": self.config.redirect_url,
82 "code_verifier": pkce.verifier,
83 }))
84 .send()
85 .await?;
86
87 if !response.status().is_success() {
88 let status = response.status();
89 let error_text = response.text().await.unwrap_or_default();
90 return Err(OAuthError::token_exchange_failed(format!(
91 "HTTP {}: {}",
92 status, error_text
93 )));
94 }
95
96 response.json::<TokenResponse>().await.map_err(|e| {
97 OAuthError::token_exchange_failed(format!("Failed to parse token response: {}", e))
98 })
99 }
100
101 pub async fn refresh_token(&self, refresh_token: &str) -> OAuthResult<TokenResponse> {
103 let client =
104 crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
105 .expect("Failed to create TLS client for OAuth token refresh");
106 let response = client
107 .post(&self.config.token_url)
108 .json(&serde_json::json!({
109 "grant_type": "refresh_token",
110 "refresh_token": refresh_token,
111 "client_id": self.config.client_id,
112 }))
113 .send()
114 .await?;
115
116 if !response.status().is_success() {
117 let status = response.status();
118 let error_text = response.text().await.unwrap_or_default();
119 return Err(OAuthError::token_refresh_failed(format!(
120 "HTTP {}: {}",
121 status, error_text
122 )));
123 }
124
125 response.json::<TokenResponse>().await.map_err(|e| {
126 OAuthError::token_refresh_failed(format!("Failed to parse token response: {}", e))
127 })
128 }
129
130 pub fn pkce_verifier(&self) -> Option<&str> {
132 self.pkce.as_ref().map(|p| p.verifier.as_str())
133 }
134}
135
136#[allow(clippy::string_slice)] fn parse_auth_code(code: &str) -> OAuthResult<(String, String)> {
141 let code = code.replace("%23", "#");
143
144 if let Some(pos) = code.find('#') {
145 let auth_code = code[..pos].to_string();
146 let state = code[pos + 1..].to_string();
147
148 if auth_code.is_empty() || state.is_empty() {
149 return Err(OAuthError::invalid_code_format(
150 "Authorization code or state is empty",
151 ));
152 }
153
154 Ok((auth_code, state))
155 } else {
156 Err(OAuthError::invalid_code_format(
157 "Expected format: authorization_code#state",
158 ))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn test_config() -> OAuthConfig {
167 OAuthConfig::new(
168 "test-client-id",
169 "https://example.com/auth",
170 "https://example.com/token",
171 "https://example.com/callback",
172 vec!["scope1".to_string(), "scope2".to_string()],
173 )
174 }
175
176 #[test]
177 fn test_generate_auth_url() {
178 let mut flow = OAuthFlow::new(test_config());
179 let url = flow.generate_auth_url();
180
181 assert!(url.starts_with("https://example.com/auth?"));
182 assert!(url.contains("client_id=test-client-id"));
183 assert!(url.contains("response_type=code"));
184 assert!(url.contains("redirect_uri="));
185 assert!(url.contains("scope=scope1%20scope2"));
186 assert!(url.contains("code_challenge="));
187 assert!(url.contains("code_challenge_method=S256"));
188 assert!(url.contains("state="));
189
190 assert!(flow.pkce.is_some());
192 }
193
194 #[test]
195 fn test_parse_auth_code_valid() {
196 let result = parse_auth_code("abc123#verifier456");
197 assert!(result.is_ok());
198 let (code, state) = result.unwrap();
199 assert_eq!(code, "abc123");
200 assert_eq!(state, "verifier456");
201 }
202
203 #[test]
204 fn test_parse_auth_code_url_encoded() {
205 let result = parse_auth_code("abc123%23verifier456");
206 assert!(result.is_ok());
207 let (code, state) = result.unwrap();
208 assert_eq!(code, "abc123");
209 assert_eq!(state, "verifier456");
210 }
211
212 #[test]
213 fn test_parse_auth_code_missing_separator() {
214 let result = parse_auth_code("abc123verifier456");
215 assert!(result.is_err());
216 }
217
218 #[test]
219 fn test_parse_auth_code_empty_parts() {
220 assert!(parse_auth_code("#state").is_err());
221 assert!(parse_auth_code("code#").is_err());
222 assert!(parse_auth_code("#").is_err());
223 }
224
225 #[test]
226 fn test_exchange_code_without_pkce() {
227 let flow = OAuthFlow::new(test_config());
228 let result = tokio_test::block_on(flow.exchange_code("code#state"));
229 assert!(matches!(result, Err(OAuthError::PkceNotInitialized)));
230 }
231
232 #[test]
233 fn test_token_response_serde() {
234 let json = r#"{
235 "access_token": "access123",
236 "refresh_token": "refresh456",
237 "expires_in": 3600,
238 "token_type": "Bearer"
239 }"#;
240
241 let response: TokenResponse = serde_json::from_str(json).unwrap();
242 assert_eq!(response.access_token, "access123");
243 assert_eq!(response.refresh_token, "refresh456");
244 assert_eq!(response.expires_in, 3600);
245 assert_eq!(response.token_type, "Bearer");
246 }
247}