wechat_pub_rs/
auth.rs

1//! Authentication module for managing WeChat access tokens.
2//!
3//! This module handles the complex process of WeChat access token management,
4//! including automatic refresh, caching, and thread-safe access.
5//!
6//! ## Features
7//!
8//! - **Automatic Token Refresh**: Tokens are refreshed before expiration
9//! - **Thread-Safe Caching**: Multiple threads can safely access tokens
10//! - **Expiration Handling**: Built-in buffer time to prevent edge cases
11//! - **Concurrent Protection**: Prevents multiple simultaneous refresh requests
12//! - **Error Recovery**: Comprehensive error handling for auth failures
13//!
14//! ## Token Lifecycle
15//!
16//! 1. **Initial Request**: Token requested on first API call
17//! 2. **Caching**: Token cached with expiration time
18//! 3. **Validation**: Each use checks if token is still valid
19//! 4. **Refresh**: Automatic refresh before expiration (300s buffer)
20//! 5. **Cleanup**: Expired tokens are discarded
21//!
22//! ## Usage
23//!
24//! ```rust
25//! use wechat_pub_rs::auth::TokenManager;
26//! use wechat_pub_rs::http::WeChatHttpClient;
27//! use std::sync::Arc;
28//!
29//! # async fn example() -> wechat_pub_rs::Result<()> {
30//! let http_client = Arc::new(WeChatHttpClient::new()?);
31//! let token_manager = TokenManager::new(
32//!     "your_app_id".to_string(),
33//!     "your_app_secret".to_string(),
34//!     http_client
35//! );
36//!
37//! // Get a valid access token (handles caching and refresh automatically)
38//! let token = token_manager.get_access_token().await?;
39//! println!("Access token: {}", token);
40//!
41//! // Force refresh if needed
42//! let new_token = token_manager.force_refresh().await?;
43//! # Ok(())
44//! # }
45//! ```
46//!
47//! ## Thread Safety
48//!
49//! The token manager is designed to be shared across multiple threads safely:
50//!
51//! ```rust
52//! use std::sync::Arc;
53//! # use wechat_pub_rs::auth::TokenManager;
54//! # use wechat_pub_rs::http::WeChatHttpClient;
55//!
56//! # async fn example() -> wechat_pub_rs::Result<()> {
57//! # let http_client = Arc::new(WeChatHttpClient::new()?);
58//! let token_manager = Arc::new(TokenManager::new(
59//!     "app_id".to_string(),
60//!     "app_secret".to_string(),
61//!     http_client
62//! ));
63//!
64//! // Share across threads
65//! let manager_clone = Arc::clone(&token_manager);
66//! tokio::spawn(async move {
67//!     let token = manager_clone.get_access_token().await.unwrap();
68//!     // Use token...
69//! });
70//! # Ok(())
71//! # }
72//! ```
73
74use crate::error::Result;
75use crate::http::{AccessTokenResponse, WeChatHttpClient, WeChatResponse};
76use chrono::{DateTime, Duration, Utc};
77use serde::{Deserialize, Serialize};
78use std::sync::Arc;
79use tokio::sync::RwLock;
80use tracing::info;
81
82/// Access token with expiration information.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct AccessToken {
85    /// The access token string
86    pub token: String,
87    /// When the token expires
88    pub expires_at: DateTime<Utc>,
89}
90
91impl AccessToken {
92    /// Creates a new access token with expiration time.
93    pub fn new(token: String, expires_in_seconds: u64) -> Self {
94        let expires_at = Utc::now() + Duration::seconds(expires_in_seconds as i64);
95        Self { token, expires_at }
96    }
97
98    /// Checks if the token is expired or will expire within the buffer time.
99    pub fn is_expired(&self, buffer_seconds: i64) -> bool {
100        let buffer_time = Duration::seconds(buffer_seconds);
101        Utc::now() + buffer_time >= self.expires_at
102    }
103
104    /// Gets the remaining time until expiration.
105    pub fn time_until_expiry(&self) -> Duration {
106        self.expires_at - Utc::now()
107    }
108}
109
110/// Token manager responsible for obtaining and caching access tokens.
111#[derive(Debug)]
112pub struct TokenManager {
113    app_id: String,
114    app_secret: String,
115    http_client: Arc<WeChatHttpClient>,
116    token_cache: Arc<RwLock<Option<AccessToken>>>,
117    refresh_lock: Arc<tokio::sync::Mutex<()>>,
118}
119
120impl TokenManager {
121    /// Creates a new token manager.
122    pub fn new(
123        app_id: impl Into<String>,
124        app_secret: impl Into<String>,
125        http_client: Arc<WeChatHttpClient>,
126    ) -> Self {
127        Self {
128            app_id: app_id.into(),
129            app_secret: app_secret.into(),
130            http_client,
131            token_cache: Arc::new(RwLock::new(None)),
132            refresh_lock: Arc::new(tokio::sync::Mutex::new(())),
133        }
134    }
135
136    /// Gets a valid access token, refreshing if necessary.
137    ///
138    /// This method is thread-safe and will prevent concurrent token refreshes.
139    pub async fn get_access_token(&self) -> Result<String> {
140        // Check cache first (fast path)
141        if let Some(token) = self.get_cached_token().await {
142            return Ok(token);
143        }
144
145        // Slow path: need to refresh token
146        self.refresh_token().await
147    }
148
149    /// Gets a cached token if it's still valid.
150    async fn get_cached_token(&self) -> Option<String> {
151        let cache = self.token_cache.read().await;
152        if let Some(ref token) = *cache {
153            // Use 60-second buffer to avoid edge cases
154            if !token.is_expired(60) {
155                return Some(token.token.clone());
156            }
157        }
158        None
159    }
160
161    /// Refreshes the access token from WeChat API.
162    async fn refresh_token(&self) -> Result<String> {
163        // Prevent concurrent refreshes
164        let _guard = self.refresh_lock.lock().await;
165
166        // Double-check after acquiring lock
167        if let Some(token) = self.get_cached_token().await {
168            return Ok(token);
169        }
170
171        info!("Refreshing WeChat access token");
172
173        // Make API call to get new token
174        let url = format!(
175            "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid={}&secret={}",
176            self.app_id, self.app_secret
177        );
178
179        let response_bytes = self.http_client.download(&url).await?;
180
181        let api_response: WeChatResponse<AccessTokenResponse> =
182            serde_json::from_slice(&response_bytes)?;
183
184        let token_response = api_response.into_result()?;
185
186        // Create and cache the new token
187        let new_token = AccessToken::new(token_response.access_token, token_response.expires_in);
188        let token_string = new_token.token.clone();
189
190        // Update cache
191        {
192            let mut cache = self.token_cache.write().await;
193            *cache = Some(new_token);
194        }
195
196        info!("Successfully refreshed WeChat access token");
197        Ok(token_string)
198    }
199
200    /// Forces a token refresh (useful for testing or when token is known to be invalid).
201    pub async fn force_refresh(&self) -> Result<String> {
202        // Clear cache first
203        {
204            let mut cache = self.token_cache.write().await;
205            *cache = None;
206        }
207
208        self.refresh_token().await
209    }
210
211    /// Gets token information for debugging purposes.
212    pub async fn get_token_info(&self) -> Option<TokenInfo> {
213        let cache = self.token_cache.read().await;
214        cache.as_ref().map(|token| TokenInfo {
215            is_expired: token.is_expired(0),
216            expires_at: token.expires_at,
217            time_until_expiry: token.time_until_expiry(),
218        })
219    }
220
221    /// Clears the token cache.
222    pub async fn clear_cache(&self) {
223        let mut cache = self.token_cache.write().await;
224        *cache = None;
225    }
226}
227
228/// Token information for debugging and monitoring.
229#[derive(Debug, Clone)]
230pub struct TokenInfo {
231    pub is_expired: bool,
232    pub expires_at: DateTime<Utc>,
233    pub time_until_expiry: Duration,
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_access_token_expiry() {
242        // Create a token that expires in 1 hour
243        let token = AccessToken::new("test_token".to_string(), 3600);
244
245        // Should not be expired without buffer
246        assert!(!token.is_expired(0));
247
248        // Should not be expired with 30-minute buffer
249        assert!(!token.is_expired(1800));
250
251        // Should be considered expired with 2-hour buffer
252        assert!(token.is_expired(7200));
253    }
254
255    #[test]
256    fn test_access_token_time_until_expiry() {
257        let token = AccessToken::new("test_token".to_string(), 3600);
258        let time_until_expiry = token.time_until_expiry();
259
260        // Should be approximately 1 hour (allowing for test execution time)
261        assert!(time_until_expiry.num_seconds() > 3590);
262        assert!(time_until_expiry.num_seconds() <= 3600);
263    }
264
265    #[tokio::test]
266    async fn test_token_manager_creation() {
267        let http_client = Arc::new(WeChatHttpClient::new().unwrap());
268        let manager = TokenManager::new("test_app_id", "test_app_secret", http_client);
269
270        assert_eq!(manager.app_id, "test_app_id");
271        assert_eq!(manager.app_secret, "test_app_secret");
272
273        // Cache should be empty initially
274        let cache = manager.token_cache.read().await;
275        assert!(cache.is_none());
276    }
277
278    #[tokio::test]
279    async fn test_cached_token_retrieval() {
280        let http_client = Arc::new(WeChatHttpClient::new().unwrap());
281        let manager = TokenManager::new("test_app_id", "test_app_secret", http_client);
282
283        // No cached token initially
284        assert!(manager.get_cached_token().await.is_none());
285
286        // Add a valid token to cache
287        {
288            let mut cache = manager.token_cache.write().await;
289            *cache = Some(AccessToken::new("cached_token".to_string(), 3600));
290        }
291
292        // Should return cached token
293        let cached = manager.get_cached_token().await;
294        assert_eq!(cached, Some("cached_token".to_string()));
295
296        // Clear cache
297        manager.clear_cache().await;
298        assert!(manager.get_cached_token().await.is_none());
299    }
300
301    #[tokio::test]
302    async fn test_token_info() {
303        let http_client = Arc::new(WeChatHttpClient::new().unwrap());
304        let manager = TokenManager::new("test_app_id", "test_app_secret", http_client);
305
306        // No token info initially
307        assert!(manager.get_token_info().await.is_none());
308
309        // Add a token
310        {
311            let mut cache = manager.token_cache.write().await;
312            *cache = Some(AccessToken::new("test_token".to_string(), 3600));
313        }
314
315        // Should have token info
316        let info = manager.get_token_info().await;
317        assert!(info.is_some());
318
319        let info = info.unwrap();
320        assert!(!info.is_expired);
321        assert!(info.time_until_expiry.num_seconds() > 3590);
322    }
323}