skill_runtime/embeddings/
ollama.rs

1//! Ollama embedding provider implementation
2//!
3//! Uses rig-core's Ollama client for local LLM-based embeddings.
4//! Requires a running Ollama server (default: http://localhost:11434).
5
6use super::EmbeddingProvider;
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use rig::embeddings::EmbeddingModel as RigEmbeddingModel;
10use rig::client::{EmbeddingsClient, ProviderClient, Nothing};
11use rig::providers::ollama::Client as OllamaClient;
12use std::sync::Arc;
13
14/// Default Ollama embedding model
15pub const DEFAULT_OLLAMA_MODEL: &str = "nomic-embed-text";
16
17/// Default Ollama server URL
18pub const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
19
20/// Known Ollama embedding model dimensions
21fn get_model_dimensions(model: &str) -> usize {
22    match model {
23        "nomic-embed-text" => 768,
24        "mxbai-embed-large" => 1024,
25        "all-minilm" => 384,
26        "snowflake-arctic-embed" => 1024,
27        _ => 768, // Default assumption
28    }
29}
30
31/// Ollama embedding provider
32///
33/// Generates embeddings via a local Ollama server.
34/// Requires Ollama to be running with an embedding model pulled.
35pub struct OllamaProvider {
36    client: Arc<OllamaClient>,
37    model: String,
38    dims: usize,
39    base_url: String,
40}
41
42impl OllamaProvider {
43    /// Create a new Ollama provider with the default model (nomic-embed-text)
44    pub fn new() -> Result<Self> {
45        Self::with_model(DEFAULT_OLLAMA_MODEL)
46    }
47
48    /// Create a new Ollama provider with a specific model
49    pub fn with_model(model: &str) -> Result<Self> {
50        let client = Arc::new(OllamaClient::from_val(Nothing));
51        let dims = get_model_dimensions(model);
52
53        Ok(Self {
54            client,
55            model: model.to_string(),
56            dims,
57            base_url: DEFAULT_OLLAMA_URL.to_string(),
58        })
59    }
60
61    /// Create with a custom base URL
62    pub fn with_url(base_url: &str, model: &str) -> Result<Self> {
63        // Note: rig's Ollama client uses OLLAMA_API_BASE env var or default
64        // For custom URL, we still create with Nothing and hope the env is set
65        let client = Arc::new(OllamaClient::from_val(Nothing));
66        let dims = get_model_dimensions(model);
67
68        Ok(Self {
69            client,
70            model: model.to_string(),
71            dims,
72            base_url: base_url.to_string(),
73        })
74    }
75
76    /// Create with custom dimensions (for models not in our known list)
77    pub fn with_dimensions(model: &str, dims: usize) -> Result<Self> {
78        let client = Arc::new(OllamaClient::from_val(Nothing));
79
80        Ok(Self {
81            client,
82            model: model.to_string(),
83            dims,
84            base_url: DEFAULT_OLLAMA_URL.to_string(),
85        })
86    }
87
88    /// Get the base URL
89    pub fn base_url(&self) -> &str {
90        &self.base_url
91    }
92
93}
94
95impl Default for OllamaProvider {
96    fn default() -> Self {
97        Self::new().expect("Failed to create default Ollama provider")
98    }
99}
100
101#[async_trait]
102impl EmbeddingProvider for OllamaProvider {
103    async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
104        if texts.is_empty() {
105            return Ok(Vec::new());
106        }
107
108        let embedding_model = self.client.embedding_model(&self.model);
109
110        // Use rig's embed method
111        let embeddings = embedding_model
112            .embed_texts(texts)
113            .await
114            .context("Ollama failed to generate embeddings. Is the server running?")?;
115
116        // Convert from rig's Embedding type to Vec<f32>
117        let results: Vec<Vec<f32>> = embeddings
118            .into_iter()
119            .map(|emb| emb.vec.into_iter().map(|x| x as f32).collect())
120            .collect();
121
122        // Update dimensions if we got a different size (model might have changed)
123        if let Some(first) = results.first() {
124            if first.len() != self.dims {
125                tracing::warn!(
126                    "Ollama model {} returned {} dimensions, expected {}",
127                    self.model,
128                    first.len(),
129                    self.dims
130                );
131            }
132        }
133
134        Ok(results)
135    }
136
137    fn dimensions(&self) -> usize {
138        self.dims
139    }
140
141    fn model_name(&self) -> &str {
142        &self.model
143    }
144
145    fn provider_name(&self) -> &str {
146        "ollama"
147    }
148
149    fn max_batch_size(&self) -> usize {
150        // Ollama processes one at a time internally, but we batch for convenience
151        100
152    }
153
154    async fn health_check(&self) -> Result<bool> {
155        // Try to embed a simple text to verify server is running
156        match self.embed_query("test").await {
157            Ok(_) => Ok(true),
158            Err(e) => {
159                tracing::debug!("Ollama health check failed: {}", e);
160                Ok(false)
161            }
162        }
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_model_dimensions() {
172        assert_eq!(get_model_dimensions("nomic-embed-text"), 768);
173        assert_eq!(get_model_dimensions("mxbai-embed-large"), 1024);
174        assert_eq!(get_model_dimensions("all-minilm"), 384);
175        assert_eq!(get_model_dimensions("unknown-model"), 768); // Default
176    }
177
178    #[test]
179    fn test_provider_creation() {
180        let provider = OllamaProvider::new().unwrap();
181        assert_eq!(provider.model_name(), "nomic-embed-text");
182        assert_eq!(provider.dimensions(), 768);
183        assert_eq!(provider.provider_name(), "ollama");
184        assert_eq!(provider.base_url(), DEFAULT_OLLAMA_URL);
185    }
186
187    #[test]
188    fn test_custom_url() {
189        let provider = OllamaProvider::with_url("http://custom:11434", "nomic-embed-text").unwrap();
190        assert_eq!(provider.base_url(), "http://custom:11434");
191    }
192
193    #[test]
194    fn test_custom_dimensions() {
195        let provider = OllamaProvider::with_dimensions("custom-model", 512).unwrap();
196        assert_eq!(provider.dimensions(), 512);
197        assert_eq!(provider.model_name(), "custom-model");
198    }
199
200    // Integration test - requires running Ollama server
201    #[tokio::test]
202    #[ignore = "requires running Ollama server"]
203    async fn test_embed_documents() {
204        let provider = OllamaProvider::new().unwrap();
205        let texts = vec![
206            "Hello world".to_string(),
207            "How are you".to_string(),
208        ];
209
210        let embeddings = provider.embed_documents(texts).await.unwrap();
211        assert_eq!(embeddings.len(), 2);
212    }
213
214    #[tokio::test]
215    async fn test_embed_empty() {
216        let provider = OllamaProvider::new().unwrap();
217        let embeddings = provider.embed_documents(vec![]).await.unwrap();
218        assert!(embeddings.is_empty());
219    }
220}