Skip to main content

voirs_cli/synthesis/
input_detector.rs

1//! Smart text input detection and parsing.
2//!
3//! This module automatically detects the format of input text and extracts
4//! relevant synthesis parameters and content.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use voirs_sdk::Result;
9
10/// Detected input format
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub enum InputFormat {
13    /// Plain text without any markup
14    PlainText,
15    /// SSML (Speech Synthesis Markup Language)
16    Ssml,
17    /// Markdown with TTS hints
18    Markdown,
19    /// JSON structured input
20    Json,
21}
22
23/// Parsed input with extracted metadata
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ParsedInput {
26    /// Detected format
27    pub format: InputFormat,
28    /// Actual text content to synthesize
29    pub content: String,
30    /// Extracted synthesis parameters
31    pub parameters: SynthesisParameters,
32    /// Additional metadata
33    pub metadata: HashMap<String, String>,
34}
35
36/// Synthesis parameters extracted from input
37#[derive(Debug, Clone, Default, Serialize, Deserialize)]
38pub struct SynthesisParameters {
39    /// Voice ID if specified
40    pub voice: Option<String>,
41    /// Speaking rate if specified
42    pub rate: Option<f32>,
43    /// Pitch adjustment if specified
44    pub pitch: Option<f32>,
45    /// Volume adjustment if specified
46    pub volume: Option<f32>,
47    /// Emotion if specified
48    pub emotion: Option<String>,
49    /// Language if specified
50    pub language: Option<String>,
51}
52
53/// Detects the format of input text
54pub fn detect_format(input: &str) -> InputFormat {
55    let trimmed = input.trim();
56
57    // Check for SSML
58    if trimmed.starts_with("<speak") && trimmed.ends_with("</speak>") {
59        return InputFormat::Ssml;
60    }
61
62    // Check for JSON (simple heuristic)
63    if (trimmed.starts_with('{') && trimmed.ends_with('}'))
64        || (trimmed.starts_with('[') && trimmed.ends_with(']'))
65    {
66        // Try to parse as JSON to confirm
67        if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
68            return InputFormat::Json;
69        }
70    }
71
72    // Check for Markdown patterns
73    if contains_markdown_syntax(trimmed) {
74        return InputFormat::Markdown;
75    }
76
77    // Default to plain text
78    InputFormat::PlainText
79}
80
81/// Checks if text contains Markdown syntax
82fn contains_markdown_syntax(text: &str) -> bool {
83    // Check for common Markdown patterns
84    text.contains("# ") // Headers
85        || text.contains("## ")
86        || text.contains("* ") // Lists
87        || text.contains("- ")
88        || text.contains("**") // Bold
89        || text.contains("*") // Italic
90        || text.contains("```") // Code blocks
91        || text.contains("[") && text.contains("](") // Links
92}
93
94/// Parses input text and extracts synthesis parameters
95pub fn parse_input(input: &str) -> Result<ParsedInput> {
96    let format = detect_format(input);
97
98    match format {
99        InputFormat::PlainText => parse_plain_text(input),
100        InputFormat::Ssml => parse_ssml(input),
101        InputFormat::Markdown => parse_markdown(input),
102        InputFormat::Json => parse_json(input),
103    }
104}
105
106/// Parses plain text input
107fn parse_plain_text(input: &str) -> Result<ParsedInput> {
108    Ok(ParsedInput {
109        format: InputFormat::PlainText,
110        content: input.to_string(),
111        parameters: SynthesisParameters::default(),
112        metadata: HashMap::new(),
113    })
114}
115
116/// Parses SSML input
117fn parse_ssml(input: &str) -> Result<ParsedInput> {
118    // For now, pass SSML through directly
119    // Future: Extract attributes from SSML tags for parameters
120    let mut parameters = SynthesisParameters::default();
121    let mut metadata = HashMap::new();
122
123    // Extract voice from SSML if present
124    if let Some(voice_start) = input.find("voice name=\"") {
125        if let Some(voice_end) = input[voice_start + 12..].find('"') {
126            let voice = &input[voice_start + 12..voice_start + 12 + voice_end];
127            parameters.voice = Some(voice.to_string());
128        }
129    }
130
131    // Extract language from SSML if present
132    if let Some(lang_start) = input.find("xml:lang=\"") {
133        if let Some(lang_end) = input[lang_start + 10..].find('"') {
134            let lang = &input[lang_start + 10..lang_start + 10 + lang_end];
135            parameters.language = Some(lang.to_string());
136        }
137    }
138
139    metadata.insert("original_format".to_string(), "ssml".to_string());
140
141    Ok(ParsedInput {
142        format: InputFormat::Ssml,
143        content: input.to_string(),
144        parameters,
145        metadata,
146    })
147}
148
149/// Parses Markdown input with TTS hints
150fn parse_markdown(input: &str) -> Result<ParsedInput> {
151    let mut content = String::new();
152    let mut parameters = SynthesisParameters::default();
153    let mut metadata = HashMap::new();
154
155    // Look for TTS hints in comments (<!-- tts: ... -->)
156    let lines: Vec<&str> = input.lines().collect();
157    let mut skip_next = false;
158
159    for line in &lines {
160        let trimmed = line.trim();
161
162        // Parse TTS hints
163        if trimmed.starts_with("<!-- tts:") && trimmed.ends_with("-->") {
164            let hint = &trimmed[9..trimmed.len() - 3].trim();
165            parse_tts_hint(hint, &mut parameters);
166            skip_next = false;
167            continue;
168        }
169
170        // Skip code blocks
171        if trimmed.starts_with("```") {
172            skip_next = !skip_next;
173            continue;
174        }
175
176        if skip_next {
177            continue;
178        }
179
180        // Remove Markdown formatting for TTS
181        let cleaned = clean_markdown_line(line);
182        if !cleaned.is_empty() {
183            content.push_str(&cleaned);
184            content.push(' ');
185        }
186    }
187
188    metadata.insert("original_format".to_string(), "markdown".to_string());
189
190    Ok(ParsedInput {
191        format: InputFormat::Markdown,
192        content: content.trim().to_string(),
193        parameters,
194        metadata,
195    })
196}
197
198/// Parses TTS hint from Markdown comment
199fn parse_tts_hint(hint: &str, parameters: &mut SynthesisParameters) {
200    for part in hint.split(',') {
201        let kv: Vec<&str> = part.trim().splitn(2, '=').collect();
202        if kv.len() == 2 {
203            let key = kv[0].trim();
204            let value = kv[1].trim();
205
206            match key {
207                "voice" => parameters.voice = Some(value.to_string()),
208                "rate" => parameters.rate = value.parse().ok(),
209                "pitch" => parameters.pitch = value.parse().ok(),
210                "volume" => parameters.volume = value.parse().ok(),
211                "emotion" => parameters.emotion = Some(value.to_string()),
212                "language" => parameters.language = Some(value.to_string()),
213                _ => {}
214            }
215        }
216    }
217}
218
219/// Cleans Markdown syntax from a line for TTS
220fn clean_markdown_line(line: &str) -> String {
221    let mut result = line.to_string();
222
223    // Remove headers
224    result = result
225        .trim_start_matches("# ")
226        .trim_start_matches("## ")
227        .trim_start_matches("### ")
228        .trim_start_matches("#### ")
229        .trim_start_matches("##### ")
230        .trim_start_matches("###### ")
231        .to_string();
232
233    // Remove list markers
234    result = result
235        .trim_start_matches("* ")
236        .trim_start_matches("- ")
237        .trim_start_matches("+ ")
238        .to_string();
239
240    // Remove bold/italic markers
241    result = result.replace("**", "");
242    result = result.replace("__", "");
243
244    // Remove links but keep text: [text](url) -> text
245    while let Some(start) = result.find('[') {
246        if let Some(middle) = result[start..].find("](") {
247            if let Some(end) = result[start + middle..].find(')') {
248                let text = &result[start + 1..start + middle];
249                let before = &result[..start];
250                let after = &result[start + middle + end + 1..];
251                result = format!("{}{}{}", before, text, after);
252            } else {
253                break;
254            }
255        } else {
256            break;
257        }
258    }
259
260    result.trim().to_string()
261}
262
263/// Parses JSON structured input
264fn parse_json(input: &str) -> Result<ParsedInput> {
265    let value: serde_json::Value = serde_json::from_str(input)
266        .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Invalid JSON input: {}", e)))?;
267
268    let mut parameters = SynthesisParameters::default();
269    let mut metadata = HashMap::new();
270
271    // Extract text content
272    let content = if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
273        text.to_string()
274    } else if let Some(content) = value.get("content").and_then(|v| v.as_str()) {
275        content.to_string()
276    } else {
277        return Err(voirs_sdk::VoirsError::config_error(
278            "JSON input must contain 'text' or 'content' field",
279        ));
280    };
281
282    // Extract synthesis parameters
283    if let Some(voice) = value.get("voice").and_then(|v| v.as_str()) {
284        parameters.voice = Some(voice.to_string());
285    }
286    if let Some(rate) = value.get("rate").and_then(|v| v.as_f64()) {
287        parameters.rate = Some(rate as f32);
288    }
289    if let Some(pitch) = value.get("pitch").and_then(|v| v.as_f64()) {
290        parameters.pitch = Some(pitch as f32);
291    }
292    if let Some(volume) = value.get("volume").and_then(|v| v.as_f64()) {
293        parameters.volume = Some(volume as f32);
294    }
295    if let Some(emotion) = value.get("emotion").and_then(|v| v.as_str()) {
296        parameters.emotion = Some(emotion.to_string());
297    }
298    if let Some(language) = value.get("language").and_then(|v| v.as_str()) {
299        parameters.language = Some(language.to_string());
300    }
301
302    // Extract metadata
303    if let Some(obj) = value.as_object() {
304        for (key, val) in obj {
305            if !matches!(
306                key.as_str(),
307                "text" | "content" | "voice" | "rate" | "pitch" | "volume" | "emotion" | "language"
308            ) {
309                if let Some(s) = val.as_str() {
310                    metadata.insert(key.clone(), s.to_string());
311                }
312            }
313        }
314    }
315
316    metadata.insert("original_format".to_string(), "json".to_string());
317
318    Ok(ParsedInput {
319        format: InputFormat::Json,
320        content,
321        parameters,
322        metadata,
323    })
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_detect_plain_text() {
332        let input = "Hello, this is plain text.";
333        assert_eq!(detect_format(input), InputFormat::PlainText);
334    }
335
336    #[test]
337    fn test_detect_ssml() {
338        let input = r#"<speak>Hello <break time="500ms"/> world</speak>"#;
339        assert_eq!(detect_format(input), InputFormat::Ssml);
340    }
341
342    #[test]
343    fn test_detect_json() {
344        let input = r#"{"text": "Hello world", "voice": "en-US"}"#;
345        assert_eq!(detect_format(input), InputFormat::Json);
346    }
347
348    #[test]
349    fn test_detect_markdown() {
350        let input = "# Hello\n\nThis is **bold** text.";
351        assert_eq!(detect_format(input), InputFormat::Markdown);
352    }
353
354    #[test]
355    fn test_parse_plain_text() {
356        let input = "Hello world";
357        let parsed = parse_input(input).unwrap();
358        assert_eq!(parsed.format, InputFormat::PlainText);
359        assert_eq!(parsed.content, "Hello world");
360    }
361
362    #[test]
363    fn test_parse_json() {
364        let input = r#"{"text": "Hello world", "voice": "kokoro-en", "rate": 1.2}"#;
365        let parsed = parse_input(input).unwrap();
366        assert_eq!(parsed.format, InputFormat::Json);
367        assert_eq!(parsed.content, "Hello world");
368        assert_eq!(parsed.parameters.voice, Some("kokoro-en".to_string()));
369        assert_eq!(parsed.parameters.rate, Some(1.2));
370    }
371
372    #[test]
373    fn test_parse_markdown_with_hints() {
374        let input = r#"<!-- tts: voice=kokoro-en, rate=1.1 -->
375# Welcome
376
377This is **important** text.
378- Item 1
379- Item 2"#;
380        let parsed = parse_input(input).unwrap();
381        assert_eq!(parsed.format, InputFormat::Markdown);
382        assert_eq!(parsed.parameters.voice, Some("kokoro-en".to_string()));
383        assert_eq!(parsed.parameters.rate, Some(1.1));
384        assert!(parsed.content.contains("Welcome"));
385        assert!(parsed.content.contains("important"));
386    }
387
388    #[test]
389    fn test_clean_markdown() {
390        let line = "## This is a **bold** heading";
391        let cleaned = clean_markdown_line(line);
392        assert_eq!(cleaned, "This is a bold heading");
393    }
394}