wae_authentication/oauth2/
client.rs1use crate::oauth2::{AuthorizationUrl, OAuth2ClientConfig, OAuth2Error, OAuth2Result, TokenResponse, UserInfo};
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use wae_request::{HttpClient, HttpClientConfig, HttpError};
8
9#[derive(Debug, Clone)]
11pub struct OAuth2Client {
12 config: OAuth2ClientConfig,
13 http_client: HttpClient,
14}
15
16impl OAuth2Client {
17 pub fn new(config: OAuth2ClientConfig) -> OAuth2Result<Self> {
22 let http_config = HttpClientConfig {
23 timeout: std::time::Duration::from_millis(config.timeout_ms),
24 connect_timeout: std::time::Duration::from_secs(10),
25 user_agent: "wae-oauth2/0.1.0".to_string(),
26 max_retries: 3,
27 retry_delay: std::time::Duration::from_millis(1000),
28 default_headers: HashMap::new(),
29 };
30
31 let http_client = HttpClient::new(http_config);
32
33 Ok(Self { config, http_client })
34 }
35
36 pub fn authorization_url(&self) -> OAuth2Result<AuthorizationUrl> {
41 let state = self.generate_state();
42 let mut params: Vec<(String, String)> = vec![
43 ("client_id".to_string(), self.config.provider.client_id.clone()),
44 ("redirect_uri".to_string(), self.config.provider.redirect_uri.clone()),
45 ("response_type".to_string(), "code".to_string()),
46 ("state".to_string(), state.clone()),
47 ];
48
49 if !self.config.provider.scopes.is_empty() {
50 params.push(("scope".to_string(), self.config.provider.scopes.join(" ")));
51 }
52
53 let code_verifier = if self.config.use_pkce {
54 let verifier = self.generate_code_verifier();
55 let challenge = self.generate_code_challenge(&verifier);
56 params.push(("code_challenge".to_string(), challenge));
57 params.push(("code_challenge_method".to_string(), "S256".to_string()));
58 Some(verifier)
59 }
60 else {
61 None
62 };
63
64 for (key, value) in &self.config.provider.extra_params {
65 params.push((key.clone(), value.clone()));
66 }
67
68 let query = params
69 .iter()
70 .map(|(k, v)| format!("{}={}", wae_types::url_encode(k), wae_types::url_encode(v)))
71 .collect::<Vec<_>>()
72 .join("&");
73
74 let url = format!("{}?{}", self.config.provider.authorization_url, query);
75
76 Ok(AuthorizationUrl { url, state, code_verifier })
77 }
78
79 pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> OAuth2Result<TokenResponse> {
85 let mut params = HashMap::new();
86 params.insert("grant_type", "authorization_code".to_string());
87 params.insert("code", code.to_string());
88 params.insert("redirect_uri", self.config.provider.redirect_uri.clone());
89 params.insert("client_id", self.config.provider.client_id.clone());
90 params.insert("client_secret", self.config.provider.client_secret.clone());
91
92 if let Some(verifier) = code_verifier {
93 params.insert("code_verifier", verifier.to_string());
94 }
95
96 let form_body = self.encode_form_data(¶ms);
97
98 let response = self
99 .http_client
100 .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
101 .await
102 .map_err(OAuth2Error::from)?;
103
104 if !response.is_success() {
105 let error_text = response.text().unwrap_or_default();
106 return Err(OAuth2Error::ProviderError(error_text));
107 }
108
109 let token_response: TokenResponse = response.json().map_err(OAuth2Error::from)?;
110 Ok(token_response)
111 }
112
113 pub async fn refresh_token(&self, refresh_token: &str) -> OAuth2Result<TokenResponse> {
118 let mut params = HashMap::new();
119 params.insert("grant_type", "refresh_token".to_string());
120 params.insert("refresh_token", refresh_token.to_string());
121 params.insert("client_id", self.config.provider.client_id.clone());
122 params.insert("client_secret", self.config.provider.client_secret.clone());
123
124 let form_body = self.encode_form_data(¶ms);
125
126 let response = self
127 .http_client
128 .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
129 .await
130 .map_err(OAuth2Error::from)?;
131
132 if !response.is_success() {
133 let error_text = response.text().unwrap_or_default();
134 return Err(OAuth2Error::ProviderError(error_text));
135 }
136
137 let token_response: TokenResponse = response.json().map_err(OAuth2Error::from)?;
138 Ok(token_response)
139 }
140
141 pub async fn get_user_info(&self, access_token: &str) -> OAuth2Result<UserInfo> {
146 let userinfo_url = self
147 .config
148 .provider
149 .userinfo_url
150 .as_ref()
151 .ok_or_else(|| OAuth2Error::ConfigurationError("userinfo_url not configured".into()))?;
152
153 let mut headers = HashMap::new();
154 headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
155
156 let response = self.http_client.get_with_headers(userinfo_url, headers).await.map_err(OAuth2Error::from)?;
157
158 if !response.is_success() {
159 let error_text = response.text().unwrap_or_default();
160 return Err(OAuth2Error::ProviderError(error_text));
161 }
162
163 let user_info: UserInfo = response.json().map_err(OAuth2Error::from)?;
164 Ok(user_info)
165 }
166
167 pub async fn revoke_token(&self, token: &str, token_type_hint: Option<&str>) -> OAuth2Result<()> {
173 let revocation_url = self
174 .config
175 .provider
176 .revocation_url
177 .as_ref()
178 .ok_or_else(|| OAuth2Error::ConfigurationError("revocation_url not configured".into()))?;
179
180 let mut params = HashMap::new();
181 params.insert("token", token.to_string());
182 params.insert("client_id", self.config.provider.client_id.clone());
183 params.insert("client_secret", self.config.provider.client_secret.clone());
184
185 if let Some(hint) = token_type_hint {
186 params.insert("token_type_hint", hint.to_string());
187 }
188
189 let form_body = self.encode_form_data(¶ms);
190
191 let response = self
192 .http_client
193 .post_with_headers(revocation_url, form_body.into_bytes(), self.form_headers())
194 .await
195 .map_err(OAuth2Error::from)?;
196
197 if !response.is_success() {
198 let error_text = response.text().unwrap_or_default();
199 return Err(OAuth2Error::ProviderError(error_text));
200 }
201
202 Ok(())
203 }
204
205 pub fn validate_state(&self, expected: &str, received: &str) -> OAuth2Result<()> {
211 if !self.config.use_state {
212 return Ok(());
213 }
214
215 if expected == received { Ok(()) } else { Err(OAuth2Error::StateMismatch) }
216 }
217
218 fn generate_state(&self) -> String {
219 uuid::Uuid::new_v4().to_string().replace('-', "")
220 }
221
222 fn generate_code_verifier(&self) -> String {
223 let random_bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
224 URL_SAFE_NO_PAD.encode(&random_bytes)
225 }
226
227 fn generate_code_challenge(&self, verifier: &str) -> String {
228 let mut hasher = Sha256::new();
229 hasher.update(verifier.as_bytes());
230 let hash = hasher.finalize();
231 URL_SAFE_NO_PAD.encode(&hash)
232 }
233
234 fn encode_form_data(&self, params: &HashMap<&str, String>) -> String {
235 params
236 .iter()
237 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
238 .collect::<Vec<_>>()
239 .join("&")
240 }
241
242 fn form_headers(&self) -> HashMap<String, String> {
243 let mut headers = HashMap::new();
244 headers.insert("Content-Type".to_string(), "application/x-www-form-urlencoded".to_string());
245 headers
246 }
247
248 pub fn provider_name(&self) -> &str {
250 &self.config.provider.name
251 }
252
253 pub fn config(&self) -> &OAuth2ClientConfig {
255 &self.config
256 }
257}
258
259impl From<HttpError> for OAuth2Error {
260 fn from(err: HttpError) -> Self {
261 match err {
262 HttpError::InvalidUrl(msg) => OAuth2Error::ConfigurationError(msg),
263 HttpError::Timeout => OAuth2Error::RequestError("Request timeout".into()),
264 HttpError::ConnectionFailed(msg) => OAuth2Error::RequestError(msg),
265 HttpError::DnsFailed(msg) => OAuth2Error::RequestError(msg),
266 HttpError::TlsError(msg) => OAuth2Error::RequestError(msg),
267 HttpError::StatusError { status, body } => match status {
268 401 => OAuth2Error::AccessDenied(body),
269 403 => OAuth2Error::AccessDenied(body),
270 _ => OAuth2Error::ProviderError(format!("HTTP {}: {}", status, body)),
271 },
272 _ => OAuth2Error::RequestError(err.to_string()),
273 }
274 }
275}
276
277mod urlencoding {
278 pub fn encode(s: &str) -> String {
279 url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
280 }
281}