skill_runtime/embeddings/
factory.rs

1//! Embedding provider factory
2//!
3//! Creates embedding providers from configuration.
4
5use super::{
6    EmbeddingConfig, EmbeddingProvider, EmbeddingProviderType,
7    FastEmbedModel, FastEmbedProvider,
8    OpenAIEmbedProvider, OpenAIEmbeddingModel,
9    OllamaProvider,
10};
11use anyhow::{Context, Result};
12use std::sync::Arc;
13
14/// Factory for creating embedding providers from configuration
15pub struct EmbeddingProviderFactory;
16
17impl EmbeddingProviderFactory {
18    /// Create an embedding provider from configuration
19    pub fn create(config: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingProvider>> {
20        match config.provider {
21            EmbeddingProviderType::FastEmbed => {
22                let model = config
23                    .model
24                    .as_ref()
25                    .filter(|m| !m.trim().is_empty()) // Filter out empty/whitespace strings
26                    .map(|m| m.parse::<FastEmbedModel>())
27                    .transpose()
28                    .context("Invalid FastEmbed model")?
29                    .unwrap_or_default();
30
31                let provider = FastEmbedProvider::with_model(model)?;
32                Ok(Arc::new(provider))
33            }
34
35            EmbeddingProviderType::OpenAI => {
36                let model = config
37                    .model
38                    .as_ref()
39                    .filter(|m| !m.trim().is_empty()) // Filter out empty/whitespace strings
40                    .map(|m| m.parse::<OpenAIEmbeddingModel>())
41                    .transpose()
42                    .context("Invalid OpenAI model")?
43                    .unwrap_or_default();
44
45                let provider = if let Some(ref api_key) = config.api_key {
46                    OpenAIEmbedProvider::with_api_key(api_key, model)?
47                } else {
48                    OpenAIEmbedProvider::with_model(model)?
49                };
50
51                Ok(Arc::new(provider))
52            }
53
54            EmbeddingProviderType::Ollama => {
55                let model = config
56                    .model
57                    .as_deref()
58                    .filter(|m| !m.trim().is_empty()) // Filter out empty/whitespace strings
59                    .unwrap_or(super::ollama::DEFAULT_OLLAMA_MODEL);
60
61                let provider = if let Some(ref base_url) = config.base_url {
62                    OllamaProvider::with_url(base_url, model)?
63                } else {
64                    OllamaProvider::with_model(model)?
65                };
66
67                Ok(Arc::new(provider))
68            }
69        }
70    }
71
72    /// Create a default provider (FastEmbed with AllMiniLM)
73    pub fn default_provider() -> Result<Arc<dyn EmbeddingProvider>> {
74        Self::create(&EmbeddingConfig::default())
75    }
76
77    /// Create a FastEmbed provider with default model
78    pub fn fastembed() -> Result<Arc<dyn EmbeddingProvider>> {
79        Ok(Arc::new(FastEmbedProvider::new()?))
80    }
81
82    /// Create an OpenAI provider with default model
83    pub fn openai() -> Result<Arc<dyn EmbeddingProvider>> {
84        Ok(Arc::new(OpenAIEmbedProvider::new()?))
85    }
86
87    /// Create an Ollama provider with default model
88    pub fn ollama() -> Result<Arc<dyn EmbeddingProvider>> {
89        Ok(Arc::new(OllamaProvider::new()?))
90    }
91}
92
93/// Convenience function to create a provider from configuration
94pub fn create_provider(config: &EmbeddingConfig) -> Result<Arc<dyn EmbeddingProvider>> {
95    EmbeddingProviderFactory::create(config)
96}
97
98/// Convenience function to create a provider from provider type string
99pub fn create_provider_from_type(
100    provider_type: &str,
101    model: Option<&str>,
102) -> Result<Arc<dyn EmbeddingProvider>> {
103    let provider_type: EmbeddingProviderType = provider_type.parse()?;
104
105    let config = EmbeddingConfig {
106        provider: provider_type,
107        model: model.map(String::from),
108        ..Default::default()
109    };
110
111    create_provider(&config)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_create_fastembed() {
120        let config = EmbeddingConfig::fastembed();
121        let provider = EmbeddingProviderFactory::create(&config).unwrap();
122        assert_eq!(provider.provider_name(), "fastembed");
123        assert_eq!(provider.dimensions(), 384);
124    }
125
126    #[test]
127    fn test_create_fastembed_with_model() {
128        let config = EmbeddingConfig::fastembed_with_model(FastEmbedModel::BGEBaseEN);
129        let provider = EmbeddingProviderFactory::create(&config).unwrap();
130        assert_eq!(provider.dimensions(), 768);
131    }
132
133    #[test]
134    fn test_create_ollama() {
135        let config = EmbeddingConfig::ollama();
136        let provider = EmbeddingProviderFactory::create(&config).unwrap();
137        assert_eq!(provider.provider_name(), "ollama");
138        assert_eq!(provider.model_name(), "nomic-embed-text");
139    }
140
141    #[test]
142    fn test_create_from_type_string() {
143        let provider = create_provider_from_type("fastembed", Some("bge-small")).unwrap();
144        assert_eq!(provider.provider_name(), "fastembed");
145        assert_eq!(provider.dimensions(), 384);
146    }
147
148    #[test]
149    fn test_default_provider() {
150        let provider = EmbeddingProviderFactory::default_provider().unwrap();
151        assert_eq!(provider.provider_name(), "fastembed");
152    }
153
154    // OpenAI tests require API key, so we just test error handling
155    #[test]
156    fn test_openai_requires_api_key() {
157        // Save and clear API key
158        let original = std::env::var("OPENAI_API_KEY").ok();
159        std::env::remove_var("OPENAI_API_KEY");
160
161        let config = EmbeddingConfig::openai();
162        let result = EmbeddingProviderFactory::create(&config);
163        assert!(result.is_err());
164
165        // Restore
166        if let Some(key) = original {
167            std::env::set_var("OPENAI_API_KEY", key);
168        }
169    }
170}