swiftide_integrations/openai/
chat_completion.rs1use std::sync::Arc;
2use std::sync::Mutex;
3
4use anyhow::{Context as _, Result};
5use async_openai::types::{
6 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
7 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
8 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
9 ChatCompletionToolType, FunctionCall, FunctionObjectArgs,
10};
11use async_trait::async_trait;
12use futures_util::StreamExt as _;
13use futures_util::stream;
14use itertools::Itertools;
15use serde_json::json;
16use swiftide_core::ChatCompletionStream;
17use swiftide_core::chat_completion::UsageBuilder;
18use swiftide_core::chat_completion::{
19 ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec,
20 errors::LanguageModelError,
21};
22
23use super::GenericOpenAI;
24use super::openai_error_to_language_model_error;
25
26#[async_trait]
27impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
28 ChatCompletion for GenericOpenAI<C>
29{
30 #[tracing::instrument(skip_all)]
31 async fn complete(
32 &self,
33 request: &ChatCompletionRequest,
34 ) -> Result<ChatCompletionResponse, LanguageModelError> {
35 let model = self
36 .default_options
37 .prompt_model
38 .as_ref()
39 .context("Model not set")?;
40
41 let messages = request
42 .messages()
43 .iter()
44 .map(message_to_openai)
45 .collect::<Result<Vec<_>>>()?;
46
47 let mut openai_request = self
49 .chat_completion_request_defaults()
50 .model(model)
51 .messages(messages)
52 .to_owned();
53
54 if !request.tools_spec.is_empty() {
55 openai_request
56 .tools(
57 request
58 .tools_spec()
59 .iter()
60 .map(tools_to_openai)
61 .collect::<Result<Vec<_>>>()?,
62 )
63 .tool_choice("auto");
64 if let Some(par) = self.default_options.parallel_tool_calls {
65 openai_request.parallel_tool_calls(par);
66 }
67 }
68
69 let request = openai_request
70 .build()
71 .map_err(openai_error_to_language_model_error)?;
72
73 tracing::debug!(model, ?request, "Sending request to OpenAI");
74
75 let response = self
76 .client
77 .chat()
78 .create(request)
79 .await
80 .map_err(openai_error_to_language_model_error)?;
81
82 tracing::debug!(?response, "Received response from OpenAI");
83
84 let mut builder = ChatCompletionResponse::builder()
85 .maybe_message(
86 response
87 .choices
88 .first()
89 .and_then(|choice| choice.message.content.clone()),
90 )
91 .maybe_tool_calls(
92 response
93 .choices
94 .first()
95 .and_then(|choice| choice.message.tool_calls.clone())
96 .map(|tool_calls| {
97 tool_calls
98 .iter()
99 .map(|tool_call| {
100 ToolCall::builder()
101 .id(tool_call.id.clone())
102 .args(tool_call.function.arguments.clone())
103 .name(tool_call.function.name.clone())
104 .build()
105 .expect("infallible")
106 })
107 .collect_vec()
108 }),
109 )
110 .to_owned();
111
112 if let Some(usage) = response.usage {
113 let usage = UsageBuilder::default()
114 .prompt_tokens(usage.prompt_tokens)
115 .completion_tokens(usage.completion_tokens)
116 .total_tokens(usage.total_tokens)
117 .build()
118 .map_err(LanguageModelError::permanent)?;
119
120 builder.usage(usage);
121 }
122
123 builder.build().map_err(LanguageModelError::from)
124 }
125
126 #[tracing::instrument(skip_all)]
127 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
128 let Some(model) = self.default_options.prompt_model.as_ref() else {
129 return LanguageModelError::permanent("Model not set").into();
130 };
131
132 let messages = match request
133 .messages()
134 .iter()
135 .map(message_to_openai)
136 .collect::<Result<Vec<_>>>()
137 {
138 Ok(messages) => messages,
139 Err(e) => return LanguageModelError::from(e).into(),
140 };
141
142 let mut openai_request = self
144 .chat_completion_request_defaults()
145 .model(model)
146 .messages(messages)
147 .to_owned();
148
149 if !request.tools_spec.is_empty() {
150 openai_request
151 .tools(
152 match request
153 .tools_spec()
154 .iter()
155 .map(tools_to_openai)
156 .collect::<Result<Vec<_>>>()
157 {
158 Ok(tools) => tools,
159 Err(e) => {
160 return LanguageModelError::from(e).into();
161 }
162 },
163 )
164 .tool_choice("auto");
165 if let Some(par) = self.default_options.parallel_tool_calls {
166 openai_request.parallel_tool_calls(par);
167 }
168 }
169
170 let request = match openai_request.build() {
171 Ok(request) => request,
172 Err(e) => {
173 return openai_error_to_language_model_error(e).into();
174 }
175 };
176
177 tracing::debug!(model, ?request, "Sending request to OpenAI");
178
179 let response = match self.client.chat().create_stream(request).await {
180 Ok(response) => response,
181 Err(e) => return openai_error_to_language_model_error(e).into(),
182 };
183
184 let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
185 let final_response = accumulating_response.clone();
186 let stream_full = self.stream_full;
187
188 response
189 .map(move |chunk| match chunk {
190 Ok(chunk) => {
191 let accumulating_response = Arc::clone(&accumulating_response);
192
193 let delta_message = chunk.choices[0].delta.content.as_deref();
194 let delta_tool_calls = chunk.choices[0].delta.tool_calls.as_deref();
195 let usage = chunk.usage.as_ref();
196
197 let chat_completion_response = {
198 let mut lock = accumulating_response.lock().unwrap();
199 lock.append_message_delta(delta_message);
200
201 if let Some(delta_tool_calls) = delta_tool_calls {
202 for tc in delta_tool_calls {
203 lock.append_tool_call_delta(
204 tc.index as usize,
205 tc.id.as_deref(),
206 tc.function.as_ref().and_then(|f| f.name.as_deref()),
207 tc.function.as_ref().and_then(|f| f.arguments.as_deref()),
208 );
209 }
210 }
211
212 if let Some(usage) = usage {
213 lock.append_usage_delta(
214 usage.prompt_tokens,
215 usage.completion_tokens,
216 usage.total_tokens,
217 );
218 }
219
220 if stream_full {
221 lock.clone()
222 } else {
223 ChatCompletionResponse {
227 id: lock.id,
228 message: None,
229 tool_calls: None,
230 usage: None,
231 delta: lock.delta.clone(),
232 }
233 }
234 };
235
236 Ok(chat_completion_response)
237 }
238 Err(e) => Err(openai_error_to_language_model_error(e)),
239 })
240 .chain(
241 stream::iter(vec![final_response]).map(move |accumulating_response| {
242 let lock = accumulating_response.lock().unwrap();
243 Ok(lock.clone())
244 }),
245 )
246 .boxed()
247 }
248}
249
250fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
251 let mut properties = serde_json::Map::new();
252
253 for param in &spec.parameters {
254 properties.insert(
255 param.name.to_string(),
256 json!({
257 "type": param.ty.as_ref(),
258 "description": ¶m.description,
259 }),
260 );
261 }
262
263 ChatCompletionToolArgs::default()
264 .r#type(ChatCompletionToolType::Function)
265 .function(FunctionObjectArgs::default()
266 .name(&spec.name)
267 .description(&spec.description)
268 .strict(true)
269 .parameters(json!({
270 "type": "object",
271 "properties": properties,
272 "required": spec.parameters.iter().filter(|param| param.required).map(|param| ¶m.name).collect_vec(),
273 "additionalProperties": false,
274 })).build()?).build()
275 .map_err(anyhow::Error::from)
276}
277
278fn message_to_openai(
279 message: &ChatMessage,
280) -> Result<async_openai::types::ChatCompletionRequestMessage> {
281 let openai_message = match message {
282 ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
283 .content(msg.as_str())
284 .build()?
285 .into(),
286 ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
287 .content(msg.as_str())
288 .build()?
289 .into(),
290 ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
291 .content(msg.as_str())
292 .build()?
293 .into(),
294 ChatMessage::ToolOutput(tool_call, tool_output) => {
295 let Some(content) = tool_output.content() else {
296 return Ok(ChatCompletionRequestToolMessageArgs::default()
297 .tool_call_id(tool_call.id())
298 .build()?
299 .into());
300 };
301
302 ChatCompletionRequestToolMessageArgs::default()
303 .content(content)
304 .tool_call_id(tool_call.id())
305 .build()?
306 .into()
307 }
308 ChatMessage::Assistant(msg, tool_calls) => {
309 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
310
311 if let Some(msg) = msg {
312 builder.content(msg.as_str());
313 }
314
315 if let Some(tool_calls) = tool_calls {
316 builder.tool_calls(
317 tool_calls
318 .iter()
319 .map(|tool_call| ChatCompletionMessageToolCall {
320 id: tool_call.id().to_string(),
321 r#type: ChatCompletionToolType::Function,
322 function: FunctionCall {
323 name: tool_call.name().to_string(),
324 arguments: tool_call.args().unwrap_or_default().to_string(),
325 },
326 })
327 .collect::<Vec<_>>(),
328 );
329 }
330
331 builder.build()?.into()
332 }
333 };
334
335 Ok(openai_message)
336}
337
338#[cfg(test)]
339mod tests {
340 use crate::openai::{OpenAI, Options};
341
342 use super::*;
343 use wiremock::matchers::{method, path};
344 use wiremock::{Mock, MockServer, ResponseTemplate};
345
346 #[test_log::test(tokio::test)]
347 async fn test_complete() {
348 let mock_server = MockServer::start().await;
349
350 let response_body = json!({
352 "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
353 "object": "chat.completion",
354 "created": 123,
355 "model": "gpt-4o",
356 "choices": [
357 {
358 "index": 0,
359 "message": {
360 "role": "assistant",
361 "content": "Hello, world!",
362 "refusal": null,
363 "annotations": []
364 },
365 "logprobs": null,
366 "finish_reason": "stop"
367 }
368 ],
369 "usage": {
370 "prompt_tokens": 19,
371 "completion_tokens": 10,
372 "total_tokens": 29,
373 "prompt_tokens_details": {
374 "cached_tokens": 0,
375 "audio_tokens": 0
376 },
377 "completion_tokens_details": {
378 "reasoning_tokens": 0,
379 "audio_tokens": 0,
380 "accepted_prediction_tokens": 0,
381 "rejected_prediction_tokens": 0
382 }
383 },
384 "service_tier": "default"
385 });
386 Mock::given(method("POST"))
387 .and(path("/chat/completions"))
388 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
389 .mount(&mock_server)
390 .await;
391
392 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
394 let async_openai = async_openai::Client::with_config(config);
395
396 let openai = OpenAI::builder()
397 .client(async_openai)
398 .default_prompt_model("gpt-4o")
399 .build()
400 .expect("Can create OpenAI client.");
401
402 let request = ChatCompletionRequest::builder()
404 .messages(vec![ChatMessage::User("Hi".to_string())])
405 .build()
406 .unwrap();
407
408 let response = openai.complete(&request).await.unwrap();
410
411 assert_eq!(response.message(), Some("Hello, world!"));
413
414 let usage = response.usage.unwrap();
416 assert_eq!(usage.prompt_tokens, 19);
417 assert_eq!(usage.completion_tokens, 10);
418 assert_eq!(usage.total_tokens, 29);
419 }
420
421 #[test_log::test(tokio::test)]
422 #[allow(clippy::items_after_statements)]
423 async fn test_complete_with_all_default_settings() {
424 use serde_json::Value;
425 use wiremock::{Request, Respond, ResponseTemplate};
426
427 let mock_server = wiremock::MockServer::start().await;
428
429 struct ValidateAllSettings;
431
432 impl Respond for ValidateAllSettings {
433 fn respond(&self, request: &Request) -> ResponseTemplate {
434 let v: Value = serde_json::from_slice(&request.body).unwrap();
435
436 assert_eq!(v["model"], "gpt-4-turbo");
438 let arr = v["messages"].as_array().unwrap();
439 assert_eq!(arr.len(), 1);
440 assert_eq!(arr[0]["content"], "Test");
441
442 assert_eq!(v["parallel_tool_calls"], true);
443 assert_eq!(v["max_completion_tokens"], 77);
444 assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5);
445 assert_eq!(v["reasoning_effort"], "low");
446 assert_eq!(v["seed"], 42);
447 assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5);
448
449 assert_eq!(v["metadata"], serde_json::json!({"key": "value"}));
451 assert_eq!(v["user"], "test-user");
452 ResponseTemplate::new(200).set_body_json(serde_json::json!({
453 "id": "chatcmpl-xxx",
454 "object": "chat.completion",
455 "created": 123,
456 "model": "gpt-4-turbo",
457 "choices": [{
458 "index": 0,
459 "message": {
460 "role": "assistant",
461 "content": "All settings validated",
462 "refusal": null,
463 "annotations": []
464 },
465 "logprobs": null,
466 "finish_reason": "stop"
467 }],
468 "usage": {
469 "prompt_tokens": 19,
470 "completion_tokens": 10,
471 "total_tokens": 29,
472 "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
473 "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}
474 },
475 "service_tier": "default"
476 }))
477 }
478 }
479
480 wiremock::Mock::given(wiremock::matchers::method("POST"))
481 .and(wiremock::matchers::path("/chat/completions"))
482 .respond_with(ValidateAllSettings)
483 .mount(&mock_server)
484 .await;
485
486 let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
487 let async_openai = async_openai::Client::with_config(config);
488
489 let openai = crate::openai::OpenAI::builder()
490 .client(async_openai)
491 .default_prompt_model("gpt-4-turbo")
492 .default_embed_model("not-used")
493 .parallel_tool_calls(Some(true))
494 .default_options(
495 Options::builder()
496 .max_completion_tokens(77)
497 .temperature(0.42)
498 .reasoning_effort(async_openai::types::ReasoningEffort::Low)
499 .seed(42)
500 .presence_penalty(1.1)
501 .metadata(serde_json::json!({"key": "value"}))
502 .user("test-user"),
503 )
504 .build()
505 .expect("Can create OpenAI client.");
506
507 let request = swiftide_core::chat_completion::ChatCompletionRequest::builder()
508 .messages(vec![swiftide_core::chat_completion::ChatMessage::User(
509 "Test".to_string(),
510 )])
511 .build()
512 .unwrap();
513
514 let response = openai.complete(&request).await.unwrap();
515
516 assert_eq!(response.message(), Some("All settings validated"));
517 }
518}