swiftide_integrations/openai/
chat_completion.rs1use anyhow::{Context as _, Result};
2use async_openai::types::{
3 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
4 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
5 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
6 ChatCompletionToolType, CreateChatCompletionRequestArgs, FunctionCall, FunctionObjectArgs,
7};
8use async_trait::async_trait;
9use itertools::Itertools;
10use serde_json::json;
11use swiftide_core::chat_completion::{
12 errors::LanguageModelError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
13 ChatMessage, ToolCall, ToolSpec,
14};
15
16use super::openai_error_to_language_model_error;
17use super::GenericOpenAI;
18
19#[async_trait]
20impl<C: async_openai::config::Config + std::default::Default + Sync + Send + std::fmt::Debug>
21 ChatCompletion for GenericOpenAI<C>
22{
23 #[tracing::instrument(skip_all)]
24 async fn complete(
25 &self,
26 request: &ChatCompletionRequest,
27 ) -> Result<ChatCompletionResponse, LanguageModelError> {
28 let model = self
29 .default_options
30 .prompt_model
31 .as_ref()
32 .context("Model not set")?;
33
34 let messages = request
35 .messages()
36 .iter()
37 .map(message_to_openai)
38 .collect::<Result<Vec<_>>>()?;
39
40 let mut openai_request = CreateChatCompletionRequestArgs::default()
42 .model(model)
43 .messages(messages)
44 .to_owned();
45
46 if !request.tools_spec.is_empty() {
47 openai_request
48 .tools(
49 request
50 .tools_spec()
51 .iter()
52 .map(tools_to_openai)
53 .collect::<Result<Vec<_>>>()?,
54 )
55 .tool_choice("auto");
56 if let Some(par) = self.default_options.parallel_tool_calls {
57 openai_request.parallel_tool_calls(par);
58 }
59 }
60
61 let request = openai_request
62 .build()
63 .map_err(openai_error_to_language_model_error)?;
64
65 tracing::debug!(
66 model = &model,
67 request = serde_json::to_string_pretty(&request).expect("infallible"),
68 "Sending request to OpenAI"
69 );
70
71 let response = self
72 .client
73 .chat()
74 .create(request)
75 .await
76 .map_err(openai_error_to_language_model_error)?;
77
78 tracing::debug!(
79 response = serde_json::to_string_pretty(&response).expect("infallible"),
80 "Received response from OpenAI"
81 );
82
83 ChatCompletionResponse::builder()
84 .maybe_message(
85 response
86 .choices
87 .first()
88 .and_then(|choice| choice.message.content.clone()),
89 )
90 .maybe_tool_calls(
91 response
92 .choices
93 .first()
94 .and_then(|choice| choice.message.tool_calls.clone())
95 .map(|tool_calls| {
96 tool_calls
97 .iter()
98 .map(|tool_call| {
99 ToolCall::builder()
100 .id(tool_call.id.clone())
101 .args(tool_call.function.arguments.clone())
102 .name(tool_call.function.name.clone())
103 .build()
104 .expect("infallible")
105 })
106 .collect_vec()
107 }),
108 )
109 .build()
110 .map_err(LanguageModelError::from)
111 }
112}
113
114fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
117 let mut properties = serde_json::Map::new();
118
119 for param in &spec.parameters {
120 properties.insert(
121 param.name.to_string(),
122 json!({
123 "type": param.ty.as_ref(),
124 "description": ¶m.description,
125 }),
126 );
127 }
128
129 ChatCompletionToolArgs::default()
130 .r#type(ChatCompletionToolType::Function)
131 .function(FunctionObjectArgs::default()
132 .name(&spec.name)
133 .description(&spec.description)
134 .strict(true)
135 .parameters(json!({
136 "type": "object",
137 "properties": properties,
138 "required": spec.parameters.iter().filter(|param| param.required).map(|param| ¶m.name).collect_vec(),
139 "additionalProperties": false,
140 })).build()?).build()
141 .map_err(anyhow::Error::from)
142}
143
144fn message_to_openai(
145 message: &ChatMessage,
146) -> Result<async_openai::types::ChatCompletionRequestMessage> {
147 let openai_message = match message {
148 ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
149 .content(msg.as_str())
150 .build()?
151 .into(),
152 ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
153 .content(msg.as_str())
154 .build()?
155 .into(),
156 ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
157 .content(msg.as_str())
158 .build()?
159 .into(),
160 ChatMessage::ToolOutput(tool_call, tool_output) => {
161 let Some(content) = tool_output.content() else {
162 return Ok(ChatCompletionRequestToolMessageArgs::default()
163 .tool_call_id(tool_call.id())
164 .build()?
165 .into());
166 };
167
168 ChatCompletionRequestToolMessageArgs::default()
169 .content(content)
170 .tool_call_id(tool_call.id())
171 .build()?
172 .into()
173 }
174 ChatMessage::Assistant(msg, tool_calls) => {
175 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
176
177 if let Some(msg) = msg {
178 builder.content(msg.as_str());
179 }
180
181 if let Some(tool_calls) = tool_calls {
182 builder.tool_calls(
183 tool_calls
184 .iter()
185 .map(|tool_call| ChatCompletionMessageToolCall {
186 id: tool_call.id().to_string(),
187 r#type: ChatCompletionToolType::Function,
188 function: FunctionCall {
189 name: tool_call.name().to_string(),
190 arguments: tool_call.args().unwrap_or_default().to_string(),
191 },
192 })
193 .collect::<Vec<_>>(),
194 );
195 }
196
197 builder.build()?.into()
198 }
199 };
200
201 Ok(openai_message)
202}