swiftide_integrations/openai/
simple_prompt.rs

1//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
2//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4
5use async_openai::types::ChatCompletionRequestUserMessageArgs;
6use async_trait::async_trait;
7use swiftide_core::{
8    SimplePrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
9    util::debug_long_utf8,
10};
11
12use super::chat_completion::usage_from_counts;
13use super::responses_api::{build_responses_request_from_prompt, response_to_chat_completion};
14use crate::openai::openai_error_to_language_model_error;
15
16use super::GenericOpenAI;
17use anyhow::Result;
18
19/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
20/// response.
21#[async_trait]
22impl<
23    C: async_openai::config::Config
24        + std::default::Default
25        + Sync
26        + Send
27        + std::fmt::Debug
28        + Clone
29        + 'static,
30> SimplePrompt for GenericOpenAI<C>
31{
32    /// Sends a prompt to the `OpenAI` API and returns the response content.
33    ///
34    /// # Parameters
35    /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API.
36    ///
37    /// # Returns
38    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
39    ///   failure, returns an error wrapped in a `Result`.
40    ///
41    /// # Errors
42    /// - Returns an error if the model is not set in the default options.
43    /// - Returns an error if the request to the `OpenAI` API fails.
44    /// - Returns an error if the response does not contain the expected content.
45    #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))]
46    #[cfg_attr(
47        feature = "langfuse",
48        tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
49    )]
50    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
51        if self.is_responses_api_enabled() {
52            return self.prompt_via_responses_api(prompt).await;
53        }
54
55        // Retrieve the model from the default options, returning an error if not set.
56        let model = self
57            .default_options
58            .prompt_model
59            .as_ref()
60            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
61
62        // Build the request to be sent to the OpenAI API.
63        let request = self
64            .chat_completion_request_defaults()
65            .model(model)
66            .messages(vec![
67                ChatCompletionRequestUserMessageArgs::default()
68                    .content(prompt.render()?)
69                    .build()
70                    .map_err(LanguageModelError::permanent)?
71                    .into(),
72            ])
73            .build()
74            .map_err(LanguageModelError::permanent)?;
75
76        // Log the request for debugging purposes.
77        tracing::trace!(
78            model = &model,
79            messages = debug_long_utf8(
80                serde_json::to_string_pretty(&request.messages.last())
81                    .map_err(LanguageModelError::permanent)?,
82                100
83            ),
84            "[SimplePrompt] Request to openai"
85        );
86
87        // Send the request to the OpenAI API and await the response.
88        let response = self
89            .client
90            .chat()
91            .create(request.clone())
92            .await
93            .map_err(openai_error_to_language_model_error)?;
94
95        let message = response
96            .choices
97            .first()
98            .and_then(|choice| choice.message.content.clone())
99            .ok_or_else(|| {
100                LanguageModelError::PermanentError("Expected content in response".into())
101            })?;
102
103        let usage = response.usage.as_ref().map(|usage| {
104            usage_from_counts(
105                usage.prompt_tokens,
106                usage.completion_tokens,
107                usage.total_tokens,
108            )
109        });
110
111        self.track_completion(model, usage.as_ref(), Some(&request), Some(&response));
112
113        Ok(message)
114    }
115}
116
117impl<
118    C: async_openai::config::Config
119        + std::default::Default
120        + Sync
121        + Send
122        + std::fmt::Debug
123        + Clone
124        + 'static,
125> GenericOpenAI<C>
126{
127    async fn prompt_via_responses_api(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
128        let prompt_text = prompt.render().map_err(LanguageModelError::permanent)?;
129        let model = self
130            .default_options
131            .prompt_model
132            .as_ref()
133            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
134
135        let create_request = build_responses_request_from_prompt(self, prompt_text.clone())?;
136
137        let response = self
138            .client
139            .responses()
140            .create(create_request.clone())
141            .await
142            .map_err(openai_error_to_language_model_error)?;
143
144        let completion = response_to_chat_completion(&response)?;
145
146        let message = completion.message.clone().ok_or_else(|| {
147            LanguageModelError::PermanentError("Expected content in response".into())
148        })?;
149
150        self.track_completion(
151            model,
152            completion.usage.as_ref(),
153            Some(&create_request),
154            Some(&completion),
155        );
156
157        Ok(message)
158    }
159}
160
161#[allow(clippy::items_after_statements)]
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::openai::OpenAI;
166    use async_openai::types::responses::{
167        CompletionTokensDetails, Content, OutputContent, OutputMessage, OutputStatus, OutputText,
168        PromptTokensDetails, Response as ResponsesResponse, Role, Status, Usage as ResponsesUsage,
169    };
170    use serde_json::Value;
171    use wiremock::{
172        Mock, MockServer, Request, Respond, ResponseTemplate,
173        matchers::{method, path},
174    };
175
176    #[test_log::test(tokio::test)]
177    async fn test_prompt_errors_when_model_missing() {
178        let openai = OpenAI::builder().build().unwrap();
179        let result = openai.prompt("hello".into()).await;
180        assert!(matches!(result, Err(LanguageModelError::PermanentError(_))));
181    }
182
183    #[test_log::test(tokio::test)]
184    async fn test_prompt_via_responses_api_returns_message() {
185        let mock_server = MockServer::start().await;
186
187        let response = ResponsesResponse {
188            created_at: 0,
189            error: None,
190            id: "resp".into(),
191            incomplete_details: None,
192            instructions: None,
193            max_output_tokens: None,
194            metadata: None,
195            model: "gpt-4.1-mini".into(),
196            object: "response".into(),
197            output: vec![OutputContent::Message(OutputMessage {
198                content: vec![Content::OutputText(OutputText {
199                    annotations: Vec::new(),
200                    text: "Hello world".into(),
201                })],
202                id: "msg".into(),
203                role: Role::Assistant,
204                status: OutputStatus::Completed,
205            })],
206            output_text: Some("Hello world".into()),
207            parallel_tool_calls: None,
208            previous_response_id: None,
209            reasoning: None,
210            store: None,
211            service_tier: None,
212            status: Status::Completed,
213            temperature: None,
214            text: None,
215            tool_choice: None,
216            tools: None,
217            top_p: None,
218            truncation: None,
219            usage: Some(ResponsesUsage {
220                input_tokens: 4,
221                input_tokens_details: PromptTokensDetails {
222                    audio_tokens: Some(0),
223                    cached_tokens: Some(0),
224                },
225                output_tokens: 2,
226                output_tokens_details: CompletionTokensDetails {
227                    accepted_prediction_tokens: Some(0),
228                    audio_tokens: Some(0),
229                    reasoning_tokens: Some(0),
230                    rejected_prediction_tokens: Some(0),
231                },
232                total_tokens: 6,
233            }),
234            user: None,
235        };
236
237        let response_body = serde_json::to_value(&response).unwrap();
238
239        struct ValidatePromptRequest {
240            response: Value,
241        }
242
243        impl Respond for ValidatePromptRequest {
244            fn respond(&self, request: &Request) -> ResponseTemplate {
245                let payload: Value = serde_json::from_slice(&request.body).unwrap();
246                assert_eq!(payload["model"], self.response["model"]);
247                let items = payload["input"].as_array().expect("array input");
248                assert_eq!(items.len(), 1);
249                assert_eq!(items[0]["type"], "message");
250                ResponseTemplate::new(200).set_body_json(self.response.clone())
251            }
252        }
253
254        Mock::given(method("POST"))
255            .and(path("/responses"))
256            .respond_with(ValidatePromptRequest {
257                response: response_body,
258            })
259            .mount(&mock_server)
260            .await;
261
262        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
263        let client = async_openai::Client::with_config(config);
264
265        let openai = OpenAI::builder()
266            .client(client)
267            .default_prompt_model("gpt-4.1-mini")
268            .use_responses_api(true)
269            .build()
270            .unwrap();
271
272        let result = openai.prompt("Say hi".into()).await.unwrap();
273        assert_eq!(result, "Hello world");
274    }
275
276    #[test_log::test(tokio::test)]
277    async fn test_prompt_via_responses_api_missing_output_errors() {
278        let mock_server = MockServer::start().await;
279        let empty_response = serde_json::json!({
280            "created_at": 0,
281            "id": "resp",
282            "model": "gpt-4.1-mini",
283            "object": "response",
284            "output": [],
285            "status": "completed"
286        });
287
288        Mock::given(method("POST"))
289            .and(path("/responses"))
290            .respond_with(ResponseTemplate::new(200).set_body_json(empty_response))
291            .mount(&mock_server)
292            .await;
293
294        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
295        let client = async_openai::Client::with_config(config);
296
297        let openai = OpenAI::builder()
298            .client(client)
299            .default_prompt_model("gpt-4.1-mini")
300            .use_responses_api(true)
301            .build()
302            .unwrap();
303
304        let err = openai.prompt("test".into()).await.unwrap_err();
305        assert!(matches!(err, LanguageModelError::PermanentError(_)));
306    }
307}