1mod error;
26pub mod message;
27
28use log::{debug, info, warn};
29use serde::{Deserialize, Serialize};
30use std::error::Error as StdError;
31use std::time::{Duration, SystemTime, UNIX_EPOCH};
32use tokio::sync::RwLock;
33
34#[derive(Debug)]
36struct AccessToken {
37 value: Option<String>,
38 timestamp: SystemTime,
39 lifetime: Duration,
40}
41
42impl AccessToken {
43 pub fn value(&self) -> Option<&String> {
45 self.value.as_ref()
46 }
47
48 pub fn update(&mut self, token: &str, timestamp: SystemTime, lifetime: Duration) {
50 self.value = Some(token.to_owned());
51 self.timestamp = timestamp;
52 self.lifetime = lifetime;
53 }
54
55 pub fn expired(&self) -> bool {
57 match SystemTime::now().duration_since(self.timestamp) {
58 Ok(duration) => duration >= self.lifetime,
59 Err(_) => false,
60 }
61 }
62
63 pub fn expire_in(&self, n: u64) -> bool {
65 match SystemTime::now().duration_since(self.timestamp) {
66 Ok(duration_from_last_update) => {
67 duration_from_last_update + Duration::from_secs(n) > self.lifetime
68 }
69 Err(_) => false,
70 }
71 }
72
73 pub fn timestamp(&self) -> SystemTime {
75 self.timestamp
76 }
77}
78
79impl Default for AccessToken {
80 fn default() -> Self {
81 Self {
82 value: None,
83 timestamp: UNIX_EPOCH,
84 lifetime: Duration::from_secs(7200),
85 }
86 }
87}
88
89#[derive(Debug)]
91pub struct WecomAgent {
92 corp_id: String,
93 secret: String,
94 access_token: RwLock<AccessToken>,
95 client: reqwest::Client,
96}
97
98impl WecomAgent {
99 pub fn new(corp_id: &str, secret: &str) -> Self {
101 Self {
102 corp_id: String::from(corp_id),
103 secret: String::from(secret),
104 access_token: RwLock::new(AccessToken::default()),
105 client: reqwest::Client::new(),
106 }
107 }
108
109 pub async fn update_token(
112 &self,
113 backoff_seconds: u64,
114 ) -> Result<(), Box<dyn StdError + Send + Sync>> {
115 let mut access_token = self.access_token.write().await;
117
118 let seconds_since_last_update = SystemTime::now()
120 .duration_since(access_token.timestamp())?
121 .as_secs();
122 if seconds_since_last_update < backoff_seconds {
123 return Err(Box::new(error::Error::new(
124 -9,
125 format!("Access token更新过于频繁。上次更新于{seconds_since_last_update}秒前。"),
126 )));
127 }
128
129 let url = format!(
131 "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={}&corpsecret={}",
132 self.corp_id, self.secret,
133 );
134 let response = reqwest::get(url)
135 .await?
136 .json::<AccessTokenResponse>()
137 .await?;
138 if response.errcode != 0 {
139 return Err(Box::<error::Error>::new(error::Error::new(
140 response.errcode,
141 response.errmsg,
142 )));
143 };
144
145 access_token.update(
147 &response.access_token,
148 SystemTime::now(),
149 Duration::from_secs(response.expires_in),
150 );
151 Ok(())
152 }
153
154 pub async fn send<T>(&self, msg: T) -> Result<MsgSendResponse, Box<dyn StdError + Send + Sync>>
156 where
157 T: Serialize,
158 {
159 let token_should_update: bool = {
161 let access_token = self.access_token.read().await;
162 access_token.value().is_none() || access_token.expire_in(300) || access_token.expired()
163 };
164 if token_should_update {
165 warn!("Token invalid. Updating...");
166 self.update_token(10).await?;
167 info!("Token updated");
168 }
169
170 let url = {
172 let access_token = self.access_token.read().await;
173 format!(
174 "https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={}",
175 access_token
176 .value()
177 .expect("Access token should not be None.")
178 )
179 };
180
181 debug!("Sending [try 1]...");
183 let mut response: MsgSendResponse = self
184 .client
185 .post(&url)
186 .json(&msg)
187 .send()
188 .await?
189 .json::<MsgSendResponse>()
190 .await?;
191
192 if response.error_code() == 40014 {
194 warn!("Token invalid. Updating...");
195 self.update_token(10).await?;
196
197 debug!("Sending [try 2]...");
199 response = self
200 .client
201 .post(&url)
202 .json(&msg)
203 .send()
204 .await?
205 .json::<MsgSendResponse>()
206 .await?;
207 };
208
209 debug!("Sending [Done]");
210 Ok(response)
211 }
212}
213
214#[derive(Deserialize)]
216pub struct MsgSendResponse {
217 errcode: i64,
218 errmsg: String,
219 #[allow(dead_code)]
220 invaliduser: Option<String>,
221 #[allow(dead_code)]
222 invalidparty: Option<String>,
223 #[allow(dead_code)]
224 invalidtag: Option<String>,
225 #[allow(dead_code)]
226 unlicenseduser: Option<String>,
227 #[allow(dead_code)]
228 msgid: Option<String>,
229 #[allow(dead_code)]
230 response_code: Option<String>,
231}
232
233impl MsgSendResponse {
234 pub fn is_error(&self) -> bool {
235 self.errcode != 0
236 }
237
238 pub fn error_code(&self) -> i64 {
239 self.errcode
240 }
241
242 pub fn error_msg(&self) -> &str {
243 &self.errmsg
244 }
245}
246
247#[derive(Deserialize)]
256struct AccessTokenResponse {
257 errcode: i64,
258 errmsg: String,
259 access_token: String,
260 expires_in: u64,
261}