Skip to main content

supertonic_core/
text.rs

1use ndarray::Array3;
2use regex::Regex;
3use std::fs::File;
4use std::io::BufReader;
5use std::path::Path;
6use unicode_normalization::UnicodeNormalization;
7
8pub const AVAILABLE_LANGS: &[&str] = &[
9    "en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi",
10    "fr", "hi", "hr", "hu", "id", "it", "lt", "lv", "nl", "pl", "pt", "ro",
11    "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na",
12];
13
14pub fn is_valid_lang(lang: &str) -> bool {
15    AVAILABLE_LANGS.contains(&lang)
16}
17
18pub struct UnicodeProcessor {
19    indexer: Vec<i64>,
20}
21
22impl UnicodeProcessor {
23    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, anyhow::Error> {
24        let file = File::open(path)?;
25        let reader = BufReader::new(file);
26        let indexer: Vec<i64> = serde_json::from_reader(reader)?;
27        Ok(UnicodeProcessor { indexer })
28    }
29
30    pub fn process(
31        &self,
32        text_list: &[String],
33        lang_list: &[String],
34    ) -> Result<(Vec<Vec<i64>>, Array3<f32>), anyhow::Error> {
35        let mut processed_texts: Vec<String> = Vec::new();
36        for (text, lang) in text_list.iter().zip(lang_list.iter()) {
37            processed_texts.push(preprocess_text(text, lang)?);
38        }
39
40        let text_ids_lengths: Vec<usize> = processed_texts
41            .iter()
42            .map(|t| t.chars().count())
43            .collect();
44
45        let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
46
47        let mut text_ids = Vec::new();
48        for text in &processed_texts {
49            let mut row = vec![0i64; max_len];
50            let unicode_vals = text_to_unicode_values(text);
51            for (j, &val) in unicode_vals.iter().enumerate() {
52                if val < self.indexer.len() {
53                    row[j] = self.indexer[val];
54                } else {
55                    row[j] = -1;
56                }
57            }
58            text_ids.push(row);
59        }
60
61        let text_mask = get_text_mask(&text_ids_lengths);
62
63        Ok((text_ids, text_mask))
64    }
65}
66
67pub fn preprocess_text(text: &str, lang: &str) -> Result<String, anyhow::Error> {
68    let mut text: String = text.nfkd().collect();
69
70    let emoji_pattern = Regex::new(
71        r"[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+",
72    )
73    .unwrap();
74    text = emoji_pattern.replace_all(&text, "").to_string();
75
76    let replacements = [
77        ("\u{2013}", "-"),
78        ("\u{2011}", "-"),
79        ("\u{2014}", "-"),
80        ("_", " "),
81        ("\u{201C}", "\""),
82        ("\u{201D}", "\""),
83        ("\u{2018}", "'"),
84        ("\u{2019}", "'"),
85        ("\u{00B4}", "'"),
86        ("`", "'"),
87        ("[", " "),
88        ("]", " "),
89        ("|", " "),
90        ("/", " "),
91        ("#", " "),
92        ("\u{2192}", " "),
93        ("\u{2190}", " "),
94    ];
95
96    for (from, to) in &replacements {
97        text = text.replace(from, to);
98    }
99
100    let special_symbols = ["\u{2665}", "\u{2606}", "\u{2661}", "\u{00A9}", "\\"];
101    for symbol in &special_symbols {
102        text = text.replace(symbol, "");
103    }
104
105    let expr_replacements = [("@", " at "), ("e.g.,", "for example, "), ("i.e.,", "that is, ")];
106    for (from, to) in &expr_replacements {
107        text = text.replace(from, to);
108    }
109
110    text = Regex::new(r" ,").unwrap().replace_all(&text, ",").to_string();
111    text = Regex::new(r" \.").unwrap().replace_all(&text, ".").to_string();
112    text = Regex::new(r" !").unwrap().replace_all(&text, "!").to_string();
113    text = Regex::new(r" \?").unwrap().replace_all(&text, "?").to_string();
114    text = Regex::new(r" ;").unwrap().replace_all(&text, ";").to_string();
115    text = Regex::new(r" :").unwrap().replace_all(&text, ":").to_string();
116    text = Regex::new(r" '").unwrap().replace_all(&text, "'").to_string();
117
118    while text.contains("\"\"") {
119        text = text.replace("\"\"", "\"");
120    }
121    while text.contains("''") {
122        text = text.replace("''", "'");
123    }
124    while text.contains("``") {
125        text = text.replace("``", "`");
126    }
127
128    text = Regex::new(r"\s+")
129        .unwrap()
130        .replace_all(&text, " ")
131        .to_string();
132    text = text.trim().to_string();
133
134    if !text.is_empty() {
135        let ends_with_punct = Regex::new(
136            r#"[.!?;:,'"\u{201C}\u{201D}\u{2018}\u{2019})\] »。』】〉》›»]$"#,
137        )
138        .unwrap();
139        if !ends_with_punct.is_match(&text) {
140            text.push('.');
141        }
142    }
143
144    if !is_valid_lang(lang) {
145        anyhow::bail!("Invalid language: {}. Available: {:?}", lang, AVAILABLE_LANGS);
146    }
147
148    text = format!("<{}>{}</{}>", lang, text, lang);
149
150    Ok(text)
151}
152
153pub fn text_to_unicode_values(text: &str) -> Vec<usize> {
154    text.chars().map(|c| c as usize).collect()
155}
156
157pub fn length_to_mask(lengths: &[usize], max_len: Option<usize>) -> Array3<f32> {
158    let bsz = lengths.len();
159    let max_len = max_len.unwrap_or_else(|| *lengths.iter().max().unwrap_or(&0));
160
161    let mut mask = Array3::<f32>::zeros((bsz, 1, max_len));
162    for (i, &len) in lengths.iter().enumerate() {
163        for j in 0..len.min(max_len) {
164            mask[[i, 0, j]] = 1.0;
165        }
166    }
167    mask
168}
169
170pub fn get_text_mask(text_ids_lengths: &[usize]) -> Array3<f32> {
171    let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
172    length_to_mask(text_ids_lengths, Some(max_len))
173}
174
175pub fn sample_noisy_latent(
176    duration: &[f32],
177    sample_rate: i32,
178    base_chunk_size: i32,
179    chunk_compress: i32,
180    latent_dim: i32,
181    rng_seed: Option<u64>,
182) -> (Array3<f32>, Array3<f32>) {
183    let bsz = duration.len();
184    let max_dur = duration.iter().fold(0.0f32, |a, &b| a.max(b));
185
186    let wav_len_max = (max_dur * sample_rate as f32) as usize;
187    let wav_lengths: Vec<usize> = duration
188        .iter()
189        .map(|&d| (d * sample_rate as f32) as usize)
190        .collect();
191
192    let chunk_size = (base_chunk_size * chunk_compress) as usize;
193    let latent_len = (wav_len_max + chunk_size - 1) / chunk_size;
194    let latent_dim_val = (latent_dim * chunk_compress) as usize;
195
196    let mut noisy_latent = Array3::<f32>::zeros((bsz, latent_dim_val, latent_len));
197
198    use rand::SeedableRng;
199    use rand_distr::{Distribution, Normal};
200    let mut rng = if let Some(seed) = rng_seed {
201        rand::rngs::StdRng::seed_from_u64(seed)
202    } else {
203        rand::rngs::StdRng::from_entropy()
204    };
205    let normal = Normal::new(0.0, 1.0).unwrap();
206
207    for b in 0..bsz {
208        for d in 0..latent_dim_val {
209            for t in 0..latent_len {
210                noisy_latent[[b, d, t]] = normal.sample(&mut rng);
211            }
212        }
213    }
214
215    let latent_lengths: Vec<usize> = wav_lengths
216        .iter()
217        .map(|&len| (len + chunk_size - 1) / chunk_size)
218        .collect();
219
220    let latent_mask = length_to_mask(&latent_lengths, Some(latent_len));
221
222    for b in 0..bsz {
223        for d in 0..latent_dim_val {
224            for t in 0..latent_len {
225                noisy_latent[[b, d, t]] *= latent_mask[[b, 0, t]];
226            }
227        }
228    }
229
230    (noisy_latent, latent_mask)
231}
232
233const MAX_CHUNK_LENGTH: usize = 300;
234
235const ABBREVIATIONS: &[&str] = &[
236    "Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "St.", "Ave.", "Rd.",
237    "Blvd.", "Dept.", "Inc.", "Ltd.", "Co.", "Corp.", "etc.", "vs.", "i.e.",
238    "e.g.", "Ph.D.",
239];
240
241pub fn chunk_text(text: &str, max_len: Option<usize>) -> Vec<String> {
242    let max_len = max_len.unwrap_or(MAX_CHUNK_LENGTH);
243    let text = text.trim();
244
245    if text.is_empty() {
246        return vec![String::new()];
247    }
248
249    let para_re = Regex::new(r"\n\s*\n").unwrap();
250    let paragraphs: Vec<&str> = para_re.split(text).collect();
251    let mut chunks = Vec::new();
252
253    for para in paragraphs {
254        let para = para.trim();
255        if para.is_empty() {
256            continue;
257        }
258
259        if para.len() <= max_len {
260            chunks.push(para.to_string());
261            continue;
262        }
263
264        let sentences = split_sentences(para);
265        let mut current = String::new();
266        let mut current_len = 0;
267
268        for sentence in sentences {
269            let sentence = sentence.trim();
270            if sentence.is_empty() {
271                continue;
272            }
273
274            let sentence_len = sentence.len();
275            if sentence_len > max_len {
276                if !current.is_empty() {
277                    chunks.push(current.trim().to_string());
278                    current.clear();
279                    current_len = 0;
280                }
281
282                let parts: Vec<&str> = sentence.split(',').collect();
283                for part in parts {
284                    let part = part.trim();
285                    if part.is_empty() {
286                        continue;
287                    }
288
289                    let part_len = part.len();
290                    if part_len > max_len {
291                        let words: Vec<&str> = part.split_whitespace().collect();
292                        let mut word_chunk = String::new();
293                        let mut word_chunk_len = 0;
294
295                        for word in words {
296                            let word_len = word.len();
297                            if word_chunk_len + word_len + 1 > max_len && !word_chunk.is_empty() {
298                                chunks.push(word_chunk.trim().to_string());
299                                word_chunk.clear();
300                                word_chunk_len = 0;
301                            }
302
303                            if !word_chunk.is_empty() {
304                                word_chunk.push(' ');
305                                word_chunk_len += 1;
306                            }
307                            word_chunk.push_str(word);
308                            word_chunk_len += word_len;
309                        }
310
311                        if !word_chunk.is_empty() {
312                            chunks.push(word_chunk.trim().to_string());
313                        }
314                    } else {
315                        if current_len + part_len + 1 > max_len && !current.is_empty() {
316                            chunks.push(current.trim().to_string());
317                            current.clear();
318                            current_len = 0;
319                        }
320
321                        if !current.is_empty() {
322                            current.push_str(", ");
323                            current_len += 2;
324                        }
325                        current.push_str(part);
326                        current_len += part_len;
327                    }
328                }
329                continue;
330            }
331
332            if current_len + sentence_len + 1 > max_len && !current.is_empty() {
333                chunks.push(current.trim().to_string());
334                current.clear();
335                current_len = 0;
336            }
337
338            if !current.is_empty() {
339                current.push(' ');
340                current_len += 1;
341            }
342            current.push_str(sentence);
343            current_len += sentence_len;
344        }
345
346        if !current.is_empty() {
347            chunks.push(current.trim().to_string());
348        }
349    }
350
351    if chunks.is_empty() {
352        vec![String::new()]
353    } else {
354        chunks
355    }
356}
357
358fn split_sentences(text: &str) -> Vec<String> {
359    let re = Regex::new(r"([.!?])\s+").unwrap();
360    let matches: Vec<_> = re.find_iter(text).collect();
361
362    if matches.is_empty() {
363        return vec![text.to_string()];
364    }
365
366    let mut sentences = Vec::new();
367    let mut last_end = 0;
368
369    for m in matches {
370        let before_punc = &text[last_end..m.start()];
371        let mut is_abbrev = false;
372        for abbrev in ABBREVIATIONS {
373            let combined = format!("{}{}", before_punc.trim(), &text[m.start()..m.start() + 1]);
374            if combined.ends_with(abbrev) {
375                is_abbrev = true;
376                break;
377            }
378        }
379
380        if !is_abbrev {
381            sentences.push(text[last_end..m.end()].to_string());
382            last_end = m.end();
383        }
384    }
385
386    if last_end < text.len() {
387        sentences.push(text[last_end..].to_string());
388    }
389
390    if sentences.is_empty() {
391        vec![text.to_string()]
392    } else {
393        sentences
394    }
395}
396
397pub fn max_chunk_len_for_lang(lang: &str) -> usize {
398    if lang == "ko" || lang == "ja" {
399        120
400    } else {
401        300
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_preprocess_text_adds_lang_tags() {
411        let result = preprocess_text("Hello.", "en").unwrap();
412        assert_eq!(result, "<en>Hello.</en>");
413    }
414
415    #[test]
416    fn test_preprocess_text_adds_period() {
417        let result = preprocess_text("Hello", "en").unwrap();
418        assert_eq!(result, "<en>Hello.</en>");
419    }
420
421    #[test]
422    fn test_preprocess_text_removes_emoji() {
423        let result = preprocess_text("Hi 😊.", "en").unwrap();
424        assert_eq!(result, "<en>Hi.</en>");
425    }
426
427    #[test]
428    fn test_is_valid_lang() {
429        assert!(is_valid_lang("en"));
430        assert!(is_valid_lang("ko"));
431        assert!(!is_valid_lang("zz"));
432    }
433
434    #[test]
435    fn test_chunk_text_short() {
436        let chunks = chunk_text("Hello world.", Some(300));
437        assert_eq!(chunks.len(), 1);
438    }
439
440    #[test]
441    fn test_text_to_unicode_values() {
442        let vals = text_to_unicode_values("A");
443        assert_eq!(vals, vec![65]);
444    }
445
446    #[test]
447    fn test_length_to_mask() {
448        let mask = length_to_mask(&[3], Some(5));
449        assert_eq!(mask[[0, 0, 0]], 1.0);
450        assert_eq!(mask[[0, 0, 2]], 1.0);
451        assert_eq!(mask[[0, 0, 3]], 0.0);
452    }
453
454    #[test]
455    fn test_max_chunk_len_for_lang() {
456        assert_eq!(max_chunk_len_for_lang("en"), 300);
457        assert_eq!(max_chunk_len_for_lang("ko"), 120);
458        assert_eq!(max_chunk_len_for_lang("ja"), 120);
459    }
460}