wechat_minapp/client/
token_storage.rs1use super::access_token::{AccessToken, is_token_expired};
7use super::token_type::TokenType;
8use crate::Result;
9use async_trait::async_trait;
10use chrono::Utc;
11use std::sync::{
12 Arc,
13 atomic::{AtomicBool, Ordering},
14};
15use tokio::sync::{Notify, RwLock};
16use tracing::debug;
17
18
19#[async_trait]
21pub trait TokenStorage: Send + Sync {
22 async fn token(&self) -> Result<String>;
23 async fn refresh_access_token(&self) -> Result<String>;
24 fn token_type(&self) -> Arc<dyn TokenType>;
25}
26
27pub struct MemoryTokenStorage {
29 access_token: Arc<RwLock<AccessToken>>,
30 refreshing: Arc<AtomicBool>,
31 notify: Arc<Notify>,
32 token_type: Arc<dyn TokenType>,
33}
34
35impl MemoryTokenStorage {
36 pub fn new(token_type: Arc<dyn TokenType>) -> Self {
37 MemoryTokenStorage {
38 access_token: Arc::new(RwLock::new(AccessToken {
39 access_token: String::new(),
40 expired_at: Utc::now(),
41 })),
42 refreshing: Arc::new(AtomicBool::new(false)),
43 notify: Arc::new(Notify::new()),
44 token_type,
45 }
46 }
47}
48
49
50#[async_trait]
52impl TokenStorage for MemoryTokenStorage {
53 async fn token(&self) -> Result<String> {
55 {
57 let guard = self.access_token.read().await;
58 if !is_token_expired(&guard) {
59 return Ok(guard.access_token.clone());
60 }
61 }
62
63 if self
65 .refreshing
66 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
67 .is_ok()
68 {
69 match self.refresh_access_token().await {
71 Ok(token) => {
72 self.refreshing.store(false, Ordering::Release);
73 self.notify.notify_waiters();
74 Ok(token)
75 }
76 Err(e) => {
77 self.refreshing.store(false, Ordering::Release);
78 self.notify.notify_waiters();
79 Err(e)
80 }
81 }
82 } else {
83 self.notify.notified().await;
85 let guard = self.access_token.read().await;
87 Ok(guard.access_token.clone())
88 }
89 }
90
91 async fn refresh_access_token(&self) -> Result<String> {
93 let mut guard = self.access_token.write().await;
94
95 if !is_token_expired(&guard) {
96 debug!("token already refreshed by another thread");
97 return Ok(guard.access_token.clone());
98 }
99
100 debug!("performing network request to refresh token");
101
102 let builder = self.token_type.token().await?;
103
104 guard.access_token = builder.access_token.clone();
105 guard.expired_at = builder.expired_at;
106
107 debug!("fresh access token: {:#?}", guard);
108
109 Ok(guard.access_token.clone())
110 }
111
112 fn token_type(&self) -> Arc<dyn TokenType> {
113 self.token_type.clone()
114 }
115}