swiftide_integrations/ollama/simple_prompt.rs
1//! This module provides an implementation of the `SimplePrompt` trait for the `Ollama` struct.
2//! It defines an asynchronous function to interact with the `Ollama` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
5use async_trait::async_trait;
6use swiftide_core::{
7 chat_completion::errors::LanguageModelError, prompt::Prompt, util::debug_long_utf8,
8 SimplePrompt,
9};
10
11use crate::openai::openai_error_to_language_model_error;
12
13use super::Ollama;
14use anyhow::{Context as _, Result};
15
16/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
17/// response.
18#[async_trait]
19impl SimplePrompt for Ollama {
20 /// Sends a prompt to the Ollama API and returns the response content.
21 ///
22 /// # Parameters
23 /// - `prompt`: A string slice that holds the prompt to be sent to the Ollama API.
24 ///
25 /// # Returns
26 /// - `Result<String>`: On success, returns the content of the response as a `String`. On
27 /// failure, returns an error wrapped in a `Result`.
28 ///
29 /// # Errors
30 /// - Returns an error if the model is not set in the default options.
31 /// - Returns an error if the request to the Ollama API fails.
32 /// - Returns an error if the response does not contain the expected content.
33 #[tracing::instrument(skip_all, err)]
34 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
35 // Retrieve the model from the default options, returning an error if not set.
36 let model = self
37 .default_options
38 .prompt_model
39 .as_ref()
40 .context("Model not set")?;
41
42 // Build the request to be sent to the Ollama API.
43 let request = CreateChatCompletionRequestArgs::default()
44 .model(model)
45 .messages(vec![ChatCompletionRequestUserMessageArgs::default()
46 .content(prompt.render()?)
47 .build()
48 .map_err(openai_error_to_language_model_error)?
49 .into()])
50 .build()
51 .map_err(openai_error_to_language_model_error)?;
52
53 // Log the request for debugging purposes.
54 tracing::debug!(
55 model = &model,
56 messages = debug_long_utf8(
57 serde_json::to_string_pretty(&request.messages.first())
58 .map_err(LanguageModelError::permanent)?,
59 100
60 ),
61 "[SimplePrompt] Request to ollama"
62 );
63
64 // Send the request to the Ollama API and await the response.
65 let response = self
66 .client
67 .chat()
68 .create(request)
69 .await
70 .map_err(openai_error_to_language_model_error)?
71 .choices
72 .remove(0)
73 .message
74 .content
75 .take()
76 .context("Expected content in response")?;
77
78 // Log the response for debugging purposes.
79 tracing::debug!(
80 response = debug_long_utf8(&response, 100),
81 "[SimplePrompt] Response from ollama"
82 );
83
84 // Extract and return the content of the response, returning an error if not found.
85 Ok(response)
86 }
87}