1use crate::error::Result;
16
17pub trait Embedder: Send + Sync {
19 fn dimensions(&self) -> usize;
21
22 fn embed(&self, text: &str) -> Result<Vec<f32>>;
24
25 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
28 texts.iter().map(|t| self.embed(t)).collect()
29 }
30}
31
32#[derive(Debug, Clone)]
40pub struct HashEmbedder {
41 dims: usize,
42}
43
44impl HashEmbedder {
45 pub fn new(dims: usize) -> Self {
47 assert!(dims > 0, "HashEmbedder dimension must be > 0");
48 Self { dims }
49 }
50
51 #[inline]
53 fn token_hash(token: &str) -> u64 {
54 let mut h: u64 = 0xcbf29ce484222325;
55 for b in token.bytes() {
56 h ^= b as u64;
57 h = h.wrapping_mul(0x100000001b3);
58 }
59 h
60 }
61}
62
63impl Embedder for HashEmbedder {
64 fn dimensions(&self) -> usize {
65 self.dims
66 }
67
68 fn embed(&self, text: &str) -> Result<Vec<f32>> {
69 let mut v = vec![0.0f32; self.dims];
70 for raw in text.split_whitespace() {
71 let token = raw.to_ascii_lowercase();
73 let h = Self::token_hash(&token);
74 let idx = (h % self.dims as u64) as usize;
75 let sign = if (h >> 32) & 1 == 0 { 1.0 } else { -1.0 };
77 v[idx] += sign;
78 }
79 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
81 if norm > 0.0 {
82 for x in &mut v {
83 *x /= norm;
84 }
85 }
86 Ok(v)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn deterministic_and_normalized() {
96 let e = HashEmbedder::new(64);
97 let a = e.embed("the quick brown fox").unwrap();
98 let b = e.embed("the quick brown fox").unwrap();
99 assert_eq!(a, b); assert_eq!(a.len(), 64);
101 let norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
102 assert!((norm - 1.0).abs() < 1e-5); }
104
105 #[test]
106 fn overlap_scores_higher_than_disjoint() {
107 let e = HashEmbedder::new(256);
108 let cos = |x: &[f32], y: &[f32]| -> f32 { x.iter().zip(y).map(|(a, b)| a * b).sum() };
109 let base = e.embed("machine learning vector database").unwrap();
110 let near = e.embed("machine learning vector search").unwrap();
111 let far = e.embed("unrelated cooking recipe content").unwrap();
112 assert!(cos(&base, &near) > cos(&base, &far));
113 }
114
115 #[test]
116 fn case_insensitive_tokens() {
117 let e = HashEmbedder::new(64);
118 assert_eq!(e.embed("Hello World").unwrap(), e.embed("hello world").unwrap());
119 }
120
121 #[test]
122 fn empty_text_is_zero_vector() {
123 let e = HashEmbedder::new(32);
124 assert_eq!(e.embed(" ").unwrap(), vec![0.0f32; 32]);
125 }
126
127 #[test]
128 fn batch_matches_single() {
129 let e = HashEmbedder::new(48);
130 let batch = e.embed_batch(&["alpha beta", "gamma"]).unwrap();
131 assert_eq!(batch[0], e.embed("alpha beta").unwrap());
132 assert_eq!(batch[1], e.embed("gamma").unwrap());
133 }
134}