trusty_embedder/lib.rs
1//! Shared text-embedding abstraction for trusty-* projects.
2//!
3//! Why: trusty-memory and trusty-search both shipped near-identical
4//! `Embedder` traits and `FastEmbedder` implementations, with subtle
5//! drift (cache vs no-cache, sync vs async warmup, `dim()` vs `dimension()`).
6//! Centralising fixes one bug in one place and lets future consumers pick up
7//! the embedder for free.
8//!
9//! What: an async `Embedder` trait with `embed_batch` as the single primitive
10//! (single-text embed is a free helper), plus a production `FastEmbedder`
11//! (fastembed-rs, all-MiniLM-L6-v2, 384-d) with LRU caching and ORT warmup,
12//! and a `MockEmbedder` test double behind the `test-support` feature.
13//!
14//! Test: `cargo test -p trusty-embedder` covers shape, cache hits, and the
15//! mock embedder. ONNX-backed tests are `#[ignore]` to keep CI under one
16//! cargo-feature umbrella.
17
18use std::num::NonZeroUsize;
19use std::sync::Arc;
20
21use anyhow::{Context, Result};
22use async_trait::async_trait;
23use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
24use lru::LruCache;
25use parking_lot::Mutex;
26
27/// Output dimension of the all-MiniLM-L6-v2 model.
28///
29/// Note: we now load the INT8-quantised variant (`AllMiniLML6V2Q`) which
30/// produces identical 384-dim vectors but runs ~3-4× faster on CPU ONNX
31/// and ships as a ~22MB file (vs 86MB for the f32 model).
32pub const EMBED_DIM: usize = 384;
33
34/// Default LRU cache capacity. Picked to be large enough to keep the
35/// hot working set of repeat queries in memory but small enough that the
36/// cache itself fits well inside L2/L3 on a typical developer machine.
37pub const DEFAULT_CACHE_CAPACITY: usize = 256;
38
39/// Identifier for the execution provider an embedder is actually using.
40///
41/// Why: callers want to log which backend is active (CPU vs CoreML/Metal vs
42/// CUDA) so operators can verify the daemon is GPU-accelerated without a
43/// debug log dive.
44/// What: a stable, human-friendly tag returned by `FastEmbedder::provider()`.
45/// Test: `FastEmbedder::new()` on Apple Silicon should yield `CoreML`; on
46/// other platforms it yields `Cpu` (or `Cuda` when the `cuda` feature is on).
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ExecutionProvider {
49 Cpu,
50 CoreML,
51 Cuda,
52}
53
54impl ExecutionProvider {
55 pub fn as_str(&self) -> &'static str {
56 match self {
57 ExecutionProvider::Cpu => "CPU",
58 ExecutionProvider::CoreML => "CoreML",
59 ExecutionProvider::Cuda => "CUDA",
60 }
61 }
62}
63
64impl std::fmt::Display for ExecutionProvider {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.write_str(self.as_str())
67 }
68}
69
70/// Abstraction over embedding backends.
71///
72/// Why: Decouple consumers from any one model so we can swap in remote APIs,
73/// quantised models, or deterministic mocks without changing call sites.
74/// What: a single primitive — `embed_batch` — plus a dimension accessor.
75/// Single-text callers should use the [`embed_one`] convenience helper.
76/// Test: covered by `FastEmbedder` and `MockEmbedder` tests below.
77#[async_trait]
78pub trait Embedder: Send + Sync {
79 /// Embed a batch of texts. Returns one `Vec<f32>` per input, each of
80 /// length `self.dimension()`. An empty input batch returns an empty Vec.
81 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
82
83 /// Output dimension of the produced embeddings.
84 fn dimension(&self) -> usize;
85}
86
87/// Convenience helper: embed a single text via `embed_batch` and return the
88/// lone vector.
89///
90/// Why: Most call sites only need one embedding at a time and writing
91/// `.embed_batch(&[text]).await?.into_iter().next()` everywhere is noise.
92/// What: builds a 1-element batch, calls `embed_batch`, returns the first
93/// vector (or errors if the embedder produced nothing).
94/// Test: covered indirectly by `mock_embedder_round_trip`.
95pub async fn embed_one(embedder: &dyn Embedder, text: &str) -> Result<Vec<f32>> {
96 let mut v = embedder.embed_batch(&[text.to_string()]).await?;
97 v.pop()
98 .context("embedder returned no embedding for non-empty input")
99}
100
101/// Local CPU embedder backed by fastembed-rs (ONNX runtime, all-MiniLM-L6-v2).
102///
103/// Why: Default to local-only embeddings so consumers have zero external
104/// network dependency and predictable latency. The LRU cache keeps the hot
105/// path free of redundant ONNX work for repeat strings (queries, common
106/// chunks).
107/// What: wraps a single `TextEmbedding` behind a `parking_lot::Mutex` (the
108/// underlying `embed` requires `&mut self`) and an `LruCache<String, Vec<f32>>`.
109/// Initialisation warms the ORT graph with a small batch so the first user
110/// query doesn't pay the one-shot compile cost.
111/// Test: `embed_batch_returns_correct_dim` and `cache_hit_is_idempotent`
112/// (marked `#[ignore]` — they download a real model).
113pub struct FastEmbedder {
114 model: Arc<Mutex<TextEmbedding>>,
115 cache: Arc<Mutex<LruCache<String, Vec<f32>>>>,
116 dim: usize,
117 provider: ExecutionProvider,
118}
119
120impl FastEmbedder {
121 /// Construct a new `FastEmbedder` with the default cache size.
122 pub async fn new() -> Result<Self> {
123 Self::with_cache_size(DEFAULT_CACHE_CAPACITY).await
124 }
125
126 /// Identifier for the execution provider this embedder is actually using.
127 ///
128 /// Why: callers (e.g. `trusty-search` startup logs) want to surface
129 /// whether the daemon is running on CPU or GPU/ANE without poking at
130 /// internals.
131 /// What: returns `ExecutionProvider::CoreML` on Apple Silicon (when EP
132 /// registration succeeded), otherwise `Cpu` (or `Cuda` if/when wired).
133 /// Test: covered by the public-surface compile check.
134 pub fn provider(&self) -> ExecutionProvider {
135 self.provider
136 }
137
138 /// Build `TextInitOptions` for the given model, attempting to register
139 /// the CoreML execution provider at runtime when on Apple Silicon.
140 ///
141 /// Why: We want zero-friction GPU/ANE acceleration on Apple Silicon
142 /// without forcing users to pass `--features coreml`. fastembed-rs accepts
143 /// a `Vec<ExecutionProviderDispatch>` via `with_execution_providers`, and
144 /// our `ort` dep (pinned to the exact `=2.0.0-rc.12` fastembed uses) has
145 /// the `coreml` feature on by default on macOS, so we can always try to
146 /// build and register CoreML at runtime. On non-Apple platforms, or if
147 /// CoreML registration fails for any reason, we transparently fall back
148 /// to the default CPU provider.
149 /// What: returns `(TextInitOptions, ExecutionProvider)` where the tag
150 /// reflects which backend was actually wired in.
151 /// Test: on an M-series Mac the tag is `CoreML`; on Intel/Linux/Windows
152 /// (or if CoreML build fails) the tag is `Cpu`.
153 fn init_options(model: EmbeddingModel) -> (TextInitOptions, ExecutionProvider) {
154 let opts = TextInitOptions::new(model);
155
156 #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
157 {
158 use ort::execution_providers::ExecutionProviderDispatch;
159 let coreml: ExecutionProviderDispatch = ort::ep::CoreML::default().build();
160 let providers: Vec<ExecutionProviderDispatch> = vec![coreml];
161 tracing::info!(
162 "trusty-embedder: registering CoreML execution provider (Apple Silicon)"
163 );
164 return (opts.with_execution_providers(providers), ExecutionProvider::CoreML);
165 }
166
167 #[allow(unreachable_code)]
168 (opts, ExecutionProvider::Cpu)
169 }
170
171 /// Construct with an explicit LRU capacity.
172 pub async fn with_cache_size(capacity: usize) -> Result<Self> {
173 let capacity =
174 NonZeroUsize::new(capacity.max(1)).expect("capacity.max(1) is always non-zero");
175
176 // fastembed's `try_new` downloads + builds an ONNX session — blocking
177 // work that must run off the async reactor.
178 let (model, provider) =
179 tokio::task::spawn_blocking(|| -> Result<(TextEmbedding, ExecutionProvider)> {
180 let (q_opts, q_provider) = Self::init_options(EmbeddingModel::AllMiniLML6V2Q);
181 let (m, provider) = match TextEmbedding::try_new(q_opts) {
182 Ok(m) => (m, q_provider),
183 Err(q_err) => {
184 tracing::warn!(
185 "AllMiniLML6V2Q init failed ({q_err:#}), falling back to AllMiniLML6V2"
186 );
187 let (fb_opts, fb_provider) =
188 Self::init_options(EmbeddingModel::AllMiniLML6V2);
189 let m = TextEmbedding::try_new(fb_opts).context(
190 "failed to initialise fastembed (tried AllMiniLML6V2Q and AllMiniLML6V2)",
191 )?;
192 (m, fb_provider)
193 }
194 };
195 let mut m = m;
196
197 // Warm the graph so the first real user query is hot.
198 let warmup: Vec<&str> = vec![
199 "hello world",
200 "the quick brown fox",
201 "memory palace warmup",
202 "embedding model ready",
203 "trusty common warmup",
204 ];
205 let _ = m
206 .embed(warmup, None)
207 .context("fastembed warmup batch failed")?;
208 Ok((m, provider))
209 })
210 .await
211 .context("spawn_blocking joined with error during embedder init")??;
212
213 tracing::info!(
214 "trusty-embedder: FastEmbedder ready (provider={}, dim={})",
215 provider,
216 EMBED_DIM
217 );
218
219 Ok(Self {
220 model: Arc::new(Mutex::new(model)),
221 cache: Arc::new(Mutex::new(LruCache::new(capacity))),
222 dim: EMBED_DIM,
223 provider,
224 })
225 }
226}
227
228#[async_trait]
229impl Embedder for FastEmbedder {
230 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
231 if texts.is_empty() {
232 return Ok(Vec::new());
233 }
234
235 // Split into cached hits vs misses.
236 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
237 let mut to_compute: Vec<(usize, String)> = Vec::new();
238 {
239 let mut cache = self.cache.lock();
240 for (i, t) in texts.iter().enumerate() {
241 if let Some(v) = cache.get(t) {
242 results[i] = Some(v.clone());
243 } else {
244 to_compute.push((i, t.clone()));
245 }
246 }
247 }
248
249 if !to_compute.is_empty() {
250 let model = Arc::clone(&self.model);
251 let owned: Vec<String> = to_compute.iter().map(|(_, s)| s.clone()).collect();
252 let computed = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
253 let mut guard = model.lock();
254 guard
255 .embed(owned, None)
256 .context("fastembed embed call failed")
257 })
258 .await
259 .context("spawn_blocking joined with error during embed")??;
260
261 if computed.len() != to_compute.len() {
262 anyhow::bail!(
263 "fastembed returned {} embeddings, expected {}",
264 computed.len(),
265 to_compute.len()
266 );
267 }
268
269 let mut cache = self.cache.lock();
270 for ((idx, key), vector) in to_compute.into_iter().zip(computed.into_iter()) {
271 cache.put(key, vector.clone());
272 results[idx] = Some(vector);
273 }
274 }
275
276 results
277 .into_iter()
278 .map(|opt| opt.context("missing embedding slot after batch"))
279 .collect()
280 }
281
282 fn dimension(&self) -> usize {
283 self.dim
284 }
285}
286
287/// Deterministic test double — hashes input bytes into a fixed-dim vector.
288///
289/// Why: ONNX model downloads dominate test runtime and can race on cold
290/// caches when multiple tests construct embedders in parallel. The mock
291/// gives integration tests a "rank by similarity" surface without any I/O.
292/// What: a tiny per-byte hash spread across `dim` slots, with the first byte
293/// always contributing so short/empty strings still differ.
294/// Test: `mock_embedder_round_trip` confirms shape + determinism.
295#[cfg(any(test, feature = "test-support"))]
296pub struct MockEmbedder {
297 dim: usize,
298}
299
300#[cfg(any(test, feature = "test-support"))]
301impl MockEmbedder {
302 pub fn new(dim: usize) -> Self {
303 Self { dim }
304 }
305
306 fn hash_to_vec(&self, text: &str) -> Vec<f32> {
307 let mut v = vec![0.0_f32; self.dim];
308 for (i, b) in text.bytes().enumerate() {
309 let slot = (i + b as usize) % self.dim;
310 v[slot] += (b as f32) / 255.0;
311 }
312 if let Some(first) = text.bytes().next() {
313 v[0] += first as f32 / 255.0;
314 }
315 v
316 }
317}
318
319#[cfg(any(test, feature = "test-support"))]
320#[async_trait]
321impl Embedder for MockEmbedder {
322 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
323 Ok(texts.iter().map(|t| self.hash_to_vec(t)).collect())
324 }
325
326 fn dimension(&self) -> usize {
327 self.dim
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[tokio::test]
336 async fn mock_embedder_round_trip() {
337 let e = MockEmbedder::new(EMBED_DIM);
338 assert_eq!(e.dimension(), EMBED_DIM);
339 let v = embed_one(&e, "hello").await.unwrap();
340 assert_eq!(v.len(), EMBED_DIM);
341 let batch = e
342 .embed_batch(&["a".to_string(), "b".to_string()])
343 .await
344 .unwrap();
345 assert_eq!(batch.len(), 2);
346 assert_ne!(batch[0], batch[1]);
347 }
348
349 #[tokio::test]
350 async fn mock_embedder_empty_input_returns_empty() {
351 let e = MockEmbedder::new(EMBED_DIM);
352 let v = e.embed_batch(&[]).await.unwrap();
353 assert!(v.is_empty());
354 }
355
356 // ONNX-backed test: downloads ~23MB on first run. Marked ignored so default
357 // `cargo test` stays offline; run with `cargo test -- --ignored` when needed.
358 #[tokio::test]
359 #[ignore]
360 async fn fastembed_returns_correct_dim() {
361 let e = FastEmbedder::new().await.unwrap();
362 assert_eq!(e.dimension(), 384);
363 let v = embed_one(&e, "fn authenticate(user: &str) -> bool")
364 .await
365 .unwrap();
366 assert_eq!(v.len(), 384);
367 assert!(v.iter().any(|x| *x != 0.0));
368 }
369
370 #[tokio::test]
371 #[ignore]
372 async fn fastembed_cache_hit_is_idempotent() {
373 let e = FastEmbedder::new().await.unwrap();
374 let v1 = embed_one(&e, "cached").await.unwrap();
375 let v2 = embed_one(&e, "cached").await.unwrap();
376 assert_eq!(v1, v2);
377 }
378}