1use crate::error::Result;
75use crate::http::{AccessTokenResponse, WeChatHttpClient, WeChatResponse};
76use chrono::{DateTime, Duration, Utc};
77use serde::{Deserialize, Serialize};
78use std::sync::Arc;
79use tokio::sync::RwLock;
80use tracing::info;
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct AccessToken {
85 pub token: String,
87 pub expires_at: DateTime<Utc>,
89}
90
91impl AccessToken {
92 pub fn new(token: String, expires_in_seconds: u64) -> Self {
94 let expires_at = Utc::now() + Duration::seconds(expires_in_seconds as i64);
95 Self { token, expires_at }
96 }
97
98 pub fn is_expired(&self, buffer_seconds: i64) -> bool {
100 let buffer_time = Duration::seconds(buffer_seconds);
101 Utc::now() + buffer_time >= self.expires_at
102 }
103
104 pub fn time_until_expiry(&self) -> Duration {
106 self.expires_at - Utc::now()
107 }
108}
109
110#[derive(Debug)]
112pub struct TokenManager {
113 app_id: String,
114 app_secret: String,
115 http_client: Arc<WeChatHttpClient>,
116 token_cache: Arc<RwLock<Option<AccessToken>>>,
117 refresh_lock: Arc<tokio::sync::Mutex<()>>,
118}
119
120impl TokenManager {
121 pub fn new(
123 app_id: impl Into<String>,
124 app_secret: impl Into<String>,
125 http_client: Arc<WeChatHttpClient>,
126 ) -> Self {
127 Self {
128 app_id: app_id.into(),
129 app_secret: app_secret.into(),
130 http_client,
131 token_cache: Arc::new(RwLock::new(None)),
132 refresh_lock: Arc::new(tokio::sync::Mutex::new(())),
133 }
134 }
135
136 pub async fn get_access_token(&self) -> Result<String> {
140 if let Some(token) = self.get_cached_token().await {
142 return Ok(token);
143 }
144
145 self.refresh_token().await
147 }
148
149 async fn get_cached_token(&self) -> Option<String> {
151 let cache = self.token_cache.read().await;
152 if let Some(ref token) = *cache {
153 if !token.is_expired(60) {
155 return Some(token.token.clone());
156 }
157 }
158 None
159 }
160
161 async fn refresh_token(&self) -> Result<String> {
163 let _guard = self.refresh_lock.lock().await;
165
166 if let Some(token) = self.get_cached_token().await {
168 return Ok(token);
169 }
170
171 info!("Refreshing WeChat access token");
172
173 let url = format!(
175 "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid={}&secret={}",
176 self.app_id, self.app_secret
177 );
178
179 let response_bytes = self.http_client.download(&url).await?;
180
181 let api_response: WeChatResponse<AccessTokenResponse> =
182 serde_json::from_slice(&response_bytes)?;
183
184 let token_response = api_response.into_result()?;
185
186 let new_token = AccessToken::new(token_response.access_token, token_response.expires_in);
188 let token_string = new_token.token.clone();
189
190 {
192 let mut cache = self.token_cache.write().await;
193 *cache = Some(new_token);
194 }
195
196 info!("Successfully refreshed WeChat access token");
197 Ok(token_string)
198 }
199
200 pub async fn force_refresh(&self) -> Result<String> {
202 {
204 let mut cache = self.token_cache.write().await;
205 *cache = None;
206 }
207
208 self.refresh_token().await
209 }
210
211 pub async fn get_token_info(&self) -> Option<TokenInfo> {
213 let cache = self.token_cache.read().await;
214 cache.as_ref().map(|token| TokenInfo {
215 is_expired: token.is_expired(0),
216 expires_at: token.expires_at,
217 time_until_expiry: token.time_until_expiry(),
218 })
219 }
220
221 pub async fn clear_cache(&self) {
223 let mut cache = self.token_cache.write().await;
224 *cache = None;
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct TokenInfo {
231 pub is_expired: bool,
232 pub expires_at: DateTime<Utc>,
233 pub time_until_expiry: Duration,
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_access_token_expiry() {
242 let token = AccessToken::new("test_token".to_string(), 3600);
244
245 assert!(!token.is_expired(0));
247
248 assert!(!token.is_expired(1800));
250
251 assert!(token.is_expired(7200));
253 }
254
255 #[test]
256 fn test_access_token_time_until_expiry() {
257 let token = AccessToken::new("test_token".to_string(), 3600);
258 let time_until_expiry = token.time_until_expiry();
259
260 assert!(time_until_expiry.num_seconds() > 3590);
262 assert!(time_until_expiry.num_seconds() <= 3600);
263 }
264
265 #[tokio::test]
266 async fn test_token_manager_creation() {
267 let http_client = Arc::new(WeChatHttpClient::new().unwrap());
268 let manager = TokenManager::new("test_app_id", "test_app_secret", http_client);
269
270 assert_eq!(manager.app_id, "test_app_id");
271 assert_eq!(manager.app_secret, "test_app_secret");
272
273 let cache = manager.token_cache.read().await;
275 assert!(cache.is_none());
276 }
277
278 #[tokio::test]
279 async fn test_cached_token_retrieval() {
280 let http_client = Arc::new(WeChatHttpClient::new().unwrap());
281 let manager = TokenManager::new("test_app_id", "test_app_secret", http_client);
282
283 assert!(manager.get_cached_token().await.is_none());
285
286 {
288 let mut cache = manager.token_cache.write().await;
289 *cache = Some(AccessToken::new("cached_token".to_string(), 3600));
290 }
291
292 let cached = manager.get_cached_token().await;
294 assert_eq!(cached, Some("cached_token".to_string()));
295
296 manager.clear_cache().await;
298 assert!(manager.get_cached_token().await.is_none());
299 }
300
301 #[tokio::test]
302 async fn test_token_info() {
303 let http_client = Arc::new(WeChatHttpClient::new().unwrap());
304 let manager = TokenManager::new("test_app_id", "test_app_secret", http_client);
305
306 assert!(manager.get_token_info().await.is_none());
308
309 {
311 let mut cache = manager.token_cache.write().await;
312 *cache = Some(AccessToken::new("test_token".to_string(), 3600));
313 }
314
315 let info = manager.get_token_info().await;
317 assert!(info.is_some());
318
319 let info = info.unwrap();
320 assert!(!info.is_expired);
321 assert!(info.time_until_expiry.num_seconds() > 3590);
322 }
323}