Skip to main content

ruvector_graph/
embed.rs

1//! Pluggable text embedding for inline `embed()`-at-insert/at-query
2//! (HelixDB-inspired, ADR-252 P3).
3//!
4//! HelixQL's built-in `Embed(text)` vectorizes inline so a caller never has to
5//! marshal a separate embedding service into the query. This module gives
6//! `TypedGraph` the same ergonomic via a pluggable [`Embedder`] trait: attach a
7//! model once, then create nodes from text or search by text and the binding's
8//! dimension is validated against the schema's vector type.
9//!
10//! A real model is supplied by implementing [`Embedder`]. A dependency-free
11//! [`HashEmbedder`] is included for offline/dev/test use — it is **not**
12//! semantic (lexical token overlap only) and must be opted into explicitly;
13//! consistent with ADR-194, the typed graph never silently falls back to it.
14
15use crate::error::Result;
16
17/// A text → vector embedding model.
18pub trait Embedder: Send + Sync {
19    /// Output dimension; must match the bound vector type's `dimensions`.
20    fn dimensions(&self) -> usize;
21
22    /// Embed a single text.
23    fn embed(&self, text: &str) -> Result<Vec<f32>>;
24
25    /// Embed a batch. Default loops `embed`; implementors may override with a
26    /// vectorized path.
27    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
28        texts.iter().map(|t| self.embed(t)).collect()
29    }
30}
31
32/// Deterministic, dependency-free feature-hashing embedder.
33///
34/// Captures **lexical token overlap only** — it is not a semantic model. It
35/// exists for offline/dev/test use and as an *explicit* opt-in (never a silent
36/// fallback, per ADR-194). Identical text always yields an identical vector, so
37/// it is useful for deterministic tests. For semantic search, supply a real
38/// model via [`Embedder`].
39#[derive(Debug, Clone)]
40pub struct HashEmbedder {
41    dims: usize,
42}
43
44impl HashEmbedder {
45    /// Create a hashing embedder of the given output dimension.
46    pub fn new(dims: usize) -> Self {
47        assert!(dims > 0, "HashEmbedder dimension must be > 0");
48        Self { dims }
49    }
50
51    /// FNV-1a hash of a token.
52    #[inline]
53    fn token_hash(token: &str) -> u64 {
54        let mut h: u64 = 0xcbf29ce484222325;
55        for b in token.bytes() {
56            h ^= b as u64;
57            h = h.wrapping_mul(0x100000001b3);
58        }
59        h
60    }
61}
62
63impl Embedder for HashEmbedder {
64    fn dimensions(&self) -> usize {
65        self.dims
66    }
67
68    fn embed(&self, text: &str) -> Result<Vec<f32>> {
69        let mut v = vec![0.0f32; self.dims];
70        for raw in text.split_whitespace() {
71            // Case-fold so "Cat" and "cat" collide.
72            let token = raw.to_ascii_lowercase();
73            let h = Self::token_hash(&token);
74            let idx = (h % self.dims as u64) as usize;
75            // Signed feature hashing reduces collision bias.
76            let sign = if (h >> 32) & 1 == 0 { 1.0 } else { -1.0 };
77            v[idx] += sign;
78        }
79        // L2-normalize so cosine of identical text is exactly 1.0.
80        let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
81        if norm > 0.0 {
82            for x in &mut v {
83                *x /= norm;
84            }
85        }
86        Ok(v)
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn deterministic_and_normalized() {
96        let e = HashEmbedder::new(64);
97        let a = e.embed("the quick brown fox").unwrap();
98        let b = e.embed("the quick brown fox").unwrap();
99        assert_eq!(a, b); // deterministic
100        assert_eq!(a.len(), 64);
101        let norm: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
102        assert!((norm - 1.0).abs() < 1e-5); // unit length
103    }
104
105    #[test]
106    fn overlap_scores_higher_than_disjoint() {
107        let e = HashEmbedder::new(256);
108        let cos = |x: &[f32], y: &[f32]| -> f32 { x.iter().zip(y).map(|(a, b)| a * b).sum() };
109        let base = e.embed("machine learning vector database").unwrap();
110        let near = e.embed("machine learning vector search").unwrap();
111        let far = e.embed("unrelated cooking recipe content").unwrap();
112        assert!(cos(&base, &near) > cos(&base, &far));
113    }
114
115    #[test]
116    fn case_insensitive_tokens() {
117        let e = HashEmbedder::new(64);
118        assert_eq!(e.embed("Hello World").unwrap(), e.embed("hello world").unwrap());
119    }
120
121    #[test]
122    fn empty_text_is_zero_vector() {
123        let e = HashEmbedder::new(32);
124        assert_eq!(e.embed("   ").unwrap(), vec![0.0f32; 32]);
125    }
126
127    #[test]
128    fn batch_matches_single() {
129        let e = HashEmbedder::new(48);
130        let batch = e.embed_batch(&["alpha beta", "gamma"]).unwrap();
131        assert_eq!(batch[0], e.embed("alpha beta").unwrap());
132        assert_eq!(batch[1], e.embed("gamma").unwrap());
133    }
134}