1#![allow(dead_code)]
52use std::{error::Error, fmt::Display};
53
54use async_trait::async_trait;
55use once_cell::sync::OnceCell;
56use serde::Deserialize;
57use serde_json::json;
58
59pub mod chat;
60pub mod completion;
61
62static RQCLIENT: OnceCell<reqwest::Client> = OnceCell::new();
63static COMPLETION_URL: &str = "https://api.openai.com/v1/completions";
64static CHAT_URL: &str = "https://api.openai.com/v1/chat/completions";
65
66#[derive(Debug, Clone)]
67pub struct JsonParseError {
68 json_string: String,
69}
70
71#[derive(Debug)]
72pub enum SendRequestError {
73 ReqwestError(reqwest::Error),
74 OpenAiError(String),
75 JsonError(JsonParseError),
76}
77
78impl Display for SendRequestError {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 SendRequestError::ReqwestError(e) => write!(f, "Reqwest error: {}", e),
82 SendRequestError::OpenAiError(e) => write!(f, "OpenAI error: {}", e),
83 SendRequestError::JsonError(e) => write!(f, "Json error: {}", e),
84 }
85 }
86}
87
88impl Display for JsonParseError {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "Could not parse json: {}", self.json_string)
91 }
92}
93
94impl Error for SendRequestError {}
95
96impl From<reqwest::Error> for SendRequestError {
97 fn from(e: reqwest::Error) -> Self {
98 SendRequestError::ReqwestError(e)
99 }
100}
101
102#[async_trait]
103pub trait SendRequest {
105 type Response;
107 type Error;
109 async fn send(self) -> Result<Self::Response, Self::Error>;
111}
112#[doc(hidden)]
113pub trait CompletionLike {}
114#[doc(hidden)]
115pub struct CompletionState;
116#[doc(hidden)]
117pub struct ChatState;
118#[derive(Debug, Clone)]
119pub enum CompletionModel {
121 TextDavinci003,
122 TextDavinci002,
123 CodeDavinci002,
124}
125#[derive(Debug, Clone)]
126pub enum ChatModel {
128 Gpt35Turbo,
129 GPT35Turbo0301,
130}
131
132impl CompletionLike for CompletionState {}
133impl CompletionLike for ChatState {}
134
135impl ToString for CompletionModel {
136 fn to_string(&self) -> String {
137 match self {
138 CompletionModel::TextDavinci003 => "text-davinci-003",
139 CompletionModel::TextDavinci002 => "text-davinci-002",
140 CompletionModel::CodeDavinci002 => "code-davinci-002",
141 }
142 .to_string()
143 }
144}
145
146impl ToString for ChatModel {
147 fn to_string(&self) -> String {
148 match self {
149 ChatModel::Gpt35Turbo => "gpt-3.5-turbo",
150 ChatModel::GPT35Turbo0301 => "gpt-3.5-turbo-0301",
151 }
152 .to_string()
153 }
154}
155
156#[derive(Debug)]
157pub struct Request<T> {
159 to_send: String,
160 api_key: String,
161 state: std::marker::PhantomData<T>,
162}
163
164#[async_trait]
165impl SendRequest for Request<CompletionState> {
166 type Response = completion::CompletionResponse;
167 type Error = SendRequestError;
168 async fn send(self) -> Result<Self::Response, Self::Error> {
169 use SendRequestError::*;
170 let client = RQCLIENT.get_or_init(reqwest::Client::new);
171
172 let resp = client
173 .post(COMPLETION_URL)
174 .header("Content-Type", "application/json")
175 .header("Authorization", self.api_key)
176 .body(self.to_send)
177 .send()
178 .await?;
179
180 let body = resp.text().await.unwrap();
181 let json: serde_json::Value = serde_json::from_str(&body).unwrap();
182
183 let response = match completion::CompletionResponse::deserialize(json.clone()) {
184 Ok(r) => r,
185 Err(_) => return Err(JsonError(JsonParseError { json_string: serde_json::to_string_pretty(&json).unwrap() })),
186 };
187
188 Ok(response)
189 }
190}
191
192#[async_trait]
193impl SendRequest for Request<ChatState> {
194 type Response = chat::ChatResponse;
195 type Error = SendRequestError;
196
197
198 async fn send(self) -> Result<Self::Response, SendRequestError> {
199 use SendRequestError::*;
200
201 if !self.to_send.contains("messages") {
202 return Err(OpenAiError("No messages in request.".into()));
203 }
204
205 let client = RQCLIENT.get_or_init(reqwest::Client::new);
206
207 let resp = client
208 .post(CHAT_URL)
209 .header("Content-Type", "application/json")
210 .header("Authorization", self.api_key)
211 .body(self.to_send)
212 .send()
213 .await?;
214
215 let body = resp.text().await.unwrap();
216 let json: serde_json::Value = serde_json::from_str(&body).unwrap();
217
218 if !json["error"].is_null() {
219 return Err(OpenAiError(serde_json::to_string_pretty(&json).unwrap().into()));
220 }
221
222 let response = match chat::ChatResponse::deserialize(json.clone()) {
223 Ok(r) => r,
224 Err(_) => return Err(JsonError(JsonParseError { json_string: serde_json::to_string_pretty(&json).unwrap() })),
225 };
226
227 Ok(response)
228
229 }
245}
246
247#[derive(Debug)]
248pub struct RequestBuilder<T> {
250 req: serde_json::Value,
251 api_key: String,
252 state: std::marker::PhantomData<T>,
253}
254
255impl<C: CompletionLike> RequestBuilder<C> {
256 pub fn new<T: ToString, S: Display>(model: T, api_key: S) -> Self {
258 let api_key = format!("Bearer {api_key}");
259
260 let req = json!({
261 "model": model.to_string(),
262 });
263
264 Self {
265 req,
266 api_key,
267 state: std::marker::PhantomData,
268 }
269 }
270 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
272 self.req["max_tokens"] = json!(max_tokens);
273 self
274 }
275 pub fn temperature(mut self, temperature: f32) -> Self {
277 self.req["temperature"] = json!(temperature);
278 self
279 }
280 pub fn top_p(mut self, top_p: f32) -> Self {
282 self.req["top_p"] = json!(top_p);
283 self
284 }
285 pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
287 self.req["frequency_penalty"] = json!(frequency_penalty);
288 self
289 }
290 pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
292 self.req["presence_penalty"] = json!(presence_penalty);
293 self
294 }
295 pub fn stop<T: ToString>(mut self, stop: T) -> Self {
297 self.req["stop"] = json!(stop.to_string());
298 self
299 }
300 pub fn n(mut self, n: u32) -> Self {
302 self.req["n"] = json!(n);
303 self
304 }
305
306 pub fn user(mut self, user: String) -> Self {
307 self.req["user"] = json!(user);
308 self
309 }
310}
311
312impl RequestBuilder<CompletionState> {
313 pub fn prompt<T: ToString>(mut self, prompt: T) -> Self {
315 self.req["prompt"] = json!(prompt.to_string());
316 self
317 }
318 pub fn build_completion(self) -> Request<CompletionState> {
320 Request {
321 api_key: self.api_key,
322 to_send: self.req.to_string(),
323 state: std::marker::PhantomData,
324 }
325 }
326}
327
328impl RequestBuilder<ChatState> {
329 pub fn messages(mut self, messages: Vec<chat::ChatMessage>) -> Self {
331 self.req["messages"] = json!(messages);
332 self
333 }
334
335 fn chat_parameters(mut self, chat_parameters: chat::ChatParameters) -> Self {
336 let mut params = json!(chat_parameters);
337 params["messages"] = self.req.get("messages").unwrap().clone();
338 params["model"] = self.req.get("model").unwrap().clone();
339 self.req = params;
340 self
341 }
342
343 pub fn build_chat(self) -> Request<ChatState> {
345 Request {
346 api_key: self.api_key,
347 to_send: self.req.to_string(),
348 state: std::marker::PhantomData,
349 }
350 }
351}