swiftide_integrations/openai/
chat_completion.rs1use std::sync::Arc;
2use std::sync::Mutex;
3
4use anyhow::{Context as _, Result};
5use async_openai::types::{
6 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
7 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
8 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
9 ChatCompletionToolType, FunctionCall, FunctionObjectArgs,
10};
11use async_trait::async_trait;
12use futures_util::StreamExt as _;
13use futures_util::stream;
14use itertools::Itertools;
15use serde_json::json;
16use swiftide_core::ChatCompletionStream;
17use swiftide_core::chat_completion::UsageBuilder;
18use swiftide_core::chat_completion::{
19 ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec,
20 errors::LanguageModelError,
21};
22#[cfg(feature = "metrics")]
23use swiftide_core::metrics::emit_usage;
24
25use super::GenericOpenAI;
26use super::openai_error_to_language_model_error;
27
28#[async_trait]
29impl<
30 C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
31> ChatCompletion for GenericOpenAI<C>
32{
33 #[tracing::instrument(skip_all)]
34 async fn complete(
35 &self,
36 request: &ChatCompletionRequest,
37 ) -> Result<ChatCompletionResponse, LanguageModelError> {
38 let model = self
39 .default_options
40 .prompt_model
41 .as_ref()
42 .context("Model not set")?;
43
44 let messages = request
45 .messages()
46 .iter()
47 .map(message_to_openai)
48 .collect::<Result<Vec<_>>>()?;
49
50 let mut openai_request = self
52 .chat_completion_request_defaults()
53 .model(model)
54 .messages(messages)
55 .to_owned();
56
57 if !request.tools_spec.is_empty() {
58 openai_request
59 .tools(
60 request
61 .tools_spec()
62 .iter()
63 .map(tools_to_openai)
64 .collect::<Result<Vec<_>>>()?,
65 )
66 .tool_choice("auto");
67 if let Some(par) = self.default_options.parallel_tool_calls {
68 openai_request.parallel_tool_calls(par);
69 }
70 }
71
72 let request = openai_request
73 .build()
74 .map_err(openai_error_to_language_model_error)?;
75
76 tracing::debug!(model, ?request, "Sending request to OpenAI");
77
78 let response = self
79 .client
80 .chat()
81 .create(request)
82 .await
83 .map_err(openai_error_to_language_model_error)?;
84
85 tracing::debug!(?response, "Received response from OpenAI");
86
87 let mut builder = ChatCompletionResponse::builder()
88 .maybe_message(
89 response
90 .choices
91 .first()
92 .and_then(|choice| choice.message.content.clone()),
93 )
94 .maybe_tool_calls(
95 response
96 .choices
97 .first()
98 .and_then(|choice| choice.message.tool_calls.clone())
99 .map(|tool_calls| {
100 tool_calls
101 .iter()
102 .map(|tool_call| {
103 ToolCall::builder()
104 .id(tool_call.id.clone())
105 .args(tool_call.function.arguments.clone())
106 .name(tool_call.function.name.clone())
107 .build()
108 .expect("infallible")
109 })
110 .collect_vec()
111 }),
112 )
113 .to_owned();
114
115 if let Some(usage) = &response.usage {
116 let usage = UsageBuilder::default()
117 .prompt_tokens(usage.prompt_tokens)
118 .completion_tokens(usage.completion_tokens)
119 .total_tokens(usage.total_tokens)
120 .build()
121 .map_err(LanguageModelError::permanent)?;
122
123 builder.usage(usage);
124
125 #[cfg(feature = "metrics")]
126 {
127 if let Some(usage) = response.usage.as_ref() {
128 emit_usage(
129 model,
130 usage.prompt_tokens.into(),
131 usage.completion_tokens.into(),
132 usage.total_tokens.into(),
133 self.metric_metadata.as_ref(),
134 );
135 }
136 }
137 }
138
139 builder.build().map_err(LanguageModelError::from)
140 }
141
142 #[tracing::instrument(skip_all)]
143 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
144 let Some(model) = self.default_options.prompt_model.as_ref() else {
145 return LanguageModelError::permanent("Model not set").into();
146 };
147
148 let messages = match request
149 .messages()
150 .iter()
151 .map(message_to_openai)
152 .collect::<Result<Vec<_>>>()
153 {
154 Ok(messages) => messages,
155 Err(e) => return LanguageModelError::from(e).into(),
156 };
157
158 let mut openai_request = self
160 .chat_completion_request_defaults()
161 .model(model)
162 .messages(messages)
163 .to_owned();
164
165 if !request.tools_spec.is_empty() {
166 openai_request
167 .tools(
168 match request
169 .tools_spec()
170 .iter()
171 .map(tools_to_openai)
172 .collect::<Result<Vec<_>>>()
173 {
174 Ok(tools) => tools,
175 Err(e) => {
176 return LanguageModelError::from(e).into();
177 }
178 },
179 )
180 .tool_choice("auto");
181 if let Some(par) = self.default_options.parallel_tool_calls {
182 openai_request.parallel_tool_calls(par);
183 }
184 }
185
186 let request = match openai_request.build() {
187 Ok(request) => request,
188 Err(e) => {
189 return openai_error_to_language_model_error(e).into();
190 }
191 };
192
193 tracing::debug!(model, ?request, "Sending request to OpenAI");
194
195 let response = match self.client.chat().create_stream(request).await {
196 Ok(response) => response,
197 Err(e) => return openai_error_to_language_model_error(e).into(),
198 };
199
200 let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
201 let final_response = accumulating_response.clone();
202 let stream_full = self.stream_full;
203
204 #[cfg(feature = "metrics")]
205 let metric_metadata = self.metric_metadata.clone();
206 #[cfg(feature = "metrics")]
207 let model = model.to_string();
208
209 response
210 .map(move |chunk| match chunk {
211 Ok(chunk) => {
212 let accumulating_response = Arc::clone(&accumulating_response);
213
214 let delta_message = chunk.choices[0].delta.content.as_deref();
215 let delta_tool_calls = chunk.choices[0].delta.tool_calls.as_deref();
216 let usage = chunk.usage.as_ref();
217
218 let chat_completion_response = {
219 let mut lock = accumulating_response.lock().unwrap();
220 lock.append_message_delta(delta_message);
221
222 if let Some(delta_tool_calls) = delta_tool_calls {
223 for tc in delta_tool_calls {
224 lock.append_tool_call_delta(
225 tc.index as usize,
226 tc.id.as_deref(),
227 tc.function.as_ref().and_then(|f| f.name.as_deref()),
228 tc.function.as_ref().and_then(|f| f.arguments.as_deref()),
229 );
230 }
231 }
232
233 if let Some(usage) = usage {
234 lock.append_usage_delta(
235 usage.prompt_tokens,
236 usage.completion_tokens,
237 usage.total_tokens,
238 );
239 }
240
241 if stream_full {
242 lock.clone()
243 } else {
244 ChatCompletionResponse {
248 id: lock.id,
249 message: None,
250 tool_calls: None,
251 usage: None,
252 delta: lock.delta.clone(),
253 }
254 }
255 };
256
257 Ok(chat_completion_response)
258 }
259 Err(e) => Err(openai_error_to_language_model_error(e)),
260 })
261 .chain(
262 stream::iter(vec![final_response]).map(move |accumulating_response| {
263 let lock = accumulating_response.lock().unwrap();
264
265 #[cfg(feature = "metrics")]
266 {
267 if let Some(usage) = lock.usage.as_ref() {
268 emit_usage(
269 &model,
270 usage.prompt_tokens.into(),
271 usage.completion_tokens.into(),
272 usage.total_tokens.into(),
273 metric_metadata.as_ref(),
274 );
275 }
276 }
277 Ok(lock.clone())
278 }),
279 )
280 .boxed()
281 }
282}
283
284fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
285 let mut properties = serde_json::Map::new();
286
287 for param in &spec.parameters {
288 properties.insert(
289 param.name.to_string(),
290 json!({
291 "type": param.ty.as_ref(),
292 "description": ¶m.description,
293 }),
294 );
295 }
296
297 ChatCompletionToolArgs::default()
298 .r#type(ChatCompletionToolType::Function)
299 .function(FunctionObjectArgs::default()
300 .name(&spec.name)
301 .description(&spec.description)
302 .strict(true)
303 .parameters(json!({
304 "type": "object",
305 "properties": properties,
306 "required": spec.parameters.iter().filter(|param| param.required).map(|param| ¶m.name).collect_vec(),
307 "additionalProperties": false,
308 })).build()?).build()
309 .map_err(anyhow::Error::from)
310}
311
312fn message_to_openai(
313 message: &ChatMessage,
314) -> Result<async_openai::types::ChatCompletionRequestMessage> {
315 let openai_message = match message {
316 ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
317 .content(msg.as_str())
318 .build()?
319 .into(),
320 ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
321 .content(msg.as_str())
322 .build()?
323 .into(),
324 ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
325 .content(msg.as_str())
326 .build()?
327 .into(),
328 ChatMessage::ToolOutput(tool_call, tool_output) => {
329 let Some(content) = tool_output.content() else {
330 return Ok(ChatCompletionRequestToolMessageArgs::default()
331 .tool_call_id(tool_call.id())
332 .build()?
333 .into());
334 };
335
336 ChatCompletionRequestToolMessageArgs::default()
337 .content(content)
338 .tool_call_id(tool_call.id())
339 .build()?
340 .into()
341 }
342 ChatMessage::Assistant(msg, tool_calls) => {
343 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
344
345 if let Some(msg) = msg {
346 builder.content(msg.as_str());
347 }
348
349 if let Some(tool_calls) = tool_calls {
350 builder.tool_calls(
351 tool_calls
352 .iter()
353 .map(|tool_call| ChatCompletionMessageToolCall {
354 id: tool_call.id().to_string(),
355 r#type: ChatCompletionToolType::Function,
356 function: FunctionCall {
357 name: tool_call.name().to_string(),
358 arguments: tool_call.args().unwrap_or_default().to_string(),
359 },
360 })
361 .collect::<Vec<_>>(),
362 );
363 }
364
365 builder.build()?.into()
366 }
367 };
368
369 Ok(openai_message)
370}
371
372#[cfg(test)]
373mod tests {
374 use crate::openai::{OpenAI, Options};
375
376 use super::*;
377 use wiremock::matchers::{method, path};
378 use wiremock::{Mock, MockServer, ResponseTemplate};
379
380 #[test_log::test(tokio::test)]
381 async fn test_complete() {
382 let mock_server = MockServer::start().await;
383
384 let response_body = json!({
386 "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
387 "object": "chat.completion",
388 "created": 123,
389 "model": "gpt-4o",
390 "choices": [
391 {
392 "index": 0,
393 "message": {
394 "role": "assistant",
395 "content": "Hello, world!",
396 "refusal": null,
397 "annotations": []
398 },
399 "logprobs": null,
400 "finish_reason": "stop"
401 }
402 ],
403 "usage": {
404 "prompt_tokens": 19,
405 "completion_tokens": 10,
406 "total_tokens": 29,
407 "prompt_tokens_details": {
408 "cached_tokens": 0,
409 "audio_tokens": 0
410 },
411 "completion_tokens_details": {
412 "reasoning_tokens": 0,
413 "audio_tokens": 0,
414 "accepted_prediction_tokens": 0,
415 "rejected_prediction_tokens": 0
416 }
417 },
418 "service_tier": "default"
419 });
420 Mock::given(method("POST"))
421 .and(path("/chat/completions"))
422 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
423 .mount(&mock_server)
424 .await;
425
426 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
428 let async_openai = async_openai::Client::with_config(config);
429
430 let openai = OpenAI::builder()
431 .client(async_openai)
432 .default_prompt_model("gpt-4o")
433 .build()
434 .expect("Can create OpenAI client.");
435
436 let request = ChatCompletionRequest::builder()
438 .messages(vec![ChatMessage::User("Hi".to_string())])
439 .build()
440 .unwrap();
441
442 let response = openai.complete(&request).await.unwrap();
444
445 assert_eq!(response.message(), Some("Hello, world!"));
447
448 let usage = response.usage.unwrap();
450 assert_eq!(usage.prompt_tokens, 19);
451 assert_eq!(usage.completion_tokens, 10);
452 assert_eq!(usage.total_tokens, 29);
453 }
454
455 #[test_log::test(tokio::test)]
456 #[allow(clippy::items_after_statements)]
457 async fn test_complete_with_all_default_settings() {
458 use serde_json::Value;
459 use wiremock::{Request, Respond, ResponseTemplate};
460
461 let mock_server = wiremock::MockServer::start().await;
462
463 struct ValidateAllSettings;
465
466 impl Respond for ValidateAllSettings {
467 fn respond(&self, request: &Request) -> ResponseTemplate {
468 let v: Value = serde_json::from_slice(&request.body).unwrap();
469
470 assert_eq!(v["model"], "gpt-4-turbo");
472 let arr = v["messages"].as_array().unwrap();
473 assert_eq!(arr.len(), 1);
474 assert_eq!(arr[0]["content"], "Test");
475
476 assert_eq!(v["parallel_tool_calls"], true);
477 assert_eq!(v["max_completion_tokens"], 77);
478 assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5);
479 assert_eq!(v["reasoning_effort"], "low");
480 assert_eq!(v["seed"], 42);
481 assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5);
482
483 assert_eq!(v["metadata"], serde_json::json!({"key": "value"}));
485 assert_eq!(v["user"], "test-user");
486 ResponseTemplate::new(200).set_body_json(serde_json::json!({
487 "id": "chatcmpl-xxx",
488 "object": "chat.completion",
489 "created": 123,
490 "model": "gpt-4-turbo",
491 "choices": [{
492 "index": 0,
493 "message": {
494 "role": "assistant",
495 "content": "All settings validated",
496 "refusal": null,
497 "annotations": []
498 },
499 "logprobs": null,
500 "finish_reason": "stop"
501 }],
502 "usage": {
503 "prompt_tokens": 19,
504 "completion_tokens": 10,
505 "total_tokens": 29,
506 "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
507 "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}
508 },
509 "service_tier": "default"
510 }))
511 }
512 }
513
514 wiremock::Mock::given(wiremock::matchers::method("POST"))
515 .and(wiremock::matchers::path("/chat/completions"))
516 .respond_with(ValidateAllSettings)
517 .mount(&mock_server)
518 .await;
519
520 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
521 let async_openai = async_openai::Client::with_config(config);
522
523 let openai = crate::openai::OpenAI::builder()
524 .client(async_openai)
525 .default_prompt_model("gpt-4-turbo")
526 .default_embed_model("not-used")
527 .parallel_tool_calls(Some(true))
528 .default_options(
529 Options::builder()
530 .max_completion_tokens(77)
531 .temperature(0.42)
532 .reasoning_effort(async_openai::types::ReasoningEffort::Low)
533 .seed(42)
534 .presence_penalty(1.1)
535 .metadata(serde_json::json!({"key": "value"}))
536 .user("test-user"),
537 )
538 .build()
539 .expect("Can create OpenAI client.");
540
541 let request = swiftide_core::chat_completion::ChatCompletionRequest::builder()
542 .messages(vec![swiftide_core::chat_completion::ChatMessage::User(
543 "Test".to_string(),
544 )])
545 .build()
546 .unwrap();
547
548 let response = openai.complete(&request).await.unwrap();
549
550 assert_eq!(response.message(), Some("All settings validated"));
551 }
552}