Skip to main content

shadow_core/diff/
embedder.rs

1//! Pluggable embedding backend for the semantic axis.
2//!
3//! The default `TfIdfEmbedder` implements smoothed TF-IDF cosine over
4//! the corpus of texts being compared — production-quality for lexical
5//! similarity but blind to paraphrase ("yes" vs "I agree" score 0).
6//!
7//! For paraphrase-robust similarity, callers supply an [`Embedder`]
8//! that produces dense vectors per text. The crate stays free of heavy
9//! ML dependencies; users bring their own embedding source via:
10//!
11//!   * `BoxedEmbedder::new(|texts| { ... })` — a closure returning
12//!     `Vec<Vec<f32>>` for any external source (ONNX runtime, HF
13//!     Inference API, OpenAI embeddings, in-house service, ...).
14//!   * A direct impl of [`Embedder`] for stateful adapters that need
15//!     to hold model handles, HTTP clients, or tokenizer state.
16//!
17//! Cross-language consistency: the cosine similarity computation
18//! happens in Rust regardless of where vectors come from. As long as
19//! two embedders produce comparable vectors (same dimensionality,
20//! similar magnitudes), their semantic-axis output stays meaningful.
21//!
22//! Why no built-in ONNX backend
23//! ----------------------------
24//! Bundling `ort` + `tokenizers` + a real embedding model would
25//! either blow past PyPI's per-wheel size limit (~100 MB) or force
26//! users to download the model file on first use — both create
27//! friction for the 99% of Shadow users whose semantic-axis needs
28//! are already met by TF-IDF over response text. The trait keeps the
29//! door open for users with paraphrase-heavy workloads to plug in
30//! whatever embedding source they already run, without forcing the
31//! cost on the default install.
32
33/// A backend that produces dense embedding vectors for a slice of
34/// input texts.
35///
36/// Implementations must be deterministic for a given input set
37/// (otherwise the semantic axis becomes flappy across runs). Vector
38/// dimensionality is implementation-defined; only the requirement
39/// "every output vector has the same length" is enforced — the cosine
40/// math handles any dimensionality.
41pub trait Embedder: Send + Sync {
42    /// Embed each text in `texts`. The returned vector at position `i`
43    /// is the embedding for `texts[i]`. All vectors must have the same
44    /// length (panics on mismatched dimensionality are a contract bug;
45    /// the caller will validate before computing cosine).
46    fn embed(&self, texts: &[&str]) -> Vec<Vec<f32>>;
47
48    /// Optional: a stable identifier for the embedder, included in
49    /// diagnostic output so a user can tell at a glance which embedder
50    /// produced a given semantic-axis score.
51    fn id(&self) -> &str {
52        "anonymous"
53    }
54}
55
56/// Adapter that wraps any `Fn(&[&str]) -> Vec<Vec<f32>>` closure into
57/// an [`Embedder`].
58///
59/// Useful when the embedding source is an HTTP client, an ONNX
60/// session, a Python callback (via PyO3), or any other resource the
61/// caller already manages.
62pub struct BoxedEmbedder<F>
63where
64    F: Fn(&[&str]) -> Vec<Vec<f32>> + Send + Sync,
65{
66    f: F,
67    name: String,
68}
69
70impl<F> BoxedEmbedder<F>
71where
72    F: Fn(&[&str]) -> Vec<Vec<f32>> + Send + Sync,
73{
74    /// Wrap a closure as an [`Embedder`] with the default name `"boxed"`.
75    pub fn new(f: F) -> Self {
76        Self {
77            f,
78            name: "boxed".to_string(),
79        }
80    }
81
82    /// Wrap a closure as an [`Embedder`] with a caller-supplied name
83    /// (returned by [`Embedder::id`] for diagnostic output).
84    pub fn named(f: F, name: impl Into<String>) -> Self {
85        Self {
86            f,
87            name: name.into(),
88        }
89    }
90}
91
92impl<F> Embedder for BoxedEmbedder<F>
93where
94    F: Fn(&[&str]) -> Vec<Vec<f32>> + Send + Sync,
95{
96    fn embed(&self, texts: &[&str]) -> Vec<Vec<f32>> {
97        (self.f)(texts)
98    }
99
100    fn id(&self) -> &str {
101        &self.name
102    }
103}
104
105/// Cosine similarity between two equal-length dense vectors.
106///
107/// Returns 1.0 when both vectors are zero (consistent with the
108/// TF-IDF axis: empty-vs-empty is a perfect match — neither has any
109/// semantic content to differ on). Returns 0.0 when exactly one is
110/// zero. Otherwise the standard `(a·b) / (‖a‖ · ‖b‖)`.
111///
112/// The result is clamped to `[-1.0, 1.0]` to absorb floating-point
113/// drift.
114pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
115    if a.len() != b.len() {
116        return 0.0;
117    }
118    let mut dot: f32 = 0.0;
119    let mut na: f32 = 0.0;
120    let mut nb: f32 = 0.0;
121    for i in 0..a.len() {
122        let x = a[i];
123        let y = b[i];
124        dot += x * y;
125        na += x * x;
126        nb += y * y;
127    }
128    let na = na.sqrt();
129    let nb = nb.sqrt();
130    if na < 1e-12 && nb < 1e-12 {
131        return 1.0;
132    }
133    if na < 1e-12 || nb < 1e-12 {
134        return 0.0;
135    }
136    (dot / (na * nb)).clamp(-1.0, 1.0)
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn cosine_identical_vectors_is_one() {
145        let v = [1.0_f32, 2.0, 3.0];
146        assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
147    }
148
149    #[test]
150    fn cosine_orthogonal_vectors_is_zero() {
151        let a = [1.0_f32, 0.0];
152        let b = [0.0_f32, 1.0];
153        assert!(cosine(&a, &b).abs() < 1e-6);
154    }
155
156    #[test]
157    fn cosine_both_zero_returns_one() {
158        let z = [0.0_f32; 4];
159        assert!((cosine(&z, &z) - 1.0).abs() < 1e-9);
160    }
161
162    #[test]
163    fn cosine_one_zero_returns_zero() {
164        let a = [0.0_f32; 4];
165        let b = [1.0_f32, 2.0, 3.0, 4.0];
166        assert_eq!(cosine(&a, &b), 0.0);
167    }
168
169    #[test]
170    fn cosine_dim_mismatch_returns_zero() {
171        let a = [1.0_f32, 2.0];
172        let b = [1.0_f32, 2.0, 3.0];
173        assert_eq!(cosine(&a, &b), 0.0);
174    }
175
176    #[test]
177    fn boxed_embedder_round_trip() {
178        let emb = BoxedEmbedder::named(
179            |texts: &[&str]| texts.iter().map(|t| vec![t.len() as f32, 1.0]).collect(),
180            "len-embed",
181        );
182        let v = emb.embed(&["abc", "abcdef"]);
183        assert_eq!(v.len(), 2);
184        assert_eq!(v[0], vec![3.0, 1.0]);
185        assert_eq!(v[1], vec![6.0, 1.0]);
186        assert_eq!(emb.id(), "len-embed");
187    }
188}