skill_runtime/embeddings/
types.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct EmbeddingConfig {
8 pub provider: EmbeddingProviderType,
10
11 #[serde(default)]
13 pub model: Option<String>,
14
15 #[serde(default, skip_serializing)]
17 pub api_key: Option<String>,
18
19 #[serde(default)]
21 pub base_url: Option<String>,
22
23 #[serde(default = "default_batch_size")]
25 pub batch_size: usize,
26}
27
28fn default_batch_size() -> usize {
29 100
30}
31
32impl Default for EmbeddingConfig {
33 fn default() -> Self {
34 Self {
35 provider: EmbeddingProviderType::FastEmbed,
36 model: None,
37 api_key: None,
38 base_url: None,
39 batch_size: default_batch_size(),
40 }
41 }
42}
43
44impl EmbeddingConfig {
45 pub fn fastembed() -> Self {
47 Self {
48 provider: EmbeddingProviderType::FastEmbed,
49 model: Some(FastEmbedModel::AllMiniLM.to_string()),
50 ..Default::default()
51 }
52 }
53
54 pub fn fastembed_with_model(model: FastEmbedModel) -> Self {
56 Self {
57 provider: EmbeddingProviderType::FastEmbed,
58 model: Some(model.to_string()),
59 ..Default::default()
60 }
61 }
62
63 pub fn openai() -> Self {
65 Self {
66 provider: EmbeddingProviderType::OpenAI,
67 model: Some(OpenAIEmbeddingModel::Ada002.to_string()),
68 api_key: std::env::var("OPENAI_API_KEY").ok(),
69 ..Default::default()
70 }
71 }
72
73 pub fn openai_with_model(model: OpenAIEmbeddingModel) -> Self {
75 Self {
76 provider: EmbeddingProviderType::OpenAI,
77 model: Some(model.to_string()),
78 api_key: std::env::var("OPENAI_API_KEY").ok(),
79 ..Default::default()
80 }
81 }
82
83 pub fn ollama() -> Self {
85 Self {
86 provider: EmbeddingProviderType::Ollama,
87 model: Some("nomic-embed-text".to_string()),
88 base_url: Some("http://localhost:11434".to_string()),
89 ..Default::default()
90 }
91 }
92
93 pub fn ollama_with_model(model: &str) -> Self {
95 Self {
96 provider: EmbeddingProviderType::Ollama,
97 model: Some(model.to_string()),
98 base_url: Some("http://localhost:11434".to_string()),
99 ..Default::default()
100 }
101 }
102
103 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
105 self.api_key = Some(api_key.into());
106 self
107 }
108
109 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
111 self.base_url = Some(base_url.into());
112 self
113 }
114
115 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
117 self.batch_size = batch_size;
118 self
119 }
120
121 pub fn with_model(mut self, model: impl Into<String>) -> Self {
123 self.model = Some(model.into());
124 self
125 }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
130#[serde(rename_all = "lowercase")]
131pub enum EmbeddingProviderType {
132 #[default]
134 FastEmbed,
135
136 OpenAI,
138
139 Ollama,
141}
142
143impl std::fmt::Display for EmbeddingProviderType {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 Self::FastEmbed => write!(f, "fastembed"),
147 Self::OpenAI => write!(f, "openai"),
148 Self::Ollama => write!(f, "ollama"),
149 }
150 }
151}
152
153impl std::str::FromStr for EmbeddingProviderType {
154 type Err = anyhow::Error;
155
156 fn from_str(s: &str) -> Result<Self, Self::Err> {
157 match s.to_lowercase().as_str() {
158 "fastembed" | "fast_embed" | "fast-embed" => Ok(Self::FastEmbed),
159 "openai" | "open_ai" | "open-ai" => Ok(Self::OpenAI),
160 "ollama" => Ok(Self::Ollama),
161 _ => Err(anyhow::anyhow!(
162 "Unknown embedding provider: {}. Supported: fastembed, openai, ollama",
163 s
164 )),
165 }
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
171pub enum FastEmbedModel {
172 #[default]
174 AllMiniLM,
175
176 BGESmallEN,
178
179 BGEBaseEN,
181
182 BGELargeEN,
184}
185
186impl FastEmbedModel {
187 pub fn dimensions(&self) -> usize {
189 match self {
190 Self::AllMiniLM => 384,
191 Self::BGESmallEN => 384,
192 Self::BGEBaseEN => 768,
193 Self::BGELargeEN => 1024,
194 }
195 }
196
197 pub fn rig_model_name(&self) -> &'static str {
199 match self {
200 Self::AllMiniLM => "all-minilm",
201 Self::BGESmallEN => "bge-small",
202 Self::BGEBaseEN => "bge-base",
203 Self::BGELargeEN => "bge-large",
204 }
205 }
206}
207
208impl std::fmt::Display for FastEmbedModel {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 Self::AllMiniLM => write!(f, "all-minilm"),
212 Self::BGESmallEN => write!(f, "bge-small"),
213 Self::BGEBaseEN => write!(f, "bge-base"),
214 Self::BGELargeEN => write!(f, "bge-large"),
215 }
216 }
217}
218
219impl std::str::FromStr for FastEmbedModel {
220 type Err = anyhow::Error;
221
222 fn from_str(s: &str) -> Result<Self, Self::Err> {
223 match s.to_lowercase().as_str() {
224 "all-minilm" | "allminilm" | "minilm" => Ok(Self::AllMiniLM),
225 "bge-small" | "bgesmall" | "bge-small-en" => Ok(Self::BGESmallEN),
226 "bge-base" | "bgebase" | "bge-base-en" => Ok(Self::BGEBaseEN),
227 "bge-large" | "bgelarge" | "bge-large-en" => Ok(Self::BGELargeEN),
228 _ => Err(anyhow::anyhow!(
229 "Unknown FastEmbed model: {}. Supported: all-minilm, bge-small, bge-base, bge-large",
230 s
231 )),
232 }
233 }
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
238pub enum OpenAIEmbeddingModel {
239 #[default]
241 Ada002,
242
243 TextEmbedding3Small,
245
246 TextEmbedding3Large,
248}
249
250impl OpenAIEmbeddingModel {
251 pub fn dimensions(&self) -> usize {
253 match self {
254 Self::Ada002 => 1536,
255 Self::TextEmbedding3Small => 1536,
256 Self::TextEmbedding3Large => 3072,
257 }
258 }
259
260 pub fn api_name(&self) -> &'static str {
262 match self {
263 Self::Ada002 => "text-embedding-ada-002",
264 Self::TextEmbedding3Small => "text-embedding-3-small",
265 Self::TextEmbedding3Large => "text-embedding-3-large",
266 }
267 }
268}
269
270impl std::fmt::Display for OpenAIEmbeddingModel {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 write!(f, "{}", self.api_name())
273 }
274}
275
276impl std::str::FromStr for OpenAIEmbeddingModel {
277 type Err = anyhow::Error;
278
279 fn from_str(s: &str) -> Result<Self, Self::Err> {
280 match s.to_lowercase().as_str() {
281 "ada-002" | "text-embedding-ada-002" | "ada" => Ok(Self::Ada002),
282 "3-small" | "text-embedding-3-small" | "embedding-3-small" => {
283 Ok(Self::TextEmbedding3Small)
284 }
285 "3-large" | "text-embedding-3-large" | "embedding-3-large" => {
286 Ok(Self::TextEmbedding3Large)
287 }
288 _ => Err(anyhow::anyhow!(
289 "Unknown OpenAI embedding model: {}. Supported: ada-002, 3-small, 3-large",
290 s
291 )),
292 }
293 }
294}
295
296#[derive(Debug, Clone)]
298pub struct EmbeddingResult {
299 pub embedding: Vec<f32>,
301
302 pub tokens_used: Option<usize>,
304
305 pub model: String,
307}
308
309impl EmbeddingResult {
310 pub fn new(embedding: Vec<f32>, model: impl Into<String>) -> Self {
311 Self {
312 embedding,
313 tokens_used: None,
314 model: model.into(),
315 }
316 }
317
318 pub fn with_tokens(mut self, tokens: usize) -> Self {
319 self.tokens_used = Some(tokens);
320 self
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_fastembed_model_dimensions() {
330 assert_eq!(FastEmbedModel::AllMiniLM.dimensions(), 384);
331 assert_eq!(FastEmbedModel::BGESmallEN.dimensions(), 384);
332 assert_eq!(FastEmbedModel::BGEBaseEN.dimensions(), 768);
333 assert_eq!(FastEmbedModel::BGELargeEN.dimensions(), 1024);
334 }
335
336 #[test]
337 fn test_openai_model_dimensions() {
338 assert_eq!(OpenAIEmbeddingModel::Ada002.dimensions(), 1536);
339 assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Small.dimensions(), 1536);
340 assert_eq!(OpenAIEmbeddingModel::TextEmbedding3Large.dimensions(), 3072);
341 }
342
343 #[test]
344 fn test_provider_type_parsing() {
345 assert_eq!(
346 "fastembed".parse::<EmbeddingProviderType>().unwrap(),
347 EmbeddingProviderType::FastEmbed
348 );
349 assert_eq!(
350 "openai".parse::<EmbeddingProviderType>().unwrap(),
351 EmbeddingProviderType::OpenAI
352 );
353 assert_eq!(
354 "ollama".parse::<EmbeddingProviderType>().unwrap(),
355 EmbeddingProviderType::Ollama
356 );
357 }
358
359 #[test]
360 fn test_fastembed_model_parsing() {
361 assert_eq!(
362 "all-minilm".parse::<FastEmbedModel>().unwrap(),
363 FastEmbedModel::AllMiniLM
364 );
365 assert_eq!(
366 "bge-small".parse::<FastEmbedModel>().unwrap(),
367 FastEmbedModel::BGESmallEN
368 );
369 }
370
371 #[test]
372 fn test_embedding_config_builders() {
373 let config = EmbeddingConfig::fastembed();
374 assert_eq!(config.provider, EmbeddingProviderType::FastEmbed);
375
376 let config = EmbeddingConfig::openai_with_model(OpenAIEmbeddingModel::TextEmbedding3Large);
377 assert_eq!(config.provider, EmbeddingProviderType::OpenAI);
378 assert_eq!(config.model, Some("text-embedding-3-large".to_string()));
379
380 let config = EmbeddingConfig::ollama().with_base_url("http://custom:11434");
381 assert_eq!(config.base_url, Some("http://custom:11434".to_string()));
382 }
383}