tf_idf_vectorizer/vectorizer/
serde.rs

1use std::sync::Arc;
2
3use ahash::RandomState;
4use indexmap::IndexSet;
5use num::Num;
6use serde::{ser::SerializeStruct, Deserialize, Serialize};
7
8use crate::{vectorizer::{tfidf::{DefaultTFIDFEngine, TFIDFEngine}, IDFVector, TFVector}, Corpus, TFIDFVectorizer};
9
10/// Data structure for deserializing TFIDFVectorizer.
11/// This struct does not contain references, so it can be serialized.
12/// Use the `into_tf_idf_vectorizer` method to convert to `TFIDFVectorizer`.
13#[derive(Debug, Deserialize, Serialize)]
14pub struct TFIDFData<N = f32, K = String, E = DefaultTFIDFEngine>
15where
16    N: Num + Copy,
17    E: TFIDFEngine<N>,
18{
19    /// TF vectors for documents
20    pub documents: Vec<TFVector<N, K>>,
21    /// Token dimension sample for TF vectors
22    pub token_dim_sample: IndexSet<Box<str>, RandomState>,
23    /// IDF vector
24    pub idf: IDFVector<N>,
25    #[serde(default, skip_serializing, skip_deserializing)]
26    _marker: std::marker::PhantomData<E>,
27}
28
29impl<N, K, E> TFIDFData<N, K, E>
30where
31    N: Num + Copy + Into<f64> + Send + Sync,
32    E: TFIDFEngine<N>,
33    K: Clone + Send + Sync,
34{
35    /// Convert `TFIDFData` into `TFIDFVectorizer`.
36    /// `corpus_ref` is a reference to the corpus.
37    pub fn into_tf_idf_vectorizer(self, corpus_ref: Arc<Corpus>) -> TFIDFVectorizer<N, K, E>
38    {
39        let mut instance = TFIDFVectorizer {
40            documents: self.documents,
41            token_dim_sample: self.token_dim_sample.clone(),
42            corpus_ref,
43            idf: self.idf,
44            _marker: std::marker::PhantomData,
45        };
46        instance.update_idf();
47        instance
48    }
49}
50
51impl<N, K, E> Serialize for TFIDFVectorizer<N, K, E>
52where
53    N: Num + Copy + Serialize + Into<f64> + Send + Sync,
54    K: Serialize + Clone + Send + Sync,
55    E: TFIDFEngine<N>,
56{
57    /// Serialize TFIDFVectorizer.
58    /// This struct contains references, so they are excluded from serialization.
59    /// Use `TFIDFData` for deserialization.
60    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61    where
62        S: serde::Serializer,
63    {
64        let mut state = serializer.serialize_struct("TFIDFVectorizer", 3)?;
65        state.serialize_field("documents", &self.documents)?;
66        state.serialize_field("token_dim_sample", &self.token_dim_sample)?;
67        state.serialize_field("idf", &self.idf)?;
68        state.end()
69    }
70}