1use ndarray::Array3;
2use regex::Regex;
3use std::fs::File;
4use std::io::BufReader;
5use std::path::Path;
6use unicode_normalization::UnicodeNormalization;
7
8pub const AVAILABLE_LANGS: &[&str] = &[
9 "en", "ko", "ja", "ar", "bg", "cs", "da", "de", "el", "es", "et", "fi",
10 "fr", "hi", "hr", "hu", "id", "it", "lt", "lv", "nl", "pl", "pt", "ro",
11 "ru", "sk", "sl", "sv", "tr", "uk", "vi", "na",
12];
13
14pub fn is_valid_lang(lang: &str) -> bool {
15 AVAILABLE_LANGS.contains(&lang)
16}
17
18pub struct UnicodeProcessor {
19 indexer: Vec<i64>,
20}
21
22impl UnicodeProcessor {
23 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, anyhow::Error> {
24 let file = File::open(path)?;
25 let reader = BufReader::new(file);
26 let indexer: Vec<i64> = serde_json::from_reader(reader)?;
27 Ok(UnicodeProcessor { indexer })
28 }
29
30 pub fn process(
31 &self,
32 text_list: &[String],
33 lang_list: &[String],
34 ) -> Result<(Vec<Vec<i64>>, Array3<f32>), anyhow::Error> {
35 let mut processed_texts: Vec<String> = Vec::new();
36 for (text, lang) in text_list.iter().zip(lang_list.iter()) {
37 processed_texts.push(preprocess_text(text, lang)?);
38 }
39
40 let text_ids_lengths: Vec<usize> = processed_texts
41 .iter()
42 .map(|t| t.chars().count())
43 .collect();
44
45 let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
46
47 let mut text_ids = Vec::new();
48 for text in &processed_texts {
49 let mut row = vec![0i64; max_len];
50 let unicode_vals = text_to_unicode_values(text);
51 for (j, &val) in unicode_vals.iter().enumerate() {
52 if val < self.indexer.len() {
53 row[j] = self.indexer[val];
54 } else {
55 row[j] = -1;
56 }
57 }
58 text_ids.push(row);
59 }
60
61 let text_mask = get_text_mask(&text_ids_lengths);
62
63 Ok((text_ids, text_mask))
64 }
65}
66
67pub fn preprocess_text(text: &str, lang: &str) -> Result<String, anyhow::Error> {
68 let mut text: String = text.nfkd().collect();
69
70 let emoji_pattern = Regex::new(
71 r"[\x{1F600}-\x{1F64F}\x{1F300}-\x{1F5FF}\x{1F680}-\x{1F6FF}\x{1F700}-\x{1F77F}\x{1F780}-\x{1F7FF}\x{1F800}-\x{1F8FF}\x{1F900}-\x{1F9FF}\x{1FA00}-\x{1FA6F}\x{1FA70}-\x{1FAFF}\x{2600}-\x{26FF}\x{2700}-\x{27BF}\x{1F1E6}-\x{1F1FF}]+",
72 )
73 .unwrap();
74 text = emoji_pattern.replace_all(&text, "").to_string();
75
76 let replacements = [
77 ("\u{2013}", "-"),
78 ("\u{2011}", "-"),
79 ("\u{2014}", "-"),
80 ("_", " "),
81 ("\u{201C}", "\""),
82 ("\u{201D}", "\""),
83 ("\u{2018}", "'"),
84 ("\u{2019}", "'"),
85 ("\u{00B4}", "'"),
86 ("`", "'"),
87 ("[", " "),
88 ("]", " "),
89 ("|", " "),
90 ("/", " "),
91 ("#", " "),
92 ("\u{2192}", " "),
93 ("\u{2190}", " "),
94 ];
95
96 for (from, to) in &replacements {
97 text = text.replace(from, to);
98 }
99
100 let special_symbols = ["\u{2665}", "\u{2606}", "\u{2661}", "\u{00A9}", "\\"];
101 for symbol in &special_symbols {
102 text = text.replace(symbol, "");
103 }
104
105 let expr_replacements = [("@", " at "), ("e.g.,", "for example, "), ("i.e.,", "that is, ")];
106 for (from, to) in &expr_replacements {
107 text = text.replace(from, to);
108 }
109
110 text = Regex::new(r" ,").unwrap().replace_all(&text, ",").to_string();
111 text = Regex::new(r" \.").unwrap().replace_all(&text, ".").to_string();
112 text = Regex::new(r" !").unwrap().replace_all(&text, "!").to_string();
113 text = Regex::new(r" \?").unwrap().replace_all(&text, "?").to_string();
114 text = Regex::new(r" ;").unwrap().replace_all(&text, ";").to_string();
115 text = Regex::new(r" :").unwrap().replace_all(&text, ":").to_string();
116 text = Regex::new(r" '").unwrap().replace_all(&text, "'").to_string();
117
118 while text.contains("\"\"") {
119 text = text.replace("\"\"", "\"");
120 }
121 while text.contains("''") {
122 text = text.replace("''", "'");
123 }
124 while text.contains("``") {
125 text = text.replace("``", "`");
126 }
127
128 text = Regex::new(r"\s+")
129 .unwrap()
130 .replace_all(&text, " ")
131 .to_string();
132 text = text.trim().to_string();
133
134 if !text.is_empty() {
135 let ends_with_punct = Regex::new(
136 r#"[.!?;:,'"\u{201C}\u{201D}\u{2018}\u{2019})\] »。』】〉》›»]$"#,
137 )
138 .unwrap();
139 if !ends_with_punct.is_match(&text) {
140 text.push('.');
141 }
142 }
143
144 if !is_valid_lang(lang) {
145 anyhow::bail!("Invalid language: {}. Available: {:?}", lang, AVAILABLE_LANGS);
146 }
147
148 text = format!("<{}>{}</{}>", lang, text, lang);
149
150 Ok(text)
151}
152
153pub fn text_to_unicode_values(text: &str) -> Vec<usize> {
154 text.chars().map(|c| c as usize).collect()
155}
156
157pub fn length_to_mask(lengths: &[usize], max_len: Option<usize>) -> Array3<f32> {
158 let bsz = lengths.len();
159 let max_len = max_len.unwrap_or_else(|| *lengths.iter().max().unwrap_or(&0));
160
161 let mut mask = Array3::<f32>::zeros((bsz, 1, max_len));
162 for (i, &len) in lengths.iter().enumerate() {
163 for j in 0..len.min(max_len) {
164 mask[[i, 0, j]] = 1.0;
165 }
166 }
167 mask
168}
169
170pub fn get_text_mask(text_ids_lengths: &[usize]) -> Array3<f32> {
171 let max_len = *text_ids_lengths.iter().max().unwrap_or(&0);
172 length_to_mask(text_ids_lengths, Some(max_len))
173}
174
175pub fn sample_noisy_latent(
176 duration: &[f32],
177 sample_rate: i32,
178 base_chunk_size: i32,
179 chunk_compress: i32,
180 latent_dim: i32,
181 rng_seed: Option<u64>,
182) -> (Array3<f32>, Array3<f32>) {
183 let bsz = duration.len();
184 let max_dur = duration.iter().fold(0.0f32, |a, &b| a.max(b));
185
186 let wav_len_max = (max_dur * sample_rate as f32) as usize;
187 let wav_lengths: Vec<usize> = duration
188 .iter()
189 .map(|&d| (d * sample_rate as f32) as usize)
190 .collect();
191
192 let chunk_size = (base_chunk_size * chunk_compress) as usize;
193 let latent_len = (wav_len_max + chunk_size - 1) / chunk_size;
194 let latent_dim_val = (latent_dim * chunk_compress) as usize;
195
196 let mut noisy_latent = Array3::<f32>::zeros((bsz, latent_dim_val, latent_len));
197
198 use rand::SeedableRng;
199 use rand_distr::{Distribution, Normal};
200 let mut rng = if let Some(seed) = rng_seed {
201 rand::rngs::StdRng::seed_from_u64(seed)
202 } else {
203 rand::rngs::StdRng::from_entropy()
204 };
205 let normal = Normal::new(0.0, 1.0).unwrap();
206
207 for b in 0..bsz {
208 for d in 0..latent_dim_val {
209 for t in 0..latent_len {
210 noisy_latent[[b, d, t]] = normal.sample(&mut rng);
211 }
212 }
213 }
214
215 let latent_lengths: Vec<usize> = wav_lengths
216 .iter()
217 .map(|&len| (len + chunk_size - 1) / chunk_size)
218 .collect();
219
220 let latent_mask = length_to_mask(&latent_lengths, Some(latent_len));
221
222 for b in 0..bsz {
223 for d in 0..latent_dim_val {
224 for t in 0..latent_len {
225 noisy_latent[[b, d, t]] *= latent_mask[[b, 0, t]];
226 }
227 }
228 }
229
230 (noisy_latent, latent_mask)
231}
232
233const MAX_CHUNK_LENGTH: usize = 300;
234
235const ABBREVIATIONS: &[&str] = &[
236 "Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "St.", "Ave.", "Rd.",
237 "Blvd.", "Dept.", "Inc.", "Ltd.", "Co.", "Corp.", "etc.", "vs.", "i.e.",
238 "e.g.", "Ph.D.",
239];
240
241pub fn chunk_text(text: &str, max_len: Option<usize>) -> Vec<String> {
242 let max_len = max_len.unwrap_or(MAX_CHUNK_LENGTH);
243 let text = text.trim();
244
245 if text.is_empty() {
246 return vec![String::new()];
247 }
248
249 let para_re = Regex::new(r"\n\s*\n").unwrap();
250 let paragraphs: Vec<&str> = para_re.split(text).collect();
251 let mut chunks = Vec::new();
252
253 for para in paragraphs {
254 let para = para.trim();
255 if para.is_empty() {
256 continue;
257 }
258
259 if para.len() <= max_len {
260 chunks.push(para.to_string());
261 continue;
262 }
263
264 let sentences = split_sentences(para);
265 let mut current = String::new();
266 let mut current_len = 0;
267
268 for sentence in sentences {
269 let sentence = sentence.trim();
270 if sentence.is_empty() {
271 continue;
272 }
273
274 let sentence_len = sentence.len();
275 if sentence_len > max_len {
276 if !current.is_empty() {
277 chunks.push(current.trim().to_string());
278 current.clear();
279 current_len = 0;
280 }
281
282 let parts: Vec<&str> = sentence.split(',').collect();
283 for part in parts {
284 let part = part.trim();
285 if part.is_empty() {
286 continue;
287 }
288
289 let part_len = part.len();
290 if part_len > max_len {
291 let words: Vec<&str> = part.split_whitespace().collect();
292 let mut word_chunk = String::new();
293 let mut word_chunk_len = 0;
294
295 for word in words {
296 let word_len = word.len();
297 if word_chunk_len + word_len + 1 > max_len && !word_chunk.is_empty() {
298 chunks.push(word_chunk.trim().to_string());
299 word_chunk.clear();
300 word_chunk_len = 0;
301 }
302
303 if !word_chunk.is_empty() {
304 word_chunk.push(' ');
305 word_chunk_len += 1;
306 }
307 word_chunk.push_str(word);
308 word_chunk_len += word_len;
309 }
310
311 if !word_chunk.is_empty() {
312 chunks.push(word_chunk.trim().to_string());
313 }
314 } else {
315 if current_len + part_len + 1 > max_len && !current.is_empty() {
316 chunks.push(current.trim().to_string());
317 current.clear();
318 current_len = 0;
319 }
320
321 if !current.is_empty() {
322 current.push_str(", ");
323 current_len += 2;
324 }
325 current.push_str(part);
326 current_len += part_len;
327 }
328 }
329 continue;
330 }
331
332 if current_len + sentence_len + 1 > max_len && !current.is_empty() {
333 chunks.push(current.trim().to_string());
334 current.clear();
335 current_len = 0;
336 }
337
338 if !current.is_empty() {
339 current.push(' ');
340 current_len += 1;
341 }
342 current.push_str(sentence);
343 current_len += sentence_len;
344 }
345
346 if !current.is_empty() {
347 chunks.push(current.trim().to_string());
348 }
349 }
350
351 if chunks.is_empty() {
352 vec![String::new()]
353 } else {
354 chunks
355 }
356}
357
358fn split_sentences(text: &str) -> Vec<String> {
359 let re = Regex::new(r"([.!?])\s+").unwrap();
360 let matches: Vec<_> = re.find_iter(text).collect();
361
362 if matches.is_empty() {
363 return vec![text.to_string()];
364 }
365
366 let mut sentences = Vec::new();
367 let mut last_end = 0;
368
369 for m in matches {
370 let before_punc = &text[last_end..m.start()];
371 let mut is_abbrev = false;
372 for abbrev in ABBREVIATIONS {
373 let combined = format!("{}{}", before_punc.trim(), &text[m.start()..m.start() + 1]);
374 if combined.ends_with(abbrev) {
375 is_abbrev = true;
376 break;
377 }
378 }
379
380 if !is_abbrev {
381 sentences.push(text[last_end..m.end()].to_string());
382 last_end = m.end();
383 }
384 }
385
386 if last_end < text.len() {
387 sentences.push(text[last_end..].to_string());
388 }
389
390 if sentences.is_empty() {
391 vec![text.to_string()]
392 } else {
393 sentences
394 }
395}
396
397pub fn max_chunk_len_for_lang(lang: &str) -> usize {
398 if lang == "ko" || lang == "ja" {
399 120
400 } else {
401 300
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_preprocess_text_adds_lang_tags() {
411 let result = preprocess_text("Hello.", "en").unwrap();
412 assert_eq!(result, "<en>Hello.</en>");
413 }
414
415 #[test]
416 fn test_preprocess_text_adds_period() {
417 let result = preprocess_text("Hello", "en").unwrap();
418 assert_eq!(result, "<en>Hello.</en>");
419 }
420
421 #[test]
422 fn test_preprocess_text_removes_emoji() {
423 let result = preprocess_text("Hi 😊.", "en").unwrap();
424 assert_eq!(result, "<en>Hi.</en>");
425 }
426
427 #[test]
428 fn test_is_valid_lang() {
429 assert!(is_valid_lang("en"));
430 assert!(is_valid_lang("ko"));
431 assert!(!is_valid_lang("zz"));
432 }
433
434 #[test]
435 fn test_chunk_text_short() {
436 let chunks = chunk_text("Hello world.", Some(300));
437 assert_eq!(chunks.len(), 1);
438 }
439
440 #[test]
441 fn test_text_to_unicode_values() {
442 let vals = text_to_unicode_values("A");
443 assert_eq!(vals, vec![65]);
444 }
445
446 #[test]
447 fn test_length_to_mask() {
448 let mask = length_to_mask(&[3], Some(5));
449 assert_eq!(mask[[0, 0, 0]], 1.0);
450 assert_eq!(mask[[0, 0, 2]], 1.0);
451 assert_eq!(mask[[0, 0, 3]], 0.0);
452 }
453
454 #[test]
455 fn test_max_chunk_len_for_lang() {
456 assert_eq!(max_chunk_len_for_lang("en"), 300);
457 assert_eq!(max_chunk_len_for_lang("ko"), 120);
458 assert_eq!(max_chunk_len_for_lang("ja"), 120);
459 }
460}