1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//!
//! Given a prompt and an instruction, the model will return an edited version
//! of the prompt.
//!
//! Source: OpenAI documentation

////////////////////////////////////////////////////////////////////////////////

use crate::openai::{
    endpoint::{endpoint_filter, request_endpoint, Endpoint, EndpointVariant},
    types::{common::Error, edit::EditResponse, model::Model},
};
use log::{debug, warn};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;

/// Given a prompt and an instruction, the model will return an edited version
/// of the prompt.
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct Edit {
    pub model: Model,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub input: Option<String>,

    pub instruction: String,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,

    #[serde(skip_serializing_if = "Option::is_none")]
    pub n: Option<u32>,
}

impl Default for Edit {
    fn default() -> Self {
        Self {
            model: Model::TEXT_DAVINCI_EDIT_001,
            input: Some(String::from("")),
            instruction: String::new(),
            temperature: None,
            top_p: None,
            n: None,
        }
    }
}

impl Edit {
    /// ID of the model to use. You can use the `text-davinci-edit-001` or
    /// `code-davinci-edit-001` model with this endpoint.
    pub fn model(self, model: Model) -> Self {
        Self { model, ..self }
    }

    /// The input text to use as a starting point for the edit.
    pub fn input(self, content: &str) -> Self {
        Self {
            input: Some(content.into()),
            ..self
        }
    }

    /// The instruction that tells the model how to edit the prompt.
    pub fn instruction(self, content: &str) -> Self {
        Self {
            instruction: content.into(),
            ..self
        }
    }

    /// What sampling temperature to use, between 0 and 2. Higher values like
    /// 0.8 will make the output more random, while lower values like 0.2
    /// will make it more focused and deterministic.
    ///
    /// We generally recommend altering this or `top_p` but not both.
    pub fn temperature(self, temperature: f32) -> Self {
        Self {
            temperature: Some(temperature),
            ..self
        }
    }

    /// An alternative to sampling with temperature, called nucleus sampling,
    ///  where the model considers the results of the tokens with `top_p`
    /// probability mass. So 0.1 means only the tokens comprising the top 10%
    /// probability mass are considered.
    ///
    /// We generally recommend altering this or `temperature` but not both.
    pub fn top_p(self, top_p: f32) -> Self {
        Self {
            top_p: Some(top_p),
            ..self
        }
    }
    /// How many edits to generate for the input and instruction.
    pub fn n(self, n: u32) -> Self {
        Self { n: Some(n), ..self }
    }

    /// Send edit request to OpenAI.
    pub async fn edit(self) -> Result<EditResponse, Box<dyn std::error::Error>> {
        if !endpoint_filter(&self.model, &Endpoint::Edit_v1) {
            return Err("Model not compatible with this endpoint".into());
        }

        let mut edit_response: Option<EditResponse> = None;

        request_endpoint(&self, &Endpoint::Edit_v1, EndpointVariant::None, |res| {
            if let Ok(text) = res {
                if let Ok(response_data) = serde_json::from_str::<EditResponse>(&text) {
                    debug!(target: "openai", "Response parsed, edit response deserialized.");
                    edit_response = Some(response_data);
                } else {
                    if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
                        warn!(target: "openai",
                            "OpenAI error code {}: `{:?}`",
                            response_error.error.code.unwrap_or(0),
                            text
                        );
                    } else {
                        warn!(target: "openai", "Edit response not deserializable.");
                    }
                }
            }
        })
        .await?;

        if let Some(response_data) = edit_response {
            Ok(response_data)
        } else {
            Err("No response or error parsing response".into())
        }
    }
}