1use async_openai::types::{
9 ChatCompletionRequestUserMessageArgs, ResponseFormat, ResponseFormatJsonSchema,
10};
11use async_trait::async_trait;
12use schemars::Schema;
13use swiftide_core::{
14 DynStructuredPrompt, chat_completion::errors::LanguageModelError, prompt::Prompt,
15 util::debug_long_utf8,
16};
17
18use super::chat_completion::usage_from_counts;
19use super::responses_api::{
20 build_responses_request_from_prompt_with_schema, response_to_chat_completion,
21};
22use crate::openai::openai_error_to_language_model_error;
23
24use super::GenericOpenAI;
25use anyhow::{Context as _, Result};
26
27#[async_trait]
30impl<
31 C: async_openai::config::Config
32 + std::default::Default
33 + Sync
34 + Send
35 + std::fmt::Debug
36 + Clone
37 + 'static,
38> DynStructuredPrompt for GenericOpenAI<C>
39{
40 #[tracing::instrument(skip_all, err)]
54 #[cfg_attr(
55 feature = "langfuse",
56 tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
57 )]
58 async fn structured_prompt_dyn(
59 &self,
60 prompt: Prompt,
61 schema: Schema,
62 ) -> Result<serde_json::Value, LanguageModelError> {
63 if self.is_responses_api_enabled() {
64 return self
65 .structured_prompt_via_responses_api(prompt, schema)
66 .await;
67 }
68
69 let model = self
71 .default_options
72 .prompt_model
73 .as_ref()
74 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
75
76 let schema_value =
77 serde_json::to_value(&schema).context("Failed to get schema as value")?;
78 let response_format = ResponseFormat::JsonSchema {
79 json_schema: ResponseFormatJsonSchema {
80 description: None,
81 name: "structured_prompt".into(),
82 schema: Some(schema_value),
83 strict: Some(true),
84 },
85 };
86
87 let request = self
89 .chat_completion_request_defaults()
90 .model(model)
91 .response_format(response_format)
92 .messages(vec![
93 ChatCompletionRequestUserMessageArgs::default()
94 .content(prompt.render()?)
95 .build()
96 .map_err(LanguageModelError::permanent)?
97 .into(),
98 ])
99 .build()
100 .map_err(LanguageModelError::permanent)?;
101
102 tracing::trace!(
104 model = &model,
105 messages = debug_long_utf8(
106 serde_json::to_string_pretty(&request.messages.last())
107 .map_err(LanguageModelError::permanent)?,
108 100
109 ),
110 "[StructuredPrompt] Request to openai"
111 );
112
113 let response = self
115 .client
116 .chat()
117 .create(request.clone())
118 .await
119 .map_err(openai_error_to_language_model_error)?;
120
121 let message = response
122 .choices
123 .first()
124 .and_then(|choice| choice.message.content.clone())
125 .ok_or_else(|| {
126 LanguageModelError::PermanentError("Expected content in response".into())
127 })?;
128
129 let usage = response.usage.as_ref().map(|usage| {
130 usage_from_counts(
131 usage.prompt_tokens,
132 usage.completion_tokens,
133 usage.total_tokens,
134 )
135 });
136
137 self.track_completion(model, usage.as_ref(), Some(&request), Some(&response));
138
139 let parsed = serde_json::from_str(&message)
140 .with_context(|| format!("Failed to parse response\n {message}"))?;
141
142 Ok(parsed)
144 }
145}
146
147impl<
148 C: async_openai::config::Config
149 + std::default::Default
150 + Sync
151 + Send
152 + std::fmt::Debug
153 + Clone
154 + 'static,
155> GenericOpenAI<C>
156{
157 async fn structured_prompt_via_responses_api(
158 &self,
159 prompt: Prompt,
160 schema: Schema,
161 ) -> Result<serde_json::Value, LanguageModelError> {
162 let prompt_text = prompt.render().map_err(LanguageModelError::permanent)?;
163 let model = self
164 .default_options
165 .prompt_model
166 .as_ref()
167 .ok_or_else(|| LanguageModelError::PermanentError("Model not set".into()))?;
168
169 let schema_value = serde_json::to_value(&schema)
170 .context("Failed to get schema as value")
171 .map_err(LanguageModelError::permanent)?;
172
173 let create_request = build_responses_request_from_prompt_with_schema(
174 self,
175 prompt_text.clone(),
176 schema_value,
177 )?;
178
179 let response = self
180 .client
181 .responses()
182 .create(create_request.clone())
183 .await
184 .map_err(openai_error_to_language_model_error)?;
185
186 let completion = response_to_chat_completion(&response)?;
187
188 let message = completion.message.clone().ok_or_else(|| {
189 LanguageModelError::PermanentError("Expected content in response".into())
190 })?;
191
192 self.track_completion(
193 model,
194 completion.usage.as_ref(),
195 Some(&create_request),
196 Some(&completion),
197 );
198
199 let parsed = serde_json::from_str(&message)
200 .with_context(|| format!("Failed to parse response\n {message}"))
201 .map_err(LanguageModelError::permanent)?;
202
203 Ok(parsed)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use crate::openai::{self, OpenAI};
210 use swiftide_core::StructuredPrompt;
211
212 use super::*;
213 use async_openai::Client;
214 use async_openai::config::OpenAIConfig;
215 use async_openai::types::responses::{
216 CompletionTokensDetails, Content, OutputContent, OutputMessage, OutputStatus, OutputText,
217 PromptTokensDetails, Response as ResponsesResponse, Role, Status, Usage as ResponsesUsage,
218 };
219 use schemars::{JsonSchema, schema_for};
220 use serde::{Deserialize, Serialize};
221 use wiremock::{
222 Mock, MockServer, ResponseTemplate,
223 matchers::{method, path},
224 };
225
226 #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
227 struct SimpleOutput {
228 answer: String,
229 }
230
231 async fn setup_client() -> (MockServer, OpenAI) {
232 let mock_server = MockServer::start().await;
234
235 let assistant_msg = serde_json::json!({
237 "role": "assistant",
238 "content": serde_json::to_string(&SimpleOutput {
239 answer: "42".to_owned()
240 }).unwrap(),
241 });
242
243 let body = serde_json::json!({
244 "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
245 "object": "chat.completion",
246 "created": 123,
247 "model": "gpt-4.1-2025-04-14",
248 "choices": [
249 {
250 "index": 0,
251 "message": assistant_msg,
252 "logprobs": null,
253 "finish_reason": "stop"
254 }
255 ],
256 "usage": {
257 "prompt_tokens": 19,
258 "completion_tokens": 10,
259 "total_tokens": 29,
260 "prompt_tokens_details": {
261 "cached_tokens": 0,
262 "audio_tokens": 0
263 },
264 "completion_tokens_details": {
265 "reasoning_tokens": 0,
266 "audio_tokens": 0,
267 "accepted_prediction_tokens": 0,
268 "rejected_prediction_tokens": 0
269 }
270 },
271 "service_tier": "default"
272 });
273
274 Mock::given(method("POST"))
275 .and(path("/chat/completions"))
276 .respond_with(ResponseTemplate::new(200).set_body_json(body))
277 .mount(&mock_server)
278 .await;
279
280 let config = OpenAIConfig::new().with_api_base(mock_server.uri());
282 let client = Client::with_config(config);
283
284 let opts = openai::Options {
286 prompt_model: Some("gpt-4".to_string()),
287 ..openai::Options::default()
288 };
289 (
290 mock_server,
291 OpenAI::builder()
292 .client(client)
293 .default_options(opts)
294 .build()
295 .unwrap(),
296 )
297 }
298
299 #[tokio::test]
300 async fn test_structured_prompt_with_wiremock() {
301 let (_guard, ai) = setup_client().await;
302 let result: serde_json::Value = ai.structured_prompt("test".into()).await.unwrap();
304 dbg!(&result);
305
306 assert_eq!(
308 serde_json::from_value::<SimpleOutput>(result).unwrap(),
309 SimpleOutput {
310 answer: "42".into()
311 }
312 );
313 }
314
315 #[tokio::test]
316 async fn test_structured_prompt_with_wiremock_as_box() {
317 let (_guard, ai) = setup_client().await;
318 let ai: Box<dyn DynStructuredPrompt> = Box::new(ai);
320 let result: serde_json::Value = ai
321 .structured_prompt_dyn("test".into(), schema_for!(SimpleOutput))
322 .await
323 .unwrap();
324 dbg!(&result);
325
326 assert_eq!(
328 serde_json::from_value::<SimpleOutput>(result).unwrap(),
329 SimpleOutput {
330 answer: "42".into()
331 }
332 );
333 }
334
335 #[test_log::test(tokio::test)]
336 async fn test_structured_prompt_via_responses_api() {
337 let mock_server = MockServer::start().await;
338
339 let response = ResponsesResponse {
340 created_at: 0,
341 error: None,
342 id: "resp".into(),
343 incomplete_details: None,
344 instructions: None,
345 max_output_tokens: None,
346 metadata: None,
347 model: "gpt-4.1-mini".into(),
348 object: "response".into(),
349 output: vec![OutputContent::Message(OutputMessage {
350 content: vec![Content::OutputText(OutputText {
351 annotations: Vec::new(),
352 text: serde_json::to_string(&SimpleOutput {
353 answer: "structured".into(),
354 })
355 .unwrap(),
356 })],
357 id: "msg".into(),
358 role: Role::Assistant,
359 status: OutputStatus::Completed,
360 })],
361 output_text: None,
362 parallel_tool_calls: None,
363 previous_response_id: None,
364 reasoning: None,
365 store: None,
366 service_tier: None,
367 status: Status::Completed,
368 temperature: None,
369 text: None,
370 tool_choice: None,
371 tools: None,
372 top_p: None,
373 truncation: None,
374 usage: Some(ResponsesUsage {
375 input_tokens: 10,
376 input_tokens_details: PromptTokensDetails {
377 audio_tokens: Some(0),
378 cached_tokens: Some(0),
379 },
380 output_tokens: 4,
381 output_tokens_details: CompletionTokensDetails {
382 accepted_prediction_tokens: Some(0),
383 audio_tokens: Some(0),
384 reasoning_tokens: Some(0),
385 rejected_prediction_tokens: Some(0),
386 },
387 total_tokens: 14,
388 }),
389 user: None,
390 };
391
392 let response_body = serde_json::to_value(&response).unwrap();
393
394 Mock::given(method("POST"))
395 .and(path("/responses"))
396 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
397 .mount(&mock_server)
398 .await;
399
400 let config = OpenAIConfig::new().with_api_base(mock_server.uri());
401 let client = Client::with_config(config);
402
403 let openai = OpenAI::builder()
404 .client(client)
405 .default_prompt_model("gpt-4.1-mini")
406 .use_responses_api(true)
407 .build()
408 .unwrap();
409
410 let schema = schema_for!(SimpleOutput);
411 let result = openai
412 .structured_prompt_dyn("Render".into(), schema)
413 .await
414 .unwrap();
415
416 assert_eq!(
417 serde_json::from_value::<SimpleOutput>(result).unwrap(),
418 SimpleOutput {
419 answer: "structured".into(),
420 }
421 );
422 }
423
424 #[test_log::test(tokio::test)]
425 async fn test_structured_prompt_via_responses_api_invalid_json_errors() {
426 let mock_server = MockServer::start().await;
427
428 let bad_response = ResponsesResponse {
429 created_at: 0,
430 error: None,
431 id: "resp".into(),
432 incomplete_details: None,
433 instructions: None,
434 max_output_tokens: None,
435 metadata: None,
436 model: "gpt-4.1-mini".into(),
437 object: "response".into(),
438 output: vec![OutputContent::Message(OutputMessage {
439 content: vec![Content::OutputText(OutputText {
440 annotations: Vec::new(),
441 text: "not json".into(),
442 })],
443 id: "msg".into(),
444 role: Role::Assistant,
445 status: OutputStatus::Completed,
446 })],
447 output_text: Some("not json".into()),
448 parallel_tool_calls: None,
449 previous_response_id: None,
450 reasoning: None,
451 store: None,
452 service_tier: None,
453 status: Status::Completed,
454 temperature: None,
455 text: None,
456 tool_choice: None,
457 tools: None,
458 top_p: None,
459 truncation: None,
460 usage: None,
461 user: None,
462 };
463
464 Mock::given(method("POST"))
465 .and(path("/responses"))
466 .respond_with(ResponseTemplate::new(200).set_body_json(bad_response))
467 .mount(&mock_server)
468 .await;
469
470 let config = OpenAIConfig::new().with_api_base(mock_server.uri());
471 let client = Client::with_config(config);
472
473 let openai = OpenAI::builder()
474 .client(client)
475 .default_prompt_model("gpt-4.1-mini")
476 .use_responses_api(true)
477 .build()
478 .unwrap();
479
480 let schema = schema_for!(SimpleOutput);
481 let err = openai
482 .structured_prompt_dyn("Render".into(), schema)
483 .await
484 .unwrap_err();
485
486 assert!(matches!(err, LanguageModelError::PermanentError(_)));
487 }
488}