1use std::sync::Arc;
12
13use serde::{Deserialize, Serialize};
14
15use super::{WechatApi, WechatContext};
16use crate::error::WechatError;
17
18#[derive(Debug, Clone, Serialize)]
23struct MsgSecCheckRequest {
24 version: u8,
25 openid: String,
26 scene: u8,
27 content: String,
28}
29
30#[derive(Debug, Clone, Serialize)]
31struct MediaCheckAsyncRequest {
32 media_url: String,
33 media_type: u8,
34 version: u8,
35 openid: String,
36 scene: u8,
37}
38
39#[derive(Debug, Clone, Serialize)]
40struct UserRiskRankRequest {
41 appid: String,
42 openid: String,
43 scene: u8,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 client_ip: Option<String>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 mobile_no: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 email_address: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 extended_info: Option<String>,
52 #[serde(skip_serializing_if = "Option::is_none")]
53 is_test: Option<bool>,
54}
55
56#[non_exhaustive]
62#[derive(Debug, Clone, Default, Deserialize, Serialize)]
63pub struct MsgSecCheckDetail {
64 #[serde(default)]
66 pub strategy: String,
67 #[serde(default)]
69 pub errcode: i32,
70 #[serde(default)]
72 pub suggest: String,
73 #[serde(default)]
75 pub label: i32,
76 #[serde(default)]
78 pub keyword: String,
79 #[serde(default)]
81 pub prob: i32,
82}
83
84#[non_exhaustive]
86#[derive(Debug, Clone, Default, Deserialize, Serialize)]
87pub struct MsgSecCheckResult {
88 #[serde(default)]
90 pub suggest: String,
91 #[serde(default)]
93 pub label: i32,
94}
95
96#[non_exhaustive]
98#[derive(Debug, Clone, Deserialize, Serialize)]
99pub struct MsgSecCheckResponse {
100 #[serde(default)]
102 pub result: MsgSecCheckResult,
103 #[serde(default)]
105 pub detail: Vec<MsgSecCheckDetail>,
106 #[serde(default)]
108 pub(crate) errcode: i32,
109 #[serde(default)]
111 pub(crate) errmsg: String,
112}
113
114#[non_exhaustive]
116#[derive(Debug, Clone, Deserialize, Serialize)]
117pub struct MediaCheckAsyncResponse {
118 #[serde(default)]
120 pub trace_id: String,
121 #[serde(default)]
123 pub(crate) errcode: i32,
124 #[serde(default)]
126 pub(crate) errmsg: String,
127}
128
129#[non_exhaustive]
131#[derive(Debug, Clone, Deserialize, Serialize)]
132pub struct UserRiskRankResponse {
133 #[serde(default)]
135 pub risk_rank: i32,
136 #[serde(default)]
138 pub unoin_id: i32,
139 #[serde(default)]
141 pub(crate) errcode: i32,
142 #[serde(default)]
144 pub(crate) errmsg: String,
145}
146
147#[non_exhaustive]
149#[derive(Debug, Clone, Default)]
150pub struct UserRiskRankOptions {
151 pub client_ip: Option<String>,
153 pub mobile_no: Option<String>,
155 pub email_address: Option<String>,
157 pub extended_info: Option<String>,
159 pub is_test: Option<bool>,
161}
162
163pub struct SecurityApi {
172 context: Arc<WechatContext>,
173}
174
175impl SecurityApi {
176 pub fn new(context: Arc<WechatContext>) -> Self {
178 Self { context }
179 }
180
181 pub async fn msg_sec_check(
190 &self,
191 openid: &str,
192 scene: u8,
193 content: &str,
194 ) -> Result<MsgSecCheckResponse, WechatError> {
195 let body = MsgSecCheckRequest {
196 version: 2,
197 openid: openid.to_string(),
198 scene,
199 content: content.to_string(),
200 };
201 let response: MsgSecCheckResponse = self
202 .context
203 .authed_post("/wxa/msg_sec_check", &body)
204 .await?;
205 WechatError::check_api(response.errcode, &response.errmsg)?;
206 Ok(response)
207 }
208
209 pub async fn media_check_async(
219 &self,
220 media_url: &str,
221 media_type: u8,
222 openid: &str,
223 scene: u8,
224 ) -> Result<MediaCheckAsyncResponse, WechatError> {
225 let body = MediaCheckAsyncRequest {
226 media_url: media_url.to_string(),
227 media_type,
228 version: 2,
229 openid: openid.to_string(),
230 scene,
231 };
232 let response: MediaCheckAsyncResponse = self
233 .context
234 .authed_post("/wxa/media_check_async", &body)
235 .await?;
236 WechatError::check_api(response.errcode, &response.errmsg)?;
237 Ok(response)
238 }
239
240 pub async fn get_user_risk_rank(
249 &self,
250 openid: &str,
251 scene: u8,
252 options: Option<UserRiskRankOptions>,
253 ) -> Result<UserRiskRankResponse, WechatError> {
254 let opts = options.unwrap_or_default();
255 let body = UserRiskRankRequest {
256 appid: self.context.client.appid().to_string(),
257 openid: openid.to_string(),
258 scene,
259 client_ip: opts.client_ip,
260 mobile_no: opts.mobile_no,
261 email_address: opts.email_address,
262 extended_info: opts.extended_info,
263 is_test: opts.is_test,
264 };
265 let response: UserRiskRankResponse = self
266 .context
267 .authed_post("/wxa/getuserriskrank", &body)
268 .await?;
269 WechatError::check_api(response.errcode, &response.errmsg)?;
270 Ok(response)
271 }
272}
273
274impl WechatApi for SecurityApi {
275 fn context(&self) -> &WechatContext {
276 &self.context
277 }
278
279 fn api_name(&self) -> &'static str {
280 "security"
281 }
282}
283
284#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::client::WechatClient;
292 use crate::token::TokenManager;
293 use crate::types::{AppId, AppSecret};
294
295 fn create_test_context(base_url: &str) -> Arc<WechatContext> {
296 let appid = AppId::new("wx1234567890abcdef").unwrap();
297 let secret = AppSecret::new("secret1234567890ab").unwrap();
298 let client = Arc::new(
299 WechatClient::builder()
300 .appid(appid)
301 .secret(secret)
302 .base_url(base_url)
303 .build()
304 .unwrap(),
305 );
306 let token_manager = Arc::new(TokenManager::new((*client).clone()));
307 Arc::new(WechatContext::new(client, token_manager))
308 }
309
310 async fn setup_token_mock(mock_server: &wiremock::MockServer) {
311 use wiremock::matchers::{method, path};
312 use wiremock::{Mock, ResponseTemplate};
313
314 Mock::given(method("GET"))
315 .and(path("/cgi-bin/token"))
316 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
317 "access_token": "test_token",
318 "expires_in": 7200,
319 "errcode": 0,
320 "errmsg": ""
321 })))
322 .mount(mock_server)
323 .await;
324 }
325
326 #[test]
329 fn test_msg_sec_check_response_parse() {
330 let json = r#"{
331 "result": {
332 "suggest": "pass",
333 "label": 100
334 },
335 "detail": [
336 {
337 "strategy": "content_model",
338 "errcode": 0,
339 "suggest": "pass",
340 "label": 100,
341 "keyword": "",
342 "prob": 90
343 }
344 ],
345 "errcode": 0,
346 "errmsg": "ok"
347 }"#;
348
349 let response: MsgSecCheckResponse = serde_json::from_str(json).unwrap();
350 assert_eq!(response.result.suggest, "pass");
351 assert_eq!(response.result.label, 100);
352 assert_eq!(response.detail.len(), 1);
353 assert_eq!(response.detail[0].strategy, "content_model");
354 assert_eq!(response.detail[0].prob, 90);
355 assert_eq!(response.errcode, 0);
356 }
357
358 #[test]
359 fn test_msg_sec_check_response_defaults() {
360 let json = r#"{"errcode": 0, "errmsg": "ok"}"#;
361 let response: MsgSecCheckResponse = serde_json::from_str(json).unwrap();
362 assert!(response.detail.is_empty());
363 assert!(response.result.suggest.is_empty());
364 }
365
366 #[test]
367 fn test_media_check_async_response_parse() {
368 let json = r#"{
369 "trace_id": "trace_abc123",
370 "errcode": 0,
371 "errmsg": "ok"
372 }"#;
373
374 let response: MediaCheckAsyncResponse = serde_json::from_str(json).unwrap();
375 assert_eq!(response.trace_id, "trace_abc123");
376 assert_eq!(response.errcode, 0);
377 }
378
379 #[test]
380 fn test_user_risk_rank_response_parse() {
381 let json = r#"{
382 "risk_rank": 2,
383 "unoin_id": 12345,
384 "errcode": 0,
385 "errmsg": "ok"
386 }"#;
387
388 let response: UserRiskRankResponse = serde_json::from_str(json).unwrap();
389 assert_eq!(response.risk_rank, 2);
390 assert_eq!(response.unoin_id, 12345);
391 assert_eq!(response.errcode, 0);
392 }
393
394 #[test]
395 fn test_user_risk_rank_response_defaults() {
396 let json = r#"{"errcode": 0, "errmsg": "ok"}"#;
397 let response: UserRiskRankResponse = serde_json::from_str(json).unwrap();
398 assert_eq!(response.risk_rank, 0);
399 assert_eq!(response.unoin_id, 0);
400 }
401
402 #[test]
403 fn test_api_name() {
404 let context = create_test_context("http://localhost:0");
405 let api = SecurityApi::new(context);
406 assert_eq!(api.api_name(), "security");
407 }
408
409 #[tokio::test]
412 async fn test_msg_sec_check_success() {
413 use wiremock::matchers::{method, path, query_param};
414 use wiremock::{Mock, MockServer, ResponseTemplate};
415
416 let mock_server = MockServer::start().await;
417 setup_token_mock(&mock_server).await;
418
419 Mock::given(method("POST"))
420 .and(path("/wxa/msg_sec_check"))
421 .and(query_param("access_token", "test_token"))
422 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
423 "result": {"suggest": "pass", "label": 100},
424 "detail": [{"strategy": "content_model", "errcode": 0, "suggest": "pass", "label": 100, "keyword": "", "prob": 90}],
425 "errcode": 0,
426 "errmsg": "ok"
427 })))
428 .mount(&mock_server)
429 .await;
430
431 let context = create_test_context(&mock_server.uri());
432 let api = SecurityApi::new(context);
433 let result = api.msg_sec_check("openid123", 1, "hello world").await;
434 assert!(result.is_ok());
435 let response = result.unwrap();
436 assert_eq!(response.result.suggest, "pass");
437 }
438
439 #[tokio::test]
440 async fn test_msg_sec_check_api_error() {
441 use wiremock::matchers::{method, path};
442 use wiremock::{Mock, MockServer, ResponseTemplate};
443
444 let mock_server = MockServer::start().await;
445 setup_token_mock(&mock_server).await;
446
447 Mock::given(method("POST"))
448 .and(path("/wxa/msg_sec_check"))
449 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
450 "errcode": 87014,
451 "errmsg": "risky content"
452 })))
453 .mount(&mock_server)
454 .await;
455
456 let context = create_test_context(&mock_server.uri());
457 let api = SecurityApi::new(context);
458 let result = api.msg_sec_check("openid123", 1, "bad content").await;
459 assert!(result.is_err());
460 if let Err(WechatError::Api { code, message }) = result {
461 assert_eq!(code, 87014);
462 assert_eq!(message, "risky content");
463 } else {
464 panic!("Expected WechatError::Api");
465 }
466 }
467
468 #[tokio::test]
469 async fn test_media_check_async_success() {
470 use wiremock::matchers::{method, path, query_param};
471 use wiremock::{Mock, MockServer, ResponseTemplate};
472
473 let mock_server = MockServer::start().await;
474 setup_token_mock(&mock_server).await;
475
476 Mock::given(method("POST"))
477 .and(path("/wxa/media_check_async"))
478 .and(query_param("access_token", "test_token"))
479 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
480 "trace_id": "trace_123",
481 "errcode": 0,
482 "errmsg": "ok"
483 })))
484 .mount(&mock_server)
485 .await;
486
487 let context = create_test_context(&mock_server.uri());
488 let api = SecurityApi::new(context);
489 let result = api
490 .media_check_async("https://example.com/image.jpg", 2, "openid123", 1)
491 .await;
492 assert!(result.is_ok());
493 assert_eq!(result.unwrap().trace_id, "trace_123");
494 }
495
496 #[tokio::test]
497 async fn test_get_user_risk_rank_success() {
498 use wiremock::matchers::{method, path, query_param};
499 use wiremock::{Mock, MockServer, ResponseTemplate};
500
501 let mock_server = MockServer::start().await;
502 setup_token_mock(&mock_server).await;
503
504 Mock::given(method("POST"))
505 .and(path("/wxa/getuserriskrank"))
506 .and(query_param("access_token", "test_token"))
507 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
508 "risk_rank": 1,
509 "unoin_id": 99,
510 "errcode": 0,
511 "errmsg": "ok"
512 })))
513 .mount(&mock_server)
514 .await;
515
516 let context = create_test_context(&mock_server.uri());
517 let api = SecurityApi::new(context);
518 let result = api.get_user_risk_rank("openid123", 0, None).await;
519 assert!(result.is_ok());
520 let response = result.unwrap();
521 assert_eq!(response.risk_rank, 1);
522 assert_eq!(response.unoin_id, 99);
523 }
524}