wae_authentication/oauth2/
client.rs1use crate::oauth2::{AuthorizationUrl, OAuth2ClientConfig, TokenResponse, UserInfo};
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use wae_request::{HttpClient, HttpClientConfig};
8use wae_types::{WaeError, WaeErrorKind};
9
10pub type OAuth2Result<T> = Result<T, WaeError>;
12
13#[derive(Debug, Clone)]
15pub struct OAuth2Client {
16 config: OAuth2ClientConfig,
17 http_client: HttpClient,
18}
19
20impl OAuth2Client {
21 pub fn new(config: OAuth2ClientConfig) -> OAuth2Result<Self> {
26 let http_config = HttpClientConfig {
27 timeout: std::time::Duration::from_millis(config.timeout_ms),
28 connect_timeout: std::time::Duration::from_secs(10),
29 user_agent: "wae-oauth2/0.1.0".to_string(),
30 max_retries: 3,
31 retry_delay: std::time::Duration::from_millis(1000),
32 default_headers: HashMap::new(),
33 };
34
35 let http_client = HttpClient::new(http_config);
36
37 Ok(Self { config, http_client })
38 }
39
40 pub fn authorization_url(&self) -> OAuth2Result<AuthorizationUrl> {
45 let state = self.generate_state();
46 let mut params: Vec<(String, String)> = vec![
47 ("client_id".to_string(), self.config.provider.client_id.clone()),
48 ("redirect_uri".to_string(), self.config.provider.redirect_uri.clone()),
49 ("response_type".to_string(), "code".to_string()),
50 ("state".to_string(), state.clone()),
51 ];
52
53 if !self.config.provider.scopes.is_empty() {
54 params.push(("scope".to_string(), self.config.provider.scopes.join(" ")));
55 }
56
57 let code_verifier = if self.config.use_pkce {
58 let verifier = self.generate_code_verifier();
59 let challenge = self.generate_code_challenge(&verifier);
60 params.push(("code_challenge".to_string(), challenge));
61 params.push(("code_challenge_method".to_string(), "S256".to_string()));
62 Some(verifier)
63 }
64 else {
65 None
66 };
67
68 for (key, value) in &self.config.provider.extra_params {
69 params.push((key.clone(), value.clone()));
70 }
71
72 let query = params
73 .iter()
74 .map(|(k, v)| format!("{}={}", wae_types::url_encode(k), wae_types::url_encode(v)))
75 .collect::<Vec<_>>()
76 .join("&");
77
78 let url = format!("{}?{}", self.config.provider.authorization_url, query);
79
80 Ok(AuthorizationUrl { url, state, code_verifier })
81 }
82
83 pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> OAuth2Result<TokenResponse> {
89 let mut params = HashMap::new();
90 params.insert("grant_type", "authorization_code".to_string());
91 params.insert("code", code.to_string());
92 params.insert("redirect_uri", self.config.provider.redirect_uri.clone());
93 params.insert("client_id", self.config.provider.client_id.clone());
94 params.insert("client_secret", self.config.provider.client_secret.clone());
95
96 if let Some(verifier) = code_verifier {
97 params.insert("code_verifier", verifier.to_string());
98 }
99
100 let form_body = self.encode_form_data(¶ms);
101
102 let response = self
103 .http_client
104 .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
105 .await
106 .map_err(WaeError::from)?;
107
108 if !response.is_success() {
109 let error_text = response.text().unwrap_or_default();
110 return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
111 }
112
113 let token_response: TokenResponse = response.json().map_err(WaeError::from)?;
114 Ok(token_response)
115 }
116
117 pub async fn refresh_token(&self, refresh_token: &str) -> OAuth2Result<TokenResponse> {
122 let mut params = HashMap::new();
123 params.insert("grant_type", "refresh_token".to_string());
124 params.insert("refresh_token", refresh_token.to_string());
125 params.insert("client_id", self.config.provider.client_id.clone());
126 params.insert("client_secret", self.config.provider.client_secret.clone());
127
128 let form_body = self.encode_form_data(¶ms);
129
130 let response = self
131 .http_client
132 .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
133 .await
134 .map_err(WaeError::from)?;
135
136 if !response.is_success() {
137 let error_text = response.text().unwrap_or_default();
138 return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
139 }
140
141 let token_response: TokenResponse = response.json().map_err(WaeError::from)?;
142 Ok(token_response)
143 }
144
145 pub async fn get_user_info(&self, access_token: &str) -> OAuth2Result<UserInfo> {
150 let userinfo_url = self
151 .config
152 .provider
153 .userinfo_url
154 .as_ref()
155 .ok_or_else(|| WaeError::config_invalid("userinfo_url", "not configured"))?;
156
157 let mut headers = HashMap::new();
158 headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
159
160 let response = self.http_client.get_with_headers(userinfo_url, headers).await.map_err(WaeError::from)?;
161
162 if !response.is_success() {
163 let error_text = response.text().unwrap_or_default();
164 return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
165 }
166
167 let user_info: UserInfo = response.json().map_err(WaeError::from)?;
168 Ok(user_info)
169 }
170
171 pub async fn revoke_token(&self, token: &str, token_type_hint: Option<&str>) -> OAuth2Result<()> {
177 let revocation_url = self
178 .config
179 .provider
180 .revocation_url
181 .as_ref()
182 .ok_or_else(|| WaeError::config_invalid("revocation_url", "not configured"))?;
183
184 let mut params = HashMap::new();
185 params.insert("token", token.to_string());
186 params.insert("client_id", self.config.provider.client_id.clone());
187 params.insert("client_secret", self.config.provider.client_secret.clone());
188
189 if let Some(hint) = token_type_hint {
190 params.insert("token_type_hint", hint.to_string());
191 }
192
193 let form_body = self.encode_form_data(¶ms);
194
195 let response = self
196 .http_client
197 .post_with_headers(revocation_url, form_body.into_bytes(), self.form_headers())
198 .await
199 .map_err(WaeError::from)?;
200
201 if !response.is_success() {
202 let error_text = response.text().unwrap_or_default();
203 return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
204 }
205
206 Ok(())
207 }
208
209 pub fn validate_state(&self, expected: &str, received: &str) -> OAuth2Result<()> {
215 if !self.config.use_state {
216 return Ok(());
217 }
218
219 if expected == received { Ok(()) } else { Err(WaeError::new(WaeErrorKind::StateMismatch)) }
220 }
221
222 fn generate_state(&self) -> String {
223 uuid::Uuid::new_v4().to_string().replace('-', "")
224 }
225
226 fn generate_code_verifier(&self) -> String {
227 let random_bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
228 URL_SAFE_NO_PAD.encode(&random_bytes)
229 }
230
231 fn generate_code_challenge(&self, verifier: &str) -> String {
232 let mut hasher = Sha256::new();
233 hasher.update(verifier.as_bytes());
234 let hash = hasher.finalize();
235 URL_SAFE_NO_PAD.encode(&hash)
236 }
237
238 fn encode_form_data(&self, params: &HashMap<&str, String>) -> String {
239 params
240 .iter()
241 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
242 .collect::<Vec<_>>()
243 .join("&")
244 }
245
246 fn form_headers(&self) -> HashMap<String, String> {
247 let mut headers = HashMap::new();
248 headers.insert("Content-Type".to_string(), "application/x-www-form-urlencoded".to_string());
249 headers
250 }
251
252 pub fn provider_name(&self) -> &str {
254 &self.config.provider.name
255 }
256
257 pub fn config(&self) -> &OAuth2ClientConfig {
259 &self.config
260 }
261}
262
263mod urlencoding {
264 pub fn encode(s: &str) -> String {
265 url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
266 }
267}