Skip to main content

vecgraph_core/
search.rs

1use crate::NodeId;
2use crate::error::VecGraphError;
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum SearchKind {
6    Edge,
7    Node,
8    All,
9}
10
11#[derive(Debug)]
12pub struct SearchResult {
13    pub node_id: NodeId,
14    pub kind: String,
15    pub score: f32,
16    pub hit_kind: SearchKind,
17}
18
19#[derive(Clone)]
20pub struct RerankParams {
21    pub vector: Vec<f32>,
22    pub kind: String,
23    pub weight: f32,
24}
25
26#[derive(Clone)]
27pub struct SearchQuery {
28    pub search_kind: SearchKind,
29    pub query_vec: Vec<f32>,
30    pub kind: String,
31    pub namespace: Option<String>,
32    pub top_k: usize,
33    pub exclude_names: Vec<String>,
34    pub rerank: Option<RerankParams>,
35}
36
37pub struct ScoredHit {
38    pub node_id_bytes: Vec<u8>,
39    pub kind: String,
40    pub score: f32,
41    pub hit_kind: SearchKind,
42}
43
44impl SearchQuery {
45    pub fn new(
46        query_vec: Vec<f32>,
47        search_kind: impl Into<SearchKind>,
48        kind: impl Into<String>,
49        top_k: usize,
50    ) -> Self {
51        Self {
52            query_vec,
53            search_kind: search_kind.into(),
54            kind: kind.into(),
55            namespace: None,
56            top_k,
57            exclude_names: Vec::new(),
58            rerank: None,
59        }
60    }
61
62    pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
63        self.namespace = Some(namespace.into());
64        self
65    }
66
67    pub fn with_excludes(mut self, names: Vec<String>) -> Self {
68        self.exclude_names = names;
69        self
70    }
71
72    pub fn with_rerank(mut self, vector: Vec<f32>, kind: impl Into<String>, weight: f32) -> Self {
73        self.rerank = Some(RerankParams {
74            vector,
75            kind: kind.into(),
76            weight: weight.clamp(0.0, 1.0),
77        });
78        self
79    }
80}
81
82impl From<&str> for SearchKind {
83    fn from(s: &str) -> Self {
84        match s.to_lowercase().as_str() {
85            "edge" => SearchKind::Edge,
86            "node" => SearchKind::Node,
87            "all" => SearchKind::All,
88            _ => SearchKind::All, // Default to all if unrecognized
89        }
90    }
91}
92
93impl From<String> for SearchKind {
94    fn from(s: String) -> Self {
95        SearchKind::from(s.as_str())
96    }
97}
98
99impl std::fmt::Display for SearchKind {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        let s = match self {
102            SearchKind::Edge => "edge",
103            SearchKind::Node => "node",
104            SearchKind::All => "all",
105        };
106        f.write_str(s)
107    }
108}
109
110impl PartialEq for ScoredHit {
111    fn eq(&self, other: &Self) -> bool {
112        self.score == other.score
113    }
114}
115
116impl Eq for ScoredHit {}
117
118impl PartialOrd for ScoredHit {
119    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
120        Some(self.cmp(other))
121    }
122}
123
124impl Ord for ScoredHit {
125    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
126        // Max-heap: worst score at the top so we can evict it
127        self.score
128            .partial_cmp(&other.score)
129            .unwrap_or(std::cmp::Ordering::Equal)
130    }
131}
132
133// Helper functions for vector math - AI Generated as I'm not great at this math
134
135pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
136    debug_assert_eq!(a.len(), b.len(), "vector dimension mismatch");
137
138    let mut dot = 0.0f32;
139    let mut norm_a = 0.0f32;
140    let mut norm_b = 0.0f32;
141
142    for (x, y) in a.iter().zip(b.iter()) {
143        dot += x * y;
144        norm_a += x * x;
145        norm_b += y * y;
146    }
147
148    let denom = norm_a.sqrt() * norm_b.sqrt();
149    if denom == 0.0 {
150        return 1.0;
151    }
152
153    1.0 - (dot / denom)
154}
155
156pub fn normalize(vec: &mut [f32]) {
157    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
158    if norm > 0.0 {
159        for x in vec.iter_mut() {
160            *x /= norm;
161        }
162    }
163}
164
165pub fn build_base_vector(
166    named_vecs: &[Vec<f32>],
167    free_text_vec: Option<&[f32]>,
168    named_weight: f32,
169) -> Result<Vec<f32>, VecGraphError> {
170    let text_weight = 1.0 - named_weight;
171    let has_named = !named_vecs.is_empty();
172    let has_text = free_text_vec.is_some();
173
174    if !has_named && !has_text {
175        return Err(VecGraphError::EmptyQuery);
176    }
177
178    // Text only
179    if !has_named {
180        return Ok(free_text_vec.unwrap().to_vec());
181    }
182
183    // Compute named centroid
184    let dim = named_vecs[0].len();
185    let count = named_vecs.len() as f32;
186    let mut centroid = vec![0.0f32; dim];
187    for v in named_vecs {
188        for (c, val) in centroid.iter_mut().zip(v.iter()) {
189            *c += val / count;
190        }
191    }
192
193    // Named only
194    if !has_text {
195        normalize(&mut centroid);
196        return Ok(centroid);
197    }
198
199    // Both — weighted blend
200    let text_vec = free_text_vec.unwrap();
201    for i in 0..dim {
202        centroid[i] = named_weight * centroid[i] + text_weight * text_vec[i];
203    }
204    normalize(&mut centroid);
205
206    Ok(centroid)
207}