1use crate::NodeId;
2use crate::error::VecGraphError;
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum SearchKind {
6 Edge,
7 Node,
8 All,
9}
10
11#[derive(Debug)]
12pub struct SearchResult {
13 pub node_id: NodeId,
14 pub kind: String,
15 pub score: f32,
16 pub hit_kind: SearchKind,
17}
18
19#[derive(Clone)]
20pub struct RerankParams {
21 pub vector: Vec<f32>,
22 pub kind: String,
23 pub weight: f32,
24}
25
26#[derive(Clone)]
27pub struct SearchQuery {
28 pub search_kind: SearchKind,
29 pub query_vec: Vec<f32>,
30 pub kind: String,
31 pub namespace: Option<String>,
32 pub top_k: usize,
33 pub exclude_names: Vec<String>,
34 pub rerank: Option<RerankParams>,
35}
36
37pub struct ScoredHit {
38 pub node_id_bytes: Vec<u8>,
39 pub kind: String,
40 pub score: f32,
41 pub hit_kind: SearchKind,
42}
43
44impl SearchQuery {
45 pub fn new(
46 query_vec: Vec<f32>,
47 search_kind: impl Into<SearchKind>,
48 kind: impl Into<String>,
49 top_k: usize,
50 ) -> Self {
51 Self {
52 query_vec,
53 search_kind: search_kind.into(),
54 kind: kind.into(),
55 namespace: None,
56 top_k,
57 exclude_names: Vec::new(),
58 rerank: None,
59 }
60 }
61
62 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
63 self.namespace = Some(namespace.into());
64 self
65 }
66
67 pub fn with_excludes(mut self, names: Vec<String>) -> Self {
68 self.exclude_names = names;
69 self
70 }
71
72 pub fn with_rerank(mut self, vector: Vec<f32>, kind: impl Into<String>, weight: f32) -> Self {
73 self.rerank = Some(RerankParams {
74 vector,
75 kind: kind.into(),
76 weight: weight.clamp(0.0, 1.0),
77 });
78 self
79 }
80}
81
82impl From<&str> for SearchKind {
83 fn from(s: &str) -> Self {
84 match s.to_lowercase().as_str() {
85 "edge" => SearchKind::Edge,
86 "node" => SearchKind::Node,
87 "all" => SearchKind::All,
88 _ => SearchKind::All, }
90 }
91}
92
93impl From<String> for SearchKind {
94 fn from(s: String) -> Self {
95 SearchKind::from(s.as_str())
96 }
97}
98
99impl std::fmt::Display for SearchKind {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 let s = match self {
102 SearchKind::Edge => "edge",
103 SearchKind::Node => "node",
104 SearchKind::All => "all",
105 };
106 f.write_str(s)
107 }
108}
109
110impl PartialEq for ScoredHit {
111 fn eq(&self, other: &Self) -> bool {
112 self.score == other.score
113 }
114}
115
116impl Eq for ScoredHit {}
117
118impl PartialOrd for ScoredHit {
119 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
120 Some(self.cmp(other))
121 }
122}
123
124impl Ord for ScoredHit {
125 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
126 self.score
128 .partial_cmp(&other.score)
129 .unwrap_or(std::cmp::Ordering::Equal)
130 }
131}
132
133pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
136 debug_assert_eq!(a.len(), b.len(), "vector dimension mismatch");
137
138 let mut dot = 0.0f32;
139 let mut norm_a = 0.0f32;
140 let mut norm_b = 0.0f32;
141
142 for (x, y) in a.iter().zip(b.iter()) {
143 dot += x * y;
144 norm_a += x * x;
145 norm_b += y * y;
146 }
147
148 let denom = norm_a.sqrt() * norm_b.sqrt();
149 if denom == 0.0 {
150 return 1.0;
151 }
152
153 1.0 - (dot / denom)
154}
155
156pub fn normalize(vec: &mut [f32]) {
157 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
158 if norm > 0.0 {
159 for x in vec.iter_mut() {
160 *x /= norm;
161 }
162 }
163}
164
165pub fn build_base_vector(
166 named_vecs: &[Vec<f32>],
167 free_text_vec: Option<&[f32]>,
168 named_weight: f32,
169) -> Result<Vec<f32>, VecGraphError> {
170 let text_weight = 1.0 - named_weight;
171 let has_named = !named_vecs.is_empty();
172 let has_text = free_text_vec.is_some();
173
174 if !has_named && !has_text {
175 return Err(VecGraphError::EmptyQuery);
176 }
177
178 if !has_named {
180 return Ok(free_text_vec.unwrap().to_vec());
181 }
182
183 let dim = named_vecs[0].len();
185 let count = named_vecs.len() as f32;
186 let mut centroid = vec![0.0f32; dim];
187 for v in named_vecs {
188 for (c, val) in centroid.iter_mut().zip(v.iter()) {
189 *c += val / count;
190 }
191 }
192
193 if !has_text {
195 normalize(&mut centroid);
196 return Ok(centroid);
197 }
198
199 let text_vec = free_text_vec.unwrap();
201 for i in 0..dim {
202 centroid[i] = named_weight * centroid[i] + text_weight * text_vec[i];
203 }
204 normalize(&mut centroid);
205
206 Ok(centroid)
207}