tf_idf_vectorizer/vectorizer/evaluate/
query.rs

1use crate::{TokenFrequency, utils::datastruct::map::{IndexMap, IndexSet}, vectorizer::KeyRc};
2
3#[derive(Clone, Debug)]
4pub enum QueryInner {
5    None,
6    All,
7    Nop(Box<str>),
8    Not(Box<QueryInner>),
9    And(Box<QueryInner>, Box<QueryInner>),
10    Or(Box<QueryInner>, Box<QueryInner>),
11}
12
13#[derive(Clone, Debug)]
14pub struct Query {
15    pub(crate) inner: QueryInner,
16}
17
18impl Query {
19    pub fn none() -> Self {
20        Query { inner: QueryInner::None }
21    }
22
23    pub fn all() -> Self {
24        Query { inner: QueryInner::All }
25    }
26
27    pub fn token<S>(token: &S) -> Self 
28    where
29        S: AsRef<str> + ?Sized,
30    {
31        Query { inner: QueryInner::Nop(Box::from(token.as_ref())) }
32    }
33
34    pub fn not(order: Query) -> Self {
35        Query { inner: QueryInner::Not(Box::new(order.inner)) }
36    }
37
38    pub fn and(left: Query, right: Query) -> Self {
39        Query { inner: QueryInner::And(Box::new(left.inner), Box::new(right.inner)) }
40    }
41
42    pub fn or(left: Query, right: Query) -> Self {
43        Query { inner: QueryInner::Or(Box::new(left.inner), Box::new(right.inner)) }
44    }
45
46    pub fn from_freq_or(freq: &TokenFrequency) -> Self {
47        let mut iter = freq.token_set_iter();
48        if let Some(first_token) = iter.next() {
49            let mut query = Query::token(first_token);
50            for token in iter {
51                let token_query = Query::token(token);
52                query = Query::or(query, token_query);
53            }
54            query
55        } else {
56            Query::none()
57        }
58    }
59
60    pub fn from_freq_and(freq: &TokenFrequency) -> Self {
61        let mut iter = freq.token_set_iter();
62        if let Some(first_token) = iter.next() {
63            let mut query = Query::token(first_token);
64            for token in iter {
65                let token_query = Query::token(token);
66                query = Query::and(query, token_query);
67            }
68            query
69        } else {
70            Query::none()
71        }
72    }
73
74    pub fn get_all_tokens(&self) -> Vec<&str> {
75        let mut tokens = Vec::new();
76        Self::collect_tokens_ref(&self.inner, &mut tokens);
77        tokens
78    }
79
80    pub(crate) fn collect_tokens_ref<'a>(query: &'a QueryInner, tokens: &mut Vec<&'a str>) {
81        match query {
82            QueryInner::All => {
83                // do nothing
84            }
85            QueryInner::None => {}
86            QueryInner::Nop(token) => {
87                tokens.push(token);
88            }
89            QueryInner::Not(inner) => {
90                Self::collect_tokens_ref(inner, tokens);
91            }
92            QueryInner::And(left, right) => {
93                Self::collect_tokens_ref(left, tokens);
94                Self::collect_tokens_ref(right, tokens);
95            }
96            QueryInner::Or(left, right) => {
97                Self::collect_tokens_ref(left, tokens);
98                Self::collect_tokens_ref(right, tokens);
99            }
100        }
101    }
102
103    pub(crate) fn build_ref<K>(query: &QueryInner, token_dim_rev_index: &IndexMap<Box<str>, Vec<KeyRc<K>>>, documents: &IndexSet<KeyRc<K>>) -> Vec<usize> 
104    where 
105        K: Eq + std::hash::Hash,
106    {
107        match query {
108            QueryInner::All => {
109                let mut result = Vec::with_capacity(documents.len());
110                for (idx, _) in documents.iter().enumerate() {
111                    result.push(idx);
112                }
113                result
114            }
115            QueryInner::None => Vec::new(),
116            QueryInner::Nop(token) => {
117                if let Some(doc_keys) = token_dim_rev_index.get(token) {
118                    let mut result = Vec::with_capacity(doc_keys.len());
119                    for doc_key in doc_keys {
120                        if let Some(idx) = documents.get_index(doc_key) {
121                            result.push(idx);
122                        }
123                    }
124                    result.sort_unstable();
125                    result
126                } else {
127                    Vec::new()
128                }
129            }
130            QueryInner::Not(inner) => {
131                let inner_indices = Self::build_ref(inner, token_dim_rev_index, documents);
132                let mut result = Vec::with_capacity(documents.len() - inner_indices.len());
133                let mut inner_iter = inner_indices.iter().peekable();
134                for (idx, _) in documents.iter().enumerate() {
135                    match inner_iter.peek() {
136                        Some(&&inner_idx) if inner_idx == idx => {
137                            inner_iter.next();
138                        }
139                        _ => {
140                            result.push(idx);
141                        }
142                    }
143                }
144                result
145            }
146            QueryInner::And(left, right) => {
147                let left_indices = Self::build_ref(left, token_dim_rev_index, documents);
148                let right_indices = Self::build_ref(right, token_dim_rev_index, documents);
149                let mut result = Vec::with_capacity(std::cmp::min(left_indices.len(), right_indices.len()));
150                let mut l = 0;
151                let mut r = 0;
152                while l < left_indices.len() && r < right_indices.len() {
153                    match left_indices[l].cmp(&right_indices[r]) {
154                        std::cmp::Ordering::Less => {
155                            l += 1;
156                        }
157                        std::cmp::Ordering::Greater => {
158                            r += 1;
159                        }
160                        std::cmp::Ordering::Equal => {
161                            result.push(left_indices[l]);
162                            l += 1;
163                            r += 1;
164                        }
165                    }
166                }
167                result
168            }
169            QueryInner::Or(left, right) => {
170                let left_indices = Self::build_ref(left, token_dim_rev_index, documents);
171                let right_indices = Self::build_ref(right, token_dim_rev_index, documents);
172                let mut result = Vec::with_capacity(left_indices.len() + right_indices.len());
173                let mut l = 0;
174                let mut r = 0;
175                while l < left_indices.len() || r < right_indices.len() {
176                    if l >= left_indices.len() {
177                        result.push(right_indices[r]);
178                        r += 1;
179                    } else if r >= right_indices.len() {
180                        result.push(left_indices[l]);
181                        l += 1;
182                    } else {
183                        match left_indices[l].cmp(&right_indices[r]) {
184                            std::cmp::Ordering::Less => {
185                                result.push(left_indices[l]);
186                                l += 1;
187                            }
188                            std::cmp::Ordering::Greater => {
189                                result.push(right_indices[r]);
190                                r += 1;
191                            }
192                            std::cmp::Ordering::Equal => {
193                                result.push(left_indices[l]);
194                                l += 1;
195                                r += 1;
196                            }
197                        }
198                    }
199                }
200                result
201            }
202        }
203    }
204
205    pub fn build<K>(&self, token_dim_rev_index: &IndexMap<Box<str>, Vec<KeyRc<K>>>, documents: &IndexSet<KeyRc<K>>) -> Vec<usize> 
206    where 
207        K: Eq + std::hash::Hash,
208    {
209        let mut res = Self::build_ref(&self.inner, token_dim_rev_index, documents);
210        res.sort_unstable();
211        res.dedup();
212        res
213    }
214}