skill_runtime/embeddings/
factory.rs1use super::{
6 EmbeddingConfig, EmbeddingProvider, EmbeddingProviderType,
7 FastEmbedModel, FastEmbedProvider,
8 OpenAIEmbedProvider, OpenAIEmbeddingModel,
9 OllamaProvider,
10};
11use anyhow::{Context, Result};
12use std::sync::Arc;
13
14pub struct EmbeddingProviderFactory;
16
17impl EmbeddingProviderFactory {
18 pub fn create(config: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingProvider>> {
20 match config.provider {
21 EmbeddingProviderType::FastEmbed => {
22 let model = config
23 .model
24 .as_ref()
25 .filter(|m| !m.trim().is_empty()) .map(|m| m.parse::<FastEmbedModel>())
27 .transpose()
28 .context("Invalid FastEmbed model")?
29 .unwrap_or_default();
30
31 let provider = FastEmbedProvider::with_model(model)?;
32 Ok(Arc::new(provider))
33 }
34
35 EmbeddingProviderType::OpenAI => {
36 let model = config
37 .model
38 .as_ref()
39 .filter(|m| !m.trim().is_empty()) .map(|m| m.parse::<OpenAIEmbeddingModel>())
41 .transpose()
42 .context("Invalid OpenAI model")?
43 .unwrap_or_default();
44
45 let provider = if let Some(ref api_key) = config.api_key {
46 OpenAIEmbedProvider::with_api_key(api_key, model)?
47 } else {
48 OpenAIEmbedProvider::with_model(model)?
49 };
50
51 Ok(Arc::new(provider))
52 }
53
54 EmbeddingProviderType::Ollama => {
55 let model = config
56 .model
57 .as_deref()
58 .filter(|m| !m.trim().is_empty()) .unwrap_or(super::ollama::DEFAULT_OLLAMA_MODEL);
60
61 let provider = if let Some(ref base_url) = config.base_url {
62 OllamaProvider::with_url(base_url, model)?
63 } else {
64 OllamaProvider::with_model(model)?
65 };
66
67 Ok(Arc::new(provider))
68 }
69 }
70 }
71
72 pub fn default_provider() -> Result<Arc<dyn EmbeddingProvider>> {
74 Self::create(&EmbeddingConfig::default())
75 }
76
77 pub fn fastembed() -> Result<Arc<dyn EmbeddingProvider>> {
79 Ok(Arc::new(FastEmbedProvider::new()?))
80 }
81
82 pub fn openai() -> Result<Arc<dyn EmbeddingProvider>> {
84 Ok(Arc::new(OpenAIEmbedProvider::new()?))
85 }
86
87 pub fn ollama() -> Result<Arc<dyn EmbeddingProvider>> {
89 Ok(Arc::new(OllamaProvider::new()?))
90 }
91}
92
93pub fn create_provider(config: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingProvider>> {
95 EmbeddingProviderFactory::create(config)
96}
97
98pub fn create_provider_from_type(
100 provider_type: &str,
101 model: Option<&str>,
102) -> Result<Arc<dyn EmbeddingProvider>> {
103 let provider_type: EmbeddingProviderType = provider_type.parse()?;
104
105 let config = EmbeddingConfig {
106 provider: provider_type,
107 model: model.map(String::from),
108 ..Default::default()
109 };
110
111 create_provider(&config)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_create_fastembed() {
120 let config = EmbeddingConfig::fastembed();
121 let provider = EmbeddingProviderFactory::create(&config).unwrap();
122 assert_eq!(provider.provider_name(), "fastembed");
123 assert_eq!(provider.dimensions(), 384);
124 }
125
126 #[test]
127 fn test_create_fastembed_with_model() {
128 let config = EmbeddingConfig::fastembed_with_model(FastEmbedModel::BGEBaseEN);
129 let provider = EmbeddingProviderFactory::create(&config).unwrap();
130 assert_eq!(provider.dimensions(), 768);
131 }
132
133 #[test]
134 fn test_create_ollama() {
135 let config = EmbeddingConfig::ollama();
136 let provider = EmbeddingProviderFactory::create(&config).unwrap();
137 assert_eq!(provider.provider_name(), "ollama");
138 assert_eq!(provider.model_name(), "nomic-embed-text");
139 }
140
141 #[test]
142 fn test_create_from_type_string() {
143 let provider = create_provider_from_type("fastembed", Some("bge-small")).unwrap();
144 assert_eq!(provider.provider_name(), "fastembed");
145 assert_eq!(provider.dimensions(), 384);
146 }
147
148 #[test]
149 fn test_default_provider() {
150 let provider = EmbeddingProviderFactory::default_provider().unwrap();
151 assert_eq!(provider.provider_name(), "fastembed");
152 }
153
154 #[test]
156 fn test_openai_requires_api_key() {
157 let original = std::env::var("OPENAI_API_KEY").ok();
159 std::env::remove_var("OPENAI_API_KEY");
160
161 let config = EmbeddingConfig::openai();
162 let result = EmbeddingProviderFactory::create(&config);
163 assert!(result.is_err());
164
165 if let Some(key) = original {
167 std::env::set_var("OPENAI_API_KEY", key);
168 }
169 }
170}