vecgraph_core/traits/
embedder.rs1use crate::error::VecGraphError;
2
3pub trait Embedder: Send + Sync {
4 fn embed(&self, input: &str) -> Result<Vec<f32>, VecGraphError>;
5 fn dimensions(&self) -> usize;
6 fn embed_batch(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, VecGraphError> {
7 inputs.iter().map(|input| self.embed(input)).collect()
8 }
9 fn arithmetic(
10 &self,
11 base: &[f32],
12 add: &[&[f32]],
13 sub: &[&[f32]],
14 ) -> Result<Vec<f32>, VecGraphError> {
15 let dim = self.dimensions();
16 if base.len() != dim {
17 return Err(VecGraphError::DimensionMismatch {
18 expected: dim,
19 got: base.len(),
20 });
21 }
22
23 let mut result = base.to_vec();
24
25 for v in add {
28 if v.len() != dim {
29 return Err(VecGraphError::DimensionMismatch {
30 expected: dim,
31 got: v.len(),
32 });
33 }
34 for (r, val) in result.iter_mut().zip(v.iter()) {
35 *r += val;
36 }
37 }
38
39 for v in sub {
40 if v.len() != dim {
41 return Err(VecGraphError::DimensionMismatch {
42 expected: dim,
43 got: v.len(),
44 });
45 }
46 for (r, val) in result.iter_mut().zip(v.iter()) {
47 *r -= val;
48 }
49 }
50
51 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
53 if norm > 0.0 {
54 for x in result.iter_mut() {
55 *x /= norm;
56 }
57 }
58
59 Ok(result)
60 }
61}