swiftide_integrations/openai/
chat_completion.rs

1use anyhow::{Context as _, Result};
2use async_openai::types::{
3    ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
4    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
5    ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
6    ChatCompletionToolType, CreateChatCompletionRequestArgs, FunctionCall, FunctionObjectArgs,
7};
8use async_trait::async_trait;
9use itertools::Itertools;
10use serde_json::json;
11use swiftide_core::chat_completion::{
12    errors::LanguageModelError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
13    ChatMessage, ToolCall, ToolSpec,
14};
15
16use super::openai_error_to_language_model_error;
17use super::GenericOpenAI;
18
19#[async_trait]
20impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
21    ChatCompletion for GenericOpenAI<C>
22{
23    #[tracing::instrument(skip_all)]
24    async fn complete(
25        &self,
26        request: &ChatCompletionRequest,
27    ) -> Result<ChatCompletionResponse, LanguageModelError> {
28        let model = self
29            .default_options
30            .prompt_model
31            .as_ref()
32            .context("Model not set")?;
33
34        let messages = request
35            .messages()
36            .iter()
37            .map(message_to_openai)
38            .collect::<Result<Vec<_>>>()?;
39
40        // Build the request to be sent to the OpenAI API.
41        let mut openai_request = CreateChatCompletionRequestArgs::default()
42            .model(model)
43            .messages(messages)
44            .to_owned();
45
46        if !request.tools_spec.is_empty() {
47            openai_request
48                .tools(
49                    request
50                        .tools_spec()
51                        .iter()
52                        .map(tools_to_openai)
53                        .collect::<Result<Vec<_>>>()?,
54                )
55                .tool_choice("auto");
56            if let Some(par) = self.default_options.parallel_tool_calls {
57                openai_request.parallel_tool_calls(par);
58            }
59        }
60
61        let request = openai_request
62            .build()
63            .map_err(openai_error_to_language_model_error)?;
64
65        tracing::debug!(
66            model = &model,
67            request = serde_json::to_string_pretty(&request).expect("infallible"),
68            "Sending request to OpenAI"
69        );
70
71        let response = self
72            .client
73            .chat()
74            .create(request)
75            .await
76            .map_err(openai_error_to_language_model_error)?;
77
78        tracing::debug!(
79            response = serde_json::to_string_pretty(&response).expect("infallible"),
80            "Received response from OpenAI"
81        );
82
83        ChatCompletionResponse::builder()
84            .maybe_message(
85                response
86                    .choices
87                    .first()
88                    .and_then(|choice| choice.message.content.clone()),
89            )
90            .maybe_tool_calls(
91                response
92                    .choices
93                    .first()
94                    .and_then(|choice| choice.message.tool_calls.clone())
95                    .map(|tool_calls| {
96                        tool_calls
97                            .iter()
98                            .map(|tool_call| {
99                                ToolCall::builder()
100                                    .id(tool_call.id.clone())
101                                    .args(tool_call.function.arguments.clone())
102                                    .name(tool_call.function.name.clone())
103                                    .build()
104                                    .expect("infallible")
105                            })
106                            .collect_vec()
107                    }),
108            )
109            .build()
110            .map_err(LanguageModelError::from)
111    }
112}
113
114// TODO: Maybe just into the whole thing? Types are not in this crate
115
116fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
117    let mut properties = serde_json::Map::new();
118
119    for param in &spec.parameters {
120        properties.insert(
121            param.name.to_string(),
122            json!({
123                "type": param.ty.as_ref(),
124                "description": &param.description,
125            }),
126        );
127    }
128
129    ChatCompletionToolArgs::default()
130        .r#type(ChatCompletionToolType::Function)
131        .function(FunctionObjectArgs::default()
132            .name(&spec.name)
133            .description(&spec.description)
134            .strict(true)
135            .parameters(json!({
136                "type": "object",
137                "properties": properties,
138                "required": spec.parameters.iter().filter(|param| param.required).map(|param| &param.name).collect_vec(),
139                "additionalProperties": false,
140            })).build()?).build()
141        .map_err(anyhow::Error::from)
142}
143
144fn message_to_openai(
145    message: &ChatMessage,
146) -> Result<async_openai::types::ChatCompletionRequestMessage> {
147    let openai_message = match message {
148        ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
149            .content(msg.as_str())
150            .build()?
151            .into(),
152        ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
153            .content(msg.as_str())
154            .build()?
155            .into(),
156        ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
157            .content(msg.as_str())
158            .build()?
159            .into(),
160        ChatMessage::ToolOutput(tool_call, tool_output) => {
161            let Some(content) = tool_output.content() else {
162                return Ok(ChatCompletionRequestToolMessageArgs::default()
163                    .tool_call_id(tool_call.id())
164                    .build()?
165                    .into());
166            };
167
168            ChatCompletionRequestToolMessageArgs::default()
169                .content(content)
170                .tool_call_id(tool_call.id())
171                .build()?
172                .into()
173        }
174        ChatMessage::Assistant(msg, tool_calls) => {
175            let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
176
177            if let Some(msg) = msg {
178                builder.content(msg.as_str());
179            }
180
181            if let Some(tool_calls) = tool_calls {
182                builder.tool_calls(
183                    tool_calls
184                        .iter()
185                        .map(|tool_call| ChatCompletionMessageToolCall {
186                            id: tool_call.id().to_string(),
187                            r#type: ChatCompletionToolType::Function,
188                            function: FunctionCall {
189                                name: tool_call.name().to_string(),
190                                arguments: tool_call.args().unwrap_or_default().to_string(),
191                            },
192                        })
193                        .collect::<Vec<_>>(),
194                );
195            }
196
197            builder.build()?.into()
198        }
199    };
200
201    Ok(openai_message)
202}