Skip to main content

tf_idf_vectorizer/vectorizer/evaluate/
query.rs

1use crate::{TermFrequency, utils::datastruct::map::{IndexMap, IndexSet}};
2
3#[derive(Clone, Debug)]
4pub enum QueryInner {
5    None,
6    All,
7    Nop(Box<str>),
8    Not(Box<QueryInner>),
9    And(Box<QueryInner>, Box<QueryInner>),
10    Or(Box<QueryInner>, Box<QueryInner>),
11}
12
13/// Query Structure
14///
15/// Represents a search query with logical filtering conditions.
16#[derive(Clone, Debug)]
17pub struct Query {
18    pub(crate) inner: QueryInner,
19}
20
21impl Query {
22    pub fn none() -> Self {
23        Query { inner: QueryInner::None }
24    }
25
26    pub fn all() -> Self {
27        Query { inner: QueryInner::All }
28    }
29
30    pub fn term<S>(term: &S) -> Self 
31    where
32        S: AsRef<str> + ?Sized,
33    {
34        Query { inner: QueryInner::Nop(Box::from(term.as_ref())) }
35    }
36
37    pub fn not(order: Query) -> Self {
38        Query { inner: QueryInner::Not(Box::new(order.inner)) }
39    }
40
41    pub fn and(left: Query, right: Query) -> Self {
42        Query { inner: QueryInner::And(Box::new(left.inner), Box::new(right.inner)) }
43    }
44
45    pub fn or(left: Query, right: Query) -> Self {
46        Query { inner: QueryInner::Or(Box::new(left.inner), Box::new(right.inner)) }
47    }
48
49    pub fn from_freq_or(freq: &TermFrequency) -> Self {
50        let mut iter = freq.term_set_iter();
51        if let Some(first_term) = iter.next() {
52            let mut query = Query::term(first_term);
53            for term in iter {
54                let term_query = Query::term(term);
55                query = Query::or(query, term_query);
56            }
57            query
58        } else {
59            Query::none()
60        }
61    }
62
63    pub fn from_freq_and(freq: &TermFrequency) -> Self {
64        let mut iter = freq.term_set_iter();
65        if let Some(first_term) = iter.next() {
66            let mut query = Query::term(first_term);
67            for term in iter {
68                let term_query = Query::term(term);
69                query = Query::and(query, term_query);
70            }
71            query
72        } else {
73            Query::none()
74        }
75    }
76
77    pub fn from_inner(inner: QueryInner) -> Self {
78        Query { inner }
79    }
80
81    pub fn get_all_terms(&self) -> Vec<&str> {
82        let mut terms = Vec::new();
83        Self::collect_terms_ref(&self.inner, &mut terms);
84        terms
85    }
86
87    pub(crate) fn collect_terms_ref<'a>(query: &'a QueryInner, terms: &mut Vec<&'a str>) {
88        match query {
89            QueryInner::All => {
90                // do nothing
91            }
92            QueryInner::None => {}
93            QueryInner::Nop(term) => {
94                terms.push(term);
95            }
96            QueryInner::Not(inner) => {
97                Self::collect_terms_ref(inner, terms);
98            }
99            QueryInner::And(left, right) => {
100                Self::collect_terms_ref(left, terms);
101                Self::collect_terms_ref(right, terms);
102            }
103            QueryInner::Or(left, right) => {
104                Self::collect_terms_ref(left, terms);
105                Self::collect_terms_ref(right, terms);
106            }
107        }
108    }
109
110    pub(crate) fn build_ref<K>(query: &QueryInner, term_dim_rev_index: &IndexMap<Box<str>, Vec<u32>>, documents: &IndexSet<K>) -> Vec<usize> 
111    where 
112        K: Eq + std::hash::Hash,
113    {
114        match query {
115            QueryInner::All => {
116                let mut result = Vec::with_capacity(documents.len());
117                for (idx, _) in documents.iter().enumerate() {
118                    result.push(idx);
119                }
120                result
121            }
122            QueryInner::None => Vec::new(),
123            QueryInner::Nop(term) => {
124                if let Some(doc_keys) = term_dim_rev_index.get(term) {
125                    let mut result = doc_keys.iter().map(|&id| id as usize).collect::<Vec<usize>>();
126                    result.sort_unstable();
127                    result
128                } else {
129                    Vec::new()
130                }
131            }
132            QueryInner::Not(inner) => {
133                let inner_indices = Self::build_ref(inner, term_dim_rev_index, documents);
134                let mut result = Vec::with_capacity(documents.len() - inner_indices.len());
135                let mut inner_iter = inner_indices.iter().peekable();
136                for (idx, _) in documents.iter().enumerate() {
137                    match inner_iter.peek() {
138                        Some(&&inner_idx) if inner_idx == idx => {
139                            inner_iter.next();
140                        }
141                        _ => {
142                            result.push(idx);
143                        }
144                    }
145                }
146                result
147            }
148            QueryInner::And(left, right) => {
149                let left_indices = Self::build_ref(left, term_dim_rev_index, documents);
150                let right_indices = Self::build_ref(right, term_dim_rev_index, documents);
151                let mut result = Vec::with_capacity(std::cmp::min(left_indices.len(), right_indices.len()));
152                let mut l = 0;
153                let mut r = 0;
154                while l < left_indices.len() && r < right_indices.len() {
155                    match left_indices[l].cmp(&right_indices[r]) {
156                        std::cmp::Ordering::Less => {
157                            l += 1;
158                        }
159                        std::cmp::Ordering::Greater => {
160                            r += 1;
161                        }
162                        std::cmp::Ordering::Equal => {
163                            result.push(left_indices[l]);
164                            l += 1;
165                            r += 1;
166                        }
167                    }
168                }
169                result
170            }
171            QueryInner::Or(left, right) => {
172                let left_indices = Self::build_ref(left, term_dim_rev_index, documents);
173                let right_indices = Self::build_ref(right, term_dim_rev_index, documents);
174                let mut result = Vec::with_capacity(left_indices.len() + right_indices.len());
175                let mut l = 0;
176                let mut r = 0;
177                while l < left_indices.len() || r < right_indices.len() {
178                    if l >= left_indices.len() {
179                        result.push(right_indices[r]);
180                        r += 1;
181                    } else if r >= right_indices.len() {
182                        result.push(left_indices[l]);
183                        l += 1;
184                    } else {
185                        match left_indices[l].cmp(&right_indices[r]) {
186                            std::cmp::Ordering::Less => {
187                                result.push(left_indices[l]);
188                                l += 1;
189                            }
190                            std::cmp::Ordering::Greater => {
191                                result.push(right_indices[r]);
192                                r += 1;
193                            }
194                            std::cmp::Ordering::Equal => {
195                                result.push(left_indices[l]);
196                                l += 1;
197                                r += 1;
198                            }
199                        }
200                    }
201                }
202                result
203            }
204        }
205    }
206
207    pub(crate) fn build<K>(&self, term_dim_rev_index: &IndexMap<Box<str>, Vec<u32>>, documents: &IndexSet<K>) -> Vec<usize> 
208    where 
209        K: Eq + std::hash::Hash,
210    {
211        let mut res = Self::build_ref(&self.inner, term_dim_rev_index, documents);
212        res.sort_unstable();
213        res.dedup();
214        res
215    }
216}