Skip to main content

wechat_mp_sdk/api/
security.rs

1//! Security API
2//!
3//! Endpoints for content security checks including text, media, and user risk assessment.
4//!
5//! # Endpoints
6//!
7//! - [`SecurityApi::msg_sec_check`] - Check text content for policy violations
8//! - [`SecurityApi::media_check_async`] - Async check media for policy violations
9//! - [`SecurityApi::get_user_risk_rank`] - Get user risk rank score
10
11use std::sync::Arc;
12
13use serde::{Deserialize, Serialize};
14
15use super::{WechatApi, WechatContext};
16use crate::error::WechatError;
17
18// ============================================================================
19// Request Types (internal)
20// ============================================================================
21
22#[derive(Debug, Clone, Serialize)]
23struct MsgSecCheckRequest {
24    version: u8,
25    openid: String,
26    scene: u8,
27    content: String,
28}
29
30#[derive(Debug, Clone, Serialize)]
31struct MediaCheckAsyncRequest {
32    media_url: String,
33    media_type: u8,
34    version: u8,
35    openid: String,
36    scene: u8,
37}
38
39#[derive(Debug, Clone, Serialize)]
40struct UserRiskRankRequest {
41    appid: String,
42    openid: String,
43    scene: u8,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    client_ip: Option<String>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    mobile_no: Option<String>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    email_address: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    extended_info: Option<String>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    is_test: Option<bool>,
54}
55
56// ============================================================================
57// Public Response Types
58// ============================================================================
59
60/// Detail item from message security check
61#[non_exhaustive]
62#[derive(Debug, Clone, Default, Deserialize, Serialize)]
63pub struct MsgSecCheckDetail {
64    /// Strategy used
65    #[serde(default)]
66    pub strategy: String,
67    /// Error code for this detail
68    #[serde(default)]
69    pub errcode: i32,
70    /// Suggestion: "pass", "risky", or "review"
71    #[serde(default)]
72    pub suggest: String,
73    /// Label classification (100=normal, 10001=ad, etc.)
74    #[serde(default)]
75    pub label: i32,
76    /// Matched keyword (if any)
77    #[serde(default)]
78    pub keyword: String,
79    /// Confidence probability (0-100)
80    #[serde(default)]
81    pub prob: i32,
82}
83
84/// Result summary from message security check
85#[non_exhaustive]
86#[derive(Debug, Clone, Default, Deserialize, Serialize)]
87pub struct MsgSecCheckResult {
88    /// Suggestion: "pass", "risky", or "review"
89    #[serde(default)]
90    pub suggest: String,
91    /// Label classification
92    #[serde(default)]
93    pub label: i32,
94}
95
96/// Response from msgSecCheck
97#[non_exhaustive]
98#[derive(Debug, Clone, Deserialize, Serialize)]
99pub struct MsgSecCheckResponse {
100    /// Overall result
101    #[serde(default)]
102    pub result: MsgSecCheckResult,
103    /// Detailed results per strategy
104    #[serde(default)]
105    pub detail: Vec<MsgSecCheckDetail>,
106    /// Error code (0 means success)
107    #[serde(default)]
108    pub(crate) errcode: i32,
109    /// Error message
110    #[serde(default)]
111    pub(crate) errmsg: String,
112}
113
114/// Response from mediaCheckAsync
115#[non_exhaustive]
116#[derive(Debug, Clone, Deserialize, Serialize)]
117pub struct MediaCheckAsyncResponse {
118    /// Trace ID for querying result
119    #[serde(default)]
120    pub trace_id: String,
121    /// Error code (0 means success)
122    #[serde(default)]
123    pub(crate) errcode: i32,
124    /// Error message
125    #[serde(default)]
126    pub(crate) errmsg: String,
127}
128
129/// Response from getUserRiskRank
130#[non_exhaustive]
131#[derive(Debug, Clone, Deserialize, Serialize)]
132pub struct UserRiskRankResponse {
133    /// Risk rank: 0-4 (0=no risk, 4=highest risk)
134    #[serde(default)]
135    pub risk_rank: i32,
136    /// Union ID (note: WeChat API field is "unoin_id", not "union_id")
137    #[serde(default)]
138    pub unoin_id: i32,
139    /// Error code (0 means success)
140    #[serde(default)]
141    pub(crate) errcode: i32,
142    /// Error message
143    #[serde(default)]
144    pub(crate) errmsg: String,
145}
146
147/// Options for getUserRiskRank
148#[non_exhaustive]
149#[derive(Debug, Clone, Default)]
150pub struct UserRiskRankOptions {
151    /// Client IP address
152    pub client_ip: Option<String>,
153    /// Mobile number
154    pub mobile_no: Option<String>,
155    /// Email address
156    pub email_address: Option<String>,
157    /// Extended info string
158    pub extended_info: Option<String>,
159    /// Whether this is a test request
160    pub is_test: Option<bool>,
161}
162
163// ============================================================================
164// SecurityApi
165// ============================================================================
166
167/// Security API
168///
169/// Provides methods for content security checks including text, media,
170/// and user risk assessment.
171pub struct SecurityApi {
172    context: Arc<WechatContext>,
173}
174
175impl SecurityApi {
176    /// Create a new SecurityApi instance
177    pub fn new(context: Arc<WechatContext>) -> Self {
178        Self { context }
179    }
180
181    /// Check text content for policy violations
182    ///
183    /// POST /wxa/msg_sec_check?access_token=ACCESS_TOKEN
184    ///
185    /// # Arguments
186    /// * `openid` - User's OpenID
187    /// * `scene` - Scene value (1=profile, 2=comment, 3=forum, 4=social log)
188    /// * `content` - Text content to check
189    pub async fn msg_sec_check(
190        &self,
191        openid: &str,
192        scene: u8,
193        content: &str,
194    ) -> Result<MsgSecCheckResponse, WechatError> {
195        let body = MsgSecCheckRequest {
196            version: 2,
197            openid: openid.to_string(),
198            scene,
199            content: content.to_string(),
200        };
201        let response: MsgSecCheckResponse = self
202            .context
203            .authed_post("/wxa/msg_sec_check", &body)
204            .await?;
205        WechatError::check_api(response.errcode, &response.errmsg)?;
206        Ok(response)
207    }
208
209    /// Async check media (image/audio) for policy violations
210    ///
211    /// POST /wxa/media_check_async?access_token=ACCESS_TOKEN
212    ///
213    /// # Arguments
214    /// * `media_url` - URL of the media to check
215    /// * `media_type` - Media type (1=audio, 2=image)
216    /// * `openid` - User's OpenID
217    /// * `scene` - Scene value
218    pub async fn media_check_async(
219        &self,
220        media_url: &str,
221        media_type: u8,
222        openid: &str,
223        scene: u8,
224    ) -> Result<MediaCheckAsyncResponse, WechatError> {
225        let body = MediaCheckAsyncRequest {
226            media_url: media_url.to_string(),
227            media_type,
228            version: 2,
229            openid: openid.to_string(),
230            scene,
231        };
232        let response: MediaCheckAsyncResponse = self
233            .context
234            .authed_post("/wxa/media_check_async", &body)
235            .await?;
236        WechatError::check_api(response.errcode, &response.errmsg)?;
237        Ok(response)
238    }
239
240    /// Get user risk rank score
241    ///
242    /// POST /wxa/getuserriskrank?access_token=ACCESS_TOKEN
243    ///
244    /// # Arguments
245    /// * `openid` - User's OpenID
246    /// * `scene` - Scene value (0=registration, 1=marketing)
247    /// * `options` - Additional optional parameters
248    pub async fn get_user_risk_rank(
249        &self,
250        openid: &str,
251        scene: u8,
252        options: Option<UserRiskRankOptions>,
253    ) -> Result<UserRiskRankResponse, WechatError> {
254        let opts = options.unwrap_or_default();
255        let body = UserRiskRankRequest {
256            appid: self.context.client.appid().to_string(),
257            openid: openid.to_string(),
258            scene,
259            client_ip: opts.client_ip,
260            mobile_no: opts.mobile_no,
261            email_address: opts.email_address,
262            extended_info: opts.extended_info,
263            is_test: opts.is_test,
264        };
265        let response: UserRiskRankResponse = self
266            .context
267            .authed_post("/wxa/getuserriskrank", &body)
268            .await?;
269        WechatError::check_api(response.errcode, &response.errmsg)?;
270        Ok(response)
271    }
272}
273
274impl WechatApi for SecurityApi {
275    fn context(&self) -> &WechatContext {
276        &self.context
277    }
278
279    fn api_name(&self) -> &'static str {
280        "security"
281    }
282}
283
284// ============================================================================
285// Tests
286// ============================================================================
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::client::WechatClient;
292    use crate::token::TokenManager;
293    use crate::types::{AppId, AppSecret};
294
295    fn create_test_context(base_url: &str) -> Arc<WechatContext> {
296        let appid = AppId::new("wx1234567890abcdef").unwrap();
297        let secret = AppSecret::new("secret1234567890ab").unwrap();
298        let client = Arc::new(
299            WechatClient::builder()
300                .appid(appid)
301                .secret(secret)
302                .base_url(base_url)
303                .build()
304                .unwrap(),
305        );
306        let token_manager = Arc::new(TokenManager::new((*client).clone()));
307        Arc::new(WechatContext::new(client, token_manager))
308    }
309
310    async fn setup_token_mock(mock_server: &wiremock::MockServer) {
311        use wiremock::matchers::{method, path};
312        use wiremock::{Mock, ResponseTemplate};
313
314        Mock::given(method("GET"))
315            .and(path("/cgi-bin/token"))
316            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
317                "access_token": "test_token",
318                "expires_in": 7200,
319                "errcode": 0,
320                "errmsg": ""
321            })))
322            .mount(mock_server)
323            .await;
324    }
325
326    // ---- Deserialization Tests ----
327
328    #[test]
329    fn test_msg_sec_check_response_parse() {
330        let json = r#"{
331            "result": {
332                "suggest": "pass",
333                "label": 100
334            },
335            "detail": [
336                {
337                    "strategy": "content_model",
338                    "errcode": 0,
339                    "suggest": "pass",
340                    "label": 100,
341                    "keyword": "",
342                    "prob": 90
343                }
344            ],
345            "errcode": 0,
346            "errmsg": "ok"
347        }"#;
348
349        let response: MsgSecCheckResponse = serde_json::from_str(json).unwrap();
350        assert_eq!(response.result.suggest, "pass");
351        assert_eq!(response.result.label, 100);
352        assert_eq!(response.detail.len(), 1);
353        assert_eq!(response.detail[0].strategy, "content_model");
354        assert_eq!(response.detail[0].prob, 90);
355        assert_eq!(response.errcode, 0);
356    }
357
358    #[test]
359    fn test_msg_sec_check_response_defaults() {
360        let json = r#"{"errcode": 0, "errmsg": "ok"}"#;
361        let response: MsgSecCheckResponse = serde_json::from_str(json).unwrap();
362        assert!(response.detail.is_empty());
363        assert!(response.result.suggest.is_empty());
364    }
365
366    #[test]
367    fn test_media_check_async_response_parse() {
368        let json = r#"{
369            "trace_id": "trace_abc123",
370            "errcode": 0,
371            "errmsg": "ok"
372        }"#;
373
374        let response: MediaCheckAsyncResponse = serde_json::from_str(json).unwrap();
375        assert_eq!(response.trace_id, "trace_abc123");
376        assert_eq!(response.errcode, 0);
377    }
378
379    #[test]
380    fn test_user_risk_rank_response_parse() {
381        let json = r#"{
382            "risk_rank": 2,
383            "unoin_id": 12345,
384            "errcode": 0,
385            "errmsg": "ok"
386        }"#;
387
388        let response: UserRiskRankResponse = serde_json::from_str(json).unwrap();
389        assert_eq!(response.risk_rank, 2);
390        assert_eq!(response.unoin_id, 12345);
391        assert_eq!(response.errcode, 0);
392    }
393
394    #[test]
395    fn test_user_risk_rank_response_defaults() {
396        let json = r#"{"errcode": 0, "errmsg": "ok"}"#;
397        let response: UserRiskRankResponse = serde_json::from_str(json).unwrap();
398        assert_eq!(response.risk_rank, 0);
399        assert_eq!(response.unoin_id, 0);
400    }
401
402    #[test]
403    fn test_api_name() {
404        let context = create_test_context("http://localhost:0");
405        let api = SecurityApi::new(context);
406        assert_eq!(api.api_name(), "security");
407    }
408
409    // ---- Wiremock Integration Tests ----
410
411    #[tokio::test]
412    async fn test_msg_sec_check_success() {
413        use wiremock::matchers::{method, path, query_param};
414        use wiremock::{Mock, MockServer, ResponseTemplate};
415
416        let mock_server = MockServer::start().await;
417        setup_token_mock(&mock_server).await;
418
419        Mock::given(method("POST"))
420            .and(path("/wxa/msg_sec_check"))
421            .and(query_param("access_token", "test_token"))
422            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
423                "result": {"suggest": "pass", "label": 100},
424                "detail": [{"strategy": "content_model", "errcode": 0, "suggest": "pass", "label": 100, "keyword": "", "prob": 90}],
425                "errcode": 0,
426                "errmsg": "ok"
427            })))
428            .mount(&mock_server)
429            .await;
430
431        let context = create_test_context(&mock_server.uri());
432        let api = SecurityApi::new(context);
433        let result = api.msg_sec_check("openid123", 1, "hello world").await;
434        assert!(result.is_ok());
435        let response = result.unwrap();
436        assert_eq!(response.result.suggest, "pass");
437    }
438
439    #[tokio::test]
440    async fn test_msg_sec_check_api_error() {
441        use wiremock::matchers::{method, path};
442        use wiremock::{Mock, MockServer, ResponseTemplate};
443
444        let mock_server = MockServer::start().await;
445        setup_token_mock(&mock_server).await;
446
447        Mock::given(method("POST"))
448            .and(path("/wxa/msg_sec_check"))
449            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
450                "errcode": 87014,
451                "errmsg": "risky content"
452            })))
453            .mount(&mock_server)
454            .await;
455
456        let context = create_test_context(&mock_server.uri());
457        let api = SecurityApi::new(context);
458        let result = api.msg_sec_check("openid123", 1, "bad content").await;
459        assert!(result.is_err());
460        if let Err(WechatError::Api { code, message }) = result {
461            assert_eq!(code, 87014);
462            assert_eq!(message, "risky content");
463        } else {
464            panic!("Expected WechatError::Api");
465        }
466    }
467
468    #[tokio::test]
469    async fn test_media_check_async_success() {
470        use wiremock::matchers::{method, path, query_param};
471        use wiremock::{Mock, MockServer, ResponseTemplate};
472
473        let mock_server = MockServer::start().await;
474        setup_token_mock(&mock_server).await;
475
476        Mock::given(method("POST"))
477            .and(path("/wxa/media_check_async"))
478            .and(query_param("access_token", "test_token"))
479            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
480                "trace_id": "trace_123",
481                "errcode": 0,
482                "errmsg": "ok"
483            })))
484            .mount(&mock_server)
485            .await;
486
487        let context = create_test_context(&mock_server.uri());
488        let api = SecurityApi::new(context);
489        let result = api
490            .media_check_async("https://example.com/image.jpg", 2, "openid123", 1)
491            .await;
492        assert!(result.is_ok());
493        assert_eq!(result.unwrap().trace_id, "trace_123");
494    }
495
496    #[tokio::test]
497    async fn test_get_user_risk_rank_success() {
498        use wiremock::matchers::{method, path, query_param};
499        use wiremock::{Mock, MockServer, ResponseTemplate};
500
501        let mock_server = MockServer::start().await;
502        setup_token_mock(&mock_server).await;
503
504        Mock::given(method("POST"))
505            .and(path("/wxa/getuserriskrank"))
506            .and(query_param("access_token", "test_token"))
507            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
508                "risk_rank": 1,
509                "unoin_id": 99,
510                "errcode": 0,
511                "errmsg": "ok"
512            })))
513            .mount(&mock_server)
514            .await;
515
516        let context = create_test_context(&mock_server.uri());
517        let api = SecurityApi::new(context);
518        let result = api.get_user_risk_rank("openid123", 0, None).await;
519        assert!(result.is_ok());
520        let response = result.unwrap();
521        assert_eq!(response.risk_rank, 1);
522        assert_eq!(response.unoin_id, 99);
523    }
524}