swiftide_integrations/openai/
structured_prompt.rs1use async_openai::types::{
9 ChatCompletionRequestUserMessageArgs, ResponseFormat, ResponseFormatJsonSchema,
10};
11use async_trait::async_trait;
12use schemars::Schema;
13#[cfg(feature = "metrics")]
14use swiftide_core::metrics::emit_usage;
15use swiftide_core::{
16 DynStructuredPrompt,
17 chat_completion::{Usage, errors::LanguageModelError},
18 prompt::Prompt,
19 util::debug_long_utf8,
20};
21
22use crate::openai::openai_error_to_language_model_error;
23
24use super::GenericOpenAI;
25use anyhow::{Context as _, Result};
26
27#[async_trait]
30impl<
31 C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
32> DynStructuredPrompt for GenericOpenAI<C>
33{
34 #[tracing::instrument(skip_all, err)]
48 #[cfg_attr(
49 feature = "langfuse",
50 tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
51 )]
52 async fn structured_prompt_dyn(
53 &self,
54 prompt: Prompt,
55 schema: Schema,
56 ) -> Result<serde_json::Value, LanguageModelError> {
57 let model = self
59 .default_options
60 .prompt_model
61 .as_ref()
62 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
63
64 let schema_value =
65 serde_json::to_value(&schema).context("Failed to get schema as value")?;
66 let response_format = ResponseFormat::JsonSchema {
67 json_schema: ResponseFormatJsonSchema {
68 description: None,
69 name: "math_reasoning".into(),
70 schema: Some(schema_value),
71 strict: Some(true),
72 },
73 };
74
75 let request = self
77 .chat_completion_request_defaults()
78 .model(model)
79 .response_format(response_format)
80 .messages(vec![
81 ChatCompletionRequestUserMessageArgs::default()
82 .content(prompt.render()?)
83 .build()
84 .map_err(LanguageModelError::permanent)?
85 .into(),
86 ])
87 .build()
88 .map_err(LanguageModelError::permanent)?;
89
90 tracing::trace!(
92 model = &model,
93 messages = debug_long_utf8(
94 serde_json::to_string_pretty(&request.messages.last())
95 .map_err(LanguageModelError::permanent)?,
96 100
97 ),
98 "[StructuredPrompt] Request to openai"
99 );
100
101 let mut response = self
103 .client
104 .chat()
105 .create(request.clone())
106 .await
107 .map_err(openai_error_to_language_model_error)?;
108
109 if cfg!(feature = "langfuse") {
110 let usage = response.usage.clone().unwrap_or_default();
111 tracing::debug!(
112 langfuse.model = model,
113 langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
114 langfuse.output = %serde_json::to_string_pretty(&response).unwrap_or_default(),
115 langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
116 );
117 }
118
119 let message = response
120 .choices
121 .remove(0)
122 .message
123 .content
124 .take()
125 .ok_or_else(|| {
126 LanguageModelError::PermanentError("Expected content in response".into())
127 })?;
128
129 {
130 if let Some(usage) = response.usage.as_ref() {
131 if let Some(callback) = &self.on_usage {
132 let usage = Usage {
133 prompt_tokens: usage.prompt_tokens,
134 completion_tokens: usage.completion_tokens,
135 total_tokens: usage.total_tokens,
136 };
137 callback(&usage).await?;
138 }
139 #[cfg(feature = "metrics")]
140 emit_usage(
141 model,
142 usage.prompt_tokens.into(),
143 usage.completion_tokens.into(),
144 usage.total_tokens.into(),
145 self.metric_metadata.as_ref(),
146 );
147 } else {
148 tracing::warn!("Metrics enabled but no usage data found in response");
149 }
150 }
151
152 let parsed = serde_json::from_str(&message)
153 .with_context(|| format!("Failed to parse response\n {message}"))?;
154
155 Ok(parsed)
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::openai::{self, OpenAI};
163 use swiftide_core::StructuredPrompt;
164
165 use super::*;
166 use async_openai::Client;
167 use async_openai::config::OpenAIConfig;
168 use schemars::{JsonSchema, schema_for};
169 use serde::{Deserialize, Serialize};
170 use wiremock::{
171 Mock, MockServer, ResponseTemplate,
172 matchers::{method, path},
173 };
174
175 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
176 struct SimpleOutput {
177 answer: String,
178 }
179
180 async fn setup_client() -> (MockServer, OpenAI) {
181 let mock_server = MockServer::start().await;
183
184 let assistant_msg = serde_json::json!({
186 "role": "assistant",
187 "content": serde_json::to_string(&SimpleOutput {
188 answer: "42".to_owned()
189 }).unwrap(),
190 });
191
192 let body = serde_json::json!({
193 "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
194 "object": "chat.completion",
195 "created": 123,
196 "model": "gpt-4.1-2025-04-14",
197 "choices": [
198 {
199 "index": 0,
200 "message": assistant_msg,
201 "logprobs": null,
202 "finish_reason": "stop"
203 }
204 ],
205 "usage": {
206 "prompt_tokens": 19,
207 "completion_tokens": 10,
208 "total_tokens": 29,
209 "prompt_tokens_details": {
210 "cached_tokens": 0,
211 "audio_tokens": 0
212 },
213 "completion_tokens_details": {
214 "reasoning_tokens": 0,
215 "audio_tokens": 0,
216 "accepted_prediction_tokens": 0,
217 "rejected_prediction_tokens": 0
218 }
219 },
220 "service_tier": "default"
221 });
222
223 Mock::given(method("POST"))
224 .and(path("/chat/completions"))
225 .respond_with(ResponseTemplate::new(200).set_body_json(body))
226 .mount(&mock_server)
227 .await;
228
229 let config = OpenAIConfig::new().with_api_base(mock_server.uri());
231 let client = Client::with_config(config);
232
233 let opts = openai::Options {
235 prompt_model: Some("gpt-4".to_string()),
236 ..openai::Options::default()
237 };
238 (
239 mock_server,
240 OpenAI::builder()
241 .client(client)
242 .default_options(opts)
243 .build()
244 .unwrap(),
245 )
246 }
247
248 #[tokio::test]
249 async fn test_structured_prompt_with_wiremock() {
250 let (_guard, ai) = setup_client().await;
251 let result: serde_json::Value = ai.structured_prompt("test".into()).await.unwrap();
253 dbg!(&result);
254
255 assert_eq!(
257 serde_json::from_value::<SimpleOutput>(result).unwrap(),
258 SimpleOutput {
259 answer: "42".into()
260 }
261 );
262 }
263
264 #[tokio::test]
265 async fn test_structured_prompt_with_wiremock_as_box() {
266 let (_guard, ai) = setup_client().await;
267 let ai: Box<dyn DynStructuredPrompt> = Box::new(ai);
269 let result: serde_json::Value = ai
270 .structured_prompt_dyn("test".into(), schema_for!(SimpleOutput))
271 .await
272 .unwrap();
273 dbg!(&result);
274
275 assert_eq!(
277 serde_json::from_value::<SimpleOutput>(result).unwrap(),
278 SimpleOutput {
279 answer: "42".into()
280 }
281 );
282 }
283}