Skip to main content

seshat_embedding/
lib.rs

1//! # seshat-embedding
2//!
3//! Embedding provider abstraction with a built-in local provider for Seshat.
4//!
5//! When the `[embedding]` section is present in `seshat.toml` **and** the
6//! crate is compiled with the `builtin-embeddings` feature (enabled by
7//! default), embeddings are generated locally using `fastembed-rs`
8//! (all-MiniLM-L6-v2 model, 384 dimensions) — no external services needed.
9//!
10//! When the section is absent or the feature is disabled, all embedding code
11//! is compiled away with zero overhead.
12//!
13//! ## Configuration
14//!
15//! ```toml
16//! # seshat.toml — uncomment to enable vector search
17//! # [embedding]
18//! # model = ""          # empty → provider default (all-MiniLM-L6-v2)
19//! # dimension = 0       # 0     → provider default (384)
20//! # batch_size = 32
21//! ```
22
23use std::fmt;
24
25use serde::{Deserialize, Serialize};
26
27// ─── Error types ─────────────────────────────────────────────────────────────
28
29/// Errors from embedding operations.
30#[derive(Debug, thiserror::Error)]
31pub enum EmbeddingError {
32    /// Embedding provider failed to generate embeddings.
33    #[error("embedding provider error: {0}")]
34    ProviderError(String),
35
36    /// Failed to parse or validate embedding output.
37    #[error("failed to parse embedding response: {0}")]
38    ParseError(String),
39
40    /// The provider returned an unexpected number of embedding vectors.
41    #[error("expected {expected} embedding vectors, got {got}")]
42    CountMismatch { expected: usize, got: usize },
43
44    /// An embedding vector has an unexpected number of dimensions.
45    #[error("expected {expected}-dimensional embedding, got {got} dimensions")]
46    DimensionMismatch { expected: usize, got: usize },
47
48    /// Configuration error (e.g., invalid model name).
49    #[error("embedding configuration error: {0}")]
50    ConfigError(String),
51}
52
53// ─── Trait ───────────────────────────────────────────────────────────────────
54
55/// Abstraction over embedding providers.
56///
57/// Implementations must be `Send + Sync` so providers can be shared across
58/// threads (e.g., stored in an `Arc`).
59pub trait EmbeddingProvider: Send + Sync + fmt::Debug {
60    /// Generate embeddings for one or more text inputs.
61    ///
62    /// Returns one `Vec<f32>` per input text, each of length [`Self::dimension`].
63    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
64
65    /// The dimensionality of the embedding vectors this provider produces.
66    fn dimension(&self) -> usize;
67}
68
69// ─── Config ──────────────────────────────────────────────────────────────────
70
71/// Configuration for the embedding provider, parsed from `[embedding]` in
72/// `seshat.toml`.
73///
74/// When this section is absent, embedding is disabled with zero overhead.
75/// When present, the built-in provider is used (requires `builtin-embeddings`
76/// feature, which is enabled by default).
77#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(default, rename_all = "snake_case")]
79pub struct EmbeddingConfig {
80    /// Model name. Empty string uses the provider default (all-MiniLM-L6-v2).
81    pub model: String,
82    /// Embedding vector dimension. `0` uses the provider default (384).
83    pub dimension: usize,
84    /// Batch size for embedding generation. Must be ≥ 1.
85    pub batch_size: usize,
86}
87
88impl Default for EmbeddingConfig {
89    fn default() -> Self {
90        Self {
91            model: String::new(), // empty → provider default
92            dimension: 0,         // 0     → provider default
93            batch_size: 32,
94        }
95    }
96}
97
98impl fmt::Display for EmbeddingConfig {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        write!(
101            f,
102            "model={}, dimension={}, batch_size={}",
103            if self.model.is_empty() {
104                "(default)"
105            } else {
106                &self.model
107            },
108            if self.dimension == 0 {
109                "(default)".to_owned()
110            } else {
111                self.dimension.to_string()
112            },
113            self.batch_size,
114        )
115    }
116}
117
118// ─── Provider factory ────────────────────────────────────────────────────────
119
120/// Create the built-in embedding provider from configuration.
121///
122/// Requires the `builtin-embeddings` feature (enabled by default).
123///
124/// # Errors
125///
126/// Returns [`EmbeddingError::ConfigError`] if `batch_size` is 0 or if the
127/// built-in provider fails to initialise.
128pub fn create_provider(
129    config: &EmbeddingConfig,
130) -> Result<Box<dyn EmbeddingProvider>, EmbeddingError> {
131    // Validate batch_size early — 0 would panic in `slice::chunks()`.
132    if config.batch_size == 0 {
133        return Err(EmbeddingError::ConfigError(
134            "batch_size must be at least 1".to_owned(),
135        ));
136    }
137
138    #[cfg(feature = "builtin-embeddings")]
139    {
140        builtin::create_builtin_provider(config)
141    }
142
143    #[cfg(not(feature = "builtin-embeddings"))]
144    {
145        Err(EmbeddingError::ConfigError(
146            "embedding support is not compiled in — rebuild with the \
147             'builtin-embeddings' feature (enabled by default)"
148                .to_owned(),
149        ))
150    }
151}
152
153// ─── Built-in provider ───────────────────────────────────────────────────────
154
155#[cfg(feature = "builtin-embeddings")]
156mod builtin {
157    use std::sync::Mutex;
158
159    use super::*;
160
161    /// Default model for the built-in provider.
162    pub const DEFAULT_MODEL: &str = "all-MiniLM-L6-v2";
163    /// Default embedding dimension for all-MiniLM-L6-v2.
164    pub const DEFAULT_DIMENSION: usize = 384;
165
166    pub fn create_builtin_provider(
167        config: &EmbeddingConfig,
168    ) -> Result<Box<dyn EmbeddingProvider>, EmbeddingError> {
169        let model = if config.model.is_empty() {
170            DEFAULT_MODEL.to_owned()
171        } else {
172            config.model.clone()
173        };
174        let dimension = if config.dimension == 0 {
175            DEFAULT_DIMENSION
176        } else {
177            config.dimension
178        };
179        Ok(Box::new(BuiltinProvider::new(model, dimension)?))
180    }
181
182    /// Built-in embedding provider using fastembed-rs (all-MiniLM-L6-v2).
183    ///
184    /// Runs fully locally — no network calls, no external services.
185    /// The model is bundled with the binary when `builtin-embeddings` feature
186    /// is enabled.
187    pub struct BuiltinProvider {
188        model_name: String,
189        dimension: usize,
190        // fastembed 5 requires `&mut self` on `embed`. Wrap in `Mutex` so the
191        // public `EmbeddingProvider::embed(&self, …)` API stays unchanged and
192        // callers can keep sharing `Arc<dyn EmbeddingProvider>`.
193        inner: Mutex<fastembed::TextEmbedding>,
194    }
195
196    impl fmt::Debug for BuiltinProvider {
197        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198            f.debug_struct("BuiltinProvider")
199                .field("model_name", &self.model_name)
200                .field("dimension", &self.dimension)
201                .finish_non_exhaustive()
202        }
203    }
204
205    impl BuiltinProvider {
206        pub fn new(model_name: String, dimension: usize) -> Result<Self, EmbeddingError> {
207            use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
208
209            // Resolve fastembed model from name.
210            let model = match model_name.as_str() {
211                "all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
212                other => {
213                    return Err(EmbeddingError::ConfigError(format!(
214                        "unknown built-in model '{other}'. \
215                         Supported: all-MiniLM-L6-v2"
216                    )));
217                }
218            };
219
220            // Suppress fastembed's own download progress output — seshat manages its own UI.
221            let init_opts = InitOptions::new(model).with_show_download_progress(false);
222
223            tracing::info!(model = %model_name, "Loading built-in embedding model (may download on first run)");
224
225            let inner = TextEmbedding::try_new(init_opts)
226                .map_err(|e| EmbeddingError::ProviderError(format!("failed to load model: {e}")))?;
227
228            Ok(Self {
229                model_name,
230                dimension,
231                inner: Mutex::new(inner),
232            })
233        }
234    }
235
236    impl EmbeddingProvider for BuiltinProvider {
237        fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
238            if texts.is_empty() {
239                return Ok(Vec::new());
240            }
241
242            let embeddings = {
243                let mut model = self.inner.lock().map_err(|e| {
244                    EmbeddingError::ProviderError(format!("model lock poisoned: {e}"))
245                })?;
246                model
247                    .embed(texts, None)
248                    .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?
249            };
250
251            if embeddings.len() != texts.len() {
252                return Err(EmbeddingError::CountMismatch {
253                    expected: texts.len(),
254                    got: embeddings.len(),
255                });
256            }
257
258            // Validate: no empty, non-finite, or wrong-dimension vectors.
259            for (i, vec) in embeddings.iter().enumerate() {
260                if vec.is_empty() {
261                    return Err(EmbeddingError::ParseError(format!(
262                        "embedding at index {i} is empty"
263                    )));
264                }
265                // Validate actual dimension matches configured dimension.
266                // This catches misconfigured dimension= values before they
267                // silently corrupt the vector store.
268                if self.dimension > 0 && vec.len() != self.dimension {
269                    return Err(EmbeddingError::DimensionMismatch {
270                        expected: self.dimension,
271                        got: vec.len(),
272                    });
273                }
274                for &val in vec {
275                    if !val.is_finite() {
276                        return Err(EmbeddingError::ParseError(format!(
277                            "embedding at index {i} contains non-finite value: {val}"
278                        )));
279                    }
280                }
281            }
282
283            Ok(embeddings)
284        }
285
286        fn dimension(&self) -> usize {
287            self.dimension
288        }
289    }
290}
291
292// ─── Tests ───────────────────────────────────────────────────────────────────
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    // ── Mock provider ──────────────────────────────────────────────────
299
300    /// A mock provider for testing that returns predetermined embeddings.
301    #[derive(Debug)]
302    struct MockProvider {
303        dim: usize,
304        error: Option<String>,
305    }
306
307    impl MockProvider {
308        fn new(dim: usize) -> Self {
309            Self { dim, error: None }
310        }
311
312        fn with_error(dim: usize, msg: &str) -> Self {
313            Self {
314                dim,
315                error: Some(msg.to_owned()),
316            }
317        }
318    }
319
320    impl EmbeddingProvider for MockProvider {
321        fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
322            if let Some(ref msg) = self.error {
323                return Err(EmbeddingError::ProviderError(msg.clone()));
324            }
325            Ok(texts
326                .iter()
327                .enumerate()
328                .map(|(i, _)| vec![i as f32 / 10.0; self.dim])
329                .collect())
330        }
331
332        fn dimension(&self) -> usize {
333            self.dim
334        }
335    }
336
337    #[test]
338    fn mock_provider_returns_expected_embeddings() {
339        let provider = MockProvider::new(384);
340        let texts = vec!["hello".to_owned(), "world".to_owned()];
341        let result = provider.embed(&texts).unwrap();
342        assert_eq!(result.len(), 2);
343        assert_eq!(result[0].len(), 384);
344        assert!((result[0][0] - 0.0).abs() < f32::EPSILON);
345        assert!((result[1][0] - 0.1).abs() < f32::EPSILON);
346    }
347
348    #[test]
349    fn mock_provider_empty_input() {
350        let provider = MockProvider::new(384);
351        let result = provider.embed(&[]).unwrap();
352        assert!(result.is_empty());
353    }
354
355    #[test]
356    fn mock_provider_dimension() {
357        let provider = MockProvider::new(1536);
358        assert_eq!(provider.dimension(), 1536);
359    }
360
361    #[test]
362    fn mock_provider_error() {
363        let provider = MockProvider::with_error(384, "model load failed");
364        let result = provider.embed(&["test".to_owned()]);
365        assert!(result.is_err());
366        assert!(matches!(
367            result.unwrap_err(),
368            EmbeddingError::ProviderError(_)
369        ));
370    }
371
372    // ── Config tests ───────────────────────────────────────────────────
373
374    #[test]
375    fn config_default() {
376        let cfg = EmbeddingConfig::default();
377        assert!(cfg.model.is_empty());
378        assert_eq!(cfg.dimension, 0);
379        assert_eq!(cfg.batch_size, 32);
380    }
381
382    #[test]
383    fn config_parse_minimal() {
384        let toml_str = r#"batch_size = 16"#;
385        let cfg: EmbeddingConfig = toml::from_str(toml_str).unwrap();
386        assert_eq!(cfg.batch_size, 16);
387        assert!(cfg.model.is_empty());
388        assert_eq!(cfg.dimension, 0);
389    }
390
391    #[test]
392    fn config_parse_full() {
393        let toml_str = r#"
394model = "all-MiniLM-L6-v2"
395dimension = 384
396batch_size = 64
397"#;
398        let cfg: EmbeddingConfig = toml::from_str(toml_str).unwrap();
399        assert_eq!(cfg.model, "all-MiniLM-L6-v2");
400        assert_eq!(cfg.dimension, 384);
401        assert_eq!(cfg.batch_size, 64);
402    }
403
404    #[test]
405    fn config_parse_empty_uses_defaults() {
406        let cfg: EmbeddingConfig = toml::from_str("").unwrap();
407        assert!(cfg.model.is_empty());
408        assert_eq!(cfg.dimension, 0);
409        assert_eq!(cfg.batch_size, 32);
410    }
411
412    // ── Display tests ──────────────────────────────────────────────────
413
414    #[test]
415    fn config_display_with_values() {
416        let cfg = EmbeddingConfig {
417            model: "all-MiniLM-L6-v2".to_owned(),
418            dimension: 384,
419            batch_size: 32,
420        };
421        let s = format!("{cfg}");
422        assert!(s.contains("model=all-MiniLM-L6-v2"));
423        assert!(s.contains("dimension=384"));
424        assert!(s.contains("batch_size=32"));
425    }
426
427    #[test]
428    fn config_display_defaults() {
429        let cfg = EmbeddingConfig::default();
430        let s = format!("{cfg}");
431        assert!(s.contains("model=(default)"));
432        assert!(s.contains("dimension=(default)"));
433    }
434
435    // ── Factory tests ──────────────────────────────────────────────────
436
437    #[test]
438    fn create_provider_batch_size_zero_returns_error() {
439        let cfg = EmbeddingConfig {
440            batch_size: 0,
441            ..Default::default()
442        };
443        let result = create_provider(&cfg);
444        assert!(result.is_err());
445        assert!(result.unwrap_err().to_string().contains("batch_size"));
446    }
447
448    // ── Error display tests ────────────────────────────────────────────
449
450    #[test]
451    fn error_display_messages() {
452        let err = EmbeddingError::ProviderError("model load failed".to_owned());
453        assert!(err.to_string().contains("model load failed"));
454
455        let err = EmbeddingError::ParseError("bad data".to_owned());
456        assert!(err.to_string().contains("bad data"));
457
458        let err = EmbeddingError::CountMismatch {
459            expected: 3,
460            got: 1,
461        };
462        assert!(err.to_string().contains("3"));
463        assert!(err.to_string().contains("embedding vectors"));
464
465        let err = EmbeddingError::DimensionMismatch {
466            expected: 384,
467            got: 1536,
468        };
469        assert!(err.to_string().contains("384"));
470        assert!(err.to_string().contains("1536"));
471
472        let err = EmbeddingError::ConfigError("bad config".to_owned());
473        assert!(err.to_string().contains("bad config"));
474    }
475
476    // ── Trait object tests ─────────────────────────────────────────────
477
478    #[test]
479    fn provider_as_trait_object() {
480        let provider: Box<dyn EmbeddingProvider> = Box::new(MockProvider::new(384));
481        assert_eq!(provider.dimension(), 384);
482        let result = provider.embed(&["test".to_owned()]).unwrap();
483        assert_eq!(result.len(), 1);
484        assert_eq!(result[0].len(), 384);
485    }
486
487    #[test]
488    fn provider_send_sync() {
489        fn assert_send_sync<T: Send + Sync>() {}
490        assert_send_sync::<MockProvider>();
491    }
492
493    #[test]
494    fn config_display_custom_model() {
495        let cfg = EmbeddingConfig {
496            model: "custom-model".to_owned(),
497            dimension: 768,
498            batch_size: 64,
499        };
500        let s = format!("{cfg}");
501        assert!(s.contains("custom-model"));
502        assert!(s.contains("dimension=768"));
503        assert!(s.contains("batch_size=64"));
504    }
505
506    #[test]
507    fn config_display_zero_dimension() {
508        let cfg = EmbeddingConfig {
509            dimension: 0,
510            ..Default::default()
511        };
512        let s = format!("{cfg}");
513        assert!(s.contains("dimension=(default)"));
514    }
515
516    #[test]
517    fn mock_provider_debug() {
518        let provider = MockProvider::new(128);
519        let dbg = format!("{provider:?}");
520        assert!(dbg.contains("MockProvider"));
521    }
522
523    #[test]
524    fn error_display_count_mismatch() {
525        let err = EmbeddingError::CountMismatch {
526            expected: 10,
527            got: 5,
528        };
529        let s = err.to_string();
530        assert!(s.contains("10"));
531        assert!(s.contains("5"));
532    }
533
534    #[test]
535    fn error_display_dimension_mismatch() {
536        let err = EmbeddingError::DimensionMismatch {
537            expected: 512,
538            got: 384,
539        };
540        let s = err.to_string();
541        assert!(s.contains("512"));
542        assert!(s.contains("384"));
543        assert!(s.contains("dimension"));
544    }
545
546    #[test]
547    fn mock_provider_zero_dimension() {
548        let provider = MockProvider::new(0);
549        assert_eq!(provider.dimension(), 0);
550        let result = provider.embed(&["test".to_owned()]).unwrap();
551        assert_eq!(result.len(), 1);
552        assert_eq!(result[0].len(), 0);
553    }
554
555    #[test]
556    fn mock_provider_embedding_values() {
557        let provider = MockProvider::new(3);
558        let texts = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
559        let result = provider.embed(&texts).unwrap();
560        assert_eq!(result.len(), 3);
561        assert!((result[0][0] - 0.0).abs() < f32::EPSILON);
562        assert!((result[1][0] - 0.1).abs() < f32::EPSILON);
563        assert!((result[2][0] - 0.2).abs() < f32::EPSILON);
564    }
565
566    #[test]
567    fn create_provider_valid_config() {
568        let cfg = EmbeddingConfig::default();
569        let _ = create_provider(&cfg);
570    }
571
572    #[test]
573    fn embedding_error_display_parse_error() {
574        let err = EmbeddingError::ParseError("json malformed".to_owned());
575        assert!(err.to_string().contains("json malformed"));
576    }
577
578    #[test]
579    fn embedding_error_display_config_error() {
580        let err = EmbeddingError::ConfigError("unsupported provider".to_owned());
581        assert!(err.to_string().contains("unsupported provider"));
582    }
583}