rust_ai/openai/apis/
chat_completion.rs

1//!
2//! Given a chat conversation, the model will return a chat completion response.
3//!
4//! Source: OpenAI documentation
5
6////////////////////////////////////////////////////////////////////////////////
7
8use std::collections::HashMap;
9
10use crate::openai::{
11    endpoint::{
12        endpoint_filter, request_endpoint, request_endpoint_stream, Endpoint, EndpointVariant,
13    },
14    types::{
15        chat_completion::{ChatCompletionResponse, ChatMessage, Chunk, MessageRole},
16        common::Error,
17        Model,
18    },
19};
20use log::{debug, warn};
21use serde::{Deserialize, Serialize};
22use serde_with::serde_as;
23
24/// Given a chat conversation, the model will return a chat completion response.
25#[serde_as]
26#[derive(Serialize, Deserialize, Debug)]
27pub struct ChatCompletion {
28    pub model: Model,
29
30    pub messages: Vec<ChatMessage>,
31
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub stream: Option<bool>,
34
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub temperature: Option<f32>,
37
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub top_p: Option<f32>,
40
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub n: Option<u32>,
43
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub stop: Option<Vec<String>>,
46
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub max_tokens: Option<u32>,
49
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub presence_penalty: Option<f32>,
52
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub frequency_penalty: Option<f32>,
55
56    #[serde_as(as = "Option<Vec<(_,_)>>")]
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub logit_bias: Option<HashMap<String, f32>>,
59
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub user: Option<String>,
62}
63
64impl Default for ChatCompletion {
65    fn default() -> Self {
66        Self {
67            model: Model::GPT_3_5_TURBO,
68            messages: vec![],
69            stream: Some(false),
70            temperature: None,
71            top_p: None,
72            n: None,
73            stop: None,
74            max_tokens: None,
75            presence_penalty: None,
76            frequency_penalty: None,
77            logit_bias: None,
78            user: None,
79        }
80    }
81}
82
83impl ChatCompletion {
84    /// ID of the model to use. See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table
85    /// for details on which models work with the Chat API.
86    ///
87    /// # Argument
88    /// - `model` - Target model to make use of
89    pub fn model(self, model: Model) -> Self {
90        Self { model, ..self }
91    }
92
93    /// Add message to prompt by role and content.
94    ///
95    /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
96    ///
97    /// # Arguments
98    /// - `role` - Message role enum variant
99    /// - `content` - Message content
100    pub fn message(self, role: MessageRole, content: &str) -> Self {
101        let mut messages = if self.messages.len() == 0 {
102            vec![]
103        } else {
104            self.messages
105        };
106        messages.push(ChatMessage::new(role, content));
107
108        Self {
109            messages: messages,
110            ..self
111        }
112    }
113
114    /// Add message to prompt by message instance.
115    ///
116    /// The messages to generate chat completions for, in the [chat format](https://platform.openai.com/docs/guides/chat/introduction).
117    ///
118    /// # Argument
119    /// - `messages` - Message instance vector, will replace all existing
120    ///     messages
121    pub fn messages(self, messages: Vec<ChatMessage>) -> Self {
122        Self { messages, ..self }
123    }
124
125    /// What sampling temperature to use, between 0 and 2. Higher values like 0.
126    /// 8 will make the output more random, while lower values like 0.2 will
127    /// make it more focused and deterministic.
128    ///
129    /// We generally recommend altering this or `top_p` but not both.
130    pub fn temperature(self, temperature: f32) -> Self {
131        Self {
132            temperature: Some(temperature),
133            ..self
134        }
135    }
136
137    /// An alternative to sampling with temperature, called nucleus sampling,
138    /// where the model considers the results of the tokens with top_p
139    /// probability mass. So 0.1 means only the tokens comprising the top 10%
140    /// probability mass are considered.
141    ///
142    /// We generally recommend altering this or `temperature` but not both.
143    pub fn top_p(self, top_p: f32) -> Self {
144        Self {
145            top_p: Some(top_p),
146            ..self
147        }
148    }
149
150    /// How many chat completion choices to generate for each input message.
151    pub fn n(self, n: u32) -> Self {
152        Self { n: Some(n), ..self }
153    }
154
155    /// Up to 4 sequences where the API will stop generating further tokens.
156    pub fn stop(self, stop: Vec<String>) -> Self {
157        Self {
158            stop: Some(stop),
159            ..self
160        }
161    }
162
163    // The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the chat completion.
164    ///
165    /// The total length of input tokens and generated tokens is limited by the
166    /// model's context length.
167    pub fn max_tokens(self, max_tokens: u32) -> Self {
168        Self {
169            max_tokens: Some(max_tokens),
170            ..self
171        }
172    }
173
174    /// Number between -2.0 and 2.0. Positive values penalize new tokens based
175    /// on whether they appear in the text so far, increasing the model's
176    /// likelihood to talk about new topics.
177    ///
178    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
179    pub fn presence_penalty(self, presence_penalty: f32) -> Self {
180        Self {
181            presence_penalty: Some(presence_penalty),
182            ..self
183        }
184    }
185
186    /// Number between -2.0 and 2.0. Positive values penalize new tokens based
187    /// on their existing frequency in the text so far, decreasing the model's
188    /// likelihood to repeat the same line verbatim.
189    ///
190    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
191    pub fn frequency_penalty(self, frequency_penalty: f32) -> Self {
192        Self {
193            frequency_penalty: Some(frequency_penalty),
194            ..self
195        }
196    }
197
198    /// Modify the likelihood of specified tokens appearing in the completion.
199    ///
200    /// Accepts a json object that maps tokens (specified by their token ID in
201    /// the tokenizer) to an associated bias value from -100 to 100.
202    /// Mathematically, the bias is added to the logits generated by the model
203    /// prior to sampling. The exact effect will vary per model, but values
204    /// between -1 and 1 should decrease or increase likelihood of selection;
205    /// values like -100 or 100 should result in a ban or exclusive selection
206    /// of the relevant token.
207    pub fn logit_bias(self, logit_bias: HashMap<String, f32>) -> Self {
208        Self {
209            logit_bias: Some(logit_bias),
210            ..self
211        }
212    }
213    
214    /// A unique identifier representing your end-user, which can help OpenAI
215    /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
216    pub fn user(self, user: &str) -> Self {
217        Self {
218            user: Some(user.into()),
219            ..self
220        }
221    }
222
223    /// Send chat completion request to OpenAI using streamed method.
224    ///
225    /// Partial message deltas will be sent, like in ChatGPT. Tokens
226    /// will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available,
227    /// with the stream terminated by a `data: [DONE]` message. See the OpenAI
228    /// Cookbook for [example code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb).
229    pub async fn streamed_completion(
230        self,
231        mut cb: Option<impl FnMut(Chunk)>,
232    ) -> Result<Vec<Chunk>, Box<dyn std::error::Error>> {
233        let data = Self {
234            stream: Some(true),
235            ..self
236        };
237
238        if !endpoint_filter(&data.model, &Endpoint::ChatCompletion_v1) {
239            return Err("Model not compatible with this endpoint".into());
240        }
241
242        let mut ret_val: Vec<Chunk> = vec![];
243        let ret_val_ref = &mut ret_val;
244
245        request_endpoint_stream(
246            &data,
247            &Endpoint::ChatCompletion_v1,
248            EndpointVariant::None,
249            |res| {
250                if let Ok(chunk_data_raw) = res {
251                    for chunk_data in chunk_data_raw.split("\n") {
252                    let chunk_data = chunk_data.trim().to_string();
253                    if &chunk_data == "data: [DONE]" {
254                        debug!(target: "openai", "Last chunk received.");
255                        return;
256                    }
257                    if chunk_data.starts_with("data: ") {
258                        // Strip response content:
259                        let stripped_chunk = &chunk_data.trim()[6..];
260                        if let Ok(message_chunk) = serde_json::from_str::<Chunk>(stripped_chunk) {
261                            ret_val_ref.push(message_chunk.clone());
262                            if let Some(cb) = &mut cb {
263                                cb(message_chunk);
264                            }
265                        } else {
266                            if let Ok(response_error) =
267                                serde_json::from_str::<Error>(&stripped_chunk)
268                            {
269                                warn!(target: "openai",
270                                    "OpenAI error code {}: `{:?}`",
271                                    response_error.error.code.unwrap_or(0),
272                                    stripped_chunk
273                                );
274                            } else {
275                                warn!(target: "openai", "Completion response not deserializable.");
276                            }
277                        }
278                    }
279                };
280                }
281            },
282        )
283        .await?;
284
285        Ok(ret_val)
286    }
287
288    /// Send chat completion request to OpenAI.
289    pub async fn completion(self) -> Result<ChatCompletionResponse, Box<dyn std::error::Error>> {
290        let data = Self {
291            stream: None,
292            ..self
293        };
294
295        if !endpoint_filter(&data.model, &Endpoint::ChatCompletion_v1) {
296            return Err("Model not compatible with this endpoint".into());
297        }
298
299        let mut completion_response: Option<ChatCompletionResponse> = None;
300
301        request_endpoint(&data, &Endpoint::ChatCompletion_v1, EndpointVariant::None, |res| {
302            if let Ok(text) = res {
303                if let Ok(response_data) = serde_json::from_str::<ChatCompletionResponse>(&text) {
304                    debug!(target: "openai", "Response parsed, completion response deserialized.");
305                    completion_response = Some(response_data);
306                } else {
307                    if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
308                        warn!(target: "openai",
309                            "OpenAI error code {}: `{:?}`",
310                            response_error.error.code.unwrap_or(0),
311                            text
312                        );
313                    } else {
314                        warn!(target: "openai", "Completion response not deserializable.");
315                    }
316                }
317            }
318        })
319        .await?;
320
321        if let Some(response_data) = completion_response {
322            Ok(response_data)
323        } else {
324            Err("No response".into())
325        }
326    }
327}