swiftide_integrations/ollama/
mod.rs

1//! This module provides integration with `Ollama`'s API, enabling the use of language models and
2//! embeddings within the Swiftide project. It includes the `Ollama` struct for managing API clients
3//! and default options for embedding and prompt models. The module is conditionally compiled based
4//! on the "ollama" feature flag.
5
6use config::OllamaConfig;
7
8use crate::openai;
9
10pub mod config;
11
12/// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt
13/// models. It uses the `Builder` pattern for flexible and customizable instantiation.
14///
15/// By default it will look for a `OLLAMA_API_KEY` environment variable. Note that either a prompt
16/// model or embedding model always need to be set, either with
17/// [`Ollama::with_default_prompt_model`] or [`Ollama::with_default_embed_model`] or via the
18/// builder. You can find available models in the Ollama documentation.
19///
20/// Under the hood it uses [`async_openai`], with the Ollama openai mapping. This means
21/// some features might not work as expected. See the Ollama documentation for details.
22pub type Ollama = openai::GenericOpenAI<OllamaConfig>;
23pub type OllamaBuilder = openai::GenericOpenAIBuilder<OllamaConfig>;
24pub type OllamaBuilderError = openai::GenericOpenAIBuilderError;
25pub use openai::{Options, OptionsBuilder, OptionsBuilderError};
26
27impl Ollama {
28    /// Build a new `Ollama` instance
29    pub fn builder() -> OllamaBuilder {
30        OllamaBuilder::default()
31    }
32}
33impl Default for Ollama {
34    fn default() -> Self {
35        Self::builder().build().unwrap()
36    }
37}
38
39#[cfg(test)]
40mod test {
41    use super::*;
42
43    #[test]
44    fn test_default_prompt_model() {
45        let openai = Ollama::builder()
46            .default_prompt_model("llama3.1")
47            .build()
48            .unwrap();
49        assert_eq!(
50            openai.default_options.prompt_model,
51            Some("llama3.1".to_string())
52        );
53    }
54
55    #[test]
56    fn test_default_embed_model() {
57        let ollama = Ollama::builder()
58            .default_embed_model("mxbai-embed-large")
59            .build()
60            .unwrap();
61        assert_eq!(
62            ollama.default_options.embed_model,
63            Some("mxbai-embed-large".to_string())
64        );
65    }
66
67    #[test]
68    fn test_default_models() {
69        let ollama = Ollama::builder()
70            .default_embed_model("mxbai-embed-large")
71            .default_prompt_model("llama3.1")
72            .build()
73            .unwrap();
74        assert_eq!(
75            ollama.default_options.embed_model,
76            Some("mxbai-embed-large".to_string())
77        );
78        assert_eq!(
79            ollama.default_options.prompt_model,
80            Some("llama3.1".to_string())
81        );
82    }
83
84    #[test]
85    fn test_building_via_default_prompt_model() {
86        let mut client = Ollama::default();
87
88        assert!(client.default_options.prompt_model.is_none());
89
90        client.with_default_prompt_model("llama3.1");
91        assert_eq!(
92            client.default_options.prompt_model,
93            Some("llama3.1".to_string())
94        );
95    }
96
97    #[test]
98    fn test_building_via_default_embed_model() {
99        let mut client = Ollama::default();
100
101        assert!(client.default_options.embed_model.is_none());
102
103        client.with_default_embed_model("mxbai-embed-large");
104        assert_eq!(
105            client.default_options.embed_model,
106            Some("mxbai-embed-large".to_string())
107        );
108    }
109
110    #[test]
111    fn test_building_via_default_models() {
112        let mut client = Ollama::default();
113
114        assert!(client.default_options.embed_model.is_none());
115
116        client.with_default_prompt_model("llama3.1");
117        client.with_default_embed_model("mxbai-embed-large");
118        assert_eq!(
119            client.default_options.prompt_model,
120            Some("llama3.1".to_string())
121        );
122        assert_eq!(
123            client.default_options.embed_model,
124            Some("mxbai-embed-large".to_string())
125        );
126    }
127}