swiftide_integrations/ollama/
simple_prompt.rs

1//! This module provides an implementation of the `SimplePrompt` trait for the `Ollama` struct.
2//! It defines an asynchronous function to interact with the `Ollama` API, allowing prompt
3//! processing and generating responses as part of the Swiftide system.
4use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
5use async_trait::async_trait;
6use swiftide_core::{
7    chat_completion::errors::LanguageModelError, prompt::Prompt, util::debug_long_utf8,
8    SimplePrompt,
9};
10
11use crate::openai::openai_error_to_language_model_error;
12
13use super::Ollama;
14use anyhow::{Context as _, Result};
15
16/// The `SimplePrompt` trait defines a method for sending a prompt to an AI model and receiving a
17/// response.
18#[async_trait]
19impl SimplePrompt for Ollama {
20    /// Sends a prompt to the Ollama API and returns the response content.
21    ///
22    /// # Parameters
23    /// - `prompt`: A string slice that holds the prompt to be sent to the Ollama API.
24    ///
25    /// # Returns
26    /// - `Result<String>`: On success, returns the content of the response as a `String`. On
27    ///   failure, returns an error wrapped in a `Result`.
28    ///
29    /// # Errors
30    /// - Returns an error if the model is not set in the default options.
31    /// - Returns an error if the request to the Ollama API fails.
32    /// - Returns an error if the response does not contain the expected content.
33    #[tracing::instrument(skip_all, err)]
34    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
35        // Retrieve the model from the default options, returning an error if not set.
36        let model = self
37            .default_options
38            .prompt_model
39            .as_ref()
40            .context("Model not set")?;
41
42        // Build the request to be sent to the Ollama API.
43        let request = CreateChatCompletionRequestArgs::default()
44            .model(model)
45            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
46                .content(prompt.render()?)
47                .build()
48                .map_err(openai_error_to_language_model_error)?
49                .into()])
50            .build()
51            .map_err(openai_error_to_language_model_error)?;
52
53        // Log the request for debugging purposes.
54        tracing::debug!(
55            model = &model,
56            messages = debug_long_utf8(
57                serde_json::to_string_pretty(&request.messages.first())
58                    .map_err(LanguageModelError::permanent)?,
59                100
60            ),
61            "[SimplePrompt] Request to ollama"
62        );
63
64        // Send the request to the Ollama API and await the response.
65        let response = self
66            .client
67            .chat()
68            .create(request)
69            .await
70            .map_err(openai_error_to_language_model_error)?
71            .choices
72            .remove(0)
73            .message
74            .content
75            .take()
76            .context("Expected content in response")?;
77
78        // Log the response for debugging purposes.
79        tracing::debug!(
80            response = debug_long_utf8(&response, 100),
81            "[SimplePrompt] Response from ollama"
82        );
83
84        // Extract and return the content of the response, returning an error if not found.
85        Ok(response)
86    }
87}