1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use voirs_sdk::Result;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub enum InputFormat {
13 PlainText,
15 Ssml,
17 Markdown,
19 Json,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ParsedInput {
26 pub format: InputFormat,
28 pub content: String,
30 pub parameters: SynthesisParameters,
32 pub metadata: HashMap<String, String>,
34}
35
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
38pub struct SynthesisParameters {
39 pub voice: Option<String>,
41 pub rate: Option<f32>,
43 pub pitch: Option<f32>,
45 pub volume: Option<f32>,
47 pub emotion: Option<String>,
49 pub language: Option<String>,
51}
52
53pub fn detect_format(input: &str) -> InputFormat {
55 let trimmed = input.trim();
56
57 if trimmed.starts_with("<speak") && trimmed.ends_with("</speak>") {
59 return InputFormat::Ssml;
60 }
61
62 if (trimmed.starts_with('{') && trimmed.ends_with('}'))
64 || (trimmed.starts_with('[') && trimmed.ends_with(']'))
65 {
66 if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
68 return InputFormat::Json;
69 }
70 }
71
72 if contains_markdown_syntax(trimmed) {
74 return InputFormat::Markdown;
75 }
76
77 InputFormat::PlainText
79}
80
81fn contains_markdown_syntax(text: &str) -> bool {
83 text.contains("# ") || text.contains("## ")
86 || text.contains("* ") || text.contains("- ")
88 || text.contains("**") || text.contains("*") || text.contains("```") || text.contains("[") && text.contains("](") }
93
94pub fn parse_input(input: &str) -> Result<ParsedInput> {
96 let format = detect_format(input);
97
98 match format {
99 InputFormat::PlainText => parse_plain_text(input),
100 InputFormat::Ssml => parse_ssml(input),
101 InputFormat::Markdown => parse_markdown(input),
102 InputFormat::Json => parse_json(input),
103 }
104}
105
106fn parse_plain_text(input: &str) -> Result<ParsedInput> {
108 Ok(ParsedInput {
109 format: InputFormat::PlainText,
110 content: input.to_string(),
111 parameters: SynthesisParameters::default(),
112 metadata: HashMap::new(),
113 })
114}
115
116fn parse_ssml(input: &str) -> Result<ParsedInput> {
118 let mut parameters = SynthesisParameters::default();
121 let mut metadata = HashMap::new();
122
123 if let Some(voice_start) = input.find("voice name=\"") {
125 if let Some(voice_end) = input[voice_start + 12..].find('"') {
126 let voice = &input[voice_start + 12..voice_start + 12 + voice_end];
127 parameters.voice = Some(voice.to_string());
128 }
129 }
130
131 if let Some(lang_start) = input.find("xml:lang=\"") {
133 if let Some(lang_end) = input[lang_start + 10..].find('"') {
134 let lang = &input[lang_start + 10..lang_start + 10 + lang_end];
135 parameters.language = Some(lang.to_string());
136 }
137 }
138
139 metadata.insert("original_format".to_string(), "ssml".to_string());
140
141 Ok(ParsedInput {
142 format: InputFormat::Ssml,
143 content: input.to_string(),
144 parameters,
145 metadata,
146 })
147}
148
149fn parse_markdown(input: &str) -> Result<ParsedInput> {
151 let mut content = String::new();
152 let mut parameters = SynthesisParameters::default();
153 let mut metadata = HashMap::new();
154
155 let lines: Vec<&str> = input.lines().collect();
157 let mut skip_next = false;
158
159 for line in &lines {
160 let trimmed = line.trim();
161
162 if trimmed.starts_with("<!-- tts:") && trimmed.ends_with("-->") {
164 let hint = &trimmed[9..trimmed.len() - 3].trim();
165 parse_tts_hint(hint, &mut parameters);
166 skip_next = false;
167 continue;
168 }
169
170 if trimmed.starts_with("```") {
172 skip_next = !skip_next;
173 continue;
174 }
175
176 if skip_next {
177 continue;
178 }
179
180 let cleaned = clean_markdown_line(line);
182 if !cleaned.is_empty() {
183 content.push_str(&cleaned);
184 content.push(' ');
185 }
186 }
187
188 metadata.insert("original_format".to_string(), "markdown".to_string());
189
190 Ok(ParsedInput {
191 format: InputFormat::Markdown,
192 content: content.trim().to_string(),
193 parameters,
194 metadata,
195 })
196}
197
198fn parse_tts_hint(hint: &str, parameters: &mut SynthesisParameters) {
200 for part in hint.split(',') {
201 let kv: Vec<&str> = part.trim().splitn(2, '=').collect();
202 if kv.len() == 2 {
203 let key = kv[0].trim();
204 let value = kv[1].trim();
205
206 match key {
207 "voice" => parameters.voice = Some(value.to_string()),
208 "rate" => parameters.rate = value.parse().ok(),
209 "pitch" => parameters.pitch = value.parse().ok(),
210 "volume" => parameters.volume = value.parse().ok(),
211 "emotion" => parameters.emotion = Some(value.to_string()),
212 "language" => parameters.language = Some(value.to_string()),
213 _ => {}
214 }
215 }
216 }
217}
218
219fn clean_markdown_line(line: &str) -> String {
221 let mut result = line.to_string();
222
223 result = result
225 .trim_start_matches("# ")
226 .trim_start_matches("## ")
227 .trim_start_matches("### ")
228 .trim_start_matches("#### ")
229 .trim_start_matches("##### ")
230 .trim_start_matches("###### ")
231 .to_string();
232
233 result = result
235 .trim_start_matches("* ")
236 .trim_start_matches("- ")
237 .trim_start_matches("+ ")
238 .to_string();
239
240 result = result.replace("**", "");
242 result = result.replace("__", "");
243
244 while let Some(start) = result.find('[') {
246 if let Some(middle) = result[start..].find("](") {
247 if let Some(end) = result[start + middle..].find(')') {
248 let text = &result[start + 1..start + middle];
249 let before = &result[..start];
250 let after = &result[start + middle + end + 1..];
251 result = format!("{}{}{}", before, text, after);
252 } else {
253 break;
254 }
255 } else {
256 break;
257 }
258 }
259
260 result.trim().to_string()
261}
262
263fn parse_json(input: &str) -> Result<ParsedInput> {
265 let value: serde_json::Value = serde_json::from_str(input)
266 .map_err(|e| voirs_sdk::VoirsError::config_error(format!("Invalid JSON input: {}", e)))?;
267
268 let mut parameters = SynthesisParameters::default();
269 let mut metadata = HashMap::new();
270
271 let content = if let Some(text) = value.get("text").and_then(|v| v.as_str()) {
273 text.to_string()
274 } else if let Some(content) = value.get("content").and_then(|v| v.as_str()) {
275 content.to_string()
276 } else {
277 return Err(voirs_sdk::VoirsError::config_error(
278 "JSON input must contain 'text' or 'content' field",
279 ));
280 };
281
282 if let Some(voice) = value.get("voice").and_then(|v| v.as_str()) {
284 parameters.voice = Some(voice.to_string());
285 }
286 if let Some(rate) = value.get("rate").and_then(|v| v.as_f64()) {
287 parameters.rate = Some(rate as f32);
288 }
289 if let Some(pitch) = value.get("pitch").and_then(|v| v.as_f64()) {
290 parameters.pitch = Some(pitch as f32);
291 }
292 if let Some(volume) = value.get("volume").and_then(|v| v.as_f64()) {
293 parameters.volume = Some(volume as f32);
294 }
295 if let Some(emotion) = value.get("emotion").and_then(|v| v.as_str()) {
296 parameters.emotion = Some(emotion.to_string());
297 }
298 if let Some(language) = value.get("language").and_then(|v| v.as_str()) {
299 parameters.language = Some(language.to_string());
300 }
301
302 if let Some(obj) = value.as_object() {
304 for (key, val) in obj {
305 if !matches!(
306 key.as_str(),
307 "text" | "content" | "voice" | "rate" | "pitch" | "volume" | "emotion" | "language"
308 ) {
309 if let Some(s) = val.as_str() {
310 metadata.insert(key.clone(), s.to_string());
311 }
312 }
313 }
314 }
315
316 metadata.insert("original_format".to_string(), "json".to_string());
317
318 Ok(ParsedInput {
319 format: InputFormat::Json,
320 content,
321 parameters,
322 metadata,
323 })
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_detect_plain_text() {
332 let input = "Hello, this is plain text.";
333 assert_eq!(detect_format(input), InputFormat::PlainText);
334 }
335
336 #[test]
337 fn test_detect_ssml() {
338 let input = r#"<speak>Hello <break time="500ms"/> world</speak>"#;
339 assert_eq!(detect_format(input), InputFormat::Ssml);
340 }
341
342 #[test]
343 fn test_detect_json() {
344 let input = r#"{"text": "Hello world", "voice": "en-US"}"#;
345 assert_eq!(detect_format(input), InputFormat::Json);
346 }
347
348 #[test]
349 fn test_detect_markdown() {
350 let input = "# Hello\n\nThis is **bold** text.";
351 assert_eq!(detect_format(input), InputFormat::Markdown);
352 }
353
354 #[test]
355 fn test_parse_plain_text() {
356 let input = "Hello world";
357 let parsed = parse_input(input).unwrap();
358 assert_eq!(parsed.format, InputFormat::PlainText);
359 assert_eq!(parsed.content, "Hello world");
360 }
361
362 #[test]
363 fn test_parse_json() {
364 let input = r#"{"text": "Hello world", "voice": "kokoro-en", "rate": 1.2}"#;
365 let parsed = parse_input(input).unwrap();
366 assert_eq!(parsed.format, InputFormat::Json);
367 assert_eq!(parsed.content, "Hello world");
368 assert_eq!(parsed.parameters.voice, Some("kokoro-en".to_string()));
369 assert_eq!(parsed.parameters.rate, Some(1.2));
370 }
371
372 #[test]
373 fn test_parse_markdown_with_hints() {
374 let input = r#"<!-- tts: voice=kokoro-en, rate=1.1 -->
375# Welcome
376
377This is **important** text.
378- Item 1
379- Item 2"#;
380 let parsed = parse_input(input).unwrap();
381 assert_eq!(parsed.format, InputFormat::Markdown);
382 assert_eq!(parsed.parameters.voice, Some("kokoro-en".to_string()));
383 assert_eq!(parsed.parameters.rate, Some(1.1));
384 assert!(parsed.content.contains("Welcome"));
385 assert!(parsed.content.contains("important"));
386 }
387
388 #[test]
389 fn test_clean_markdown() {
390 let line = "## This is a **bold** heading";
391 let cleaned = clean_markdown_line(line);
392 assert_eq!(cleaned, "This is a bold heading");
393 }
394}