wecom_agent/
lib.rs

1//! # wecom-agent
2//!
3//! `wecom-agent`封装了企业微信API的消息发送功能。
4//!
5//! ## 使用方法
6//! ```rust
7//! use wecom_agent::{
8//!     message::{MessageBuilder, Text},
9//!     MsgSendResponse, WecomAgent,
10//! };
11//! async fn example() {
12//!     let content = Text::new("Hello from Wandering AI!".to_string());
13//!     let msg = MessageBuilder::default()
14//!         .to_users(vec!["robin", "tom"])
15//!         .from_agent(42)
16//!         .build(content)
17//!         .expect("Massage should be built");
18//!     let handle = tokio::spawn(async move {
19//!         let wecom_agent = WecomAgent::new("your_corpid", "your_secret");
20//!         let response = wecom_agent.send(msg).await;
21//!     });
22//! }
23//! ```
24
25mod error;
26pub mod message;
27
28use log::{debug, info, warn};
29use serde::{Deserialize, Serialize};
30use std::error::Error as StdError;
31use std::time::{Duration, SystemTime, UNIX_EPOCH};
32use tokio::sync::RwLock;
33
34// 企业微信鉴权凭据
35#[derive(Debug)]
36struct AccessToken {
37    value: Option<String>,
38    timestamp: SystemTime,
39    lifetime: Duration,
40}
41
42impl AccessToken {
43    /// 获取凭据内容
44    pub fn value(&self) -> Option<&String> {
45        self.value.as_ref()
46    }
47
48    /// 更新凭据
49    pub fn update(&mut self, token: &str, timestamp: SystemTime, lifetime: Duration) {
50        self.value = Some(token.to_owned());
51        self.timestamp = timestamp;
52        self.lifetime = lifetime;
53    }
54
55    /// 凭据是否已过期
56    pub fn expired(&self) -> bool {
57        match SystemTime::now().duration_since(self.timestamp) {
58            Ok(duration) => duration >= self.lifetime,
59            Err(_) => false,
60        }
61    }
62
63    /// 凭据将在N秒后过期。注意,若凭据已过期,将返回false。必要时配合`expired()`使用。
64    pub fn expire_in(&self, n: u64) -> bool {
65        match SystemTime::now().duration_since(self.timestamp) {
66            Ok(duration_from_last_update) => {
67                duration_from_last_update + Duration::from_secs(n) > self.lifetime
68            }
69            Err(_) => false,
70        }
71    }
72
73    /// 获取token上一次更新时刻
74    pub fn timestamp(&self) -> SystemTime {
75        self.timestamp
76    }
77}
78
79impl Default for AccessToken {
80    fn default() -> Self {
81        Self {
82            value: None,
83            timestamp: UNIX_EPOCH,
84            lifetime: Duration::from_secs(7200),
85        }
86    }
87}
88
89/// 企业微信API的轻量封装
90#[derive(Debug)]
91pub struct WecomAgent {
92    corp_id: String,
93    secret: String,
94    access_token: RwLock<AccessToken>,
95    client: reqwest::Client,
96}
97
98impl WecomAgent {
99    /// 创建一个Agent。注意此过程不会自动初始化access token。
100    pub fn new(corp_id: &str, secret: &str) -> Self {
101        Self {
102            corp_id: String::from(corp_id),
103            secret: String::from(secret),
104            access_token: RwLock::new(AccessToken::default()),
105            client: reqwest::Client::new(),
106        }
107    }
108
109    /// 更新access_token。使用`backoff_seconds`设定休止时段。若距离上次更新时间短于此时长,
110    /// 将返回频繁更新错误。
111    pub async fn update_token(
112        &self,
113        backoff_seconds: u64,
114    ) -> Result<(), Box<dyn StdError + Send + Sync>> {
115        // 获取token写权限
116        let mut access_token = self.access_token.write().await;
117
118        // 企业微信服务器对高频的接口调用存在风控措施。因此需要管制接口调用频率。
119        let seconds_since_last_update = SystemTime::now()
120            .duration_since(access_token.timestamp())?
121            .as_secs();
122        if seconds_since_last_update < backoff_seconds {
123            return Err(Box::new(error::Error::new(
124                -9,
125                format!("Access token更新过于频繁。上次更新于{seconds_since_last_update}秒前。"),
126            )));
127        }
128
129        // Fetch a new token
130        let url = format!(
131            "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={}&corpsecret={}",
132            self.corp_id, self.secret,
133        );
134        let response = reqwest::get(url)
135            .await?
136            .json::<AccessTokenResponse>()
137            .await?;
138        if response.errcode != 0 {
139            return Err(Box::<error::Error>::new(error::Error::new(
140                response.errcode,
141                response.errmsg,
142            )));
143        };
144
145        // Update token with a write lock
146        access_token.update(
147            &response.access_token,
148            SystemTime::now(),
149            Duration::from_secs(response.expires_in),
150        );
151        Ok(())
152    }
153
154    /// 发送应用消息
155    pub async fn send<T>(&self, msg: T) -> Result<MsgSendResponse, Box<dyn StdError + Send + Sync>>
156    where
157        T: Serialize,
158    {
159        // 需要更新Token?
160        let token_should_update: bool = {
161            let access_token = self.access_token.read().await;
162            access_token.value().is_none() || access_token.expire_in(300) || access_token.expired()
163        };
164        if token_should_update {
165            warn!("Token invalid. Updating...");
166            self.update_token(10).await?;
167            info!("Token updated");
168        }
169
170        // API地址
171        let url = {
172            let access_token = self.access_token.read().await;
173            format!(
174                "https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={}",
175                access_token
176                    .value()
177                    .expect("Access token should not be None.")
178            )
179        };
180
181        // 第一次发送
182        debug!("Sending [try 1]...");
183        let mut response: MsgSendResponse = self
184            .client
185            .post(&url)
186            .json(&msg)
187            .send()
188            .await?
189            .json::<MsgSendResponse>()
190            .await?;
191
192        // 微信服务器主动弃用了当前token?
193        if response.error_code() == 40014 {
194            warn!("Token invalid. Updating...");
195            self.update_token(10).await?;
196
197            // 第二次发送
198            debug!("Sending [try 2]...");
199            response = self
200                .client
201                .post(&url)
202                .json(&msg)
203                .send()
204                .await?
205                .json::<MsgSendResponse>()
206                .await?;
207        };
208
209        debug!("Sending [Done]");
210        Ok(response)
211    }
212}
213
214// 应用消息发送结果
215#[derive(Deserialize)]
216pub struct MsgSendResponse {
217    errcode: i64,
218    errmsg: String,
219    #[allow(dead_code)]
220    invaliduser: Option<String>,
221    #[allow(dead_code)]
222    invalidparty: Option<String>,
223    #[allow(dead_code)]
224    invalidtag: Option<String>,
225    #[allow(dead_code)]
226    unlicenseduser: Option<String>,
227    #[allow(dead_code)]
228    msgid: Option<String>,
229    #[allow(dead_code)]
230    response_code: Option<String>,
231}
232
233impl MsgSendResponse {
234    pub fn is_error(&self) -> bool {
235        self.errcode != 0
236    }
237
238    pub fn error_code(&self) -> i64 {
239        self.errcode
240    }
241
242    pub fn error_msg(&self) -> &str {
243        &self.errmsg
244    }
245}
246
247// 获取Access Token时的返回结果
248// 示例
249// {
250//     "errcode": 0,
251//     "errmsg": "ok",
252//     "access_token": "accesstoken000001",
253//     "expires_in": 7200
254// }
255#[derive(Deserialize)]
256struct AccessTokenResponse {
257    errcode: i64,
258    errmsg: String,
259    access_token: String,
260    expires_in: u64,
261}