skill_runtime/embeddings/
types.rs

1//! Types for embedding configuration and providers
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for embedding providers
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct EmbeddingConfig {
8    /// Provider type: "fastembed", "openai", "ollama"
9    pub provider: EmbeddingProviderType,
10
11    /// Model name/identifier (provider-specific)
12    #[serde(default)]
13    pub model: Option<String>,
14
15    /// API key (for cloud providers)
16    #[serde(default, skip_serializing)]
17    pub api_key: Option<String>,
18
19    /// Base URL (for self-hosted or custom endpoints)
20    #[serde(default)]
21    pub base_url: Option<String>,
22
23    /// Maximum batch size for document embedding
24    #[serde(default = "default_batch_size")]
25    pub batch_size: usize,
26}
27
28fn default_batch_size() -> usize {
29    100
30}
31
32impl Default for EmbeddingConfig {
33    fn default() -> Self {
34        Self {
35            provider: EmbeddingProviderType::FastEmbed,
36            model: None,
37            api_key: None,
38            base_url: None,
39            batch_size: default_batch_size(),
40        }
41    }
42}
43
44impl EmbeddingConfig {
45    /// Create a FastEmbed configuration
46    pub fn fastembed() -> Self {
47        Self {
48            provider: EmbeddingProviderType::FastEmbed,
49            model: Some(FastEmbedModel::AllMiniLM.to_string()),
50            ..Default::default()
51        }
52    }
53
54    /// Create a FastEmbed configuration with a specific model
55    pub fn fastembed_with_model(model: FastEmbedModel) -> Self {
56        Self {
57            provider: EmbeddingProviderType::FastEmbed,
58            model: Some(model.to_string()),
59            ..Default::default()
60        }
61    }
62
63    /// Create an OpenAI configuration
64    pub fn openai() -> Self {
65        Self {
66            provider: EmbeddingProviderType::OpenAI,
67            model: Some(OpenAIEmbeddingModel::Ada002.to_string()),
68            api_key: std::env::var("OPENAI_API_KEY").ok(),
69            ..Default::default()
70        }
71    }
72
73    /// Create an OpenAI configuration with a specific model
74    pub fn openai_with_model(model: OpenAIEmbeddingModel) -> Self {
75        Self {
76            provider: EmbeddingProviderType::OpenAI,
77            model: Some(model.to_string()),
78            api_key: std::env::var("OPENAI_API_KEY").ok(),
79            ..Default::default()
80        }
81    }
82
83    /// Create an Ollama configuration
84    pub fn ollama() -> Self {
85        Self {
86            provider: EmbeddingProviderType::Ollama,
87            model: Some("nomic-embed-text".to_string()),
88            base_url: Some("http://localhost:11434".to_string()),
89            ..Default::default()
90        }
91    }
92
93    /// Create an Ollama configuration with a specific model
94    pub fn ollama_with_model(model: &str) -> Self {
95        Self {
96            provider: EmbeddingProviderType::Ollama,
97            model: Some(model.to_string()),
98            base_url: Some("http://localhost:11434".to_string()),
99            ..Default::default()
100        }
101    }
102
103    /// Set the API key
104    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
105        self.api_key = Some(api_key.into());
106        self
107    }
108
109    /// Set the base URL
110    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
111        self.base_url = Some(base_url.into());
112        self
113    }
114
115    /// Set the batch size
116    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
117        self.batch_size = batch_size;
118        self
119    }
120
121    /// Set the model
122    pub fn with_model(mut self, model: impl Into<String>) -> Self {
123        self.model = Some(model.into());
124        self
125    }
126}
127
128/// Supported embedding provider types
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
130#[serde(rename_all = "lowercase")]
131pub enum EmbeddingProviderType {
132    /// Local FastEmbed (ONNX-based)
133    #[default]
134    FastEmbed,
135
136    /// OpenAI API
137    OpenAI,
138
139    /// Ollama local server
140    Ollama,
141}
142
143impl std::fmt::Display for EmbeddingProviderType {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            Self::FastEmbed => write!(f, "fastembed"),
147            Self::OpenAI => write!(f, "openai"),
148            Self::Ollama => write!(f, "ollama"),
149        }
150    }
151}
152
153impl std::str::FromStr for EmbeddingProviderType {
154    type Err = anyhow::Error;
155
156    fn from_str(s: &str) -> Result<Self, Self::Err> {
157        match s.to_lowercase().as_str() {
158            "fastembed" | "fast_embed" | "fast-embed" => Ok(Self::FastEmbed),
159            "openai" | "open_ai" | "open-ai" => Ok(Self::OpenAI),
160            "ollama" => Ok(Self::Ollama),
161            _ => Err(anyhow::anyhow!(
162                "Unknown embedding provider: {}. Supported: fastembed, openai, ollama",
163                s
164            )),
165        }
166    }
167}
168
169/// FastEmbed model variants
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
171pub enum FastEmbedModel {
172    /// all-MiniLM-L6-v2 (Quantized) - 384 dimensions, fastest
173    #[default]
174    AllMiniLM,
175
176    /// BGE-small-en-v1.5 (Quantized) - 384 dimensions, good quality
177    BGESmallEN,
178
179    /// BGE-base-en-v1.5 - 768 dimensions, better quality
180    BGEBaseEN,
181
182    /// BGE-large-en-v1.5 - 1024 dimensions, best quality
183    BGELargeEN,
184}
185
186impl FastEmbedModel {
187    /// Get the embedding dimensions for this model
188    pub fn dimensions(&self) -> usize {
189        match self {
190            Self::AllMiniLM => 384,
191            Self::BGESmallEN => 384,
192            Self::BGEBaseEN => 768,
193            Self::BGELargeEN => 1024,
194        }
195    }
196
197    /// Get the model name as used by rig-fastembed
198    pub fn rig_model_name(&self) -> &'static str {
199        match self {
200            Self::AllMiniLM => "all-minilm",
201            Self::BGESmallEN => "bge-small",
202            Self::BGEBaseEN => "bge-base",
203            Self::BGELargeEN => "bge-large",
204        }
205    }
206}
207
208impl std::fmt::Display for FastEmbedModel {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            Self::AllMiniLM => write!(f, "all-minilm"),
212            Self::BGESmallEN => write!(f, "bge-small"),
213            Self::BGEBaseEN => write!(f, "bge-base"),
214            Self::BGELargeEN => write!(f, "bge-large"),
215        }
216    }
217}
218
219impl std::str::FromStr for FastEmbedModel {
220    type Err = anyhow::Error;
221
222    fn from_str(s: &str) -> Result<Self, Self::Err> {
223        match s.to_lowercase().as_str() {
224            "all-minilm" | "allminilm" | "minilm" => Ok(Self::AllMiniLM),
225            "bge-small" | "bgesmall" | "bge-small-en" => Ok(Self::BGESmallEN),
226            "bge-base" | "bgebase" | "bge-base-en" => Ok(Self::BGEBaseEN),
227            "bge-large" | "bgelarge" | "bge-large-en" => Ok(Self::BGELargeEN),
228            _ => Err(anyhow::anyhow!(
229                "Unknown FastEmbed model: {}. Supported: all-minilm, bge-small, bge-base, bge-large",
230                s
231            )),
232        }
233    }
234}
235
236/// OpenAI embedding model variants
237#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
238pub enum OpenAIEmbeddingModel {
239    /// text-embedding-ada-002 - 1536 dimensions (legacy, widely supported)
240    #[default]
241    Ada002,
242
243    /// text-embedding-3-small - 1536 dimensions (newer, better)
244    TextEmbedding3Small,
245
246    /// text-embedding-3-large - 3072 dimensions (best quality)
247    TextEmbedding3Large,
248}
249
250impl OpenAIEmbeddingModel {
251    /// Get the embedding dimensions for this model
252    pub fn dimensions(&self) -> usize {
253        match self {
254            Self::Ada002 => 1536,
255            Self::TextEmbedding3Small => 1536,
256            Self::TextEmbedding3Large => 3072,
257        }
258    }
259
260    /// Get the model name as used by OpenAI API
261    pub fn api_name(&self) -> &'static str {
262        match self {
263            Self::Ada002 => "text-embedding-ada-002",
264            Self::TextEmbedding3Small => "text-embedding-3-small",
265            Self::TextEmbedding3Large => "text-embedding-3-large",
266        }
267    }
268}
269
270impl std::fmt::Display for OpenAIEmbeddingModel {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        write!(f, "{}", self.api_name())
273    }
274}
275
276impl std::str::FromStr for OpenAIEmbeddingModel {
277    type Err = anyhow::Error;
278
279    fn from_str(s: &str) -> Result<Self, Self::Err> {
280        match s.to_lowercase().as_str() {
281            "ada-002" | "text-embedding-ada-002" | "ada" => Ok(Self::Ada002),
282            "3-small" | "text-embedding-3-small" | "embedding-3-small" => {
283                Ok(Self::TextEmbedding3Small)
284            }
285            "3-large" | "text-embedding-3-large" | "embedding-3-large" => {
286                Ok(Self::TextEmbedding3Large)
287            }
288            _ => Err(anyhow::anyhow!(
289                "Unknown OpenAI embedding model: {}. Supported: ada-002, 3-small, 3-large",
290                s
291            )),
292        }
293    }
294}
295
296/// Embedding result with metadata
297#[derive(Debug, Clone)]
298pub struct EmbeddingResult {
299    /// The embedding vector
300    pub embedding: Vec<f32>,
301
302    /// Token count used (if available)
303    pub tokens_used: Option<usize>,
304
305    /// Model used for embedding
306    pub model: String,
307}
308
309impl EmbeddingResult {
310    pub fn new(embedding: Vec<f32>, model: impl Into<String>) -> Self {
311        Self {
312            embedding,
313            tokens_used: None,
314            model: model.into(),
315        }
316    }
317
318    pub fn with_tokens(mut self, tokens: usize) -> Self {
319        self.tokens_used = Some(tokens);
320        self
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_fastembed_model_dimensions() {
330        assert_eq!(FastEmbedModel::AllMiniLM.dimensions(), 384);
331        assert_eq!(FastEmbedModel::BGESmallEN.dimensions(), 384);
332        assert_eq!(FastEmbedModel::BGEBaseEN.dimensions(), 768);
333        assert_eq!(FastEmbedModel::BGELargeEN.dimensions(), 1024);
334    }
335
336    #[test]
337    fn test_openai_model_dimensions() {
338        assert_eq!(OpenAIEmbeddingModel::Ada002.dimensions(), 1536);
339        assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Small.dimensions(), 1536);
340        assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Large.dimensions(), 3072);
341    }
342
343    #[test]
344    fn test_provider_type_parsing() {
345        assert_eq!(
346            "fastembed".parse::<EmbeddingProviderType>().unwrap(),
347            EmbeddingProviderType::FastEmbed
348        );
349        assert_eq!(
350            "openai".parse::<EmbeddingProviderType>().unwrap(),
351            EmbeddingProviderType::OpenAI
352        );
353        assert_eq!(
354            "ollama".parse::<EmbeddingProviderType>().unwrap(),
355            EmbeddingProviderType::Ollama
356        );
357    }
358
359    #[test]
360    fn test_fastembed_model_parsing() {
361        assert_eq!(
362            "all-minilm".parse::<FastEmbedModel>().unwrap(),
363            FastEmbedModel::AllMiniLM
364        );
365        assert_eq!(
366            "bge-small".parse::<FastEmbedModel>().unwrap(),
367            FastEmbedModel::BGESmallEN
368        );
369    }
370
371    #[test]
372    fn test_embedding_config_builders() {
373        let config = EmbeddingConfig::fastembed();
374        assert_eq!(config.provider, EmbeddingProviderType::FastEmbed);
375
376        let config = EmbeddingConfig::openai_with_model(OpenAIEmbeddingModel::TextEmbedding3Large);
377        assert_eq!(config.provider, EmbeddingProviderType::OpenAI);
378        assert_eq!(config.model, Some("text-embedding-3-large".to_string()));
379
380        let config = EmbeddingConfig::ollama().with_base_url("http://custom:11434");
381        assert_eq!(config.base_url, Some("http://custom:11434".to_string()));
382    }
383}