swiftide_integrations/ollama/
chat_completion.rs1use anyhow::{Context as _, Result};
2use async_openai::types::{
3 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
4 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
5 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
6 ChatCompletionToolType, CreateChatCompletionRequestArgs, FunctionCall, FunctionObjectArgs,
7};
8use async_trait::async_trait;
9use itertools::Itertools;
10use serde_json::json;
11use swiftide_core::chat_completion::{
12 errors::LanguageModelError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
13 ChatMessage, ToolCall, ToolSpec,
14};
15
16use crate::openai::openai_error_to_language_model_error;
17
18use super::Ollama;
19
20#[async_trait]
21impl ChatCompletion for Ollama {
22 #[tracing::instrument(skip_all)]
23 async fn complete(
24 &self,
25 request: &ChatCompletionRequest,
26 ) -> Result<ChatCompletionResponse, LanguageModelError> {
27 let model = self
28 .default_options
29 .prompt_model
30 .as_ref()
31 .context("Model not set")?;
32
33 let messages = request
34 .messages()
35 .iter()
36 .map(message_to_openai)
37 .collect::<Result<Vec<_>>>()?;
38
39 let mut openai_request = CreateChatCompletionRequestArgs::default()
41 .model(model)
42 .messages(messages)
43 .to_owned();
44
45 if !request.tools_spec.is_empty() {
46 openai_request
47 .tools(
48 request
49 .tools_spec()
50 .iter()
51 .map(tools_to_openai)
52 .collect::<Result<Vec<_>>>()?,
53 )
54 .tool_choice("auto")
55 .parallel_tool_calls(true);
56 }
57
58 let request = openai_request
59 .build()
60 .map_err(openai_error_to_language_model_error)?;
61
62 tracing::debug!(
63 model = &model,
64 request = serde_json::to_string_pretty(&request).expect("infallible"),
65 "Sending request to Ollama"
66 );
67
68 let response = self
69 .client
70 .chat()
71 .create(request)
72 .await
73 .map_err(openai_error_to_language_model_error)?;
74
75 tracing::debug!(
76 response = serde_json::to_string_pretty(&response).expect("infallible"),
77 "Received response from Ollama"
78 );
79
80 ChatCompletionResponse::builder()
81 .maybe_message(
82 response
83 .choices
84 .first()
85 .and_then(|choice| choice.message.content.clone()),
86 )
87 .maybe_tool_calls(
88 response
89 .choices
90 .first()
91 .and_then(|choice| choice.message.tool_calls.clone())
92 .map(|tool_calls| {
93 tool_calls
94 .iter()
95 .map(|tool_call| {
96 ToolCall::builder()
97 .id(tool_call.id.clone())
98 .args(tool_call.function.arguments.clone())
99 .name(tool_call.function.name.clone())
100 .build()
101 .expect("infallible")
102 })
103 .collect_vec()
104 }),
105 )
106 .build()
107 .map_err(LanguageModelError::from)
108 }
109}
110
111fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
114 let mut properties = serde_json::Map::new();
115
116 for param in &spec.parameters {
117 properties.insert(
118 param.name.to_string(),
119 json!({
120 "type": param.ty.as_ref(),
121 "description": ¶m.description,
122 }),
123 );
124 }
125
126 ChatCompletionToolArgs::default()
127 .r#type(ChatCompletionToolType::Function)
128 .function(FunctionObjectArgs::default()
129 .name(&spec.name)
130 .description(&spec.description)
131 .parameters(json!({
132 "type": "object",
133 "properties": properties,
134 "required": spec.parameters.iter().filter(|param| param.required).map(|param| ¶m.name).collect_vec(),
135 "additionalProperties": false,
136 })).build()?).build()
137 .map_err(anyhow::Error::from)
138}
139
140fn message_to_openai(
141 message: &ChatMessage,
142) -> Result<async_openai::types::ChatCompletionRequestMessage> {
143 let openai_message = match message {
144 ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
145 .content(msg.as_str())
146 .build()?
147 .into(),
148 ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
149 .content(msg.as_str())
150 .build()?
151 .into(),
152 ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
153 .content(msg.as_str())
154 .build()?
155 .into(),
156 ChatMessage::ToolOutput(tool_call, tool_output) => {
157 let Some(content) = tool_output.content() else {
158 return Ok(ChatCompletionRequestToolMessageArgs::default()
159 .tool_call_id(tool_call.id())
160 .build()?
161 .into());
162 };
163
164 ChatCompletionRequestToolMessageArgs::default()
165 .content(content)
166 .tool_call_id(tool_call.id())
167 .build()?
168 .into()
169 }
170 ChatMessage::Assistant(msg, tool_calls) => {
171 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
172
173 if let Some(msg) = msg {
174 builder.content(msg.as_str());
175 }
176
177 if let Some(tool_calls) = tool_calls {
178 builder.tool_calls(
179 tool_calls
180 .iter()
181 .map(|tool_call| ChatCompletionMessageToolCall {
182 id: tool_call.id().to_string(),
183 r#type: ChatCompletionToolType::Function,
184 function: FunctionCall {
185 name: tool_call.name().to_string(),
186 arguments: tool_call.args().unwrap_or_default().to_string(),
187 },
188 })
189 .collect::<Vec<_>>(),
190 );
191 }
192
193 builder.build()?.into()
194 }
195 };
196
197 Ok(openai_message)
198}