swiftide_integrations/openai/
chat_completion.rs

1use std::sync::Arc;
2use std::sync::Mutex;
3
4use anyhow::{Context as _, Result};
5use async_openai::types::ChatCompletionStreamOptions;
6use async_openai::types::{
7    ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
8    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
9    ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
10    ChatCompletionToolType, FunctionCall, FunctionObjectArgs,
11};
12use async_trait::async_trait;
13use futures_util::StreamExt as _;
14use futures_util::stream;
15use itertools::Itertools;
16use serde_json::json;
17use swiftide_core::ChatCompletionStream;
18use swiftide_core::chat_completion::UsageBuilder;
19use swiftide_core::chat_completion::{
20    ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolCall, ToolSpec,
21    errors::LanguageModelError,
22};
23#[cfg(feature = "metrics")]
24use swiftide_core::metrics::emit_usage;
25
26use super::GenericOpenAI;
27use super::openai_error_to_language_model_error;
28
29#[async_trait]
30impl<
31    C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug + Clone,
32> ChatCompletion for GenericOpenAI<C>
33{
34    #[cfg_attr(not(feature = "langfuse"), tracing::instrument(skip_all, err))]
35    #[cfg_attr(
36        feature = "langfuse",
37        tracing::instrument(skip_all, err, fields(langfuse.type = "GENERATION"))
38    )]
39    async fn complete(
40        &self,
41        request: &ChatCompletionRequest,
42    ) -> Result<ChatCompletionResponse, LanguageModelError> {
43        let model = self
44            .default_options
45            .prompt_model
46            .as_ref()
47            .context("Model not set")?;
48
49        let messages = request
50            .messages()
51            .iter()
52            .map(message_to_openai)
53            .collect::<Result<Vec<_>>>()?;
54
55        // Build the request to be sent to the OpenAI API.
56        let mut openai_request = self
57            .chat_completion_request_defaults()
58            .model(model)
59            .messages(messages)
60            .to_owned();
61
62        if !request.tools_spec.is_empty() {
63            openai_request
64                .tools(
65                    request
66                        .tools_spec()
67                        .iter()
68                        .map(tools_to_openai)
69                        .collect::<Result<Vec<_>>>()?,
70                )
71                .tool_choice("auto");
72            if let Some(par) = self.default_options.parallel_tool_calls {
73                openai_request.parallel_tool_calls(par);
74            }
75        }
76
77        let request = openai_request
78            .build()
79            .map_err(openai_error_to_language_model_error)?;
80
81        tracing::trace!(model, ?request, "Sending request to OpenAI");
82
83        let response = self
84            .client
85            .chat()
86            .create(request.clone())
87            .await
88            .map_err(openai_error_to_language_model_error)?;
89
90        if cfg!(feature = "langfuse") {
91            let usage = response.usage.clone().unwrap_or_default();
92            tracing::debug!(
93                langfuse.model = model,
94                langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
95                langfuse.output = %serde_json::to_string_pretty(&response).unwrap_or_default(),
96                langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
97            );
98        }
99
100        tracing::trace!(?response, "[ChatCompletion] Full response from OpenAI");
101        // Make sure the debug log is a concise one line
102
103        let mut builder = ChatCompletionResponse::builder()
104            .maybe_message(
105                response
106                    .choices
107                    .first()
108                    .and_then(|choice| choice.message.content.clone()),
109            )
110            .maybe_tool_calls(
111                response
112                    .choices
113                    .first()
114                    .and_then(|choice| choice.message.tool_calls.clone())
115                    .map(|tool_calls| {
116                        tool_calls
117                            .iter()
118                            .map(|tool_call| {
119                                ToolCall::builder()
120                                    .id(tool_call.id.clone())
121                                    .args(tool_call.function.arguments.clone())
122                                    .name(tool_call.function.name.clone())
123                                    .build()
124                                    .expect("infallible")
125                            })
126                            .collect_vec()
127                    }),
128            )
129            .to_owned();
130
131        if let Some(usage) = &response.usage {
132            let usage = UsageBuilder::default()
133                .prompt_tokens(usage.prompt_tokens)
134                .completion_tokens(usage.completion_tokens)
135                .total_tokens(usage.total_tokens)
136                .build()
137                .map_err(LanguageModelError::permanent)?;
138
139            if let Some(callback) = &self.on_usage {
140                callback(&usage).await?;
141            }
142
143            builder.usage(usage);
144
145            #[cfg(feature = "metrics")]
146            {
147                if let Some(usage) = response.usage.as_ref() {
148                    emit_usage(
149                        model,
150                        usage.prompt_tokens.into(),
151                        usage.completion_tokens.into(),
152                        usage.total_tokens.into(),
153                        self.metric_metadata.as_ref(),
154                    );
155                }
156            }
157        }
158
159        let our_response = builder.build().map_err(LanguageModelError::from)?;
160
161        Ok(our_response)
162    }
163
164    #[tracing::instrument(skip_all)]
165    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
166        let Some(model) = self.default_options.prompt_model.clone() else {
167            return LanguageModelError::permanent("Model not set").into();
168        };
169
170        let messages = match request
171            .messages()
172            .iter()
173            .map(message_to_openai)
174            .collect::<Result<Vec<_>>>()
175        {
176            Ok(messages) => messages,
177            Err(e) => return LanguageModelError::from(e).into(),
178        };
179
180        // Build the request to be sent to the OpenAI API.
181        let mut openai_request = self
182            .chat_completion_request_defaults()
183            .model(&model)
184            .messages(messages)
185            .stream_options(ChatCompletionStreamOptions {
186                include_usage: true,
187            })
188            .to_owned();
189
190        if !request.tools_spec.is_empty() {
191            openai_request
192                .tools(
193                    match request
194                        .tools_spec()
195                        .iter()
196                        .map(tools_to_openai)
197                        .collect::<Result<Vec<_>>>()
198                    {
199                        Ok(tools) => tools,
200                        Err(e) => {
201                            return LanguageModelError::from(e).into();
202                        }
203                    },
204                )
205                .tool_choice("auto");
206            if let Some(par) = self.default_options.parallel_tool_calls {
207                openai_request.parallel_tool_calls(par);
208            }
209        }
210
211        let request = match openai_request.build() {
212            Ok(request) => request,
213            Err(e) => {
214                return openai_error_to_language_model_error(e).into();
215            }
216        };
217
218        tracing::trace!(model, ?request, "Sending request to OpenAI");
219
220        let response = match self.client.chat().create_stream(request.clone()).await {
221            Ok(response) => response,
222            Err(e) => return openai_error_to_language_model_error(e).into(),
223        };
224
225        let accumulating_response = Arc::new(Mutex::new(ChatCompletionResponse::default()));
226        let final_response = accumulating_response.clone();
227        let stream_full = self.stream_full;
228
229        #[cfg(feature = "metrics")]
230        let metric_metadata = self.metric_metadata.clone();
231
232        let span = if cfg!(feature = "langfuse") {
233            tracing::info_span!(
234                "stream",
235                langfuse.type = "GENERATION",
236            )
237        } else {
238            tracing::info_span!("stream")
239        };
240
241        let maybe_usage_callback = self.on_usage.clone();
242        let stream = response
243            .map(move |chunk| match chunk {
244                Ok(chunk) => {
245                    let accumulating_response = Arc::clone(&accumulating_response);
246
247                    let delta_message = chunk
248                        .choices
249                        .first()
250                        .and_then(|d| d.delta.content.as_deref());
251                    let delta_tool_calls = chunk
252                        .choices
253                        .first()
254                        .and_then(|d| d.delta.tool_calls.as_deref());
255                    let usage = chunk.usage.as_ref();
256
257                    let chat_completion_response = {
258                        let mut lock = accumulating_response.lock().unwrap();
259                        lock.append_message_delta(delta_message);
260
261                        if let Some(delta_tool_calls) = delta_tool_calls {
262                            for tc in delta_tool_calls {
263                                lock.append_tool_call_delta(
264                                    tc.index as usize,
265                                    tc.id.as_deref(),
266                                    tc.function.as_ref().and_then(|f| f.name.as_deref()),
267                                    tc.function.as_ref().and_then(|f| f.arguments.as_deref()),
268                                );
269                            }
270                        }
271
272                        if let Some(usage) = usage {
273                            lock.append_usage_delta(
274                                usage.prompt_tokens,
275                                usage.completion_tokens,
276                                usage.total_tokens,
277                            );
278                        }
279
280                        if stream_full {
281                            lock.clone()
282                        } else {
283                            // If we are not streaming the full response, we return a clone of the
284                            // current state to avoid holding the lock
285                            // for too long.
286                            ChatCompletionResponse {
287                                id: lock.id,
288                                message: None,
289                                tool_calls: None,
290                                usage: None,
291                                delta: lock.delta.clone(),
292                            }
293                        }
294                    };
295
296                    Ok(chat_completion_response)
297                }
298                Err(e) => Err(openai_error_to_language_model_error(e)),
299            })
300            .chain(
301                stream::iter(vec![final_response]).map(move |accumulating_response| {
302                    let lock = accumulating_response.lock().unwrap();
303
304                    let usage = lock.usage.clone().unwrap_or_default();
305
306                    if cfg!(feature = "langfuse") {
307                        tracing::debug!(
308                            langfuse.model = model,
309                            langfuse.input = %serde_json::to_string_pretty(&request).unwrap_or_default(),
310                            langfuse.output = %serde_json::to_string_pretty(&*lock).unwrap_or_default(),
311                            langfuse.usage = %serde_json::to_string_pretty(&usage).unwrap_or_default(),
312                        );
313                    }
314
315                    tracing::debug!(
316                        usage = format!(
317                            "{}/{}/{}",
318                            lock.usage.as_ref().map_or(0, |u| u.prompt_tokens),
319                            lock.usage.as_ref().map_or(0, |u| u.completion_tokens),
320                            lock.usage.as_ref().map_or(0, |u| u.total_tokens)
321                        ),
322                        model = &model,
323                        has_message = lock.message.is_some(),
324                        num_tool_calls = lock.tool_calls.as_ref().map_or(0, std::vec::Vec::len),
325                        "[ChatCompletion/Streaming] Response from OpenAI"
326                    );
327                    #[allow(clippy::collapsible_if)]
328                    if let Some(usage) = lock.usage.as_ref() {
329                        if let Some(callback) = maybe_usage_callback.as_ref() {
330                            let usage = usage.clone();
331                            let callback = callback.clone();
332
333                            tokio::spawn(async move {
334                                if let Err(e) = callback(&usage).await {
335                                    tracing::error!("Error in on_usage callback: {}", e);
336                                }
337                            });
338                        }
339
340                        #[cfg(feature = "metrics")]
341                        {
342                            emit_usage(
343                                &model,
344                                usage.prompt_tokens.into(),
345                                usage.completion_tokens.into(),
346                                usage.total_tokens.into(),
347                                metric_metadata.as_ref(),
348                            );
349                        }
350                    }
351                    Ok(lock.clone())
352                }),
353            );
354
355        let stream = tracing_futures::Instrument::instrument(stream, span);
356
357        Box::pin(stream)
358    }
359}
360
361fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
362    let mut properties = serde_json::Map::new();
363
364    for param in &spec.parameters {
365        properties.insert(
366            param.name.to_string(),
367            json!({
368                "type": param.ty.as_ref(),
369                "description": &param.description,
370            }),
371        );
372    }
373
374    ChatCompletionToolArgs::default()
375        .r#type(ChatCompletionToolType::Function)
376        .function(FunctionObjectArgs::default()
377            .name(&spec.name)
378            .description(&spec.description)
379            .strict(true)
380            .parameters(json!({
381                "type": "object",
382                "properties": properties,
383                "required": spec.parameters.iter().filter(|param| param.required).map(|param| &param.name).collect_vec(),
384                "additionalProperties": false,
385            })).build()?).build()
386        .map_err(anyhow::Error::from)
387}
388
389fn message_to_openai(
390    message: &ChatMessage,
391) -> Result<async_openai::types::ChatCompletionRequestMessage> {
392    let openai_message = match message {
393        ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
394            .content(msg.as_str())
395            .build()?
396            .into(),
397        ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
398            .content(msg.as_str())
399            .build()?
400            .into(),
401        ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
402            .content(msg.as_str())
403            .build()?
404            .into(),
405        ChatMessage::ToolOutput(tool_call, tool_output) => {
406            let Some(content) = tool_output.content() else {
407                return Ok(ChatCompletionRequestToolMessageArgs::default()
408                    .tool_call_id(tool_call.id())
409                    .build()?
410                    .into());
411            };
412
413            ChatCompletionRequestToolMessageArgs::default()
414                .content(content)
415                .tool_call_id(tool_call.id())
416                .build()?
417                .into()
418        }
419        ChatMessage::Assistant(msg, tool_calls) => {
420            let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
421
422            if let Some(msg) = msg {
423                builder.content(msg.as_str());
424            }
425
426            if let Some(tool_calls) = tool_calls {
427                builder.tool_calls(
428                    tool_calls
429                        .iter()
430                        .map(|tool_call| ChatCompletionMessageToolCall {
431                            id: tool_call.id().to_string(),
432                            r#type: ChatCompletionToolType::Function,
433                            function: FunctionCall {
434                                name: tool_call.name().to_string(),
435                                arguments: tool_call.args().unwrap_or_default().to_string(),
436                            },
437                        })
438                        .collect::<Vec<_>>(),
439                );
440            }
441
442            builder.build()?.into()
443        }
444    };
445
446    Ok(openai_message)
447}
448
449#[cfg(test)]
450mod tests {
451    use crate::openai::{OpenAI, Options};
452
453    use super::*;
454    use wiremock::matchers::{method, path};
455    use wiremock::{Mock, MockServer, ResponseTemplate};
456
457    #[test_log::test(tokio::test)]
458    async fn test_complete() {
459        let mock_server = MockServer::start().await;
460
461        // Mock OpenAI API response
462        let response_body = json!({
463          "id": "chatcmpl-B9MBs8CjcvOU2jLn4n570S5qMJKcT",
464          "object": "chat.completion",
465          "created": 123,
466          "model": "gpt-4o",
467          "choices": [
468            {
469              "index": 0,
470              "message": {
471                "role": "assistant",
472                "content": "Hello, world!",
473                "refusal": null,
474                "annotations": []
475              },
476              "logprobs": null,
477              "finish_reason": "stop"
478            }
479          ],
480          "usage": {
481            "prompt_tokens": 19,
482            "completion_tokens": 10,
483            "total_tokens": 29,
484            "prompt_tokens_details": {
485              "cached_tokens": 0,
486              "audio_tokens": 0
487            },
488            "completion_tokens_details": {
489              "reasoning_tokens": 0,
490              "audio_tokens": 0,
491              "accepted_prediction_tokens": 0,
492              "rejected_prediction_tokens": 0
493            }
494          },
495          "service_tier": "default"
496        });
497        Mock::given(method("POST"))
498            .and(path("/chat/completions"))
499            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
500            .mount(&mock_server)
501            .await;
502
503        // Create a GenericOpenAI instance with the mock server URL
504        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
505        let async_openai = async_openai::Client::with_config(config);
506
507        let openai = OpenAI::builder()
508            .client(async_openai)
509            .default_prompt_model("gpt-4o")
510            .build()
511            .expect("Can create OpenAI client.");
512
513        // Prepare a test request
514        let request = ChatCompletionRequest::builder()
515            .messages(vec![ChatMessage::User("Hi".to_string())])
516            .build()
517            .unwrap();
518
519        // Call the `complete` method
520        let response = openai.complete(&request).await.unwrap();
521
522        // Assert the response
523        assert_eq!(response.message(), Some("Hello, world!"));
524
525        // Usage
526        let usage = response.usage.unwrap();
527        assert_eq!(usage.prompt_tokens, 19);
528        assert_eq!(usage.completion_tokens, 10);
529        assert_eq!(usage.total_tokens, 29);
530    }
531
532    #[test_log::test(tokio::test)]
533    #[allow(clippy::items_after_statements)]
534    async fn test_complete_with_all_default_settings() {
535        use serde_json::Value;
536        use wiremock::{Request, Respond, ResponseTemplate};
537
538        let mock_server = wiremock::MockServer::start().await;
539
540        // Custom matcher to validate all settings in the incoming request
541        struct ValidateAllSettings;
542
543        impl Respond for ValidateAllSettings {
544            fn respond(&self, request: &Request) -> ResponseTemplate {
545                let v: Value = serde_json::from_slice(&request.body).unwrap();
546
547                // Validate required fields
548                assert_eq!(v["model"], "gpt-4-turbo");
549                let arr = v["messages"].as_array().unwrap();
550                assert_eq!(arr.len(), 1);
551                assert_eq!(arr[0]["content"], "Test");
552
553                assert_eq!(v["parallel_tool_calls"], true);
554                assert_eq!(v["max_completion_tokens"], 77);
555                assert!((v["temperature"].as_f64().unwrap() - 0.42).abs() < 1e-5);
556                assert_eq!(v["reasoning_effort"], "low");
557                assert_eq!(v["seed"], 42);
558                assert!((v["presence_penalty"].as_f64().unwrap() - 1.1).abs() < 1e-5);
559
560                // Metadata as JSON object and user string
561                assert_eq!(v["metadata"], serde_json::json!({"key": "value"}));
562                assert_eq!(v["user"], "test-user");
563                ResponseTemplate::new(200).set_body_json(serde_json::json!({
564                "id": "chatcmpl-xxx",
565                "object": "chat.completion",
566                "created": 123,
567                "model": "gpt-4-turbo",
568                "choices": [{
569                    "index": 0,
570                    "message": {
571                        "role": "assistant",
572                        "content": "All settings validated",
573                        "refusal": null,
574                        "annotations": []
575                    },
576                    "logprobs": null,
577                    "finish_reason": "stop"
578                }],
579                "usage": {
580                    "prompt_tokens": 19,
581                    "completion_tokens": 10,
582                    "total_tokens": 29,
583                    "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0},
584                    "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}
585                },
586                "service_tier": "default"
587            }))
588            }
589        }
590
591        wiremock::Mock::given(wiremock::matchers::method("POST"))
592            .and(wiremock::matchers::path("/chat/completions"))
593            .respond_with(ValidateAllSettings)
594            .mount(&mock_server)
595            .await;
596
597        let config = async_openai::config::OpenAIConfig::new().with_api_base(mock_server.uri());
598        let async_openai = async_openai::Client::with_config(config);
599
600        let openai = crate::openai::OpenAI::builder()
601            .client(async_openai)
602            .default_prompt_model("gpt-4-turbo")
603            .default_embed_model("not-used")
604            .parallel_tool_calls(Some(true))
605            .default_options(
606                Options::builder()
607                    .max_completion_tokens(77)
608                    .temperature(0.42)
609                    .reasoning_effort(async_openai::types::ReasoningEffort::Low)
610                    .seed(42)
611                    .presence_penalty(1.1)
612                    .metadata(serde_json::json!({"key": "value"}))
613                    .user("test-user"),
614            )
615            .build()
616            .expect("Can create OpenAI client.");
617
618        let request = swiftide_core::chat_completion::ChatCompletionRequest::builder()
619            .messages(vec![swiftide_core::chat_completion::ChatMessage::User(
620                "Test".to_string(),
621            )])
622            .build()
623            .unwrap();
624
625        let response = openai.complete(&request).await.unwrap();
626
627        assert_eq!(response.message(), Some("All settings validated"));
628    }
629}