rust_gpt/
lib.rs

1//! # OpenAI Completion/Chat Rust API
2//! Provides a neat and rusty way of interacting with the OpenAI Completion/Chat API.
3//! You can find the documentation for the API [here](https://platform.openai.com/docs/api-reference/completions).
4//! ## Example
5//! ```rust no_run
6//! use rust_gpt::RequestBuilder;
7//! use rust_gpt::CompletionModel;
8//! use rust_gpt::SendRequest;
9//!
10//! #[tokio::main]
11//! async fn main() {
12//!     let req = RequestBuilder::new(CompletionModel::TextDavinci003, "YOUR_API_KEY")
13//!         .prompt("Write a sonnet about a crab named Ferris in the style of Shakespeare.")
14//!         .build_completion();
15//!     let response = req.send().await.unwrap();
16//!     println!("My bot replied with: \"{:?}\"", response);
17//! }
18//!```
19//!
20//! ## General Usage
21//! You will most likely just use the [`RequestBuilder`] to create a request. You can then use the [`SendRequest`] trait to send the request.
22//! Right now only the completion and chat endpoints are supported.
23//! These two endpoints require different parameters, so you will need to use the [`build_completion`] and [`build_chat`] methods respectively.  
24//!
25//! [`RequestBuilder`] can take any type that implements [`ToString`] as the model input and any type that implements [`Display`] as the API key.
26//!
27//! [`build_completion`]: ./struct.RequestBuilder.html#method.build_completion
28//! [`build_chat`]: ./struct.RequestBuilder.html#method.build_chat
29//!
30//! ## Completion
31//! The completion endpoint requires a [`prompt`] parameter. You can set this with the [`prompt`] method which takes any type that implements [`ToString`].
32//!
33//! [`prompt`]: ./struct.RequestBuilder.html#method.prompt
34//!
35//! ## Chat
36//! The chat endpoint is a little more complicated. It requires a [`messages`] parameter which is a list of messages.
37//! These messages are represented by the [`ChatMessage`] struct. You can create a [`ChatMessage`] with the [`new`] method.
38//!
39//! [`messages`]: ./struct.RequestBuilder.html#method.messages
40//! [`new`]: ./struct.ChatMessage.html#method.new
41//!
42//!
43//!
44//! ## Additional Notes
45//! The API is still in development, so there may be some breaking changes in the future.  
46//! The API is also not fully tested, so there may be some bugs.  
47//! There is a little bit of error handling, but it is not very robust.  
48//! [serde_json](https://docs.rs/serde_json/latest/serde_json/) is used to seralize and deserialize the responses and messages. Although since many are derived they may not match up with the exact API json responses.
49//!
50
51#![allow(dead_code)]
52use std::{error::Error, fmt::Display};
53
54use async_trait::async_trait;
55use once_cell::sync::OnceCell;
56use serde::Deserialize;
57use serde_json::json;
58
59pub mod chat;
60pub mod completion;
61
62static RQCLIENT: OnceCell<reqwest::Client> = OnceCell::new();
63static COMPLETION_URL: &str = "https://api.openai.com/v1/completions";
64static CHAT_URL: &str = "https://api.openai.com/v1/chat/completions";
65
66#[derive(Debug, Clone)]
67pub struct JsonParseError {
68    json_string: String,
69}
70
71#[derive(Debug)]
72pub enum SendRequestError {
73    ReqwestError(reqwest::Error),
74    OpenAiError(String),
75    JsonError(JsonParseError),
76}
77
78impl Display for SendRequestError {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            SendRequestError::ReqwestError(e) => write!(f, "Reqwest error: {}", e),
82            SendRequestError::OpenAiError(e) => write!(f, "OpenAI error: {}", e),
83            SendRequestError::JsonError(e) => write!(f, "Json error: {}", e),
84        }
85    }
86}
87
88impl Display for JsonParseError {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        write!(f, "Could not parse json: {}", self.json_string)
91    }
92}
93
94impl Error for SendRequestError {}
95
96impl From<reqwest::Error> for SendRequestError {
97    fn from(e: reqwest::Error) -> Self {
98        SendRequestError::ReqwestError(e)
99    }
100}
101
102#[async_trait]
103/// A trait for abstracting sending requests between APIs.
104pub trait SendRequest {
105    /// The type of the response.
106    type Response;
107    /// The type of the error.
108    type Error;
109    /// Sends the request, returning whether or not there was an error with the response.
110    async fn send(self) -> Result<Self::Response, Self::Error>;
111}
112#[doc(hidden)]
113pub trait CompletionLike {}
114#[doc(hidden)]
115pub struct CompletionState;
116#[doc(hidden)]
117pub struct ChatState;
118#[derive(Debug, Clone)]
119/// The current completion models.
120pub enum CompletionModel {
121    TextDavinci003,
122    TextDavinci002,
123    CodeDavinci002,
124}
125#[derive(Debug, Clone)]
126/// The current chat models.
127pub enum ChatModel {
128    Gpt35Turbo,
129    GPT35Turbo0301,
130}
131
132impl CompletionLike for CompletionState {}
133impl CompletionLike for ChatState {}
134
135impl ToString for CompletionModel {
136    fn to_string(&self) -> String {
137        match self {
138            CompletionModel::TextDavinci003 => "text-davinci-003",
139            CompletionModel::TextDavinci002 => "text-davinci-002",
140            CompletionModel::CodeDavinci002 => "code-davinci-002",
141        }
142        .to_string()
143    }
144}
145
146impl ToString for ChatModel {
147    fn to_string(&self) -> String {
148        match self {
149            ChatModel::Gpt35Turbo => "gpt-3.5-turbo",
150            ChatModel::GPT35Turbo0301 => "gpt-3.5-turbo-0301",
151        }
152        .to_string()
153    }
154}
155
156#[derive(Debug)]
157/// A generic request which can be used to send requests to the OpenAI API.
158pub struct Request<T> {
159    to_send: String,
160    api_key: String,
161    state: std::marker::PhantomData<T>,
162}
163
164#[async_trait]
165impl SendRequest for Request<CompletionState> {
166    type Response = completion::CompletionResponse;
167    type Error = SendRequestError;
168    async fn send(self) -> Result<Self::Response, Self::Error> {
169        use SendRequestError::*;
170        let client = RQCLIENT.get_or_init(reqwest::Client::new);
171
172        let resp = client
173            .post(COMPLETION_URL)
174            .header("Content-Type", "application/json")
175            .header("Authorization", self.api_key)
176            .body(self.to_send)
177            .send()
178            .await?;
179
180        let body = resp.text().await.unwrap();
181        let json: serde_json::Value = serde_json::from_str(&body).unwrap();
182
183        let response = match completion::CompletionResponse::deserialize(json.clone()) {
184            Ok(r) => r,
185            Err(_) => return Err(JsonError(JsonParseError { json_string: serde_json::to_string_pretty(&json).unwrap() })),
186        };
187
188        Ok(response)
189    }
190}
191
192#[async_trait]
193impl SendRequest for Request<ChatState> {
194    type Response = chat::ChatResponse;
195    type Error = SendRequestError;
196
197    
198    async fn send(self) -> Result<Self::Response, SendRequestError> {
199        use SendRequestError::*;
200
201        if !self.to_send.contains("messages") {
202            return Err(OpenAiError("No messages in request.".into()));
203        }
204
205        let client = RQCLIENT.get_or_init(reqwest::Client::new);
206
207        let resp = client
208            .post(CHAT_URL)
209            .header("Content-Type", "application/json")
210            .header("Authorization", self.api_key)
211            .body(self.to_send)
212            .send()
213            .await?;
214
215        let body = resp.text().await.unwrap();
216        let json: serde_json::Value = serde_json::from_str(&body).unwrap();
217
218        if !json["error"].is_null() {
219            return Err(OpenAiError(serde_json::to_string_pretty(&json).unwrap().into()));
220        }
221
222        let response = match chat::ChatResponse::deserialize(json.clone()) {
223            Ok(r) => r,
224            Err(_) => return Err(JsonError(JsonParseError { json_string: serde_json::to_string_pretty(&json).unwrap() })),
225        };
226
227        Ok(response)
228
229        // Ok(ChatResponse {
230        //     id: json["id"].as_str().unwrap().to_string(),
231        //     object: json["object"].as_str().unwrap().to_string(),
232        //     created: json["created"].as_u64().unwrap(),
233        //     model: json["model"].as_str().unwrap().to_string(),
234        //     usage: (
235        //         json["usage"]["prompt_tokens"].as_u64().unwrap() as u32,
236        //         json["usage"]["completion_tokens"].as_u64().unwrap() as u32,
237        //         json["usage"]["total_tokens"].as_u64().unwrap() as u32,
238        //     ),
239        //     choices: json["choices"].as_array().unwrap().iter().map(|message| ChatMessage {
240        //         role: message["message"]["role"].as_str().unwrap().try_into().unwrap(),
241        //         content: message["message"]["content"].as_str().unwrap().to_string(),
242        //     }).collect()
243        // })
244    }
245}
246
247#[derive(Debug)]
248/// A builder for creating requests to the OpenAI API.
249pub struct RequestBuilder<T> {
250    req: serde_json::Value,
251    api_key: String,
252    state: std::marker::PhantomData<T>,
253}
254
255impl<C: CompletionLike> RequestBuilder<C> {
256    /// Create a new request builder.
257    pub fn new<T: ToString, S: Display>(model: T, api_key: S) -> Self {
258        let api_key = format!("Bearer {api_key}");
259
260        let req = json!({
261            "model": model.to_string(),
262        });
263
264        Self {
265            req,
266            api_key,
267            state: std::marker::PhantomData,
268        }
269    }
270    /// Set the max_tokens parameter.
271    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
272        self.req["max_tokens"] = json!(max_tokens);
273        self
274    }
275    /// Set the temperature parameter.
276    pub fn temperature(mut self, temperature: f32) -> Self {
277        self.req["temperature"] = json!(temperature);
278        self
279    }
280    /// Set the top_p parameter.
281    pub fn top_p(mut self, top_p: f32) -> Self {
282        self.req["top_p"] = json!(top_p);
283        self
284    }
285    /// Set the frequency_penalty parameter.
286    pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
287        self.req["frequency_penalty"] = json!(frequency_penalty);
288        self
289    }
290    /// Set the presence_penalty parameter.
291    pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
292        self.req["presence_penalty"] = json!(presence_penalty);
293        self
294    }
295    /// Set the stop parameter.
296    pub fn stop<T: ToString>(mut self, stop: T) -> Self {
297        self.req["stop"] = json!(stop.to_string());
298        self
299    }
300    /// Set the n parameter.
301    pub fn n(mut self, n: u32) -> Self {
302        self.req["n"] = json!(n);
303        self
304    }
305
306    pub fn user(mut self, user: String) -> Self {
307        self.req["user"] = json!(user);
308        self
309    }
310}
311
312impl RequestBuilder<CompletionState> {
313    /// Set the prompt parameter.
314    pub fn prompt<T: ToString>(mut self, prompt: T) -> Self {
315        self.req["prompt"] = json!(prompt.to_string());
316        self
317    }
318    /// Builds a completion request.
319    pub fn build_completion(self) -> Request<CompletionState> {
320        Request {
321            api_key: self.api_key,
322            to_send: self.req.to_string(),
323            state: std::marker::PhantomData,
324        }
325    }
326}
327
328impl RequestBuilder<ChatState> {
329    /// Set the messages parameter.
330    pub fn messages(mut self, messages: Vec<chat::ChatMessage>) -> Self {
331        self.req["messages"] = json!(messages);
332        self
333    }
334
335    fn chat_parameters(mut self, chat_parameters: chat::ChatParameters) -> Self {
336        let mut params = json!(chat_parameters);
337        params["messages"] = self.req.get("messages").unwrap().clone();
338        params["model"] = self.req.get("model").unwrap().clone();
339        self.req = params;
340        self
341    }
342
343    /// Builds a chat request.
344    pub fn build_chat(self) -> Request<ChatState> {
345        Request {
346            api_key: self.api_key,
347            to_send: self.req.to_string(),
348            state: std::marker::PhantomData,
349        }
350    }
351}