rust_ai/openai/apis/
completion.rs

1//!
2//! Given a prompt, the model will return one or more predicted completions,
3//! and can also return the probabilities of alternative tokens at each position.
4//!
5//! Source: OpenAI documentation
6
7////////////////////////////////////////////////////////////////////////////////
8
9use std::collections::HashMap;
10
11use crate::openai::{
12    endpoint::{
13        endpoint_filter, request_endpoint, request_endpoint_stream, Endpoint, EndpointVariant,
14    },
15    types::{
16        common::Error,
17        completion::{Chunk, CompletionResponse},
18        model::Model,
19    },
20};
21use log::{debug, warn};
22use serde::{Deserialize, Serialize};
23use serde_with::serde_as;
24
25/// Given a prompt, the model will return one or more predicted completions,
26/// and can also return the probabilities of alternative tokens at each
27/// position.
28#[serde_as]
29#[derive(Serialize, Deserialize, Debug)]
30pub struct Completion {
31    pub model: Model,
32
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub prompt: Option<Vec<String>>,
35
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub stream: Option<bool>,
38
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub suffix: Option<String>,
41
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub temperature: Option<f32>,
44
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub top_p: Option<f32>,
47
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub n: Option<u32>,
50
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub logprobs: Option<u32>,
53
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub echo: Option<Vec<bool>>,
56
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub stop: Option<Vec<String>>,
59
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub max_tokens: Option<u32>,
62
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub presence_penalty: Option<f32>,
65
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub frequency_penalty: Option<f32>,
68
69    #[serde_as(as = "Option<Vec<(_,_)>>")]
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub best_of: Option<HashMap<String, u32>>,
72
73    #[serde_as(as = "Option<Vec<(_,_)>>")]
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub logit_bias: Option<HashMap<String, f32>>,
76
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub user: Option<String>,
79}
80
81impl Default for Completion {
82    fn default() -> Self {
83        Self {
84            model: Model::TEXT_DAVINCI_003,
85            prompt: None,
86            stream: Some(false),
87            temperature: None,
88            top_p: None,
89            n: None,
90            stop: None,
91            max_tokens: None,
92            presence_penalty: None,
93            frequency_penalty: None,
94            logit_bias: None,
95            user: None,
96            suffix: None,
97            logprobs: None,
98            echo: None,
99            best_of: None,
100        }
101    }
102}
103
104impl Completion {
105    /// ID of the model to use. You can use the [List models API](https://platform.openai.com/docs/api-reference/models/list) to see all of
106    /// your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of
107    /// them.
108    pub fn model(self, model: Model) -> Self {
109        Self { model, ..self }
110    }
111
112    /// Add message to prompt.
113    /// The prompt(s) to generate completions for, encoded as a string, array
114    /// of strings, array of tokens, or array of token arrays.
115    ///
116    /// Note that <|endoftext|> is the document separator that the model sees
117    /// during training, so if a prompt is not specified the model will
118    /// generate as if from the beginning of a new document.
119    pub fn prompt(self, content: &str) -> Self {
120        let mut prompt = vec![];
121        if let Some(prmp) = self.prompt {
122            prompt.extend(prmp);
123        }
124        prompt.push(String::from(content));
125
126        Self {
127            prompt: Some(prompt),
128            ..self
129        }
130    }
131
132    /// The suffix that comes after a completion of inserted text.
133    pub fn suffix(self, suffix: String) -> Self {
134        Self {
135            suffix: Some(suffix),
136            ..self
137        }
138    }
139
140    /// What sampling temperature to use, between 0 and 2. Higher values like 0.
141    /// 8 will make the output more random, while lower values like 0.2 will
142    /// make it more focused and deterministic.
143    ///
144    /// We generally recommend altering this or `top_p` but not both.
145    pub fn temperature(self, temperature: f32) -> Self {
146        Self {
147            temperature: Some(temperature),
148            ..self
149        }
150    }
151
152    /// An alternative to sampling with temperature, called nucleus sampling,
153    /// where the model considers the results of the tokens with top_p
154    /// probability mass. So 0.1 means only the tokens comprising the top 10%
155    /// probability mass are considered.
156    ///
157    /// We generally recommend altering this or `temperature` but not both.
158    pub fn top_p(self, top_p: f32) -> Self {
159        Self {
160            top_p: Some(top_p),
161            ..self
162        }
163    }
164
165    /// How many completions to generate for each prompt.
166    ///
167    /// **Note**: Because this parameter generates many completions, it can quickly
168    /// consume your token quota. Use carefully and ensure that you have
169    /// reasonable settings for `max_tokens` and `stop`.
170    pub fn n(self, n: u32) -> Self {
171        Self { n: Some(n), ..self }
172    }
173    /// Include the log probabilities on the `logprobs` most likely tokens, as
174    /// well the chosen tokens. For example, if `logprobs` is 5, the API will
175    /// return a list of the 5 most likely tokens. The API will always return
176    /// the `logprob` of the sampled token, so there may be up to `logprobs+1`
177    /// elements in the response.
178    ///
179    /// The maximum value for `logprobs` is 5. If you need more than this,
180    /// please contact us through our **Help center** and describe your use
181    /// case.
182    pub fn logprobs(self, logprobs: u32) -> Self {
183        Self {
184            logprobs: Some(logprobs),
185            ..self
186        }
187    }
188
189    /// Echo back the prompt in addition to the completion
190    pub fn echo(self, echo: Vec<bool>) -> Self {
191        Self {
192            echo: Some(echo),
193            ..self
194        }
195    }
196
197    /// Up to 4 sequences where the API will stop generating further tokens.
198    /// The returned text will not contain the stop sequence.
199    pub fn stop(self, stop: Vec<String>) -> Self {
200        Self {
201            stop: Some(stop),
202            ..self
203        }
204    }
205
206    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the completion.
207    ///
208    /// The token count of your prompt plus `max_tokens` cannot exceed the
209    /// model's context length. Most models have a context length of 2048
210    /// tokens (except for the newest models, which support 4096).
211    pub fn max_tokens(self, max_tokens: u32) -> Self {
212        Self {
213            max_tokens: Some(max_tokens),
214            ..self
215        }
216    }
217
218    /// Number between -2.0 and 2.0. Positive values penalize new tokens based
219    /// on whether they appear in the text so far, increasing the model's
220    /// likelihood to talk about new topics.
221    ///
222    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
223    pub fn presence_penalty(self, presence_penalty: f32) -> Self {
224        Self {
225            presence_penalty: Some(presence_penalty),
226            ..self
227        }
228    }
229
230    /// Number between -2.0 and 2.0. Positive values penalize new tokens based
231    /// on their existing frequency in the text so far, decreasing the model's
232    /// likelihood to repeat the same line verbatim.
233    ///
234    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
235    pub fn frequency_penalty(self, frequency_penalty: f32) -> Self {
236        Self {
237            frequency_penalty: Some(frequency_penalty),
238            ..self
239        }
240    }
241
242    /// Generates `best_of` completions server-side and returns the "best" (the
243    /// one with the highest log probability per token). Results cannot be
244    /// streamed.
245    ///
246    /// When used with `n`, `best_of` controls the number of candidate
247    /// completions and `n` specifies how many to return – `best_of` must be
248    /// greater than n.
249    ///
250    /// **Note**: Because this parameter generates many completions, it can
251    /// quickly consume your token quota. Use carefully and ensure that you
252    /// have reasonable settings for `max_tokens` and `stop`.
253    pub fn best_of(self, best_of: HashMap<String, u32>) -> Self {
254        Self {
255            best_of: Some(best_of),
256            ..self
257        }
258    }
259
260    /// Modify the likelihood of specified tokens appearing in the completion.
261    ///
262    /// Accepts a json object that maps tokens (specified by their token ID in
263    /// the GPT tokenizer) to an associated bias value from -100 to 100. You
264    /// can use this [tokenizer tool](https://platform.openai.com/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
265    ///  convert text to token IDs. Mathematically, the bias is added to the
266    /// logits generated by the model prior to sampling. The exact effect will
267    /// vary per model, but values between -1 and 1 should decrease or increase
268    ///  likelihood of selection; values like -100 or 100 should result in a
269    /// ban or exclusive selection of the relevant token.
270    ///
271    /// As an example, you can pass `{"50256": -100}` to prevent the
272    /// <|endoftext|> token from being generated.
273    pub fn logit_bias(self, logit_bias: HashMap<String, f32>) -> Self {
274        Self {
275            logit_bias: Some(logit_bias),
276            ..self
277        }
278    }
279
280    /// A unique identifier representing your end-user, which can help OpenAI
281    /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
282    pub fn user(self, user: &str) -> Self {
283        Self {
284            user: Some(user.into()),
285            ..self
286        }
287    }
288
289    /// Send completion request to OpenAI using streamed method.
290    ///
291    /// Whether to stream back partial progress. If set, tokens will be sent as
292    /// data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream
293    /// terminated by a `data: [DONE]` message.
294    pub async fn stream_completion<F>(
295        self,
296        mut cb: Option<F>,
297    ) -> Result<Vec<Chunk>, Box<dyn std::error::Error>>
298    where
299        F: FnMut(Chunk),
300    {
301        let data = Self {
302            stream: Some(true),
303            ..self
304        };
305
306        if !endpoint_filter(&data.model, &Endpoint::Completion_v1) {
307            return Err("Model not compatible with this endpoint".into());
308        }
309
310        let mut ret_val: Vec<Chunk> = vec![];
311
312        request_endpoint_stream(&data, &Endpoint::Completion_v1, EndpointVariant::None,|res| {
313            if let Ok(chunk_data_raw) = res {
314                chunk_data_raw.split("\n").for_each(|chunk_data| {
315                    let chunk_data = chunk_data.trim().to_string();
316                    if &chunk_data == "data: [DONE]" {
317                        debug!(target: "openai", "Last chunk received.");
318                        return;
319                    }
320                    if chunk_data.starts_with("data: ") {
321                        // Strip response content:
322                        let stripped_chunk = &chunk_data.trim()[6..];
323                        if let Ok(message_chunk) = serde_json::from_str::<Chunk>(stripped_chunk) {
324                            ret_val.push(message_chunk.clone());
325                            if let Some(cb) = &mut cb {
326                                cb(message_chunk);
327                            }
328                        } else {
329                            if let Ok(response_error) = serde_json::from_str::<Error>(&stripped_chunk) {
330                                warn!(target: "openai",
331                                    "OpenAI error code {}: `{:?}`",
332                                    response_error.error.code.unwrap_or(0),
333                                    stripped_chunk
334                                );
335                            } else {
336                                warn!(target: "openai", "Completion response not deserializable.");
337                            }
338                        }
339                    }
340                });
341            }
342        })
343        .await?;
344
345        Ok(ret_val)
346    }
347
348    /// Send completion request to OpenAI.
349    pub async fn completion(self) -> Result<CompletionResponse, Box<dyn std::error::Error>> {
350        
351        let data = Self {
352            stream: None,
353            ..self
354        };
355
356        if !endpoint_filter(&data.model, &Endpoint::Completion_v1) {
357            return Err("Model not compatible with this endpoint".into());
358        }
359
360        let mut completion_response: Option<CompletionResponse> = None;
361
362        request_endpoint(&data, &Endpoint::Completion_v1, EndpointVariant::None, |res| {
363            if let Ok(text) = res {
364                if let Ok(response_data) = serde_json::from_str::<CompletionResponse>(&text) {
365                    debug!(target: "openai", "Response parsed, completion response deserialized.");
366                    completion_response = Some(response_data);
367                } else {
368                    if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
369                        warn!(target: "openai",
370                            "OpenAI error code {}: `{:?}`",
371                            response_error.error.code.unwrap_or(0),
372                            text
373                        );
374                    } else {
375                        warn!(target: "openai", "Completion response not deserializable.");
376                    }
377                }
378            }
379        })
380        .await?;
381
382        if let Some(response_data) = completion_response {
383            Ok(response_data)
384        } else {
385            Err("No response or error parsing response".into())
386        }
387    }
388}