swiftide_integrations/openai/
simple_prompt.rs1use async_openai::types::ChatCompletionRequestUserMessageArgs;
6use async_trait::async_trait;
7use swiftide_core::{
8 SimplePrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
9 util::debug_long_utf8,
10};
11
12use super::chat_completion::usage_from_counts;
13use super::responses_api::{build_responses_request_from_prompt, response_to_chat_completion};
14use crate::openai::openai_error_to_language_model_error;
15
16use super::GenericOpenAI;
17use anyhow::Result;
18
19#[async_trait]
22impl<
23 C: async_openai::config::Config
24 + std::default::Default
25 + Sync
26 + Send
27 + std::fmt::Debug
28 + Clone
29 + 'static,
30> SimplePrompt for GenericOpenAI<C>
31{
32 #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))]
46 #[cfg_attr(
47 feature = "langfuse",
48 tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
49 )]
50 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
51 if self.is_responses_api_enabled() {
52 return self.prompt_via_responses_api(prompt).await;
53 }
54
55 let model = self
57 .default_options
58 .prompt_model
59 .as_ref()
60 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
61
62 let request = self
64 .chat_completion_request_defaults()
65 .model(model)
66 .messages(vec![
67 ChatCompletionRequestUserMessageArgs::default()
68 .content(prompt.render()?)
69 .build()
70 .map_err(LanguageModelError::permanent)?
71 .into(),
72 ])
73 .build()
74 .map_err(LanguageModelError::permanent)?;
75
76 tracing::trace!(
78 model = &model,
79 messages = debug_long_utf8(
80 serde_json::to_string_pretty(&request.messages.last())
81 .map_err(LanguageModelError::permanent)?,
82 100
83 ),
84 "[SimplePrompt] Request to openai"
85 );
86
87 let response = self
89 .client
90 .chat()
91 .create(request.clone())
92 .await
93 .map_err(openai_error_to_language_model_error)?;
94
95 let message = response
96 .choices
97 .first()
98 .and_then(|choice| choice.message.content.clone())
99 .ok_or_else(|| {
100 LanguageModelError::PermanentError("Expected content in response".into())
101 })?;
102
103 let usage = response.usage.as_ref().map(|usage| {
104 usage_from_counts(
105 usage.prompt_tokens,
106 usage.completion_tokens,
107 usage.total_tokens,
108 )
109 });
110
111 self.track_completion(model, usage.as_ref(), Some(&request), Some(&response));
112
113 Ok(message)
114 }
115}
116
117impl<
118 C: async_openai::config::Config
119 + std::default::Default
120 + Sync
121 + Send
122 + std::fmt::Debug
123 + Clone
124 + 'static,
125> GenericOpenAI<C>
126{
127 async fn prompt_via_responses_api(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
128 let prompt_text = prompt.render().map_err(LanguageModelError::permanent)?;
129 let model = self
130 .default_options
131 .prompt_model
132 .as_ref()
133 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
134
135 let create_request = build_responses_request_from_prompt(self, prompt_text.clone())?;
136
137 let response = self
138 .client
139 .responses()
140 .create(create_request.clone())
141 .await
142 .map_err(openai_error_to_language_model_error)?;
143
144 let completion = response_to_chat_completion(&response)?;
145
146 let message = completion.message.clone().ok_or_else(|| {
147 LanguageModelError::PermanentError("Expected content in response".into())
148 })?;
149
150 self.track_completion(
151 model,
152 completion.usage.as_ref(),
153 Some(&create_request),
154 Some(&completion),
155 );
156
157 Ok(message)
158 }
159}
160
161#[allow(clippy::items_after_statements)]
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::openai::OpenAI;
166 use async_openai::types::responses::{
167 CompletionTokensDetails, Content, OutputContent, OutputMessage, OutputStatus, OutputText,
168 PromptTokensDetails, Response as ResponsesResponse, Role, Status, Usage as ResponsesUsage,
169 };
170 use serde_json::Value;
171 use wiremock::{
172 Mock, MockServer, Request, Respond, ResponseTemplate,
173 matchers::{method, path},
174 };
175
176 #[test_log::test(tokio::test)]
177 async fn test_prompt_errors_when_model_missing() {
178 let openai = OpenAI::builder().build().unwrap();
179 let result = openai.prompt("hello".into()).await;
180 assert!(matches!(result, Err(LanguageModelError::PermanentError(_))));
181 }
182
183 #[test_log::test(tokio::test)]
184 async fn test_prompt_via_responses_api_returns_message() {
185 let mock_server = MockServer::start().await;
186
187 let response = ResponsesResponse {
188 created_at: 0,
189 error: None,
190 id: "resp".into(),
191 incomplete_details: None,
192 instructions: None,
193 max_output_tokens: None,
194 metadata: None,
195 model: "gpt-4.1-mini".into(),
196 object: "response".into(),
197 output: vec![OutputContent::Message(OutputMessage {
198 content: vec![Content::OutputText(OutputText {
199 annotations: Vec::new(),
200 text: "Hello world".into(),
201 })],
202 id: "msg".into(),
203 role: Role::Assistant,
204 status: OutputStatus::Completed,
205 })],
206 output_text: Some("Hello world".into()),
207 parallel_tool_calls: None,
208 previous_response_id: None,
209 reasoning: None,
210 store: None,
211 service_tier: None,
212 status: Status::Completed,
213 temperature: None,
214 text: None,
215 tool_choice: None,
216 tools: None,
217 top_p: None,
218 truncation: None,
219 usage: Some(ResponsesUsage {
220 input_tokens: 4,
221 input_tokens_details: PromptTokensDetails {
222 audio_tokens: Some(0),
223 cached_tokens: Some(0),
224 },
225 output_tokens: 2,
226 output_tokens_details: CompletionTokensDetails {
227 accepted_prediction_tokens: Some(0),
228 audio_tokens: Some(0),
229 reasoning_tokens: Some(0),
230 rejected_prediction_tokens: Some(0),
231 },
232 total_tokens: 6,
233 }),
234 user: None,
235 };
236
237 let response_body = serde_json::to_value(&response).unwrap();
238
239 struct ValidatePromptRequest {
240 response: Value,
241 }
242
243 impl Respond for ValidatePromptRequest {
244 fn respond(&self, request: &Request) -> ResponseTemplate {
245 let payload: Value = serde_json::from_slice(&request.body).unwrap();
246 assert_eq!(payload["model"], self.response["model"]);
247 let items = payload["input"].as_array().expect("array input");
248 assert_eq!(items.len(), 1);
249 assert_eq!(items[0]["type"], "message");
250 ResponseTemplate::new(200).set_body_json(self.response.clone())
251 }
252 }
253
254 Mock::given(method("POST"))
255 .and(path("/responses"))
256 .respond_with(ValidatePromptRequest {
257 response: response_body,
258 })
259 .mount(&mock_server)
260 .await;
261
262 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
263 let client = async_openai::Client::with_config(config);
264
265 let openai = OpenAI::builder()
266 .client(client)
267 .default_prompt_model("gpt-4.1-mini")
268 .use_responses_api(true)
269 .build()
270 .unwrap();
271
272 let result = openai.prompt("Say hi".into()).await.unwrap();
273 assert_eq!(result, "Hello world");
274 }
275
276 #[test_log::test(tokio::test)]
277 async fn test_prompt_via_responses_api_missing_output_errors() {
278 let mock_server = MockServer::start().await;
279 let empty_response = serde_json::json!({
280 "created_at": 0,
281 "id": "resp",
282 "model": "gpt-4.1-mini",
283 "object": "response",
284 "output": [],
285 "status": "completed"
286 });
287
288 Mock::given(method("POST"))
289 .and(path("/responses"))
290 .respond_with(ResponseTemplate::new(200).set_body_json(empty_response))
291 .mount(&mock_server)
292 .await;
293
294 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
295 let client = async_openai::Client::with_config(config);
296
297 let openai = OpenAI::builder()
298 .client(client)
299 .default_prompt_model("gpt-4.1-mini")
300 .use_responses_api(true)
301 .build()
302 .unwrap();
303
304 let err = openai.prompt("test".into()).await.unwrap_err();
305 assert!(matches!(err, LanguageModelError::PermanentError(_)));
306 }
307}