wetext_rs/
normalizer.rs

1//! Main Normalizer implementation
2//!
3//! This module provides the main Normalizer struct that orchestrates
4//! the text normalization pipeline.
5
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9use crate::config::{Language, NormalizerConfig, Operator};
10use crate::contractions::fix_contractions;
11use crate::error::{Result, WeTextError};
12use crate::text_normalizer::FstTextNormalizer;
13use crate::token_parser::TokenParser;
14
15/// FST file cache for lazy loading
16struct FstCache {
17    fsts: HashMap<String, FstTextNormalizer>,
18    fst_dir: PathBuf,
19}
20
21impl FstCache {
22    fn new<P: AsRef<Path>>(fst_dir: P) -> Self {
23        Self {
24            fsts: HashMap::new(),
25            fst_dir: fst_dir.as_ref().to_path_buf(),
26        }
27    }
28
29    fn get_or_load(&mut self, relative_path: &str) -> Result<&FstTextNormalizer> {
30        if !self.fsts.contains_key(relative_path) {
31            let full_path = self.fst_dir.join(relative_path);
32            let normalizer = FstTextNormalizer::from_file(&full_path)?;
33            self.fsts.insert(relative_path.to_string(), normalizer);
34        }
35        Ok(self.fsts.get(relative_path).unwrap())
36    }
37}
38
39/// WeText Normalizer
40///
41/// Main entry point for text normalization functionality.
42/// Supports Text Normalization (TN) and Inverse Text Normalization (ITN)
43/// for Chinese, English, and Japanese.
44///
45/// # Example
46/// ```rust,ignore
47/// use wetext_rs::{Normalizer, NormalizerConfig, Language};
48///
49/// let config = NormalizerConfig::new().with_lang(Language::Zh);
50/// let mut normalizer = Normalizer::new("path/to/fsts", config);
51/// let result = normalizer.normalize("2024年").unwrap();
52/// // Result: "二零二四年"
53/// ```
54pub struct Normalizer {
55    config: NormalizerConfig,
56    cache: FstCache,
57}
58
59impl Normalizer {
60    /// Create a new Normalizer
61    ///
62    /// # Arguments
63    /// * `fst_dir` - Directory containing FST weight files
64    /// * `config` - Normalizer configuration
65    pub fn new<P: AsRef<Path>>(fst_dir: P, config: NormalizerConfig) -> Self {
66        Self {
67            config,
68            cache: FstCache::new(fst_dir),
69        }
70    }
71
72    /// Create a Normalizer with default configuration
73    pub fn with_defaults<P: AsRef<Path>>(fst_dir: P) -> Self {
74        Self::new(fst_dir, NormalizerConfig::default())
75    }
76
77    /// Normalize text using the configured settings
78    pub fn normalize(&mut self, text: &str) -> Result<String> {
79        self.normalize_with_config(text, &self.config.clone())
80    }
81
82    /// Normalize text with a specific configuration
83    pub fn normalize_with_config(
84        &mut self,
85        text: &str,
86        config: &NormalizerConfig,
87    ) -> Result<String> {
88        let mut text = text.to_string();
89
90        // 1. Fix English contractions
91        if config.fix_contractions && text.contains('\'') {
92            text = fix_contractions(&text);
93        }
94
95        // 2. Preprocessing
96        text = self.preprocess(&text, config)?;
97
98        // 3. Detect language
99        let lang = if config.lang == Language::Auto {
100            Self::detect_language(&text)
101        } else {
102            config.lang
103        };
104
105        // 4. Check if normalization is needed
106        if self.should_normalize(&text, config.operator, config.remove_erhua) {
107            // English ITN is not supported in Python wetext (raises NotImplementedError).
108            // Fallback to Chinese ITN as a workaround, matching Python behavior.
109            let lang = if lang == Language::En && config.operator == Operator::Itn {
110                Language::Zh
111            } else {
112                lang
113            };
114
115            // 4.1 Tagger: tag entities
116            text = self.tag(&text, lang, config)?;
117
118            // 4.2 Reorder: reorder token fields
119            text = self.reorder(&text, lang, config.operator)?;
120
121            // 4.3 Verbalizer: convert to spoken form
122            text = self.verbalize(&text, lang, config)?;
123        }
124
125        // 5. Postprocessing
126        text = self.postprocess(&text, config)?;
127
128        Ok(text)
129    }
130
131    /// Detect text language
132    ///
133    /// **Note:** This implementation extends the original Python version with Japanese detection.
134    /// Python wetext only detects Chinese vs English. This Rust version adds Japanese support
135    /// by detecting Hiragana/Katakana characters.
136    ///
137    /// Detection priority:
138    /// 1. Japanese (Hiragana/Katakana) - Rust extension, not in Python version
139    /// 2. Chinese (CJK Unified Ideographs)
140    /// 3. Numeric-only text (digits, punctuation, symbols) - treated as Chinese
141    /// 4. Default to English
142    fn detect_language(text: &str) -> Language {
143        let mut has_cjk = false;
144        let mut has_alpha = false;
145
146        for ch in text.chars() {
147            // [Rust Extension] Japanese detection via Hiragana/Katakana
148            // Japanese Hiragana: U+3040 - U+309F
149            // Japanese Katakana: U+30A0 - U+30FF
150            // Note: Python wetext does NOT have this detection - it would return "zh" for Japanese text
151            if ('\u{3040}'..='\u{309f}').contains(&ch) || ('\u{30a0}'..='\u{30ff}').contains(&ch) {
152                return Language::Ja;
153            }
154
155            // CJK Unified Ideographs: U+4E00 - U+9FFF
156            // Note: These are shared between Chinese and Japanese
157            // If we find hiragana/katakana, it's Japanese; otherwise treat as Chinese
158            if ('\u{4e00}'..='\u{9fff}').contains(&ch) {
159                has_cjk = true;
160            }
161
162            // Track if there are any ASCII alphabetic characters
163            if ch.is_ascii_alphabetic() {
164                has_alpha = true;
165            }
166        }
167
168        // If contains CJK but no Japanese-specific characters, treat as Chinese
169        if has_cjk {
170            return Language::Zh;
171        }
172
173        // Numeric-only text (no alphabetic characters) treated as Chinese
174        // This covers cases like "123", "3/4", "1.5", "2024年" (when year char is not present)
175        if !text.is_empty() && !has_alpha {
176            return Language::Zh;
177        }
178
179        Language::En
180    }
181
182    /// Check if normalization is needed
183    fn should_normalize(&self, text: &str, operator: Operator, remove_erhua: bool) -> bool {
184        if operator == Operator::Tn {
185            // TN: needs normalization if contains digits
186            if text.chars().any(|c| c.is_ascii_digit()) {
187                return true;
188            }
189            // Or if need to remove erhua
190            if remove_erhua && (text.contains('儿') || text.contains('兒')) {
191                return true;
192            }
193            false
194        } else {
195            // ITN: non-empty text needs processing
196            !text.is_empty()
197        }
198    }
199
200    /// Preprocessing step
201    fn preprocess(&mut self, text: &str, config: &NormalizerConfig) -> Result<String> {
202        let mut result = text.trim().to_string();
203
204        if config.traditional_to_simple {
205            let fst = self.cache.get_or_load("traditional_to_simple.fst")?;
206            result = fst.normalize(&result)?;
207        }
208
209        Ok(result)
210    }
211
212    /// Postprocessing step
213    fn postprocess(&mut self, text: &str, config: &NormalizerConfig) -> Result<String> {
214        let mut result = text.to_string();
215
216        if config.full_to_half {
217            let fst = self.cache.get_or_load("full_to_half.fst")?;
218            result = fst.normalize(&result)?;
219        }
220
221        if config.remove_interjections {
222            let fst = self.cache.get_or_load("remove_interjections.fst")?;
223            result = fst.normalize(&result)?;
224        }
225
226        if config.remove_puncts {
227            let fst = self.cache.get_or_load("remove_puncts.fst")?;
228            result = fst.normalize(&result)?;
229        }
230
231        if config.tag_oov {
232            let fst = self.cache.get_or_load("tag_oov.fst")?;
233            result = fst.normalize(&result)?;
234        }
235
236        Ok(result.trim().to_string())
237    }
238
239    /// Tag entities using tagger FST
240    fn tag(&mut self, text: &str, lang: Language, config: &NormalizerConfig) -> Result<String> {
241        let fst_path = match (lang, config.operator) {
242            (Language::En, Operator::Tn) => "en/tn/tagger.fst",
243            (Language::Zh, Operator::Tn) => "zh/tn/tagger.fst",
244            (Language::Zh, Operator::Itn) => {
245                if config.enable_0_to_9 {
246                    "zh/itn/tagger_enable_0_to_9.fst"
247                } else {
248                    "zh/itn/tagger.fst"
249                }
250            }
251            (Language::Ja, Operator::Tn) => "ja/tn/tagger.fst",
252            (Language::Ja, Operator::Itn) => {
253                if config.enable_0_to_9 {
254                    "ja/itn/tagger_enable_0_to_9.fst"
255                } else {
256                    "ja/itn/tagger.fst"
257                }
258            }
259            _ => return Err(WeTextError::InvalidLanguage(format!("{:?}", lang))),
260        };
261
262        let fst = self.cache.get_or_load(fst_path)?;
263        let result = fst.normalize(text)?;
264        Ok(result.trim().to_string())
265    }
266
267    /// Reorder token fields
268    fn reorder(&self, text: &str, lang: Language, operator: Operator) -> Result<String> {
269        let parser = TokenParser::new(lang, operator);
270        parser.reorder(text)
271    }
272
273    /// Verbalize using verbalizer FST
274    fn verbalize(
275        &mut self,
276        text: &str,
277        lang: Language,
278        config: &NormalizerConfig,
279    ) -> Result<String> {
280        let fst_path = match (lang, config.operator) {
281            (Language::En, Operator::Tn) => "en/tn/verbalizer.fst",
282            (Language::Zh, Operator::Tn) => {
283                if config.remove_erhua {
284                    "zh/tn/verbalizer_remove_erhua.fst"
285                } else {
286                    "zh/tn/verbalizer.fst"
287                }
288            }
289            (Language::Zh, Operator::Itn) => "zh/itn/verbalizer.fst",
290            (Language::Ja, Operator::Tn) => "ja/tn/verbalizer.fst",
291            (Language::Ja, Operator::Itn) => "ja/itn/verbalizer.fst",
292            _ => return Err(WeTextError::InvalidLanguage(format!("{:?}", lang))),
293        };
294
295        let fst = self.cache.get_or_load(fst_path)?;
296        let result = fst.normalize(text)?;
297        Ok(result.trim().to_string())
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_detect_language() {
307        // English
308        assert_eq!(Normalizer::detect_language("hello world"), Language::En);
309        assert_eq!(Normalizer::detect_language("Hello, World!"), Language::En);
310
311        // Chinese
312        assert_eq!(Normalizer::detect_language("你好世界"), Language::Zh);
313        assert_eq!(Normalizer::detect_language("今天是2024年"), Language::Zh);
314
315        // Japanese (Hiragana/Katakana triggers Japanese detection)
316        assert_eq!(Normalizer::detect_language("こんにちは"), Language::Ja); // Hiragana
317        assert_eq!(Normalizer::detect_language("カタカナ"), Language::Ja); // Katakana
318        assert_eq!(Normalizer::detect_language("東京タワー"), Language::Ja); // Mixed Kanji + Katakana
319
320        // Pure digits treated as Chinese (common TTS use case)
321        assert_eq!(Normalizer::detect_language("123"), Language::Zh);
322        assert_eq!(Normalizer::detect_language("2024"), Language::Zh);
323
324        // Edge cases
325        assert_eq!(Normalizer::detect_language(""), Language::En); // Empty defaults to English
326    }
327}