Skip to main content

trusty_embedder/
lib.rs

1//! Shared text-embedding abstraction for trusty-* projects.
2//!
3//! Why: trusty-memory and trusty-search both shipped near-identical
4//! `Embedder` traits and `FastEmbedder` implementations, with subtle
5//! drift (cache vs no-cache, sync vs async warmup, `dim()` vs `dimension()`).
6//! Centralising fixes one bug in one place and lets future consumers pick up
7//! the embedder for free.
8//!
9//! What: an async `Embedder` trait with `embed_batch` as the single primitive
10//! (single-text embed is a free helper), plus a production `FastEmbedder`
11//! (fastembed-rs, all-MiniLM-L6-v2, 384-d) with LRU caching and ORT warmup,
12//! and a `MockEmbedder` test double behind the `test-support` feature.
13//!
14//! Test: `cargo test -p trusty-embedder` covers shape, cache hits, and the
15//! mock embedder. ONNX-backed tests are `#[ignore]` to keep CI under one
16//! cargo-feature umbrella.
17
18use std::num::NonZeroUsize;
19use std::sync::Arc;
20
21use anyhow::{Context, Result};
22use async_trait::async_trait;
23use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
24use lru::LruCache;
25use parking_lot::Mutex;
26
27/// Output dimension of the all-MiniLM-L6-v2 model.
28///
29/// Note: we now load the INT8-quantised variant (`AllMiniLML6V2Q`) which
30/// produces identical 384-dim vectors but runs ~3-4× faster on CPU ONNX
31/// and ships as a ~22MB file (vs 86MB for the f32 model).
32pub const EMBED_DIM: usize = 384;
33
34/// Default LRU cache capacity. Picked to be large enough to keep the
35/// hot working set of repeat queries in memory but small enough that the
36/// cache itself fits well inside L2/L3 on a typical developer machine.
37pub const DEFAULT_CACHE_CAPACITY: usize = 256;
38
39/// Abstraction over embedding backends.
40///
41/// Why: Decouple consumers from any one model so we can swap in remote APIs,
42/// quantised models, or deterministic mocks without changing call sites.
43/// What: a single primitive — `embed_batch` — plus a dimension accessor.
44/// Single-text callers should use the [`embed_one`] convenience helper.
45/// Test: covered by `FastEmbedder` and `MockEmbedder` tests below.
46#[async_trait]
47pub trait Embedder: Send + Sync {
48    /// Embed a batch of texts. Returns one `Vec<f32>` per input, each of
49    /// length `self.dimension()`. An empty input batch returns an empty Vec.
50    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
51
52    /// Output dimension of the produced embeddings.
53    fn dimension(&self) -> usize;
54}
55
56/// Convenience helper: embed a single text via `embed_batch` and return the
57/// lone vector.
58///
59/// Why: Most call sites only need one embedding at a time and writing
60/// `.embed_batch(&[text]).await?.into_iter().next()` everywhere is noise.
61/// What: builds a 1-element batch, calls `embed_batch`, returns the first
62/// vector (or errors if the embedder produced nothing).
63/// Test: covered indirectly by `mock_embedder_round_trip`.
64pub async fn embed_one(embedder: &dyn Embedder, text: &str) -> Result<Vec<f32>> {
65    let mut v = embedder.embed_batch(&[text.to_string()]).await?;
66    v.pop()
67        .context("embedder returned no embedding for non-empty input")
68}
69
70/// Local CPU embedder backed by fastembed-rs (ONNX runtime, all-MiniLM-L6-v2).
71///
72/// Why: Default to local-only embeddings so consumers have zero external
73/// network dependency and predictable latency. The LRU cache keeps the hot
74/// path free of redundant ONNX work for repeat strings (queries, common
75/// chunks).
76/// What: wraps a single `TextEmbedding` behind a `parking_lot::Mutex` (the
77/// underlying `embed` requires `&mut self`) and an `LruCache<String, Vec<f32>>`.
78/// Initialisation warms the ORT graph with a small batch so the first user
79/// query doesn't pay the one-shot compile cost.
80/// Test: `embed_batch_returns_correct_dim` and `cache_hit_is_idempotent`
81/// (marked `#[ignore]` — they download a real model).
82pub struct FastEmbedder {
83    model: Arc<Mutex<TextEmbedding>>,
84    cache: Arc<Mutex<LruCache<String, Vec<f32>>>>,
85    dim: usize,
86}
87
88impl FastEmbedder {
89    /// Construct a new `FastEmbedder` with the default cache size.
90    pub async fn new() -> Result<Self> {
91        Self::with_cache_size(DEFAULT_CACHE_CAPACITY).await
92    }
93
94    /// Build `TextInitOptions` for the given model, wiring in the CoreML
95    /// execution provider when the `coreml` feature is enabled.
96    ///
97    /// Why: fastembed-rs doesn't expose `coreml` as a passthrough feature, but
98    /// it does accept a `Vec<ExecutionProviderDispatch>` via
99    /// `with_execution_providers`. We construct `ep::CoreML::default().build()`
100    /// from our own `ort` dep (pinned to the same `=2.0.0-rc.12` that fastembed
101    /// uses) so the ONNX session for all-MiniLM-L6-v2 runs on the Apple GPU/ANE.
102    /// What: returns a configured `TextInitOptions`. With `coreml` off this is
103    /// just `TextInitOptions::new(model)`; with `coreml` on it appends a
104    /// CoreML EP to the dispatch list.
105    /// Test: `cargo build --features coreml` on macOS produces a binary that
106    /// logs CoreML EP registration when `RUST_LOG=ort=debug` is set.
107    fn init_options(model: EmbeddingModel) -> TextInitOptions {
108        let opts = TextInitOptions::new(model);
109        #[cfg(feature = "coreml")]
110        {
111            use ort::execution_providers::ExecutionProviderDispatch;
112            let coreml = ort::ep::CoreML::default().build();
113            let providers: Vec<ExecutionProviderDispatch> = vec![coreml];
114            tracing::info!("trusty-embedder: registering CoreML execution provider");
115            opts.with_execution_providers(providers)
116        }
117        #[cfg(not(feature = "coreml"))]
118        opts
119    }
120
121    /// Construct with an explicit LRU capacity.
122    pub async fn with_cache_size(capacity: usize) -> Result<Self> {
123        let capacity =
124            NonZeroUsize::new(capacity.max(1)).expect("capacity.max(1) is always non-zero");
125
126        // fastembed's `try_new` downloads + builds an ONNX session — blocking
127        // work that must run off the async reactor.
128        let model = tokio::task::spawn_blocking(|| -> Result<TextEmbedding> {
129            let mut m = TextEmbedding::try_new(Self::init_options(EmbeddingModel::AllMiniLML6V2Q))
130                .or_else(|q_err| {
131                    tracing::warn!(
132                        "AllMiniLML6V2Q init failed ({q_err:#}), falling back to AllMiniLML6V2"
133                    );
134                    TextEmbedding::try_new(Self::init_options(EmbeddingModel::AllMiniLML6V2))
135                })
136                .context(
137                    "failed to initialise fastembed (tried AllMiniLML6V2Q and AllMiniLML6V2)",
138                )?;
139
140            // Warm the graph so the first real user query is hot.
141            let warmup: Vec<&str> = vec![
142                "hello world",
143                "the quick brown fox",
144                "memory palace warmup",
145                "embedding model ready",
146                "trusty common warmup",
147            ];
148            let _ = m
149                .embed(warmup, None)
150                .context("fastembed warmup batch failed")?;
151            Ok(m)
152        })
153        .await
154        .context("spawn_blocking joined with error during embedder init")??;
155
156        Ok(Self {
157            model: Arc::new(Mutex::new(model)),
158            cache: Arc::new(Mutex::new(LruCache::new(capacity))),
159            dim: EMBED_DIM,
160        })
161    }
162}
163
164#[async_trait]
165impl Embedder for FastEmbedder {
166    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
167        if texts.is_empty() {
168            return Ok(Vec::new());
169        }
170
171        // Split into cached hits vs misses.
172        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
173        let mut to_compute: Vec<(usize, String)> = Vec::new();
174        {
175            let mut cache = self.cache.lock();
176            for (i, t) in texts.iter().enumerate() {
177                if let Some(v) = cache.get(t) {
178                    results[i] = Some(v.clone());
179                } else {
180                    to_compute.push((i, t.clone()));
181                }
182            }
183        }
184
185        if !to_compute.is_empty() {
186            let model = Arc::clone(&self.model);
187            let owned: Vec<String> = to_compute.iter().map(|(_, s)| s.clone()).collect();
188            let computed = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
189                let mut guard = model.lock();
190                guard
191                    .embed(owned, None)
192                    .context("fastembed embed call failed")
193            })
194            .await
195            .context("spawn_blocking joined with error during embed")??;
196
197            if computed.len() != to_compute.len() {
198                anyhow::bail!(
199                    "fastembed returned {} embeddings, expected {}",
200                    computed.len(),
201                    to_compute.len()
202                );
203            }
204
205            let mut cache = self.cache.lock();
206            for ((idx, key), vector) in to_compute.into_iter().zip(computed.into_iter()) {
207                cache.put(key, vector.clone());
208                results[idx] = Some(vector);
209            }
210        }
211
212        results
213            .into_iter()
214            .map(|opt| opt.context("missing embedding slot after batch"))
215            .collect()
216    }
217
218    fn dimension(&self) -> usize {
219        self.dim
220    }
221}
222
223/// Deterministic test double — hashes input bytes into a fixed-dim vector.
224///
225/// Why: ONNX model downloads dominate test runtime and can race on cold
226/// caches when multiple tests construct embedders in parallel. The mock
227/// gives integration tests a "rank by similarity" surface without any I/O.
228/// What: a tiny per-byte hash spread across `dim` slots, with the first byte
229/// always contributing so short/empty strings still differ.
230/// Test: `mock_embedder_round_trip` confirms shape + determinism.
231#[cfg(any(test, feature = "test-support"))]
232pub struct MockEmbedder {
233    dim: usize,
234}
235
236#[cfg(any(test, feature = "test-support"))]
237impl MockEmbedder {
238    pub fn new(dim: usize) -> Self {
239        Self { dim }
240    }
241
242    fn hash_to_vec(&self, text: &str) -> Vec<f32> {
243        let mut v = vec![0.0_f32; self.dim];
244        for (i, b) in text.bytes().enumerate() {
245            let slot = (i + b as usize) % self.dim;
246            v[slot] += (b as f32) / 255.0;
247        }
248        if let Some(first) = text.bytes().next() {
249            v[0] += first as f32 / 255.0;
250        }
251        v
252    }
253}
254
255#[cfg(any(test, feature = "test-support"))]
256#[async_trait]
257impl Embedder for MockEmbedder {
258    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
259        Ok(texts.iter().map(|t| self.hash_to_vec(t)).collect())
260    }
261
262    fn dimension(&self) -> usize {
263        self.dim
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[tokio::test]
272    async fn mock_embedder_round_trip() {
273        let e = MockEmbedder::new(EMBED_DIM);
274        assert_eq!(e.dimension(), EMBED_DIM);
275        let v = embed_one(&e, "hello").await.unwrap();
276        assert_eq!(v.len(), EMBED_DIM);
277        let batch = e
278            .embed_batch(&["a".to_string(), "b".to_string()])
279            .await
280            .unwrap();
281        assert_eq!(batch.len(), 2);
282        assert_ne!(batch[0], batch[1]);
283    }
284
285    #[tokio::test]
286    async fn mock_embedder_empty_input_returns_empty() {
287        let e = MockEmbedder::new(EMBED_DIM);
288        let v = e.embed_batch(&[]).await.unwrap();
289        assert!(v.is_empty());
290    }
291
292    // ONNX-backed test: downloads ~23MB on first run. Marked ignored so default
293    // `cargo test` stays offline; run with `cargo test -- --ignored` when needed.
294    #[tokio::test]
295    #[ignore]
296    async fn fastembed_returns_correct_dim() {
297        let e = FastEmbedder::new().await.unwrap();
298        assert_eq!(e.dimension(), 384);
299        let v = embed_one(&e, "fn authenticate(user: &str) -> bool")
300            .await
301            .unwrap();
302        assert_eq!(v.len(), 384);
303        assert!(v.iter().any(|x| *x != 0.0));
304    }
305
306    #[tokio::test]
307    #[ignore]
308    async fn fastembed_cache_hit_is_idempotent() {
309        let e = FastEmbedder::new().await.unwrap();
310        let v1 = embed_one(&e, "cached").await.unwrap();
311        let v2 = embed_one(&e, "cached").await.unwrap();
312        assert_eq!(v1, v2);
313    }
314}