wechat_minapp/minapp_security/
msg_sec_check.rs

1use super::{Label, Suggest};
2use crate::{Result, client::Client, constants, error::Error};
3use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use tracing::debug;
7
8/// 内容安全检测场景
9#[derive(Debug, Serialize, Clone, Copy, PartialEq)]
10pub enum Scene {
11    /// 资料
12    Profile = 1,
13    /// 评论
14    Comment = 2,
15    /// 论坛
16    Forum = 3,
17    /// 社交日志
18    SocialLog = 4,
19}
20
21/// 微信内容安全检测请求参数
22#[derive(Debug, Serialize, Clone)]
23pub struct Args {
24    /// 需检测的文本内容,文本字数的上限为2500字,需使用UTF-8编码
25    pub content: String,
26    /// 接口版本号,2.0版本为固定值2
27    pub version: u32,
28    /// 场景枚举值
29    pub scene: Scene,
30    /// 用户的openid(用户需在近两小时访问过小程序)
31    pub openid: String,
32    /// 文本标题,需使用UTF-8编码
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub title: Option<String>,
35    /// 用户昵称,需使用UTF-8编码
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub nickname: Option<String>,
38    /// 个性签名,该参数仅在资料类场景有效(scene=1),需使用UTF-8编码
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub signature: Option<String>,
41}
42
43/// Args 构建器,提供链式调用和验证
44#[derive(Debug, Default)]
45pub struct ArgsBuilder {
46    content: Option<String>,
47    version: Option<u32>,
48    scene: Option<Scene>,
49    openid: Option<String>,
50    title: Option<String>,
51    nickname: Option<String>,
52    signature: Option<String>,
53}
54
55impl ArgsBuilder {
56    /// 创建新的构建器实例
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// 设置检测文本内容
62    pub fn content(mut self, content: impl Into<String>) -> Self {
63        self.content = Some(content.into());
64        self
65    }
66
67    /// 设置接口版本号(通常为2)
68    pub fn version(mut self, version: u32) -> Self {
69        self.version = Some(version);
70        self
71    }
72
73    /// 设置场景
74    pub fn scene(mut self, scene: Scene) -> Self {
75        self.scene = Some(scene);
76        self
77    }
78
79    /// 设置用户openid
80    pub fn openid(mut self, openid: impl Into<String>) -> Self {
81        self.openid = Some(openid.into());
82        self
83    }
84
85    /// 设置文本标题
86    pub fn title(mut self, title: impl Into<String>) -> Self {
87        self.title = Some(title.into());
88        self
89    }
90
91    /// 设置用户昵称
92    pub fn nickname(mut self, nickname: impl Into<String>) -> Self {
93        self.nickname = Some(nickname.into());
94        self
95    }
96
97    /// 设置个性签名(仅在资料场景有效)
98    pub fn signature(mut self, signature: impl Into<String>) -> Self {
99        self.signature = Some(signature.into());
100        self
101    }
102
103    /// 构建 Args,验证必填字段
104    pub fn build(self) -> Result<Args> {
105        let content = self
106            .content
107            .ok_or(Error::InvalidParameter("content 是必填参数".to_string()))?;
108        let version = self.version.unwrap_or(2); // 默认版本为2
109        let scene = self
110            .scene
111            .ok_or(Error::InvalidParameter("scene 是必填参数".to_string()))?;
112        let openid = self
113            .openid
114            .ok_or(Error::InvalidParameter("openid 是必填参数".to_string()))?;
115
116        // 内容长度验证
117        if content.len() > 2500 {
118            return Err(Error::InvalidParameter(
119                "content 长度不能超过2500字".to_string(),
120            ));
121        }
122
123        // 场景与签名的关联验证
124        if self.signature.is_some() && scene != Scene::Profile {
125            return Err(Error::InvalidParameter(
126                "signature 仅在资料场景(scene=1)下有效".to_string(),
127            ));
128        }
129
130        Ok(Args {
131            content,
132            version,
133            scene,
134            openid,
135            title: self.title,
136            nickname: self.nickname,
137            signature: self.signature,
138        })
139    }
140}
141
142// 为 Args 实现便捷的构建方法
143impl Args {
144    /// 创建构建器
145    pub fn builder() -> ArgsBuilder {
146        ArgsBuilder::new()
147    }
148
149    /// 快速创建基本参数(使用默认版本2)
150    pub fn new(content: impl Into<String>, scene: Scene, openid: impl Into<String>) -> Self {
151        Self {
152            content: content.into(),
153            version: 2,
154            scene,
155            openid: openid.into(),
156            title: None,
157            nickname: None,
158            signature: None,
159        }
160    }
161
162    /// 检查是否为资料场景
163    pub fn is_profile_scene(&self) -> bool {
164        self.scene == Scene::Profile
165    }
166
167    /// 获取内容长度
168    pub fn content_length(&self) -> usize {
169        self.content.len()
170    }
171
172    /// 验证参数是否有效
173    pub fn validate(&self) -> Result<()> {
174        if self.content.len() > 2500 {
175            return Err(Error::InvalidParameter(
176                "content 长度不能超过2500字".to_string(),
177            ));
178        }
179
180        if self.signature.is_some() && !self.is_profile_scene() {
181            return Err(Error::InvalidParameter(
182                "signature 仅在资料场景(scene=1)下有效".to_string(),
183            ));
184        }
185
186        Ok(())
187    }
188}
189
190// Scene 枚举的便捷方法
191impl Scene {
192    /// 从数值创建场景
193    pub fn from_value(value: u32) -> Option<Self> {
194        match value {
195            1 => Some(Scene::Profile),
196            2 => Some(Scene::Comment),
197            3 => Some(Scene::Forum),
198            4 => Some(Scene::SocialLog),
199            _ => None,
200        }
201    }
202
203    /// 获取场景描述
204    pub fn description(&self) -> &'static str {
205        match self {
206            Scene::Profile => "资料",
207            Scene::Comment => "评论",
208            Scene::Forum => "论坛",
209            Scene::SocialLog => "社交日志",
210        }
211    }
212}
213
214/// 详细检测结果
215#[derive(Debug, Deserialize, Serialize, Clone)]
216pub struct DetailResult {
217    /// 策略类型
218    pub strategy: String,
219    /// 错误码,仅当该值为0时,该项结果有效
220    pub errcode: i32,
221    /// 建议
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub suggest: Option<Suggest>,
224    /// 命中标签枚举值(可能不存在)
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub label: Option<Label>,
227    /// 命中的自定义关键词(可能不存在)
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub keyword: Option<String>,
230    /// 0-100,代表置信度,越高代表越有可能属于当前返回的标签(label)(可能不存在)
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub prob: Option<f64>,
233}
234
235/// 综合结果
236#[derive(Debug, Deserialize, Serialize, Clone)]
237pub struct ComprehensiveResult {
238    /// 建议
239    pub suggest: Suggest,
240    /// 命中标签枚举值
241    pub label: Label,
242}
243
244/// 内容安全检测返回结果
245#[derive(Debug, Deserialize, Serialize, Clone)]
246pub struct MsgSecCheckResult {
247    /// 错误码
248    pub errcode: i32,
249    /// 错误信息
250    pub errmsg: String,
251    /// 详细检测结果
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub detail: Option<Vec<DetailResult>>,
254    /// 综合结果
255    #[serde(skip_serializing_if = "Option::is_none")]
256    pub result: Option<ComprehensiveResult>,
257    /// 唯一请求标识,标记单次请求
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub trace_id: Option<String>,
260}
261
262// 为 MsgSecCheckResult 实现一些便捷方法
263impl MsgSecCheckResult {
264    /// 检查请求是否成功(errcode 为 0)
265    pub fn is_success(&self) -> bool {
266        self.errcode == 0
267    }
268
269    /// 获取综合建议
270    pub fn get_suggest(&self) -> Option<&Suggest> {
271        self.result.as_ref().map(|r| &r.suggest)
272    }
273
274    /// 获取综合标签
275    pub fn get_label(&self) -> Option<&Label> {
276        self.result.as_ref().map(|r| &r.label)
277    }
278
279    /// 检查是否通过
280    pub fn is_pass(&self) -> bool {
281        self.get_suggest().map(|s| s.is_pass()).unwrap_or(false)
282    }
283
284    /// 检查是否有风险
285    pub fn is_risky(&self) -> bool {
286        self.get_suggest().map(|s| s.is_risky()).unwrap_or(false)
287    }
288
289    /// 检查是否需要审核
290    pub fn needs_review(&self) -> bool {
291        self.get_suggest()
292            .map(|s| s.needs_review())
293            .unwrap_or(false)
294    }
295
296    /// 获取有效的详细检测结果(errcode 为 0 的项)
297    pub fn get_valid_details(&self) -> Vec<&DetailResult> {
298        self.detail
299            .as_ref()
300            .map(|details| details.iter().filter(|d| d.errcode == 0).collect())
301            .unwrap_or_default()
302    }
303}
304
305impl Client {
306    /// 内容安全检测
307    ///
308    /// # 示例
309    /// ```ignore
310    /// use wechat_minapp::minapp_security::{Args, Scene};
311    /// use wechant_minapp::Client;
312    ///
313    /// #[tokio::main]
314    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
315    ///     let app_id = "your app id";
316    ///     let secret = "your app secret";
317    ///     
318    ///     let client = Client::new(app_id, secret);
319    ///     let args = Args::builder()
320    ///         .content("需要检测的文本内容")
321    ///         .scene(Scene::Comment)
322    ///         .openid("user_openid")
323    ///         .build()?;
324    ///     
325    ///     let result = client.msg_sec_check(args).await?;
326    ///     
327    ///     if result.is_pass() {
328    ///         println!("内容安全,可以发布");
329    ///     } else if result.needs_review() {
330    ///         println!("内容需要人工审核");
331    ///     } else {
332    ///         println!("内容有风险,建议修改");
333    ///     }
334    ///     
335    ///     Ok(())
336    /// }
337    /// ```
338    pub async fn msg_sec_check(&self, args: &Args) -> Result<MsgSecCheckResult> {
339        debug!("msg_sec_check args: {:?}", &args);
340
341        // 验证参数
342        args.validate()?;
343        let access_token = self.access_token().await?;
344        let mut query = HashMap::new();
345        let mut body = HashMap::new();
346        let version = args.version.to_string();
347        let scene = (args.scene as u32).to_string();
348        // URL 参数:access_token
349        query.insert("access_token", &access_token);
350
351        // Body 参数
352        body.insert("content", &args.content);
353        body.insert("version", &version);
354        body.insert("scene", &scene);
355        body.insert("openid", &args.openid);
356
357        if let Some(title) = &args.title {
358            body.insert("title", title);
359        }
360
361        if let Some(nickname) = &args.nickname {
362            body.insert("nickname", nickname);
363        }
364
365        if let Some(signature) = &args.signature {
366            body.insert("signature", signature);
367        }
368
369        let mut headers = HeaderMap::new();
370        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
371
372        let response = self
373            .request()
374            .post(constants::MSG_SEC_CHECK_END_POINT)
375            .headers(headers)
376            .query(&query)
377            .json(&body)
378            .send()
379            .await?;
380
381        debug!("msg_sec_check response: {:#?}", response);
382
383        if response.status().is_success() {
384            let response_text = response.text().await?;
385            debug!("msg_sec_check response body: {}", response_text);
386
387            let result: MsgSecCheckResult = serde_json::from_str(&response_text)?;
388
389            if result.is_success() {
390                Ok(result)
391            } else {
392                // 微信API返回错误
393                Err(Error::InternalServer(format!(
394                    "微信内容安全检测API错误: {} - {}",
395                    result.errcode, result.errmsg
396                )))
397            }
398        } else {
399            // HTTP 请求错误
400            Err(Error::InternalServer(response.text().await?))
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_args_builder() {
411        let args = Args::builder()
412            .content("测试内容")
413            .scene(Scene::Comment)
414            .openid("test_openid")
415            .build()
416            .unwrap();
417
418        assert_eq!(args.content, "测试内容");
419        assert_eq!(args.version, 2);
420        assert_eq!(args.scene, Scene::Comment);
421        assert_eq!(args.openid, "test_openid");
422    }
423
424    #[test]
425    fn test_args_builder_validation() {
426        // 测试缺少必填参数
427        let result = Args::builder()
428            .scene(Scene::Comment)
429            .openid("test_openid")
430            .build();
431        assert!(result.is_err());
432
433        // 测试内容超长
434        let long_content = "a".repeat(2501);
435        let result = Args::builder()
436            .content(long_content)
437            .scene(Scene::Comment)
438            .openid("openid")
439            .build();
440        assert!(result.is_err());
441
442        // 测试场景与签名验证
443        let result = Args::builder()
444            .content("内容")
445            .scene(Scene::Comment)
446            .openid("openid")
447            .signature("签名")
448            .build();
449        assert!(result.is_err());
450    }
451
452    #[test]
453    fn test_scene_enum() {
454        assert_eq!(Scene::from_value(1), Some(Scene::Profile));
455        assert_eq!(Scene::Profile.description(), "资料");
456        assert_eq!(Scene::Profile as u32, 1);
457    }
458
459    #[test]
460    fn test_msg_sec_check_result() {
461        let json = r#"
462        {
463            "errcode": 0,
464            "errmsg": "ok",
465            "detail": [
466                {
467                    "strategy": "content_model",
468                    "errcode": 0,
469                    "suggest": "pass",
470                    "label": 100,
471                    "prob": 90.5
472                }
473            ],
474            "result": {
475                "suggest": "pass",
476                "label": 100
477            },
478            "trace_id": "test_trace_id"
479        }"#;
480
481        let result: MsgSecCheckResult = serde_json::from_str(json).unwrap();
482
483        assert!(result.is_success());
484        assert!(result.is_pass());
485        assert!(!result.is_risky());
486        assert!(!result.needs_review());
487        assert_eq!(result.get_valid_details().len(), 1);
488        assert_eq!(result.trace_id, Some("test_trace_id".to_string()));
489    }
490
491    #[test]
492    fn test_msg_sec_check_result_with_risk() {
493        let json = r#"
494        {
495            "errcode": 0,
496            "errmsg": "ok",
497            "detail": [
498                {
499                    "strategy": "content_model",
500                    "errcode": 0,
501                    "suggest": "risky",
502                    "label": 20001,
503                    "keyword": "敏感词",
504                    "prob": 95.0
505                }
506            ],
507            "result": {
508                "suggest": "risky",
509                "label": 20001
510            }
511        }"#;
512
513        let result: MsgSecCheckResult = serde_json::from_str(json).unwrap();
514
515        assert!(result.is_success());
516        assert!(!result.is_pass());
517        assert!(result.is_risky());
518        assert!(!result.needs_review());
519        assert_eq!(
520            result.get_valid_details()[0].keyword,
521            Some("敏感词".to_string())
522        );
523    }
524}