1use std::sync::Arc;
2use std::sync::Mutex;
3
4use anyhow::{Context as _, Result};
5use async_openai::types::ChatCompletionStreamOptions;
6use async_openai::types::{
7 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
8 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
9 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
10 ChatCompletionToolType, FunctionCall, FunctionObjectArgs,
11};
12use async_trait::async_trait;
13use futures_util::StreamExt as _;
14use futures_util::stream;
15use itertools::Itertools;
16use serde_json::json;
17use swiftide_core::ChatCompletionStream;
18use swiftide_core::chat_completion::UsageBuilder;
19use swiftide_core::chat_completion::{
20 ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec,
21 errors::LanguageModelError,
22};
23#[cfg(feature = "metrics")]
24use swiftide_core::metrics::emit_usage;
25
26use super::GenericOpenAI;
27use super::openai_error_to_language_model_error;
28
29#[async_trait]
30impl<
31 C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
32> ChatCompletion for GenericOpenAI<C>
33{
34 #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))]
35 #[cfg_attr(
36 feature = "langfuse",
37 tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
38 )]
39 async fn complete(
40 &self,
41 request: &ChatCompletionRequest,
42 ) -> Result<ChatCompletionResponse, LanguageModelError> {
43 let model = self
44 .default_options
45 .prompt_model
46 .as_ref()
47 .context("Model not set")?;
48
49 let messages = request
50 .messages()
51 .iter()
52 .map(message_to_openai)
53 .collect::<Result<Vec<_>>>()?;
54
55 let mut openai_request = self
57 .chat_completion_request_defaults()
58 .model(model)
59 .messages(messages)
60 .to_owned();
61
62 if !request.tools_spec.is_empty() {
63 openai_request
64 .tools(
65 request
66 .tools_spec()
67 .iter()
68 .map(tools_to_openai)
69 .collect::<Result<Vec<_>>>()?,
70 )
71 .tool_choice("auto");
72 if let Some(par) = self.default_options.parallel_tool_calls {
73 openai_request.parallel_tool_calls(par);
74 }
75 }
76
77 let request = openai_request
78 .build()
79 .map_err(openai_error_to_language_model_error)?;
80
81 tracing::trace!(model, ?request, "Sending request to OpenAI");
82
83 let response = self
84 .client
85 .chat()
86 .create(request.clone())
87 .await
88 .map_err(openai_error_to_language_model_error)?;
89
90 if cfg!(feature = "langfuse") {
91 let usage = response.usage.clone().unwrap_or_default();
92 tracing::debug!(
93 langfuse.model = model,
94 langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
95 langfuse.output = %serde_json::to_string_pretty(&response).unwrap_or_default(),
96 langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
97 );
98 }
99
100 tracing::trace!(?response, "[ChatCompletion] Full response from OpenAI");
101 let mut builder = ChatCompletionResponse::builder()
104 .maybe_message(
105 response
106 .choices
107 .first()
108 .and_then(|choice| choice.message.content.clone()),
109 )
110 .maybe_tool_calls(
111 response
112 .choices
113 .first()
114 .and_then(|choice| choice.message.tool_calls.clone())
115 .map(|tool_calls| {
116 tool_calls
117 .iter()
118 .map(|tool_call| {
119 ToolCall::builder()
120 .id(tool_call.id.clone())
121 .args(tool_call.function.arguments.clone())
122 .name(tool_call.function.name.clone())
123 .build()
124 .expect("infallible")
125 })
126 .collect_vec()
127 }),
128 )
129 .to_owned();
130
131 if let Some(usage) = &response.usage {
132 let usage = UsageBuilder::default()
133 .prompt_tokens(usage.prompt_tokens)
134 .completion_tokens(usage.completion_tokens)
135 .total_tokens(usage.total_tokens)
136 .build()
137 .map_err(LanguageModelError::permanent)?;
138
139 if let Some(callback) = &self.on_usage {
140 callback(&usage).await?;
141 }
142
143 builder.usage(usage);
144
145 #[cfg(feature = "metrics")]
146 {
147 if let Some(usage) = response.usage.as_ref() {
148 emit_usage(
149 model,
150 usage.prompt_tokens.into(),
151 usage.completion_tokens.into(),
152 usage.total_tokens.into(),
153 self.metric_metadata.as_ref(),
154 );
155 }
156 }
157 }
158
159 let our_response = builder.build().map_err(LanguageModelError::from)?;
160
161 Ok(our_response)
162 }
163
164 #[tracing::instrument(skip_all)]
165 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
166 let Some(model) = self.default_options.prompt_model.clone() else {
167 return LanguageModelError::permanent("Model not set").into();
168 };
169
170 let messages = match request
171 .messages()
172 .iter()
173 .map(message_to_openai)
174 .collect::<Result<Vec<_>>>()
175 {
176 Ok(messages) => messages,
177 Err(e) => return LanguageModelError::from(e).into(),
178 };
179
180 let mut openai_request = self
182 .chat_completion_request_defaults()
183 .model(&model)
184 .messages(messages)
185 .stream_options(ChatCompletionStreamOptions {
186 include_usage: true,
187 })
188 .to_owned();
189
190 if !request.tools_spec.is_empty() {
191 openai_request
192 .tools(
193 match request
194 .tools_spec()
195 .iter()
196 .map(tools_to_openai)
197 .collect::<Result<Vec<_>>>()
198 {
199 Ok(tools) => tools,
200 Err(e) => {
201 return LanguageModelError::from(e).into();
202 }
203 },
204 )
205 .tool_choice("auto");
206 if let Some(par) = self.default_options.parallel_tool_calls {
207 openai_request.parallel_tool_calls(par);
208 }
209 }
210
211 let request = match openai_request.build() {
212 Ok(request) => request,
213 Err(e) => {
214 return openai_error_to_language_model_error(e).into();
215 }
216 };
217
218 tracing::trace!(model, ?request, "Sending request to OpenAI");
219
220 let response = match self.client.chat().create_stream(request.clone()).await {
221 Ok(response) => response,
222 Err(e) => return openai_error_to_language_model_error(e).into(),
223 };
224
225 let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
226 let final_response = accumulating_response.clone();
227 let stream_full = self.stream_full;
228
229 #[cfg(feature = "metrics")]
230 let metric_metadata = self.metric_metadata.clone();
231
232 let span = if cfg!(feature = "langfuse") {
233 tracing::info_span!(
234 "stream",
235 langfuse.type = "GENERATION",
236 )
237 } else {
238 tracing::info_span!("stream")
239 };
240
241 let maybe_usage_callback = self.on_usage.clone();
242 let stream = response
243 .map(move |chunk| match chunk {
244 Ok(chunk) => {
245 let accumulating_response = Arc::clone(&accumulating_response);
246
247 let delta_message = chunk
248 .choices
249 .first()
250 .and_then(|d| d.delta.content.as_deref());
251 let delta_tool_calls = chunk
252 .choices
253 .first()
254 .and_then(|d| d.delta.tool_calls.as_deref());
255 let usage = chunk.usage.as_ref();
256
257 let chat_completion_response = {
258 let mut lock = accumulating_response.lock().unwrap();
259 lock.append_message_delta(delta_message);
260
261 if let Some(delta_tool_calls) = delta_tool_calls {
262 for tc in delta_tool_calls {
263 lock.append_tool_call_delta(
264 tc.index as usize,
265 tc.id.as_deref(),
266 tc.function.as_ref().and_then(|f| f.name.as_deref()),
267 tc.function.as_ref().and_then(|f| f.arguments.as_deref()),
268 );
269 }
270 }
271
272 if let Some(usage) = usage {
273 lock.append_usage_delta(
274 usage.prompt_tokens,
275 usage.completion_tokens,
276 usage.total_tokens,
277 );
278 }
279
280 if stream_full {
281 lock.clone()
282 } else {
283 ChatCompletionResponse {
287 id: lock.id,
288 message: None,
289 tool_calls: None,
290 usage: None,
291 delta: lock.delta.clone(),
292 }
293 }
294 };
295
296 Ok(chat_completion_response)
297 }
298 Err(e) => Err(openai_error_to_language_model_error(e)),
299 })
300 .chain(
301 stream::iter(vec![final_response]).map(move |accumulating_response| {
302 let lock = accumulating_response.lock().unwrap();
303
304 let usage = lock.usage.clone().unwrap_or_default();
305
306 if cfg!(feature = "langfuse") {
307 tracing::debug!(
308 langfuse.model = model,
309 langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
310 langfuse.output = %serde_json::to_string_pretty(&*lock).unwrap_or_default(),
311 langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
312 );
313 }
314
315 tracing::debug!(
316 usage = format!(
317 "{}/{}/{}",
318 lock.usage.as_ref().map_or(0, |u| u.prompt_tokens),
319 lock.usage.as_ref().map_or(0, |u| u.completion_tokens),
320 lock.usage.as_ref().map_or(0, |u| u.total_tokens)
321 ),
322 model = &model,
323 has_message = lock.message.is_some(),
324 num_tool_calls = lock.tool_calls.as_ref().map_or(0, std::vec::Vec::len),
325 "[ChatCompletion/Streaming] Response from OpenAI"
326 );
327 #[allow(clippy::collapsible_if)]
328 if let Some(usage) = lock.usage.as_ref() {
329 if let Some(callback) = maybe_usage_callback.as_ref() {
330 let usage = usage.clone();
331 let callback = callback.clone();
332
333 tokio::spawn(async move {
334 if let Err(e) = callback(&usage).await {
335 tracing::error!("Error in on_usage callback: {}", e);
336 }
337 });
338 }
339
340 #[cfg(feature = "metrics")]
341 {
342 emit_usage(
343 &model,
344 usage.prompt_tokens.into(),
345 usage.completion_tokens.into(),
346 usage.total_tokens.into(),
347 metric_metadata.as_ref(),
348 );
349 }
350 }
351 Ok(lock.clone())
352 }),
353 );
354
355 let stream = tracing_futures::Instrument::instrument(stream, span);
356
357 Box::pin(stream)
358 }
359}
360
361fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
362 let mut properties = serde_json::Map::new();
363
364 for param in &spec.parameters {
365 properties.insert(
366 param.name.to_string(),
367 json!({
368 "type": param.ty.as_ref(),
369 "description": ¶m.description,
370 }),
371 );
372 }
373
374 ChatCompletionToolArgs::default()
375 .r#type(ChatCompletionToolType::Function)
376 .function(FunctionObjectArgs::default()
377 .name(&spec.name)
378 .description(&spec.description)
379 .strict(true)
380 .parameters(json!({
381 "type": "object",
382 "properties": properties,
383 "required": spec.parameters.iter().filter(|param| param.required).map(|param| ¶m.name).collect_vec(),
384 "additionalProperties": false,
385 })).build()?).build()
386 .map_err(anyhow::Error::from)
387}
388
389fn message_to_openai(
390 message: &ChatMessage,
391) -> Result<async_openai::types::ChatCompletionRequestMessage> {
392 let openai_message = match message {
393 ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
394 .content(msg.as_str())
395 .build()?
396 .into(),
397 ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
398 .content(msg.as_str())
399 .build()?
400 .into(),
401 ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
402 .content(msg.as_str())
403 .build()?
404 .into(),
405 ChatMessage::ToolOutput(tool_call, tool_output) => {
406 let Some(content) = tool_output.content() else {
407 return Ok(ChatCompletionRequestToolMessageArgs::default()
408 .tool_call_id(tool_call.id())
409 .build()?
410 .into());
411 };
412
413 ChatCompletionRequestToolMessageArgs::default()
414 .content(content)
415 .tool_call_id(tool_call.id())
416 .build()?
417 .into()
418 }
419 ChatMessage::Assistant(msg, tool_calls) => {
420 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
421
422 if let Some(msg) = msg {
423 builder.content(msg.as_str());
424 }
425
426 if let Some(tool_calls) = tool_calls {
427 builder.tool_calls(
428 tool_calls
429 .iter()
430 .map(|tool_call| ChatCompletionMessageToolCall {
431 id: tool_call.id().to_string(),
432 r#type: ChatCompletionToolType::Function,
433 function: FunctionCall {
434 name: tool_call.name().to_string(),
435 arguments: tool_call.args().unwrap_or_default().to_string(),
436 },
437 })
438 .collect::<Vec<_>>(),
439 );
440 }
441
442 builder.build()?.into()
443 }
444 };
445
446 Ok(openai_message)
447}
448
449#[cfg(test)]
450mod tests {
451 use crate::openai::{OpenAI, Options};
452
453 use super::*;
454 use wiremock::matchers::{method, path};
455 use wiremock::{Mock, MockServer, ResponseTemplate};
456
457 #[test_log::test(tokio::test)]
458 async fn test_complete() {
459 let mock_server = MockServer::start().await;
460
461 let response_body = json!({
463 "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
464 "object": "chat.completion",
465 "created": 123,
466 "model": "gpt-4o",
467 "choices": [
468 {
469 "index": 0,
470 "message": {
471 "role": "assistant",
472 "content": "Hello, world!",
473 "refusal": null,
474 "annotations": []
475 },
476 "logprobs": null,
477 "finish_reason": "stop"
478 }
479 ],
480 "usage": {
481 "prompt_tokens": 19,
482 "completion_tokens": 10,
483 "total_tokens": 29,
484 "prompt_tokens_details": {
485 "cached_tokens": 0,
486 "audio_tokens": 0
487 },
488 "completion_tokens_details": {
489 "reasoning_tokens": 0,
490 "audio_tokens": 0,
491 "accepted_prediction_tokens": 0,
492 "rejected_prediction_tokens": 0
493 }
494 },
495 "service_tier": "default"
496 });
497 Mock::given(method("POST"))
498 .and(path("/chat/completions"))
499 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
500 .mount(&mock_server)
501 .await;
502
503 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
505 let async_openai = async_openai::Client::with_config(config);
506
507 let openai = OpenAI::builder()
508 .client(async_openai)
509 .default_prompt_model("gpt-4o")
510 .build()
511 .expect("Can create OpenAI client.");
512
513 let request = ChatCompletionRequest::builder()
515 .messages(vec![ChatMessage::User("Hi".to_string())])
516 .build()
517 .unwrap();
518
519 let response = openai.complete(&request).await.unwrap();
521
522 assert_eq!(response.message(), Some("Hello, world!"));
524
525 let usage = response.usage.unwrap();
527 assert_eq!(usage.prompt_tokens, 19);
528 assert_eq!(usage.completion_tokens, 10);
529 assert_eq!(usage.total_tokens, 29);
530 }
531
532 #[test_log::test(tokio::test)]
533 #[allow(clippy::items_after_statements)]
534 async fn test_complete_with_all_default_settings() {
535 use serde_json::Value;
536 use wiremock::{Request, Respond, ResponseTemplate};
537
538 let mock_server = wiremock::MockServer::start().await;
539
540 struct ValidateAllSettings;
542
543 impl Respond for ValidateAllSettings {
544 fn respond(&self, request: &Request) -> ResponseTemplate {
545 let v: Value = serde_json::from_slice(&request.body).unwrap();
546
547 assert_eq!(v["model"], "gpt-4-turbo");
549 let arr = v["messages"].as_array().unwrap();
550 assert_eq!(arr.len(), 1);
551 assert_eq!(arr[0]["content"], "Test");
552
553 assert_eq!(v["parallel_tool_calls"], true);
554 assert_eq!(v["max_completion_tokens"], 77);
555 assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5);
556 assert_eq!(v["reasoning_effort"], "low");
557 assert_eq!(v["seed"], 42);
558 assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5);
559
560 assert_eq!(v["metadata"], serde_json::json!({"key": "value"}));
562 assert_eq!(v["user"], "test-user");
563 ResponseTemplate::new(200).set_body_json(serde_json::json!({
564 "id": "chatcmpl-xxx",
565 "object": "chat.completion",
566 "created": 123,
567 "model": "gpt-4-turbo",
568 "choices": [{
569 "index": 0,
570 "message": {
571 "role": "assistant",
572 "content": "All settings validated",
573 "refusal": null,
574 "annotations": []
575 },
576 "logprobs": null,
577 "finish_reason": "stop"
578 }],
579 "usage": {
580 "prompt_tokens": 19,
581 "completion_tokens": 10,
582 "total_tokens": 29,
583 "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
584 "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}
585 },
586 "service_tier": "default"
587 }))
588 }
589 }
590
591 wiremock::Mock::given(wiremock::matchers::method("POST"))
592 .and(wiremock::matchers::path("/chat/completions"))
593 .respond_with(ValidateAllSettings)
594 .mount(&mock_server)
595 .await;
596
597 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
598 let async_openai = async_openai::Client::with_config(config);
599
600 let openai = crate::openai::OpenAI::builder()
601 .client(async_openai)
602 .default_prompt_model("gpt-4-turbo")
603 .default_embed_model("not-used")
604 .parallel_tool_calls(Some(true))
605 .default_options(
606 Options::builder()
607 .max_completion_tokens(77)
608 .temperature(0.42)
609 .reasoning_effort(async_openai::types::ReasoningEffort::Low)
610 .seed(42)
611 .presence_penalty(1.1)
612 .metadata(serde_json::json!({"key": "value"}))
613 .user("test-user"),
614 )
615 .build()
616 .expect("Can create OpenAI client.");
617
618 let request = swiftide_core::chat_completion::ChatCompletionRequest::builder()
619 .messages(vec![swiftide_core::chat_completion::ChatMessage::User(
620 "Test".to_string(),
621 )])
622 .build()
623 .unwrap();
624
625 let response = openai.complete(&request).await.unwrap();
626
627 assert_eq!(response.message(), Some("All settings validated"));
628 }
629}