swiftide_integrations/openai/
chat_completion.rs

1use std::sync::Arc;
2use std::sync::Mutex;
3
4use anyhow::{Context as _, Result};
5use async_openai::types::{
6    ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
7    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
8    ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
9    ChatCompletionToolType, FunctionCall, FunctionObjectArgs,
10};
11use async_trait::async_trait;
12use futures_util::StreamExt as _;
13use futures_util::stream;
14use itertools::Itertools;
15use serde_json::json;
16use swiftide_core::ChatCompletionStream;
17use swiftide_core::chat_completion::UsageBuilder;
18use swiftide_core::chat_completion::{
19    ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec,
20    errors::LanguageModelError,
21};
22
23use super::GenericOpenAI;
24use super::openai_error_to_language_model_error;
25
26#[async_trait]
27impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
28    ChatCompletion for GenericOpenAI<C>
29{
30    #[tracing::instrument(skip_all)]
31    async fn complete(
32        &self,
33        request: &ChatCompletionRequest,
34    ) -> Result<ChatCompletionResponse, LanguageModelError> {
35        let model = self
36            .default_options
37            .prompt_model
38            .as_ref()
39            .context("Model not set")?;
40
41        let messages = request
42            .messages()
43            .iter()
44            .map(message_to_openai)
45            .collect::<Result<Vec<_>>>()?;
46
47        // Build the request to be sent to the OpenAI API.
48        let mut openai_request = self
49            .chat_completion_request_defaults()
50            .model(model)
51            .messages(messages)
52            .to_owned();
53
54        if !request.tools_spec.is_empty() {
55            openai_request
56                .tools(
57                    request
58                        .tools_spec()
59                        .iter()
60                        .map(tools_to_openai)
61                        .collect::<Result<Vec<_>>>()?,
62                )
63                .tool_choice("auto");
64            if let Some(par) = self.default_options.parallel_tool_calls {
65                openai_request.parallel_tool_calls(par);
66            }
67        }
68
69        let request = openai_request
70            .build()
71            .map_err(openai_error_to_language_model_error)?;
72
73        tracing::debug!(model, ?request, "Sending request to OpenAI");
74
75        let response = self
76            .client
77            .chat()
78            .create(request)
79            .await
80            .map_err(openai_error_to_language_model_error)?;
81
82        tracing::debug!(?response, "Received response from OpenAI");
83
84        let mut builder = ChatCompletionResponse::builder()
85            .maybe_message(
86                response
87                    .choices
88                    .first()
89                    .and_then(|choice| choice.message.content.clone()),
90            )
91            .maybe_tool_calls(
92                response
93                    .choices
94                    .first()
95                    .and_then(|choice| choice.message.tool_calls.clone())
96                    .map(|tool_calls| {
97                        tool_calls
98                            .iter()
99                            .map(|tool_call| {
100                                ToolCall::builder()
101                                    .id(tool_call.id.clone())
102                                    .args(tool_call.function.arguments.clone())
103                                    .name(tool_call.function.name.clone())
104                                    .build()
105                                    .expect("infallible")
106                            })
107                            .collect_vec()
108                    }),
109            )
110            .to_owned();
111
112        if let Some(usage) = response.usage {
113            let usage = UsageBuilder::default()
114                .prompt_tokens(usage.prompt_tokens)
115                .completion_tokens(usage.completion_tokens)
116                .total_tokens(usage.total_tokens)
117                .build()
118                .map_err(LanguageModelError::permanent)?;
119
120            builder.usage(usage);
121        }
122
123        builder.build().map_err(LanguageModelError::from)
124    }
125
126    #[tracing::instrument(skip_all)]
127    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
128        let Some(model) = self.default_options.prompt_model.as_ref() else {
129            return LanguageModelError::permanent("Model not set").into();
130        };
131
132        let messages = match request
133            .messages()
134            .iter()
135            .map(message_to_openai)
136            .collect::<Result<Vec<_>>>()
137        {
138            Ok(messages) => messages,
139            Err(e) => return LanguageModelError::from(e).into(),
140        };
141
142        // Build the request to be sent to the OpenAI API.
143        let mut openai_request = self
144            .chat_completion_request_defaults()
145            .model(model)
146            .messages(messages)
147            .to_owned();
148
149        if !request.tools_spec.is_empty() {
150            openai_request
151                .tools(
152                    match request
153                        .tools_spec()
154                        .iter()
155                        .map(tools_to_openai)
156                        .collect::<Result<Vec<_>>>()
157                    {
158                        Ok(tools) => tools,
159                        Err(e) => {
160                            return LanguageModelError::from(e).into();
161                        }
162                    },
163                )
164                .tool_choice("auto");
165            if let Some(par) = self.default_options.parallel_tool_calls {
166                openai_request.parallel_tool_calls(par);
167            }
168        }
169
170        let request = match openai_request.build() {
171            Ok(request) => request,
172            Err(e) => {
173                return openai_error_to_language_model_error(e).into();
174            }
175        };
176
177        tracing::debug!(model, ?request, "Sending request to OpenAI");
178
179        let response = match self.client.chat().create_stream(request).await {
180            Ok(response) => response,
181            Err(e) => return openai_error_to_language_model_error(e).into(),
182        };
183
184        let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
185        let final_response = accumulating_response.clone();
186        let stream_full = self.stream_full;
187
188        response
189            .map(move |chunk| match chunk {
190                Ok(chunk) => {
191                    let accumulating_response = Arc::clone(&accumulating_response);
192
193                    let delta_message = chunk.choices[0].delta.content.as_deref();
194                    let delta_tool_calls = chunk.choices[0].delta.tool_calls.as_deref();
195                    let usage = chunk.usage.as_ref();
196
197                    let chat_completion_response = {
198                        let mut lock = accumulating_response.lock().unwrap();
199                        lock.append_message_delta(delta_message);
200
201                        if let Some(delta_tool_calls) = delta_tool_calls {
202                            for tc in delta_tool_calls {
203                                lock.append_tool_call_delta(
204                                    tc.index as usize,
205                                    tc.id.as_deref(),
206                                    tc.function.as_ref().and_then(|f| f.name.as_deref()),
207                                    tc.function.as_ref().and_then(|f| f.arguments.as_deref()),
208                                );
209                            }
210                        }
211
212                        if let Some(usage) = usage {
213                            lock.append_usage_delta(
214                                usage.prompt_tokens,
215                                usage.completion_tokens,
216                                usage.total_tokens,
217                            );
218                        }
219
220                        if stream_full {
221                            lock.clone()
222                        } else {
223                            // If we are not streaming the full response, we return a clone of the
224                            // current state to avoid holding the lock
225                            // for too long.
226                            ChatCompletionResponse {
227                                id: lock.id,
228                                message: None,
229                                tool_calls: None,
230                                usage: None,
231                                delta: lock.delta.clone(),
232                            }
233                        }
234                    };
235
236                    Ok(chat_completion_response)
237                }
238                Err(e) => Err(openai_error_to_language_model_error(e)),
239            })
240            .chain(
241                stream::iter(vec![final_response]).map(move |accumulating_response| {
242                    let lock = accumulating_response.lock().unwrap();
243                    Ok(lock.clone())
244                }),
245            )
246            .boxed()
247    }
248}
249
250fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
251    let mut properties = serde_json::Map::new();
252
253    for param in &spec.parameters {
254        properties.insert(
255            param.name.to_string(),
256            json!({
257                "type": param.ty.as_ref(),
258                "description": &param.description,
259            }),
260        );
261    }
262
263    ChatCompletionToolArgs::default()
264        .r#type(ChatCompletionToolType::Function)
265        .function(FunctionObjectArgs::default()
266            .name(&spec.name)
267            .description(&spec.description)
268            .strict(true)
269            .parameters(json!({
270                "type": "object",
271                "properties": properties,
272                "required": spec.parameters.iter().filter(|param| param.required).map(|param| &param.name).collect_vec(),
273                "additionalProperties": false,
274            })).build()?).build()
275        .map_err(anyhow::Error::from)
276}
277
278fn message_to_openai(
279    message: &ChatMessage,
280) -> Result<async_openai::types::ChatCompletionRequestMessage> {
281    let openai_message = match message {
282        ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
283            .content(msg.as_str())
284            .build()?
285            .into(),
286        ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
287            .content(msg.as_str())
288            .build()?
289            .into(),
290        ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
291            .content(msg.as_str())
292            .build()?
293            .into(),
294        ChatMessage::ToolOutput(tool_call, tool_output) => {
295            let Some(content) = tool_output.content() else {
296                return Ok(ChatCompletionRequestToolMessageArgs::default()
297                    .tool_call_id(tool_call.id())
298                    .build()?
299                    .into());
300            };
301
302            ChatCompletionRequestToolMessageArgs::default()
303                .content(content)
304                .tool_call_id(tool_call.id())
305                .build()?
306                .into()
307        }
308        ChatMessage::Assistant(msg, tool_calls) => {
309            let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
310
311            if let Some(msg) = msg {
312                builder.content(msg.as_str());
313            }
314
315            if let Some(tool_calls) = tool_calls {
316                builder.tool_calls(
317                    tool_calls
318                        .iter()
319                        .map(|tool_call| ChatCompletionMessageToolCall {
320                            id: tool_call.id().to_string(),
321                            r#type: ChatCompletionToolType::Function,
322                            function: FunctionCall {
323                                name: tool_call.name().to_string(),
324                                arguments: tool_call.args().unwrap_or_default().to_string(),
325                            },
326                        })
327                        .collect::<Vec<_>>(),
328                );
329            }
330
331            builder.build()?.into()
332        }
333    };
334
335    Ok(openai_message)
336}
337
338#[cfg(test)]
339mod tests {
340    use crate::openai::{OpenAI, Options};
341
342    use super::*;
343    use wiremock::matchers::{method, path};
344    use wiremock::{Mock, MockServer, ResponseTemplate};
345
346    #[test_log::test(tokio::test)]
347    async fn test_complete() {
348        let mock_server = MockServer::start().await;
349
350        // Mock OpenAI API response
351        let response_body = json!({
352          "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
353          "object": "chat.completion",
354          "created": 123,
355          "model": "gpt-4o",
356          "choices": [
357            {
358              "index": 0,
359              "message": {
360                "role": "assistant",
361                "content": "Hello, world!",
362                "refusal": null,
363                "annotations": []
364              },
365              "logprobs": null,
366              "finish_reason": "stop"
367            }
368          ],
369          "usage": {
370            "prompt_tokens": 19,
371            "completion_tokens": 10,
372            "total_tokens": 29,
373            "prompt_tokens_details": {
374              "cached_tokens": 0,
375              "audio_tokens": 0
376            },
377            "completion_tokens_details": {
378              "reasoning_tokens": 0,
379              "audio_tokens": 0,
380              "accepted_prediction_tokens": 0,
381              "rejected_prediction_tokens": 0
382            }
383          },
384          "service_tier": "default"
385        });
386        Mock::given(method("POST"))
387            .and(path("/chat/completions"))
388            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
389            .mount(&mock_server)
390            .await;
391
392        // Create a GenericOpenAI instance with the mock server URL
393        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
394        let async_openai = async_openai::Client::with_config(config);
395
396        let openai = OpenAI::builder()
397            .client(async_openai)
398            .default_prompt_model("gpt-4o")
399            .build()
400            .expect("Can create OpenAI client.");
401
402        // Prepare a test request
403        let request = ChatCompletionRequest::builder()
404            .messages(vec![ChatMessage::User("Hi".to_string())])
405            .build()
406            .unwrap();
407
408        // Call the `complete` method
409        let response = openai.complete(&request).await.unwrap();
410
411        // Assert the response
412        assert_eq!(response.message(), Some("Hello, world!"));
413
414        // Usage
415        let usage = response.usage.unwrap();
416        assert_eq!(usage.prompt_tokens, 19);
417        assert_eq!(usage.completion_tokens, 10);
418        assert_eq!(usage.total_tokens, 29);
419    }
420
421    #[test_log::test(tokio::test)]
422    #[allow(clippy::items_after_statements)]
423    async fn test_complete_with_all_default_settings() {
424        use serde_json::Value;
425        use wiremock::{Request, Respond, ResponseTemplate};
426
427        let mock_server = wiremock::MockServer::start().await;
428
429        // Custom matcher to validate all settings in the incoming request
430        struct ValidateAllSettings;
431
432        impl Respond for ValidateAllSettings {
433            fn respond(&self, request: &Request) -> ResponseTemplate {
434                let v: Value = serde_json::from_slice(&request.body).unwrap();
435
436                // Validate required fields
437                assert_eq!(v["model"], "gpt-4-turbo");
438                let arr = v["messages"].as_array().unwrap();
439                assert_eq!(arr.len(), 1);
440                assert_eq!(arr[0]["content"], "Test");
441
442                assert_eq!(v["parallel_tool_calls"], true);
443                assert_eq!(v["max_completion_tokens"], 77);
444                assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5);
445                assert_eq!(v["reasoning_effort"], "low");
446                assert_eq!(v["seed"], 42);
447                assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5);
448
449                // Metadata as JSON object and user string
450                assert_eq!(v["metadata"], serde_json::json!({"key": "value"}));
451                assert_eq!(v["user"], "test-user");
452                ResponseTemplate::new(200).set_body_json(serde_json::json!({
453                "id": "chatcmpl-xxx",
454                "object": "chat.completion",
455                "created": 123,
456                "model": "gpt-4-turbo",
457                "choices": [{
458                    "index": 0,
459                    "message": {
460                        "role": "assistant",
461                        "content": "All settings validated",
462                        "refusal": null,
463                        "annotations": []
464                    },
465                    "logprobs": null,
466                    "finish_reason": "stop"
467                }],
468                "usage": {
469                    "prompt_tokens": 19,
470                    "completion_tokens": 10,
471                    "total_tokens": 29,
472                    "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
473                    "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}
474                },
475                "service_tier": "default"
476            }))
477            }
478        }
479
480        wiremock::Mock::given(wiremock::matchers::method("POST"))
481            .and(wiremock::matchers::path("/chat/completions"))
482            .respond_with(ValidateAllSettings)
483            .mount(&mock_server)
484            .await;
485
486        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
487        let async_openai = async_openai::Client::with_config(config);
488
489        let openai = crate::openai::OpenAI::builder()
490            .client(async_openai)
491            .default_prompt_model("gpt-4-turbo")
492            .default_embed_model("not-used")
493            .parallel_tool_calls(Some(true))
494            .default_options(
495                Options::builder()
496                    .max_completion_tokens(77)
497                    .temperature(0.42)
498                    .reasoning_effort(async_openai::types::ReasoningEffort::Low)
499                    .seed(42)
500                    .presence_penalty(1.1)
501                    .metadata(serde_json::json!({"key": "value"}))
502                    .user("test-user"),
503            )
504            .build()
505            .expect("Can create OpenAI client.");
506
507        let request = swiftide_core::chat_completion::ChatCompletionRequest::builder()
508            .messages(vec![swiftide_core::chat_completion::ChatMessage::User(
509                "Test".to_string(),
510            )])
511            .build()
512            .unwrap();
513
514        let response = openai.complete(&request).await.unwrap();
515
516        assert_eq!(response.message(), Some("All settings validated"));
517    }
518}