1pub trait Embedder: Send + Sync {
18 fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
20 fn model_id(&self) -> &str;
23 fn dim(&self) -> usize;
25
26 fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
28 let mut v = self.embed(&[text])?;
29 Ok(v.pop().unwrap_or_default())
30 }
31}
32
33pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
37 if a.len() != b.len() || a.is_empty() {
38 return 0.0;
39 }
40 let mut dot = 0.0f32;
41 let mut na = 0.0f32;
42 let mut nb = 0.0f32;
43 for i in 0..a.len() {
44 dot += a[i] * b[i];
45 na += a[i] * a[i];
46 nb += b[i] * b[i];
47 }
48 if na == 0.0 || nb == 0.0 {
49 return 0.0;
50 }
51 dot / (na.sqrt() * nb.sqrt())
52}
53
54pub fn to_blob(v: &[f32]) -> Vec<u8> {
57 let mut out = Vec::with_capacity(v.len() * 4);
58 for f in v {
59 out.extend_from_slice(&f.to_le_bytes());
60 }
61 out
62}
63
64pub fn from_blob(b: &[u8]) -> Vec<f32> {
68 b.chunks_exact(4)
69 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
70 .collect()
71}
72
73pub fn is_embeddable(text: &str) -> bool {
76 text.trim().chars().count() >= 12
77}
78
79pub struct HashEmbedder {
86 dim: usize,
87}
88
89impl HashEmbedder {
90 pub fn new(dim: usize) -> Self {
91 Self { dim: dim.max(1) }
92 }
93
94 fn hash_token(tok: &str) -> u64 {
95 let mut h: u64 = 0xcbf29ce484222325;
97 for b in tok.bytes() {
98 h ^= b as u64;
99 h = h.wrapping_mul(0x100000001b3);
100 }
101 h
102 }
103}
104
105impl Default for HashEmbedder {
106 fn default() -> Self {
107 Self::new(64)
108 }
109}
110
111impl Embedder for HashEmbedder {
112 fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
113 let mut out = Vec::with_capacity(texts.len());
114 for t in texts {
115 let mut v = vec![0.0f32; self.dim];
116 for tok in t
117 .split(|c: char| !c.is_alphanumeric())
118 .filter(|s| !s.is_empty())
119 {
120 let lower = tok.to_lowercase();
121 let bucket = (Self::hash_token(&lower) as usize) % self.dim;
122 v[bucket] += 1.0;
123 }
124 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
126 if norm > 0.0 {
127 for x in &mut v {
128 *x /= norm;
129 }
130 }
131 out.push(v);
132 }
133 Ok(out)
134 }
135
136 fn model_id(&self) -> &str {
137 "hash-v1"
138 }
139
140 fn dim(&self) -> usize {
141 self.dim
142 }
143}
144
145#[cfg(feature = "embed")]
148pub const DEFAULT_EMBED_MODEL: &str = "minishlab/potion-multilingual-128M";
149
150pub fn default_embedder() -> Box<dyn Embedder> {
156 if std::env::var("TJ_EMBED").as_deref() == Ok("hash") {
158 return Box::new(HashEmbedder::default());
159 }
160 #[cfg(feature = "embed")]
161 {
162 let repo =
163 std::env::var("TJ_EMBED_MODEL").unwrap_or_else(|_| DEFAULT_EMBED_MODEL.to_string());
164 match Model2VecEmbedder::load(&repo) {
165 Ok(m) => return Box::new(m),
166 Err(e) => {
167 tracing::warn!("model2vec load failed ({e:#}); using hash embedder fallback");
168 }
169 }
170 }
171 Box::new(HashEmbedder::default())
172}
173
174#[cfg(feature = "embed")]
178pub struct Model2VecEmbedder {
179 model: model2vec_rs::model::StaticModel,
180 model_id: String,
181 dim: usize,
182}
183
184#[cfg(feature = "embed")]
185impl Model2VecEmbedder {
186 pub fn load(repo: &str) -> anyhow::Result<Self> {
189 let model = model2vec_rs::model::StaticModel::from_pretrained(
190 repo,
191 None, Some(true), None, )?;
195 let dim = model.encode_single("probe").len();
196 anyhow::ensure!(
197 dim > 0,
198 "model2vec model {repo} produced a zero-dim embedding"
199 );
200 Ok(Self {
201 model,
202 model_id: format!("model2vec:{repo}"),
203 dim,
204 })
205 }
206}
207
208#[cfg(feature = "embed")]
209impl Embedder for Model2VecEmbedder {
210 fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
211 let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
212 Ok(self.model.encode(&owned))
213 }
214
215 fn model_id(&self) -> &str {
216 &self.model_id
217 }
218
219 fn dim(&self) -> usize {
220 self.dim
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn cosine_identical_is_one() {
230 let v = vec![1.0, 2.0, 3.0];
231 assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
232 }
233
234 #[test]
235 fn cosine_orthogonal_is_zero() {
236 assert_eq!(cosine(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
237 }
238
239 #[test]
240 fn cosine_mismatch_or_zero_norm_is_zero() {
241 assert_eq!(cosine(&[1.0, 2.0], &[1.0]), 0.0);
242 assert_eq!(cosine(&[0.0, 0.0], &[1.0, 1.0]), 0.0);
243 }
244
245 #[test]
246 fn blob_round_trips() {
247 let v = vec![0.5, -1.25, 3.0, 0.0];
248 assert_eq!(from_blob(&to_blob(&v)), v);
249 }
250
251 #[test]
252 fn is_embeddable_skips_short_boilerplate() {
253 assert!(!is_embeddable(""));
254 assert!(!is_embeddable("[open]"));
255 assert!(is_embeddable("Fix the auth bug in middleware"));
256 }
257
258 #[test]
259 fn hash_embedder_is_deterministic_and_normalised() {
260 let e = HashEmbedder::new(32);
261 let a = e.embed_one("payment gateway dedup").unwrap();
262 let b = e.embed_one("payment gateway dedup").unwrap();
263 assert_eq!(a, b);
264 assert_eq!(a.len(), 32);
265 let norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
266 assert!((norm - 1.0).abs() < 1e-5);
267 }
268
269 #[test]
270 fn hash_embedder_overlap_ranks_above_disjoint() {
271 let e = HashEmbedder::new(256);
272 let q = e.embed_one("payment refund duplicate write").unwrap();
273 let near = e.embed_one("duplicate refund write on payment").unwrap();
274 let far = e.embed_one("frontend button color tweak").unwrap();
275 assert!(
276 cosine(&q, &near) > cosine(&q, &far),
277 "lexical overlap must score higher: near={} far={}",
278 cosine(&q, &near),
279 cosine(&q, &far)
280 );
281 }
282}