Skip to main content

zai_rs/model/audio_to_text/
data.rs

1use std::path::Path;
2
3use serde::Serialize;
4use validator::Validate;
5
6use super::{super::traits::*, request::AudioToTextBody};
7use crate::client::http::HttpClient;
8
9/// Audio transcription request (multipart/form-data)
10pub struct AudioToTextRequest<N>
11where
12    N: ModelName + AudioToText + Serialize,
13{
14    pub key: String,
15    pub body: AudioToTextBody<N>,
16    file_path: Option<String>,
17}
18
19impl<N> AudioToTextRequest<N>
20where
21    N: ModelName + AudioToText + Serialize + Clone,
22{
23    pub fn new(model: N, key: String) -> Self {
24        Self {
25            key,
26            body: AudioToTextBody::new(model),
27            file_path: None,
28        }
29    }
30
31    pub fn with_file_path(mut self, path: impl Into<String>) -> Self {
32        self.file_path = Some(path.into());
33        self
34    }
35
36    pub fn with_temperature(mut self, temperature: f32) -> Self {
37        self.body = self.body.with_temperature(temperature);
38        self
39    }
40
41    pub fn with_stream(mut self, stream: bool) -> Self {
42        self.body = self.body.with_stream(stream);
43        self
44    }
45
46    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
47        self.body = self.body.with_request_id(request_id);
48        self
49    }
50
51    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
52        self.body = self.body.with_user_id(user_id);
53        self
54    }
55
56    pub fn validate(&self) -> crate::ZaiResult<()> {
57        // Check body constraints
58
59        self.body
60            .validate()
61            .map_err(crate::client::error::ZaiError::from)?;
62        // Ensure file path exists
63
64        let p =
65            self.file_path
66                .as_ref()
67                .ok_or_else(|| crate::client::error::ZaiError::ApiError {
68                    code: 1200,
69                    message: "file_path is required".to_string(),
70                })?;
71
72        if !Path::new(p).exists() {
73            return Err(crate::client::error::ZaiError::FileError {
74                code: 0,
75                message: format!("file_path not found: {}", p),
76            });
77        }
78
79        Ok(())
80    }
81
82    pub async fn send(&self) -> crate::ZaiResult<super::response::AudioToTextResponse>
83    where
84        N: Clone + Send + Sync + 'static,
85    {
86        self.validate()?;
87
88        let resp = self.post().await?;
89
90        let parsed = resp.json::<super::response::AudioToTextResponse>().await?;
91
92        Ok(parsed)
93    }
94}
95
96impl<N> HttpClient for AudioToTextRequest<N>
97where
98    N: ModelName + AudioToText + Serialize + Clone + Send + Sync + 'static,
99{
100    type Body = AudioToTextBody<N>;
101    type ApiUrl = &'static str;
102    type ApiKey = String;
103
104    fn api_url(&self) -> &Self::ApiUrl {
105        &"https://open.bigmodel.cn/api/paas/v4/audio/transcriptions"
106    }
107
108    fn api_key(&self) -> &Self::ApiKey {
109        &self.key
110    }
111
112    fn body(&self) -> &Self::Body {
113        &self.body
114    }
115
116    fn post(
117        &self,
118    ) -> impl std::future::Future<Output = crate::ZaiResult<reqwest::Response>> + Send {
119        let key = self.key.clone();
120
121        let url = (*self.api_url()).to_string();
122
123        let body = self.body.clone();
124
125        let file_path_opt = self.file_path.clone();
126
127        async move {
128            let file_path =
129                file_path_opt.ok_or_else(|| crate::client::error::ZaiError::ApiError {
130                    code: 1200,
131                    message: "file_path is required".to_string(),
132                })?;
133
134            let mut form = reqwest::multipart::Form::new();
135
136            // file
137            let file_name = Path::new(&file_path)
138                .file_name()
139                .and_then(|s| s.to_str())
140                .unwrap_or("audio.wav");
141            let file_bytes = tokio::fs::read(&file_path).await?;
142
143            // Basic MIME guess by extension
144            let mime = if file_name.to_ascii_lowercase().ends_with(".mp3") {
145                "audio/mpeg"
146            } else {
147                "audio/wav"
148            };
149
150            let part = reqwest::multipart::Part::bytes(file_bytes)
151                .file_name(file_name.to_string())
152                .mime_str(mime)?;
153            form = form.part("file", part);
154
155            // model
156            let model_name: String = body.model.into();
157            form = form.text("model", model_name);
158
159            // Optional fields
160            if let Some(t) = body.temperature {
161                form = form.text("temperature", t.to_string());
162            }
163            if let Some(s) = body.stream {
164                form = form.text("stream", s.to_string());
165            }
166            if let Some(rid) = body.request_id {
167                form = form.text("request_id", rid);
168            }
169            if let Some(uid) = body.user_id {
170                form = form.text("user_id", uid);
171            }
172
173            let resp = reqwest::Client::new()
174                .post(url)
175                .bearer_auth(key)
176                .multipart(form)
177                .send()
178                .await?;
179
180            Ok(resp)
181        }
182    }
183}