Skip to main content

speechmarkdown_rust/formatters/ssml/
microsoft_azure.rs

1use crate::ast::{AstNode, NodeType};
2use crate::error::Result;
3use crate::formatters::base::{Formatter, FormatterOptions};
4use crate::formatters::ssml::base::{
5    attrs_merge, format_attr_string_ordered, SsmlFormatterBase, TagAttrs, TagInfo,
6};
7
8const AZURE_EXPRESS_AS_STYLES: &[&str] = &[
9    "angry",
10    "cheerful",
11    "excited",
12    "friendly",
13    "hopeful",
14    "sad",
15    "shouting",
16    "terrified",
17    "whispering",
18    "unfriendly",
19    "depressed",
20    "serious",
21    "calm",
22    "fearful",
23    "envious",
24    "gentle",
25    "lyrical",
26    "narration-professional",
27    "narration-relaxed",
28    "newscast-casual",
29    "newscast-formal",
30    "chat",
31    "customerservice",
32    "empathetic",
33    "documentary-narration",
34    "advertisement_upbeat",
35    "sports_commentary",
36    "sports_commentary_excited",
37    "poetry-reading",
38    "assistant",
39    "embarrassed",
40    "disgruntled",
41];
42
43pub fn azure_voice_name(name: &str) -> String {
44    let lower = name.to_lowercase();
45    let mapping: [(&str, &str); 11] = [
46        ("jenny", "en-US-JennyNeural"),
47        ("guy", "en-US-GuyNeural"),
48        ("aria", "en-US-AriaNeural"),
49        ("davis", "en-US-DavisNeural"),
50        ("amber", "en-US-AmberNeural"),
51        ("ana", "en-US-AnaNeural"),
52        ("andrew", "en-US-AndrewNeural"),
53        ("emma", "en-US-EmmaNeural"),
54        ("brian", "en-US-BrianNeural"),
55        ("christopher", "en-US-ChristopherNeural"),
56        ("eric", "en-US-EricNeural"),
57    ];
58    for (key, neural) in &mapping {
59        if lower == *key {
60            return neural.to_string();
61        }
62    }
63    let mut result = String::new();
64    let mut cap_next = true;
65    for c in name.chars() {
66        if c == '-' || c == '_' || c == ' ' {
67            result.push('-');
68            cap_next = true;
69        } else if cap_next {
70            for uc in c.to_uppercase() {
71                result.push(uc);
72            }
73            cap_next = false;
74        } else {
75            result.push(c);
76        }
77    }
78    result
79}
80
81pub struct MicrosoftAzureSsmlFormatter {
82    base: SsmlFormatterBase,
83    options: FormatterOptions,
84}
85
86impl MicrosoftAzureSsmlFormatter {
87    pub fn new(options: FormatterOptions) -> Self {
88        let base = SsmlFormatterBase::new(options.clone());
89        Self { base, options }
90    }
91
92    fn is_valid_azure_style(style: &str) -> bool {
93        AZURE_EXPRESS_AS_STYLES.contains(&style)
94    }
95
96    fn section_style(node: &AstNode) -> Option<String> {
97        node.attributes
98            .get("style")
99            .or_else(|| node.attributes.get("emotion"))
100            .cloned()
101    }
102
103    fn is_azure_express_section(node: &AstNode) -> bool {
104        let style = Self::section_style(node);
105        match style {
106            Some(s) => {
107                if matches!(s.as_str(), "voice" | "lang" | "device" | "defaults") {
108                    return false;
109                }
110                Self::is_valid_azure_style(&s)
111            }
112            None => false,
113        }
114    }
115
116    fn is_unsupported_emotion_section(node: &AstNode) -> bool {
117        let style = Self::section_style(node);
118        match style {
119            Some(s) => {
120                if matches!(s.as_str(), "voice" | "lang" | "device" | "defaults") {
121                    return false;
122                }
123                !Self::is_valid_azure_style(&s)
124            }
125            None => {
126                for key in &node.attribute_keys {
127                    if key == "disappointed" || key == "excited" {
128                        return true;
129                    }
130                }
131                false
132            }
133        }
134    }
135
136    fn is_emotion_section(node: &AstNode) -> bool {
137        let style = Self::section_style(node);
138        match style {
139            Some(s) => !matches!(s.as_str(), "voice" | "lang" | "device"),
140            None => {
141                for key in &node.attribute_keys {
142                    if key == "disappointed" || key == "excited" {
143                        return true;
144                    }
145                }
146                false
147            }
148        }
149    }
150
151    fn azure_attribute_to_tag(&self, key: &str, value: &str) -> Option<TagInfo> {
152        let mut attributes: TagAttrs = Vec::new();
153        match key.to_lowercase().as_str() {
154            "emphasis" => None,
155            "whisper" => {
156                attributes.push(("volume".to_string(), "x-soft".to_string()));
157                attributes.push(("rate".to_string(), "slow".to_string()));
158                Some(("prosody".to_string(), attributes))
159            }
160            "number" | "cardinal" => Some(("say-as".to_string(), {
161                vec![("interpret-as".to_string(), "cardinal".to_string())]
162            })),
163            "excited" | "disappointed" => Some(("mstts:express-as".to_string(), {
164                vec![("style".to_string(), key.to_lowercase())]
165            })),
166            "voice" => {
167                if value.is_empty() || value == "device" {
168                    return None;
169                }
170                let neural_name = azure_voice_name(value);
171                attributes.push(("name".to_string(), neural_name));
172                Some(("voice".to_string(), attributes))
173            }
174            _ => self.base.attribute_to_tag(key, value),
175        }
176    }
177
178    fn format_azure_text_modifier(&self, node: &AstNode) -> Result<String> {
179        let mut tags: Vec<TagInfo> = Vec::new();
180        let mut last_say_as: Option<TagInfo> = None;
181
182        for key in &node.attribute_keys {
183            let value = match node.attributes.get(key) {
184                Some(v) => v,
185                None => continue,
186            };
187            if let Some(tag_info) = self.azure_attribute_to_tag(key, value) {
188                let tag_name = tag_info.0.clone();
189                if tag_name == "prosody" {
190                    if let Some(existing) = tags.iter_mut().find(|(name, _)| name == "prosody") {
191                        attrs_merge(&mut existing.1, tag_info.1);
192                        continue;
193                    }
194                }
195                if tag_name == "say-as" {
196                    last_say_as = Some(tag_info);
197                    continue;
198                }
199                tags.push(tag_info);
200            }
201        }
202
203        if let Some(say_as) = last_say_as {
204            tags.push(say_as);
205        }
206
207        if tags.is_empty() {
208            return Ok(node.text.clone());
209        }
210
211        self.base.apply_tags_to_text(&node.text, &tags)
212    }
213
214    fn format_azure_node(&self, node: &AstNode) -> Result<String> {
215        match node.node_type {
216            NodeType::PlainText => Ok(node.text.clone()),
217            NodeType::TextModifier => self.format_azure_text_modifier(node),
218            _ => self.base.format_node_internal(node),
219        }
220    }
221
222    fn format_azure_section(&self, node: &AstNode) -> Result<String> {
223        let mut tags: Vec<TagInfo> = Vec::new();
224
225        if let Some(style) = node.attributes.get("style") {
226            if style != "defaults" {
227                if let Some(tag_info) = self.azure_attribute_to_tag(style, "") {
228                    tags.push(tag_info);
229                }
230            }
231        }
232
233        for key in &node.attribute_keys {
234            let value = match node.attributes.get(key) {
235                Some(v) => v,
236                None => continue,
237            };
238            if key == "style" {
239                continue;
240            }
241            if let Some(tag_info) = self.azure_attribute_to_tag(key, value) {
242                let tag_name = tag_info.0.clone();
243                if tag_name == "prosody" {
244                    if let Some(existing) = tags.iter_mut().find(|(name, _)| name == "prosody") {
245                        attrs_merge(&mut existing.1, tag_info.1);
246                        continue;
247                    }
248                }
249                tags.push(tag_info);
250            }
251        }
252
253        if tags.is_empty() {
254            return Ok(String::new());
255        }
256
257        let mut result = String::new();
258        for (i, (tag_name, attrs)) in tags.iter().enumerate() {
259            let attr_string = format_attr_string_ordered(tag_name, attrs);
260            if i > 0 {
261                result.push('\n');
262            }
263            if attr_string.is_empty() {
264                result.push_str(&format!("<{}>", tag_name));
265            } else {
266                result.push_str(&format!("<{} {}>", tag_name, attr_string));
267            }
268        }
269        Ok(result)
270    }
271
272    fn format_azure_section_close(&self, node: &AstNode) -> Result<String> {
273        let mut tags: Vec<TagInfo> = Vec::new();
274
275        if let Some(style) = node.attributes.get("style") {
276            if style != "defaults" {
277                if let Some(tag_info) = self.azure_attribute_to_tag(style, "") {
278                    tags.push(tag_info);
279                }
280            }
281        }
282
283        for key in &node.attribute_keys {
284            let value = match node.attributes.get(key) {
285                Some(v) => v,
286                None => continue,
287            };
288            if key == "style" {
289                continue;
290            }
291            if let Some(tag_info) = self.azure_attribute_to_tag(key, value) {
292                let tag_name = tag_info.0.clone();
293                if tag_name == "prosody" {
294                    if let Some(existing) = tags.iter_mut().find(|(name, _)| name == "prosody") {
295                        attrs_merge(&mut existing.1, tag_info.1);
296                        continue;
297                    }
298                }
299                tags.push(tag_info);
300            }
301        }
302
303        if tags.is_empty() {
304            return Ok(String::new());
305        }
306
307        let mut result = String::new();
308        for (i, (tag_name, _)) in tags.iter().rev().enumerate() {
309            result.push_str(&format!("</{}>", tag_name));
310            if i < tags.len() - 1 {
311                result.push('\n');
312            }
313        }
314        Ok(result)
315    }
316
317    fn has_unsupported_emotion_sections(ast: &AstNode) -> bool {
318        for child in &ast.children {
319            if child.node_type == NodeType::Section && Self::is_unsupported_emotion_section(child) {
320                return true;
321            }
322        }
323        false
324    }
325
326    fn format_document_sections(&self, ast: &AstNode) -> Result<String> {
327        let passthrough_emotions = Self::has_unsupported_emotion_sections(ast);
328        let mut content = String::new();
329        let mut children_iter = ast.children.iter().peekable();
330
331        while let Some(child) = children_iter.next() {
332            if child.node_type == NodeType::Section {
333                let is_express = !passthrough_emotions && Self::is_azure_express_section(child);
334                let is_unsupported = Self::is_unsupported_emotion_section(child);
335                let is_emotion_passthrough =
336                    passthrough_emotions && Self::is_emotion_section(child);
337                let is_defaults = child
338                    .attributes
339                    .get("style")
340                    .is_some_and(|s| s == "defaults");
341
342                let mut section_content_raw = String::new();
343                while let Some(next_child) = children_iter.peek() {
344                    if next_child.node_type == NodeType::Section {
345                        break;
346                    }
347                    let next_child = children_iter.next().unwrap();
348                    section_content_raw.push_str(&self.format_azure_node(next_child)?);
349                }
350
351                if is_unsupported || is_emotion_passthrough || is_defaults {
352                    content.push_str(&format!("#[{}]", child.text));
353                    content.push_str(&section_content_raw);
354                } else if is_express {
355                    let style = Self::section_style(child).unwrap_or_default();
356                    let had_leading_newline = section_content_raw.starts_with('\n');
357                    let section_content = if had_leading_newline {
358                        &section_content_raw[1..]
359                    } else {
360                        &section_content_raw
361                    };
362
363                    content.push_str(&format!("<mstts:express-as style=\"{}\">", style));
364                    if had_leading_newline {
365                        content.push('\n');
366                    }
367                    content.push_str(section_content);
368                    content.push_str("</mstts:express-as>");
369                    if had_leading_newline {
370                        content.push('\n');
371                    }
372                } else {
373                    let had_leading_newline = section_content_raw.starts_with('\n');
374                    let section_content = if had_leading_newline {
375                        &section_content_raw[1..]
376                    } else {
377                        &section_content_raw
378                    };
379
380                    let section_open = self.format_azure_section(child)?;
381                    let section_close = if !section_open.is_empty() {
382                        self.format_azure_section_close(child)?
383                    } else {
384                        String::new()
385                    };
386
387                    if !section_open.is_empty() {
388                        content.push_str(&section_open);
389                        if had_leading_newline {
390                            content.push('\n');
391                        }
392                        content.push_str(section_content);
393                        content.push_str(&section_close);
394                        if had_leading_newline {
395                            content.push('\n');
396                        }
397                    } else {
398                        content.push_str(section_content);
399                    }
400                }
401            } else {
402                content.push_str(&self.format_azure_node(child)?);
403            }
404        }
405
406        Ok(content)
407    }
408}
409
410impl Formatter for MicrosoftAzureSsmlFormatter {
411    fn format(&self, ast: &AstNode) -> Result<String> {
412        let content = self.format_document_sections(ast)?;
413
414        if self.options.include_speak_tag {
415            let trimmed = content.trim_end_matches('\n');
416            let use_mstts = trimmed.contains("mstts:express-as");
417            if use_mstts {
418                Ok(format!(
419                    "<speak xmlns:mstts=\"https://www.w3.org/2001/mstts\">\n{}\n</speak>",
420                    trimmed
421                ))
422            } else {
423                Ok(format!("<speak>\n{}\n</speak>", trimmed))
424            }
425        } else {
426            Ok(content)
427        }
428    }
429
430    fn format_node(&self, node: &AstNode) -> Result<String> {
431        self.format_azure_node(node)
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use crate::parser::SpeechMarkdownParser;
438
439    #[test]
440    fn test_microsoft_azure_basic_parsing() {
441        let input = "Hello world";
442        let result =
443            SpeechMarkdownParser::to_ssml(input, crate::formatters::base::Platform::MicrosoftAzure);
444        assert!(result.is_ok());
445    }
446
447    #[test]
448    fn test_microsoft_azure_with_section() {
449        let input = "#[angry] I am angry!";
450        let result =
451            SpeechMarkdownParser::to_ssml(input, crate::formatters::base::Platform::MicrosoftAzure);
452        assert!(result.is_ok());
453
454        let ssml = result.unwrap();
455        assert!(ssml.contains("<mstts:express-as"));
456        assert!(ssml.contains("style=\"angry\""));
457        assert!(ssml.contains("xmlns:mstts"));
458    }
459}