Skip to main content

weixin_agent/messaging/
markdown_filter.rs

1//! Streaming markdown filter — character-level state machine that strips
2//! unsupported markdown syntax on-the-fly for `WeChat` output.
3//!
4//! Preserves: code fences, inline code, tables, horizontal rules, bold,
5//! blockquotes, leading indent, italic/bold-italic wrapping non-CJK content.
6//!
7//! Strips markers but keeps content: italic/bold-italic wrapping CJK, H5/H6 headings.
8//!
9//! Removes entirely: images `![alt](url)`.
10
11#![allow(clippy::module_name_repetitions)]
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14enum InlineType {
15    Image,
16    Bold3,
17    Italic,
18    UBold3,
19    UItalic,
20}
21
22#[derive(Debug, Clone)]
23struct InlineState {
24    typ: InlineType,
25    acc: String,
26}
27
28/// A character-level streaming markdown filter that processes text incrementally.
29///
30/// Outputs as much filtered text as possible on each `feed()` call, only
31/// holding back the minimum characters needed for pattern disambiguation.
32#[derive(Debug, Clone)]
33pub struct StreamingMarkdownFilter {
34    buf: String,
35    fence: bool,
36    sol: bool,
37    inl: Option<InlineState>,
38}
39
40impl Default for StreamingMarkdownFilter {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46#[allow(
47    clippy::too_many_lines,
48    clippy::range_plus_one,
49    clippy::manual_range_contains,
50    clippy::naive_bytecount,
51    clippy::manual_strip,
52    clippy::let_and_return,
53    clippy::doc_markdown,
54    clippy::items_after_statements,
55    clippy::cast_possible_truncation
56)]
57impl StreamingMarkdownFilter {
58    /// Create a new filter in its initial state.
59    pub fn new() -> Self {
60        Self {
61            buf: String::new(),
62            fence: false,
63            sol: true,
64            inl: None,
65        }
66    }
67
68    /// Feed a chunk of text into the filter, returning any output that can be emitted.
69    pub fn feed(&mut self, delta: &str) -> String {
70        self.buf.push_str(delta);
71        self.pump(false)
72    }
73
74    /// Flush remaining buffered content, emitting everything.
75    pub fn flush(&mut self) -> String {
76        self.pump(true)
77    }
78
79    fn pump(&mut self, eof: bool) -> String {
80        let mut out = String::new();
81        loop {
82            if self.buf.is_empty() {
83                break;
84            }
85            let s_len = self.buf.len();
86            let s_sol = self.sol;
87            let s_fence = self.fence;
88            let s_inl_is_some = self.inl.is_some();
89
90            if self.fence {
91                out.push_str(&self.pump_fence(eof));
92            } else if self.inl.is_some() {
93                out.push_str(&self.pump_inline(eof));
94            } else if self.sol {
95                out.push_str(&self.pump_sol(eof));
96            } else {
97                out.push_str(&self.pump_body(eof));
98            }
99
100            if self.buf.len() == s_len
101                && self.sol == s_sol
102                && self.fence == s_fence
103                && self.inl.is_some() == s_inl_is_some
104            {
105                break;
106            }
107        }
108
109        if eof {
110            if let Some(inl) = self.inl.take() {
111                let marker = match inl.typ {
112                    InlineType::Image => "![",
113                    InlineType::Bold3 => "***",
114                    InlineType::Italic => "*",
115                    InlineType::UBold3 => "___",
116                    InlineType::UItalic => "_",
117                };
118                out.push_str(marker);
119                out.push_str(&inl.acc);
120            }
121        }
122        out
123    }
124
125    fn pump_fence(&mut self, eof: bool) -> String {
126        if self.sol {
127            if self.buf.len() < 3 && !eof {
128                return String::new();
129            }
130            if self.buf.starts_with("```") {
131                if let Some(nl) = self.buf[3..].find('\n') {
132                    let nl = nl + 3;
133                    self.fence = false;
134                    let line = self.buf[..nl + 1].to_string();
135                    self.buf = self.buf[nl + 1..].to_string();
136                    self.sol = true;
137                    return line;
138                }
139                if eof {
140                    self.fence = false;
141                    let line = std::mem::take(&mut self.buf);
142                    return line;
143                }
144                return String::new();
145            }
146            self.sol = false;
147        }
148        if let Some(nl) = self.buf.find('\n') {
149            let chunk = self.buf[..nl + 1].to_string();
150            self.buf = self.buf[nl + 1..].to_string();
151            self.sol = true;
152            return chunk;
153        }
154        let chunk = std::mem::take(&mut self.buf);
155        chunk
156    }
157
158    fn pump_sol(&mut self, eof: bool) -> String {
159        let b = self.buf.clone();
160        let bytes = b.as_bytes();
161
162        if bytes[0] == b'\n' {
163            self.buf = b[1..].to_string();
164            return "\n".to_string();
165        }
166
167        if bytes[0] == b'`' {
168            if b.len() < 3 && !eof {
169                return String::new();
170            }
171            if b.starts_with("```") {
172                if let Some(nl) = b[3..].find('\n') {
173                    let nl = nl + 3;
174                    self.fence = true;
175                    let line = b[..nl + 1].to_string();
176                    self.buf = b[nl + 1..].to_string();
177                    self.sol = true;
178                    return line;
179                }
180                if eof {
181                    self.buf = String::new();
182                    return b;
183                }
184                return String::new();
185            }
186            self.sol = false;
187            return String::new();
188        }
189
190        if bytes[0] == b'>' {
191            self.sol = false;
192            return String::new();
193        }
194
195        if bytes[0] == b'#' {
196            let mut n = 0;
197            while n < bytes.len() && bytes[n] == b'#' {
198                n += 1;
199            }
200            if n == b.len() && !eof {
201                return String::new();
202            }
203            if n >= 5 && n <= 6 && n < b.len() && bytes[n] == b' ' {
204                self.buf = b[n + 1..].to_string();
205                self.sol = false;
206                return String::new();
207            }
208            self.sol = false;
209            return String::new();
210        }
211
212        if bytes[0] == b' ' || bytes[0] == b'\t' {
213            let non_ws = b.find(|c: char| c != ' ' && c != '\t');
214            if non_ws.is_none() && !eof {
215                return String::new();
216            }
217            self.sol = false;
218            return String::new();
219        }
220
221        if bytes[0] == b'-' || bytes[0] == b'*' || bytes[0] == b'_' {
222            let ch = bytes[0];
223            let mut j = 0;
224            while j < bytes.len() && (bytes[j] == ch || bytes[j] == b' ') {
225                j += 1;
226            }
227            if j == b.len() && !eof {
228                return String::new();
229            }
230            if j == b.len() || bytes[j] == b'\n' {
231                let count = bytes[..j].iter().filter(|&&x| x == ch).count();
232                if count >= 3 {
233                    if j < b.len() {
234                        self.buf = b[j + 1..].to_string();
235                        self.sol = true;
236                        return b[..j + 1].to_string();
237                    }
238                    self.buf = String::new();
239                    return b;
240                }
241            }
242            self.sol = false;
243            return String::new();
244        }
245
246        self.sol = false;
247        String::new()
248    }
249
250    fn pump_body(&mut self, eof: bool) -> String {
251        let mut out = String::new();
252        let chars: Vec<char> = self.buf.chars().collect();
253        let mut i = 0;
254
255        while i < chars.len() {
256            let c = chars[i];
257            if c == '\n' {
258                out.push_str(&chars[..i + 1].iter().collect::<String>());
259                self.buf = chars[i + 1..].iter().collect();
260                self.sol = true;
261                return out;
262            }
263            if c == '!' && i + 1 < chars.len() && chars[i + 1] == '[' {
264                out.push_str(&chars[..i].iter().collect::<String>());
265                self.buf = chars[i + 2..].iter().collect();
266                self.inl = Some(InlineState {
267                    typ: InlineType::Image,
268                    acc: String::new(),
269                });
270                return out;
271            }
272            if c == '~' {
273                i += 1;
274                continue;
275            }
276            if c == '*' {
277                if i + 2 < chars.len() && chars[i + 1] == '*' && chars[i + 2] == '*' {
278                    out.push_str(&chars[..i].iter().collect::<String>());
279                    self.buf = chars[i + 3..].iter().collect();
280                    self.inl = Some(InlineState {
281                        typ: InlineType::Bold3,
282                        acc: String::new(),
283                    });
284                    return out;
285                }
286                if i + 1 < chars.len() && chars[i + 1] == '*' {
287                    i += 2;
288                    continue;
289                }
290                if i + 1 < chars.len() && chars[i + 1] != ' ' && chars[i + 1] != '\n' {
291                    out.push_str(&chars[..i].iter().collect::<String>());
292                    self.buf = chars[i + 1..].iter().collect();
293                    self.inl = Some(InlineState {
294                        typ: InlineType::Italic,
295                        acc: String::new(),
296                    });
297                    return out;
298                }
299                i += 1;
300                continue;
301            }
302            if c == '_' {
303                if i + 2 < chars.len() && chars[i + 1] == '_' && chars[i + 2] == '_' {
304                    out.push_str(&chars[..i].iter().collect::<String>());
305                    self.buf = chars[i + 3..].iter().collect();
306                    self.inl = Some(InlineState {
307                        typ: InlineType::UBold3,
308                        acc: String::new(),
309                    });
310                    return out;
311                }
312                if i + 1 < chars.len() && chars[i + 1] == '_' {
313                    i += 2;
314                    continue;
315                }
316                if i + 1 < chars.len() && chars[i + 1] != ' ' && chars[i + 1] != '\n' {
317                    out.push_str(&chars[..i].iter().collect::<String>());
318                    self.buf = chars[i + 1..].iter().collect();
319                    self.inl = Some(InlineState {
320                        typ: InlineType::UItalic,
321                        acc: String::new(),
322                    });
323                    return out;
324                }
325                i += 1;
326                continue;
327            }
328            i += 1;
329        }
330
331        let mut hold = 0;
332        if !eof {
333            let s: String = chars.iter().collect();
334            if s.ends_with("**") || s.ends_with("__") {
335                hold = 2;
336            } else if s.ends_with('*') || s.ends_with('_') || s.ends_with('!') {
337                hold = 1;
338            }
339        }
340        let emit_len = chars.len() - hold;
341        out.push_str(&chars[..emit_len].iter().collect::<String>());
342        self.buf = if hold > 0 {
343            chars[chars.len() - hold..].iter().collect()
344        } else {
345            String::new()
346        };
347        out
348    }
349
350    fn pump_inline(&mut self, _eof: bool) -> String {
351        let Some(inl) = self.inl.as_mut() else {
352            return String::new();
353        };
354        inl.acc.push_str(&self.buf);
355        self.buf = String::new();
356
357        let typ = inl.typ.clone();
358        let acc = inl.acc.clone();
359
360        match typ {
361            InlineType::Bold3 => {
362                if let Some(idx) = acc.find("***") {
363                    let content = &acc[..idx];
364                    self.buf = acc[idx + 3..].to_string();
365                    let result = if Self::contains_cjk(content) {
366                        content.to_string()
367                    } else {
368                        format!("***{content}***")
369                    };
370                    self.inl = None;
371                    return result;
372                }
373                String::new()
374            }
375            InlineType::UBold3 => {
376                if let Some(idx) = acc.find("___") {
377                    let content = &acc[..idx];
378                    self.buf = acc[idx + 3..].to_string();
379                    let result = if Self::contains_cjk(content) {
380                        content.to_string()
381                    } else {
382                        format!("___{content}___")
383                    };
384                    self.inl = None;
385                    return result;
386                }
387                String::new()
388            }
389            InlineType::Italic => {
390                let chars: Vec<char> = acc.chars().collect();
391                for j in 0..chars.len() {
392                    if chars[j] == '\n' {
393                        let before: String = chars[..j + 1].iter().collect();
394                        let after: String = chars[j + 1..].iter().collect();
395                        self.buf = after;
396                        self.inl = None;
397                        self.sol = true;
398                        return format!("*{before}");
399                    }
400                    if chars[j] == '*' {
401                        if j + 1 < chars.len() && chars[j + 1] == '*' {
402                            continue;
403                        }
404                        let content: String = chars[..j].iter().collect();
405                        self.buf = chars[j + 1..].iter().collect();
406                        self.inl = None;
407                        return if Self::contains_cjk(&content) {
408                            content
409                        } else {
410                            format!("*{content}*")
411                        };
412                    }
413                }
414                String::new()
415            }
416            InlineType::UItalic => {
417                let chars: Vec<char> = acc.chars().collect();
418                for j in 0..chars.len() {
419                    if chars[j] == '\n' {
420                        let before: String = chars[..j + 1].iter().collect();
421                        let after: String = chars[j + 1..].iter().collect();
422                        self.buf = after;
423                        self.inl = None;
424                        self.sol = true;
425                        return format!("_{before}");
426                    }
427                    if chars[j] == '_' {
428                        if j + 1 < chars.len() && chars[j + 1] == '_' {
429                            continue;
430                        }
431                        let content: String = chars[..j].iter().collect();
432                        self.buf = chars[j + 1..].iter().collect();
433                        self.inl = None;
434                        return if Self::contains_cjk(&content) {
435                            content
436                        } else {
437                            format!("_{content}_")
438                        };
439                    }
440                }
441                String::new()
442            }
443            InlineType::Image => {
444                if let Some(cb) = acc.find(']') {
445                    if cb + 1 >= acc.len() {
446                        return String::new();
447                    }
448                    if acc.as_bytes()[cb + 1] != b'(' {
449                        let r = format!("![{}", &acc[..cb + 1]);
450                        self.buf = acc[cb + 1..].to_string();
451                        self.inl = None;
452                        return r;
453                    }
454                    if let Some(cp) = acc[cb + 2..].find(')') {
455                        let cp = cp + cb + 2;
456                        self.buf = acc[cp + 1..].to_string();
457                        self.inl = None;
458                        return String::new();
459                    }
460                }
461                String::new()
462            }
463        }
464    }
465
466    fn contains_cjk(text: &str) -> bool {
467        text.chars().any(|c| {
468            ('\u{2E80}'..='\u{9FFF}').contains(&c)
469                || ('\u{AC00}'..='\u{D7AF}').contains(&c)
470                || ('\u{F900}'..='\u{FAFF}').contains(&c)
471        })
472    }
473}
474
475/// Filter markdown from a complete text string, stripping unsupported syntax for `WeChat`.
476pub fn filter_markdown(text: &str) -> String {
477    let mut f = StreamingMarkdownFilter::new();
478    let mut out = f.feed(text);
479    out.push_str(&f.flush());
480    out
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn plain_text() {
489        assert_eq!(filter_markdown("hello world"), "hello world");
490    }
491
492    #[test]
493    fn code_fence() {
494        let input = "```rust\nfn main() {}\n```\n";
495        assert_eq!(filter_markdown(input), input);
496    }
497
498    #[test]
499    fn bold_preserved() {
500        assert_eq!(filter_markdown("**bold**"), "**bold**");
501    }
502
503    #[test]
504    fn image_stripping() {
505        assert_eq!(
506            filter_markdown("before ![alt](http://img.png) after"),
507            "before  after"
508        );
509    }
510
511    #[test]
512    fn cjk_italic() {
513        assert_eq!(filter_markdown("*你好*"), "你好");
514    }
515
516    #[test]
517    fn non_cjk_italic() {
518        assert_eq!(filter_markdown("*hello*"), "*hello*");
519    }
520
521    #[test]
522    fn cjk_bold_italic() {
523        assert_eq!(filter_markdown("***你好***"), "你好");
524    }
525
526    #[test]
527    fn non_cjk_bold_italic() {
528        assert_eq!(filter_markdown("***hello***"), "***hello***");
529    }
530
531    #[test]
532    fn underscore_italic_cjk() {
533        assert_eq!(filter_markdown("_你好_"), "你好");
534    }
535
536    #[test]
537    fn underscore_bold_italic_cjk() {
538        assert_eq!(filter_markdown("___你好___"), "你好");
539    }
540
541    #[test]
542    fn non_cjk_underscore_italic() {
543        assert_eq!(filter_markdown("_hello_"), "_hello_");
544    }
545
546    #[test]
547    fn h5_heading() {
548        assert_eq!(filter_markdown("##### Title"), "Title");
549    }
550
551    #[test]
552    fn h6_heading() {
553        assert_eq!(filter_markdown("###### Title"), "Title");
554    }
555
556    #[test]
557    fn table_preserved() {
558        let input = "| a | b |\n| - | - |\n| 1 | 2 |\n";
559        assert_eq!(filter_markdown(input), input);
560    }
561
562    #[test]
563    fn horizontal_rule() {
564        assert_eq!(filter_markdown("---\n"), "---\n");
565        assert_eq!(filter_markdown("***\n"), "***\n");
566        assert_eq!(filter_markdown("___\n"), "___\n");
567    }
568
569    #[test]
570    fn streaming_incremental() {
571        let mut f = StreamingMarkdownFilter::new();
572        let mut out = String::new();
573        out.push_str(&f.feed("hel"));
574        out.push_str(&f.feed("lo world"));
575        out.push_str(&f.flush());
576        assert_eq!(out, "hello world");
577    }
578
579    #[test]
580    fn blockquote_preservation() {
581        // The > marker is preserved (sol transitions to body, content passes through)
582        let result = filter_markdown("> quote text");
583        assert_eq!(result, "> quote text");
584    }
585
586    #[test]
587    fn indent_preservation() {
588        // Leading whitespace is preserved (sol transitions to body, content passes through)
589        let result = filter_markdown("    indented");
590        assert_eq!(result, "    indented");
591    }
592}