tf_idf_vectorizer/vectorizer/
analyzer.rs

1use std::{collections::{HashMap, HashSet}, sync::{atomic::{AtomicU64, Ordering}, Arc, Mutex}};
2use std::str;
3
4use fst::{MapBuilder, Streamer};
5use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
6use serde::{Deserialize, Serialize};
7
8use vec_plus::vec::default_sparse_vec::DefaultSparseVec;
9
10use super::{index::Index, token::TokenFrequency};
11
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
14pub struct Document {
15    pub text: Option<String>,
16    pub tokens: TokenFrequency,
17}
18
19impl Document {
20    pub fn new() -> Self {
21        Document {
22            text: None,
23            tokens: TokenFrequency::new(),
24        }
25    }
26
27    pub fn new_with_set(text: Option<&str>, tokens: TokenFrequency) -> Self {
28        Document {
29            text: text.map(|s| s.to_string()),
30            tokens,
31        }
32    }
33}
34
35#[derive(Serialize, Deserialize, Debug, Clone)]
36pub struct DocumentAnalyzer<IdType>
37where
38    IdType: Eq + std::hash::Hash + Clone + Serialize + Send + Sync + std::fmt::Debug,
39{
40    pub documents: HashMap<IdType, Document>,
41    pub idf: TokenFrequency,
42    pub total_doc_count: u64,
43}
44
45impl<IdType> DocumentAnalyzer<IdType>
46where
47    IdType: Eq + std::hash::Hash + Clone + Serialize + Send + Sync + std::fmt::Debug,
48{
49
50    pub fn new() -> Self {
51        Self {
52            documents: HashMap::new(),
53            idf: TokenFrequency::new(),
54            total_doc_count: 0,
55        }
56    }
57
58    pub fn add_document(&mut self, id: IdType, content: &[&str], text: Option<&str>) -> Option<&Document>{
59        if let Some(document) = self.documents.get_mut(&id) {
60            self.idf.sub_tokens_string(&document.tokens.get_token_set());
61            document.text = text.map(|s| s.to_string());
62            document.tokens.reset();
63            document. tokens.add_tokens(content);
64            self.idf.add_tokens_string(&document.tokens.get_token_set());
65            return self.documents.get(&id);
66        } else {
67            let mut tokens = TokenFrequency::new();
68            tokens.add_tokens(content);
69            self.idf.add_tokens_string(&tokens.get_token_set());
70            self.documents.insert(id.clone(), Document::new_with_set(text, tokens));
71            self.total_doc_count += 1;
72            return self.documents.get(&id);
73        }
74    }
75
76    pub fn get_document(&self, id: &IdType) -> Option<&Document> {
77        self.documents.get(id)
78    }
79
80    pub fn del_document(&mut self, id: &IdType) -> Option<Document> {
81        if let Some(document) = self.documents.remove(id) {
82            self.total_doc_count -= 1;
83            self.idf
84                .sub_tokens_string(&document.tokens.get_token_set());
85            Some(document)
86        } else {
87            None
88        }
89    }
90
91    pub fn get_document_count(&self) -> u64 {
92        self.total_doc_count
93    }
94
95    pub fn get_token_set_vec(&self) -> Vec<String> {
96        self.idf.get_token_set()
97    }
98
99    pub fn get_token_set_vec_ref(&self) -> Vec<&str> {
100        self.idf.get_token_set_ref()
101    }
102
103    pub fn get_token_set(&self) -> HashSet<String> {
104        self.idf.get_token_hashset()
105    }
106
107    pub fn get_token_set_ref(&self) -> HashSet<&str> {
108        self.idf.get_token_hashset_ref()
109    }
110
111    pub fn get_token_set_len(&self) -> usize {
112        self.idf.get_token_set_len()
113    }
114
115    pub fn generate_index(&self) -> Index<IdType> {
116        // 統計の初期化
117        let total_doc_tokens_len = Arc::new(AtomicU64::new(0));
118        let max_doc_tokens_len = Arc::new(AtomicU64::new(0));
119        let now_prosessing = Arc::new(AtomicU64::new(0));
120    
121        // idf のfst生成
122        let mut builder = MapBuilder::memory();
123        let mut idf_vec = self.idf.get_idf_vector_ref_parallel(self.total_doc_count);
124        idf_vec.sort_by(|a, b| a.0.cmp(b.0));
125        for (token, idf) in idf_vec {
126            builder.insert(token.as_bytes(), idf as u64).unwrap();
127        }
128        let idf = Arc::new(builder.into_map());
129    
130        // 並列処理用のスレッドセーフなIndex
131        let index = Arc::new(Mutex::new(HashMap::new()));
132    
133        // ドキュメントごとの処理を並列化
134        self.documents.par_iter().for_each(|(id, document)| {
135            now_prosessing.fetch_add(1, Ordering::SeqCst);
136            let mut tf_idf_sort_vec: Vec<u16> = Vec::new();
137    
138            let tf_idf_vec: HashMap<String, u16> =
139                document.tokens.get_tfidf_hashmap_fst_parallel(&idf);
140    
141            let mut stream = idf.stream();
142            while let Some((token, _)) = stream.next() {
143                let tf_idf = *tf_idf_vec.get(str::from_utf8(token).unwrap()).unwrap_or(&0);
144                tf_idf_sort_vec.push(tf_idf);
145            }
146    
147            let tf_idf_csvec: DefaultSparseVec<u16> = DefaultSparseVec::from(tf_idf_sort_vec);
148            let doc_tokens_len = document.tokens.get_total_token_count();
149    
150            total_doc_tokens_len.fetch_add(doc_tokens_len, Ordering::SeqCst);
151    
152            max_doc_tokens_len.fetch_max(doc_tokens_len, Ordering::SeqCst);
153    
154            let mut index_guard = index.lock().unwrap();
155            index_guard.insert(id.clone(), (tf_idf_csvec, doc_tokens_len));
156        });
157    
158        // 統計計算
159        let avg_total_doc_tokens_len = (total_doc_tokens_len.load(Ordering::SeqCst)
160            / self.total_doc_count as u64) as u64;
161        let max_doc_tokens_len = max_doc_tokens_len.load(Ordering::SeqCst);
162    
163        // indexの返却
164        Index::new_with_set(
165            Arc::try_unwrap(index).unwrap_or(HashMap::new().into()).into_inner().unwrap(),
166            Arc::try_unwrap(idf).unwrap(),
167            avg_total_doc_tokens_len,
168            max_doc_tokens_len,
169            self.total_doc_count,
170        )
171    }
172}