wx_work/
client.rs

1use std::fs::File;
2use std::io::Read;
3use std::path::Path;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::mpsc::{self, Sender};
6use std::sync::{Arc, RwLock};
7use std::thread::{self, JoinHandle};
8use std::time::Duration;
9
10use log::{error, info};
11use reqwest::multipart::{Form, Part};
12use serde::de::DeserializeOwned;
13use serde::{Deserialize, Serialize};
14
15use crate::media::*;
16use crate::message::*;
17use crate::{Error, Result};
18
19static WX_URL: &str = "https://qyapi.weixin.qq.com";
20
21pub struct Client {
22    access_token: Arc<RwLock<String>>,
23    http_client: reqwest::Client,
24    refresh_token_thread: Option<JoinHandle<()>>,
25    is_exit: Arc<AtomicBool>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29struct AccessTokenResponse {
30    errcode: u64,
31    errmsg: String,
32    access_token: String,
33    expires_in: u64,
34}
35
36fn get_access_token(client: &reqwest::blocking::Client, url: &str) -> Result<AccessTokenResponse> {
37    let resp = client.get(url).send()?.json::<AccessTokenResponse>()?;
38
39    if resp.errcode != 0 {
40        return Err(Error::GetAccessTokenFailed(resp.errcode, resp.errmsg));
41    }
42
43    Ok(resp)
44}
45
46fn start_refresh_token_thread(
47    url: String,
48    access_token: Arc<RwLock<String>>,
49    sender: Sender<Result<()>>,
50    is_exit: Arc<AtomicBool>,
51) -> JoinHandle<()> {
52    thread::Builder::new()
53        .name("wx work client".to_string())
54        .spawn(move || {
55            let client = reqwest::blocking::Client::new();
56
57            let d = match get_access_token(&client, &url) {
58                Ok(d) => d,
59                Err(e) => {
60                    sender.send(Err(e)).unwrap();
61                    return;
62                }
63            };
64
65            let mut expires_in = d.expires_in;
66            {
67                let mut token = access_token.write().unwrap();
68                *token = d.access_token;
69            }
70            info!("init token success, expires_in {}", d.expires_in);
71            sender.send(Ok(())).unwrap();
72
73            loop {
74                //let delay_time = expires_in / 2;
75                let delay_time = expires_in / 2;
76
77                thread::park_timeout(Duration::from_secs(delay_time));
78                if is_exit.load(Ordering::Acquire) {
79                    info!("detect exit signal, exit thread");
80                    break;
81                }
82
83                match get_access_token(&client, &url) {
84                    Ok(d) => {
85                        expires_in = d.expires_in;
86                        let mut token = access_token.write().unwrap();
87                        *token = d.access_token;
88                        info!("update token success, expires_in {}", d.expires_in);
89                    }
90                    Err(e) => error!("refresh token failed, reason: {}", e),
91                }
92            }
93        })
94        .unwrap()
95}
96
97impl Client {
98    pub fn new(corp_id: &str, corp_secret: &str) -> Result<Self> {
99        let url = format!(
100            "{}/cgi-bin/gettoken?corpid={}&corpsecret={}",
101            WX_URL, corp_id, corp_secret
102        );
103
104        let http_client = reqwest::Client::new();
105        let (tx, rx) = mpsc::channel();
106
107        let access_token = Arc::new(RwLock::new("".to_string()));
108        let is_exit = Arc::new(AtomicBool::new(false));
109
110        let refresh_token_thread = Some(start_refresh_token_thread(
111            url,
112            access_token.clone(),
113            tx,
114            is_exit.clone(),
115        ));
116
117        rx.recv().unwrap()?;
118
119        info!("construct Client success");
120
121        let ret = Client {
122            access_token,
123            http_client,
124            refresh_token_thread,
125            is_exit,
126        };
127
128        Ok(ret)
129    }
130}
131
132/// 素材管理
133impl Client {
134    pub async fn upload_file(&self, ty: FileType, path: &str) -> Result<UploadFileResponse> {
135        let url = format!(
136            "{}/cgi-bin/media/upload?access_token={}&type={}",
137            WX_URL,
138            self.access_token.read().unwrap(),
139            ty.type_desc()
140        );
141
142        let mut f = File::open(path)?;
143        let file_name = Path::new(path)
144            .file_name()
145            .unwrap()
146            .to_str()
147            .unwrap()
148            .to_string(); // TODO need handle unwrap
149        let mut buf = vec![];
150        f.read_to_end(&mut buf)?;
151
152        let ret = self
153            .upload_media::<UploadFileResponse>(&url, buf, file_name)
154            .await?;
155        if ret.errcode != 0 {
156            Err(Error::UploadMediaFailed(ret.errcode, ret.errmsg))
157        } else {
158            Ok(ret)
159        }
160    }
161
162    pub async fn upload_image(&self, path: &str) -> Result<UploadImageResponse> {
163        let url = format!(
164            "{}/cgi-bin/media/uploadimg?access_token={}",
165            WX_URL,
166            self.access_token.read().unwrap(),
167        );
168
169        let mut f = File::open(path)?;
170        let file_name = Path::new(path)
171            .file_name()
172            .unwrap()
173            .to_str()
174            .unwrap()
175            .to_string(); // TODO need handle unwrap
176        let mut buf = vec![];
177        f.read_to_end(&mut buf)?;
178
179        let ret = self
180            .upload_media::<UploadImageResponse>(&url, buf, file_name)
181            .await?;
182        if ret.errcode != 0 {
183            Err(Error::UploadMediaFailed(ret.errcode, ret.errmsg))
184        } else {
185            Ok(ret)
186        }
187    }
188
189    async fn upload_media<T: DeserializeOwned>(
190        &self,
191        url: &str,
192        data: Vec<u8>,
193        file_name: String,
194    ) -> Result<T> {
195        let part = Part::bytes(data).file_name(file_name);
196        let form = Form::new().part("media", part);
197
198        let ret = self
199            .http_client
200            .post(url)
201            .multipart(form)
202            .send()
203            .await?
204            .json()
205            .await?;
206
207        Ok(ret)
208    }
209}
210
211/// 发送应用消息
212impl Client {
213    pub async fn send_msg(&self, msg: &Message) -> Result<MessageResponse> {
214        let url = format!(
215            "{}/cgi-bin/message/send?access_token={}",
216            WX_URL,
217            self.access_token.read().unwrap(),
218        );
219
220        let ret = self
221            .http_client
222            .post(&url)
223            .json(&msg)
224            .send()
225            .await?
226            .json()
227            .await?;
228
229        Ok(ret)
230    }
231}
232
233impl Drop for Client {
234    fn drop(&mut self) {
235        self.is_exit.store(true, Ordering::Release);
236        let handle = self.refresh_token_thread.take().unwrap();
237        handle.thread().unpark();
238        handle
239            .join()
240            .expect("can not join the refresh token thread");
241        info!("join refresh token thread success");
242    }
243}
244
245// for mannual test
246//#[cfg(test)]
247//mod tests {
248//    use super::*;
249//
250//    use dotenv::dotenv;
251//    use std::env::var;
252//
253//    #[test]
254//    fn test_drop() {
255//        dotenv().ok();
256//        env_logger::init();
257//
258//        let corp_id = var("CORP_ID").unwrap();
259//        let corp_secret = var("CORP_SECRET").unwrap();
260//        let client = Client::new(&corp_id, &corp_secret).unwrap();
261//        drop(client);
262//    }
263//}