1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use std::collections::HashMap;
use std::mem;
use std::io::{BufRead, Write};
use std::slice::from_raw_parts_mut;
use byteorder::{LittleEndian, WriteBytesExt};
use ndarray::{Array, Axis};
use super::*;
pub trait ReadWord2Vec<R>
where R: BufRead
{
fn read_word2vec_binary(reader: &mut R) -> Result<Embeddings>;
}
impl<R> ReadWord2Vec<R> for Embeddings
where R: BufRead
{
fn read_word2vec_binary(reader: &mut R) -> Result<Embeddings> {
let n_words = try!(read_number(reader, b' '));
let embed_len = try!(read_number(reader, b'\n'));
let mut matrix = Array::zeros((n_words, embed_len));
let mut indices = HashMap::new();
let mut words = Vec::with_capacity(n_words);
for idx in 0..n_words {
let word = try!(read_string(reader, ' ' as u8));
let word = word.trim();
words.push(word.to_owned());
indices.insert(word.to_owned(), idx);
let mut embedding = matrix.subview_mut(Axis(0), idx);
{
let mut embedding_raw = match embedding.as_slice_mut() {
Some(s) => unsafe { typed_to_bytes(s) },
None => return Err("Matrix not contiguous".into()),
};
try!(reader.read_exact(&mut embedding_raw));
}
}
Ok(super::embeddings::new_embeddings(matrix, embed_len, indices, words))
}
}
fn read_number(reader: &mut BufRead, delim: u8) -> Result<usize> {
let field_str = try!(read_string(reader, delim));
Ok(try!(field_str.parse()))
}
fn read_string(reader: &mut BufRead, delim: u8) -> Result<String> {
let mut buf = Vec::new();
try!(reader.read_until(delim, &mut buf));
buf.pop();
Ok(try!(String::from_utf8(buf)))
}
unsafe fn typed_to_bytes<T>(slice: &mut [T]) -> &mut [u8] {
from_raw_parts_mut(slice.as_mut_ptr() as *mut u8,
slice.len() * mem::size_of::<T>())
}
pub trait WriteWord2Vec<W>
where W: Write
{
fn write_word2vec_binary(&self, w: &mut W) -> Result<()>;
}
impl<W> WriteWord2Vec<W> for Embeddings
where W: Write
{
fn write_word2vec_binary(&self, w: &mut W) -> Result<()>
where W: Write
{
write!(w, "{} {}\n", self.len(), self.embed_len())?;
for (word, embed) in self.iter() {
write!(w, "{} ", word)?;
for v in embed {
w.write_f32::<LittleEndian>(*v)?;
}
w.write(&[0x0a])?;
}
Ok(())
}
}