Skip to main content

tj_core/
embed.rs

1//! Embedding substrate for semantic memory (Pillar A).
2//!
3//! An [`Embedder`] turns text into a fixed-dimension vector so events can be
4//! retrieved by *meaning*, not just keyword (FTS5). The real semantic backend is
5//! a pure-Rust static model (model2vec, behind the `embed` feature); when it is
6//! absent every caller falls back to FTS5, so the journal's zero-cost,
7//! offline-by-default behaviour is preserved.
8//!
9//! This module is dependency-free on purpose: the trait, the cosine/recency
10//! math, the SQLite blob codec, and a deterministic [`HashEmbedder`] all build
11//! and test without pulling a model. The model2vec backend is added as an
12//! isolated, feature-gated step on top.
13
14/// A text embedder. Implementations return exactly one vector per input, all of
15/// the same [`dim`](Embedder::dim), produced by the model named by
16/// [`model_id`](Embedder::model_id).
17pub trait Embedder: Send + Sync {
18    /// Embed a batch of texts. `out[i]` corresponds to `texts[i]`.
19    fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
20    /// Stable identifier of the model (stored per vector so a model change can
21    /// trigger a re-embed and we never compare vectors across models).
22    fn model_id(&self) -> &str;
23    /// Output dimensionality.
24    fn dim(&self) -> usize;
25
26    /// Convenience: embed a single text.
27    fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
28        let mut v = self.embed(&[text])?;
29        Ok(v.pop().unwrap_or_default())
30    }
31}
32
33/// Cosine similarity of two vectors. Returns `0.0` on a length mismatch or a
34/// zero-norm input — callers *rank* with this, they don't assert on it, so it
35/// must never panic.
36pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
37    if a.len() != b.len() || a.is_empty() {
38        return 0.0;
39    }
40    let mut dot = 0.0f32;
41    let mut na = 0.0f32;
42    let mut nb = 0.0f32;
43    for i in 0..a.len() {
44        dot += a[i] * b[i];
45        na += a[i] * a[i];
46        nb += b[i] * b[i];
47    }
48    if na == 0.0 || nb == 0.0 {
49        return 0.0;
50    }
51    dot / (na.sqrt() * nb.sqrt())
52}
53
54/// Encode an `f32` vector as a little-endian byte blob for SQLite `BLOB`
55/// storage. Round-trips with [`from_blob`].
56pub fn to_blob(v: &[f32]) -> Vec<u8> {
57    let mut out = Vec::with_capacity(v.len() * 4);
58    for f in v {
59        out.extend_from_slice(&f.to_le_bytes());
60    }
61    out
62}
63
64/// Decode a little-endian byte blob back into an `f32` vector. Trailing bytes
65/// that don't form a full `f32` are ignored (defensive; should never happen for
66/// blobs produced by [`to_blob`]).
67pub fn from_blob(b: &[u8]) -> Vec<f32> {
68    b.chunks_exact(4)
69        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
70        .collect()
71}
72
73/// Whether an event's text is worth embedding. Skips empties and very short
74/// boilerplate (e.g. the `[open]` marker) that carry no retrievable meaning.
75pub fn is_embeddable(text: &str) -> bool {
76    text.trim().chars().count() >= 12
77}
78
79/// A deterministic, dependency-free embedder using the feature-hashing trick:
80/// each token is hashed into one of `dim` buckets and the resulting bag-of-words
81/// vector is L2-normalised. It is **lexical**, not semantic — its job is to make
82/// the trait, storage, ingest and ranking code testable without a model, and to
83/// serve as a crude offline fallback. The real semantic quality comes from the
84/// model2vec backend.
85pub struct HashEmbedder {
86    dim: usize,
87}
88
89impl HashEmbedder {
90    pub fn new(dim: usize) -> Self {
91        Self { dim: dim.max(1) }
92    }
93
94    fn hash_token(tok: &str) -> u64 {
95        // FNV-1a — small, deterministic, no deps.
96        let mut h: u64 = 0xcbf29ce484222325;
97        for b in tok.bytes() {
98            h ^= b as u64;
99            h = h.wrapping_mul(0x100000001b3);
100        }
101        h
102    }
103}
104
105impl Default for HashEmbedder {
106    fn default() -> Self {
107        Self::new(64)
108    }
109}
110
111impl Embedder for HashEmbedder {
112    fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
113        let mut out = Vec::with_capacity(texts.len());
114        for t in texts {
115            let mut v = vec![0.0f32; self.dim];
116            for tok in t
117                .split(|c: char| !c.is_alphanumeric())
118                .filter(|s| !s.is_empty())
119            {
120                let lower = tok.to_lowercase();
121                let bucket = (Self::hash_token(&lower) as usize) % self.dim;
122                v[bucket] += 1.0;
123            }
124            // L2-normalise so cosine == dot product and lengths don't bias.
125            let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
126            if norm > 0.0 {
127                for x in &mut v {
128                    *x /= norm;
129                }
130            }
131            out.push(v);
132        }
133        Ok(out)
134    }
135
136    fn model_id(&self) -> &str {
137        "hash-v1"
138    }
139
140    fn dim(&self) -> usize {
141        self.dim
142    }
143}
144
145/// Default model2vec repo — multilingual so RU/EN prose both embed well.
146/// Overridable via `TJ_EMBED_MODEL`.
147#[cfg(feature = "embed")]
148pub const DEFAULT_EMBED_MODEL: &str = "minishlab/potion-multilingual-128M";
149
150/// The embedder the journal uses unless overridden. With the `embed` feature
151/// (on by default) it loads the model2vec static model for true semantic
152/// recall; if that can't load — offline first run, download failure, or
153/// `TJ_EMBED=hash` — it falls back to the dependency-free lexical
154/// [`HashEmbedder`] so the journal never breaks.
155pub fn default_embedder() -> Box<dyn Embedder> {
156    // Test/escape hatch: force the deterministic lexical embedder.
157    if std::env::var("TJ_EMBED").as_deref() == Ok("hash") {
158        return Box::new(HashEmbedder::default());
159    }
160    #[cfg(feature = "embed")]
161    {
162        let repo =
163            std::env::var("TJ_EMBED_MODEL").unwrap_or_else(|_| DEFAULT_EMBED_MODEL.to_string());
164        match Model2VecEmbedder::load(&repo) {
165            Ok(m) => return Box::new(m),
166            Err(e) => {
167                tracing::warn!("model2vec load failed ({e:#}); using hash embedder fallback");
168            }
169        }
170    }
171    Box::new(HashEmbedder::default())
172}
173
174/// True semantic embedder backed by a model2vec static model (pure-Rust, no
175/// onnxruntime). The model is downloaded once via the HuggingFace hub and
176/// cached locally; later loads read the cache. Behind the `embed` feature.
177#[cfg(feature = "embed")]
178pub struct Model2VecEmbedder {
179    model: model2vec_rs::model::StaticModel,
180    model_id: String,
181    dim: usize,
182}
183
184#[cfg(feature = "embed")]
185impl Model2VecEmbedder {
186    /// Load `repo` (a HuggingFace model id or a local directory). Probes the
187    /// model once to discover its output dimension.
188    pub fn load(repo: &str) -> anyhow::Result<Self> {
189        let model = model2vec_rs::model::StaticModel::from_pretrained(
190            repo,
191            None,       // no auth token
192            Some(true), // L2-normalise outputs
193            None,       // no subfolder
194        )?;
195        let dim = model.encode_single("probe").len();
196        anyhow::ensure!(
197            dim > 0,
198            "model2vec model {repo} produced a zero-dim embedding"
199        );
200        Ok(Self {
201            model,
202            model_id: format!("model2vec:{repo}"),
203            dim,
204        })
205    }
206}
207
208#[cfg(feature = "embed")]
209impl Embedder for Model2VecEmbedder {
210    fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
211        let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
212        Ok(self.model.encode(&owned))
213    }
214
215    fn model_id(&self) -> &str {
216        &self.model_id
217    }
218
219    fn dim(&self) -> usize {
220        self.dim
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn cosine_identical_is_one() {
230        let v = vec![1.0, 2.0, 3.0];
231        assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
232    }
233
234    #[test]
235    fn cosine_orthogonal_is_zero() {
236        assert_eq!(cosine(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
237    }
238
239    #[test]
240    fn cosine_mismatch_or_zero_norm_is_zero() {
241        assert_eq!(cosine(&[1.0, 2.0], &[1.0]), 0.0);
242        assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
243    }
244
245    #[test]
246    fn blob_round_trips() {
247        let v = vec![0.5, -1.25, 3.0, 0.0];
248        assert_eq!(from_blob(&to_blob(&v)), v);
249    }
250
251    #[test]
252    fn is_embeddable_skips_short_boilerplate() {
253        assert!(!is_embeddable(""));
254        assert!(!is_embeddable("[open]"));
255        assert!(is_embeddable("Fix the auth bug in middleware"));
256    }
257
258    #[test]
259    fn hash_embedder_is_deterministic_and_normalised() {
260        let e = HashEmbedder::new(32);
261        let a = e.embed_one("payment gateway dedup").unwrap();
262        let b = e.embed_one("payment gateway dedup").unwrap();
263        assert_eq!(a, b);
264        assert_eq!(a.len(), 32);
265        let norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
266        assert!((norm - 1.0).abs() < 1e-5);
267    }
268
269    #[test]
270    fn hash_embedder_overlap_ranks_above_disjoint() {
271        let e = HashEmbedder::new(256);
272        let q = e.embed_one("payment refund duplicate write").unwrap();
273        let near = e.embed_one("duplicate refund write on payment").unwrap();
274        let far = e.embed_one("frontend button color tweak").unwrap();
275        assert!(
276            cosine(&q, &near) > cosine(&q, &far),
277            "lexical overlap must score higher: near={} far={}",
278            cosine(&q, &near),
279            cosine(&q, &far)
280        );
281    }
282}