Skip to main content

spark_bert/
vector_vocabulary.rs

1use std::collections::HashMap;
2use std::fs::File;
3use std::io::BufReader;
4
5use anyhow::{anyhow, Result};
6use faiss::index::{IndexImpl, SearchResult};
7use faiss::{read_index, Idx, Index};
8use hf_hub::api::sync::Api;
9use hf_hub::Repo;
10use rayon::iter::{IndexedParallelIterator, ParallelIterator};
11use rayon::slice::ParallelSliceMut;
12
13pub struct VectorVocabulary {
14    vector_index: IndexImpl,
15    faiss_idx_to_token: HashMap<String, String>,
16}
17
18impl VectorVocabulary {
19    pub fn build() -> Result<Self> {
20        let repo = Repo::model(
21            "viacheslav-dobrynin/spark-bert-msmarco-all-MiniLM-L6-v2-vector-vocab".to_owned(),
22        );
23        let api = Api::new()?.repo(repo);
24        let faiss_idx_to_token_path = api
25            .get("faiss_idx_to_token.json")?
26            .into_os_string()
27            .into_string()
28            .map_err(|path| anyhow!("cache path is not valid UTF-8: {:?}", path))?;
29        let vector_vocab_path = api
30            .get("vector_vocab.hnsw.faiss")?
31            .into_os_string()
32            .into_string()
33            .map_err(|path| anyhow!("cache path is not valid UTF-8: {:?}", path))?;
34        let vector_index = read_index(&vector_vocab_path)?;
35        let faiss_idx_to_token: HashMap<String, String> =
36            Self::load_faiss_idx_to_token(&faiss_idx_to_token_path)?;
37        Ok(Self {
38            vector_index,
39            faiss_idx_to_token,
40        })
41    }
42
43    fn load_faiss_idx_to_token(json_path: &str) -> anyhow::Result<HashMap<String, String>> {
44        let file = File::open(json_path)?;
45        let reader = BufReader::new(file);
46        let faiss_idx_to_token: HashMap<String, String> = serde_json::from_reader(reader)?;
47        anyhow::Ok(faiss_idx_to_token)
48    }
49
50    pub fn get_num_tokens(&self) -> u64 {
51        self.vector_index.ntotal()
52    }
53
54    pub fn get_embedding_dims(&self) -> u32 {
55        self.vector_index.d()
56    }
57
58    pub fn find_tokens(
59        &mut self,
60        query_embs: &[f32],
61        n_neighbors: usize,
62        with_embs: bool,
63    ) -> Result<(Vec<&str>, Option<Vec<f32>>)> {
64        let SearchResult {
65            distances: _,
66            labels,
67        } = self.vector_index.search(query_embs, n_neighbors)?;
68        let labels = unique_labels(&labels);
69        let tokens: Vec<&str> = labels
70            .iter()
71            .map(|idx| {
72                let idx = idx.get().unwrap().to_string();
73                self.faiss_idx_to_token
74                    .get(&idx)
75                    .map(String::as_str)
76                    .unwrap()
77            })
78            .collect();
79        let token_embs = if with_embs {
80            Some(reconstruct_batch(&self.vector_index, &labels)?)
81        } else {
82            None
83        };
84        Ok((tokens, token_embs))
85    }
86}
87
88pub fn reconstruct_batch<T>(index: &T, labels: &[faiss::Idx]) -> anyhow::Result<Vec<f32>>
89where
90    T: Index + Sync,
91{
92    let d = index.d() as usize;
93    let batch = labels.len();
94    let mut flat_embs = vec![0f32; batch * d];
95    debug_assert_eq!(flat_embs.len(), labels.len() * d);
96    flat_embs
97        .par_chunks_mut(d)
98        .enumerate()
99        .try_for_each(|(i, chunk)| {
100            let idx = labels[i];
101            index.reconstruct(idx, chunk).map_err(anyhow::Error::from)
102        })?;
103    anyhow::Ok(flat_embs)
104}
105
106pub fn unique_labels(labels: &[Idx]) -> Vec<Idx> {
107    let mut unique_ids: Vec<u64> = labels.iter().filter_map(|idx| idx.get()).collect();
108    unique_ids.sort_unstable();
109    unique_ids.dedup();
110    unique_ids.into_iter().map(Idx::new).collect()
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use faiss::error::Result as FaissResult;
117    use faiss::{Idx, MetricType};
118
119    #[test]
120    fn should_reconstruct_batch_of_embs() {
121        let mock = MockIndex {
122            vecs: vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]],
123        };
124        let labels = [Idx::new(0), Idx::new(1)];
125
126        let embs = reconstruct_batch(&mock, &labels).unwrap();
127
128        assert_eq!(embs, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
129    }
130
131    #[test]
132    fn should_return_unique_labels() {
133        let labels = [
134            Idx::new(2),
135            Idx::new(1),
136            Idx::new(2),
137            Idx::new(3),
138            Idx::new(1),
139        ];
140
141        let uniques = unique_labels(&labels);
142
143        assert_eq!(uniques, vec![Idx::new(1), Idx::new(2), Idx::new(3)]);
144    }
145
146    struct MockIndex {
147        vecs: Vec<Vec<f32>>,
148    }
149
150    impl faiss::Index for MockIndex {
151        fn d(&self) -> u32 {
152            self.vecs[0].len() as u32
153        }
154
155        fn reconstruct(&self, idx: Idx, dest: &mut [f32]) -> FaissResult<()> {
156            dest.copy_from_slice(&self.vecs[idx.get().unwrap() as usize]);
157            Ok(())
158        }
159
160        fn is_trained(&self) -> bool {
161            todo!()
162        }
163
164        fn ntotal(&self) -> u64 {
165            todo!()
166        }
167
168        fn metric_type(&self) -> MetricType {
169            todo!()
170        }
171
172        fn add(&mut self, x: &[f32]) -> FaissResult<()> {
173            let _ = x;
174            todo!()
175        }
176
177        fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> FaissResult<()> {
178            let _ = xids;
179            let _ = x;
180            todo!()
181        }
182
183        fn train(&mut self, x: &[f32]) -> FaissResult<()> {
184            let _ = x;
185            todo!()
186        }
187
188        fn assign(&mut self, q: &[f32], k: usize) -> FaissResult<faiss::index::AssignSearchResult> {
189            let _ = k;
190            let _ = q;
191            todo!()
192        }
193
194        fn search(&mut self, q: &[f32], k: usize) -> FaissResult<faiss::index::SearchResult> {
195            let _ = k;
196            let _ = q;
197            todo!()
198        }
199
200        fn range_search(
201            &mut self,
202            q: &[f32],
203            radius: f32,
204        ) -> FaissResult<faiss::index::RangeSearchResult> {
205            let _ = radius;
206            let _ = q;
207            todo!()
208        }
209
210        fn reconstruct_n(
211            &self,
212            first_key: Idx,
213            count: usize,
214            output: &mut [f32],
215        ) -> FaissResult<()> {
216            let _ = output;
217            let _ = count;
218            let _ = first_key;
219            todo!()
220        }
221
222        fn reset(&mut self) -> FaissResult<()> {
223            todo!()
224        }
225
226        fn remove_ids(&mut self, sel: &faiss::selector::IdSelector) -> FaissResult<usize> {
227            let _ = sel;
228            todo!()
229        }
230
231        fn verbose(&self) -> bool {
232            todo!()
233        }
234
235        fn set_verbose(&mut self, value: bool) {
236            let _ = value;
237            todo!()
238        }
239    }
240}