spark_bert/
vector_vocabulary.rs1use std::collections::HashMap;
2use std::fs::File;
3use std::io::BufReader;
4
5use anyhow::{anyhow, Result};
6use faiss::index::{IndexImpl, SearchResult};
7use faiss::{read_index, Idx, Index};
8use hf_hub::api::sync::Api;
9use hf_hub::Repo;
10use rayon::iter::{IndexedParallelIterator, ParallelIterator};
11use rayon::slice::ParallelSliceMut;
12
13pub struct VectorVocabulary {
14 vector_index: IndexImpl,
15 faiss_idx_to_token: HashMap<String, String>,
16}
17
18impl VectorVocabulary {
19 pub fn build() -> Result<Self> {
20 let repo = Repo::model(
21 "viacheslav-dobrynin/spark-bert-msmarco-all-MiniLM-L6-v2-vector-vocab".to_owned(),
22 );
23 let api = Api::new()?.repo(repo);
24 let faiss_idx_to_token_path = api
25 .get("faiss_idx_to_token.json")?
26 .into_os_string()
27 .into_string()
28 .map_err(|path| anyhow!("cache path is not valid UTF-8: {:?}", path))?;
29 let vector_vocab_path = api
30 .get("vector_vocab.hnsw.faiss")?
31 .into_os_string()
32 .into_string()
33 .map_err(|path| anyhow!("cache path is not valid UTF-8: {:?}", path))?;
34 let vector_index = read_index(&vector_vocab_path)?;
35 let faiss_idx_to_token: HashMap<String, String> =
36 Self::load_faiss_idx_to_token(&faiss_idx_to_token_path)?;
37 Ok(Self {
38 vector_index,
39 faiss_idx_to_token,
40 })
41 }
42
43 fn load_faiss_idx_to_token(json_path: &str) -> anyhow::Result<HashMap<String, String>> {
44 let file = File::open(json_path)?;
45 let reader = BufReader::new(file);
46 let faiss_idx_to_token: HashMap<String, String> = serde_json::from_reader(reader)?;
47 anyhow::Ok(faiss_idx_to_token)
48 }
49
50 pub fn get_num_tokens(&self) -> u64 {
51 self.vector_index.ntotal()
52 }
53
54 pub fn get_embedding_dims(&self) -> u32 {
55 self.vector_index.d()
56 }
57
58 pub fn find_tokens(
59 &mut self,
60 query_embs: &[f32],
61 n_neighbors: usize,
62 with_embs: bool,
63 ) -> Result<(Vec<&str>, Option<Vec<f32>>)> {
64 let SearchResult {
65 distances: _,
66 labels,
67 } = self.vector_index.search(query_embs, n_neighbors)?;
68 let labels = unique_labels(&labels);
69 let tokens: Vec<&str> = labels
70 .iter()
71 .map(|idx| {
72 let idx = idx.get().unwrap().to_string();
73 self.faiss_idx_to_token
74 .get(&idx)
75 .map(String::as_str)
76 .unwrap()
77 })
78 .collect();
79 let token_embs = if with_embs {
80 Some(reconstruct_batch(&self.vector_index, &labels)?)
81 } else {
82 None
83 };
84 Ok((tokens, token_embs))
85 }
86}
87
88pub fn reconstruct_batch<T>(index: &T, labels: &[faiss::Idx]) -> anyhow::Result<Vec<f32>>
89where
90 T: Index + Sync,
91{
92 let d = index.d() as usize;
93 let batch = labels.len();
94 let mut flat_embs = vec![0f32; batch * d];
95 debug_assert_eq!(flat_embs.len(), labels.len() * d);
96 flat_embs
97 .par_chunks_mut(d)
98 .enumerate()
99 .try_for_each(|(i, chunk)| {
100 let idx = labels[i];
101 index.reconstruct(idx, chunk).map_err(anyhow::Error::from)
102 })?;
103 anyhow::Ok(flat_embs)
104}
105
106pub fn unique_labels(labels: &[Idx]) -> Vec<Idx> {
107 let mut unique_ids: Vec<u64> = labels.iter().filter_map(|idx| idx.get()).collect();
108 unique_ids.sort_unstable();
109 unique_ids.dedup();
110 unique_ids.into_iter().map(Idx::new).collect()
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use faiss::error::Result as FaissResult;
117 use faiss::{Idx, MetricType};
118
119 #[test]
120 fn should_reconstruct_batch_of_embs() {
121 let mock = MockIndex {
122 vecs: vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]],
123 };
124 let labels = [Idx::new(0), Idx::new(1)];
125
126 let embs = reconstruct_batch(&mock, &labels).unwrap();
127
128 assert_eq!(embs, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
129 }
130
131 #[test]
132 fn should_return_unique_labels() {
133 let labels = [
134 Idx::new(2),
135 Idx::new(1),
136 Idx::new(2),
137 Idx::new(3),
138 Idx::new(1),
139 ];
140
141 let uniques = unique_labels(&labels);
142
143 assert_eq!(uniques, vec![Idx::new(1), Idx::new(2), Idx::new(3)]);
144 }
145
146 struct MockIndex {
147 vecs: Vec<Vec<f32>>,
148 }
149
150 impl faiss::Index for MockIndex {
151 fn d(&self) -> u32 {
152 self.vecs[0].len() as u32
153 }
154
155 fn reconstruct(&self, idx: Idx, dest: &mut [f32]) -> FaissResult<()> {
156 dest.copy_from_slice(&self.vecs[idx.get().unwrap() as usize]);
157 Ok(())
158 }
159
160 fn is_trained(&self) -> bool {
161 todo!()
162 }
163
164 fn ntotal(&self) -> u64 {
165 todo!()
166 }
167
168 fn metric_type(&self) -> MetricType {
169 todo!()
170 }
171
172 fn add(&mut self, x: &[f32]) -> FaissResult<()> {
173 let _ = x;
174 todo!()
175 }
176
177 fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> FaissResult<()> {
178 let _ = xids;
179 let _ = x;
180 todo!()
181 }
182
183 fn train(&mut self, x: &[f32]) -> FaissResult<()> {
184 let _ = x;
185 todo!()
186 }
187
188 fn assign(&mut self, q: &[f32], k: usize) -> FaissResult<faiss::index::AssignSearchResult> {
189 let _ = k;
190 let _ = q;
191 todo!()
192 }
193
194 fn search(&mut self, q: &[f32], k: usize) -> FaissResult<faiss::index::SearchResult> {
195 let _ = k;
196 let _ = q;
197 todo!()
198 }
199
200 fn range_search(
201 &mut self,
202 q: &[f32],
203 radius: f32,
204 ) -> FaissResult<faiss::index::RangeSearchResult> {
205 let _ = radius;
206 let _ = q;
207 todo!()
208 }
209
210 fn reconstruct_n(
211 &self,
212 first_key: Idx,
213 count: usize,
214 output: &mut [f32],
215 ) -> FaissResult<()> {
216 let _ = output;
217 let _ = count;
218 let _ = first_key;
219 todo!()
220 }
221
222 fn reset(&mut self) -> FaissResult<()> {
223 todo!()
224 }
225
226 fn remove_ids(&mut self, sel: &faiss::selector::IdSelector) -> FaissResult<usize> {
227 let _ = sel;
228 todo!()
229 }
230
231 fn verbose(&self) -> bool {
232 todo!()
233 }
234
235 fn set_verbose(&mut self, value: bool) {
236 let _ = value;
237 todo!()
238 }
239 }
240}