weixin_agent/messaging/
markdown_filter.rs1#![allow(clippy::module_name_repetitions)]
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14enum InlineType {
15 Image,
16 Bold3,
17 Italic,
18 UBold3,
19 UItalic,
20}
21
22#[derive(Debug, Clone)]
23struct InlineState {
24 typ: InlineType,
25 acc: String,
26}
27
28#[derive(Debug, Clone)]
33pub struct StreamingMarkdownFilter {
34 buf: String,
35 fence: bool,
36 sol: bool,
37 inl: Option<InlineState>,
38}
39
40impl Default for StreamingMarkdownFilter {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46#[allow(
47 clippy::too_many_lines,
48 clippy::range_plus_one,
49 clippy::manual_range_contains,
50 clippy::naive_bytecount,
51 clippy::manual_strip,
52 clippy::let_and_return,
53 clippy::doc_markdown,
54 clippy::items_after_statements,
55 clippy::cast_possible_truncation
56)]
57impl StreamingMarkdownFilter {
58 pub fn new() -> Self {
60 Self {
61 buf: String::new(),
62 fence: false,
63 sol: true,
64 inl: None,
65 }
66 }
67
68 pub fn feed(&mut self, delta: &str) -> String {
70 self.buf.push_str(delta);
71 self.pump(false)
72 }
73
74 pub fn flush(&mut self) -> String {
76 self.pump(true)
77 }
78
79 fn pump(&mut self, eof: bool) -> String {
80 let mut out = String::new();
81 loop {
82 if self.buf.is_empty() {
83 break;
84 }
85 let s_len = self.buf.len();
86 let s_sol = self.sol;
87 let s_fence = self.fence;
88 let s_inl_is_some = self.inl.is_some();
89
90 if self.fence {
91 out.push_str(&self.pump_fence(eof));
92 } else if self.inl.is_some() {
93 out.push_str(&self.pump_inline(eof));
94 } else if self.sol {
95 out.push_str(&self.pump_sol(eof));
96 } else {
97 out.push_str(&self.pump_body(eof));
98 }
99
100 if self.buf.len() == s_len
101 && self.sol == s_sol
102 && self.fence == s_fence
103 && self.inl.is_some() == s_inl_is_some
104 {
105 break;
106 }
107 }
108
109 if eof {
110 if let Some(inl) = self.inl.take() {
111 let marker = match inl.typ {
112 InlineType::Image => "![",
113 InlineType::Bold3 => "***",
114 InlineType::Italic => "*",
115 InlineType::UBold3 => "___",
116 InlineType::UItalic => "_",
117 };
118 out.push_str(marker);
119 out.push_str(&inl.acc);
120 }
121 }
122 out
123 }
124
125 fn pump_fence(&mut self, eof: bool) -> String {
126 if self.sol {
127 if self.buf.len() < 3 && !eof {
128 return String::new();
129 }
130 if self.buf.starts_with("```") {
131 if let Some(nl) = self.buf[3..].find('\n') {
132 let nl = nl + 3;
133 self.fence = false;
134 let line = self.buf[..nl + 1].to_string();
135 self.buf = self.buf[nl + 1..].to_string();
136 self.sol = true;
137 return line;
138 }
139 if eof {
140 self.fence = false;
141 let line = std::mem::take(&mut self.buf);
142 return line;
143 }
144 return String::new();
145 }
146 self.sol = false;
147 }
148 if let Some(nl) = self.buf.find('\n') {
149 let chunk = self.buf[..nl + 1].to_string();
150 self.buf = self.buf[nl + 1..].to_string();
151 self.sol = true;
152 return chunk;
153 }
154 let chunk = std::mem::take(&mut self.buf);
155 chunk
156 }
157
158 fn pump_sol(&mut self, eof: bool) -> String {
159 let b = self.buf.clone();
160 let bytes = b.as_bytes();
161
162 if bytes[0] == b'\n' {
163 self.buf = b[1..].to_string();
164 return "\n".to_string();
165 }
166
167 if bytes[0] == b'`' {
168 if b.len() < 3 && !eof {
169 return String::new();
170 }
171 if b.starts_with("```") {
172 if let Some(nl) = b[3..].find('\n') {
173 let nl = nl + 3;
174 self.fence = true;
175 let line = b[..nl + 1].to_string();
176 self.buf = b[nl + 1..].to_string();
177 self.sol = true;
178 return line;
179 }
180 if eof {
181 self.buf = String::new();
182 return b;
183 }
184 return String::new();
185 }
186 self.sol = false;
187 return String::new();
188 }
189
190 if bytes[0] == b'>' {
191 self.sol = false;
192 return String::new();
193 }
194
195 if bytes[0] == b'#' {
196 let mut n = 0;
197 while n < bytes.len() && bytes[n] == b'#' {
198 n += 1;
199 }
200 if n == b.len() && !eof {
201 return String::new();
202 }
203 if n >= 5 && n <= 6 && n < b.len() && bytes[n] == b' ' {
204 self.buf = b[n + 1..].to_string();
205 self.sol = false;
206 return String::new();
207 }
208 self.sol = false;
209 return String::new();
210 }
211
212 if bytes[0] == b' ' || bytes[0] == b'\t' {
213 let non_ws = b.find(|c: char| c != ' ' && c != '\t');
214 if non_ws.is_none() && !eof {
215 return String::new();
216 }
217 self.sol = false;
218 return String::new();
219 }
220
221 if bytes[0] == b'-' || bytes[0] == b'*' || bytes[0] == b'_' {
222 let ch = bytes[0];
223 let mut j = 0;
224 while j < bytes.len() && (bytes[j] == ch || bytes[j] == b' ') {
225 j += 1;
226 }
227 if j == b.len() && !eof {
228 return String::new();
229 }
230 if j == b.len() || bytes[j] == b'\n' {
231 let count = bytes[..j].iter().filter(|&&x| x == ch).count();
232 if count >= 3 {
233 if j < b.len() {
234 self.buf = b[j + 1..].to_string();
235 self.sol = true;
236 return b[..j + 1].to_string();
237 }
238 self.buf = String::new();
239 return b;
240 }
241 }
242 self.sol = false;
243 return String::new();
244 }
245
246 self.sol = false;
247 String::new()
248 }
249
250 fn pump_body(&mut self, eof: bool) -> String {
251 let mut out = String::new();
252 let chars: Vec<char> = self.buf.chars().collect();
253 let mut i = 0;
254
255 while i < chars.len() {
256 let c = chars[i];
257 if c == '\n' {
258 out.push_str(&chars[..i + 1].iter().collect::<String>());
259 self.buf = chars[i + 1..].iter().collect();
260 self.sol = true;
261 return out;
262 }
263 if c == '!' && i + 1 < chars.len() && chars[i + 1] == '[' {
264 out.push_str(&chars[..i].iter().collect::<String>());
265 self.buf = chars[i + 2..].iter().collect();
266 self.inl = Some(InlineState {
267 typ: InlineType::Image,
268 acc: String::new(),
269 });
270 return out;
271 }
272 if c == '~' {
273 i += 1;
274 continue;
275 }
276 if c == '*' {
277 if i + 2 < chars.len() && chars[i + 1] == '*' && chars[i + 2] == '*' {
278 out.push_str(&chars[..i].iter().collect::<String>());
279 self.buf = chars[i + 3..].iter().collect();
280 self.inl = Some(InlineState {
281 typ: InlineType::Bold3,
282 acc: String::new(),
283 });
284 return out;
285 }
286 if i + 1 < chars.len() && chars[i + 1] == '*' {
287 i += 2;
288 continue;
289 }
290 if i + 1 < chars.len() && chars[i + 1] != ' ' && chars[i + 1] != '\n' {
291 out.push_str(&chars[..i].iter().collect::<String>());
292 self.buf = chars[i + 1..].iter().collect();
293 self.inl = Some(InlineState {
294 typ: InlineType::Italic,
295 acc: String::new(),
296 });
297 return out;
298 }
299 i += 1;
300 continue;
301 }
302 if c == '_' {
303 if i + 2 < chars.len() && chars[i + 1] == '_' && chars[i + 2] == '_' {
304 out.push_str(&chars[..i].iter().collect::<String>());
305 self.buf = chars[i + 3..].iter().collect();
306 self.inl = Some(InlineState {
307 typ: InlineType::UBold3,
308 acc: String::new(),
309 });
310 return out;
311 }
312 if i + 1 < chars.len() && chars[i + 1] == '_' {
313 i += 2;
314 continue;
315 }
316 if i + 1 < chars.len() && chars[i + 1] != ' ' && chars[i + 1] != '\n' {
317 out.push_str(&chars[..i].iter().collect::<String>());
318 self.buf = chars[i + 1..].iter().collect();
319 self.inl = Some(InlineState {
320 typ: InlineType::UItalic,
321 acc: String::new(),
322 });
323 return out;
324 }
325 i += 1;
326 continue;
327 }
328 i += 1;
329 }
330
331 let mut hold = 0;
332 if !eof {
333 let s: String = chars.iter().collect();
334 if s.ends_with("**") || s.ends_with("__") {
335 hold = 2;
336 } else if s.ends_with('*') || s.ends_with('_') || s.ends_with('!') {
337 hold = 1;
338 }
339 }
340 let emit_len = chars.len() - hold;
341 out.push_str(&chars[..emit_len].iter().collect::<String>());
342 self.buf = if hold > 0 {
343 chars[chars.len() - hold..].iter().collect()
344 } else {
345 String::new()
346 };
347 out
348 }
349
350 fn pump_inline(&mut self, _eof: bool) -> String {
351 let Some(inl) = self.inl.as_mut() else {
352 return String::new();
353 };
354 inl.acc.push_str(&self.buf);
355 self.buf = String::new();
356
357 let typ = inl.typ.clone();
358 let acc = inl.acc.clone();
359
360 match typ {
361 InlineType::Bold3 => {
362 if let Some(idx) = acc.find("***") {
363 let content = &acc[..idx];
364 self.buf = acc[idx + 3..].to_string();
365 let result = if Self::contains_cjk(content) {
366 content.to_string()
367 } else {
368 format!("***{content}***")
369 };
370 self.inl = None;
371 return result;
372 }
373 String::new()
374 }
375 InlineType::UBold3 => {
376 if let Some(idx) = acc.find("___") {
377 let content = &acc[..idx];
378 self.buf = acc[idx + 3..].to_string();
379 let result = if Self::contains_cjk(content) {
380 content.to_string()
381 } else {
382 format!("___{content}___")
383 };
384 self.inl = None;
385 return result;
386 }
387 String::new()
388 }
389 InlineType::Italic => {
390 let chars: Vec<char> = acc.chars().collect();
391 for j in 0..chars.len() {
392 if chars[j] == '\n' {
393 let before: String = chars[..j + 1].iter().collect();
394 let after: String = chars[j + 1..].iter().collect();
395 self.buf = after;
396 self.inl = None;
397 self.sol = true;
398 return format!("*{before}");
399 }
400 if chars[j] == '*' {
401 if j + 1 < chars.len() && chars[j + 1] == '*' {
402 continue;
403 }
404 let content: String = chars[..j].iter().collect();
405 self.buf = chars[j + 1..].iter().collect();
406 self.inl = None;
407 return if Self::contains_cjk(&content) {
408 content
409 } else {
410 format!("*{content}*")
411 };
412 }
413 }
414 String::new()
415 }
416 InlineType::UItalic => {
417 let chars: Vec<char> = acc.chars().collect();
418 for j in 0..chars.len() {
419 if chars[j] == '\n' {
420 let before: String = chars[..j + 1].iter().collect();
421 let after: String = chars[j + 1..].iter().collect();
422 self.buf = after;
423 self.inl = None;
424 self.sol = true;
425 return format!("_{before}");
426 }
427 if chars[j] == '_' {
428 if j + 1 < chars.len() && chars[j + 1] == '_' {
429 continue;
430 }
431 let content: String = chars[..j].iter().collect();
432 self.buf = chars[j + 1..].iter().collect();
433 self.inl = None;
434 return if Self::contains_cjk(&content) {
435 content
436 } else {
437 format!("_{content}_")
438 };
439 }
440 }
441 String::new()
442 }
443 InlineType::Image => {
444 if let Some(cb) = acc.find(']') {
445 if cb + 1 >= acc.len() {
446 return String::new();
447 }
448 if acc.as_bytes()[cb + 1] != b'(' {
449 let r = format!("![{}", &acc[..cb + 1]);
450 self.buf = acc[cb + 1..].to_string();
451 self.inl = None;
452 return r;
453 }
454 if let Some(cp) = acc[cb + 2..].find(')') {
455 let cp = cp + cb + 2;
456 self.buf = acc[cp + 1..].to_string();
457 self.inl = None;
458 return String::new();
459 }
460 }
461 String::new()
462 }
463 }
464 }
465
466 fn contains_cjk(text: &str) -> bool {
467 text.chars().any(|c| {
468 ('\u{2E80}'..='\u{9FFF}').contains(&c)
469 || ('\u{AC00}'..='\u{D7AF}').contains(&c)
470 || ('\u{F900}'..='\u{FAFF}').contains(&c)
471 })
472 }
473}
474
475pub fn filter_markdown(text: &str) -> String {
477 let mut f = StreamingMarkdownFilter::new();
478 let mut out = f.feed(text);
479 out.push_str(&f.flush());
480 out
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn plain_text() {
489 assert_eq!(filter_markdown("hello world"), "hello world");
490 }
491
492 #[test]
493 fn code_fence() {
494 let input = "```rust\nfn main() {}\n```\n";
495 assert_eq!(filter_markdown(input), input);
496 }
497
498 #[test]
499 fn bold_preserved() {
500 assert_eq!(filter_markdown("**bold**"), "**bold**");
501 }
502
503 #[test]
504 fn image_stripping() {
505 assert_eq!(
506 filter_markdown("before  after"),
507 "before after"
508 );
509 }
510
511 #[test]
512 fn cjk_italic() {
513 assert_eq!(filter_markdown("*你好*"), "你好");
514 }
515
516 #[test]
517 fn non_cjk_italic() {
518 assert_eq!(filter_markdown("*hello*"), "*hello*");
519 }
520
521 #[test]
522 fn cjk_bold_italic() {
523 assert_eq!(filter_markdown("***你好***"), "你好");
524 }
525
526 #[test]
527 fn non_cjk_bold_italic() {
528 assert_eq!(filter_markdown("***hello***"), "***hello***");
529 }
530
531 #[test]
532 fn underscore_italic_cjk() {
533 assert_eq!(filter_markdown("_你好_"), "你好");
534 }
535
536 #[test]
537 fn underscore_bold_italic_cjk() {
538 assert_eq!(filter_markdown("___你好___"), "你好");
539 }
540
541 #[test]
542 fn non_cjk_underscore_italic() {
543 assert_eq!(filter_markdown("_hello_"), "_hello_");
544 }
545
546 #[test]
547 fn h5_heading() {
548 assert_eq!(filter_markdown("##### Title"), "Title");
549 }
550
551 #[test]
552 fn h6_heading() {
553 assert_eq!(filter_markdown("###### Title"), "Title");
554 }
555
556 #[test]
557 fn table_preserved() {
558 let input = "| a | b |\n| - | - |\n| 1 | 2 |\n";
559 assert_eq!(filter_markdown(input), input);
560 }
561
562 #[test]
563 fn horizontal_rule() {
564 assert_eq!(filter_markdown("---\n"), "---\n");
565 assert_eq!(filter_markdown("***\n"), "***\n");
566 assert_eq!(filter_markdown("___\n"), "___\n");
567 }
568
569 #[test]
570 fn streaming_incremental() {
571 let mut f = StreamingMarkdownFilter::new();
572 let mut out = String::new();
573 out.push_str(&f.feed("hel"));
574 out.push_str(&f.feed("lo world"));
575 out.push_str(&f.flush());
576 assert_eq!(out, "hello world");
577 }
578
579 #[test]
580 fn blockquote_preservation() {
581 let result = filter_markdown("> quote text");
583 assert_eq!(result, "> quote text");
584 }
585
586 #[test]
587 fn indent_preservation() {
588 let result = filter_markdown(" indented");
590 assert_eq!(result, " indented");
591 }
592}