1use crate::imports::*;
2
3#[derive(Debug)]
4pub enum Model {
5 CushmanCodex,
6 DavinciCodex,
7 Gpt35Turbo,
8 Gpt4,
9 Gpt4o,
10 TextAda001,
11 TextBabbage001,
12 TextCurie001,
13 TextDavinci002,
14 TextDavinci003,
15 Custom(String),
16}
17
18impl std::fmt::Display for Model {
19 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
20 match self {
21 Model::CushmanCodex => write!(f, "cushman-codex"),
22 Model::DavinciCodex => write!(f, "davinci-codex"),
23 Model::Gpt35Turbo => write!(f, "gpt-3.5-turbo"),
24 Model::Gpt4 => write!(f, "gpt-4"),
25 Model::Gpt4o => write!(f, "gpt-4o"),
26 Model::TextAda001 => write!(f, "text-ada-001"),
27 Model::TextBabbage001 => write!(f, "text-babbage-001"),
28 Model::TextCurie001 => write!(f, "text-curie-001"),
29 Model::TextDavinci002 => write!(f, "text-davinci-002"),
30 Model::TextDavinci003 => write!(f, "text-davinci-003"),
31 Model::Custom(model) => write!(f, "{model}"),
32 }
33 }
34}
35
36struct Inner {
37 api_key: String,
38 model: Model,
39 client: Client,
40}
41
42#[derive(Clone)]
43pub struct ChatGPT {
44 inner: Arc<Inner>,
45}
46
47impl ChatGPT {
48 pub fn new(api_key: String, model: Model) -> Self {
49 ChatGPT {
50 inner: Arc::new(Inner {
51 api_key,
52 model,
53 client: Client::new(),
54 }),
55 }
56 }
57
58 pub async fn query_with_retries(
59 &self,
60 text: String,
61 retries: usize,
62 delay: Duration,
63 ) -> Result<String> {
64 let mut attempt = 0;
65 loop {
66 match self.query(text.clone()).await {
67 Ok(response) => {
68 return Ok(response);
69 }
70 Err(err) => {
71 workflow_core::task::sleep(delay).await;
72 attempt += 1;
73 if attempt >= retries {
74 return Err(Error::RetryFailure(retries, err.to_string()));
75 }
76 }
77 }
78 }
79 }
80
81 pub async fn query(&self, text: String) -> Result<String> {
82 let response = self
83 .inner
84 .client
85 .post("https://api.openai.com/v1/chat/completions")
86 .header("Authorization", format!("Bearer {}", self.inner.api_key))
87 .json(&Request {
88 model: self.inner.model.to_string(),
89 messages: vec![Message {
90 role: "user".to_string(),
91 content: text,
92 }],
93 })
94 .send()
95 .await?
96 .json::<Response>()
97 .await?;
98
99 Ok(response
100 .choices
101 .first()
102 .map(|choice| choice.message.content.clone())
103 .unwrap_or_default())
104 }
105
106 pub async fn translate(
107 &self,
108 entries: Vec<String>,
109 target_language: &str,
110 ) -> Result<Vec<(String, String)>> {
111 let message_content = entries.clone().join("\n");
113 let message_content = format!(
114 "Translate the following text line by line to {}\n{}",
115 target_language, message_content
116 );
117
118 let response = self
119 .inner
120 .client
121 .post("https://api.openai.com/v1/chat/completions")
122 .header("Authorization", format!("Bearer {}", self.inner.api_key))
123 .json(&Request {
124 model: self.inner.model.to_string(),
125 messages: vec![Message {
126 role: "user".to_string(),
127 content: message_content,
128 }],
129 })
130 .send()
131 .await?
132 .json::<Response>()
133 .await?;
134
135 let translations = response
137 .choices
138 .first()
139 .map(|choice| {
140 choice
141 .message
142 .content
143 .split('\n')
144 .map(String::from)
145 .collect::<Vec<String>>()
146 })
147 .unwrap_or_default();
148
149 let result: Vec<(String, String)> = entries.into_iter().zip(translations).collect();
151
152 Ok(result)
153 }
154}
155
156#[derive(Serialize)]
157struct Request {
158 model: String,
159 messages: Vec<Message>,
160}
161
162#[derive(Serialize)]
163struct Message {
164 role: String,
165 content: String,
166}
167
168#[derive(Deserialize)]
169struct Response {
170 choices: Vec<Choice>,
171}
172
173#[derive(Deserialize)]
174struct Choice {
175 message: MessageResponse,
176}
177
178#[derive(Deserialize)]
179struct MessageResponse {
180 content: String,
181}