rust_ai_generator/
lib.rs

1use std::path::PathBuf;
2
3use chrono::Utc;
4use reqwest::Client;
5use types::Voice;
6use utils::config::Config;
7
8/// Support functions
9pub mod utils;
10
11pub mod types;
12
13pub async fn get_azure_voice_list() -> Result<Vec<Voice>, Box<dyn std::error::Error>> {
14    let config = Config::load().unwrap();
15
16    let url = format!(
17        "https://{}.tts.speech.microsoft.com/cognitiveservices/voices/list",
18        config.azure.speech.region
19    );
20
21    let text = request_get(&url).await?;
22    match serde_json::from_str::<Vec<Voice>>(&text) {
23        Ok(voices) => Ok(voices),
24        Err(e) => Err(format!(
25            "Unable to parse voice list, check log for details: {:#?}",
26            e
27        )
28        .into()),
29    }
30}
31
32pub async fn request_get(url: &str) -> Result<String, Box<dyn std::error::Error>> {
33    let config = Config::load().unwrap();
34
35    let client = Client::new();
36    let mut req = client.get(url);
37    req = req.header("Ocp-Apim-Subscription-Key", config.azure.speech.key);
38
39    let res = req.send().await?;
40
41    match res.text().await {
42        Ok(text) => Ok(text),
43        Err(e) => Err(Box::new(e)),
44    }
45}
46
47pub fn generate_voice_names(
48    path: PathBuf,
49    voices: Vec<Voice>,
50) -> Result<bool, Box<dyn std::error::Error>> {
51    if voices.len() > 0 {
52        // Generate file
53        let mut voice_names = vec![];
54        let mut voice_name_intos = vec![];
55
56        voices.iter().for_each(|vn| {
57            let enum_variant_name = vn.short_name.replace("-", "_");
58
59            voice_names.push(format!(
60                "\n{}/// Voice name variant for `{}`\n{}{},",
61                " ".repeat(4),
62                vn.short_name,
63                " ".repeat(4),
64                enum_variant_name
65            ));
66            voice_name_intos.push(format!(
67                "{}Self::{} => \"{}\",",
68                " ".repeat(12),
69                enum_variant_name,
70                vn.short_name
71            ));
72        });
73
74        // Generate `voice_name.rs`
75        let voice_name_file_content = format!(
76            r#"//!
77//! *Auto-generated file, you should NOT update its contents directly*
78//! 
79//! Voice names fetched from Microsoft Cognitive Services API.
80//! 
81//! Updated on {}.
82
83////////////////////////////////////////////////////////////////////////////////
84
85/// VoiceNames generated from API call
86#[allow(non_camel_case_types)]
87#[derive(Debug, Clone)]
88pub enum VoiceName {{
89{}
90}}
91
92impl Into<String> for VoiceName {{
93    fn into(self) -> String {{
94        (match self {{
95{}
96        }})
97        .into()
98    }}
99}}"#,
100            format!("{}", Utc::now().format("%Y-%m-%d")),
101            voice_names.join("\n"),
102            voice_name_intos.join("\n")
103        );
104
105        return if path.exists() {
106            std::fs::write(path, &voice_name_file_content).unwrap();
107            Ok(true)
108        } else {
109            Ok(false)
110        };
111    }
112    Ok(true)
113}
114
115pub fn generate_locale_names(
116    path: PathBuf,
117    voices: Vec<Voice>,
118) -> Result<bool, Box<dyn std::error::Error>> {
119    if voices.len() > 0 {
120        // Generate file
121        let mut locale_names = vec![];
122        let mut locale_name_intos = vec![];
123        let mut locale_name_froms = vec![];
124
125        voices.iter().for_each(|vn| {
126            let locale_name = vn
127                .short_name
128                .split("-")
129                .take(2)
130                .collect::<Vec<&str>>()
131                .join("-");
132            let locale_variant_name = locale_name.replace("-", "_");
133
134            let temp = format!(
135                "\n{}/// Locale variant for `{}`\n{}{},",
136                " ".repeat(4),
137                locale_name,
138                " ".repeat(4),
139                locale_variant_name
140            );
141            if !locale_names.contains(&temp) {
142                locale_names.push(temp);
143            }
144
145            let temp = format!(
146                "{}Self::{} => \"{}\",",
147                " ".repeat(12),
148                locale_variant_name,
149                locale_name
150            );
151            if !locale_name_intos.contains(&temp) {
152                locale_name_intos.push(temp);
153            }
154
155            let temp = format!(
156                "{}\"{}\" => Self::{},",
157                " ".repeat(12),
158                locale_name,
159                locale_variant_name,
160            );
161            if !locale_name_froms.contains(&temp) {
162                locale_name_froms.push(temp);
163            }
164        });
165
166        // Generate `locale.rs`
167        let locale_file_content = format!(
168            r#"//!
169//! *Auto-generated file, you should NOT update its contents directly*
170//! 
171//! Locale names fetched from Microsoft Cognitive Services API.
172//! 
173//! Updated on {}.
174
175////////////////////////////////////////////////////////////////////////////////
176
177/// Locales generated from API call
178#[allow(non_camel_case_types)]
179#[derive(Debug, Clone)]
180pub enum Locale {{
181{}
182}}
183
184impl Into<String> for Locale {{
185    fn into(self) -> String {{
186      (match self {{
187{}
188      }})
189      .into()
190  }}
191}}
192
193
194impl From<&str> for Locale {{
195    fn from(value: &str) -> Self {{
196        match value {{
197{}
198        _ => {{
199          log::warn!( target: "rust-ai", "Unrecognized locale `{{}}`", value);
200          todo!("The local file should be updated and regenerated")
201        }}
202      }}
203  }}
204}}
205
206{}
207"#,
208            format!("{}", Utc::now().format("%Y-%m-%d")),
209            locale_names.join("\n"),
210            locale_name_intos.join("\n"),
211            locale_name_froms.join("\n"),
212            r#"impl serde::Serialize for Locale {
213    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
214    where
215        S: serde::Serializer,
216    {
217        let content = Into::<String>::into(self.clone());
218        serializer.serialize_str(&content)
219    }
220}
221struct LocaleVisitor;
222
223impl<'de> serde::de::Visitor<'de> for LocaleVisitor {
224    type Value = Locale;
225    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
226    where
227        E: serde::de::Error,
228    {
229        Ok(Into::<Self::Value>::into(v.as_str()))
230    }
231
232    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
233    where
234        E: serde::de::Error,
235    {
236        Ok(Into::<Self::Value>::into(v))
237    }
238
239    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
240        formatter.write_str("Unrecognizable locale string.")
241    }
242}
243
244impl<'de> serde::Deserialize<'de> for Locale {
245    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
246    where
247        D: serde::Deserializer<'de>,
248    {
249        deserializer.deserialize_string(LocaleVisitor)
250    }
251}"#
252        );
253
254        if path.exists() {
255            std::fs::write(path, &locale_file_content).unwrap();
256        }
257    }
258    Ok(true)
259}