rust_ai/openai/apis/completion.rs
1//!
2//! Given a prompt, the model will return one or more predicted completions,
3//! and can also return the probabilities of alternative tokens at each position.
4//!
5//! Source: OpenAI documentation
6
7////////////////////////////////////////////////////////////////////////////////
8
9use std::collections::HashMap;
10
11use crate::openai::{
12 endpoint::{
13 endpoint_filter, request_endpoint, request_endpoint_stream, Endpoint, EndpointVariant,
14 },
15 types::{
16 common::Error,
17 completion::{Chunk, CompletionResponse},
18 model::Model,
19 },
20};
21use log::{debug, warn};
22use serde::{Deserialize, Serialize};
23use serde_with::serde_as;
24
25/// Given a prompt, the model will return one or more predicted completions,
26/// and can also return the probabilities of alternative tokens at each
27/// position.
28#[serde_as]
29#[derive(Serialize, Deserialize, Debug)]
30pub struct Completion {
31 pub model: Model,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub prompt: Option<Vec<String>>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub stream: Option<bool>,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub suffix: Option<String>,
41
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub temperature: Option<f32>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub top_p: Option<f32>,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub n: Option<u32>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub logprobs: Option<u32>,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub echo: Option<Vec<bool>>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub stop: Option<Vec<String>>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub max_tokens: Option<u32>,
62
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub presence_penalty: Option<f32>,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub frequency_penalty: Option<f32>,
68
69 #[serde_as(as = "Option<Vec<(_,_)>>")]
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub best_of: Option<HashMap<String, u32>>,
72
73 #[serde_as(as = "Option<Vec<(_,_)>>")]
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub logit_bias: Option<HashMap<String, f32>>,
76
77 #[serde(skip_serializing_if = "Option::is_none")]
78 pub user: Option<String>,
79}
80
81impl Default for Completion {
82 fn default() -> Self {
83 Self {
84 model: Model::TEXT_DAVINCI_003,
85 prompt: None,
86 stream: Some(false),
87 temperature: None,
88 top_p: None,
89 n: None,
90 stop: None,
91 max_tokens: None,
92 presence_penalty: None,
93 frequency_penalty: None,
94 logit_bias: None,
95 user: None,
96 suffix: None,
97 logprobs: None,
98 echo: None,
99 best_of: None,
100 }
101 }
102}
103
104impl Completion {
105 /// ID of the model to use. You can use the [List models API](https://platform.openai.com/docs/api-reference/models/list) to see all of
106 /// your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of
107 /// them.
108 pub fn model(self, model: Model) -> Self {
109 Self { model, ..self }
110 }
111
112 /// Add message to prompt.
113 /// The prompt(s) to generate completions for, encoded as a string, array
114 /// of strings, array of tokens, or array of token arrays.
115 ///
116 /// Note that <|endoftext|> is the document separator that the model sees
117 /// during training, so if a prompt is not specified the model will
118 /// generate as if from the beginning of a new document.
119 pub fn prompt(self, content: &str) -> Self {
120 let mut prompt = vec![];
121 if let Some(prmp) = self.prompt {
122 prompt.extend(prmp);
123 }
124 prompt.push(String::from(content));
125
126 Self {
127 prompt: Some(prompt),
128 ..self
129 }
130 }
131
132 /// The suffix that comes after a completion of inserted text.
133 pub fn suffix(self, suffix: String) -> Self {
134 Self {
135 suffix: Some(suffix),
136 ..self
137 }
138 }
139
140 /// What sampling temperature to use, between 0 and 2. Higher values like 0.
141 /// 8 will make the output more random, while lower values like 0.2 will
142 /// make it more focused and deterministic.
143 ///
144 /// We generally recommend altering this or `top_p` but not both.
145 pub fn temperature(self, temperature: f32) -> Self {
146 Self {
147 temperature: Some(temperature),
148 ..self
149 }
150 }
151
152 /// An alternative to sampling with temperature, called nucleus sampling,
153 /// where the model considers the results of the tokens with top_p
154 /// probability mass. So 0.1 means only the tokens comprising the top 10%
155 /// probability mass are considered.
156 ///
157 /// We generally recommend altering this or `temperature` but not both.
158 pub fn top_p(self, top_p: f32) -> Self {
159 Self {
160 top_p: Some(top_p),
161 ..self
162 }
163 }
164
165 /// How many completions to generate for each prompt.
166 ///
167 /// **Note**: Because this parameter generates many completions, it can quickly
168 /// consume your token quota. Use carefully and ensure that you have
169 /// reasonable settings for `max_tokens` and `stop`.
170 pub fn n(self, n: u32) -> Self {
171 Self { n: Some(n), ..self }
172 }
173 /// Include the log probabilities on the `logprobs` most likely tokens, as
174 /// well the chosen tokens. For example, if `logprobs` is 5, the API will
175 /// return a list of the 5 most likely tokens. The API will always return
176 /// the `logprob` of the sampled token, so there may be up to `logprobs+1`
177 /// elements in the response.
178 ///
179 /// The maximum value for `logprobs` is 5. If you need more than this,
180 /// please contact us through our **Help center** and describe your use
181 /// case.
182 pub fn logprobs(self, logprobs: u32) -> Self {
183 Self {
184 logprobs: Some(logprobs),
185 ..self
186 }
187 }
188
189 /// Echo back the prompt in addition to the completion
190 pub fn echo(self, echo: Vec<bool>) -> Self {
191 Self {
192 echo: Some(echo),
193 ..self
194 }
195 }
196
197 /// Up to 4 sequences where the API will stop generating further tokens.
198 /// The returned text will not contain the stop sequence.
199 pub fn stop(self, stop: Vec<String>) -> Self {
200 Self {
201 stop: Some(stop),
202 ..self
203 }
204 }
205
206 /// The maximum number of [tokens](https://platform.openai.com/tokenizer) to generate in the completion.
207 ///
208 /// The token count of your prompt plus `max_tokens` cannot exceed the
209 /// model's context length. Most models have a context length of 2048
210 /// tokens (except for the newest models, which support 4096).
211 pub fn max_tokens(self, max_tokens: u32) -> Self {
212 Self {
213 max_tokens: Some(max_tokens),
214 ..self
215 }
216 }
217
218 /// Number between -2.0 and 2.0. Positive values penalize new tokens based
219 /// on whether they appear in the text so far, increasing the model's
220 /// likelihood to talk about new topics.
221 ///
222 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
223 pub fn presence_penalty(self, presence_penalty: f32) -> Self {
224 Self {
225 presence_penalty: Some(presence_penalty),
226 ..self
227 }
228 }
229
230 /// Number between -2.0 and 2.0. Positive values penalize new tokens based
231 /// on their existing frequency in the text so far, decreasing the model's
232 /// likelihood to repeat the same line verbatim.
233 ///
234 /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
235 pub fn frequency_penalty(self, frequency_penalty: f32) -> Self {
236 Self {
237 frequency_penalty: Some(frequency_penalty),
238 ..self
239 }
240 }
241
242 /// Generates `best_of` completions server-side and returns the "best" (the
243 /// one with the highest log probability per token). Results cannot be
244 /// streamed.
245 ///
246 /// When used with `n`, `best_of` controls the number of candidate
247 /// completions and `n` specifies how many to return – `best_of` must be
248 /// greater than n.
249 ///
250 /// **Note**: Because this parameter generates many completions, it can
251 /// quickly consume your token quota. Use carefully and ensure that you
252 /// have reasonable settings for `max_tokens` and `stop`.
253 pub fn best_of(self, best_of: HashMap<String, u32>) -> Self {
254 Self {
255 best_of: Some(best_of),
256 ..self
257 }
258 }
259
260 /// Modify the likelihood of specified tokens appearing in the completion.
261 ///
262 /// Accepts a json object that maps tokens (specified by their token ID in
263 /// the GPT tokenizer) to an associated bias value from -100 to 100. You
264 /// can use this [tokenizer tool](https://platform.openai.com/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to
265 /// convert text to token IDs. Mathematically, the bias is added to the
266 /// logits generated by the model prior to sampling. The exact effect will
267 /// vary per model, but values between -1 and 1 should decrease or increase
268 /// likelihood of selection; values like -100 or 100 should result in a
269 /// ban or exclusive selection of the relevant token.
270 ///
271 /// As an example, you can pass `{"50256": -100}` to prevent the
272 /// <|endoftext|> token from being generated.
273 pub fn logit_bias(self, logit_bias: HashMap<String, f32>) -> Self {
274 Self {
275 logit_bias: Some(logit_bias),
276 ..self
277 }
278 }
279
280 /// A unique identifier representing your end-user, which can help OpenAI
281 /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
282 pub fn user(self, user: &str) -> Self {
283 Self {
284 user: Some(user.into()),
285 ..self
286 }
287 }
288
289 /// Send completion request to OpenAI using streamed method.
290 ///
291 /// Whether to stream back partial progress. If set, tokens will be sent as
292 /// data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream
293 /// terminated by a `data: [DONE]` message.
294 pub async fn stream_completion<F>(
295 self,
296 mut cb: Option<F>,
297 ) -> Result<Vec<Chunk>, Box<dyn std::error::Error>>
298 where
299 F: FnMut(Chunk),
300 {
301 let data = Self {
302 stream: Some(true),
303 ..self
304 };
305
306 if !endpoint_filter(&data.model, &Endpoint::Completion_v1) {
307 return Err("Model not compatible with this endpoint".into());
308 }
309
310 let mut ret_val: Vec<Chunk> = vec![];
311
312 request_endpoint_stream(&data, &Endpoint::Completion_v1, EndpointVariant::None,|res| {
313 if let Ok(chunk_data_raw) = res {
314 chunk_data_raw.split("\n").for_each(|chunk_data| {
315 let chunk_data = chunk_data.trim().to_string();
316 if &chunk_data == "data: [DONE]" {
317 debug!(target: "openai", "Last chunk received.");
318 return;
319 }
320 if chunk_data.starts_with("data: ") {
321 // Strip response content:
322 let stripped_chunk = &chunk_data.trim()[6..];
323 if let Ok(message_chunk) = serde_json::from_str::<Chunk>(stripped_chunk) {
324 ret_val.push(message_chunk.clone());
325 if let Some(cb) = &mut cb {
326 cb(message_chunk);
327 }
328 } else {
329 if let Ok(response_error) = serde_json::from_str::<Error>(&stripped_chunk) {
330 warn!(target: "openai",
331 "OpenAI error code {}: `{:?}`",
332 response_error.error.code.unwrap_or(0),
333 stripped_chunk
334 );
335 } else {
336 warn!(target: "openai", "Completion response not deserializable.");
337 }
338 }
339 }
340 });
341 }
342 })
343 .await?;
344
345 Ok(ret_val)
346 }
347
348 /// Send completion request to OpenAI.
349 pub async fn completion(self) -> Result<CompletionResponse, Box<dyn std::error::Error>> {
350
351 let data = Self {
352 stream: None,
353 ..self
354 };
355
356 if !endpoint_filter(&data.model, &Endpoint::Completion_v1) {
357 return Err("Model not compatible with this endpoint".into());
358 }
359
360 let mut completion_response: Option<CompletionResponse> = None;
361
362 request_endpoint(&data, &Endpoint::Completion_v1, EndpointVariant::None, |res| {
363 if let Ok(text) = res {
364 if let Ok(response_data) = serde_json::from_str::<CompletionResponse>(&text) {
365 debug!(target: "openai", "Response parsed, completion response deserialized.");
366 completion_response = Some(response_data);
367 } else {
368 if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
369 warn!(target: "openai",
370 "OpenAI error code {}: `{:?}`",
371 response_error.error.code.unwrap_or(0),
372 text
373 );
374 } else {
375 warn!(target: "openai", "Completion response not deserializable.");
376 }
377 }
378 }
379 })
380 .await?;
381
382 if let Some(response_data) = completion_response {
383 Ok(response_data)
384 } else {
385 Err("No response or error parsing response".into())
386 }
387 }
388}