tf_idf_vectorizer/vectorizer/
serde.rs

1use std::sync::Arc;
2use std::hash::Hash;
3
4use ahash::RandomState;
5use num_traits::Num;
6use serde::{ser::SerializeStruct, Deserialize, Serialize};
7
8use crate::{Corpus, TFIDFVectorizer, utils::datastruct::{map::IndexMap, vector::ZeroSpVecTrait}, vectorizer::{IDFVector, KeyRc, TFVector, tfidf::{DefaultTFIDFEngine, TFIDFEngine}}};
9
10/// Data structure for deserializing TFIDFVectorizer.
11/// This struct does not contain references, so it can be serialized.
12/// Use the `into_tf_idf_vectorizer` method to convert to `TFIDFVectorizer`.
13#[derive(Debug, Deserialize, Serialize)]
14pub struct TFIDFData<N = f32, K = String, E = DefaultTFIDFEngine>
15where
16    N: Num + Copy,
17    E: TFIDFEngine<N, K>,
18    K: Clone + Eq + Hash,
19{
20    /// TF vectors for documents
21    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
22    /// Token dimension sample for TF vectors
23    pub token_dim_sample: Vec<Box<str>>,
24    /// IDF vector
25    pub idf: IDFVector<N>,
26    #[serde(default, skip_serializing, skip_deserializing)]
27    _marker: std::marker::PhantomData<E>,
28}
29
30impl<N, K, E> TFIDFData<N, K, E>
31where
32    N: Num + Copy + Into<f64> + Send + Sync,
33    E: TFIDFEngine<N, K>,
34    K: Clone + Send + Sync + Eq + Hash,
35{
36    /// Convert `TFIDFData` into `TFIDFVectorizer`.
37    /// `corpus_ref` is a reference to the corpus.
38    pub fn into_tf_idf_vectorizer(self, corpus_ref: Arc<Corpus>) -> TFIDFVectorizer<N, K, E>
39    {
40        let raw_iter = self.documents.iter();
41        let mut token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
42            IndexMap::with_capacity(self.token_dim_sample.len());
43        self.token_dim_sample.iter().for_each(|token| {
44            token_dim_rev_index.insert(&token.clone(), Vec::new());
45        });
46        for (key, doc) in raw_iter {
47            doc.tf_vec.raw_iter().for_each(|(idx, _)| {
48                let token = &self.token_dim_sample[idx];
49                token_dim_rev_index
50                    .get_mut(token).unwrap()
51                    .push(key.clone());
52            });
53        }
54
55        let mut instance = TFIDFVectorizer {
56            documents: self.documents,
57            token_dim_rev_index: token_dim_rev_index,
58            corpus_ref,
59            idf_cache: self.idf,
60            _marker: std::marker::PhantomData,
61        };
62        instance.update_idf();
63        instance
64    }
65}
66
67impl<N, K, E> Serialize for TFIDFVectorizer<N, K, E>
68where
69    N: Num + Copy + Serialize + Into<f64> + Send + Sync,
70    K: Serialize + Clone + Send + Sync + Eq + Hash,
71    E: TFIDFEngine<N, K>,
72{
73    /// Serialize TFIDFVectorizer.
74    /// This struct contains references, so they are excluded from serialization.
75    /// Use `TFIDFData` for deserialization.
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: serde::Serializer,
79    {
80        let mut state = serializer.serialize_struct("TFIDFVectorizer", 3)?;
81        state.serialize_field("documents", &self.documents)?;
82        state.serialize_field("token_dim_sample", &self.token_dim_rev_index.keys())?;
83        // これいらんわ、結局再計算してるし、、 後方互換の問題でいったん放置
84        state.serialize_field("idf", &self.idf_cache)?;
85        state.end()
86    }
87}