wechat_minapp/client/
token_storage.rs1use super::access_token::{AccessToken, is_token_expired};
7use super::token_type::TokenType;
8use crate::Result;
9use async_trait::async_trait;
10use chrono::Utc;
11use std::sync::{
12 Arc,
13 atomic::{AtomicBool, Ordering},
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::debug;
17
18#[async_trait]
20pub trait TokenStorage: Send + Sync {
21 async fn token(&self) -> Result<String>;
22 async fn refresh_access_token(&self) -> Result<String>;
23 fn token_type(&self) -> Arc<dyn TokenType>;
24}
25
26pub struct MemoryTokenStorage {
28 access_token: Arc<RwLock<AccessToken>>,
29 refreshing: Arc<AtomicBool>,
30 notify: Arc<Notify>,
31 token_type: Arc<dyn TokenType>,
32}
33
34impl MemoryTokenStorage {
35 pub fn new(token_type: Arc<dyn TokenType>) -> Self {
36 MemoryTokenStorage {
37 access_token: Arc::new(RwLock::new(AccessToken {
38 access_token: String::new(),
39 expired_at: Utc::now(),
40 })),
41 refreshing: Arc::new(AtomicBool::new(false)),
42 notify: Arc::new(Notify::new()),
43 token_type,
44 }
45 }
46}
47
48#[async_trait]
50impl TokenStorage for MemoryTokenStorage {
51 async fn token(&self) -> Result<String> {
53 {
55 let guard = self.access_token.read().await;
56 if !is_token_expired(&guard) {
57 return Ok(guard.access_token.clone());
58 }
59 }
60
61 if self
63 .refreshing
64 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
65 .is_ok()
66 {
67 match self.refresh_access_token().await {
69 Ok(token) => {
70 self.refreshing.store(false, Ordering::Release);
71 self.notify.notify_waiters();
72 Ok(token)
73 }
74 Err(e) => {
75 self.refreshing.store(false, Ordering::Release);
76 self.notify.notify_waiters();
77 Err(e)
78 }
79 }
80 } else {
81 self.notify.notified().await;
83 let guard = self.access_token.read().await;
85 Ok(guard.access_token.clone())
86 }
87 }
88
89 async fn refresh_access_token(&self) -> Result<String> {
91 let mut guard = self.access_token.write().await;
92
93 if !is_token_expired(&guard) {
94 debug!("token already refreshed by another thread");
95 return Ok(guard.access_token.clone());
96 }
97
98 debug!("performing network request to refresh token");
99
100 let builder = self.token_type.token().await?;
101
102 guard.access_token = builder.access_token.clone();
103 guard.expired_at = builder.expired_at;
104
105 debug!("fresh access token: {:#?}", guard);
106
107 Ok(guard.access_token.clone())
108 }
109
110 fn token_type(&self) -> Arc<dyn TokenType> {
111 self.token_type.clone()
112 }
113}