swiftide_integrations/ollama/
mod.rs1use config::OllamaConfig;
7use derive_builder::Builder;
8use std::sync::Arc;
9
10pub mod chat_completion;
11pub mod config;
12pub mod embed;
13pub mod simple_prompt;
14
15#[derive(Debug, Builder, Clone)]
26#[builder(setter(into, strip_option))]
27pub struct Ollama {
28 #[builder(default = "default_client()", setter(custom))]
30 client: Arc<async_openai::Client<OllamaConfig>>,
31 #[builder(default)]
33 default_options: Options,
34}
35
36impl Default for Ollama {
37 fn default() -> Self {
38 Self {
39 client: default_client(),
40 default_options: Options::default(),
41 }
42 }
43}
44
45#[derive(Debug, Default, Clone, Builder)]
48#[builder(setter(into, strip_option))]
49pub struct Options {
50 #[builder(default)]
52 pub embed_model: Option<String>,
53
54 #[builder(default)]
56 pub prompt_model: Option<String>,
57}
58
59impl Options {
60 pub fn builder() -> OptionsBuilder {
62 OptionsBuilder::default()
63 }
64}
65
66impl Ollama {
67 pub fn builder() -> OllamaBuilder {
69 OllamaBuilder::default()
70 }
71
72 pub fn with_default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
74 self.default_options = Options {
75 prompt_model: Some(model.into()),
76 embed_model: self.default_options.embed_model.clone(),
77 };
78 self
79 }
80
81 pub fn with_default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
83 self.default_options = Options {
84 prompt_model: self.default_options.prompt_model.clone(),
85 embed_model: Some(model.into()),
86 };
87 self
88 }
89}
90
91impl OllamaBuilder {
92 pub fn client(&mut self, client: async_openai::Client<OllamaConfig>) -> &mut Self {
100 self.client = Some(Arc::new(client));
101 self
102 }
103
104 pub fn default_embed_model(&mut self, model: impl Into<String>) -> &mut Self {
112 if let Some(options) = self.default_options.as_mut() {
113 options.embed_model = Some(model.into());
114 } else {
115 self.default_options = Some(Options {
116 embed_model: Some(model.into()),
117 ..Default::default()
118 });
119 }
120 self
121 }
122
123 pub fn default_prompt_model(&mut self, model: impl Into<String>) -> &mut Self {
131 if let Some(options) = self.default_options.as_mut() {
132 options.prompt_model = Some(model.into());
133 } else {
134 self.default_options = Some(Options {
135 prompt_model: Some(model.into()),
136 ..Default::default()
137 });
138 }
139 self
140 }
141}
142
143fn default_client() -> Arc<async_openai::Client<OllamaConfig>> {
144 Arc::new(async_openai::Client::with_config(OllamaConfig::default()))
145}
146
147#[cfg(test)]
148mod test {
149 use super::*;
150
151 #[test]
152 fn test_default_prompt_model() {
153 let openai = Ollama::builder()
154 .default_prompt_model("llama3.1")
155 .build()
156 .unwrap();
157 assert_eq!(
158 openai.default_options.prompt_model,
159 Some("llama3.1".to_string())
160 );
161 }
162
163 #[test]
164 fn test_default_embed_model() {
165 let ollama = Ollama::builder()
166 .default_embed_model("mxbai-embed-large")
167 .build()
168 .unwrap();
169 assert_eq!(
170 ollama.default_options.embed_model,
171 Some("mxbai-embed-large".to_string())
172 );
173 }
174
175 #[test]
176 fn test_default_models() {
177 let ollama = Ollama::builder()
178 .default_embed_model("mxbai-embed-large")
179 .default_prompt_model("llama3.1")
180 .build()
181 .unwrap();
182 assert_eq!(
183 ollama.default_options.embed_model,
184 Some("mxbai-embed-large".to_string())
185 );
186 assert_eq!(
187 ollama.default_options.prompt_model,
188 Some("llama3.1".to_string())
189 );
190 }
191
192 #[test]
193 fn test_building_via_default_prompt_model() {
194 let mut client = Ollama::default();
195
196 assert!(client.default_options.prompt_model.is_none());
197
198 client.with_default_prompt_model("llama3.1");
199 assert_eq!(
200 client.default_options.prompt_model,
201 Some("llama3.1".to_string())
202 );
203 }
204
205 #[test]
206 fn test_building_via_default_embed_model() {
207 let mut client = Ollama::default();
208
209 assert!(client.default_options.embed_model.is_none());
210
211 client.with_default_embed_model("mxbai-embed-large");
212 assert_eq!(
213 client.default_options.embed_model,
214 Some("mxbai-embed-large".to_string())
215 );
216 }
217
218 #[test]
219 fn test_building_via_default_models() {
220 let mut client = Ollama::default();
221
222 assert!(client.default_options.embed_model.is_none());
223
224 client.with_default_prompt_model("llama3.1");
225 client.with_default_embed_model("mxbai-embed-large");
226 assert_eq!(
227 client.default_options.prompt_model,
228 Some("llama3.1".to_string())
229 );
230 assert_eq!(
231 client.default_options.embed_model,
232 Some("mxbai-embed-large".to_string())
233 );
234 }
235}