tf_idf_vectorizer/vectorizer/
analyzer.rs1use std::{collections::{HashMap, HashSet}, sync::{atomic::{AtomicU64, Ordering}, Arc, Mutex}};
2use std::str;
3
4use fst::{MapBuilder, Streamer};
5use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
6use serde::{Deserialize, Serialize};
7
8use vec_plus::vec::default_sparse_vec::DefaultSparseVec;
9
10use super::{index::Index, token::TokenFrequency};
11
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
14pub struct Document {
15 pub text: Option<String>,
16 pub tokens: TokenFrequency,
17}
18
19impl Document {
20 pub fn new() -> Self {
21 Document {
22 text: None,
23 tokens: TokenFrequency::new(),
24 }
25 }
26
27 pub fn new_with_set(text: Option<&str>, tokens: TokenFrequency) -> Self {
28 Document {
29 text: text.map(|s| s.to_string()),
30 tokens,
31 }
32 }
33}
34
35#[derive(Serialize, Deserialize, Debug, Clone)]
36pub struct DocumentAnalyzer<IdType>
37where
38 IdType: Eq + std::hash::Hash + Clone + Serialize + Send + Sync + std::fmt::Debug,
39{
40 pub documents: HashMap<IdType, Document>,
41 pub idf: TokenFrequency,
42 pub total_doc_count: u64,
43}
44
45impl<IdType> DocumentAnalyzer<IdType>
46where
47 IdType: Eq + std::hash::Hash + Clone + Serialize + Send + Sync + std::fmt::Debug,
48{
49
50 pub fn new() -> Self {
51 Self {
52 documents: HashMap::new(),
53 idf: TokenFrequency::new(),
54 total_doc_count: 0,
55 }
56 }
57
58 pub fn add_document(&mut self, id: IdType, content: &[&str], text: Option<&str>) -> Option<&Document>{
59 if let Some(document) = self.documents.get_mut(&id) {
60 self.idf.sub_tokens_string(&document.tokens.get_token_set());
61 document.text = text.map(|s| s.to_string());
62 document.tokens.reset();
63 document. tokens.add_tokens(content);
64 self.idf.add_tokens_string(&document.tokens.get_token_set());
65 return self.documents.get(&id);
66 } else {
67 let mut tokens = TokenFrequency::new();
68 tokens.add_tokens(content);
69 self.idf.add_tokens_string(&tokens.get_token_set());
70 self.documents.insert(id.clone(), Document::new_with_set(text, tokens));
71 self.total_doc_count += 1;
72 return self.documents.get(&id);
73 }
74 }
75
76 pub fn get_document(&self, id: &IdType) -> Option<&Document> {
77 self.documents.get(id)
78 }
79
80 pub fn del_document(&mut self, id: &IdType) -> Option<Document> {
81 if let Some(document) = self.documents.remove(id) {
82 self.total_doc_count -= 1;
83 self.idf
84 .sub_tokens_string(&document.tokens.get_token_set());
85 Some(document)
86 } else {
87 None
88 }
89 }
90
91 pub fn get_document_count(&self) -> u64 {
92 self.total_doc_count
93 }
94
95 pub fn get_token_set_vec(&self) -> Vec<String> {
96 self.idf.get_token_set()
97 }
98
99 pub fn get_token_set_vec_ref(&self) -> Vec<&str> {
100 self.idf.get_token_set_ref()
101 }
102
103 pub fn get_token_set(&self) -> HashSet<String> {
104 self.idf.get_token_hashset()
105 }
106
107 pub fn get_token_set_ref(&self) -> HashSet<&str> {
108 self.idf.get_token_hashset_ref()
109 }
110
111 pub fn get_token_set_len(&self) -> usize {
112 self.idf.get_token_set_len()
113 }
114
115 pub fn generate_index(&self) -> Index<IdType> {
116 let total_doc_tokens_len = Arc::new(AtomicU64::new(0));
118 let max_doc_tokens_len = Arc::new(AtomicU64::new(0));
119 let now_prosessing = Arc::new(AtomicU64::new(0));
120
121 let mut builder = MapBuilder::memory();
123 let mut idf_vec = self.idf.get_idf_vector_ref_parallel(self.total_doc_count);
124 idf_vec.sort_by(|a, b| a.0.cmp(b.0));
125 for (token, idf) in idf_vec {
126 builder.insert(token.as_bytes(), idf as u64).unwrap();
127 }
128 let idf = Arc::new(builder.into_map());
129
130 let index = Arc::new(Mutex::new(HashMap::new()));
132
133 self.documents.par_iter().for_each(|(id, document)| {
135 now_prosessing.fetch_add(1, Ordering::SeqCst);
136 let mut tf_idf_sort_vec: Vec<u16> = Vec::new();
137
138 let tf_idf_vec: HashMap<String, u16> =
139 document.tokens.get_tfidf_hashmap_fst_parallel(&idf);
140
141 let mut stream = idf.stream();
142 while let Some((token, _)) = stream.next() {
143 let tf_idf = *tf_idf_vec.get(str::from_utf8(token).unwrap()).unwrap_or(&0);
144 tf_idf_sort_vec.push(tf_idf);
145 }
146
147 let tf_idf_csvec: DefaultSparseVec<u16> = DefaultSparseVec::from(tf_idf_sort_vec);
148 let doc_tokens_len = document.tokens.get_total_token_count();
149
150 total_doc_tokens_len.fetch_add(doc_tokens_len, Ordering::SeqCst);
151
152 max_doc_tokens_len.fetch_max(doc_tokens_len, Ordering::SeqCst);
153
154 let mut index_guard = index.lock().unwrap();
155 index_guard.insert(id.clone(), (tf_idf_csvec, doc_tokens_len));
156 });
157
158 let avg_total_doc_tokens_len = (total_doc_tokens_len.load(Ordering::SeqCst)
160 / self.total_doc_count as u64) as u64;
161 let max_doc_tokens_len = max_doc_tokens_len.load(Ordering::SeqCst);
162
163 Index::new_with_set(
165 Arc::try_unwrap(index).unwrap_or(HashMap::new().into()).into_inner().unwrap(),
166 Arc::try_unwrap(idf).unwrap(),
167 avg_total_doc_tokens_len,
168 max_doc_tokens_len,
169 self.total_doc_count,
170 )
171 }
172}