zai_rs/model/audio_to_text/
data.rs1use super::super::traits::*;
2use super::request::AudioTranscriptionBody;
3use serde::Serialize;
4use std::path::Path;
5
6use validator::Validate;
7
8use crate::client::http::HttpClient;
9
10pub struct AudioTranscriptionRequest<N>
12where
13 N: ModelName + AudioToText + Serialize,
14{
15 pub key: String,
16 pub body: AudioTranscriptionBody<N>,
17 file_path: Option<String>,
18}
19
20impl<N> AudioTranscriptionRequest<N>
21where
22 N: ModelName + AudioToText + Serialize + Clone,
23{
24 pub fn new(model: N, key: String) -> Self {
25 Self {
26 key,
27 body: AudioTranscriptionBody::new(model),
28 file_path: None,
29 }
30 }
31
32 pub fn with_file_path(mut self, path: impl Into<String>) -> Self {
33 self.file_path = Some(path.into());
34 self
35 }
36
37 pub fn with_temperature(mut self, temperature: f32) -> Self {
38 self.body = self.body.with_temperature(temperature);
39 self
40 }
41
42 pub fn with_stream(mut self, stream: bool) -> Self {
43 self.body = self.body.with_stream(stream);
44 self
45 }
46
47 pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
48 self.body = self.body.with_request_id(request_id);
49 self
50 }
51
52 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
53 self.body = self.body.with_user_id(user_id);
54 self
55 }
56
57 pub fn validate(&self) -> anyhow::Result<()> {
58 self.body.validate().map_err(|e| anyhow::anyhow!(e))?;
60 let p = self
62 .file_path
63 .as_ref()
64 .ok_or_else(|| anyhow::anyhow!("file_path is required"))?;
65 if !Path::new(p).exists() {
66 return Err(anyhow::anyhow!(format!("file_path not found: {}", p)));
67 }
68 Ok(())
69 }
70
71 pub async fn send(&self) -> anyhow::Result<super::response::AudioTranscriptionResponse>
72 where
73 N: Clone + Send + Sync + 'static,
74 {
75 self.validate()?;
76 let resp = self.post().await?;
77 let parsed = resp
78 .json::<super::response::AudioTranscriptionResponse>()
79 .await?;
80 Ok(parsed)
81 }
82}
83
84impl<N> HttpClient for AudioTranscriptionRequest<N>
85where
86 N: ModelName + AudioToText + Serialize + Clone + Send + Sync + 'static,
87{
88 type Body = AudioTranscriptionBody<N>;
89 type ApiUrl = &'static str;
90 type ApiKey = String;
91
92 fn api_url(&self) -> &Self::ApiUrl {
93 &"https://open.bigmodel.cn/api/paas/v4/audio/transcriptions"
94 }
95
96 fn api_key(&self) -> &Self::ApiKey {
97 &self.key
98 }
99
100 fn body(&self) -> &Self::Body {
101 &self.body
102 }
103
104 fn post(&self) -> impl std::future::Future<Output = anyhow::Result<reqwest::Response>> + Send {
106 let key = self.key.clone();
107 let url = (*self.api_url()).to_string();
108 let body = self.body.clone();
109 let file_path_opt = self.file_path.clone();
110
111 async move {
112 let file_path =
113 file_path_opt.ok_or_else(|| anyhow::anyhow!("file_path is required"))?;
114
115 let mut form = reqwest::multipart::Form::new();
116
117 let file_name = Path::new(&file_path)
119 .file_name()
120 .and_then(|s| s.to_str())
121 .unwrap_or("audio.wav");
122 let file_bytes = tokio::fs::read(&file_path).await?;
123
124 let mime = if file_name.to_ascii_lowercase().ends_with(".mp3") {
126 "audio/mpeg"
127 } else {
128 "audio/wav"
129 };
130
131 let part = reqwest::multipart::Part::bytes(file_bytes)
132 .file_name(file_name.to_string())
133 .mime_str(mime)?;
134 form = form.part("file", part);
135
136 let model_name: String = body.model.into();
138 form = form.text("model", model_name);
139
140 if let Some(t) = body.temperature {
142 form = form.text("temperature", t.to_string());
143 }
144 if let Some(s) = body.stream {
145 form = form.text("stream", s.to_string());
146 }
147 if let Some(rid) = body.request_id {
148 form = form.text("request_id", rid);
149 }
150 if let Some(uid) = body.user_id {
151 form = form.text("user_id", uid);
152 }
153
154 let resp = reqwest::Client::new()
155 .post(url)
156 .bearer_auth(key)
157 .multipart(form)
158 .send()
159 .await?;
160
161 Ok(resp)
162 }
163 }
164}