1use std::fmt::Display;
10
11use log::{debug, error, warn};
12use reqwest::multipart::Form;
13use serde::Serialize;
14
15use super::types::model::Model;
16use crate::utils::{config::Config, header::AdditionalHeaders};
17
18pub fn endpoint_filter(model: &Model, endpoint: &Endpoint) -> bool {
24 match endpoint {
25 Endpoint::ChatCompletion_v1 => [
26 Model::GPT_3_5_TURBO,
27 Model::GPT_3_5_TURBO_16K,
28 Model::GPT_3_5_TURBO_16K_0613,
29 Model::GPT_3_5_TURBO_0301,
30 Model::GPT_3_5_TURBO_0613,
31 Model::GPT_3_5_TURBO_1106,
32 Model::GPT_3_5_TURBO_0125,
33 Model::GPT_4,
34 Model::GPT_4_0613,
35 Model::GPT_4_0314,
36 Model::GPT_4_32K,
37 Model::GPT_4_32K_0613,
38 Model::GPT_4_32K_0314,
39 Model::GPT_4_TURBO_WITH_VISION,
40 Model::GPT_4_TURBO_PREVIEW,
41 Model::GPT_4_0125_PREVIEW,
42 Model::GPT_4_1106_PREVIEW,
43 Model::GPT_4_TURBO_1106_WITH_VISION,
44 ]
45 .contains(&model),
46 Endpoint::Completion_v1 => [
47 Model::TEXT_DAVINCI_003,
48 Model::TEXT_DAVINCI_002,
49 Model::TEXT_CURIE_001,
50 Model::TEXT_BABBAGE_001,
51 Model::TEXT_ADA_001,
52 Model::DAVINCI,
53 Model::CURIE,
54 Model::BABBAGE,
55 Model::ADA,
56 ]
57 .contains(&model),
58 Endpoint::Edit_v1 => {
59 [Model::TEXT_DAVINCI_EDIT_001, Model::CODE_DAVINCI_EDIT_001].contains(&model)
60 }
61 Endpoint::Audio_v1 => [Model::WHISPER_1].contains(&model),
62 Endpoint::FineTune_v1 => {
63 [Model::DAVINCI, Model::CURIE, Model::BABBAGE, Model::ADA].contains(&model)
64 }
65 Endpoint::Embedding_v1 => [
66 Model::TEXT_EMBEDDING_ADA_002,
67 Model::TEXT_SEARCH_ADA_DOC_001,
68 ]
69 .contains(&model),
70 Endpoint::Moderation_v1 => [
71 Model::TEXT_MODERATION_LATEST,
72 Model::TEXT_MODERATION_STABLE,
73 Model::TEXT_MODERATION_004,
74 ]
75 .contains(&model),
76 _ => false,
77 }
78}
79
80pub enum EndpointVariant {
82 None,
84 Extended(String),
86}
87
88impl From<String> for EndpointVariant {
89 fn from(value: String) -> Self {
90 Self::Extended(value)
91 }
92}
93
94#[allow(non_camel_case_types)]
96#[derive(Debug, PartialEq, Eq, Clone)]
97pub enum Endpoint {
98 ChatCompletion_v1,
99 Completion_v1,
100 Edit_v1,
101 Image_v1,
102 Audio_v1,
103 FineTune_v1,
104 Embedding_v1,
105 Moderation_v1,
106}
107
108impl Display for Endpoint {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 write!(f, "{}", <Self as Into<&str>>::into(self.clone()))
111 }
112}
113
114impl Into<&'static str> for Endpoint {
115 fn into(self) -> &'static str {
116 match self {
117 Self::Audio_v1 => "/v1/audio",
118 Self::ChatCompletion_v1 => "/v1/chat/completions",
119 Self::Completion_v1 => "/v1/completions",
120 Self::Edit_v1 => "/v1/edits",
121 Self::Image_v1 => "/v1/images",
122 Self::Embedding_v1 => "/v1/embeddings",
123 Self::FineTune_v1 => "/v1/fine-tunes",
124 Self::Moderation_v1 => "/v1/moderations",
125 }
126 }
127}
128
129#[allow(non_camel_case_types)]
131#[derive(Debug, PartialEq, Eq, Clone)]
132pub enum ImageEndpointVariant {
133 Generation,
134 Editing,
135 Variation,
136}
137
138impl Into<String> for ImageEndpointVariant {
139 fn into(self) -> String {
140 String::from(match self {
141 Self::Editing => "/edits",
142 Self::Variation => "/variations",
143 Self::Generation => "/generations",
144 })
145 }
146}
147
148#[allow(non_camel_case_types)]
150#[derive(Debug, PartialEq, Eq, Clone)]
151pub enum AudioEndpointVariant {
152 Transcription,
153 Translation,
154}
155
156impl Into<String> for AudioEndpointVariant {
157 fn into(self) -> String {
158 String::from(match self {
159 Self::Transcription => "/transcriptions",
160 Self::Translation => "/translations",
161 })
162 }
163}
164
165pub async fn request_endpoint<'a, T, F>(
173 json: &'a T,
174 endpoint: &'a Endpoint,
175 variant: EndpointVariant,
176 mut cb: F,
177) -> Result<(), Box<dyn std::error::Error>>
178where
179 T: Serialize,
180 F: FnMut(Result<String, Box<dyn std::error::Error>>),
181{
182 let client = reqwest::Client::new();
183 let config = Config::load().unwrap();
184 let url = if let EndpointVariant::Extended(var) = variant {
185 format!(
186 "{}{}{}",
187 config.openai.base_endpoint(),
188 endpoint,
189 var.to_owned()
190 )
191 } else {
192 format!("{}{}", config.openai.base_endpoint(), endpoint)
193 };
194
195 let mut req = client.post(url);
196
197 let headers = AdditionalHeaders::from_var().provide();
199 if headers.len() > 0 {
200 req = req.headers(headers);
201 }
202
203 req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
204 if config.openai.org_id.is_some() {
205 req = req.header("OpenAI-Organization", config.openai.org_id.clone().unwrap());
206 }
207
208 if let Some(req_clone) = req.try_clone() {
209 log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
210 };
211
212 let res = req.json(&json).send().await?;
213
214 if let Ok(text) = res.text().await {
215 debug!(target: "openai", "Received response from OpenAI: `{:?}`", text);
216 cb(Ok(text.clone()));
217 } else {
218 error!(target: "openai", "Error receiving response from OpenAI");
219 cb(Err("Error receiving response from OpenAI".into()))
220 }
221
222 Ok(())
223}
224
225pub async fn request_endpoint_stream<'a, T>(
234 json: &'a T,
235 endpoint: &'a Endpoint,
236 variant: EndpointVariant,
237 mut cb: impl FnMut(Result<String, Box<dyn std::error::Error>>),
238) -> Result<(), Box<dyn std::error::Error>>
239where
240 T: Serialize,
241{
242 let client = reqwest::Client::new();
243 let config = Config::load().unwrap();
244 let url = if let EndpointVariant::Extended(var) = variant {
245 format!(
246 "{}{}{}",
247 config.openai.base_endpoint(),
248 endpoint,
249 var.to_owned()
250 )
251 } else {
252 format!("{}{}", config.openai.base_endpoint(), endpoint)
253 };
254
255 let mut req = client.post(url);
256
257 let headers = AdditionalHeaders::from_var().provide();
259 if headers.len() > 0 {
260 req = req.headers(headers);
261 }
262
263 req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
264 if config.openai.org_id.is_some() {
265 req = req.header("OpenAI-Organization", config.openai.org_id.clone().unwrap());
266 }
267
268 if let Some(req_clone) = req.try_clone() {
269 log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
270 };
271
272 let mut res = req.json(&json).send().await?;
273
274 while let Some(chunk) = res.chunk().await? {
275 if let Ok(chunk_data_raw) = String::from_utf8(chunk.to_vec()) {
276 debug!(target: "openai", "Received response chunk from OpenAI: `{:?}`", chunk_data_raw);
277 cb(Ok(chunk_data_raw));
278 } else {
279 warn!(target: "openai", "Response chunk empty");
280 }
281 }
282
283 Ok(())
284}
285
286pub async fn request_endpoint_form_data<'a, F>(
294 form: Form,
295 endpoint: &'a Endpoint,
296 variant: EndpointVariant,
297 mut cb: F,
298) -> Result<(), Box<dyn std::error::Error>>
299where
300 F: FnMut(Result<String, Box<dyn std::error::Error>>),
301{
302 let client = reqwest::Client::new();
303 let config = Config::load().unwrap();
304 let url = if let EndpointVariant::Extended(var) = variant {
305 format!(
306 "{}{}{}",
307 config.openai.base_endpoint(),
308 endpoint,
309 var.to_owned()
310 )
311 } else {
312 format!("{}{}", config.openai.base_endpoint(), endpoint)
313 };
314
315 let mut req = client.post(url);
316
317 let headers = AdditionalHeaders::from_var().provide();
319 if headers.len() > 0 {
320 req = req.headers(headers);
321 }
322
323 req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
324
325 if let Some(req_clone) = req.try_clone() {
326 log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
327 };
328
329 let res = req.multipart(form).send().await?;
330
331 if let Ok(text) = res.text().await {
332 debug!(target: "openai", "Received response from OpenAI: `{:?}`", text);
333 cb(Ok(text.clone()));
334 } else {
335 error!(target: "openai", "Error receiving response from OpenAI");
336 cb(Err("Error receiving response from OpenAI".into()))
337 }
338
339 Ok(())
340}