swiftide_integrations/openai/
chat_completion.rs

1use std::sync::Arc;
2use std::sync::Mutex;
3
4use anyhow::{Context as _, Result};
5use async_openai::types::{
6    ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
7    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
8    ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
9    ChatCompletionToolType, FunctionCall, FunctionObjectArgs,
10};
11use async_trait::async_trait;
12use futures_util::StreamExt as _;
13use futures_util::stream;
14use itertools::Itertools;
15use serde_json::json;
16use swiftide_core::ChatCompletionStream;
17use swiftide_core::chat_completion::UsageBuilder;
18use swiftide_core::chat_completion::{
19    ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec,
20    errors::LanguageModelError,
21};
22#[cfg(feature = "metrics")]
23use swiftide_core::metrics::emit_usage;
24
25use super::GenericOpenAI;
26use super::openai_error_to_language_model_error;
27
28#[async_trait]
29impl<
30    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
31> ChatCompletion for GenericOpenAI<C>
32{
33    #[tracing::instrument(skip_all)]
34    async fn complete(
35        &self,
36        request: &ChatCompletionRequest,
37    ) -> Result<ChatCompletionResponse, LanguageModelError> {
38        let model = self
39            .default_options
40            .prompt_model
41            .as_ref()
42            .context("Model not set")?;
43
44        let messages = request
45            .messages()
46            .iter()
47            .map(message_to_openai)
48            .collect::<Result<Vec<_>>>()?;
49
50        // Build the request to be sent to the OpenAI API.
51        let mut openai_request = self
52            .chat_completion_request_defaults()
53            .model(model)
54            .messages(messages)
55            .to_owned();
56
57        if !request.tools_spec.is_empty() {
58            openai_request
59                .tools(
60                    request
61                        .tools_spec()
62                        .iter()
63                        .map(tools_to_openai)
64                        .collect::<Result<Vec<_>>>()?,
65                )
66                .tool_choice("auto");
67            if let Some(par) = self.default_options.parallel_tool_calls {
68                openai_request.parallel_tool_calls(par);
69            }
70        }
71
72        let request = openai_request
73            .build()
74            .map_err(openai_error_to_language_model_error)?;
75
76        tracing::debug!(model, ?request, "Sending request to OpenAI");
77
78        let response = self
79            .client
80            .chat()
81            .create(request)
82            .await
83            .map_err(openai_error_to_language_model_error)?;
84
85        tracing::debug!(?response, "Received response from OpenAI");
86
87        let mut builder = ChatCompletionResponse::builder()
88            .maybe_message(
89                response
90                    .choices
91                    .first()
92                    .and_then(|choice| choice.message.content.clone()),
93            )
94            .maybe_tool_calls(
95                response
96                    .choices
97                    .first()
98                    .and_then(|choice| choice.message.tool_calls.clone())
99                    .map(|tool_calls| {
100                        tool_calls
101                            .iter()
102                            .map(|tool_call| {
103                                ToolCall::builder()
104                                    .id(tool_call.id.clone())
105                                    .args(tool_call.function.arguments.clone())
106                                    .name(tool_call.function.name.clone())
107                                    .build()
108                                    .expect("infallible")
109                            })
110                            .collect_vec()
111                    }),
112            )
113            .to_owned();
114
115        if let Some(usage) = &response.usage {
116            let usage = UsageBuilder::default()
117                .prompt_tokens(usage.prompt_tokens)
118                .completion_tokens(usage.completion_tokens)
119                .total_tokens(usage.total_tokens)
120                .build()
121                .map_err(LanguageModelError::permanent)?;
122
123            builder.usage(usage);
124
125            #[cfg(feature = "metrics")]
126            {
127                if let Some(usage) = response.usage.as_ref() {
128                    emit_usage(
129                        model,
130                        usage.prompt_tokens.into(),
131                        usage.completion_tokens.into(),
132                        usage.total_tokens.into(),
133                        self.metric_metadata.as_ref(),
134                    );
135                }
136            }
137        }
138
139        builder.build().map_err(LanguageModelError::from)
140    }
141
142    #[tracing::instrument(skip_all)]
143    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
144        let Some(model) = self.default_options.prompt_model.as_ref() else {
145            return LanguageModelError::permanent("Model not set").into();
146        };
147
148        let messages = match request
149            .messages()
150            .iter()
151            .map(message_to_openai)
152            .collect::<Result<Vec<_>>>()
153        {
154            Ok(messages) => messages,
155            Err(e) => return LanguageModelError::from(e).into(),
156        };
157
158        // Build the request to be sent to the OpenAI API.
159        let mut openai_request = self
160            .chat_completion_request_defaults()
161            .model(model)
162            .messages(messages)
163            .to_owned();
164
165        if !request.tools_spec.is_empty() {
166            openai_request
167                .tools(
168                    match request
169                        .tools_spec()
170                        .iter()
171                        .map(tools_to_openai)
172                        .collect::<Result<Vec<_>>>()
173                    {
174                        Ok(tools) => tools,
175                        Err(e) => {
176                            return LanguageModelError::from(e).into();
177                        }
178                    },
179                )
180                .tool_choice("auto");
181            if let Some(par) = self.default_options.parallel_tool_calls {
182                openai_request.parallel_tool_calls(par);
183            }
184        }
185
186        let request = match openai_request.build() {
187            Ok(request) => request,
188            Err(e) => {
189                return openai_error_to_language_model_error(e).into();
190            }
191        };
192
193        tracing::debug!(model, ?request, "Sending request to OpenAI");
194
195        let response = match self.client.chat().create_stream(request).await {
196            Ok(response) => response,
197            Err(e) => return openai_error_to_language_model_error(e).into(),
198        };
199
200        let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
201        let final_response = accumulating_response.clone();
202        let stream_full = self.stream_full;
203
204        #[cfg(feature = "metrics")]
205        let metric_metadata = self.metric_metadata.clone();
206        #[cfg(feature = "metrics")]
207        let model = model.to_string();
208
209        response
210            .map(move |chunk| match chunk {
211                Ok(chunk) => {
212                    let accumulating_response = Arc::clone(&accumulating_response);
213
214                    let delta_message = chunk.choices[0].delta.content.as_deref();
215                    let delta_tool_calls = chunk.choices[0].delta.tool_calls.as_deref();
216                    let usage = chunk.usage.as_ref();
217
218                    let chat_completion_response = {
219                        let mut lock = accumulating_response.lock().unwrap();
220                        lock.append_message_delta(delta_message);
221
222                        if let Some(delta_tool_calls) = delta_tool_calls {
223                            for tc in delta_tool_calls {
224                                lock.append_tool_call_delta(
225                                    tc.index as usize,
226                                    tc.id.as_deref(),
227                                    tc.function.as_ref().and_then(|f| f.name.as_deref()),
228                                    tc.function.as_ref().and_then(|f| f.arguments.as_deref()),
229                                );
230                            }
231                        }
232
233                        if let Some(usage) = usage {
234                            lock.append_usage_delta(
235                                usage.prompt_tokens,
236                                usage.completion_tokens,
237                                usage.total_tokens,
238                            );
239                        }
240
241                        if stream_full {
242                            lock.clone()
243                        } else {
244                            // If we are not streaming the full response, we return a clone of the
245                            // current state to avoid holding the lock
246                            // for too long.
247                            ChatCompletionResponse {
248                                id: lock.id,
249                                message: None,
250                                tool_calls: None,
251                                usage: None,
252                                delta: lock.delta.clone(),
253                            }
254                        }
255                    };
256
257                    Ok(chat_completion_response)
258                }
259                Err(e) => Err(openai_error_to_language_model_error(e)),
260            })
261            .chain(
262                stream::iter(vec![final_response]).map(move |accumulating_response| {
263                    let lock = accumulating_response.lock().unwrap();
264
265                    #[cfg(feature = "metrics")]
266                    {
267                        if let Some(usage) = lock.usage.as_ref() {
268                            emit_usage(
269                                &model,
270                                usage.prompt_tokens.into(),
271                                usage.completion_tokens.into(),
272                                usage.total_tokens.into(),
273                                metric_metadata.as_ref(),
274                            );
275                        }
276                    }
277                    Ok(lock.clone())
278                }),
279            )
280            .boxed()
281    }
282}
283
284fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
285    let mut properties = serde_json::Map::new();
286
287    for param in &spec.parameters {
288        properties.insert(
289            param.name.to_string(),
290            json!({
291                "type": param.ty.as_ref(),
292                "description": &param.description,
293            }),
294        );
295    }
296
297    ChatCompletionToolArgs::default()
298        .r#type(ChatCompletionToolType::Function)
299        .function(FunctionObjectArgs::default()
300            .name(&spec.name)
301            .description(&spec.description)
302            .strict(true)
303            .parameters(json!({
304                "type": "object",
305                "properties": properties,
306                "required": spec.parameters.iter().filter(|param| param.required).map(|param| &param.name).collect_vec(),
307                "additionalProperties": false,
308            })).build()?).build()
309        .map_err(anyhow::Error::from)
310}
311
312fn message_to_openai(
313    message: &ChatMessage,
314) -> Result<async_openai::types::ChatCompletionRequestMessage> {
315    let openai_message = match message {
316        ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
317            .content(msg.as_str())
318            .build()?
319            .into(),
320        ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
321            .content(msg.as_str())
322            .build()?
323            .into(),
324        ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
325            .content(msg.as_str())
326            .build()?
327            .into(),
328        ChatMessage::ToolOutput(tool_call, tool_output) => {
329            let Some(content) = tool_output.content() else {
330                return Ok(ChatCompletionRequestToolMessageArgs::default()
331                    .tool_call_id(tool_call.id())
332                    .build()?
333                    .into());
334            };
335
336            ChatCompletionRequestToolMessageArgs::default()
337                .content(content)
338                .tool_call_id(tool_call.id())
339                .build()?
340                .into()
341        }
342        ChatMessage::Assistant(msg, tool_calls) => {
343            let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
344
345            if let Some(msg) = msg {
346                builder.content(msg.as_str());
347            }
348
349            if let Some(tool_calls) = tool_calls {
350                builder.tool_calls(
351                    tool_calls
352                        .iter()
353                        .map(|tool_call| ChatCompletionMessageToolCall {
354                            id: tool_call.id().to_string(),
355                            r#type: ChatCompletionToolType::Function,
356                            function: FunctionCall {
357                                name: tool_call.name().to_string(),
358                                arguments: tool_call.args().unwrap_or_default().to_string(),
359                            },
360                        })
361                        .collect::<Vec<_>>(),
362                );
363            }
364
365            builder.build()?.into()
366        }
367    };
368
369    Ok(openai_message)
370}
371
372#[cfg(test)]
373mod tests {
374    use crate::openai::{OpenAI, Options};
375
376    use super::*;
377    use wiremock::matchers::{method, path};
378    use wiremock::{Mock, MockServer, ResponseTemplate};
379
380    #[test_log::test(tokio::test)]
381    async fn test_complete() {
382        let mock_server = MockServer::start().await;
383
384        // Mock OpenAI API response
385        let response_body = json!({
386          "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
387          "object": "chat.completion",
388          "created": 123,
389          "model": "gpt-4o",
390          "choices": [
391            {
392              "index": 0,
393              "message": {
394                "role": "assistant",
395                "content": "Hello, world!",
396                "refusal": null,
397                "annotations": []
398              },
399              "logprobs": null,
400              "finish_reason": "stop"
401            }
402          ],
403          "usage": {
404            "prompt_tokens": 19,
405            "completion_tokens": 10,
406            "total_tokens": 29,
407            "prompt_tokens_details": {
408              "cached_tokens": 0,
409              "audio_tokens": 0
410            },
411            "completion_tokens_details": {
412              "reasoning_tokens": 0,
413              "audio_tokens": 0,
414              "accepted_prediction_tokens": 0,
415              "rejected_prediction_tokens": 0
416            }
417          },
418          "service_tier": "default"
419        });
420        Mock::given(method("POST"))
421            .and(path("/chat/completions"))
422            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
423            .mount(&mock_server)
424            .await;
425
426        // Create a GenericOpenAI instance with the mock server URL
427        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
428        let async_openai = async_openai::Client::with_config(config);
429
430        let openai = OpenAI::builder()
431            .client(async_openai)
432            .default_prompt_model("gpt-4o")
433            .build()
434            .expect("Can create OpenAI client.");
435
436        // Prepare a test request
437        let request = ChatCompletionRequest::builder()
438            .messages(vec![ChatMessage::User("Hi".to_string())])
439            .build()
440            .unwrap();
441
442        // Call the `complete` method
443        let response = openai.complete(&request).await.unwrap();
444
445        // Assert the response
446        assert_eq!(response.message(), Some("Hello, world!"));
447
448        // Usage
449        let usage = response.usage.unwrap();
450        assert_eq!(usage.prompt_tokens, 19);
451        assert_eq!(usage.completion_tokens, 10);
452        assert_eq!(usage.total_tokens, 29);
453    }
454
455    #[test_log::test(tokio::test)]
456    #[allow(clippy::items_after_statements)]
457    async fn test_complete_with_all_default_settings() {
458        use serde_json::Value;
459        use wiremock::{Request, Respond, ResponseTemplate};
460
461        let mock_server = wiremock::MockServer::start().await;
462
463        // Custom matcher to validate all settings in the incoming request
464        struct ValidateAllSettings;
465
466        impl Respond for ValidateAllSettings {
467            fn respond(&self, request: &Request) -> ResponseTemplate {
468                let v: Value = serde_json::from_slice(&request.body).unwrap();
469
470                // Validate required fields
471                assert_eq!(v["model"], "gpt-4-turbo");
472                let arr = v["messages"].as_array().unwrap();
473                assert_eq!(arr.len(), 1);
474                assert_eq!(arr[0]["content"], "Test");
475
476                assert_eq!(v["parallel_tool_calls"], true);
477                assert_eq!(v["max_completion_tokens"], 77);
478                assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5);
479                assert_eq!(v["reasoning_effort"], "low");
480                assert_eq!(v["seed"], 42);
481                assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5);
482
483                // Metadata as JSON object and user string
484                assert_eq!(v["metadata"], serde_json::json!({"key": "value"}));
485                assert_eq!(v["user"], "test-user");
486                ResponseTemplate::new(200).set_body_json(serde_json::json!({
487                "id": "chatcmpl-xxx",
488                "object": "chat.completion",
489                "created": 123,
490                "model": "gpt-4-turbo",
491                "choices": [{
492                    "index": 0,
493                    "message": {
494                        "role": "assistant",
495                        "content": "All settings validated",
496                        "refusal": null,
497                        "annotations": []
498                    },
499                    "logprobs": null,
500                    "finish_reason": "stop"
501                }],
502                "usage": {
503                    "prompt_tokens": 19,
504                    "completion_tokens": 10,
505                    "total_tokens": 29,
506                    "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
507                    "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}
508                },
509                "service_tier": "default"
510            }))
511            }
512        }
513
514        wiremock::Mock::given(wiremock::matchers::method("POST"))
515            .and(wiremock::matchers::path("/chat/completions"))
516            .respond_with(ValidateAllSettings)
517            .mount(&mock_server)
518            .await;
519
520        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
521        let async_openai = async_openai::Client::with_config(config);
522
523        let openai = crate::openai::OpenAI::builder()
524            .client(async_openai)
525            .default_prompt_model("gpt-4-turbo")
526            .default_embed_model("not-used")
527            .parallel_tool_calls(Some(true))
528            .default_options(
529                Options::builder()
530                    .max_completion_tokens(77)
531                    .temperature(0.42)
532                    .reasoning_effort(async_openai::types::ReasoningEffort::Low)
533                    .seed(42)
534                    .presence_penalty(1.1)
535                    .metadata(serde_json::json!({"key": "value"}))
536                    .user("test-user"),
537            )
538            .build()
539            .expect("Can create OpenAI client.");
540
541        let request = swiftide_core::chat_completion::ChatCompletionRequest::builder()
542            .messages(vec![swiftide_core::chat_completion::ChatMessage::User(
543                "Test".to_string(),
544            )])
545            .build()
546            .unwrap();
547
548        let response = openai.complete(&request).await.unwrap();
549
550        assert_eq!(response.message(), Some("All settings validated"));
551    }
552}