swiftide_integrations/openai/
structured_prompt.rs

1//! This module provides an implementation of the `StructuredPrompt` trait for the `OpenAI` struct.
2//!
3//! Unlike the other traits, `StructuredPrompt` is *not* dyn safe.
4//!
5//! Use `DynStructuredPrompt` if you need dyn dispatch. For custom implementations, if you
6//! implement `DynStructuredPrompt`, you get `StructuredPrompt` for free.
7
8use async_openai::types::{
9    ChatCompletionRequestUserMessageArgs, ResponseFormat, ResponseFormatJsonSchema,
10};
11use async_trait::async_trait;
12use schemars::Schema;
13#[cfg(feature = "metrics")]
14use swiftide_core::metrics::emit_usage;
15use swiftide_core::{
16    DynStructuredPrompt,
17    chat_completion::{Usage, errors::LanguageModelError},
18    prompt::Prompt,
19    util::debug_long_utf8,
20};
21
22use crate::openai::openai_error_to_language_model_error;
23
24use super::GenericOpenAI;
25use anyhow::{Context as _, Result};
26
27/// The `StructuredPrompt` trait defines a method for sending a prompt to an AI model and receiving
28/// a response.
29#[async_trait]
30impl<
31    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
32> DynStructuredPrompt for GenericOpenAI<C>
33{
34    /// Sends a prompt to the `OpenAI` API and returns the response content.
35    ///
36    /// # Parameters
37    /// - `prompt`: A string slice that holds the prompt to be sent to the `OpenAI` API.
38    ///
39    /// # Returns
40    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
41    ///   failure, returns an error wrapped in a `Result`.
42    ///
43    /// # Errors
44    /// - Returns an error if the model is not set in the default options.
45    /// - Returns an error if the request to the `OpenAI` API fails.
46    /// - Returns an error if the response does not contain the expected content.
47    #[tracing::instrument(skip_all, err)]
48    #[cfg_attr(
49        feature = "langfuse",
50        tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
51    )]
52    async fn structured_prompt_dyn(
53        &self,
54        prompt: Prompt,
55        schema: Schema,
56    ) -> Result<serde_json::Value, LanguageModelError> {
57        // Retrieve the model from the default options, returning an error if not set.
58        let model = self
59            .default_options
60            .prompt_model
61            .as_ref()
62            .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
63
64        let schema_value =
65            serde_json::to_value(&schema).context("Failed to get schema as value")?;
66        let response_format = ResponseFormat::JsonSchema {
67            json_schema: ResponseFormatJsonSchema {
68                description: None,
69                name: "math_reasoning".into(),
70                schema: Some(schema_value),
71                strict: Some(true),
72            },
73        };
74
75        // Build the request to be sent to the OpenAI API.
76        let request = self
77            .chat_completion_request_defaults()
78            .model(model)
79            .response_format(response_format)
80            .messages(vec![
81                ChatCompletionRequestUserMessageArgs::default()
82                    .content(prompt.render()?)
83                    .build()
84                    .map_err(LanguageModelError::permanent)?
85                    .into(),
86            ])
87            .build()
88            .map_err(LanguageModelError::permanent)?;
89
90        // Log the request for debugging purposes.
91        tracing::trace!(
92            model = &model,
93            messages = debug_long_utf8(
94                serde_json::to_string_pretty(&request.messages.last())
95                    .map_err(LanguageModelError::permanent)?,
96                100
97            ),
98            "[StructuredPrompt] Request to openai"
99        );
100
101        // Send the request to the OpenAI API and await the response.
102        let mut response = self
103            .client
104            .chat()
105            .create(request.clone())
106            .await
107            .map_err(openai_error_to_language_model_error)?;
108
109        if cfg!(feature = "langfuse") {
110            let usage = response.usage.clone().unwrap_or_default();
111            tracing::debug!(
112                langfuse.model = model,
113                langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
114                langfuse.output = %serde_json::to_string_pretty(&response).unwrap_or_default(),
115                langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
116            );
117        }
118
119        let message = response
120            .choices
121            .remove(0)
122            .message
123            .content
124            .take()
125            .ok_or_else(|| {
126                LanguageModelError::PermanentError("Expected content in response".into())
127            })?;
128
129        {
130            if let Some(usage) = response.usage.as_ref() {
131                if let Some(callback) = &self.on_usage {
132                    let usage = Usage {
133                        prompt_tokens: usage.prompt_tokens,
134                        completion_tokens: usage.completion_tokens,
135                        total_tokens: usage.total_tokens,
136                    };
137                    callback(&usage).await?;
138                }
139                #[cfg(feature = "metrics")]
140                emit_usage(
141                    model,
142                    usage.prompt_tokens.into(),
143                    usage.completion_tokens.into(),
144                    usage.total_tokens.into(),
145                    self.metric_metadata.as_ref(),
146                );
147            } else {
148                tracing::warn!("Metrics enabled but no usage data found in response");
149            }
150        }
151
152        let parsed = serde_json::from_str(&message)
153            .with_context(|| format!("Failed to parse response\n {message}"))?;
154
155        // Extract and return the content of the response, returning an error if not found.
156        Ok(parsed)
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::openai::{self, OpenAI};
163    use swiftide_core::StructuredPrompt;
164
165    use super::*;
166    use async_openai::Client;
167    use async_openai::config::OpenAIConfig;
168    use schemars::{JsonSchema, schema_for};
169    use serde::{Deserialize, Serialize};
170    use wiremock::{
171        Mock, MockServer, ResponseTemplate,
172        matchers::{method, path},
173    };
174
175    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
176    struct SimpleOutput {
177        answer: String,
178    }
179
180    async fn setup_client() -> (MockServer, OpenAI) {
181        // Start the Wiremock server
182        let mock_server = MockServer::start().await;
183
184        // Prepare the response the mock should return
185        let assistant_msg = serde_json::json!({
186            "role": "assistant",
187            "content": serde_json::to_string(&SimpleOutput {
188                answer: "42".to_owned()
189            }).unwrap(),
190        });
191
192        let body = serde_json::json!({
193          "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
194          "object": "chat.completion",
195          "created": 123,
196          "model": "gpt-4.1-2025-04-14",
197          "choices": [
198            {
199              "index": 0,
200              "message": assistant_msg,
201              "logprobs": null,
202              "finish_reason": "stop"
203            }
204          ],
205          "usage": {
206            "prompt_tokens": 19,
207            "completion_tokens": 10,
208            "total_tokens": 29,
209            "prompt_tokens_details": {
210              "cached_tokens": 0,
211              "audio_tokens": 0
212            },
213            "completion_tokens_details": {
214              "reasoning_tokens": 0,
215              "audio_tokens": 0,
216              "accepted_prediction_tokens": 0,
217              "rejected_prediction_tokens": 0
218            }
219          },
220          "service_tier": "default"
221        });
222
223        Mock::given(method("POST"))
224            .and(path("/chat/completions"))
225            .respond_with(ResponseTemplate::new(200).set_body_json(body))
226            .mount(&mock_server)
227            .await;
228
229        // Point our client at the mock server
230        let config = OpenAIConfig::new().with_api_base(mock_server.uri());
231        let client = Client::with_config(config);
232
233        // Construct the GenericOpenAI instance
234        let opts = openai::Options {
235            prompt_model: Some("gpt-4".to_string()),
236            ..openai::Options::default()
237        };
238        (
239            mock_server,
240            OpenAI::builder()
241                .client(client)
242                .default_options(opts)
243                .build()
244                .unwrap(),
245        )
246    }
247
248    #[tokio::test]
249    async fn test_structured_prompt_with_wiremock() {
250        let (_guard, ai) = setup_client().await;
251        // Call structured_prompt
252        let result: serde_json::Value = ai.structured_prompt("test".into()).await.unwrap();
253        dbg!(&result);
254
255        // Assert
256        assert_eq!(
257            serde_json::from_value::<SimpleOutput>(result).unwrap(),
258            SimpleOutput {
259                answer: "42".into()
260            }
261        );
262    }
263
264    #[tokio::test]
265    async fn test_structured_prompt_with_wiremock_as_box() {
266        let (_guard, ai) = setup_client().await;
267        // Call structured_prompt
268        let ai: Box<dyn DynStructuredPrompt> = Box::new(ai);
269        let result: serde_json::Value = ai
270            .structured_prompt_dyn("test".into(), schema_for!(SimpleOutput))
271            .await
272            .unwrap();
273        dbg!(&result);
274
275        // Assert
276        assert_eq!(
277            serde_json::from_value::<SimpleOutput>(result).unwrap(),
278            SimpleOutput {
279                answer: "42".into()
280            }
281        );
282    }
283}