Skip to main content

sc/embeddings/
ollama.rs

1//! Ollama embedding provider.
2//!
3//! Uses local Ollama server for embedding generation.
4//! This is the recommended provider for local development.
5
6use crate::error::{Error, Result};
7use serde::{Deserialize, Serialize};
8
9use super::config::{resolve_ollama_endpoint, resolve_ollama_model};
10use super::provider::EmbeddingProvider;
11use super::types::{ollama_models, ProviderInfo};
12
13/// Ollama embedding provider.
14pub struct OllamaProvider {
15    client: reqwest::Client,
16    endpoint: String,
17    model: String,
18    dimensions: usize,
19    max_chars: usize,
20}
21
22impl OllamaProvider {
23    /// Create a new Ollama provider with default configuration.
24    pub fn new() -> Self {
25        Self::with_config(None, None)
26    }
27
28    /// Create a new Ollama provider with custom configuration.
29    pub fn with_config(endpoint: Option<String>, model: Option<String>) -> Self {
30        let endpoint = endpoint.unwrap_or_else(resolve_ollama_endpoint);
31        let model = model.unwrap_or_else(resolve_ollama_model);
32        let config = ollama_models::get_config(&model);
33
34        Self {
35            client: reqwest::Client::new(),
36            endpoint,
37            model,
38            dimensions: config.dimensions,
39            max_chars: config.max_chars,
40        }
41    }
42}
43
44impl Default for OllamaProvider {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50/// Ollama API response for listing models.
51#[derive(Debug, Deserialize)]
52struct OllamaTagsResponse {
53    models: Option<Vec<OllamaModel>>,
54}
55
56#[derive(Debug, Deserialize)]
57struct OllamaModel {
58    name: String,
59}
60
61/// Ollama API request for embedding.
62#[derive(Debug, Serialize)]
63struct OllamaEmbedRequest<'a> {
64    model: &'a str,
65    input: EmbedInput<'a>,
66}
67
68#[derive(Debug, Serialize)]
69#[serde(untagged)]
70enum EmbedInput<'a> {
71    Single(&'a str),
72    Batch(Vec<&'a str>),
73}
74
75/// Ollama API response for embedding.
76#[derive(Debug, Deserialize)]
77struct OllamaEmbedResponse {
78    embeddings: Vec<Vec<f32>>,
79}
80
81impl EmbeddingProvider for OllamaProvider {
82    fn info(&self) -> ProviderInfo {
83        ProviderInfo {
84            name: "ollama".to_string(),
85            model: self.model.clone(),
86            dimensions: self.dimensions,
87            max_chars: self.max_chars,
88            available: false, // Will be checked by is_available()
89        }
90    }
91
92    async fn is_available(&self) -> bool {
93        let url = format!("{}/api/tags", self.endpoint);
94
95        let response = match self.client
96            .get(&url)
97            .timeout(std::time::Duration::from_secs(2))
98            .send()
99            .await
100        {
101            Ok(r) => r,
102            Err(_) => return false,
103        };
104
105        if !response.status().is_success() {
106            return false;
107        }
108
109        let data: OllamaTagsResponse = match response.json().await {
110            Ok(d) => d,
111            Err(_) => return false,
112        };
113
114        // Check if our model is available
115        data.models.map_or(false, |models| {
116            models.iter().any(|m| {
117                m.name == self.model || m.name.starts_with(&format!("{}:", self.model))
118            })
119        })
120    }
121
122    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
123        let url = format!("{}/api/embed", self.endpoint);
124
125        let request = OllamaEmbedRequest {
126            model: &self.model,
127            input: EmbedInput::Single(text),
128        };
129
130        let response = self.client
131            .post(&url)
132            .json(&request)
133            .send()
134            .await
135            .map_err(|e| Error::Embedding(format!("Ollama request failed: {e}")))?;
136
137        if !response.status().is_success() {
138            let error = response.text().await.unwrap_or_default();
139            return Err(Error::Embedding(format!("Ollama embedding failed: {error}")));
140        }
141
142        let data: OllamaEmbedResponse = response.json().await
143            .map_err(|e| Error::Embedding(format!("Failed to parse Ollama response: {e}")))?;
144
145        data.embeddings.into_iter().next()
146            .ok_or_else(|| Error::Embedding("No embeddings returned from Ollama".into()))
147    }
148
149    async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
150        let url = format!("{}/api/embed", self.endpoint);
151
152        let request = OllamaEmbedRequest {
153            model: &self.model,
154            input: EmbedInput::Batch(texts.to_vec()),
155        };
156
157        let response = self.client
158            .post(&url)
159            .json(&request)
160            .send()
161            .await
162            .map_err(|e| Error::Embedding(format!("Ollama batch request failed: {e}")))?;
163
164        if !response.status().is_success() {
165            let error = response.text().await.unwrap_or_default();
166            return Err(Error::Embedding(format!("Ollama batch embedding failed: {error}")));
167        }
168
169        let data: OllamaEmbedResponse = response.json().await
170            .map_err(|e| Error::Embedding(format!("Failed to parse Ollama response: {e}")))?;
171
172        Ok(data.embeddings)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_ollama_provider_creation() {
182        let provider = OllamaProvider::new();
183        let info = provider.info();
184        assert_eq!(info.name, "ollama");
185        assert!(!info.model.is_empty());
186        assert!(info.dimensions > 0);
187    }
188
189    #[test]
190    fn test_ollama_provider_custom_config() {
191        let provider = OllamaProvider::with_config(
192            Some("http://custom:11434".to_string()),
193            Some("mxbai-embed-large".to_string()),
194        );
195        let info = provider.info();
196        assert_eq!(info.model, "mxbai-embed-large");
197        assert_eq!(info.dimensions, 1024);
198    }
199}