tokenizers/normalizers/
replace.rs

1use crate::tokenizer::pattern::Pattern;
2use crate::tokenizer::Decoder;
3use crate::tokenizer::{NormalizedString, Normalizer, Result};
4use crate::utils::SysRegex;
5use serde::{Deserialize, Serialize};
6
7/// Represents the different patterns that `Replace` can use
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
9pub enum ReplacePattern {
10    String(String),
11    Regex(String),
12}
13
14impl From<String> for ReplacePattern {
15    fn from(v: String) -> Self {
16        Self::String(v)
17    }
18}
19
20impl From<&str> for ReplacePattern {
21    fn from(v: &str) -> Self {
22        Self::String(v.to_owned())
23    }
24}
25
26/// We use this custom deserializer to provide the value for `regex` for `Replace`
27#[doc(hidden)]
28#[derive(Deserialize)]
29#[serde(tag = "type")]
30struct ReplaceDeserializer {
31    pattern: ReplacePattern,
32    content: String,
33}
34
35impl std::convert::TryFrom<ReplaceDeserializer> for Replace {
36    type Error = Box<dyn std::error::Error + Send + Sync>;
37
38    fn try_from(v: ReplaceDeserializer) -> Result<Self> {
39        Self::new(v.pattern, v.content)
40    }
41}
42
43/// This normalizer will take a `pattern` (for now only a String)
44/// and replace every occurrence with `content`.
45#[derive(Debug, Serialize, Deserialize)]
46#[serde(tag = "type", try_from = "ReplaceDeserializer")]
47pub struct Replace {
48    pattern: ReplacePattern,
49    content: String,
50    #[serde(skip)]
51    regex: SysRegex,
52}
53
54impl Clone for Replace {
55    fn clone(&self) -> Self {
56        Self::new(self.pattern.clone(), &self.content).unwrap()
57    }
58}
59
60impl PartialEq for Replace {
61    fn eq(&self, other: &Self) -> bool {
62        self.pattern == other.pattern && self.content == other.content
63    }
64}
65
66impl Replace {
67    pub fn new<I: Into<ReplacePattern>, C: Into<String>>(pattern: I, content: C) -> Result<Self> {
68        let pattern: ReplacePattern = pattern.into();
69        let regex = match &pattern {
70            ReplacePattern::String(s) => SysRegex::new(&regex::escape(s))?,
71            ReplacePattern::Regex(r) => SysRegex::new(r)?,
72        };
73
74        Ok(Self {
75            pattern,
76            content: content.into(),
77            regex,
78        })
79    }
80}
81
82impl Normalizer for Replace {
83    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
84        normalized.replace(&self.regex, &self.content)
85    }
86}
87
88impl Decoder for Replace {
89    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
90        tokens
91            .into_iter()
92            .map(|token| -> Result<String> {
93                let mut new_token = "".to_string();
94
95                for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
96                    if is_match {
97                        new_token.push_str(&self.content);
98                    } else {
99                        new_token.push_str(&token[start..stop]);
100                    }
101                }
102                Ok(new_token)
103            })
104            .collect()
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_replace() {
114        let original = "This is a ''test''";
115        let normalized = "This is a \"test\"";
116
117        let mut n = NormalizedString::from(original);
118        Replace::new("''", "\"").unwrap().normalize(&mut n).unwrap();
119
120        assert_eq!(&n.get(), &normalized);
121    }
122
123    #[test]
124    fn test_replace_regex() {
125        let original = "This     is   a         test";
126        let normalized = "This is a test";
127
128        let mut n = NormalizedString::from(original);
129        Replace::new(ReplacePattern::Regex(r"\s+".into()), ' ')
130            .unwrap()
131            .normalize(&mut n)
132            .unwrap();
133
134        assert_eq!(&n.get(), &normalized);
135    }
136
137    #[test]
138    fn serialization() {
139        let replace = Replace::new("Hello", "Hey").unwrap();
140        let replace_s = r#"{"type":"Replace","pattern":{"String":"Hello"},"content":"Hey"}"#;
141        assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
142        assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
143
144        let replace = Replace::new(ReplacePattern::Regex(r"\s+".into()), ' ').unwrap();
145        let replace_s = r#"{"type":"Replace","pattern":{"Regex":"\\s+"},"content":" "}"#;
146        assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
147        assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
148    }
149
150    #[test]
151    fn test_replace_decode() {
152        let original = vec!["hello".to_string(), "_hello".to_string()];
153        let replace = Replace::new("_", " ").unwrap();
154        assert_eq!(
155            replace.decode_chain(original).unwrap(),
156            vec!["hello", " hello"]
157        );
158    }
159}