1use crate::{space::VecSpace, vector::Vector};
2use std::io::Write;
3
4pub const DEFAULT_WRITE_HEADER: bool = true;
5pub const DEFAULT_TERM_SEP: char = ' ';
6pub const DEFAULT_VEC_SEP: char = ' ';
7
8#[derive(Debug, Clone, Copy)]
10pub struct Exporter<W> {
11 term_separator: char,
13 vec_separator: char,
14 binary: bool,
15
16 writer: W,
18 header_written: bool,
19}
20
21impl<W> Exporter<W> {
22 #[inline]
25 pub fn new(w: W) -> Self {
26 Self {
27 term_separator: DEFAULT_TERM_SEP,
28 vec_separator: DEFAULT_VEC_SEP,
29 binary: false,
30 writer: w,
31 header_written: false,
32 }
33 }
34
35 pub fn use_binary(mut self) -> Self {
37 self.binary = true;
38 self
39 }
40}
41
42impl<W: Write> Exporter<W> {
43 pub fn export_space(self, space: &VecSpace) -> Result<usize, std::io::Error> {
45 self.export_space_filtered(space, |_| true)
46 }
47
48 pub fn export_space_filtered<F>(
51 mut self,
52 space: &VecSpace,
53 filter: F,
54 ) -> Result<usize, std::io::Error>
55 where
56 F: Fn(&Vector) -> bool,
57 {
58 let mut n = 0;
59
60 let len = space.len();
61 let dim = space.dim();
62 n += self.write_header(len, dim)?;
63
64 if self.binary {
67 n += self.writer.write(b"\n")?;
68 }
69
70 n += self.export_vectors(space.iter().filter(|i| (filter)(i)))?;
71
72 Ok(n)
73 }
74
75 pub fn export_vectors<'a, 'b, I>(&mut self, iter: I) -> Result<usize, std::io::Error>
80 where
81 I: IntoIterator<Item = Vector<'a, 'b>>,
82 {
83 if !self.header_written {
84 panic!("Expecetd header to be written");
85 }
86
87 let mut n = 0;
88
89 for i in iter.into_iter() {
90 n += self.write_vector(i)?;
91 }
92
93 Ok(n)
94 }
95
96 fn write_vector(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
98 if self.binary {
99 self.write_vector_bin(vec)
100 } else {
101 self.write_vector_txt(vec)
102 }
103 }
104
105 fn write_vector_bin(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
107 let mut n = 0;
108 n += self.writer.write(vec.term().as_bytes())?;
109 n += self.writer.write(&[b' '])?;
110 for v in vec.data() {
111 self.writer.write(&v.to_le_bytes())?;
112 }
113 Ok(n)
114 }
115
116 fn write_vector_txt(&mut self, vec: Vector) -> Result<usize, std::io::Error> {
118 let mut n = 0;
119 n += self.writer.write(b"\n")?;
120 n += self.writer.write(vec.term().as_bytes())?;
122 n += self
124 .writer
125 .write(self.term_separator.to_string().as_bytes())?;
126
127 for (pos, v) in vec.data().iter().enumerate() {
128 if pos > 0 {
129 n += self
130 .writer
131 .write(self.vec_separator.to_string().as_bytes())?;
132 }
133
134 n += self.writer.write(v.to_string().as_bytes())?;
135 }
136
137 Ok(n)
138 }
139
140 fn write_header(&mut self, dim: usize, len: usize) -> Result<usize, std::io::Error> {
142 self.header_written = true;
143 let mut n = 0;
144 n += self.writer.write(dim.to_string().as_bytes())?;
145 n += self.writer.write(b" ")?;
146 n += self.writer.write(len.to_string().as_bytes())?;
147 Ok(n)
148 }
149}
150
151#[cfg(test)]
152mod test {
153 use super::*;
154 use crate::parse::Word2VecParser;
155 use std::io::Cursor;
156
157 #[test]
158 fn test_txt_export() {
159 let vecs = [
160 Vector::new(&[1.2, 2.0, 4.4], "term1"),
161 Vector::new(&[2.3, 1.0, 3.4], "term3"),
162 Vector::new(&[3.1, 9.4, 3.0], "term3"),
163 ];
164 let mut space = VecSpace::new(3);
165 space.extend(vecs);
166
167 let mut buf: Vec<u8> = vec![];
168
169 Exporter::new(&mut buf).export_space(&space).unwrap();
170
171 let parsed = Word2VecParser::new().parse(Cursor::new(&buf)).unwrap();
172
173 assert_eq!(space, parsed);
174 }
175
176 #[test]
177 fn test_bin_export() {
178 let vecs = [
179 Vector::new(&[1.2, 2.0, 4.4], "term1"),
180 Vector::new(&[2.3, 1.0, 3.4], "term3"),
181 Vector::new(&[3.1, 9.4, 3.0], "term3"),
182 ];
183 let mut space = VecSpace::new(3);
184 space.extend(vecs);
185
186 let mut buf: Vec<u8> = vec![];
187
188 Exporter::new(&mut buf)
189 .use_binary()
190 .export_space(&space)
191 .unwrap();
192
193 let parsed = Word2VecParser::new()
194 .binary()
195 .parse(Cursor::new(&buf))
196 .unwrap();
197
198 assert_eq!(space, parsed);
199 }
200}