swiftide_integrations/ollama/
mod.rs

1//! This module provides integration with `Ollama`'s API, enabling the use of language models and
2//! embeddings within the Swiftide project. It includes the `Ollama` struct for managing API clients
3//! and default options for embedding and prompt models. The module is conditionally compiled based
4//! on the "ollama" feature flag.
5
6use config::OllamaConfig;
7use derive_builder::Builder;
8use std::sync::Arc;
9
10pub mod chat_completion;
11pub mod config;
12pub mod embed;
13pub mod simple_prompt;
14
15/// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt
16/// models. It uses the `Builder` pattern for flexible and customizable instantiation.
17///
18/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that either a prompt
19/// model or embedding model always need to be set, either with
20/// [`Ollama::with_default_prompt_model`] or [`Ollama::with_default_embed_model`] or via the
21/// builder. You can find available models in the Ollama documentation.
22///
23/// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means
24/// some features might not work as expected. See the Ollama documentation for details.
25#[derive(Debug, Builder, Clone)]
26#[builder(setter(into, strip_option))]
27pub struct Ollama {
28    /// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting.
29    #[builder(default = "default_client()", setter(custom))]
30    client: Arc<async_openai::Client<OllamaConfig>>,
31    /// Default options for the embedding and prompt models.
32    #[builder(default)]
33    default_options: Options,
34}
35
36impl Default for Ollama {
37    fn default() -> Self {
38        Self {
39            client: default_client(),
40            default_options: Options::default(),
41        }
42    }
43}
44
45/// The `Options` struct holds configuration options for the `Ollama` client.
46/// It includes optional fields for specifying the embedding and prompt models.
47#[derive(Debug, Default, Clone, Builder)]
48#[builder(setter(into, strip_option))]
49pub struct Options {
50    /// The default embedding model to use, if specified.
51    #[builder(default)]
52    pub embed_model: Option<String>,
53
54    /// The default prompt model to use, if specified.
55    #[builder(default)]
56    pub prompt_model: Option<String>,
57}
58
59impl Options {
60    /// Creates a new `OptionsBuilder` for constructing `Options` instances.
61    pub fn builder() -> OptionsBuilder {
62        OptionsBuilder::default()
63    }
64}
65
66impl Ollama {
67    /// Creates a new `OllamaBuilder` for constructing `Ollama` instances.
68    pub fn builder() -> OllamaBuilder {
69        OllamaBuilder::default()
70    }
71
72    /// Sets a default prompt model to use when prompting
73    pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
74        self.default_options = Options {
75            prompt_model: Some(model.into()),
76            embed_model: self.default_options.embed_model.clone(),
77        };
78        self
79    }
80
81    /// Sets a default embedding model to use when embedding
82    pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
83        self.default_options = Options {
84            prompt_model: self.default_options.prompt_model.clone(),
85            embed_model: Some(model.into()),
86        };
87        self
88    }
89}
90
91impl OllamaBuilder {
92    /// Sets the `Ollama` client for the `Ollama` instance.
93    ///
94    /// # Parameters
95    /// - `client`: The `Ollama` client to set.
96    ///
97    /// # Returns
98    /// A mutable reference to the `OllamaBuilder`.
99    pub fn client(&mut self, client: async_openai::Client<OllamaConfig>) -> &mut Self {
100        self.client = Some(Arc::new(client));
101        self
102    }
103
104    /// Sets the default embedding model for the `Ollama` instance.
105    ///
106    /// # Parameters
107    /// - `model`: The embedding model to set.
108    ///
109    /// # Returns
110    /// A mutable reference to the `OllamaBuilder`.
111    pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
112        if let Some(options) = self.default_options.as_mut() {
113            options.embed_model = Some(model.into());
114        } else {
115            self.default_options = Some(Options {
116                embed_model: Some(model.into()),
117                ..Default::default()
118            });
119        }
120        self
121    }
122
123    /// Sets the default prompt model for the `Ollama` instance.
124    ///
125    /// # Parameters
126    /// - `model`: The prompt model to set.
127    ///
128    /// # Returns
129    /// A mutable reference to the `OllamaBuilder`.
130    pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
131        if let Some(options) = self.default_options.as_mut() {
132            options.prompt_model = Some(model.into());
133        } else {
134            self.default_options = Some(Options {
135                prompt_model: Some(model.into()),
136                ..Default::default()
137            });
138        }
139        self
140    }
141}
142
143fn default_client() -> Arc<async_openai::Client<OllamaConfig>> {
144    Arc::new(async_openai::Client::with_config(OllamaConfig::default()))
145}
146
147#[cfg(test)]
148mod test {
149    use super::*;
150
151    #[test]
152    fn test_default_prompt_model() {
153        let openai = Ollama::builder()
154            .default_prompt_model("llama3.1")
155            .build()
156            .unwrap();
157        assert_eq!(
158            openai.default_options.prompt_model,
159            Some("llama3.1".to_string())
160        );
161    }
162
163    #[test]
164    fn test_default_embed_model() {
165        let ollama = Ollama::builder()
166            .default_embed_model("mxbai-embed-large")
167            .build()
168            .unwrap();
169        assert_eq!(
170            ollama.default_options.embed_model,
171            Some("mxbai-embed-large".to_string())
172        );
173    }
174
175    #[test]
176    fn test_default_models() {
177        let ollama = Ollama::builder()
178            .default_embed_model("mxbai-embed-large")
179            .default_prompt_model("llama3.1")
180            .build()
181            .unwrap();
182        assert_eq!(
183            ollama.default_options.embed_model,
184            Some("mxbai-embed-large".to_string())
185        );
186        assert_eq!(
187            ollama.default_options.prompt_model,
188            Some("llama3.1".to_string())
189        );
190    }
191
192    #[test]
193    fn test_building_via_default_prompt_model() {
194        let mut client = Ollama::default();
195
196        assert!(client.default_options.prompt_model.is_none());
197
198        client.with_default_prompt_model("llama3.1");
199        assert_eq!(
200            client.default_options.prompt_model,
201            Some("llama3.1".to_string())
202        );
203    }
204
205    #[test]
206    fn test_building_via_default_embed_model() {
207        let mut client = Ollama::default();
208
209        assert!(client.default_options.embed_model.is_none());
210
211        client.with_default_embed_model("mxbai-embed-large");
212        assert_eq!(
213            client.default_options.embed_model,
214            Some("mxbai-embed-large".to_string())
215        );
216    }
217
218    #[test]
219    fn test_building_via_default_models() {
220        let mut client = Ollama::default();
221
222        assert!(client.default_options.embed_model.is_none());
223
224        client.with_default_prompt_model("llama3.1");
225        client.with_default_embed_model("mxbai-embed-large");
226        assert_eq!(
227            client.default_options.prompt_model,
228            Some("llama3.1".to_string())
229        );
230        assert_eq!(
231            client.default_options.embed_model,
232            Some("mxbai-embed-large".to_string())
233        );
234    }
235}