tokenizers/normalizers/
replace.rs1use crate::tokenizer::pattern::Pattern;
2use crate::tokenizer::Decoder;
3use crate::tokenizer::{NormalizedString, Normalizer, Result};
4use crate::utils::SysRegex;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
9pub enum ReplacePattern {
10 String(String),
11 Regex(String),
12}
13
14impl From<String> for ReplacePattern {
15 fn from(v: String) -> Self {
16 Self::String(v)
17 }
18}
19
20impl From<&str> for ReplacePattern {
21 fn from(v: &str) -> Self {
22 Self::String(v.to_owned())
23 }
24}
25
26#[doc(hidden)]
28#[derive(Deserialize)]
29#[serde(tag = "type")]
30struct ReplaceDeserializer {
31 pattern: ReplacePattern,
32 content: String,
33}
34
35impl std::convert::TryFrom<ReplaceDeserializer> for Replace {
36 type Error = Box<dyn std::error::Error + Send + Sync>;
37
38 fn try_from(v: ReplaceDeserializer) -> Result<Self> {
39 Self::new(v.pattern, v.content)
40 }
41}
42
43#[derive(Debug, Serialize, Deserialize)]
46#[serde(tag = "type", try_from = "ReplaceDeserializer")]
47pub struct Replace {
48 pattern: ReplacePattern,
49 content: String,
50 #[serde(skip)]
51 regex: SysRegex,
52}
53
54impl Clone for Replace {
55 fn clone(&self) -> Self {
56 Self::new(self.pattern.clone(), &self.content).unwrap()
57 }
58}
59
60impl PartialEq for Replace {
61 fn eq(&self, other: &Self) -> bool {
62 self.pattern == other.pattern && self.content == other.content
63 }
64}
65
66impl Replace {
67 pub fn new<I: Into<ReplacePattern>, C: Into<String>>(pattern: I, content: C) -> Result<Self> {
68 let pattern: ReplacePattern = pattern.into();
69 let regex = match &pattern {
70 ReplacePattern::String(s) => SysRegex::new(®ex::escape(s))?,
71 ReplacePattern::Regex(r) => SysRegex::new(r)?,
72 };
73
74 Ok(Self {
75 pattern,
76 content: content.into(),
77 regex,
78 })
79 }
80}
81
82impl Normalizer for Replace {
83 fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
84 normalized.replace(&self.regex, &self.content)
85 }
86}
87
88impl Decoder for Replace {
89 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
90 tokens
91 .into_iter()
92 .map(|token| -> Result<String> {
93 let mut new_token = "".to_string();
94
95 for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
96 if is_match {
97 new_token.push_str(&self.content);
98 } else {
99 new_token.push_str(&token[start..stop]);
100 }
101 }
102 Ok(new_token)
103 })
104 .collect()
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_replace() {
114 let original = "This is a ''test''";
115 let normalized = "This is a \"test\"";
116
117 let mut n = NormalizedString::from(original);
118 Replace::new("''", "\"").unwrap().normalize(&mut n).unwrap();
119
120 assert_eq!(&n.get(), &normalized);
121 }
122
123 #[test]
124 fn test_replace_regex() {
125 let original = "This is a test";
126 let normalized = "This is a test";
127
128 let mut n = NormalizedString::from(original);
129 Replace::new(ReplacePattern::Regex(r"\s+".into()), ' ')
130 .unwrap()
131 .normalize(&mut n)
132 .unwrap();
133
134 assert_eq!(&n.get(), &normalized);
135 }
136
137 #[test]
138 fn serialization() {
139 let replace = Replace::new("Hello", "Hey").unwrap();
140 let replace_s = r#"{"type":"Replace","pattern":{"String":"Hello"},"content":"Hey"}"#;
141 assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
142 assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
143
144 let replace = Replace::new(ReplacePattern::Regex(r"\s+".into()), ' ').unwrap();
145 let replace_s = r#"{"type":"Replace","pattern":{"Regex":"\\s+"},"content":" "}"#;
146 assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
147 assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
148 }
149
150 #[test]
151 fn test_replace_decode() {
152 let original = vec!["hello".to_string(), "_hello".to_string()];
153 let replace = Replace::new("_", " ").unwrap();
154 assert_eq!(
155 replace.decode_chain(original).unwrap(),
156 vec!["hello", " hello"]
157 );
158 }
159}