Skip to main content

vecgraph_core/traits/
embedder.rs

1use crate::error::VecGraphError;
2
3pub trait Embedder: Send + Sync {
4    fn embed(&self, input: &str) -> Result<Vec<f32>, VecGraphError>;
5    fn dimensions(&self) -> usize;
6    fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, VecGraphError> {
7        inputs.iter().map(|input| self.embed(input)).collect()
8    }
9    fn arithmetic(
10        &self,
11        base: &[f32],
12        add: &[&[f32]],
13        sub: &[&[f32]],
14    ) -> Result<Vec<f32>, VecGraphError> {
15        let dim = self.dimensions();
16        if base.len() != dim {
17            return Err(VecGraphError::DimensionMismatch {
18                expected: dim,
19                got: base.len(),
20            });
21        }
22
23        let mut result = base.to_vec();
24
25        // Simple vector arithmetic: result = base + sum(add) - sum(sub)
26
27        for v in add {
28            if v.len() != dim {
29                return Err(VecGraphError::DimensionMismatch {
30                    expected: dim,
31                    got: v.len(),
32                });
33            }
34            for (r, val) in result.iter_mut().zip(v.iter()) {
35                *r += val;
36            }
37        }
38
39        for v in sub {
40            if v.len() != dim {
41                return Err(VecGraphError::DimensionMismatch {
42                    expected: dim,
43                    got: v.len(),
44                });
45            }
46            for (r, val) in result.iter_mut().zip(v.iter()) {
47                *r -= val;
48            }
49        }
50
51        // Normalize the result to unit length
52        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
53        if norm > 0.0 {
54            for x in result.iter_mut() {
55                *x /= norm;
56            }
57        }
58
59        Ok(result)
60    }
61}