rustrict/
censor.rs

1use crate::banned::BANNED;
2use crate::buffer_proxy_iterator::BufferProxyIterator;
3use crate::mtch::*;
4use crate::replacements::REPLACEMENTS;
5use crate::trie::*;
6use crate::Set;
7use crate::{is_whitespace, Replacements, Type};
8use std::iter::Filter;
9use std::mem;
10use std::ops::Deref;
11use std::ops::RangeInclusive;
12use std::str::Chars;
13use unicode_normalization::{Decompositions, Recompositions, UnicodeNormalization};
14
15/// Censor is a flexible profanity filter that can analyze and/or censor arbitrary text.
16///
17/// You can also make use of `Censor` via traits `CensorStr` and `CensorIter`, which allow inline
18/// checking and censoring of `&str` and `Iterator<Item = char>` respectively.
19pub struct Censor<I: Iterator<Item = char>> {
20    /// A buffer of the input that stores unconfirmed characters (may need to censor before flushing).
21    /// This is so the censored output is unaffected by the subsequent iterator machinery.
22    buffer: BufferProxyIterator<Recompositions<Filter<Decompositions<I>, fn(&char) -> bool>>>,
23    options: Options,
24    inline: InlineState,
25    allocated: AllocatedState,
26}
27
28struct Options {
29    trie: &'static Trie,
30    replacements: &'static Replacements,
31    //banned: &'static Banned,
32    ignore_false_positives: bool,
33    ignore_self_censoring: bool,
34    censor_first_character_threshold: Type,
35    //preserve_accents: bool,
36    censor_replacement: char,
37    censor_threshold: Type,
38}
39
40impl Default for Options {
41    fn default() -> Self {
42        Self {
43            trie: &*TRIE,
44            replacements: &*REPLACEMENTS,
45            //banned: &*BANNED,
46            ignore_false_positives: false,
47            ignore_self_censoring: false,
48            censor_first_character_threshold: Type::OFFENSIVE & Type::SEVERE,
49            //preserve_accents: false,
50            censor_replacement: '*',
51            censor_threshold: Default::default(),
52        }
53    }
54}
55
56struct InlineState {
57    /// Whether the last character can be considered a separator.
58    separate: bool,
59    /// The last position matched against.
60    last_pos: usize,
61    /// An accumulation of the different types of inappropriateness.
62    typ: Type,
63    /// Counters (mainly for spam detection).
64    uppercase: u8,
65    repetitions: u8,
66    last: Option<char>,
67    gibberish: u8,
68    replacements: u8,
69    /// How many instances of censor replacement in the raw text?
70    self_censoring: u8,
71    /// Is the input completely safe.
72    safe: bool,
73    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
74    match_ptrs: usize,
75    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
76    total_matches: usize,
77    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
78    total_match_characters: usize,
79    /// Whether already appended a space at the end.
80    space_appended: bool,
81    /// Whether all processing of characters has completed.
82    done: bool,
83}
84
85impl Default for InlineState {
86    fn default() -> Self {
87        Self {
88            // The beginning of the sequence is a separator.
89            separate: true,
90            // Nothing was detected yet.
91            typ: Type::NONE,
92            uppercase: 0,
93            repetitions: 0,
94            last: None,
95            gibberish: 0,
96            replacements: 0,
97            self_censoring: 0,
98            safe: false,
99            space_appended: false,
100            done: false,
101            last_pos: usize::MAX,
102            #[cfg(any(feature = "find_false_positives", feature = "trace"))]
103            match_ptrs: 0,
104            #[cfg(any(feature = "find_false_positives", feature = "trace"))]
105            total_matches: 0,
106            #[cfg(any(feature = "find_false_positives", feature = "trace"))]
107            total_match_characters: 0,
108        }
109    }
110}
111
112#[derive(Default)]
113struct AllocatedState {
114    /// Where potential matches are kept between calls to Self::next.
115    matches: Set<Match>,
116    /// Where potential matches are temporarily shuffled. Only allocate this once.
117    matches_tmp: Set<Match>,
118    /// Where matches are kept after they are complete but may be cancelled due to false positives.
119    pending_commit: Vec<Match>,
120    #[cfg(feature = "trace_full")]
121    detections: crate::Map<String, usize>,
122}
123
124impl AllocatedState {
125    fn clear(&mut self) {
126        let Self {
127            matches,
128            matches_tmp,
129            pending_commit,
130            #[cfg(feature = "trace_full")]
131            detections,
132        } = self;
133        matches.clear();
134        matches_tmp.clear();
135        pending_commit.clear();
136        #[cfg(feature = "trace_full")]
137        detections.clear();
138    }
139}
140
141impl<'a> Censor<Chars<'a>> {
142    /// Creates a `Censor` from a `&str`, ready to censor or analyze it.
143    pub fn from_str(s: &'a str) -> Self {
144        Self::new(s.chars())
145    }
146}
147
148impl<I: Iterator<Item = char>> Censor<I> {
149    /// Allocates a new `Censor` for analyzing and/or censoring text.
150    pub fn new(text: I) -> Self {
151        Self {
152            buffer: Self::buffer_from(text),
153            options: Default::default(),
154            inline: Default::default(),
155            allocated: Default::default(),
156        }
157    }
158
159    fn buffer_from(
160        text: I,
161    ) -> BufferProxyIterator<Recompositions<Filter<Decompositions<I>, fn(&char) -> bool>>> {
162        // Detects if a char isn't a diacritical mark (accent) or banned, such that such characters may be
163        // filtered on that basis.
164        fn filter_char(c: &char) -> bool {
165            use finl_unicode::categories::{CharacterCategories, MinorCategory};
166            let category = c.get_minor_category();
167            // Preserve Japanese dakuten/handakuten so kana aren't turned into their unvoiced forms.
168            let preserve_japanese = matches!(*c, '\u{3099}' | '\u{309A}');
169            let nok = matches!(
170                category,
171                MinorCategory::Cn | MinorCategory::Co | MinorCategory::Mn
172            ) && !preserve_japanese;
173
174            !(nok || BANNED.deref().deref().contains(*c))
175        }
176
177        BufferProxyIterator::new(
178            text
179                // The following three transformers are to ignore diacritical marks.
180                .nfd()
181                .filter(filter_char as fn(&char) -> bool)
182                .nfc(),
183        )
184    }
185
186    /// Resets the `Censor` with new text. Does not change any configured options.
187    /// This avoids reallocation of internal buffers on the heap.
188    pub fn reset(&mut self, text: I) {
189        self.inline = Default::default();
190        self.allocated.clear();
191        self.buffer = Self::buffer_from(text);
192    }
193
194    /// Replaces the trie containing profanity, false positives, and safe words.
195    pub fn with_trie(&mut self, trie: &'static Trie) -> &mut Self {
196        self.options.trie = trie;
197        self
198    }
199
200    /// Replaces the set of character replacements.
201    pub fn with_replacements(&mut self, replacements: &'static Replacements) -> &mut Self {
202        self.options.replacements = replacements;
203        self
204    }
205
206    /// Selects a threshold to apply while censoring. Only words that meet or exceed the threshold
207    /// are censored.
208    ///
209    /// At present, [`Type::SPAM`] cannot be censored.
210    ///
211    /// The default is [`Type::INAPPROPRIATE`].
212    pub fn with_censor_threshold(&mut self, censor_threshold: Type) -> &mut Self {
213        self.options.censor_threshold = censor_threshold;
214        self
215    }
216
217    /// Censor words like "sh*t" in "push it," which heavily increases false positives, but
218    /// slightly decreases false negatives.
219    ///
220    /// The default is `false`.
221    pub fn with_ignore_false_positives(&mut self, ignore_false_positives: bool) -> &mut Self {
222        self.options.ignore_false_positives = ignore_false_positives;
223        self
224    }
225
226    /// Do not count instances of censor replacement in the input text as possible profanity.
227    ///
228    /// If `false`, the input `"****"` will be assumed to be profane since if censor replacement is
229    /// set to `'*'`. This can help in cases like `"mother******"` where, if the user hadn't self
230    /// censored, the censored version would have been `"m***********"`.
231    ///
232    /// At present, only affects analysis and not censoring.
233    ///
234    /// The default is `false`.
235    pub fn with_ignore_self_censoring(&mut self, ignore_self_censoring: bool) -> &mut Self {
236        self.options.ignore_self_censoring = ignore_self_censoring;
237        self
238    }
239
240    /// Censor all characters e.g. "xxxx," instead of all but the first e.g. "fxxx," if the word
241    /// meets this threshold.
242    ///
243    /// The default is `false`.
244    pub fn with_censor_first_character_threshold(
245        &mut self,
246        censor_first_character_threshold: Type,
247    ) -> &mut Self {
248        self.options.censor_first_character_threshold = censor_first_character_threshold;
249        self
250    }
251
252    /*
253    /// Preserve diacritics/accents, at the cost of detecting accented words such as f̸̪͇͘ų̷̖̽c̸͙̎̚k̶͚̗͛.
254    ///
255    /// The default is false.
256    pub fn with_preserve_accents(&mut self, preserve_accents: bool) {
257        self.options.preserve_accents = preserve_accents;
258    }
259     */
260
261    /// Sets the character used to censor detected words.
262    ///
263    /// The default is `'*'`.
264    pub fn with_censor_replacement(&mut self, censor_replacement: char) -> &mut Self {
265        self.options.censor_replacement = censor_replacement;
266        self
267    }
268
269    /// Useful for processing sub-slices of profanity.
270    #[cfg(feature = "find_false_positives")]
271    pub fn with_separate(&mut self, separate: bool) -> &mut Self {
272        self.inline.separate = separate;
273        self
274    }
275
276    /// Produces a censored string. If called, it must be the first form of processing. It
277    /// entirely consumes and censors the input characters.
278    ///
279    /// # Unfortunate Side Effects
280    ///
281    /// All diacritical marks (accents) are removed by the current implementation. This is subject
282    /// to change, as a better implementation would make this optional.
283    ///
284    /// # Panics
285    ///
286    /// If called after analyze or a previous call to censor (except if reset is called in between).
287    pub fn censor(&mut self) -> String {
288        assert!(
289            !self.buffer.index().is_some(),
290            "censor must be called before any other form of processing"
291        );
292        self.collect()
293    }
294
295    /// Fully analyzes a the input characters, to determine the type of inappropriateness present, if any.
296    ///
297    /// The return value can be introspected with `Type::is`.
298    pub fn analyze(&mut self) -> Type {
299        self.ensure_done();
300        self.analysis()
301    }
302
303    /// Equivalent to `censor` and `analyze`, but in one pass through the input.
304    pub fn censor_and_analyze(&mut self) -> (String, Type) {
305        // It is important that censor is called first, so that the input is processed.
306        let censored = self.censor();
307        // After that, analysis is ready to call.
308        (censored, self.analysis())
309    }
310
311    /// Converts internal weights to a `Type`.
312    fn analysis(&self) -> Type {
313        self.inline.typ | self.safe_self_censoring_and_spam_detection()
314    }
315
316    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
317    pub fn match_ptrs(&self) -> usize {
318        self.inline.match_ptrs
319    }
320
321    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
322    pub fn total_matches(&self) -> usize {
323        self.inline.total_matches
324    }
325
326    #[cfg(any(feature = "find_false_positives", feature = "trace"))]
327    pub fn total_match_characters(&self) -> usize {
328        self.inline.total_match_characters
329    }
330
331    #[cfg(feature = "trace_full")]
332    pub fn detections(&self) -> &crate::Map<String, usize> {
333        &self.allocated.detections
334    }
335
336    fn ensure_done(&mut self) {
337        if !self.inline.done {
338            for _ in self {}
339        }
340    }
341
342    fn safe_self_censoring_and_spam_detection(&self) -> Type {
343        let safe = if self.inline.safe && self.inline.repetitions < 4 {
344            Type::SAFE
345        } else {
346            Type::NONE
347        };
348
349        if self.inline.last_pos < 6 {
350            // Short strings consisting of a single acronym are problematic percentage-wise.
351            return safe;
352        }
353
354        // Total opportunities for spam and self censoring. A bias is added so that a few words in a
355        // relatively short string won't create massive percentages.
356        let total = self
357            .inline
358            .last_pos
359            .saturating_add(6)
360            .min(u16::MAX as usize) as u16;
361
362        // Total spam.
363        let spam = self
364            .inline
365            .uppercase
366            .max(self.inline.repetitions)
367            .max(self.inline.gibberish / 2)
368            .max(self.inline.replacements) as u16;
369
370        // Calculate percents.
371        let percent_spam = 100 * spam / total;
372        let percent_self_censoring = 100 * self.inline.self_censoring as u16 / total;
373
374        // Assess amount of spam.
375        let spam = if percent_spam >= 70 && self.inline.last_pos >= 20 {
376            Type::SPAM & Type::SEVERE
377        } else if percent_spam >= 50 && self.inline.last_pos >= 10 {
378            Type::SPAM & Type::MODERATE
379        } else if percent_spam >= 30 {
380            Type::SPAM & Type::MILD
381        } else {
382            Type::NONE
383        };
384
385        // Assess amount of self-censoring.
386        let self_censoring = if !self.options.ignore_self_censoring && percent_self_censoring > 20 {
387            Type::PROFANE & Type::MILD
388        } else {
389            Type::NONE
390        };
391
392        safe | spam | self_censoring
393    }
394}
395
396impl<I: Iterator<Item = char>> Iterator for Censor<I> {
397    type Item = char;
398
399    /// Retrieves the next (potentially censored) character.
400    fn next(&mut self) -> Option<Self::Item> {
401        while let Some(raw_c) = self.buffer.next().or_else(|| {
402            if self.inline.space_appended {
403                None
404            } else {
405                self.inline.space_appended = true;
406                Some(' ')
407            }
408        }) {
409            if !self.inline.space_appended && raw_c != '!' && raw_c != '.' && raw_c != '?' {
410                // The input is not over yet, so any previous notion of safety is irrelevant.
411                self.inline.safe = false;
412            }
413
414            let pos = self.buffer.index();
415
416            self.inline.uppercase = self
417                .inline
418                .uppercase
419                .saturating_add(raw_c.is_uppercase() as u8);
420
421            let skippable = !raw_c.is_alphabetic() || is_whitespace(raw_c);
422            let replacement = self.options.replacements.get(raw_c);
423
424            #[cfg(feature = "trace")]
425            println!(
426                "Read '{}', skippable={}, replacing with={:?}",
427                raw_c, skippable, replacement
428            );
429
430            const BLOCK_ELEMENTS: RangeInclusive<char> = '\u{2580}'..='\u{259F}';
431
432            if (!self.inline.separate || self.inline.last == Some(self.options.censor_replacement))
433                && (raw_c == self.options.censor_replacement || BLOCK_ELEMENTS.contains(&raw_c))
434            {
435                // Censor replacement found but not beginning of word.
436                self.inline.self_censoring = self.inline.self_censoring.saturating_add(1);
437            }
438
439            if let Some(last) = self.inline.last {
440                if raw_c == last {
441                    self.inline.repetitions = self.inline.repetitions.saturating_add(1);
442                }
443
444                // Characters on the home-row of a QWERTY keyboard.
445                fn is_gibberish(c: char) -> bool {
446                    matches!(c, 'a' | 's' | 'd' | 'f' | 'j' | 'k' | 'l' | ';')
447                }
448
449                // Single gibberish characters don't count. Must have been preceded by another gibberish character.
450                if is_gibberish(raw_c) && is_gibberish(last) {
451                    self.inline.gibberish = self.inline.gibberish.saturating_add(1);
452                }
453            }
454
455            if let Some(pos) = pos {
456                // Must special-case all skippable, non-replaced characters that may start
457                // a profanity, so that these profanities are detected.
458                //
459                // Not adding a match is mainly an optimization.
460                if !(skippable
461                    && replacement.is_none()
462                    && !self.options.trie.root.children.contains_key(&raw_c))
463                {
464                    let begin_camel_case_word = raw_c.is_ascii_uppercase()
465                        && self
466                            .inline
467                            .last
468                            .map(|c| !c.is_ascii_uppercase())
469                            .unwrap_or(false);
470
471                    // Seed a new match for every character read.
472                    self.allocated.matches.insert(Match {
473                        node: &self.options.trie.root,
474                        start: pos, // will immediately be incremented if match is kept.
475                        end: usize::MAX, // sentinel.
476                        last: 0 as char, // sentinel.
477                        begin_separate: self.inline.separate || begin_camel_case_word,
478                        end_separate: false, // unknown at this time.
479                        spaces: 0,
480                        skipped: 0,
481                        replacements: 0,
482                        repetitions: 0,
483                        low_confidence_replacements: 0,
484                    });
485                }
486            }
487
488            self.inline.separate = skippable;
489
490            if self.inline.separate {
491                for pending in self.allocated.pending_commit.iter_mut() {
492                    if pending.end == self.inline.last_pos {
493                        pending.end_separate = true;
494                    }
495                }
496            }
497
498            let mut drain_start: Option<usize> = None;
499            let mut safety_end = usize::MAX;
500            let mut replacement_counted = false;
501            let raw_c_lower = raw_c.to_lowercase().next().unwrap();
502
503            mem::swap(&mut self.allocated.matches, &mut self.allocated.matches_tmp);
504            for c in replacement
505                .map(|a| a.as_str())
506                .unwrap_or(&&*raw_c.encode_utf8(&mut [0; 4]))
507                .chars()
508            {
509                // This replacement (uppercase to lower case) raises absolutely zero suspicion.
510                let benign_replacement = c == raw_c || c == raw_c_lower;
511
512                // This counts as a replacement, mainly for spam detection purposes.
513                let countable_replacement = !(replacement_counted
514                    || benign_replacement
515                    || raw_c.is_ascii_alphabetic()
516                    || (raw_c.is_ascii_digit()
517                        && self
518                            .inline
519                            .last
520                            .map(|l| l.is_ascii_digit())
521                            .unwrap_or(false)));
522
523                if countable_replacement {
524                    self.inline.replacements = self.inline.replacements.saturating_add(1);
525                    replacement_counted = true;
526                }
527
528                #[cfg(feature = "trace")]
529                println!(
530                    " - Replacement '{}', benign={}, countable={}",
531                    c, benign_replacement, countable_replacement
532                );
533
534                // These separators don't invalidate a false-positive match.
535                //
536                // -
537                // half-right =/= frig
538                //
539                // '
540                // invalidating false positives in cases like didn't (it where ( is a space.
541                // also, so "i'm fine" matches "im fine" for safety purposes.
542                let ignore_sep = matches!(c, '-' | '\'' | '\n' | '\r');
543
544                for m in self.allocated.matches_tmp.iter() {
545                    let m = m.clone();
546
547                    if m.low_confidence_replacements > 5
548                        || m.skipped > 5
549                        || (m.node.word && m.repetitions > 20)
550                    {
551                        #[cfg(feature = "trace")]
552                        println!("throwing out low confidence match: \"{}\"", m.node.trace);
553                        //continue;
554                    }
555
556                    safety_end = safety_end.min(m.start);
557
558                    #[cfg(feature = "trace")]
559                    println!(
560                        "  - Consider match \"{}\" with spaces={}, replacements={}",
561                        m.node.trace, m.spaces, m.replacements
562                    );
563
564                    if (skippable || c == m.last || Some(c) == m.node.last)
565                        && m.start != pos.unwrap_or(0)
566                    {
567                        // Here, '.' is primarily for allowing ellipsis ("...") as a form of
568                        // space.
569                        // ( and ) are for ignoring appositive phrases.
570                        // Checking node.last is to collapse multiple spaces into one
571                        let new_space =
572                            matches!(c, ' ' | '.' | ',' | ':' | ';' | '…' | '(' | ')' | '_' | '-')
573                                && m.node.last != Some(' ');
574                        let new_repetition: bool = !new_space && c == m.last;
575                        let new_skip = !new_space && skippable && !ignore_sep && !new_repetition;
576                        // dil -> dii
577                        let new_replacement = c == m.last && raw_c != c && !new_repetition;
578                        let new_low_confidence_replacement =
579                            new_replacement && raw_c.is_ascii_digit();
580
581                        let undo_m = Match {
582                            spaces: m.spaces.saturating_add(new_space as u8),
583                            skipped: m.skipped.saturating_add(new_skip as u8),
584                            replacements: m.replacements.saturating_add(new_replacement as u8),
585                            low_confidence_replacements: m
586                                .low_confidence_replacements
587                                .saturating_add(new_low_confidence_replacement as u8),
588                            repetitions: m.repetitions.saturating_add(new_repetition as u8),
589                            last: c,
590                            ..m
591                        };
592                        #[cfg(feature = "trace")]
593                        println!("    (keep with last={}, node last={:?}, spaces={}, skip={}, repl={}, repet={})", undo_m.last, undo_m.node.last, undo_m.spaces, undo_m.skipped, undo_m.replacements, undo_m.repetitions);
594
595                        if let Some(existing) = self.allocated.matches.get(&undo_m) {
596                            let replacement = existing.combine(&undo_m);
597                            self.allocated.matches.replace(replacement);
598                        } else {
599                            self.allocated.matches.insert(undo_m);
600                        }
601                    }
602
603                    if let Some(next) = m.node.children.get(&c) {
604                        let new_replacement = !benign_replacement && (c != raw_c) && c != ' ';
605                        let new_low_confidence_replacement =
606                            new_replacement && raw_c.is_ascii_digit();
607                        let new_space =
608                            !new_replacement && (raw_c != c && self.inline.separate && c != '\'');
609
610                        let next_m = Match {
611                            node: next,
612                            spaces: m.spaces.saturating_add(new_space as u8),
613                            replacements: m.replacements.saturating_add(new_replacement as u8),
614                            low_confidence_replacements: m
615                                .low_confidence_replacements
616                                .saturating_add(new_low_confidence_replacement as u8),
617                            last: c,
618                            ..m
619                        };
620
621                        #[cfg(feature = "trace")]
622                        println!(
623                            "     - Next is \"{}\", with spaces={}, replacements={}",
624                            next.trace, next_m.spaces, next_m.replacements
625                        );
626
627                        if next.word {
628                            if next_m.node.typ.is(Type::SAFE)
629                                && next_m.start == 0
630                                && next_m.spaces == 0
631                                && next_m.skipped == 0
632                                && next_m.replacements == 0
633                                && !self.options.ignore_false_positives
634                            {
635                                // Everything in the input until now is safe.
636                                #[cfg(feature = "trace")]
637                                println!("found safe word: {}", next_m.node.trace);
638                                self.inline.safe = true;
639                            }
640
641                            /*
642                            #[cfg(feature = "trace")]
643                            if !next_m.node.typ.is(Type::ANY) {
644                                if self.options.ignore_false_positives {
645                                    print!("ignoring");
646                                } else {
647                                    print!("found");
648                                }
649                                println!(
650                                    " false positive \"{}\", spaces={}, skipped={}, replacements={}",
651                                    next_m.node.trace, next_m.spaces, next_m.skipped, next_m.replacements
652                                );
653                            }
654                            */
655
656                            if next_m.node.typ.is(Type::ANY) {
657                                self.allocated.pending_commit.push(Match {
658                                    end: pos.unwrap(),
659                                    ..next_m
660                                });
661                            } else if next_m.spaces == 0
662                                && next_m.skipped == 0
663                                && next_m.replacements == 0
664                                && next_m.repetitions == 0 // as se
665                                && !self.options.ignore_false_positives
666                            {
667                                // Is false positive, so invalidate internal matches.
668                                #[cfg(feature = "trace")]
669                                println!("Found false positive {}", next_m.node.trace);
670                                drain_start = Some(
671                                    drain_start
672                                        .map(|start| start.min(next_m.start))
673                                        .unwrap_or(next_m.start),
674                                );
675                            }
676                        }
677
678                        if let Some(existing) = self.allocated.matches.get(&next_m) {
679                            let replacement = existing.combine(&next_m);
680                            self.allocated.matches.replace(replacement);
681                        } else {
682                            self.allocated.matches.insert(next_m);
683                        }
684                    }
685                }
686            }
687            self.allocated.matches_tmp.clear();
688            self.inline.last = Some(raw_c);
689            if let Some(pos) = pos {
690                self.inline.last_pos = pos;
691            }
692
693            let spy = &mut self.buffer;
694            let options = &self.options;
695            let inline = &mut self.inline;
696            let pending_commit = &mut self.allocated.pending_commit;
697            #[cfg(feature = "trace_full")]
698            let detections = &mut self.allocated.detections;
699
700            pending_commit.retain(|pending| {
701                #[cfg(feature = "trace")]
702                println!("Consider whether to cancel pending commit {} with start={} against drain_start={:?}", pending.node.trace, pending.start, drain_start);
703
704                // Cancel due to false positive.
705                if let Some(start) = drain_start {
706                    if pending.start >= start {
707                        #[cfg(feature = "trace")]
708                        println!("Cancelled {}", pending.node.trace);
709                        return false;
710                    }
711                }
712
713                // Can pre-commit due to lack of false positive matches.
714                if pending.end < safety_end {
715                    if pending.commit(
716                        &mut inline.typ,
717                        spy,
718                        options.censor_threshold,
719                        options.censor_first_character_threshold,
720                        options.censor_replacement,
721                    ) {
722                        #[cfg(any(feature = "find_false_positives", feature = "trace"))]
723                        {
724                            inline.match_ptrs ^= pending.node as *const _ as usize;
725                            inline.total_matches += 1;
726                            inline.total_match_characters += pending.end - pending.start;
727                            #[cfg(feature = "trace_full")]
728                            {
729                                *detections.entry(pending.node.trace.clone()).or_default() += 1;
730                            }
731                        }
732                    }
733                    return false;
734                }
735
736                // At this point, don't know whether this match will be committed or cancelled, so
737                // return.
738                true
739            });
740
741            // Yield one character if possible.
742            if let Some(spy_next_index) = self.buffer.spy_next_index() {
743                // This covers all in-flight matches.
744                let mut safe_until = spy_next_index < safety_end;
745
746                // This covers all pending commit matches.
747                for pending in &self.allocated.pending_commit {
748                    if pending.start <= spy_next_index {
749                        safe_until = false;
750                        break;
751                    }
752                }
753                if safe_until {
754                    return self.buffer.spy_next();
755                }
756            }
757        }
758
759        let residual = mem::take(&mut self.allocated.pending_commit);
760        #[cfg(feature = "trace")]
761        if !residual.is_empty() {
762            println!("{} residuals", residual.len());
763        }
764        for pending in residual {
765            if pending.commit(
766                &mut self.inline.typ,
767                &mut self.buffer,
768                self.options.censor_threshold,
769                self.options.censor_first_character_threshold,
770                self.options.censor_replacement,
771            ) {
772                #[cfg(any(feature = "find_false_positives", feature = "trace"))]
773                {
774                    self.inline.match_ptrs ^= pending.node as *const _ as usize;
775                    self.inline.total_matches += 1;
776                    self.inline.total_match_characters += pending.end - pending.start;
777                    #[cfg(feature = "trace_full")]
778                    {
779                        *self
780                            .allocated
781                            .detections
782                            .entry(pending.node.trace.clone())
783                            .or_default() += 1;
784                    }
785                }
786            }
787        }
788
789        if let Some(c) = self.buffer.spy_next() {
790            return Some(c);
791        }
792
793        self.inline.done = true;
794
795        None
796    }
797}
798
799/// CensorStr makes it easy to sanitize a `String` or `&str` by calling `.censor()`.
800pub trait CensorStr: Sized {
801    /// The output is a newly allocated, censored string.
802    fn censor(self) -> String;
803
804    /// Returns `true` if the text is inappropriate.
805    fn is_inappropriate(self) -> bool {
806        self.is(Type::INAPPROPRIATE)
807    }
808
809    /// Returns `true` if text meets the provided threshold.
810    fn is(self, threshold: Type) -> bool;
811
812    /// Returns `true` if text **does not** meet the provided threshold.
813    fn isnt(self, threshold: Type) -> bool {
814        !self.is(threshold)
815    }
816}
817
818impl CensorStr for &str {
819    fn censor(self) -> String {
820        if should_skip_censor(self) {
821            self.to_owned()
822        } else {
823            Censor::new(self.chars()).censor()
824        }
825    }
826
827    fn is(self, threshold: Type) -> bool {
828        Censor::from_str(self).analyze().is(threshold)
829    }
830}
831
832/// CensorIter makes it easy to sanitize an arbitrary `Iterator<Item=char>` by calling `.censor()`.
833pub trait CensorIter {
834    type Iterator: Iterator<Item = char>;
835
836    /// Iteratively censor characters, yielding (except accents) those that are not inappropriate, and replacing
837    /// those that are with `'*'`.
838    fn censor(self) -> Self::Iterator;
839}
840
841impl<I: Iterator<Item = char> + Clone> CensorIter for I {
842    type Iterator = Censor<I>;
843
844    /// Censors text, keeping (except accents) those that are not inappropriate, and replacing
845    /// those that are with `'*'`.
846    fn censor(self) -> Self::Iterator {
847        Censor::new(self)
848    }
849}
850
851/// Returns true if censoring won't work but will likely damage the input (e.g. by removing
852/// diacritics). Will consider the entire input.
853pub(crate) fn should_skip_censor(string: &str) -> bool {
854    let mut some_special = false;
855    for c in string.chars() {
856        use finl_unicode::categories::CharacterCategories;
857        // Devanagari is compromised by normalization and diacritic removal.
858        if ('\u{0900}'..='\u{097F}').contains(&c) {
859            some_special = true;
860        } else if !(c.is_whitespace() || c.is_separator()) {
861            return false;
862        }
863    }
864    some_special
865}
866
867/// Adds a word, with the given type. The type can be `Type::SAFE`, or a combination of `Type::PROFANE`,
868/// `Type::Sexual`, `Type::Offensive`, `Type::Mean`, `Type::Mild`, `Type::Moderate`, and `Type::Severe`,
869/// but NOT both (can't be safe and unsafe).
870///
871/// It is recommended to use all lower-case, which will match both cases. Upper-case characters will
872/// only match upper-case.
873///
874/// Prefer the safe API `Censor::with_trie`, using a modified `Trie::default()`.
875///
876/// # Warning
877///
878/// Any profanity words added this way will not support false positives. For example, if you add the word
879/// "field," you can expect "cornfield" to be detected as well, unless you call `add_word("cornfield", Type::None)`.
880///
881/// # Safety
882///
883/// This must not be called when the crate is being used in any other way. It is best to call this
884/// from the main thread, near the beginning of the program.
885#[cfg(feature = "customize")]
886#[deprecated = "Use the equivalent Trie::customize_default().set(word, typ) or the safe API Censor::with_trie"]
887pub unsafe fn add_word(word: &str, typ: Type) {
888    Trie::customize_default().set(word, typ)
889}
890
891#[cfg(test)]
892mod tests {
893    #![allow(unused_imports)]
894
895    extern crate test;
896    use crate::censor::should_skip_censor;
897    use crate::{Censor, CensorIter, CensorStr, Trie, Type};
898    use bitflags::_core::ops::Not;
899    use rand::prelude::ThreadRng;
900    use rand::{thread_rng, Rng};
901    use serial_test::serial;
902    use std::fs::File;
903    use std::io::BufReader;
904    use std::time::{Duration, Instant};
905    use test::Bencher;
906
907    #[test]
908    #[serial]
909    fn short_replacement() {
910        "99".isnt(Type::PROFANE);
911        "900".isnt(Type::PROFANE);
912        "kkk".is(Type::OFFENSIVE);
913    }
914
915    #[test]
916    #[serial]
917    fn unicode_whitespace() {
918        assert!("fu\u{1160}ck".is(Type::PROFANE));
919        assert!(!"fu\u{1161}ck".is(Type::PROFANE));
920    }
921
922    #[test]
923    #[serial]
924    fn unicode_abuse() {
925        let mut rng = thread_rng();
926
927        fn random_string(rng: &mut ThreadRng, len: usize) -> String {
928            rng.sample_iter::<char, _>(rand::distributions::Standard)
929                .take(len)
930                .collect()
931        }
932
933        for _ in 0..10 {
934            let input = random_string(&mut rng, 100);
935            let censored = input.censor();
936
937            // Most of the characters should be removed for being invalid.
938            assert!(censored.len() < input.len() / 2);
939
940            println!("{} -> {}", input, censored);
941        }
942    }
943
944    #[allow(dead_code)]
945    fn find_detection(text: &str) {
946        let holistic = Censor::from_str(text).analyze();
947
948        if holistic & Type::SPAM.not() != Type::NONE {
949            println!("{}", text);
950
951            // There was some non-spam detection.
952            let mut start = 0;
953            let mut end = text.chars().count();
954
955            while start < end
956                && Censor::new(text.chars().skip(start).take(end - start))
957                    .analyze()
958                    .is(Type::ANY)
959            {
960                start += 1;
961            }
962            start = start.saturating_sub(1);
963            while start < end
964                && Censor::new(text.chars().skip(start).take(end - start))
965                    .analyze()
966                    .is(Type::ANY)
967            {
968                end -= 1;
969            }
970            end += 1;
971            for _ in 0..start {
972                print!("-");
973            }
974            for _ in start..end {
975                print!("^");
976            }
977            print!(" ");
978            println!(
979                "(\"{}\" is {:?})",
980                text.chars()
981                    .skip(start)
982                    .take(end - start)
983                    .collect::<String>(),
984                holistic
985            );
986        } else {
987            println!("{} ({:?})", text, holistic);
988        }
989    }
990
991    #[test]
992    #[serial]
993    fn curated() {
994        let mut cases: Vec<(&str, bool, Option<bool>)> = vec![("", false, Some(false))];
995        cases.extend(
996            include_str!("test_positive.txt")
997                .split('\n')
998                .filter(|l| !l.is_empty())
999                .map(|l| (l, true, Some(false))),
1000        );
1001        cases.extend(
1002            include_str!("test_negative.txt")
1003                .split('\n')
1004                .filter(|l| !l.is_empty())
1005                .map(|l| (l, false, None)),
1006        );
1007        cases.extend(
1008            include_str!("safe.txt")
1009                .split('\n')
1010                .filter(|l| !l.is_empty() && !l.starts_with('#'))
1011                .map(|l| (l, false, Some(true))),
1012        );
1013        cases.extend(
1014            include_str!("test_safe.txt")
1015                .split('\n')
1016                .filter(|l| !l.is_empty())
1017                .map(|l| (l, false, Some(true))),
1018        );
1019
1020        let mut failures = Vec::new();
1021
1022        for (case, any_truth, safe_truth) in cases {
1023            /*
1024            #[cfg(debug_assertions)]
1025            println!("Case: \"{}\"", case);
1026             */
1027
1028            let typ = Censor::from_str(case).analyze();
1029            let any = typ.is(Type::ANY);
1030            let safe = typ.is(Type::SAFE);
1031
1032            //let (censored, analysis) = Censor::from_str(case).with_censor_threshold(Type::ANY).censor_and_analyze();
1033            //println!("\"{}\" -> \"{}\" ({}, {})", case, censored, prediction, analysis.is(Type::ANY));
1034
1035            if any != any_truth {
1036                find_detection(case);
1037                failures.push(format!("FAIL: Predicted {:?} for: \"{}\"", typ, case));
1038            } else if !any_truth {
1039                // None of the current test cases contain any abusive Unicode characters.
1040                let censored = case.censor();
1041                if case != censored {
1042                    failures.push(format!("Censored: : \"{case}\" -> {censored}"))
1043                }
1044            }
1045            if let Some(safe_truth) = safe_truth {
1046                if safe != safe_truth {
1047                    failures.push(format!("FAIL: Predicted safe={} for: \"{}\"", safe, case));
1048                }
1049            }
1050        }
1051
1052        if !failures.is_empty() {
1053            for failure in failures {
1054                println!("{failure}");
1055            }
1056            panic!();
1057        }
1058    }
1059
1060    #[test]
1061    #[serial]
1062    fn censor() {
1063        let censored = Censor::from_str("HELLO fučk Shit nudes WORLD!")
1064            .with_censor_replacement('#')
1065            .with_censor_first_character_threshold(Type::SEXUAL & Type::SEVERE)
1066            .censor();
1067
1068        assert_eq!(censored, "HELLO f### S### ##### WORLD!");
1069
1070        // Minor mean-ness is not considered inappropriate
1071        assert_eq!("fcking coward".censor(), "f***** coward");
1072
1073        let censored = Censor::from_str("卍")
1074            .with_censor_first_character_threshold(Type::NONE)
1075            .censor();
1076
1077        assert_eq!(censored, "*");
1078    }
1079
1080    #[test]
1081    #[serial]
1082    fn bidirectional() {
1083        // Censoring removes direction overrides, so that the text output is the text that was analyzed.
1084        assert_eq!("an toidi", "an \u{202e}toidi".censor());
1085    }
1086
1087    #[test]
1088    #[serial]
1089    fn analyze() {
1090        let analysis = Censor::from_str("HELLO fuck shit WORLD!").analyze();
1091
1092        assert_ne!(analysis, Type::NONE);
1093        assert!(analysis.is(Type::INAPPROPRIATE));
1094        assert!(analysis.is(Type::PROFANE));
1095        assert!(analysis.isnt(Type::SEXUAL & Type::SEVERE));
1096        assert!(analysis.isnt(Type::OFFENSIVE));
1097        assert!(analysis.isnt(Type::MEAN));
1098    }
1099
1100    /// This exists purely to ensure all the APIs keep compiling.
1101    #[test]
1102    #[serial]
1103    fn apis() {
1104        "abcd".censor();
1105        String::from("abcd").censor();
1106        let _ = "abcd".chars().censor().collect::<String>();
1107        let (_, _) = Censor::new("abcd".chars())
1108            .with_censor_replacement('?')
1109            .censor_and_analyze();
1110        let mut censor = Censor::from_str("abcd");
1111        let _ = censor.censor();
1112        let _ = censor.analyze();
1113        let (_, _) = Censor::from_str("HELLO crap WORLD!").censor_and_analyze();
1114    }
1115
1116    #[test]
1117    #[serial]
1118    fn levels() {
1119        assert!("poo".is(Type::PROFANE & Type::MILD));
1120        assert!("poo".is(Type::PROFANE & Type::MILD_OR_HIGHER));
1121        assert!("poo".isnt(Type::PROFANE & Type::MODERATE));
1122        assert!("poo".isnt(Type::PROFANE & Type::MODERATE_OR_HIGHER));
1123        assert!("poo".isnt(Type::PROFANE & Type::SEVERE));
1124        assert!("arse".is(Type::PROFANE & Type::MODERATE));
1125        assert!("arse".is(Type::PROFANE & Type::MILD_OR_HIGHER));
1126        assert!("arse".is(Type::PROFANE & Type::MODERATE_OR_HIGHER));
1127        assert!("arse".isnt(Type::PROFANE & Type::MILD));
1128        assert!("arse".isnt(Type::PROFANE & Type::SEVERE));
1129        assert!("i hope you die".is(Type::MEAN & Type::SEVERE));
1130        assert!("i hope you die".is(Type::MEAN & Type::MILD_OR_HIGHER));
1131        assert!("i hope you die".is(Type::MEAN & Type::MODERATE_OR_HIGHER));
1132        assert!("i hope you die".isnt(Type::MEAN & Type::MILD));
1133        assert!("i hope you die".isnt(Type::MEAN & Type::MODERATE));
1134        assert!("You said your mother only smiled on her TV show".isnt(
1135            Type::PROFANE
1136                | Type::OFFENSIVE
1137                | Type::SEXUAL & Type::MODERATE_OR_HIGHER
1138                | Type::MEAN & Type::SEVERE
1139        ));
1140    }
1141
1142    #[test]
1143    #[serial]
1144    fn repetitions_non_safe() {
1145        assert!("hello".is(Type::SAFE));
1146        assert!("helllo".is(Type::SAFE));
1147        assert!("hellllllllo".isnt(Type::SAFE));
1148    }
1149
1150    #[test]
1151    #[serial]
1152    #[cfg(not(debug_assertions))]
1153    fn accuracy() {
1154        fn rustrict(s: &str) -> bool {
1155            s.is(Type::ANY)
1156        }
1157
1158        #[allow(dead_code)]
1159        fn rustrict_old(s: &str) -> bool {
1160            rustrict_old::CensorStr::is(s, rustrict_old::Type::ANY)
1161        }
1162
1163        fn censor(s: &str) -> bool {
1164            use censor_crate::*;
1165            let filter = Standard + Sex + Zealous;
1166            filter.check(s)
1167        }
1168
1169        let mut stfu_filter = stfu_crate::types::OwnedFilter::default();
1170        use stfu_crate::word_lists::severity::{MILD, SEVERE, STRONG};
1171        stfu_filter.add_slice(&MILD);
1172        stfu_filter.add_slice(&STRONG);
1173        stfu_filter.add_slice(&SEVERE);
1174
1175        let stfu = |s: &str| -> bool { stfu_filter.filter_string(s).is_some() };
1176
1177        println!("| Crate | Accuracy | Positive Accuracy | Negative Accuracy | Time |");
1178        println!("|-------|----------|-------------------|-------------------|------|");
1179        print_accuracy(
1180            "https://crates.io/crates/rustrict",
1181            rustrict,
1182            false, // true,
1183            Some(rustrict_old as fn(&str) -> bool).filter(|_| std::env::var("COMPARE").is_ok()),
1184        );
1185        print_accuracy("https://crates.io/crates/censor", censor, false, None);
1186        print_accuracy("https://crates.io/crates/stfu", stfu, false, None);
1187    }
1188
1189    #[allow(dead_code)]
1190    fn print_accuracy(
1191        link: &str,
1192        checker: impl Fn(&str) -> bool,
1193        find_detections: bool,
1194        compare_to: Option<fn(&str) -> bool>,
1195    ) {
1196        let start = Instant::now();
1197        let (total, positive, negative) = accuracy_of(checker, find_detections, compare_to);
1198        println!(
1199            "| [{}]({}) | {:.2}% | {:.2}% | {:.2}% | {:.2}s |",
1200            link.split('/').last().unwrap(),
1201            link,
1202            total * 100.0,
1203            positive * 100.0,
1204            negative * 100.0,
1205            start.elapsed().as_secs()
1206        );
1207    }
1208
1209    #[allow(dead_code)]
1210    fn accuracy_of(
1211        checker: impl Fn(&str) -> bool,
1212        find_detections: bool,
1213        compare_to: Option<fn(&str) -> bool>,
1214    ) -> (f32, f32, f32) {
1215        let file = File::open("test.csv").unwrap();
1216        let reader = BufReader::new(file);
1217        let mut csv = csv::Reader::from_reader(reader);
1218
1219        let mut correct_positive = 0;
1220        let mut correct_negative = 0;
1221        let mut total_positive = 0;
1222        let mut total_negative = 0;
1223
1224        for line in csv.records().take(100000) {
1225            let record = line.unwrap();
1226            let truth = record[0].parse::<i8>().unwrap() == 1;
1227            let text = &record[1];
1228            let prediction = checker(text);
1229            //assert_eq!(is(text), is(text), "With ({})", text);
1230            if prediction == truth {
1231                if truth {
1232                    correct_positive += 1;
1233                } else {
1234                    correct_negative += 1;
1235                }
1236            } else if find_detections && text.len() < 100 {
1237                println!("{}: {}", truth, text);
1238                if prediction {
1239                    find_detection(text);
1240                }
1241            }
1242            if let Some(checker) = compare_to {
1243                let compare_prediction = checker(text);
1244                if prediction != compare_prediction && text.len() < 100 {
1245                    println!("COMPARISON: On \"{}\", output {} instead", text, prediction);
1246                }
1247            }
1248            if truth {
1249                total_positive += 1;
1250            } else {
1251                total_negative += 1;
1252            }
1253        }
1254
1255        (
1256            (correct_positive + correct_negative) as f32 / (total_positive + total_negative) as f32,
1257            correct_positive as f32 / total_positive as f32,
1258            correct_negative as f32 / total_negative as f32,
1259        )
1260    }
1261
1262    #[test]
1263    #[serial]
1264    fn devanagari() {
1265        println!("f\u{0900}u\u{0900}c\u{0900}k");
1266        const TEST: &'static str = "हत्यारा मकसहूद भाई तुम बड़ा मस्त काम करती।";
1267        assert!(should_skip_censor(TEST));
1268        assert_eq!(TEST, TEST.censor());
1269    }
1270
1271    #[test]
1272    #[serial]
1273    fn pancakes() {
1274        assert_eq!(
1275            "🥞",
1276            std::str::from_utf8(&[240, 159, 165, 158]).unwrap().censor()
1277        );
1278    }
1279
1280    #[test]
1281    #[serial]
1282    fn japanese_diacritics_preserved() {
1283        assert_eq!("パピプペポ", "パピプペポ".censor());
1284        assert_eq!("バビブベボ", "バビブベボ".censor());
1285        assert_eq!("ぱぴぷぺぽ", "ぱぴぷぺぽ".censor());
1286        assert_eq!("ばびぶべぼ", "ばびぶべぼ".censor());
1287    }
1288
1289    #[test]
1290    #[serial]
1291    fn bandwidth() {
1292        let file = File::open("test.csv").unwrap();
1293        let total_len = file.metadata().unwrap().len() as usize;
1294        let reader = BufReader::new(file);
1295        let mut csv = csv::Reader::from_reader(reader);
1296
1297        let mut text = String::with_capacity(total_len);
1298
1299        for line in csv.records().take(100000) {
1300            let record = line.unwrap();
1301            text.push_str(&record[1]);
1302        }
1303
1304        for power in 1..16 {
1305            let len = 2usize.pow(power);
1306
1307            if len > text.len() {
1308                break;
1309            }
1310
1311            let now = Instant::now();
1312
1313            let (_, _) = Censor::from_str(&text[0..len]).censor_and_analyze();
1314
1315            let elapsed = now.elapsed();
1316
1317            println!(
1318                "{}, {}, {}",
1319                len,
1320                elapsed.as_secs_f32(),
1321                len as f32 / elapsed.as_secs_f32() / 1000.0 / 1000.0
1322            );
1323        }
1324    }
1325
1326    #[cfg(feature = "customize")]
1327    #[test]
1328    #[serial]
1329    #[allow(deprecated)]
1330    fn customize() {
1331        use crate::add_word;
1332
1333        let test_profanity = "thisisafakeprofanityfortesting";
1334        let test_profanity_issue_7 = "плохоеслово";
1335        let test_safe = "thisisafakesafewordfortesting";
1336
1337        // SAFETY: Tests are run serially, so concurrent mutation is avoided.
1338        unsafe {
1339            add_word(test_profanity, Type::PROFANE & Type::SEVERE);
1340            add_word(test_profanity_issue_7, Type::PROFANE & Type::SEVERE);
1341            add_word(test_safe, Type::SAFE);
1342        }
1343
1344        assert!(test_profanity.is(Type::PROFANE & Type::SEVERE));
1345        assert!(test_profanity_issue_7.is(Type::PROFANE & Type::SEVERE));
1346        assert!(test_safe.is(Type::SAFE));
1347
1348        unsafe {
1349            add_word(test_profanity, Type::NONE);
1350        }
1351
1352        assert!(test_profanity.isnt(Type::PROFANE));
1353    }
1354
1355    #[cfg(feature = "serde")]
1356    #[test]
1357    #[serial]
1358    fn serde() {
1359        let large = Trie::default();
1360        let bc = bincode::serialize(&large).unwrap();
1361        let json = serde_json::to_string(&large).unwrap();
1362        println!("large bincode {}, large json {}", bc.len(), json.len());
1363
1364        let mut trie = Trie::new();
1365        trie.set("squeak", Type::SPAM & Type::MILD);
1366        trie.set("squirrel", Type::SAFE);
1367
1368        let bc = bincode::serialize(&trie).unwrap();
1369        println!("smol bincode (len {}): {bc:?}", bc.len());
1370        let json = serde_json::to_string(&trie).unwrap();
1371        println!("smol json (len {}): {json}", json.len());
1372    }
1373
1374    #[allow(soft_unstable)]
1375    #[bench]
1376    fn bench_is_inappropriate(b: &mut Bencher) {
1377        b.iter(|| test::black_box("hello fuck world shit").is_inappropriate());
1378    }
1379
1380    #[allow(soft_unstable)]
1381    #[bench]
1382    fn bench_is_inappropriate_long(b: &mut Bencher) {
1383        b.iter(|| test::black_box("hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit hello fuck world shit").is_inappropriate());
1384    }
1385
1386    #[allow(soft_unstable)]
1387    #[bench]
1388    fn bench_censor(b: &mut Bencher) {
1389        b.iter(|| test::black_box("hello fuck world shit").censor());
1390    }
1391}