word_vec_rs/
export.rs

1use crate::{space::VecSpace, vector::Vector};
2use std::io::Write;
3
4pub const DEFAULT_WRITE_HEADER: bool = true;
5pub const DEFAULT_TERM_SEP: char = ' ';
6pub const DEFAULT_VEC_SEP: char = ' ';
7
8/// Exporter for vectors
9#[derive(Debug, Clone, Copy)]
10pub struct Exporter<W> {
11    // Options
12    term_separator: char,
13    vec_separator: char,
14    binary: bool,
15
16    // Where to write the data to
17    writer: W,
18    header_written: bool,
19}
20
21impl<W> Exporter<W> {
22    /// Create a new vector exporter with default configurations and a writer to which the vectors
23    /// will be written to.
24    #[inline]
25    pub fn new(w: W) -> Self {
26        Self {
27            term_separator: DEFAULT_TERM_SEP,
28            vec_separator: DEFAULT_VEC_SEP,
29            binary: false,
30            writer: w,
31            header_written: false,
32        }
33    }
34
35    /// Exports the data into binary word2vec format.
36    pub fn use_binary(mut self) -> Self {
37        self.binary = true;
38        self
39    }
40}
41
42impl<W: Write> Exporter<W> {
43    /// Exports an entire [`VecSpace`]
44    pub fn export_space(self, space: &VecSpace) -> Result<usize, std::io::Error> {
45        self.export_space_filtered(space, |_| true)
46    }
47
48    /// Exports all vectors from a [`VecSpace`] for which the given filter function returns
49    /// `true`
50    pub fn export_space_filtered<F>(
51        mut self,
52        space: &VecSpace,
53        filter: F,
54    ) -> Result<usize, std::io::Error>
55    where
56        F: Fn(&Vector) -> bool,
57    {
58        let mut n = 0;
59
60        let len = space.len();
61        let dim = space.dim();
62        n += self.write_header(len, dim)?;
63
64        // In txt format, vectors always prepend a '\n' but in binary this is not necessary, so add
65        // one after the header as this is needed for binary too.
66        if self.binary {
67            n += self.writer.write(b"\n")?;
68        }
69
70        n += self.export_vectors(space.iter().filter(|i| (filter)(i)))?;
71
72        Ok(n)
73    }
74
75    /// Export all given vectors. You have to call `write_header` first.
76    ///
77    /// # Panics:
78    /// Panics if `write_header` is true but none has been written
79    pub fn export_vectors<'a, 'b, I>(&mut self, iter: I) -> Result<usize, std::io::Error>
80    where
81        I: IntoIterator<Item = Vector<'a, 'b>>,
82    {
83        if !self.header_written {
84            panic!("Expecetd header to be written");
85        }
86
87        let mut n = 0;
88
89        for i in iter.into_iter() {
90            n += self.write_vector(i)?;
91        }
92
93        Ok(n)
94    }
95
96    /// Exports a given vector
97    fn write_vector(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
98        if self.binary {
99            self.write_vector_bin(vec)
100        } else {
101            self.write_vector_txt(vec)
102        }
103    }
104
105    /// Write a single vector in bin format.
106    fn write_vector_bin(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
107        let mut n = 0;
108        n += self.writer.write(vec.term().as_bytes())?;
109        n += self.writer.write(&[b' '])?;
110        for v in vec.data() {
111            self.writer.write(&v.to_le_bytes())?;
112        }
113        Ok(n)
114    }
115
116    /// Write a single vector in txt format.
117    fn write_vector_txt(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
118        let mut n = 0;
119        n += self.writer.write(b"\n")?;
120        // Term itself
121        n += self.writer.write(vec.term().as_bytes())?;
122        // Term separator
123        n += self
124            .writer
125            .write(self.term_separator.to_string().as_bytes())?;
126
127        for (pos, v) in vec.data().iter().enumerate() {
128            if pos > 0 {
129                n += self
130                    .writer
131                    .write(self.vec_separator.to_string().as_bytes())?;
132            }
133
134            n += self.writer.write(v.to_string().as_bytes())?;
135        }
136
137        Ok(n)
138    }
139
140    /// Writes the header line.
141    fn write_header(&mut self, dim: usize, len: usize) -> Result<usize, std::io::Error> {
142        self.header_written = true;
143        let mut n = 0;
144        n += self.writer.write(dim.to_string().as_bytes())?;
145        n += self.writer.write(b" ")?;
146        n += self.writer.write(len.to_string().as_bytes())?;
147        Ok(n)
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use super::*;
154    use crate::parse::Word2VecParser;
155    use std::io::Cursor;
156
157    #[test]
158    fn test_txt_export() {
159        let vecs = [
160            Vector::new(&[1.2, 2.0, 4.4], "term1"),
161            Vector::new(&[2.3, 1.0, 3.4], "term3"),
162            Vector::new(&[3.1, 9.4, 3.0], "term3"),
163        ];
164        let mut space = VecSpace::new(3);
165        space.extend(vecs);
166
167        let mut buf: Vec<u8> = vec![];
168
169        Exporter::new(&mut buf).export_space(&space).unwrap();
170
171        let parsed = Word2VecParser::new().parse(Cursor::new(&buf)).unwrap();
172
173        assert_eq!(space, parsed);
174    }
175
176    #[test]
177    fn test_bin_export() {
178        let vecs = [
179            Vector::new(&[1.2, 2.0, 4.4], "term1"),
180            Vector::new(&[2.3, 1.0, 3.4], "term3"),
181            Vector::new(&[3.1, 9.4, 3.0], "term3"),
182        ];
183        let mut space = VecSpace::new(3);
184        space.extend(vecs);
185
186        let mut buf: Vec<u8> = vec![];
187
188        Exporter::new(&mut buf)
189            .use_binary()
190            .export_space(&space)
191            .unwrap();
192
193        let parsed = Word2VecParser::new()
194            .binary()
195            .parse(Cursor::new(&buf))
196            .unwrap();
197
198        assert_eq!(space, parsed);
199    }
200}