use std::fmt::Display;
use log::{debug, error, warn};
use reqwest::multipart::Form;
use serde::Serialize;
use super::types::model::Model;
use crate::utils::{config::Config, header::AdditionalHeaders};
pub fn endpoint_filter(model: &Model, endpoint: &Endpoint) -> bool {
match endpoint {
Endpoint::ChatCompletion_v1 => [
Model::GPT_3_5_TURBO,
Model::GPT_3_5_TURBO_0301,
Model::GPT_4,
Model::GPT_4_0314,
Model::GPT_4_32K,
Model::GPT_4_32K_0314,
]
.contains(&model),
Endpoint::Completion_v1 => [
Model::TEXT_DAVINCI_003,
Model::TEXT_DAVINCI_002,
Model::TEXT_CURIE_001,
Model::TEXT_BABBAGE_001,
Model::TEXT_ADA_001,
Model::DAVINCI,
Model::CURIE,
Model::BABBAGE,
Model::ADA,
]
.contains(&model),
Endpoint::Edit_v1 => {
[Model::TEXT_DAVINCI_EDIT_001, Model::CODE_DAVINCI_EDIT_001].contains(&model)
}
Endpoint::Audio_v1 => [Model::WHISPER_1].contains(&model),
Endpoint::FineTune_v1 => {
[Model::DAVINCI, Model::CURIE, Model::BABBAGE, Model::ADA].contains(&model)
}
Endpoint::Embedding_v1 => [
Model::TEXT_EMBEDDING_ADA_002,
Model::TEXT_SEARCH_ADA_DOC_001,
]
.contains(&model),
Endpoint::Moderation_v1 => [
Model::TEXT_MODERATION_LATEST,
Model::TEXT_MODERATION_STABLE,
Model::TEXT_MODERATION_004,
]
.contains(&model),
_ => false,
}
}
pub enum EndpointVariant {
None,
Extended(String),
}
impl From<String> for EndpointVariant {
fn from(value: String) -> Self {
Self::Extended(value)
}
}
#[allow(non_camel_case_types)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Endpoint {
ChatCompletion_v1,
Completion_v1,
Edit_v1,
Image_v1,
Audio_v1,
FineTune_v1,
Embedding_v1,
Moderation_v1,
}
impl Display for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", <Self as Into<&str>>::into(self.clone()))
}
}
impl Into<&'static str> for Endpoint {
fn into(self) -> &'static str {
match self {
Self::Audio_v1 => "/v1/audio",
Self::ChatCompletion_v1 => "/v1/chat/completions",
Self::Completion_v1 => "/v1/completions",
Self::Edit_v1 => "/v1/edits",
Self::Image_v1 => "/v1/images",
Self::Embedding_v1 => "/v1/embeddings",
Self::FineTune_v1 => "/v1/fine-tunes",
Self::Moderation_v1 => "/v1/moderations",
}
}
}
#[allow(non_camel_case_types)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum ImageEndpointVariant {
Generation,
Editing,
Variation,
}
impl Into<String> for ImageEndpointVariant {
fn into(self) -> String {
String::from(match self {
Self::Editing => "/edits",
Self::Variation => "/variations",
Self::Generation => "/generations",
})
}
}
#[allow(non_camel_case_types)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum AudioEndpointVariant {
Transcription,
Translation,
}
impl Into<String> for AudioEndpointVariant {
fn into(self) -> String {
String::from(match self {
Self::Transcription => "/transcriptions",
Self::Translation => "/translations",
})
}
}
pub async fn request_endpoint<'a, T, F>(
json: &'a T,
endpoint: &'a Endpoint,
variant: EndpointVariant,
mut cb: F,
) -> Result<(), Box<dyn std::error::Error>>
where
T: Serialize,
F: FnMut(Result<String, Box<dyn std::error::Error>>),
{
let client = reqwest::Client::new();
let config = Config::load().unwrap();
let url = if let EndpointVariant::Extended(var) = variant {
format!(
"{}{}{}",
config.openai.base_endpoint(),
endpoint,
var.to_owned()
)
} else {
format!("{}{}", config.openai.base_endpoint(), endpoint)
};
let mut req = client.post(url);
let headers = AdditionalHeaders::from_var().provide();
if headers.len() > 0 {
req = req.headers(headers);
}
req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
if config.openai.org_id.is_some() {
req = req.header("OpenAI-Organization", config.openai.org_id.clone().unwrap());
}
if let Some(req_clone) = req.try_clone() {
log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
};
let res = req.json(&json).send().await?;
if let Ok(text) = res.text().await {
debug!(target: "openai", "Received response from OpenAI: `{:?}`", text);
cb(Ok(text.clone()));
} else {
error!(target: "openai", "Error receiving response from OpenAI");
cb(Err("Error receiving response from OpenAI".into()))
}
Ok(())
}
pub async fn request_endpoint_stream<'a, T>(
json: &'a T,
endpoint: &'a Endpoint,
variant: EndpointVariant,
mut cb: impl FnMut(Result<String, Box<dyn std::error::Error>>),
) -> Result<(), Box<dyn std::error::Error>>
where
T: Serialize,
{
let client = reqwest::Client::new();
let config = Config::load().unwrap();
let url = if let EndpointVariant::Extended(var) = variant {
format!(
"{}{}{}",
config.openai.base_endpoint(),
endpoint,
var.to_owned()
)
} else {
format!("{}{}", config.openai.base_endpoint(), endpoint)
};
let mut req = client.post(url);
let headers = AdditionalHeaders::from_var().provide();
if headers.len() > 0 {
req = req.headers(headers);
}
req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
if config.openai.org_id.is_some() {
req = req.header("OpenAI-Organization", config.openai.org_id.clone().unwrap());
}
if let Some(req_clone) = req.try_clone() {
log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
};
let mut res = req.json(&json).send().await?;
while let Some(chunk) = res.chunk().await? {
if let Ok(chunk_data_raw) = String::from_utf8(chunk.to_vec()) {
debug!(target: "openai", "Received response chunk from OpenAI: `{:?}`", chunk_data_raw);
cb(Ok(chunk_data_raw));
} else {
warn!(target: "openai", "Response chunk empty");
}
}
Ok(())
}
pub async fn request_endpoint_form_data<'a, F>(
form: Form,
endpoint: &'a Endpoint,
variant: EndpointVariant,
mut cb: F,
) -> Result<(), Box<dyn std::error::Error>>
where
F: FnMut(Result<String, Box<dyn std::error::Error>>),
{
let client = reqwest::Client::new();
let config = Config::load().unwrap();
let url = if let EndpointVariant::Extended(var) = variant {
format!(
"{}{}{}",
config.openai.base_endpoint(),
endpoint,
var.to_owned()
)
} else {
format!("{}{}", config.openai.base_endpoint(), endpoint)
};
let mut req = client.post(url);
let headers = AdditionalHeaders::from_var().provide();
if headers.len() > 0 {
req = req.headers(headers);
}
req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
if let Some(req_clone) = req.try_clone() {
log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
};
let res = req.multipart(form).send().await?;
if let Ok(text) = res.text().await {
debug!(target: "openai", "Received response from OpenAI: `{:?}`", text);
cb(Ok(text.clone()));
} else {
error!(target: "openai", "Error receiving response from OpenAI");
cb(Err("Error receiving response from OpenAI".into()))
}
Ok(())
}