swiftide_integrations/openai/
structured_prompt.rs

1//! This module provides an implementation of the `StructuredPrompt` trait for the `OpenAI` struct.
2//!
3//! Unlike the other traits, `StructuredPrompt` is *not* dyn safe.
4//!
5//! Use `DynStructuredPrompt` if you need dyn dispatch. For custom implementations, if you
6//! implement `DynStructuredPrompt`, you get `StructuredPrompt` for free.
7
8use async_openai::types::{
9    ChatCompletionRequestUserMessageArgs, ResponseFormat, ResponseFormatJsonSchema,
10};
11use async_trait::async_trait;
12use schemars::Schema;
13use swiftide_core::{
14    DynStructuredPrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
15    util::debug_long_utf8,
16};
17
18use super::chat_completion::usage_from_counts;
19use super::responses_api::{
20    build_responses_request_from_prompt_with_schema, response_to_chat_completion,
21};
22use crate::openai::openai_error_to_language_model_error;
23
24use super::GenericOpenAI;
25use anyhow::{Context as _, Result};
26
27/// The `StructuredPrompt` trait defines a method for sending a prompt to an AI model and receiving
28/// a response.
29#[async_trait]
30impl<
31    C: async_openai::config::Config
32        + std::default::Default
33        + Sync
34        + Send
35        + std::fmt::Debug
36        + Clone
37        + 'static,
38> DynStructuredPrompt for GenericOpenAI<C>
39{
40    /// Sends a prompt to the `OpenAI` API and returns the response content.
41    ///
42    /// # Parameters
43    /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API.
44    ///
45    /// # Returns
46    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
47    ///   failure, returns an error wrapped in a `Result`.
48    ///
49    /// # Errors
50    /// - Returns an error if the model is not set in the default options.
51    /// - Returns an error if the request to the `OpenAI` API fails.
52    /// - Returns an error if the response does not contain the expected content.
53    #[tracing::instrument(skip_all, err)]
54    #[cfg_attr(
55        feature = "langfuse",
56        tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
57    )]
58    async fn structured_prompt_dyn(
59        &self,
60        prompt: Prompt,
61        schema: Schema,
62    ) -> Result<serde_json::Value, LanguageModelError> {
63        if self.is_responses_api_enabled() {
64            return self
65                .structured_prompt_via_responses_api(prompt, schema)
66                .await;
67        }
68
69        // Retrieve the model from the default options, returning an error if not set.
70        let model = self
71            .default_options
72            .prompt_model
73            .as_ref()
74            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
75
76        let schema_value =
77            serde_json::to_value(&schema).context("Failed to get schema as value")?;
78        let response_format = ResponseFormat::JsonSchema {
79            json_schema: ResponseFormatJsonSchema {
80                description: None,
81                name: "structured_prompt".into(),
82                schema: Some(schema_value),
83                strict: Some(true),
84            },
85        };
86
87        // Build the request to be sent to the OpenAI API.
88        let request = self
89            .chat_completion_request_defaults()
90            .model(model)
91            .response_format(response_format)
92            .messages(vec![
93                ChatCompletionRequestUserMessageArgs::default()
94                    .content(prompt.render()?)
95                    .build()
96                    .map_err(LanguageModelError::permanent)?
97                    .into(),
98            ])
99            .build()
100            .map_err(LanguageModelError::permanent)?;
101
102        // Log the request for debugging purposes.
103        tracing::trace!(
104            model = &model,
105            messages = debug_long_utf8(
106                serde_json::to_string_pretty(&request.messages.last())
107                    .map_err(LanguageModelError::permanent)?,
108                100
109            ),
110            "[StructuredPrompt] Request to openai"
111        );
112
113        // Send the request to the OpenAI API and await the response.
114        let response = self
115            .client
116            .chat()
117            .create(request.clone())
118            .await
119            .map_err(openai_error_to_language_model_error)?;
120
121        let message = response
122            .choices
123            .first()
124            .and_then(|choice| choice.message.content.clone())
125            .ok_or_else(|| {
126                LanguageModelError::PermanentError("Expected content in response".into())
127            })?;
128
129        let usage = response.usage.as_ref().map(|usage| {
130            usage_from_counts(
131                usage.prompt_tokens,
132                usage.completion_tokens,
133                usage.total_tokens,
134            )
135        });
136
137        self.track_completion(model, usage.as_ref(), Some(&request), Some(&response));
138
139        let parsed = serde_json::from_str(&message)
140            .with_context(|| format!("Failed to parse response\n {message}"))?;
141
142        // Extract and return the content of the response, returning an error if not found.
143        Ok(parsed)
144    }
145}
146
147impl<
148    C: async_openai::config::Config
149        + std::default::Default
150        + Sync
151        + Send
152        + std::fmt::Debug
153        + Clone
154        + 'static,
155> GenericOpenAI<C>
156{
157    async fn structured_prompt_via_responses_api(
158        &self,
159        prompt: Prompt,
160        schema: Schema,
161    ) -> Result<serde_json::Value, LanguageModelError> {
162        let prompt_text = prompt.render().map_err(LanguageModelError::permanent)?;
163        let model = self
164            .default_options
165            .prompt_model
166            .as_ref()
167            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
168
169        let schema_value = serde_json::to_value(&schema)
170            .context("Failed to get schema as value")
171            .map_err(LanguageModelError::permanent)?;
172
173        let create_request = build_responses_request_from_prompt_with_schema(
174            self,
175            prompt_text.clone(),
176            schema_value,
177        )?;
178
179        let response = self
180            .client
181            .responses()
182            .create(create_request.clone())
183            .await
184            .map_err(openai_error_to_language_model_error)?;
185
186        let completion = response_to_chat_completion(&response)?;
187
188        let message = completion.message.clone().ok_or_else(|| {
189            LanguageModelError::PermanentError("Expected content in response".into())
190        })?;
191
192        self.track_completion(
193            model,
194            completion.usage.as_ref(),
195            Some(&create_request),
196            Some(&completion),
197        );
198
199        let parsed = serde_json::from_str(&message)
200            .with_context(|| format!("Failed to parse response\n {message}"))
201            .map_err(LanguageModelError::permanent)?;
202
203        Ok(parsed)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use crate::openai::{self, OpenAI};
210    use swiftide_core::StructuredPrompt;
211
212    use super::*;
213    use async_openai::Client;
214    use async_openai::config::OpenAIConfig;
215    use async_openai::types::responses::{
216        CompletionTokensDetails, Content, OutputContent, OutputMessage, OutputStatus, OutputText,
217        PromptTokensDetails, Response as ResponsesResponse, Role, Status, Usage as ResponsesUsage,
218    };
219    use schemars::{JsonSchema, schema_for};
220    use serde::{Deserialize, Serialize};
221    use wiremock::{
222        Mock, MockServer, ResponseTemplate,
223        matchers::{method, path},
224    };
225
226    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
227    struct SimpleOutput {
228        answer: String,
229    }
230
231    async fn setup_client() -> (MockServer, OpenAI) {
232        // Start the Wiremock server
233        let mock_server = MockServer::start().await;
234
235        // Prepare the response the mock should return
236        let assistant_msg = serde_json::json!({
237            "role": "assistant",
238            "content": serde_json::to_string(&SimpleOutput {
239                answer: "42".to_owned()
240            }).unwrap(),
241        });
242
243        let body = serde_json::json!({
244          "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
245          "object": "chat.completion",
246          "created": 123,
247          "model": "gpt-4.1-2025-04-14",
248          "choices": [
249            {
250              "index": 0,
251              "message": assistant_msg,
252              "logprobs": null,
253              "finish_reason": "stop"
254            }
255          ],
256          "usage": {
257            "prompt_tokens": 19,
258            "completion_tokens": 10,
259            "total_tokens": 29,
260            "prompt_tokens_details": {
261              "cached_tokens": 0,
262              "audio_tokens": 0
263            },
264            "completion_tokens_details": {
265              "reasoning_tokens": 0,
266              "audio_tokens": 0,
267              "accepted_prediction_tokens": 0,
268              "rejected_prediction_tokens": 0
269            }
270          },
271          "service_tier": "default"
272        });
273
274        Mock::given(method("POST"))
275            .and(path("/chat/completions"))
276            .respond_with(ResponseTemplate::new(200).set_body_json(body))
277            .mount(&mock_server)
278            .await;
279
280        // Point our client at the mock server
281        let config = OpenAIConfig::new().with_api_base(mock_server.uri());
282        let client = Client::with_config(config);
283
284        // Construct the GenericOpenAI instance
285        let opts = openai::Options {
286            prompt_model: Some("gpt-4".to_string()),
287            ..openai::Options::default()
288        };
289        (
290            mock_server,
291            OpenAI::builder()
292                .client(client)
293                .default_options(opts)
294                .build()
295                .unwrap(),
296        )
297    }
298
299    #[tokio::test]
300    async fn test_structured_prompt_with_wiremock() {
301        let (_guard, ai) = setup_client().await;
302        // Call structured_prompt
303        let result: serde_json::Value = ai.structured_prompt("test".into()).await.unwrap();
304        dbg!(&result);
305
306        // Assert
307        assert_eq!(
308            serde_json::from_value::<SimpleOutput>(result).unwrap(),
309            SimpleOutput {
310                answer: "42".into()
311            }
312        );
313    }
314
315    #[tokio::test]
316    async fn test_structured_prompt_with_wiremock_as_box() {
317        let (_guard, ai) = setup_client().await;
318        // Call structured_prompt
319        let ai: Box<dyn DynStructuredPrompt> = Box::new(ai);
320        let result: serde_json::Value = ai
321            .structured_prompt_dyn("test".into(), schema_for!(SimpleOutput))
322            .await
323            .unwrap();
324        dbg!(&result);
325
326        // Assert
327        assert_eq!(
328            serde_json::from_value::<SimpleOutput>(result).unwrap(),
329            SimpleOutput {
330                answer: "42".into()
331            }
332        );
333    }
334
335    #[test_log::test(tokio::test)]
336    async fn test_structured_prompt_via_responses_api() {
337        let mock_server = MockServer::start().await;
338
339        let response = ResponsesResponse {
340            created_at: 0,
341            error: None,
342            id: "resp".into(),
343            incomplete_details: None,
344            instructions: None,
345            max_output_tokens: None,
346            metadata: None,
347            model: "gpt-4.1-mini".into(),
348            object: "response".into(),
349            output: vec![OutputContent::Message(OutputMessage {
350                content: vec![Content::OutputText(OutputText {
351                    annotations: Vec::new(),
352                    text: serde_json::to_string(&SimpleOutput {
353                        answer: "structured".into(),
354                    })
355                    .unwrap(),
356                })],
357                id: "msg".into(),
358                role: Role::Assistant,
359                status: OutputStatus::Completed,
360            })],
361            output_text: None,
362            parallel_tool_calls: None,
363            previous_response_id: None,
364            reasoning: None,
365            store: None,
366            service_tier: None,
367            status: Status::Completed,
368            temperature: None,
369            text: None,
370            tool_choice: None,
371            tools: None,
372            top_p: None,
373            truncation: None,
374            usage: Some(ResponsesUsage {
375                input_tokens: 10,
376                input_tokens_details: PromptTokensDetails {
377                    audio_tokens: Some(0),
378                    cached_tokens: Some(0),
379                },
380                output_tokens: 4,
381                output_tokens_details: CompletionTokensDetails {
382                    accepted_prediction_tokens: Some(0),
383                    audio_tokens: Some(0),
384                    reasoning_tokens: Some(0),
385                    rejected_prediction_tokens: Some(0),
386                },
387                total_tokens: 14,
388            }),
389            user: None,
390        };
391
392        let response_body = serde_json::to_value(&response).unwrap();
393
394        Mock::given(method("POST"))
395            .and(path("/responses"))
396            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
397            .mount(&mock_server)
398            .await;
399
400        let config = OpenAIConfig::new().with_api_base(mock_server.uri());
401        let client = Client::with_config(config);
402
403        let openai = OpenAI::builder()
404            .client(client)
405            .default_prompt_model("gpt-4.1-mini")
406            .use_responses_api(true)
407            .build()
408            .unwrap();
409
410        let schema = schema_for!(SimpleOutput);
411        let result = openai
412            .structured_prompt_dyn("Render".into(), schema)
413            .await
414            .unwrap();
415
416        assert_eq!(
417            serde_json::from_value::<SimpleOutput>(result).unwrap(),
418            SimpleOutput {
419                answer: "structured".into(),
420            }
421        );
422    }
423
424    #[test_log::test(tokio::test)]
425    async fn test_structured_prompt_via_responses_api_invalid_json_errors() {
426        let mock_server = MockServer::start().await;
427
428        let bad_response = ResponsesResponse {
429            created_at: 0,
430            error: None,
431            id: "resp".into(),
432            incomplete_details: None,
433            instructions: None,
434            max_output_tokens: None,
435            metadata: None,
436            model: "gpt-4.1-mini".into(),
437            object: "response".into(),
438            output: vec![OutputContent::Message(OutputMessage {
439                content: vec![Content::OutputText(OutputText {
440                    annotations: Vec::new(),
441                    text: "not json".into(),
442                })],
443                id: "msg".into(),
444                role: Role::Assistant,
445                status: OutputStatus::Completed,
446            })],
447            output_text: Some("not json".into()),
448            parallel_tool_calls: None,
449            previous_response_id: None,
450            reasoning: None,
451            store: None,
452            service_tier: None,
453            status: Status::Completed,
454            temperature: None,
455            text: None,
456            tool_choice: None,
457            tools: None,
458            top_p: None,
459            truncation: None,
460            usage: None,
461            user: None,
462        };
463
464        Mock::given(method("POST"))
465            .and(path("/responses"))
466            .respond_with(ResponseTemplate::new(200).set_body_json(bad_response))
467            .mount(&mock_server)
468            .await;
469
470        let config = OpenAIConfig::new().with_api_base(mock_server.uri());
471        let client = Client::with_config(config);
472
473        let openai = OpenAI::builder()
474            .client(client)
475            .default_prompt_model("gpt-4.1-mini")
476            .use_responses_api(true)
477            .build()
478            .unwrap();
479
480        let schema = schema_for!(SimpleOutput);
481        let err = openai
482            .structured_prompt_dyn("Render".into(), schema)
483            .await
484            .unwrap_err();
485
486        assert!(matches!(err, LanguageModelError::PermanentError(_)));
487    }
488}