rust_ai/openai/
endpoint.rs

1//!
2//! # OpenAI's API endpoints
3//!
4//! Note that for Audios and Images, an extended endpoint variant will be
5//! needed.
6
7////////////////////////////////////////////////////////////////////////////////
8
9use std::fmt::Display;
10
11use log::{debug, error, warn};
12use reqwest::multipart::Form;
13use serde::Serialize;
14
15use super::types::model::Model;
16use crate::utils::{config::Config, header::AdditionalHeaders};
17
18/// Check if selected model is available to certain API endpoint
19///
20/// # Arguments
21/// - `model` - A provided model enum variant
22/// - `endpoint` - API endpoint name, e.g. `/v1/completions`
23pub fn endpoint_filter(model: &Model, endpoint: &Endpoint) -> bool {
24    match endpoint {
25        Endpoint::ChatCompletion_v1 => [
26            Model::GPT_3_5_TURBO,
27            Model::GPT_3_5_TURBO_16K,
28            Model::GPT_3_5_TURBO_16K_0613,
29            Model::GPT_3_5_TURBO_0301,
30            Model::GPT_3_5_TURBO_0613,
31            Model::GPT_3_5_TURBO_1106,
32            Model::GPT_3_5_TURBO_0125,
33            Model::GPT_4,
34            Model::GPT_4_0613,
35            Model::GPT_4_0314,
36            Model::GPT_4_32K,
37            Model::GPT_4_32K_0613,
38            Model::GPT_4_32K_0314,
39            Model::GPT_4_TURBO_WITH_VISION,
40            Model::GPT_4_TURBO_PREVIEW,
41            Model::GPT_4_0125_PREVIEW,
42            Model::GPT_4_1106_PREVIEW,
43            Model::GPT_4_TURBO_1106_WITH_VISION,
44        ]
45        .contains(&model),
46        Endpoint::Completion_v1 => [
47            Model::TEXT_DAVINCI_003,
48            Model::TEXT_DAVINCI_002,
49            Model::TEXT_CURIE_001,
50            Model::TEXT_BABBAGE_001,
51            Model::TEXT_ADA_001,
52            Model::DAVINCI,
53            Model::CURIE,
54            Model::BABBAGE,
55            Model::ADA,
56        ]
57        .contains(&model),
58        Endpoint::Edit_v1 => {
59            [Model::TEXT_DAVINCI_EDIT_001, Model::CODE_DAVINCI_EDIT_001].contains(&model)
60        }
61        Endpoint::Audio_v1 => [Model::WHISPER_1].contains(&model),
62        Endpoint::FineTune_v1 => {
63            [Model::DAVINCI, Model::CURIE, Model::BABBAGE, Model::ADA].contains(&model)
64        }
65        Endpoint::Embedding_v1 => [
66            Model::TEXT_EMBEDDING_ADA_002,
67            Model::TEXT_SEARCH_ADA_DOC_001,
68        ]
69        .contains(&model),
70        Endpoint::Moderation_v1 => [
71            Model::TEXT_MODERATION_LATEST,
72            Model::TEXT_MODERATION_STABLE,
73            Model::TEXT_MODERATION_004,
74        ]
75        .contains(&model),
76        _ => false,
77    }
78}
79
80/// Enum for endpoints that have several variants.
81pub enum EndpointVariant {
82    /// No sub variants.
83    None,
84    /// Denotes a variant of some endpoint.
85    Extended(String),
86}
87
88impl From<String> for EndpointVariant {
89    fn from(value: String) -> Self {
90        Self::Extended(value)
91    }
92}
93
94/// API endpoint definition enum
95#[allow(non_camel_case_types)]
96#[derive(Debug, PartialEq, Eq, Clone)]
97pub enum Endpoint {
98    ChatCompletion_v1,
99    Completion_v1,
100    Edit_v1,
101    Image_v1,
102    Audio_v1,
103    FineTune_v1,
104    Embedding_v1,
105    Moderation_v1,
106}
107
108impl Display for Endpoint {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        write!(f, "{}", <Self as Into<&str>>::into(self.clone()))
111    }
112}
113
114impl Into<&'static str> for Endpoint {
115    fn into(self) -> &'static str {
116        match self {
117            Self::Audio_v1 => "/v1/audio",
118            Self::ChatCompletion_v1 => "/v1/chat/completions",
119            Self::Completion_v1 => "/v1/completions",
120            Self::Edit_v1 => "/v1/edits",
121            Self::Image_v1 => "/v1/images",
122            Self::Embedding_v1 => "/v1/embeddings",
123            Self::FineTune_v1 => "/v1/fine-tunes",
124            Self::Moderation_v1 => "/v1/moderations",
125        }
126    }
127}
128
129/// Endpoint variants for Images
130#[allow(non_camel_case_types)]
131#[derive(Debug, PartialEq, Eq, Clone)]
132pub enum ImageEndpointVariant {
133    Generation,
134    Editing,
135    Variation,
136}
137
138impl Into<String> for ImageEndpointVariant {
139    fn into(self) -> String {
140        String::from(match self {
141            Self::Editing => "/edits",
142            Self::Variation => "/variations",
143            Self::Generation => "/generations",
144        })
145    }
146}
147
148/// Endpoint variants for Audios
149#[allow(non_camel_case_types)]
150#[derive(Debug, PartialEq, Eq, Clone)]
151pub enum AudioEndpointVariant {
152    Transcription,
153    Translation,
154}
155
156impl Into<String> for AudioEndpointVariant {
157    fn into(self) -> String {
158        String::from(match self {
159            Self::Transcription => "/transcriptions",
160            Self::Translation => "/translations",
161        })
162    }
163}
164
165/// Send request to remote endpoint using JSON.
166///  
167/// # Arguments
168/// - `json` - the serialized contents to send
169/// - `endpoint` - Endpoint enum variant
170/// - `variant` - Endpoint variant enum
171/// - `cb` - callback function that will be called when message received.
172pub async fn request_endpoint<'a, T, F>(
173    json: &'a T,
174    endpoint: &'a Endpoint,
175    variant: EndpointVariant,
176    mut cb: F,
177) -> Result<(), Box<dyn std::error::Error>>
178where
179    T: Serialize,
180    F: FnMut(Result<String, Box<dyn std::error::Error>>),
181{
182    let client = reqwest::Client::new();
183    let config = Config::load().unwrap();
184    let url = if let EndpointVariant::Extended(var) = variant {
185        format!(
186            "{}{}{}",
187            config.openai.base_endpoint(),
188            endpoint,
189            var.to_owned()
190        )
191    } else {
192        format!("{}{}", config.openai.base_endpoint(), endpoint)
193    };
194
195    let mut req = client.post(url);
196
197    // Load additional headers from environment variable
198    let headers = AdditionalHeaders::from_var().provide();
199    if headers.len() > 0 {
200        req = req.headers(headers);
201    }
202
203    req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
204    if config.openai.org_id.is_some() {
205        req = req.header("OpenAI-Organization", config.openai.org_id.clone().unwrap());
206    }
207
208    if let Some(req_clone) = req.try_clone() {
209        log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
210    };
211
212    let res = req.json(&json).send().await?;
213
214    if let Ok(text) = res.text().await {
215        debug!(target: "openai", "Received response from OpenAI: `{:?}`", text);
216        cb(Ok(text.clone()));
217    } else {
218        error!(target: "openai", "Error receiving response from OpenAI");
219        cb(Err("Error receiving response from OpenAI".into()))
220    }
221
222    Ok(())
223}
224
225/// Send request to remote endpoint using JSON but response will be streamed.
226///  
227/// # Arguments
228/// - `json` - the serialized contents to send
229/// - `endpoint` - Endpoint enum variant
230/// - `variant` - Endpoint variant enum
231/// - `cb` - callback function that will be called when message received. Note
232/// the differences of the function parameters.
233pub async fn request_endpoint_stream<'a, T>(
234    json: &'a T,
235    endpoint: &'a Endpoint,
236    variant: EndpointVariant,
237    mut cb: impl FnMut(Result<String, Box<dyn std::error::Error>>),
238) -> Result<(), Box<dyn std::error::Error>>
239where
240    T: Serialize,
241{
242    let client = reqwest::Client::new();
243    let config = Config::load().unwrap();
244    let url = if let EndpointVariant::Extended(var) = variant {
245        format!(
246            "{}{}{}",
247            config.openai.base_endpoint(),
248            endpoint,
249            var.to_owned()
250        )
251    } else {
252        format!("{}{}", config.openai.base_endpoint(), endpoint)
253    };
254
255    let mut req = client.post(url);
256
257    // Load additional headers from environment variable
258    let headers = AdditionalHeaders::from_var().provide();
259    if headers.len() > 0 {
260        req = req.headers(headers);
261    }
262
263    req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
264    if config.openai.org_id.is_some() {
265        req = req.header("OpenAI-Organization", config.openai.org_id.clone().unwrap());
266    }
267
268    if let Some(req_clone) = req.try_clone() {
269        log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
270    };
271
272    let mut res = req.json(&json).send().await?;
273
274    while let Some(chunk) = res.chunk().await? {
275        if let Ok(chunk_data_raw) = String::from_utf8(chunk.to_vec()) {
276            debug!(target: "openai", "Received response chunk from OpenAI: `{:?}`", chunk_data_raw);
277            cb(Ok(chunk_data_raw));
278        } else {
279            warn!(target: "openai", "Response chunk empty");
280        }
281    }
282
283    Ok(())
284}
285
286/// Send request to remote endpoint using Form data.
287///  
288/// # Arguments
289/// - `form` - the constructed HTTP form to send
290/// - `endpoint` - Endpoint enum variant
291/// - `variant` - Endpoint variant enum
292/// - `cb` - callback function that will be called when message received.
293pub async fn request_endpoint_form_data<'a, F>(
294    form: Form,
295    endpoint: &'a Endpoint,
296    variant: EndpointVariant,
297    mut cb: F,
298) -> Result<(), Box<dyn std::error::Error>>
299where
300    F: FnMut(Result<String, Box<dyn std::error::Error>>),
301{
302    let client = reqwest::Client::new();
303    let config = Config::load().unwrap();
304    let url = if let EndpointVariant::Extended(var) = variant {
305        format!(
306            "{}{}{}",
307            config.openai.base_endpoint(),
308            endpoint,
309            var.to_owned()
310        )
311    } else {
312        format!("{}{}", config.openai.base_endpoint(), endpoint)
313    };
314
315    let mut req = client.post(url);
316
317    // Load additional headers from environment variable
318    let headers = AdditionalHeaders::from_var().provide();
319    if headers.len() > 0 {
320        req = req.headers(headers);
321    }
322
323    req = req.header("Authorization", format!("Bearer {}", config.openai.api_key));
324
325    if let Some(req_clone) = req.try_clone() {
326        log::debug!(target: "requests", "Headers `{:?}`", req_clone.build().unwrap().headers());
327    };
328
329    let res = req.multipart(form).send().await?;
330
331    if let Ok(text) = res.text().await {
332        debug!(target: "openai", "Received response from OpenAI: `{:?}`", text);
333        cb(Ok(text.clone()));
334    } else {
335        error!(target: "openai", "Error receiving response from OpenAI");
336        cb(Err("Error receiving response from OpenAI".into()))
337    }
338
339    Ok(())
340}