Skip to main content

spark_bert/
inverted_index.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fs,
4    path::PathBuf,
5};
6
7use anyhow::{Context, Result};
8use float8::F8E4M3;
9use tantivy::{
10    directory::MmapDirectory,
11    query::{BooleanQuery, Query},
12    schema::{
13        Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, Value, FAST, STORED,
14    },
15    Index, IndexReader, IndexWriter, ReloadPolicy, Searcher, TantivyDocument, Term,
16};
17
18use crate::{directory::ram_directory_from_mmap_dir, tf_term_query::TfTermQuery};
19
20const MAX_DF_RATIO: f32 = 0.15;
21
22pub struct InvertedIndex {
23    index: Index,
24    writer: IndexWriter,
25    pub reader: IndexReader,
26    token_cluster_id: Field,
27    doc_id: Field,
28    pending: HashMap<u64, Vec<(String, f32)>>,
29}
30
31impl InvertedIndex {
32    pub fn open(use_ram_index: bool) -> Result<Self> {
33        let directory_path = Self::default_directory_path();
34        let (schema, token_cluster_id, doc_id) = Self::build_schema()?;
35        let index = if use_ram_index {
36            if directory_path.exists() {
37                let ram_directory = ram_directory_from_mmap_dir(&directory_path)?;
38                Index::open(ram_directory)?
39            } else {
40                Index::create_in_ram(schema)
41            }
42        } else {
43            fs::create_dir_all(&directory_path)?;
44            let directory = MmapDirectory::open(&directory_path)?;
45            Index::open_or_create(directory, schema)?
46        };
47        let memory_budget_in_bytes = 500_000_000; // 500 MB heap
48        let writer = index.writer(memory_budget_in_bytes)?;
49        let reader = Self::build_reader(&index)?;
50        Ok(Self {
51            index,
52            writer,
53            reader,
54            token_cluster_id,
55            doc_id,
56            pending: HashMap::new(),
57        })
58    }
59
60    fn default_directory_path() -> PathBuf {
61        std::env::var_os("SPARKBERT_INVERTED_INDEX_DIR")
62            .map(PathBuf::from)
63            .context("Please set SPARKBERT_INVERTED_INDEX_DIR env variable")
64            .unwrap()
65    }
66
67    fn build_schema() -> Result<(Schema, Field, Field)> {
68        let mut schema_builder = Schema::builder();
69        let tok_opts = TextOptions::default().set_indexing_options(
70            TextFieldIndexing::default()
71                .set_tokenizer("raw")
72                .set_index_option(IndexRecordOption::WithFreqs),
73        );
74        let token_cluster_id = schema_builder.add_text_field("token", tok_opts);
75        let doc_id = schema_builder.add_u64_field("doc_id", FAST | STORED);
76        let schema = schema_builder.build();
77        Ok((schema, token_cluster_id, doc_id))
78    }
79
80    pub fn index(&mut self, doc_id: u64, tokens: Vec<String>, scores: Vec<f32>) {
81        debug_assert_eq!(tokens.len(), scores.len());
82        let doc_entry = self.pending.entry(doc_id).or_default();
83        for (token, score) in tokens.into_iter().zip(scores.into_iter()) {
84            doc_entry.push((token, score));
85        }
86    }
87
88    /// commit
89    pub fn finalize(&mut self, filter_stop_words: bool) -> Result<()> {
90        let stop_words = if filter_stop_words {
91            self.prepare_stop_words()
92        } else {
93            HashSet::new()
94        };
95        for (&doc_id, token_score_pairs) in self.pending.iter() {
96            let mut doc = TantivyDocument::new();
97            doc.add_u64(self.doc_id, doc_id);
98            let mut set = false;
99            for (token, score) in token_score_pairs {
100                if stop_words.contains(token) {
101                    continue;
102                }
103                // TODO: remove magic number, use stats
104                if *score < 22.7136 {
105                    continue;
106                }
107                // TODO: add boundaries and try without f8
108                let reps = F8E4M3::from_f32(*score).to_bits();
109                if reps == 0 {
110                    continue;
111                }
112                set = true;
113                // set score as tf
114                for _ in 0..reps {
115                    doc.add_text(self.token_cluster_id, token);
116                }
117            }
118            if set {
119                self.writer.add_document(doc)?;
120            } else {
121                panic!("adjust hyperparams, no tokens were added to doc")
122            }
123        }
124        self.pending.clear();
125        self.writer.commit()?;
126        self.writer
127            .merge(&self.index.searchable_segment_ids()?)
128            .wait()?;
129        self.reader.reload()?;
130        Ok(())
131    }
132
133    fn prepare_stop_words(&self) -> HashSet<&String> {
134        let mut token_to_doc_count = HashMap::new();
135        for (_, token_score_pairs) in self.pending.iter() {
136            let mut seen = HashSet::new();
137            for (token, _) in token_score_pairs {
138                if seen.insert(token) {
139                    *token_to_doc_count.entry(token).or_insert(0) += 1;
140                }
141            }
142        }
143        let total_docs = self.pending.len() as f32;
144        token_to_doc_count
145            .into_iter()
146            .filter(|(_, doc_count)| (*doc_count as f32 / total_docs) >= MAX_DF_RATIO)
147            .map(|(token, _)| token)
148            .collect()
149    }
150
151    fn build_reader(index: &Index) -> Result<IndexReader> {
152        let reader = index
153            .reader_builder()
154            .reload_policy(ReloadPolicy::Manual)
155            .try_into()?;
156        Ok(reader)
157    }
158
159    pub fn get_num_docs(&self) -> u64 {
160        let searcher = self.reader.searcher();
161        searcher.num_docs()
162    }
163
164    // TODO: 1. построить графики качество/время 2. посмотреть на глубину обхода постинг листов
165    /// execute a query that is a list of `(token#cluster)` strings
166    /// returns `Vec<(doc_id, sum_score)>` sorted desc by sum_score
167    pub fn search(
168        &self,
169        searcher: Option<&Searcher>,
170        pairs: &[&str],
171        top_k: usize,
172    ) -> Result<Vec<(u64, f64)>> {
173        if pairs.is_empty() {
174            return Ok(Vec::new());
175        }
176        let searcher = if let Some(searcher) = searcher {
177            searcher
178        } else {
179            &self.reader.searcher()
180        };
181        let mut clauses = Vec::with_capacity(pairs.len());
182        for &tok in pairs {
183            let term = Term::from_field_text(self.token_cluster_id, tok);
184            clauses.push(Box::new(TfTermQuery::new(term)) as Box<dyn Query>);
185        }
186        let bool_q = BooleanQuery::union(clauses);
187
188        let hits = searcher.search(&bool_q, &tantivy::collector::TopDocs::with_limit(top_k))?;
189
190        let mut results = Vec::with_capacity(hits.len());
191        for (score, doc_addr) in hits {
192            let retrieved_doc: TantivyDocument = searcher.doc(doc_addr)?;
193            let doc_id: u64 = retrieved_doc
194                .get_first(self.doc_id)
195                .unwrap()
196                .as_u64()
197                .unwrap();
198            results.push((doc_id, score as f64));
199        }
200        Ok(results)
201    }
202}