rust_ai/azure/
endpoint.rs

1use log::{debug, error};
2use reqwest::{header::HeaderMap, Client};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use urlencoding::encode;
6
7use crate::utils::config::Config;
8
9use super::{
10    types::common::{ResponseExpectation, ResponseType},
11    SSML,
12};
13
14#[allow(non_camel_case_types)]
15#[derive(Serialize, Deserialize, Debug, Clone)]
16pub enum SpeechServiceEndpoint {
17    Get_List_of_Voices,
18    Post_Text_to_Speech_v1,
19    Get_Speech_to_Text_Health_Status_v3_1,
20    Get_List_of_Models_v3_1,
21    Post_Create_Transcription_v3_1,
22    Get_Transcription_v3_1,
23    Get_Transcription_Files_v3_1,
24    Get_Transcription_File_v3_1,
25    None,
26}
27
28impl SpeechServiceEndpoint {
29    pub fn build(&self, region: &str) -> String {
30        match self {
31            Self::Get_List_of_Voices => format!(
32                "https://{}.tts.speech.microsoft.com/cognitiveservices/voices/list",
33                region
34            ),
35
36            Self::Post_Text_to_Speech_v1 => format!(
37                "https://{}.tts.speech.microsoft.com/cognitiveservices/v1",
38                region
39            ),
40
41            Self::Get_Speech_to_Text_Health_Status_v3_1 => format!(
42                "https://{}.cognitiveservices.azure.com/speechtotext/v3.1/healthstatus",
43                region
44            ),
45
46            Self::Get_List_of_Models_v3_1 => format!(
47                "https://{}.cognitiveservices.azure.com/speechtotext/v3.1/models",
48                region
49            ),
50
51            Self::Post_Create_Transcription_v3_1 => format!(
52                "https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions",
53                region
54            ),
55
56            Self::Get_Transcription_v3_1 => format!(
57                "https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions/",
58                region
59            ),
60
61            Self::Get_Transcription_Files_v3_1 => format!(
62                "https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions/",
63                region
64            ),
65
66            Self::Get_Transcription_File_v3_1 => format!(
67                "https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions/",
68                region
69            ),
70            Self::None => String::new(),
71        }
72    }
73}
74
75/// If you would like to use a plain request URI, pass 
76/// [`SpeechServiceEndpoint::None`] as the endpoint, and then provide the 
77/// complete URI through `url_suffix` parameter.
78pub async fn request_get_endpoint(
79    endpoint: &SpeechServiceEndpoint,
80    params: Option<HashMap<String, String>>,
81    url_suffix: Option<String>,
82) -> Result<String, Box<dyn std::error::Error>> {
83    let config = Config::load().unwrap();
84    let region = config.azure.speech.region;
85
86    let mut url = match endpoint {
87        SpeechServiceEndpoint::None => String::new(),
88        _ => endpoint.build(&region),
89    };
90
91    if let Some(url_suffix) = url_suffix {
92        url.push_str(&url_suffix);
93    }
94
95    if let Some(params) = params {
96        let combined = params
97            .iter()
98            .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
99            .collect::<Vec<String>>()
100            .join("&");
101        url.push_str(&format!("?{}", combined));
102    }
103
104    let client = Client::new();
105    let mut req = client.get(url);
106    req = req.header("Ocp-Apim-Subscription-Key", config.azure.speech.key);
107
108    let res = req.send().await?;
109
110    match res.text().await {
111        Ok(text) => Ok(text),
112        Err(e) => {
113            error!(target: "azure", "Error requesting Azure endpoint (GET): {:?}", e);
114            Err(Box::new(e))
115        }
116    }
117}
118
119pub async fn request_post_endpoint_ssml(
120    endpoint: &SpeechServiceEndpoint,
121    ssml: SSML,
122    expect: ResponseExpectation,
123    extra_headers: Option<HeaderMap>,
124) -> Result<ResponseType, Box<dyn std::error::Error>> {
125    let config = Config::load().unwrap();
126    let region = config.azure.speech.region;
127
128    let url = endpoint.build(&region);
129
130    let client = Client::new();
131    let mut req = client
132        .post(url)
133        .header("Ocp-Apim-Subscription-Key", config.azure.speech.key)
134        .header("User-Agent", "rust-ai/example")
135        .header("Content-Type", "application/ssml+xml");
136    if let Some(extra_headers) = extra_headers {
137        req = req.headers(extra_headers);
138    }
139
140    let body = ssml.to_string();
141    req = req.body(body.clone());
142    debug!(target: "azure", "Request body: {:?}", body);
143
144    let res = req.send().await?;
145
146    match expect {
147        ResponseExpectation::Text => match res.text().await {
148            Ok(text) => Ok(ResponseType::Text(text)),
149            Err(e) => {
150                error!(target: "azure", "Error requesting Azure endpoint (GET): {:?}", e);
151                Err(Box::new(e))
152            }
153        },
154        ResponseExpectation::Bytes => match res.bytes().await {
155            Ok(bytes) => Ok(ResponseType::Bytes(bytes.to_vec())),
156            Err(e) => {
157                error!(target: "azure", "Error requesting Azure endpoint (GET): {:?}", e);
158                Err(Box::new(e))
159            }
160        },
161    }
162}
163
164pub async fn request_post_endpoint(
165    endpoint: &SpeechServiceEndpoint,
166    json: impl Serialize + Into<String>,
167    expect: ResponseExpectation,
168    extra_headers: Option<HeaderMap>,
169) -> Result<ResponseType, Box<dyn std::error::Error>> {
170    let config = Config::load().unwrap();
171    let region = config.azure.speech.region;
172
173    let url = endpoint.build(&region);
174
175    let client = Client::new();
176    let mut req = client
177        .post(url)
178        .header("Ocp-Apim-Subscription-Key", config.azure.speech.key)
179        .header("User-Agent", "rust-ai/example")
180        .header("Content-Type", "application/ssml+xml");
181    if let Some(extra_headers) = extra_headers {
182        req = req.headers(extra_headers);
183    }
184
185    req = req.json(&json);
186    debug!(target: "azure", "Request body: {:?}", Into::<String>::into(json));
187
188    let res = req.send().await?;
189
190    match expect {
191        ResponseExpectation::Text => match res.text().await {
192            Ok(text) => Ok(ResponseType::Text(text)),
193            Err(e) => {
194                error!(target: "azure", "Error requesting Azure endpoint (GET): {:?}", e);
195                Err(Box::new(e))
196            }
197        },
198        ResponseExpectation::Bytes => match res.bytes().await {
199            Ok(bytes) => Ok(ResponseType::Bytes(bytes.to_vec())),
200            Err(e) => {
201                error!(target: "azure", "Error requesting Azure endpoint (GET): {:?}", e);
202                Err(Box::new(e))
203            }
204        },
205    }
206}