skill_runtime/embeddings/
ollama.rs1use super::EmbeddingProvider;
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
10use rig::client::{EmbeddingsClient, ProviderClient, Nothing};
11use rig::providers::ollama::Client as OllamaClient;
12use std::sync::Arc;
13
14pub const DEFAULT_OLLAMA_MODEL: &str = "nomic-embed-text";
16
17pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
19
20fn get_model_dimensions(model: &str) -> usize {
22 match model {
23 "nomic-embed-text" => 768,
24 "mxbai-embed-large" => 1024,
25 "all-minilm" => 384,
26 "snowflake-arctic-embed" => 1024,
27 _ => 768, }
29}
30
31pub struct OllamaProvider {
36 client: Arc<OllamaClient>,
37 model: String,
38 dims: usize,
39 base_url: String,
40}
41
42impl OllamaProvider {
43 pub fn new() -> Result<Self> {
45 Self::with_model(DEFAULT_OLLAMA_MODEL)
46 }
47
48 pub fn with_model(model: &str) -> Result<Self> {
50 let client = Arc::new(OllamaClient::from_val(Nothing));
51 let dims = get_model_dimensions(model);
52
53 Ok(Self {
54 client,
55 model: model.to_string(),
56 dims,
57 base_url: DEFAULT_OLLAMA_URL.to_string(),
58 })
59 }
60
61 pub fn with_url(base_url: &str, model: &str) -> Result<Self> {
63 let client = Arc::new(OllamaClient::from_val(Nothing));
66 let dims = get_model_dimensions(model);
67
68 Ok(Self {
69 client,
70 model: model.to_string(),
71 dims,
72 base_url: base_url.to_string(),
73 })
74 }
75
76 pub fn with_dimensions(model: &str, dims: usize) -> Result<Self> {
78 let client = Arc::new(OllamaClient::from_val(Nothing));
79
80 Ok(Self {
81 client,
82 model: model.to_string(),
83 dims,
84 base_url: DEFAULT_OLLAMA_URL.to_string(),
85 })
86 }
87
88 pub fn base_url(&self) -> &str {
90 &self.base_url
91 }
92
93}
94
95impl Default for OllamaProvider {
96 fn default() -> Self {
97 Self::new().expect("Failed to create default Ollama provider")
98 }
99}
100
101#[async_trait]
102impl EmbeddingProvider for OllamaProvider {
103 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
104 if texts.is_empty() {
105 return Ok(Vec::new());
106 }
107
108 let embedding_model = self.client.embedding_model(&self.model);
109
110 let embeddings = embedding_model
112 .embed_texts(texts)
113 .await
114 .context("Ollama failed to generate embeddings. Is the server running?")?;
115
116 let results: Vec<Vec<f32>> = embeddings
118 .into_iter()
119 .map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
120 .collect();
121
122 if let Some(first) = results.first() {
124 if first.len() != self.dims {
125 tracing::warn!(
126 "Ollama model {} returned {} dimensions, expected {}",
127 self.model,
128 first.len(),
129 self.dims
130 );
131 }
132 }
133
134 Ok(results)
135 }
136
137 fn dimensions(&self) -> usize {
138 self.dims
139 }
140
141 fn model_name(&self) -> &str {
142 &self.model
143 }
144
145 fn provider_name(&self) -> &str {
146 "ollama"
147 }
148
149 fn max_batch_size(&self) -> usize {
150 100
152 }
153
154 async fn health_check(&self) -> Result<bool> {
155 match self.embed_query("test").await {
157 Ok(_) => Ok(true),
158 Err(e) => {
159 tracing::debug!("Ollama health check failed: {}", e);
160 Ok(false)
161 }
162 }
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_model_dimensions() {
172 assert_eq!(get_model_dimensions("nomic-embed-text"), 768);
173 assert_eq!(get_model_dimensions("mxbai-embed-large"), 1024);
174 assert_eq!(get_model_dimensions("all-minilm"), 384);
175 assert_eq!(get_model_dimensions("unknown-model"), 768); }
177
178 #[test]
179 fn test_provider_creation() {
180 let provider = OllamaProvider::new().unwrap();
181 assert_eq!(provider.model_name(), "nomic-embed-text");
182 assert_eq!(provider.dimensions(), 768);
183 assert_eq!(provider.provider_name(), "ollama");
184 assert_eq!(provider.base_url(), DEFAULT_OLLAMA_URL);
185 }
186
187 #[test]
188 fn test_custom_url() {
189 let provider = OllamaProvider::with_url("http://custom:11434", "nomic-embed-text").unwrap();
190 assert_eq!(provider.base_url(), "http://custom:11434");
191 }
192
193 #[test]
194 fn test_custom_dimensions() {
195 let provider = OllamaProvider::with_dimensions("custom-model", 512).unwrap();
196 assert_eq!(provider.dimensions(), 512);
197 assert_eq!(provider.model_name(), "custom-model");
198 }
199
200 #[tokio::test]
202 #[ignore = "requires running Ollama server"]
203 async fn test_embed_documents() {
204 let provider = OllamaProvider::new().unwrap();
205 let texts = vec![
206 "Hello world".to_string(),
207 "How are you".to_string(),
208 ];
209
210 let embeddings = provider.embed_documents(texts).await.unwrap();
211 assert_eq!(embeddings.len(), 2);
212 }
213
214 #[tokio::test]
215 async fn test_embed_empty() {
216 let provider = OllamaProvider::new().unwrap();
217 let embeddings = provider.embed_documents(vec![]).await.unwrap();
218 assert!(embeddings.is_empty());
219 }
220}