swiftide_integrations/ollama/
chat_completion.rs

1use anyhow::{Context as _, Result};
2use async_openai::types::{
3    ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
4    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
5    ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
6    ChatCompletionToolType, CreateChatCompletionRequestArgs, FunctionCall, FunctionObjectArgs,
7};
8use async_trait::async_trait;
9use itertools::Itertools;
10use serde_json::json;
11use swiftide_core::chat_completion::{
12    errors::LanguageModelError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
13    ChatMessage, ToolCall, ToolSpec,
14};
15
16use crate::openai::openai_error_to_language_model_error;
17
18use super::Ollama;
19
20#[async_trait]
21impl ChatCompletion for Ollama {
22    #[tracing::instrument(skip_all)]
23    async fn complete(
24        &self,
25        request: &ChatCompletionRequest,
26    ) -> Result<ChatCompletionResponse, LanguageModelError> {
27        let model = self
28            .default_options
29            .prompt_model
30            .as_ref()
31            .context("Model not set")?;
32
33        let messages = request
34            .messages()
35            .iter()
36            .map(message_to_openai)
37            .collect::<Result<Vec<_>>>()?;
38
39        // Build the request to be sent to the OpenAI API.
40        let mut openai_request = CreateChatCompletionRequestArgs::default()
41            .model(model)
42            .messages(messages)
43            .to_owned();
44
45        if !request.tools_spec.is_empty() {
46            openai_request
47                .tools(
48                    request
49                        .tools_spec()
50                        .iter()
51                        .map(tools_to_openai)
52                        .collect::<Result<Vec<_>>>()?,
53                )
54                .tool_choice("auto")
55                .parallel_tool_calls(true);
56        }
57
58        let request = openai_request
59            .build()
60            .map_err(openai_error_to_language_model_error)?;
61
62        tracing::debug!(
63            model = &model,
64            request = serde_json::to_string_pretty(&request).expect("infallible"),
65            "Sending request to Ollama"
66        );
67
68        let response = self
69            .client
70            .chat()
71            .create(request)
72            .await
73            .map_err(openai_error_to_language_model_error)?;
74
75        tracing::debug!(
76            response = serde_json::to_string_pretty(&response).expect("infallible"),
77            "Received response from Ollama"
78        );
79
80        ChatCompletionResponse::builder()
81            .maybe_message(
82                response
83                    .choices
84                    .first()
85                    .and_then(|choice| choice.message.content.clone()),
86            )
87            .maybe_tool_calls(
88                response
89                    .choices
90                    .first()
91                    .and_then(|choice| choice.message.tool_calls.clone())
92                    .map(|tool_calls| {
93                        tool_calls
94                            .iter()
95                            .map(|tool_call| {
96                                ToolCall::builder()
97                                    .id(tool_call.id.clone())
98                                    .args(tool_call.function.arguments.clone())
99                                    .name(tool_call.function.name.clone())
100                                    .build()
101                                    .expect("infallible")
102                            })
103                            .collect_vec()
104                    }),
105            )
106            .build()
107            .map_err(LanguageModelError::from)
108    }
109}
110
111// TODO: Maybe just into the whole thing? Types are not in this crate
112
113fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
114    let mut properties = serde_json::Map::new();
115
116    for param in &spec.parameters {
117        properties.insert(
118            param.name.to_string(),
119            json!({
120                "type": param.ty.as_ref(),
121                "description": &param.description,
122            }),
123        );
124    }
125
126    ChatCompletionToolArgs::default()
127        .r#type(ChatCompletionToolType::Function)
128        .function(FunctionObjectArgs::default()
129            .name(&spec.name)
130            .description(&spec.description)
131            .parameters(json!({
132                "type": "object",
133                "properties": properties,
134                "required": spec.parameters.iter().filter(|param| param.required).map(|param| &param.name).collect_vec(),
135                "additionalProperties": false,
136            })).build()?).build()
137        .map_err(anyhow::Error::from)
138}
139
140fn message_to_openai(
141    message: &ChatMessage,
142) -> Result<async_openai::types::ChatCompletionRequestMessage> {
143    let openai_message = match message {
144        ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
145            .content(msg.as_str())
146            .build()?
147            .into(),
148        ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
149            .content(msg.as_str())
150            .build()?
151            .into(),
152        ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
153            .content(msg.as_str())
154            .build()?
155            .into(),
156        ChatMessage::ToolOutput(tool_call, tool_output) => {
157            let Some(content) = tool_output.content() else {
158                return Ok(ChatCompletionRequestToolMessageArgs::default()
159                    .tool_call_id(tool_call.id())
160                    .build()?
161                    .into());
162            };
163
164            ChatCompletionRequestToolMessageArgs::default()
165                .content(content)
166                .tool_call_id(tool_call.id())
167                .build()?
168                .into()
169        }
170        ChatMessage::Assistant(msg, tool_calls) => {
171            let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
172
173            if let Some(msg) = msg {
174                builder.content(msg.as_str());
175            }
176
177            if let Some(tool_calls) = tool_calls {
178                builder.tool_calls(
179                    tool_calls
180                        .iter()
181                        .map(|tool_call| ChatCompletionMessageToolCall {
182                            id: tool_call.id().to_string(),
183                            r#type: ChatCompletionToolType::Function,
184                            function: FunctionCall {
185                                name: tool_call.name().to_string(),
186                                arguments: tool_call.args().unwrap_or_default().to_string(),
187                            },
188                        })
189                        .collect::<Vec<_>>(),
190                );
191            }
192
193            builder.build()?.into()
194        }
195    };
196
197    Ok(openai_message)
198}