workflow_gpt/
gpt.rs

1use crate::imports::*;
2
3#[derive(Debug)]
4pub enum Model {
5    CushmanCodex,
6    DavinciCodex,
7    Gpt35Turbo,
8    Gpt4,
9    Gpt4o,
10    TextAda001,
11    TextBabbage001,
12    TextCurie001,
13    TextDavinci002,
14    TextDavinci003,
15    Custom(String),
16}
17
18impl std::fmt::Display for Model {
19    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
20        match self {
21            Model::CushmanCodex => write!(f, "cushman-codex"),
22            Model::DavinciCodex => write!(f, "davinci-codex"),
23            Model::Gpt35Turbo => write!(f, "gpt-3.5-turbo"),
24            Model::Gpt4 => write!(f, "gpt-4"),
25            Model::Gpt4o => write!(f, "gpt-4o"),
26            Model::TextAda001 => write!(f, "text-ada-001"),
27            Model::TextBabbage001 => write!(f, "text-babbage-001"),
28            Model::TextCurie001 => write!(f, "text-curie-001"),
29            Model::TextDavinci002 => write!(f, "text-davinci-002"),
30            Model::TextDavinci003 => write!(f, "text-davinci-003"),
31            Model::Custom(model) => write!(f, "{model}"),
32        }
33    }
34}
35
36struct Inner {
37    api_key: String,
38    model: Model,
39    client: Client,
40}
41
42#[derive(Clone)]
43pub struct ChatGPT {
44    inner: Arc<Inner>,
45}
46
47impl ChatGPT {
48    pub fn new(api_key: String, model: Model) -> Self {
49        ChatGPT {
50            inner: Arc::new(Inner {
51                api_key,
52                model,
53                client: Client::new(),
54            }),
55        }
56    }
57
58    pub async fn query_with_retries(
59        &self,
60        text: String,
61        retries: usize,
62        delay: Duration,
63    ) -> Result<String> {
64        let mut attempt = 0;
65        loop {
66            match self.query(text.clone()).await {
67                Ok(response) => {
68                    return Ok(response);
69                }
70                Err(err) => {
71                    workflow_core::task::sleep(delay).await;
72                    attempt += 1;
73                    if attempt >= retries {
74                        return Err(Error::RetryFailure(retries, err.to_string()));
75                    }
76                }
77            }
78        }
79    }
80
81    pub async fn query(&self, text: String) -> Result<String> {
82        let response = self
83            .inner
84            .client
85            .post("https://api.openai.com/v1/chat/completions")
86            .header("Authorization", format!("Bearer {}", self.inner.api_key))
87            .json(&Request {
88                model: self.inner.model.to_string(),
89                messages: vec![Message {
90                    role: "user".to_string(),
91                    content: text,
92                }],
93            })
94            .send()
95            .await?
96            .json::<Response>()
97            .await?;
98
99        Ok(response
100            .choices
101            .first()
102            .map(|choice| choice.message.content.clone())
103            .unwrap_or_default())
104    }
105
106    pub async fn translate(
107        &self,
108        entries: Vec<String>,
109        target_language: &str,
110    ) -> Result<Vec<(String, String)>> {
111        // Construct a single message with all texts to be translated
112        let message_content = entries.clone().join("\n");
113        let message_content = format!(
114            "Translate the following text line by line to {}\n{}",
115            target_language, message_content
116        );
117
118        let response = self
119            .inner
120            .client
121            .post("https://api.openai.com/v1/chat/completions")
122            .header("Authorization", format!("Bearer {}", self.inner.api_key))
123            .json(&Request {
124                model: self.inner.model.to_string(),
125                messages: vec![Message {
126                    role: "user".to_string(),
127                    content: message_content,
128                }],
129            })
130            .send()
131            .await?
132            .json::<Response>()
133            .await?;
134
135        // Extract the translations from the response
136        let translations = response
137            .choices
138            .first()
139            .map(|choice| {
140                choice
141                    .message
142                    .content
143                    .split('\n')
144                    .map(String::from)
145                    .collect::<Vec<String>>()
146            })
147            .unwrap_or_default();
148
149        // Pair each original text with its translation
150        let result: Vec<(String, String)> = entries.into_iter().zip(translations).collect();
151
152        Ok(result)
153    }
154}
155
156#[derive(Serialize)]
157struct Request {
158    model: String,
159    messages: Vec<Message>,
160}
161
162#[derive(Serialize)]
163struct Message {
164    role: String,
165    content: String,
166}
167
168#[derive(Deserialize)]
169struct Response {
170    choices: Vec<Choice>,
171}
172
173#[derive(Deserialize)]
174struct Choice {
175    message: MessageResponse,
176}
177
178#[derive(Deserialize)]
179struct MessageResponse {
180    content: String,
181}