1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
//!
//! Given a prompt and an instruction, the model will return an edited version
//! of the prompt.
//!
//! Source: OpenAI documentation
////////////////////////////////////////////////////////////////////////////////
use crate::openai::{
endpoint::{endpoint_filter, request_endpoint, Endpoint, EndpointVariant},
types::{common::Error, edit::EditResponse, model::Model},
};
use log::{debug, warn};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
/// Given a prompt and an instruction, the model will return an edited version
/// of the prompt.
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct Edit {
pub model: Model,
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<String>,
pub instruction: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
}
impl Default for Edit {
fn default() -> Self {
Self {
model: Model::TEXT_DAVINCI_EDIT_001,
input: Some(String::from("")),
instruction: String::new(),
temperature: None,
top_p: None,
n: None,
}
}
}
impl Edit {
/// ID of the model to use. You can use the `text-davinci-edit-001` or
/// `code-davinci-edit-001` model with this endpoint.
pub fn model(self, model: Model) -> Self {
Self { model, ..self }
}
/// The input text to use as a starting point for the edit.
pub fn input(self, content: &str) -> Self {
Self {
input: Some(content.into()),
..self
}
}
/// The instruction that tells the model how to edit the prompt.
pub fn instruction(self, content: &str) -> Self {
Self {
instruction: content.into(),
..self
}
}
/// What sampling temperature to use, between 0 and 2. Higher values like
/// 0.8 will make the output more random, while lower values like 0.2
/// will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
pub fn temperature(self, temperature: f32) -> Self {
Self {
temperature: Some(temperature),
..self
}
}
/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with `top_p`
/// probability mass. So 0.1 means only the tokens comprising the top 10%
/// probability mass are considered.
///
/// We generally recommend altering this or `temperature` but not both.
pub fn top_p(self, top_p: f32) -> Self {
Self {
top_p: Some(top_p),
..self
}
}
/// How many edits to generate for the input and instruction.
pub fn n(self, n: u32) -> Self {
Self { n: Some(n), ..self }
}
/// Send edit request to OpenAI.
pub async fn edit(self) -> Result<EditResponse, Box<dyn std::error::Error>> {
if !endpoint_filter(&self.model, &Endpoint::Edit_v1) {
return Err("Model not compatible with this endpoint".into());
}
let mut edit_response: Option<EditResponse> = None;
request_endpoint(&self, &Endpoint::Edit_v1, EndpointVariant::None, |res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<EditResponse>(&text) {
debug!(target: "openai", "Response parsed, edit response deserialized.");
edit_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Edit response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = edit_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
}