1use crate::banned::BANNED;
2use crate::buffer_proxy_iterator::BufferProxyIterator;
3use crate::mtch::*;
4use crate::replacements::REPLACEMENTS;
5use crate::trie::*;
6use crate::Set;
7use crate::{is_whitespace, Replacements, Type};
8use std::iter::Filter;
9use std::mem;
10use std::ops::Deref;
11use std::ops::RangeInclusive;
12use std::str::Chars;
13use unicode_normalization::{Decompositions, Recompositions, UnicodeNormalization};
14
15pub struct Censor<I: Iterator<Item = char>> {
20 buffer: BufferProxyIterator<Recompositions<Filter<Decompositions<I>, fn(&char) -> bool>>>,
23 options: Options,
24 inline: InlineState,
25 allocated: AllocatedState,
26}
27
28struct Options {
29 trie: &'static Trie,
30 replacements: &'static Replacements,
31 ignore_false_positives: bool,
33 ignore_self_censoring: bool,
34 censor_first_character_threshold: Type,
35 censor_replacement: char,
37 censor_threshold: Type,
38}
39
40impl Default for Options {
41 fn default() -> Self {
42 Self {
43 trie: &*TRIE,
44 replacements: &*REPLACEMENTS,
45 ignore_false_positives: false,
47 ignore_self_censoring: false,
48 censor_first_character_threshold: Type::OFFENSIVE & Type::SEVERE,
49 censor_replacement: '*',
51 censor_threshold: Default::default(),
52 }
53 }
54}
55
56struct InlineState {
57 separate: bool,
59 last_pos: usize,
61 typ: Type,
63 uppercase: u8,
65 repetitions: u8,
66 last: Option<char>,
67 gibberish: u8,
68 replacements: u8,
69 self_censoring: u8,
71 safe: bool,
73 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
74 match_ptrs: usize,
75 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
76 total_matches: usize,
77 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
78 total_match_characters: usize,
79 space_appended: bool,
81 done: bool,
83}
84
85impl Default for InlineState {
86 fn default() -> Self {
87 Self {
88 separate: true,
90 typ: Type::NONE,
92 uppercase: 0,
93 repetitions: 0,
94 last: None,
95 gibberish: 0,
96 replacements: 0,
97 self_censoring: 0,
98 safe: false,
99 space_appended: false,
100 done: false,
101 last_pos: usize::MAX,
102 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
103 match_ptrs: 0,
104 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
105 total_matches: 0,
106 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
107 total_match_characters: 0,
108 }
109 }
110}
111
112#[derive(Default)]
113struct AllocatedState {
114 matches: Set<Match>,
116 matches_tmp: Set<Match>,
118 pending_commit: Vec<Match>,
120 #[cfg(feature = "trace_full")]
121 detections: crate::Map<String, usize>,
122}
123
124impl AllocatedState {
125 fn clear(&mut self) {
126 let Self {
127 matches,
128 matches_tmp,
129 pending_commit,
130 #[cfg(feature = "trace_full")]
131 detections,
132 } = self;
133 matches.clear();
134 matches_tmp.clear();
135 pending_commit.clear();
136 #[cfg(feature = "trace_full")]
137 detections.clear();
138 }
139}
140
141impl<'a> Censor<Chars<'a>> {
142 pub fn from_str(s: &'a str) -> Self {
144 Self::new(s.chars())
145 }
146}
147
148impl<I: Iterator<Item = char>> Censor<I> {
149 pub fn new(text: I) -> Self {
151 Self {
152 buffer: Self::buffer_from(text),
153 options: Default::default(),
154 inline: Default::default(),
155 allocated: Default::default(),
156 }
157 }
158
159 fn buffer_from(
160 text: I,
161 ) -> BufferProxyIterator<Recompositions<Filter<Decompositions<I>, fn(&char) -> bool>>> {
162 fn filter_char(c: &char) -> bool {
165 use finl_unicode::categories::{CharacterCategories, MinorCategory};
166 let category = c.get_minor_category();
167 let preserve_japanese = matches!(*c, '\u{3099}' | '\u{309A}');
169 let nok = matches!(
170 category,
171 MinorCategory::Cn | MinorCategory::Co | MinorCategory::Mn
172 ) && !preserve_japanese;
173
174 !(nok || BANNED.deref().deref().contains(*c))
175 }
176
177 BufferProxyIterator::new(
178 text
179 .nfd()
181 .filter(filter_char as fn(&char) -> bool)
182 .nfc(),
183 )
184 }
185
186 pub fn reset(&mut self, text: I) {
189 self.inline = Default::default();
190 self.allocated.clear();
191 self.buffer = Self::buffer_from(text);
192 }
193
194 pub fn with_trie(&mut self, trie: &'static Trie) -> &mut Self {
196 self.options.trie = trie;
197 self
198 }
199
200 pub fn with_replacements(&mut self, replacements: &'static Replacements) -> &mut Self {
202 self.options.replacements = replacements;
203 self
204 }
205
206 pub fn with_censor_threshold(&mut self, censor_threshold: Type) -> &mut Self {
213 self.options.censor_threshold = censor_threshold;
214 self
215 }
216
217 pub fn with_ignore_false_positives(&mut self, ignore_false_positives: bool) -> &mut Self {
222 self.options.ignore_false_positives = ignore_false_positives;
223 self
224 }
225
226 pub fn with_ignore_self_censoring(&mut self, ignore_self_censoring: bool) -> &mut Self {
236 self.options.ignore_self_censoring = ignore_self_censoring;
237 self
238 }
239
240 pub fn with_censor_first_character_threshold(
245 &mut self,
246 censor_first_character_threshold: Type,
247 ) -> &mut Self {
248 self.options.censor_first_character_threshold = censor_first_character_threshold;
249 self
250 }
251
252 pub fn with_censor_replacement(&mut self, censor_replacement: char) -> &mut Self {
265 self.options.censor_replacement = censor_replacement;
266 self
267 }
268
269 #[cfg(feature = "find_false_positives")]
271 pub fn with_separate(&mut self, separate: bool) -> &mut Self {
272 self.inline.separate = separate;
273 self
274 }
275
276 pub fn censor(&mut self) -> String {
288 assert!(
289 !self.buffer.index().is_some(),
290 "censor must be called before any other form of processing"
291 );
292 self.collect()
293 }
294
295 pub fn analyze(&mut self) -> Type {
299 self.ensure_done();
300 self.analysis()
301 }
302
303 pub fn censor_and_analyze(&mut self) -> (String, Type) {
305 let censored = self.censor();
307 (censored, self.analysis())
309 }
310
311 fn analysis(&self) -> Type {
313 self.inline.typ | self.safe_self_censoring_and_spam_detection()
314 }
315
316 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
317 pub fn match_ptrs(&self) -> usize {
318 self.inline.match_ptrs
319 }
320
321 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
322 pub fn total_matches(&self) -> usize {
323 self.inline.total_matches
324 }
325
326 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
327 pub fn total_match_characters(&self) -> usize {
328 self.inline.total_match_characters
329 }
330
331 #[cfg(feature = "trace_full")]
332 pub fn detections(&self) -> &crate::Map<String, usize> {
333 &self.allocated.detections
334 }
335
336 fn ensure_done(&mut self) {
337 if !self.inline.done {
338 for _ in self {}
339 }
340 }
341
342 fn safe_self_censoring_and_spam_detection(&self) -> Type {
343 let safe = if self.inline.safe && self.inline.repetitions < 4 {
344 Type::SAFE
345 } else {
346 Type::NONE
347 };
348
349 if self.inline.last_pos < 6 {
350 return safe;
352 }
353
354 let total = self
357 .inline
358 .last_pos
359 .saturating_add(6)
360 .min(u16::MAX as usize) as u16;
361
362 let spam = self
364 .inline
365 .uppercase
366 .max(self.inline.repetitions)
367 .max(self.inline.gibberish / 2)
368 .max(self.inline.replacements) as u16;
369
370 let percent_spam = 100 * spam / total;
372 let percent_self_censoring = 100 * self.inline.self_censoring as u16 / total;
373
374 let spam = if percent_spam >= 70 && self.inline.last_pos >= 20 {
376 Type::SPAM & Type::SEVERE
377 } else if percent_spam >= 50 && self.inline.last_pos >= 10 {
378 Type::SPAM & Type::MODERATE
379 } else if percent_spam >= 30 {
380 Type::SPAM & Type::MILD
381 } else {
382 Type::NONE
383 };
384
385 let self_censoring = if !self.options.ignore_self_censoring && percent_self_censoring > 20 {
387 Type::PROFANE & Type::MILD
388 } else {
389 Type::NONE
390 };
391
392 safe | spam | self_censoring
393 }
394}
395
396impl<I: Iterator<Item = char>> Iterator for Censor<I> {
397 type Item = char;
398
399 fn next(&mut self) -> Option<Self::Item> {
401 while let Some(raw_c) = self.buffer.next().or_else(|| {
402 if self.inline.space_appended {
403 None
404 } else {
405 self.inline.space_appended = true;
406 Some(' ')
407 }
408 }) {
409 if !self.inline.space_appended && raw_c != '!' && raw_c != '.' && raw_c != '?' {
410 self.inline.safe = false;
412 }
413
414 let pos = self.buffer.index();
415
416 self.inline.uppercase = self
417 .inline
418 .uppercase
419 .saturating_add(raw_c.is_uppercase() as u8);
420
421 let skippable = !raw_c.is_alphabetic() || is_whitespace(raw_c);
422 let replacement = self.options.replacements.get(raw_c);
423
424 #[cfg(feature = "trace")]
425 println!(
426 "Read '{}', skippable={}, replacing with={:?}",
427 raw_c, skippable, replacement
428 );
429
430 const BLOCK_ELEMENTS: RangeInclusive<char> = '\u{2580}'..='\u{259F}';
431
432 if (!self.inline.separate || self.inline.last == Some(self.options.censor_replacement))
433 && (raw_c == self.options.censor_replacement || BLOCK_ELEMENTS.contains(&raw_c))
434 {
435 self.inline.self_censoring = self.inline.self_censoring.saturating_add(1);
437 }
438
439 if let Some(last) = self.inline.last {
440 if raw_c == last {
441 self.inline.repetitions = self.inline.repetitions.saturating_add(1);
442 }
443
444 fn is_gibberish(c: char) -> bool {
446 matches!(c, 'a' | 's' | 'd' | 'f' | 'j' | 'k' | 'l' | ';')
447 }
448
449 if is_gibberish(raw_c) && is_gibberish(last) {
451 self.inline.gibberish = self.inline.gibberish.saturating_add(1);
452 }
453 }
454
455 if let Some(pos) = pos {
456 if !(skippable
461 && replacement.is_none()
462 && !self.options.trie.root.children.contains_key(&raw_c))
463 {
464 let begin_camel_case_word = raw_c.is_ascii_uppercase()
465 && self
466 .inline
467 .last
468 .map(|c| !c.is_ascii_uppercase())
469 .unwrap_or(false);
470
471 self.allocated.matches.insert(Match {
473 node: &self.options.trie.root,
474 start: pos, end: usize::MAX, last: 0 as char, begin_separate: self.inline.separate || begin_camel_case_word,
478 end_separate: false, spaces: 0,
480 skipped: 0,
481 replacements: 0,
482 repetitions: 0,
483 low_confidence_replacements: 0,
484 });
485 }
486 }
487
488 self.inline.separate = skippable;
489
490 if self.inline.separate {
491 for pending in self.allocated.pending_commit.iter_mut() {
492 if pending.end == self.inline.last_pos {
493 pending.end_separate = true;
494 }
495 }
496 }
497
498 let mut drain_start: Option<usize> = None;
499 let mut safety_end = usize::MAX;
500 let mut replacement_counted = false;
501 let raw_c_lower = raw_c.to_lowercase().next().unwrap();
502
503 mem::swap(&mut self.allocated.matches, &mut self.allocated.matches_tmp);
504 for c in replacement
505 .map(|a| a.as_str())
506 .unwrap_or(&&*raw_c.encode_utf8(&mut [0; 4]))
507 .chars()
508 {
509 let benign_replacement = c == raw_c || c == raw_c_lower;
511
512 let countable_replacement = !(replacement_counted
514 || benign_replacement
515 || raw_c.is_ascii_alphabetic()
516 || (raw_c.is_ascii_digit()
517 && self
518 .inline
519 .last
520 .map(|l| l.is_ascii_digit())
521 .unwrap_or(false)));
522
523 if countable_replacement {
524 self.inline.replacements = self.inline.replacements.saturating_add(1);
525 replacement_counted = true;
526 }
527
528 #[cfg(feature = "trace")]
529 println!(
530 " - Replacement '{}', benign={}, countable={}",
531 c, benign_replacement, countable_replacement
532 );
533
534 let ignore_sep = matches!(c, '-' | '\'' | '\n' | '\r');
543
544 for m in self.allocated.matches_tmp.iter() {
545 let m = m.clone();
546
547 if m.low_confidence_replacements > 5
548 || m.skipped > 5
549 || (m.node.word && m.repetitions > 20)
550 {
551 #[cfg(feature = "trace")]
552 println!("throwing out low confidence match: \"{}\"", m.node.trace);
553 }
555
556 safety_end = safety_end.min(m.start);
557
558 #[cfg(feature = "trace")]
559 println!(
560 " - Consider match \"{}\" with spaces={}, replacements={}",
561 m.node.trace, m.spaces, m.replacements
562 );
563
564 if (skippable || c == m.last || Some(c) == m.node.last)
565 && m.start != pos.unwrap_or(0)
566 {
567 let new_space =
572 matches!(c, ' ' | '.' | ',' | ':' | ';' | '…' | '(' | ')' | '_' | '-')
573 && m.node.last != Some(' ');
574 let new_repetition: bool = !new_space && c == m.last;
575 let new_skip = !new_space && skippable && !ignore_sep && !new_repetition;
576 let new_replacement = c == m.last && raw_c != c && !new_repetition;
578 let new_low_confidence_replacement =
579 new_replacement && raw_c.is_ascii_digit();
580
581 let undo_m = Match {
582 spaces: m.spaces.saturating_add(new_space as u8),
583 skipped: m.skipped.saturating_add(new_skip as u8),
584 replacements: m.replacements.saturating_add(new_replacement as u8),
585 low_confidence_replacements: m
586 .low_confidence_replacements
587 .saturating_add(new_low_confidence_replacement as u8),
588 repetitions: m.repetitions.saturating_add(new_repetition as u8),
589 last: c,
590 ..m
591 };
592 #[cfg(feature = "trace")]
593 println!(" (keep with last={}, node last={:?}, spaces={}, skip={}, repl={}, repet={})", undo_m.last, undo_m.node.last, undo_m.spaces, undo_m.skipped, undo_m.replacements, undo_m.repetitions);
594
595 if let Some(existing) = self.allocated.matches.get(&undo_m) {
596 let replacement = existing.combine(&undo_m);
597 self.allocated.matches.replace(replacement);
598 } else {
599 self.allocated.matches.insert(undo_m);
600 }
601 }
602
603 if let Some(next) = m.node.children.get(&c) {
604 let new_replacement = !benign_replacement && (c != raw_c) && c != ' ';
605 let new_low_confidence_replacement =
606 new_replacement && raw_c.is_ascii_digit();
607 let new_space =
608 !new_replacement && (raw_c != c && self.inline.separate && c != '\'');
609
610 let next_m = Match {
611 node: next,
612 spaces: m.spaces.saturating_add(new_space as u8),
613 replacements: m.replacements.saturating_add(new_replacement as u8),
614 low_confidence_replacements: m
615 .low_confidence_replacements
616 .saturating_add(new_low_confidence_replacement as u8),
617 last: c,
618 ..m
619 };
620
621 #[cfg(feature = "trace")]
622 println!(
623 " - Next is \"{}\", with spaces={}, replacements={}",
624 next.trace, next_m.spaces, next_m.replacements
625 );
626
627 if next.word {
628 if next_m.node.typ.is(Type::SAFE)
629 && next_m.start == 0
630 && next_m.spaces == 0
631 && next_m.skipped == 0
632 && next_m.replacements == 0
633 && !self.options.ignore_false_positives
634 {
635 #[cfg(feature = "trace")]
637 println!("found safe word: {}", next_m.node.trace);
638 self.inline.safe = true;
639 }
640
641 if next_m.node.typ.is(Type::ANY) {
657 self.allocated.pending_commit.push(Match {
658 end: pos.unwrap(),
659 ..next_m
660 });
661 } else if next_m.spaces == 0
662 && next_m.skipped == 0
663 && next_m.replacements == 0
664 && next_m.repetitions == 0 && !self.options.ignore_false_positives
666 {
667 #[cfg(feature = "trace")]
669 println!("Found false positive {}", next_m.node.trace);
670 drain_start = Some(
671 drain_start
672 .map(|start| start.min(next_m.start))
673 .unwrap_or(next_m.start),
674 );
675 }
676 }
677
678 if let Some(existing) = self.allocated.matches.get(&next_m) {
679 let replacement = existing.combine(&next_m);
680 self.allocated.matches.replace(replacement);
681 } else {
682 self.allocated.matches.insert(next_m);
683 }
684 }
685 }
686 }
687 self.allocated.matches_tmp.clear();
688 self.inline.last = Some(raw_c);
689 if let Some(pos) = pos {
690 self.inline.last_pos = pos;
691 }
692
693 let spy = &mut self.buffer;
694 let options = &self.options;
695 let inline = &mut self.inline;
696 let pending_commit = &mut self.allocated.pending_commit;
697 #[cfg(feature = "trace_full")]
698 let detections = &mut self.allocated.detections;
699
700 pending_commit.retain(|pending| {
701 #[cfg(feature = "trace")]
702 println!("Consider whether to cancel pending commit {} with start={} against drain_start={:?}", pending.node.trace, pending.start, drain_start);
703
704 if let Some(start) = drain_start {
706 if pending.start >= start {
707 #[cfg(feature = "trace")]
708 println!("Cancelled {}", pending.node.trace);
709 return false;
710 }
711 }
712
713 if pending.end < safety_end {
715 if pending.commit(
716 &mut inline.typ,
717 spy,
718 options.censor_threshold,
719 options.censor_first_character_threshold,
720 options.censor_replacement,
721 ) {
722 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
723 {
724 inline.match_ptrs ^= pending.node as *const _ as usize;
725 inline.total_matches += 1;
726 inline.total_match_characters += pending.end - pending.start;
727 #[cfg(feature = "trace_full")]
728 {
729 *detections.entry(pending.node.trace.clone()).or_default() += 1;
730 }
731 }
732 }
733 return false;
734 }
735
736 true
739 });
740
741 if let Some(spy_next_index) = self.buffer.spy_next_index() {
743 let mut safe_until = spy_next_index < safety_end;
745
746 for pending in &self.allocated.pending_commit {
748 if pending.start <= spy_next_index {
749 safe_until = false;
750 break;
751 }
752 }
753 if safe_until {
754 return self.buffer.spy_next();
755 }
756 }
757 }
758
759 let residual = mem::take(&mut self.allocated.pending_commit);
760 #[cfg(feature = "trace")]
761 if !residual.is_empty() {
762 println!("{} residuals", residual.len());
763 }
764 for pending in residual {
765 if pending.commit(
766 &mut self.inline.typ,
767 &mut self.buffer,
768 self.options.censor_threshold,
769 self.options.censor_first_character_threshold,
770 self.options.censor_replacement,
771 ) {
772 #[cfg(any(feature = "find_false_positives", feature = "trace"))]
773 {
774 self.inline.match_ptrs ^= pending.node as *const _ as usize;
775 self.inline.total_matches += 1;
776 self.inline.total_match_characters += pending.end - pending.start;
777 #[cfg(feature = "trace_full")]
778 {
779 *self
780 .allocated
781 .detections
782 .entry(pending.node.trace.clone())
783 .or_default() += 1;
784 }
785 }
786 }
787 }
788
789 if let Some(c) = self.buffer.spy_next() {
790 return Some(c);
791 }
792
793 self.inline.done = true;
794
795 None
796 }
797}
798
799pub trait CensorStr: Sized {
801 fn censor(self) -> String;
803
804 fn is_inappropriate(self) -> bool {
806 self.is(Type::INAPPROPRIATE)
807 }
808
809 fn is(self, threshold: Type) -> bool;
811
812 fn isnt(self, threshold: Type) -> bool {
814 !self.is(threshold)
815 }
816}
817
818impl CensorStr for &str {
819 fn censor(self) -> String {
820 if should_skip_censor(self) {
821 self.to_owned()
822 } else {
823 Censor::new(self.chars()).censor()
824 }
825 }
826
827 fn is(self, threshold: Type) -> bool {
828 Censor::from_str(self).analyze().is(threshold)
829 }
830}
831
832pub trait CensorIter {
834 type Iterator: Iterator<Item = char>;
835
836 fn censor(self) -> Self::Iterator;
839}
840
841impl<I: Iterator<Item = char> + Clone> CensorIter for I {
842 type Iterator = Censor<I>;
843
844 fn censor(self) -> Self::Iterator {
847 Censor::new(self)
848 }
849}
850
851pub(crate) fn should_skip_censor(string: &str) -> bool {
854 let mut some_special = false;
855 for c in string.chars() {
856 use finl_unicode::categories::CharacterCategories;
857 if ('\u{0900}'..='\u{097F}').contains(&c) {
859 some_special = true;
860 } else if !(c.is_whitespace() || c.is_separator()) {
861 return false;
862 }
863 }
864 some_special
865}
866
867#[cfg(feature = "customize")]
886#[deprecated = "Use the equivalent Trie::customize_default().set(word, typ) or the safe API Censor::with_trie"]
887pub unsafe fn add_word(word: &str, typ: Type) {
888 Trie::customize_default().set(word, typ)
889}
890
891#[cfg(test)]
892mod tests {
893 #![allow(unused_imports)]
894
895 extern crate test;
896 use crate::censor::should_skip_censor;
897 use crate::{Censor, CensorIter, CensorStr, Trie, Type};
898 use bitflags::_core::ops::Not;
899 use rand::prelude::ThreadRng;
900 use rand::{thread_rng, Rng};
901 use serial_test::serial;
902 use std::fs::File;
903 use std::io::BufReader;
904 use std::time::{Duration, Instant};
905 use test::Bencher;
906
907 #[test]
908 #[serial]
909 fn short_replacement() {
910 "99".isnt(Type::PROFANE);
911 "900".isnt(Type::PROFANE);
912 "kkk".is(Type::OFFENSIVE);
913 }
914
915 #[test]
916 #[serial]
917 fn unicode_whitespace() {
918 assert!("fu\u{1160}ck".is(Type::PROFANE));
919 assert!(!"fu\u{1161}ck".is(Type::PROFANE));
920 }
921
922 #[test]
923 #[serial]
924 fn unicode_abuse() {
925 let mut rng = thread_rng();
926
927 fn random_string(rng: &mut ThreadRng, len: usize) -> String {
928 rng.sample_iter::<char, _>(rand::distributions::Standard)
929 .take(len)
930 .collect()
931 }
932
933 for _ in 0..10 {
934 let input = random_string(&mut rng, 100);
935 let censored = input.censor();
936
937 assert!(censored.len() < input.len() / 2);
939
940 println!("{} -> {}", input, censored);
941 }
942 }
943
944 #[allow(dead_code)]
945 fn find_detection(text: &str) {
946 let holistic = Censor::from_str(text).analyze();
947
948 if holistic & Type::SPAM.not() != Type::NONE {
949 println!("{}", text);
950
951 let mut start = 0;
953 let mut end = text.chars().count();
954
955 while start < end
956 && Censor::new(text.chars().skip(start).take(end - start))
957 .analyze()
958 .is(Type::ANY)
959 {
960 start += 1;
961 }
962 start = start.saturating_sub(1);
963 while start < end
964 && Censor::new(text.chars().skip(start).take(end - start))
965 .analyze()
966 .is(Type::ANY)
967 {
968 end -= 1;
969 }
970 end += 1;
971 for _ in 0..start {
972 print!("-");
973 }
974 for _ in start..end {
975 print!("^");
976 }
977 print!(" ");
978 println!(
979 "(\"{}\" is {:?})",
980 text.chars()
981 .skip(start)
982 .take(end - start)
983 .collect::<String>(),
984 holistic
985 );
986 } else {
987 println!("{} ({:?})", text, holistic);
988 }
989 }
990
991 #[test]
992 #[serial]
993 fn curated() {
994 let mut cases: Vec<(&str, bool, Option<bool>)> = vec![("", false, Some(false))];
995 cases.extend(
996 include_str!("test_positive.txt")
997 .split('\n')
998 .filter(|l| !l.is_empty())
999 .map(|l| (l, true, Some(false))),
1000 );
1001 cases.extend(
1002 include_str!("test_negative.txt")
1003 .split('\n')
1004 .filter(|l| !l.is_empty())
1005 .map(|l| (l, false, None)),
1006 );
1007 cases.extend(
1008 include_str!("safe.txt")
1009 .split('\n')
1010 .filter(|l| !l.is_empty() && !l.starts_with('#'))
1011 .map(|l| (l, false, Some(true))),
1012 );
1013 cases.extend(
1014 include_str!("test_safe.txt")
1015 .split('\n')
1016 .filter(|l| !l.is_empty())
1017 .map(|l| (l, false, Some(true))),
1018 );
1019
1020 let mut failures = Vec::new();
1021
1022 for (case, any_truth, safe_truth) in cases {
1023 let typ = Censor::from_str(case).analyze();
1029 let any = typ.is(Type::ANY);
1030 let safe = typ.is(Type::SAFE);
1031
1032 if any != any_truth {
1036 find_detection(case);
1037 failures.push(format!("FAIL: Predicted {:?} for: \"{}\"", typ, case));
1038 } else if !any_truth {
1039 let censored = case.censor();
1041 if case != censored {
1042 failures.push(format!("Censored: : \"{case}\" -> {censored}"))
1043 }
1044 }
1045 if let Some(safe_truth) = safe_truth {
1046 if safe != safe_truth {
1047 failures.push(format!("FAIL: Predicted safe={} for: \"{}\"", safe, case));
1048 }
1049 }
1050 }
1051
1052 if !failures.is_empty() {
1053 for failure in failures {
1054 println!("{failure}");
1055 }
1056 panic!();
1057 }
1058 }
1059
1060 #[test]
1061 #[serial]
1062 fn censor() {
1063 let censored = Censor::from_str("HELLO fučk Shit nudes WORLD!")
1064 .with_censor_replacement('#')
1065 .with_censor_first_character_threshold(Type::SEXUAL & Type::SEVERE)
1066 .censor();
1067
1068 assert_eq!(censored, "HELLO f### S### ##### WORLD!");
1069
1070 assert_eq!("fcking coward".censor(), "f***** coward");
1072
1073 let censored = Censor::from_str("卍")
1074 .with_censor_first_character_threshold(Type::NONE)
1075 .censor();
1076
1077 assert_eq!(censored, "*");
1078 }
1079
1080 #[test]
1081 #[serial]
1082 fn bidirectional() {
1083 assert_eq!("an toidi", "an \u{202e}toidi".censor());
1085 }
1086
1087 #[test]
1088 #[serial]
1089 fn analyze() {
1090 let analysis = Censor::from_str("HELLO fuck shit WORLD!").analyze();
1091
1092 assert_ne!(analysis, Type::NONE);
1093 assert!(analysis.is(Type::INAPPROPRIATE));
1094 assert!(analysis.is(Type::PROFANE));
1095 assert!(analysis.isnt(Type::SEXUAL & Type::SEVERE));
1096 assert!(analysis.isnt(Type::OFFENSIVE));
1097 assert!(analysis.isnt(Type::MEAN));
1098 }
1099
1100 #[test]
1102 #[serial]
1103 fn apis() {
1104 "abcd".censor();
1105 String::from("abcd").censor();
1106 let _ = "abcd".chars().censor().collect::<String>();
1107 let (_, _) = Censor::new("abcd".chars())
1108 .with_censor_replacement('?')
1109 .censor_and_analyze();
1110 let mut censor = Censor::from_str("abcd");
1111 let _ = censor.censor();
1112 let _ = censor.analyze();
1113 let (_, _) = Censor::from_str("HELLO crap WORLD!").censor_and_analyze();
1114 }
1115
1116 #[test]
1117 #[serial]
1118 fn levels() {
1119 assert!("poo".is(Type::PROFANE & Type::MILD));
1120 assert!("poo".is(Type::PROFANE & Type::MILD_OR_HIGHER));
1121 assert!("poo".isnt(Type::PROFANE & Type::MODERATE));
1122 assert!("poo".isnt(Type::PROFANE & Type::MODERATE_OR_HIGHER));
1123 assert!("poo".isnt(Type::PROFANE & Type::SEVERE));
1124 assert!("arse".is(Type::PROFANE & Type::MODERATE));
1125 assert!("arse".is(Type::PROFANE & Type::MILD_OR_HIGHER));
1126 assert!("arse".is(Type::PROFANE & Type::MODERATE_OR_HIGHER));
1127 assert!("arse".isnt(Type::PROFANE & Type::MILD));
1128 assert!("arse".isnt(Type::PROFANE & Type::SEVERE));
1129 assert!("i hope you die".is(Type::MEAN & Type::SEVERE));
1130 assert!("i hope you die".is(Type::MEAN & Type::MILD_OR_HIGHER));
1131 assert!("i hope you die".is(Type::MEAN & Type::MODERATE_OR_HIGHER));
1132 assert!("i hope you die".isnt(Type::MEAN & Type::MILD));
1133 assert!("i hope you die".isnt(Type::MEAN & Type::MODERATE));
1134 assert!("You said your mother only smiled on her TV show".isnt(
1135 Type::PROFANE
1136 | Type::OFFENSIVE
1137 | Type::SEXUAL & Type::MODERATE_OR_HIGHER
1138 | Type::MEAN & Type::SEVERE
1139 ));
1140 }
1141
1142 #[test]
1143 #[serial]
1144 fn repetitions_non_safe() {
1145 assert!("hello".is(Type::SAFE));
1146 assert!("helllo".is(Type::SAFE));
1147 assert!("hellllllllo".isnt(Type::SAFE));
1148 }
1149
1150 #[test]
1151 #[serial]
1152 #[cfg(not(debug_assertions))]
1153 fn accuracy() {
1154 fn rustrict(s: &str) -> bool {
1155 s.is(Type::ANY)
1156 }
1157
1158 #[allow(dead_code)]
1159 fn rustrict_old(s: &str) -> bool {
1160 rustrict_old::CensorStr::is(s, rustrict_old::Type::ANY)
1161 }
1162
1163 fn censor(s: &str) -> bool {
1164 use censor_crate::*;
1165 let filter = Standard + Sex + Zealous;
1166 filter.check(s)
1167 }
1168
1169 let mut stfu_filter = stfu_crate::types::OwnedFilter::default();
1170 use stfu_crate::word_lists::severity::{MILD, SEVERE, STRONG};
1171 stfu_filter.add_slice(&MILD);
1172 stfu_filter.add_slice(&STRONG);
1173 stfu_filter.add_slice(&SEVERE);
1174
1175 let stfu = |s: &str| -> bool { stfu_filter.filter_string(s).is_some() };
1176
1177 println!("| Crate | Accuracy | Positive Accuracy | Negative Accuracy | Time |");
1178 println!("|-------|----------|-------------------|-------------------|------|");
1179 print_accuracy(
1180 "https://crates.io/crates/rustrict",
1181 rustrict,
1182 false, Some(rustrict_old as fn(&str) -> bool).filter(|_| std::env::var("COMPARE").is_ok()),
1184 );
1185 print_accuracy("https://crates.io/crates/censor", censor, false, None);
1186 print_accuracy("https://crates.io/crates/stfu", stfu, false, None);
1187 }
1188
1189 #[allow(dead_code)]
1190 fn print_accuracy(
1191 link: &str,
1192 checker: impl Fn(&str) -> bool,
1193 find_detections: bool,
1194 compare_to: Option<fn(&str) -> bool>,
1195 ) {
1196 let start = Instant::now();
1197 let (total, positive, negative) = accuracy_of(checker, find_detections, compare_to);
1198 println!(
1199 "| [{}]({}) | {:.2}% | {:.2}% | {:.2}% | {:.2}s |",
1200 link.split('/').last().unwrap(),
1201 link,
1202 total * 100.0,
1203 positive * 100.0,
1204 negative * 100.0,
1205 start.elapsed().as_secs()
1206 );
1207 }
1208
1209 #[allow(dead_code)]
1210 fn accuracy_of(
1211 checker: impl Fn(&str) -> bool,
1212 find_detections: bool,
1213 compare_to: Option<fn(&str) -> bool>,
1214 ) -> (f32, f32, f32) {
1215 let file = File::open("test.csv").unwrap();
1216 let reader = BufReader::new(file);
1217 let mut csv = csv::Reader::from_reader(reader);
1218
1219 let mut correct_positive = 0;
1220 let mut correct_negative = 0;
1221 let mut total_positive = 0;
1222 let mut total_negative = 0;
1223
1224 for line in csv.records().take(100000) {
1225 let record = line.unwrap();
1226 let truth = record[0].parse::<i8>().unwrap() == 1;
1227 let text = &record[1];
1228 let prediction = checker(text);
1229 if prediction == truth {
1231 if truth {
1232 correct_positive += 1;
1233 } else {
1234 correct_negative += 1;
1235 }
1236 } else if find_detections && text.len() < 100 {
1237 println!("{}: {}", truth, text);
1238 if prediction {
1239 find_detection(text);
1240 }
1241 }
1242 if let Some(checker) = compare_to {
1243 let compare_prediction = checker(text);
1244 if prediction != compare_prediction && text.len() < 100 {
1245 println!("COMPARISON: On \"{}\", output {} instead", text, prediction);
1246 }
1247 }
1248 if truth {
1249 total_positive += 1;
1250 } else {
1251 total_negative += 1;
1252 }
1253 }
1254
1255 (
1256 (correct_positive + correct_negative) as f32 / (total_positive + total_negative) as f32,
1257 correct_positive as f32 / total_positive as f32,
1258 correct_negative as f32 / total_negative as f32,
1259 )
1260 }
1261
1262 #[test]
1263 #[serial]
1264 fn devanagari() {
1265 println!("f\u{0900}u\u{0900}c\u{0900}k");
1266 const TEST: &'static str = "हत्यारा मकसहूद भाई तुम बड़ा मस्त काम करती।";
1267 assert!(should_skip_censor(TEST));
1268 assert_eq!(TEST, TEST.censor());
1269 }
1270
1271 #[test]
1272 #[serial]
1273 fn pancakes() {
1274 assert_eq!(
1275 "🥞",
1276 std::str::from_utf8(&[240, 159, 165, 158]).unwrap().censor()
1277 );
1278 }
1279
1280 #[test]
1281 #[serial]
1282 fn japanese_diacritics_preserved() {
1283 assert_eq!("パピプペポ", "パピプペポ".censor());
1284 assert_eq!("バビブベボ", "バビブベボ".censor());
1285 assert_eq!("ぱぴぷぺぽ", "ぱぴぷぺぽ".censor());
1286 assert_eq!("ばびぶべぼ", "ばびぶべぼ".censor());
1287 }
1288
1289 #[test]
1290 #[serial]
1291 fn bandwidth() {
1292 let file = File::open("test.csv").unwrap();
1293 let total_len = file.metadata().unwrap().len() as usize;
1294 let reader = BufReader::new(file);
1295 let mut csv = csv::Reader::from_reader(reader);
1296
1297 let mut text = String::with_capacity(total_len);
1298
1299 for line in csv.records().take(100000) {
1300 let record = line.unwrap();
1301 text.push_str(&record[1]);
1302 }
1303
1304 for power in 1..16 {
1305 let len = 2usize.pow(power);
1306
1307 if len > text.len() {
1308 break;
1309 }
1310
1311 let now = Instant::now();
1312
1313 let (_, _) = Censor::from_str(&text[0..len]).censor_and_analyze();
1314
1315 let elapsed = now.elapsed();
1316
1317 println!(
1318 "{}, {}, {}",
1319 len,
1320 elapsed.as_secs_f32(),
1321 len as f32 / elapsed.as_secs_f32() / 1000.0 / 1000.0
1322 );
1323 }
1324 }
1325
1326 #[cfg(feature = "customize")]
1327 #[test]
1328 #[serial]
1329 #[allow(deprecated)]
1330 fn customize() {
1331 use crate::add_word;
1332
1333 let test_profanity = "thisisafakeprofanityfortesting";
1334 let test_profanity_issue_7 = "плохоеслово";
1335 let test_safe = "thisisafakesafewordfortesting";
1336
1337 unsafe {
1339 add_word(test_profanity, Type::PROFANE & Type::SEVERE);
1340 add_word(test_profanity_issue_7, Type::PROFANE & Type::SEVERE);
1341 add_word(test_safe, Type::SAFE);
1342 }
1343
1344 assert!(test_profanity.is(Type::PROFANE & Type::SEVERE));
1345 assert!(test_profanity_issue_7.is(Type::PROFANE & Type::SEVERE));
1346 assert!(test_safe.is(Type::SAFE));
1347
1348 unsafe {
1349 add_word(test_profanity, Type::NONE);
1350 }
1351
1352 assert!(test_profanity.isnt(Type::PROFANE));
1353 }
1354
1355 #[cfg(feature = "serde")]
1356 #[test]
1357 #[serial]
1358 fn serde() {
1359 let large = Trie::default();
1360 let bc = bincode::serialize(&large).unwrap();
1361 let json = serde_json::to_string(&large).unwrap();
1362 println!("large bincode {}, large json {}", bc.len(), json.len());
1363
1364 let mut trie = Trie::new();
1365 trie.set("squeak", Type::SPAM & Type::MILD);
1366 trie.set("squirrel", Type::SAFE);
1367
1368 let bc = bincode::serialize(&trie).unwrap();
1369 println!("smol bincode (len {}): {bc:?}", bc.len());
1370 let json = serde_json::to_string(&trie).unwrap();
1371 println!("smol json (len {}): {json}", json.len());
1372 }
1373
1374 #[allow(soft_unstable)]
1375 #[bench]
1376 fn bench_is_inappropriate(b: &mut Bencher) {
1377 b.iter(|| test::black_box("hello fuck world shit").is_inappropriate());
1378 }
1379
1380 #[allow(soft_unstable)]
1381 #[bench]
1382 fn bench_is_inappropriate_long(b: &mut Bencher) {
1383 b.iter(|| test::black_box("hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit").is_inappropriate());
1384 }
1385
1386 #[allow(soft_unstable)]
1387 #[bench]
1388 fn bench_censor(b: &mut Bencher) {
1389 b.iter(|| test::black_box("hello fuck world shit").censor());
1390 }
1391}