1use std::{
2 fs::File,
3 io::{BufRead, BufReader, Read},
4 path::Path,
5 str,
6};
7
8use crate::{error::Error, space::VecSpace, vector::Vector};
9
10#[derive(Clone, Copy, Debug)]
12pub struct Word2VecParser {
13 parse_header: bool,
15 term_separator: char,
16 vec_separator: char,
17 binary: bool,
18
19 index_terms: bool,
21}
22
23impl Word2VecParser {
24 #[inline]
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn binary(mut self) -> Self {
31 self.binary = true;
32 self
33 }
34
35 pub fn no_header(mut self) -> Self {
37 self.parse_header = false;
38 self
39 }
40
41 pub fn cust_term_separator(mut self, sep: char) -> Self {
43 self.term_separator = sep;
44 self
45 }
46
47 pub fn cust_vec_separator(mut self, sep: char) -> Self {
49 self.vec_separator = sep;
50 self
51 }
52
53 pub fn index_terms(mut self, index: bool) -> Self {
55 self.index_terms = index;
56 self
57 }
58
59 pub fn parse<R: Read>(&self, reader: R) -> Result<VecSpace, Error> {
60 let mut space = VecSpace::new(0);
61
62 let mut parsed_header = false;
63 let mut line_buf = vec![];
64 let mut float_buf = vec![];
65
66 let mut r = BufReader::new(reader);
67
68 loop {
69 line_buf.clear();
70
71 if !parsed_header {
72 if r.read_until(b'\n', &mut line_buf).unwrap() == 0 {
73 return Err(Error::InvalidVectorFormat)?;
74 }
75
76 let (_, dim) = self.parse_header(&line_buf)?;
77 space = VecSpace::new(dim);
78 float_buf.reserve_exact(dim);
79
80 if self.index_terms {
81 space = space.with_termmap();
82 }
83
84 parsed_header = true;
85
86 continue;
88 }
89
90 let vec = self.parse_vec(&mut r, &mut float_buf, &mut line_buf, space.dim());
92 if vec == Err(Error::EOF) {
93 break;
94 }
95 space.insert(&vec?)?;
96 }
97
98 Ok(space)
99 }
100
101 #[inline]
103 pub fn parse_file<F: AsRef<Path>>(&self, file: F) -> Result<VecSpace, Error> {
104 self.parse(File::open(file)?)
105 }
106
107 fn parse_vec<'v, 't, R: BufRead>(
109 &self,
110 r: &mut R,
111 vbuf: &'v mut Vec<f32>,
112 line_buf: &'t mut Vec<u8>,
113 vec_len: usize,
114 ) -> Result<Vector<'v, 't>, Error> {
115 vbuf.clear();
116 line_buf.clear();
117
118 if self.binary {
119 self.parse_vec_bin(r, vbuf, line_buf, vec_len)
120 } else {
121 if r.read_until(b'\n', line_buf)? == 0 {
122 return Err(Error::EOF);
123 }
124 let line = str::from_utf8(line_buf)?;
125 self.parse_vec_txt(line, vbuf)
126 }
127 }
128
129 fn parse_vec_txt<'v, 't>(
131 &self,
132 line: &'t str,
133 buf: &'v mut Vec<f32>,
134 ) -> Result<Vector<'v, 't>, Error> {
135 let term_vec_split = line
136 .find(self.term_separator)
137 .ok_or(Error::InvalidVectorFormat)?;
138
139 for i in line[term_vec_split + 1..]
140 .trim()
141 .split(self.vec_separator)
142 .map(|i| i.parse::<f32>())
143 {
144 buf.push(i.map_err(fmt_err)?);
145 }
146
147 let term = &line[..term_vec_split];
148 Ok(Vector::new(buf, &term))
149 }
150
151 fn parse_vec_bin<'v, 't, R: BufRead>(
153 &self,
154 r: &mut R,
155 vbuf: &'v mut Vec<f32>,
156 rbuf: &'t mut Vec<u8>,
157 vec_len: usize,
158 ) -> Result<Vector<'v, 't>, Error> {
159 if r.read_until(b' ', rbuf)? == 0 {
160 return Err(Error::EOF);
161 }
162
163 let term = str::from_utf8(&rbuf[..rbuf.len() - 1])?;
164
165 let mut float_buf = [0u8; 4];
166 for _ in 0..vec_len {
167 r.read_exact(&mut float_buf)?;
168 vbuf.push(f32::from_le_bytes(float_buf.try_into().map_err(fmt_err)?));
169 }
170
171 Ok(Vector::new(vbuf, term))
172 }
173
174 #[inline]
175 fn parse_header(&self, line: &[u8]) -> Result<(usize, usize), Error> {
176 if self.binary {
177 self.parse_header_bin(line)
178 } else {
179 let line = str::from_utf8(line)?.trim();
180 self.parse_header_txt(line)
181 }
182 }
183
184 fn parse_header_bin(&self, line: &[u8]) -> Result<(usize, usize), Error> {
185 let space = line
186 .iter()
187 .enumerate()
188 .find(|i| *i.1 == b' ')
189 .ok_or(Error::InvalidVectorFormat)?
190 .0;
191
192 let count = str::from_utf8(&line[..space])?;
193 let len = str::from_utf8(&line[space + 1..line.len() - 1])?;
194
195 let count: usize = count.parse().unwrap();
196 let len: usize = len.parse().unwrap();
197
198 Ok((count, len))
199 }
200
201 fn parse_header_txt(&self, line: &str) -> Result<(usize, usize), Error> {
202 let mut split = line.split(' ');
203 let mut next_nr = || {
204 split
205 .next()
206 .and_then(|i| i.parse::<usize>().ok())
207 .ok_or(Error::InvalidVectorFormat)
208 };
209 let count = next_nr()?;
210 let dim = next_nr()?;
211 Ok((count, dim))
212 }
213}
214
215#[inline]
216fn fmt_err<T>(_: T) -> Error {
217 Error::InvalidVectorFormat
218}
219
220impl Default for Word2VecParser {
221 fn default() -> Self {
222 Self {
223 parse_header: true,
224 term_separator: ' ',
225 vec_separator: ' ',
226 index_terms: false,
227 binary: false,
228 }
229 }
230}