word_vec_rs/
parse.rs

1use std::{
2    fs::File,
3    io::{BufRead, BufReader, Read},
4    path::Path,
5    str,
6};
7
8use crate::{error::Error, space::VecSpace, vector::Vector};
9
10/// Parser for Word2Vec's .vec files.
11#[derive(Clone, Copy, Debug)]
12pub struct Word2VecParser {
13    // File options
14    parse_header: bool,
15    term_separator: char,
16    vec_separator: char,
17    binary: bool,
18
19    // Vec space options
20    index_terms: bool,
21}
22
23impl Word2VecParser {
24    #[inline]
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Parse from binary format.
30    pub fn binary(mut self) -> Self {
31        self.binary = true;
32        self
33    }
34
35    /// Don't treat the first line as header.
36    pub fn no_header(mut self) -> Self {
37        self.parse_header = false;
38        self
39    }
40
41    /// Use a custom term<->Vec separator character.
42    pub fn cust_term_separator(mut self, sep: char) -> Self {
43        self.term_separator = sep;
44        self
45    }
46
47    /// Use a custom Vec item <->Vec item separator character.
48    pub fn cust_vec_separator(mut self, sep: char) -> Self {
49        self.vec_separator = sep;
50        self
51    }
52
53    /// Whether to index the words for faster term->vec lookup.
54    pub fn index_terms(mut self, index: bool) -> Self {
55        self.index_terms = index;
56        self
57    }
58
59    pub fn parse<R: Read>(&self, reader: R) -> Result<VecSpace, Error> {
60        let mut space = VecSpace::new(0);
61
62        let mut parsed_header = false;
63        let mut line_buf = vec![];
64        let mut float_buf = vec![];
65
66        let mut r = BufReader::new(reader);
67
68        loop {
69            line_buf.clear();
70
71            if !parsed_header {
72                if r.read_until(b'\n', &mut line_buf).unwrap() == 0 {
73                    return Err(Error::InvalidVectorFormat)?;
74                }
75
76                let (_, dim) = self.parse_header(&line_buf)?;
77                space = VecSpace::new(dim);
78                float_buf.reserve_exact(dim);
79
80                if self.index_terms {
81                    space = space.with_termmap();
82                }
83
84                parsed_header = true;
85
86                // Don't parse header as vector
87                continue;
88            }
89
90            // Parse line and insert into space
91            let vec = self.parse_vec(&mut r, &mut float_buf, &mut line_buf, space.dim());
92            if vec == Err(Error::EOF) {
93                break;
94            }
95            space.insert(&vec?)?;
96        }
97
98        Ok(space)
99    }
100
101    /// Parses a word vector file.
102    #[inline]
103    pub fn parse_file<F: AsRef<Path>>(&self, file: F) -> Result<VecSpace, Error> {
104        self.parse(File::open(file)?)
105    }
106
107    /// Parses a single vec line
108    fn parse_vec<'v, 't, R: BufRead>(
109        &self,
110        r: &mut R,
111        vbuf: &'v mut Vec<f32>,
112        line_buf: &'t mut Vec<u8>,
113        vec_len: usize,
114    ) -> Result<Vector<'v, 't>, Error> {
115        vbuf.clear();
116        line_buf.clear();
117
118        if self.binary {
119            self.parse_vec_bin(r, vbuf, line_buf, vec_len)
120        } else {
121            if r.read_until(b'\n', line_buf)? == 0 {
122                return Err(Error::EOF);
123            }
124            let line = str::from_utf8(line_buf)?;
125            self.parse_vec_txt(line, vbuf)
126        }
127    }
128
129    /// Parses a word vector from txt format.
130    fn parse_vec_txt<'v, 't>(
131        &self,
132        line: &'t str,
133        buf: &'v mut Vec<f32>,
134    ) -> Result<Vector<'v, 't>, Error> {
135        let term_vec_split = line
136            .find(self.term_separator)
137            .ok_or(Error::InvalidVectorFormat)?;
138
139        for i in line[term_vec_split + 1..]
140            .trim()
141            .split(self.vec_separator)
142            .map(|i| i.parse::<f32>())
143        {
144            buf.push(i.map_err(fmt_err)?);
145        }
146
147        let term = &line[..term_vec_split];
148        Ok(Vector::new(buf, &term))
149    }
150
151    /// Parses a word vector from bin format.
152    fn parse_vec_bin<'v, 't, R: BufRead>(
153        &self,
154        r: &mut R,
155        vbuf: &'v mut Vec<f32>,
156        rbuf: &'t mut Vec<u8>,
157        vec_len: usize,
158    ) -> Result<Vector<'v, 't>, Error> {
159        if r.read_until(b' ', rbuf)? == 0 {
160            return Err(Error::EOF);
161        }
162
163        let term = str::from_utf8(&rbuf[..rbuf.len() - 1])?;
164
165        let mut float_buf = [0u8; 4];
166        for _ in 0..vec_len {
167            r.read_exact(&mut float_buf)?;
168            vbuf.push(f32::from_le_bytes(float_buf.try_into().map_err(fmt_err)?));
169        }
170
171        Ok(Vector::new(vbuf, term))
172    }
173
174    #[inline]
175    fn parse_header(&self, line: &[u8]) -> Result<(usize, usize), Error> {
176        if self.binary {
177            self.parse_header_bin(line)
178        } else {
179            let line = str::from_utf8(line)?.trim();
180            self.parse_header_txt(line)
181        }
182    }
183
184    fn parse_header_bin(&self, line: &[u8]) -> Result<(usize, usize), Error> {
185        let space = line
186            .iter()
187            .enumerate()
188            .find(|i| *i.1 == b' ')
189            .ok_or(Error::InvalidVectorFormat)?
190            .0;
191
192        let count = str::from_utf8(&line[..space])?;
193        let len = str::from_utf8(&line[space + 1..line.len() - 1])?;
194
195        let count: usize = count.parse().unwrap();
196        let len: usize = len.parse().unwrap();
197
198        Ok((count, len))
199    }
200
201    fn parse_header_txt(&self, line: &str) -> Result<(usize, usize), Error> {
202        let mut split = line.split(' ');
203        let mut next_nr = || {
204            split
205                .next()
206                .and_then(|i| i.parse::<usize>().ok())
207                .ok_or(Error::InvalidVectorFormat)
208        };
209        let count = next_nr()?;
210        let dim = next_nr()?;
211        Ok((count, dim))
212    }
213}
214
215#[inline]
216fn fmt_err<T>(_: T) -> Error {
217    Error::InvalidVectorFormat
218}
219
220impl Default for Word2VecParser {
221    fn default() -> Self {
222        Self {
223            parse_header: true,
224            term_separator: ' ',
225            vec_separator: ' ',
226            index_terms: false,
227            binary: false,
228        }
229    }
230}