Skip to main content

wae_authentication/oauth2/
client.rs

1//! OAuth2 客户端实现
2
3use crate::oauth2::{AuthorizationUrl, OAuth2ClientConfig, TokenResponse, UserInfo};
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use wae_request::{HttpClient, HttpClientConfig};
8use wae_types::{WaeError, WaeErrorKind};
9
10/// OAuth2 结果类型
11pub type OAuth2Result<T> = Result<T, WaeError>;
12
13/// OAuth2 客户端
14#[derive(Debug, Clone)]
15pub struct OAuth2Client {
16    config: OAuth2ClientConfig,
17    http_client: HttpClient,
18}
19
20impl OAuth2Client {
21    /// 创建新的 OAuth2 客户端
22    ///
23    /// # Arguments
24    /// * `config` - 客户端配置
25    pub fn new(config: OAuth2ClientConfig) -> OAuth2Result<Self> {
26        let http_config = HttpClientConfig {
27            timeout: std::time::Duration::from_millis(config.timeout_ms),
28            connect_timeout: std::time::Duration::from_secs(10),
29            user_agent: "wae-oauth2/0.1.0".to_string(),
30            max_retries: 3,
31            retry_delay: std::time::Duration::from_millis(1000),
32            default_headers: HashMap::new(),
33        };
34
35        let http_client = HttpClient::new(http_config);
36
37        Ok(Self { config, http_client })
38    }
39
40    /// 生成授权 URL
41    ///
42    /// # Returns
43    /// 返回授权 URL、状态参数和 PKCE code verifier
44    pub fn authorization_url(&self) -> OAuth2Result<AuthorizationUrl> {
45        let state = self.generate_state();
46        let mut params: Vec<(String, String)> = vec![
47            ("client_id".to_string(), self.config.provider.client_id.clone()),
48            ("redirect_uri".to_string(), self.config.provider.redirect_uri.clone()),
49            ("response_type".to_string(), "code".to_string()),
50            ("state".to_string(), state.clone()),
51        ];
52
53        if !self.config.provider.scopes.is_empty() {
54            params.push(("scope".to_string(), self.config.provider.scopes.join(" ")));
55        }
56
57        let code_verifier = if self.config.use_pkce {
58            let verifier = self.generate_code_verifier();
59            let challenge = self.generate_code_challenge(&verifier);
60            params.push(("code_challenge".to_string(), challenge));
61            params.push(("code_challenge_method".to_string(), "S256".to_string()));
62            Some(verifier)
63        }
64        else {
65            None
66        };
67
68        for (key, value) in &self.config.provider.extra_params {
69            params.push((key.clone(), value.clone()));
70        }
71
72        let query = params
73            .iter()
74            .map(|(k, v)| format!("{}={}", wae_types::url_encode(k), wae_types::url_encode(v)))
75            .collect::<Vec<_>>()
76            .join("&");
77
78        let url = format!("{}?{}", self.config.provider.authorization_url, query);
79
80        Ok(AuthorizationUrl { url, state, code_verifier })
81    }
82
83    /// 使用授权码交换令牌
84    ///
85    /// # Arguments
86    /// * `code` - 授权码
87    /// * `code_verifier` - PKCE code verifier (如果启用了 PKCE)
88    pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> OAuth2Result<TokenResponse> {
89        let mut params = HashMap::new();
90        params.insert("grant_type", "authorization_code".to_string());
91        params.insert("code", code.to_string());
92        params.insert("redirect_uri", self.config.provider.redirect_uri.clone());
93        params.insert("client_id", self.config.provider.client_id.clone());
94        params.insert("client_secret", self.config.provider.client_secret.clone());
95
96        if let Some(verifier) = code_verifier {
97            params.insert("code_verifier", verifier.to_string());
98        }
99
100        let form_body = self.encode_form_data(&params);
101
102        let response = self
103            .http_client
104            .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
105            .await
106            .map_err(WaeError::from)?;
107
108        if !response.is_success() {
109            let error_text = response.text().unwrap_or_default();
110            return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
111        }
112
113        let token_response: TokenResponse = response.json().map_err(WaeError::from)?;
114        Ok(token_response)
115    }
116
117    /// 刷新访问令牌
118    ///
119    /// # Arguments
120    /// * `refresh_token` - 刷新令牌
121    pub async fn refresh_token(&self, refresh_token: &str) -> OAuth2Result<TokenResponse> {
122        let mut params = HashMap::new();
123        params.insert("grant_type", "refresh_token".to_string());
124        params.insert("refresh_token", refresh_token.to_string());
125        params.insert("client_id", self.config.provider.client_id.clone());
126        params.insert("client_secret", self.config.provider.client_secret.clone());
127
128        let form_body = self.encode_form_data(&params);
129
130        let response = self
131            .http_client
132            .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
133            .await
134            .map_err(WaeError::from)?;
135
136        if !response.is_success() {
137            let error_text = response.text().unwrap_or_default();
138            return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
139        }
140
141        let token_response: TokenResponse = response.json().map_err(WaeError::from)?;
142        Ok(token_response)
143    }
144
145    /// 获取用户信息
146    ///
147    /// # Arguments
148    /// * `access_token` - 访问令牌
149    pub async fn get_user_info(&self, access_token: &str) -> OAuth2Result<UserInfo> {
150        let userinfo_url = self
151            .config
152            .provider
153            .userinfo_url
154            .as_ref()
155            .ok_or_else(|| WaeError::config_invalid("userinfo_url", "not configured"))?;
156
157        let mut headers = HashMap::new();
158        headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
159
160        let response = self.http_client.get_with_headers(userinfo_url, headers).await.map_err(WaeError::from)?;
161
162        if !response.is_success() {
163            let error_text = response.text().unwrap_or_default();
164            return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
165        }
166
167        let user_info: UserInfo = response.json().map_err(WaeError::from)?;
168        Ok(user_info)
169    }
170
171    /// 撤销令牌
172    ///
173    /// # Arguments
174    /// * `token` - 要撤销的令牌
175    /// * `token_type_hint` - 令牌类型提示 (access_token 或 refresh_token)
176    pub async fn revoke_token(&self, token: &str, token_type_hint: Option<&str>) -> OAuth2Result<()> {
177        let revocation_url = self
178            .config
179            .provider
180            .revocation_url
181            .as_ref()
182            .ok_or_else(|| WaeError::config_invalid("revocation_url", "not configured"))?;
183
184        let mut params = HashMap::new();
185        params.insert("token", token.to_string());
186        params.insert("client_id", self.config.provider.client_id.clone());
187        params.insert("client_secret", self.config.provider.client_secret.clone());
188
189        if let Some(hint) = token_type_hint {
190            params.insert("token_type_hint", hint.to_string());
191        }
192
193        let form_body = self.encode_form_data(&params);
194
195        let response = self
196            .http_client
197            .post_with_headers(revocation_url, form_body.into_bytes(), self.form_headers())
198            .await
199            .map_err(WaeError::from)?;
200
201        if !response.is_success() {
202            let error_text = response.text().unwrap_or_default();
203            return Err(WaeError::new(WaeErrorKind::OAuth2ProviderError { message: error_text }));
204        }
205
206        Ok(())
207    }
208
209    /// 验证状态参数
210    ///
211    /// # Arguments
212    /// * `expected` - 期望的状态值
213    /// * `received` - 接收到的状态值
214    pub fn validate_state(&self, expected: &str, received: &str) -> OAuth2Result<()> {
215        if !self.config.use_state {
216            return Ok(());
217        }
218
219        if expected == received { Ok(()) } else { Err(WaeError::new(WaeErrorKind::StateMismatch)) }
220    }
221
222    fn generate_state(&self) -> String {
223        uuid::Uuid::new_v4().to_string().replace('-', "")
224    }
225
226    fn generate_code_verifier(&self) -> String {
227        let random_bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
228        URL_SAFE_NO_PAD.encode(&random_bytes)
229    }
230
231    fn generate_code_challenge(&self, verifier: &str) -> String {
232        let mut hasher = Sha256::new();
233        hasher.update(verifier.as_bytes());
234        let hash = hasher.finalize();
235        URL_SAFE_NO_PAD.encode(&hash)
236    }
237
238    fn encode_form_data(&self, params: &HashMap<&str, String>) -> String {
239        params
240            .iter()
241            .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
242            .collect::<Vec<_>>()
243            .join("&")
244    }
245
246    fn form_headers(&self) -> HashMap<String, String> {
247        let mut headers = HashMap::new();
248        headers.insert("Content-Type".to_string(), "application/x-www-form-urlencoded".to_string());
249        headers
250    }
251
252    /// 获取提供者名称
253    pub fn provider_name(&self) -> &str {
254        &self.config.provider.name
255    }
256
257    /// 获取配置
258    pub fn config(&self) -> &OAuth2ClientConfig {
259        &self.config
260    }
261}
262
263mod urlencoding {
264    pub fn encode(s: &str) -> String {
265        url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
266    }
267}