zai_rs/model/audio_to_text/
data.rs

1use super::super::traits::*;
2use super::request::AudioTranscriptionBody;
3use serde::Serialize;
4use std::path::Path;
5
6use validator::Validate;
7
8use crate::client::http::HttpClient;
9
10/// Audio transcription request (multipart/form-data)
11pub struct AudioTranscriptionRequest<N>
12where
13    N: ModelName + AudioToText + Serialize,
14{
15    pub key: String,
16    pub body: AudioTranscriptionBody<N>,
17    file_path: Option<String>,
18}
19
20impl<N> AudioTranscriptionRequest<N>
21where
22    N: ModelName + AudioToText + Serialize + Clone,
23{
24    pub fn new(model: N, key: String) -> Self {
25        Self {
26            key,
27            body: AudioTranscriptionBody::new(model),
28            file_path: None,
29        }
30    }
31
32    pub fn with_file_path(mut self, path: impl Into<String>) -> Self {
33        self.file_path = Some(path.into());
34        self
35    }
36
37    pub fn with_temperature(mut self, temperature: f32) -> Self {
38        self.body = self.body.with_temperature(temperature);
39        self
40    }
41
42    pub fn with_stream(mut self, stream: bool) -> Self {
43        self.body = self.body.with_stream(stream);
44        self
45    }
46
47    pub fn with_request_id(mut self, request_id: impl Into<String>) -> Self {
48        self.body = self.body.with_request_id(request_id);
49        self
50    }
51
52    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
53        self.body = self.body.with_user_id(user_id);
54        self
55    }
56
57    pub fn validate(&self) -> anyhow::Result<()> {
58        // Check body constraints
59        self.body.validate().map_err(|e| anyhow::anyhow!(e))?;
60        // Ensure file path exists
61        let p = self
62            .file_path
63            .as_ref()
64            .ok_or_else(|| anyhow::anyhow!("file_path is required"))?;
65        if !Path::new(p).exists() {
66            return Err(anyhow::anyhow!(format!("file_path not found: {}", p)));
67        }
68        Ok(())
69    }
70
71    pub async fn send(&self) -> anyhow::Result<super::response::AudioTranscriptionResponse>
72    where
73        N: Clone + Send + Sync + 'static,
74    {
75        self.validate()?;
76        let resp = self.post().await?;
77        let parsed = resp
78            .json::<super::response::AudioTranscriptionResponse>()
79            .await?;
80        Ok(parsed)
81    }
82}
83
84impl<N> HttpClient for AudioTranscriptionRequest<N>
85where
86    N: ModelName + AudioToText + Serialize + Clone + Send + Sync + 'static,
87{
88    type Body = AudioTranscriptionBody<N>;
89    type ApiUrl = &'static str;
90    type ApiKey = String;
91
92    fn api_url(&self) -> &Self::ApiUrl {
93        &"https://open.bigmodel.cn/api/paas/v4/audio/transcriptions"
94    }
95
96    fn api_key(&self) -> &Self::ApiKey {
97        &self.key
98    }
99
100    fn body(&self) -> &Self::Body {
101        &self.body
102    }
103
104    // Override default JSON post with multipart/form-data post
105    fn post(&self) -> impl std::future::Future<Output = anyhow::Result<reqwest::Response>> + Send {
106        let key = self.key.clone();
107        let url = (*self.api_url()).to_string();
108        let body = self.body.clone();
109        let file_path_opt = self.file_path.clone();
110
111        async move {
112            let file_path =
113                file_path_opt.ok_or_else(|| anyhow::anyhow!("file_path is required"))?;
114
115            let mut form = reqwest::multipart::Form::new();
116
117            // file
118            let file_name = Path::new(&file_path)
119                .file_name()
120                .and_then(|s| s.to_str())
121                .unwrap_or("audio.wav");
122            let file_bytes = tokio::fs::read(&file_path).await?;
123
124            // Basic MIME guess by extension
125            let mime = if file_name.to_ascii_lowercase().ends_with(".mp3") {
126                "audio/mpeg"
127            } else {
128                "audio/wav"
129            };
130
131            let part = reqwest::multipart::Part::bytes(file_bytes)
132                .file_name(file_name.to_string())
133                .mime_str(mime)?;
134            form = form.part("file", part);
135
136            // model
137            let model_name: String = body.model.into();
138            form = form.text("model", model_name);
139
140            // Optional fields
141            if let Some(t) = body.temperature {
142                form = form.text("temperature", t.to_string());
143            }
144            if let Some(s) = body.stream {
145                form = form.text("stream", s.to_string());
146            }
147            if let Some(rid) = body.request_id {
148                form = form.text("request_id", rid);
149            }
150            if let Some(uid) = body.user_id {
151                form = form.text("user_id", uid);
152            }
153
154            let resp = reqwest::Client::new()
155                .post(url)
156                .bearer_auth(key)
157                .multipart(form)
158                .send()
159                .await?;
160
161            Ok(resp)
162        }
163    }
164}