use crate::openai::{
endpoint::{
request_endpoint_form_data, Endpoint, EndpointVariant, AudioEndpointVariant,
},
types::{
audio::{AudioResponse, Format},
common::Error,
model::Model,
},
};
use isolang::Language;
use log::{debug, warn};
use reqwest::multipart::{Form, Part};
#[derive(Debug)]
pub struct Audio {
pub model: Model,
pub prompt: Option<String>,
pub response_format: Option<Format>,
pub temperature: Option<f32>,
pub language: Option<Language>,
}
impl Default for Audio {
fn default() -> Self {
Self {
model: Model::WHISPER_1,
prompt: None,
response_format: None,
temperature: None,
language: None,
}
}
}
impl Audio {
fn mime(
&self,
file_name: &str,
) -> Result<&'static str, Box<dyn std::error::Error>> {
Ok(match *file_name.split('.').collect::<Vec<&str>>().last().unwrap() {
"mp3" => "audio/mpeg",
"mp4" => "audio/mp4",
"mpeg" => "audio/mpeg",
"mpga" => "audio/mpeg",
"m4a" => "audio/mp4a-latm",
"wav" => "audio/wav",
"webm" => "audio/webm",
_ => return Err("Unsupported format!".into()),
})
}
pub fn model(self, model: Model)->Self {
Self {
model,
..self
}
}
pub fn prompt(self, content: &str)->Self {
Self {
prompt: Some(content.into()),
..self
}
}
pub fn response_format(self, response_format: Format)->Self {
Self {
response_format: Some(response_format.into()),
..self
}
}
pub fn temperature(self, temperature: f32)->Self {
Self {
temperature: Some(temperature.into()),
..self
}
}
pub fn language(self, language: Language)->Self {
Self {
language: Some(language.into()),
..self
}
}
pub async fn transcription(&self, file_name: String, bytes: Vec<u8>) -> Result<AudioResponse, Box<dyn std::error::Error>> {
let mut audio_response: Option<AudioResponse> = None;
let mut form = Form::new();
let file = Part::bytes(bytes)
.file_name(file_name.clone())
.mime_str(self.mime(&file_name).unwrap())
.unwrap();
form = form.part("file", file);
let m: &str = self.model.clone().into();
form = form.part("model", Part::text(m));
if let Some(prompt) = self.prompt.clone() {
form = form.part("prompt", Part::text(prompt));
}
if let Some(lang) = self.language {
form = form.part("language", Part::text(lang.to_639_1().unwrap()));
}
if let Some(fmt) = self.response_format.clone() {
let fmt: &str = fmt.into();
form = form.part("response_format", Part::text(fmt));
}
if let Some(temp) = self.temperature {
let temp = temp.to_string();
form = form.part("temperature", Part::text(temp));
}
let variant: String = AudioEndpointVariant::Transcription.into();
request_endpoint_form_data(
form,
&Endpoint::Audio_v1, EndpointVariant::Extended(variant),
|res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<AudioResponse>(&text) {
debug!(target: "openai", "Response parsed, audio transcription response deserialized.");
audio_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Audio response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = audio_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
pub async fn translation(&self, file_name: String, bytes: Vec<u8>) -> Result<AudioResponse, Box<dyn std::error::Error>> {
let mut audio_response: Option<AudioResponse> = None;
let mut form = Form::new();
let file = Part::bytes(bytes)
.file_name(file_name.clone())
.mime_str(self.mime(&file_name).unwrap())
.unwrap();
form = form.part("file", file);
let m: &str = self.model.clone().into();
form = form.part("model", Part::text(m));
if let Some(prompt) = self.prompt.clone() {
form = form.part("prompt", Part::text(prompt));
}
if let Some(fmt) = self.response_format.clone() {
let fmt: &str = fmt.into();
form = form.part("response_format", Part::text(fmt));
}
if let Some(temp) = self.temperature {
let temp = temp.to_string();
form = form.part("temperature", Part::text(temp));
}
let variant: String = AudioEndpointVariant::Transcription.into();
request_endpoint_form_data(
form,
&Endpoint::Audio_v1, EndpointVariant::Extended(variant),
|res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<AudioResponse>(&text) {
debug!(target: "openai", "Response parsed, audio translation response deserialized.");
audio_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Audio response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = audio_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
}