rust_ai/openai/apis/chat_completion.rs
1//!
2//! Given a chat conversation, the model will return a chat completion response.
3//!
4//! Source: OpenAI documentation
5
6////////////////////////////////////////////////////////////////////////////////
7
8use std::collections::HashMap;
9
10use crate::openai::{
11 endpoint::{
12 endpoint_filter, request_endpoint, request_endpoint_stream, Endpoint, EndpointVariant,
13 },
14 types::{
15 chat_completion::{ChatCompletionResponse, ChatMessage, Chunk, MessageRole},
16 common::Error,
17 Model,
18 },
19};
20use log::{debug, warn};
21use serde::{Deserialize, Serialize};
22use serde_with::serde_as;
23
24/// Given a chat conversation, the model will return a chat completion response.
25#[serde_as]
26#[derive(Serialize, Deserialize, Debug)]
27pub struct ChatCompletion {
28 pub model: Model,
29
30 pub messages: Vec<ChatMessage>,
31
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub stream: Option<bool>,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub temperature: Option<f32>,
37
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub top_p: Option<f32>,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub n: Option<u32>,
43
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub stop: Option<Vec<String>>,
46
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub max_tokens: Option<u32>,
49
50 #[serde(skip_serializing_if = "Option::is_none")]
51 pub presence_penalty: Option<f32>,
52
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub frequency_penalty: Option<f32>,
55
56 #[serde_as(as = "Option<Vec<(_,_)>>")]
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub logit_bias: Option<HashMap<String, f32>>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub user: Option<String>,
62}
63
64impl Default for ChatCompletion {
65 fn default() -> Self {
66 Self {
67 model: Model::GPT_3_5_TURBO,
68 messages: vec![],
69 stream: Some(false),
70 temperature: None,
71 top_p: None,
72 n: None,
73 stop: None,
74 max_tokens: None,
75 presence_penalty: None,
76 frequency_penalty: None,
77 logit_bias: None,
78 user: None,
79 }
80 }
81}
82
83impl ChatCompletion {
84 /// ID of the model to use. See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table
85 /// for details on which models work with the Chat API.
86 ///
87 /// # Argument
88 /// - `model` - Target model to make use of
89 pub fn model(self, model: Model) -> Self {
90 Self { model, ..self }
91 }
92
93 /// Add message to prompt by role and content.
94 ///
95 /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
96 ///
97 /// # Arguments
98 /// - `role` - Message role enum variant
99 /// - `content` - Message content
100 pub fn message(self, role: MessageRole, content: &str) -> Self {
101 let mut messages = if self.messages.len() == 0 {
102 vec![]
103 } else {
104 self.messages
105 };
106 messages.push(ChatMessage::new(role, content));
107
108 Self {
109 messages: messages,
110 ..self
111 }
112 }
113
114 /// Add message to prompt by message instance.
115 ///
116 /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
117 ///
118 /// # Argument
119 /// - `messages` - Message instance vector, will replace all existing
120 /// messages
121 pub fn messages(self, messages: Vec<ChatMessage>) -> Self {
122 Self { messages, ..self }
123 }
124
125 /// What sampling temperature to use, between 0 and 2. Higher values like 0.
126 /// 8 will make the output more random, while lower values like 0.2 will
127 /// make it more focused and deterministic.
128 ///
129 /// We generally recommend altering this or `top_p` but not both.
130 pub fn temperature(self, temperature: f32) -> Self {
131 Self {
132 temperature: Some(temperature),
133 ..self
134 }
135 }
136
137 /// An alternative to sampling with temperature, called nucleus sampling,
138 /// where the model considers the results of the tokens with top_p
139 /// probability mass. So 0.1 means only the tokens comprising the top 10%
140 /// probability mass are considered.
141 ///
142 /// We generally recommend altering this or `temperature` but not both.
143 pub fn top_p(self, top_p: f32) -> Self {
144 Self {
145 top_p: Some(top_p),
146 ..self
147 }
148 }
149
150 /// How many chat completion choices to generate for each input message.
151 pub fn n(self, n: u32) -> Self {
152 Self { n: Some(n), ..self }
153 }
154
155 /// Up to 4 sequences where the API will stop generating further tokens.
156 pub fn stop(self, stop: Vec<String>) -> Self {
157 Self {
158 stop: Some(stop),
159 ..self
160 }
161 }
162
163 // The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
164 ///
165 /// The total length of input tokens and generated tokens is limited by the
166 /// model's context length.
167 pub fn max_tokens(self, max_tokens: u32) -> Self {
168 Self {
169 max_tokens: Some(max_tokens),
170 ..self
171 }
172 }
173
174 /// Number between -2.0 and 2.0. Positive values penalize new tokens based
175 /// on whether they appear in the text so far, increasing the model's
176 /// likelihood to talk about new topics.
177 ///
178 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
179 pub fn presence_penalty(self, presence_penalty: f32) -> Self {
180 Self {
181 presence_penalty: Some(presence_penalty),
182 ..self
183 }
184 }
185
186 /// Number between -2.0 and 2.0. Positive values penalize new tokens based
187 /// on their existing frequency in the text so far, decreasing the model's
188 /// likelihood to repeat the same line verbatim.
189 ///
190 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
191 pub fn frequency_penalty(self, frequency_penalty: f32) -> Self {
192 Self {
193 frequency_penalty: Some(frequency_penalty),
194 ..self
195 }
196 }
197
198 /// Modify the likelihood of specified tokens appearing in the completion.
199 ///
200 /// Accepts a json object that maps tokens (specified by their token ID in
201 /// the tokenizer) to an associated bias value from -100 to 100.
202 /// Mathematically, the bias is added to the logits generated by the model
203 /// prior to sampling. The exact effect will vary per model, but values
204 /// between -1 and 1 should decrease or increase likelihood of selection;
205 /// values like -100 or 100 should result in a ban or exclusive selection
206 /// of the relevant token.
207 pub fn logit_bias(self, logit_bias: HashMap<String, f32>) -> Self {
208 Self {
209 logit_bias: Some(logit_bias),
210 ..self
211 }
212 }
213
214 /// A unique identifier representing your end-user, which can help OpenAI
215 /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
216 pub fn user(self, user: &str) -> Self {
217 Self {
218 user: Some(user.into()),
219 ..self
220 }
221 }
222
223 /// Send chat completion request to OpenAI using streamed method.
224 ///
225 /// Partial message deltas will be sent, like in ChatGPT. Tokens
226 /// will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available,
227 /// with the stream terminated by a `data: [DONE]` message. See the OpenAI
228 /// Cookbook for [example code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
229 pub async fn streamed_completion(
230 self,
231 mut cb: Option<impl FnMut(Chunk)>,
232 ) -> Result<Vec<Chunk>, Box<dyn std::error::Error>> {
233 let data = Self {
234 stream: Some(true),
235 ..self
236 };
237
238 if !endpoint_filter(&data.model, &Endpoint::ChatCompletion_v1) {
239 return Err("Model not compatible with this endpoint".into());
240 }
241
242 let mut ret_val: Vec<Chunk> = vec![];
243 let ret_val_ref = &mut ret_val;
244
245 request_endpoint_stream(
246 &data,
247 &Endpoint::ChatCompletion_v1,
248 EndpointVariant::None,
249 |res| {
250 if let Ok(chunk_data_raw) = res {
251 for chunk_data in chunk_data_raw.split("\n") {
252 let chunk_data = chunk_data.trim().to_string();
253 if &chunk_data == "data: [DONE]" {
254 debug!(target: "openai", "Last chunk received.");
255 return;
256 }
257 if chunk_data.starts_with("data: ") {
258 // Strip response content:
259 let stripped_chunk = &chunk_data.trim()[6..];
260 if let Ok(message_chunk) = serde_json::from_str::<Chunk>(stripped_chunk) {
261 ret_val_ref.push(message_chunk.clone());
262 if let Some(cb) = &mut cb {
263 cb(message_chunk);
264 }
265 } else {
266 if let Ok(response_error) =
267 serde_json::from_str::<Error>(&stripped_chunk)
268 {
269 warn!(target: "openai",
270 "OpenAI error code {}: `{:?}`",
271 response_error.error.code.unwrap_or(0),
272 stripped_chunk
273 );
274 } else {
275 warn!(target: "openai", "Completion response not deserializable.");
276 }
277 }
278 }
279 };
280 }
281 },
282 )
283 .await?;
284
285 Ok(ret_val)
286 }
287
288 /// Send chat completion request to OpenAI.
289 pub async fn completion(self) -> Result<ChatCompletionResponse, Box<dyn std::error::Error>> {
290 let data = Self {
291 stream: None,
292 ..self
293 };
294
295 if !endpoint_filter(&data.model, &Endpoint::ChatCompletion_v1) {
296 return Err("Model not compatible with this endpoint".into());
297 }
298
299 let mut completion_response: Option<ChatCompletionResponse> = None;
300
301 request_endpoint(&data, &Endpoint::ChatCompletion_v1, EndpointVariant::None, |res| {
302 if let Ok(text) = res {
303 if let Ok(response_data) = serde_json::from_str::<ChatCompletionResponse>(&text) {
304 debug!(target: "openai", "Response parsed, completion response deserialized.");
305 completion_response = Some(response_data);
306 } else {
307 if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
308 warn!(target: "openai",
309 "OpenAI error code {}: `{:?}`",
310 response_error.error.code.unwrap_or(0),
311 text
312 );
313 } else {
314 warn!(target: "openai", "Completion response not deserializable.");
315 }
316 }
317 }
318 })
319 .await?;
320
321 if let Some(response_data) = completion_response {
322 Ok(response_data)
323 } else {
324 Err("No response".into())
325 }
326 }
327}