1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
use std::collections::HashMap;
use std::mem;
use std::io::{BufRead, Write};
use std::slice::from_raw_parts_mut;

use byteorder::{LittleEndian, WriteBytesExt};
use ndarray::{Array, Axis};

use super::*;

/// Method to construct `Embeddings` from a word2vec binary file.
///
/// This trait defines an extension to `Embeddings` to read the word embeddings
/// from a file in word2vec binary format.
pub trait ReadWord2Vec<R>
    where R: BufRead
{
    /// Read the embeddings from the given buffered reader.
    fn read_word2vec_binary(reader: &mut R) -> Result<Embeddings>;
}

impl<R> ReadWord2Vec<R> for Embeddings
    where R: BufRead
{
    fn read_word2vec_binary(reader: &mut R) -> Result<Embeddings> {
        let n_words = try!(read_number(reader, b' '));
        let embed_len = try!(read_number(reader, b'\n'));

        let mut matrix = Array::zeros((n_words, embed_len));
        let mut indices = HashMap::new();
        let mut words = Vec::with_capacity(n_words);

        for idx in 0..n_words {
            let word = try!(read_string(reader, ' ' as u8));
            let word = word.trim();
            words.push(word.to_owned());
            indices.insert(word.to_owned(), idx);

            let mut embedding = matrix.subview_mut(Axis(0), idx);

            {
                let mut embedding_raw = match embedding.as_slice_mut() {
                    Some(s) => unsafe { typed_to_bytes(s) },
                    None => return Err("Matrix not contiguous".into()),
                };
                try!(reader.read_exact(&mut embedding_raw));
            }
        }

        Ok(super::embeddings::new_embeddings(matrix, embed_len, indices, words))
    }
}

fn read_number(reader: &mut BufRead, delim: u8) -> Result<usize> {
    let field_str = try!(read_string(reader, delim));
    Ok(try!(field_str.parse()))
}

fn read_string(reader: &mut BufRead, delim: u8) -> Result<String> {
    let mut buf = Vec::new();
    try!(reader.read_until(delim, &mut buf));
    buf.pop();
    Ok(try!(String::from_utf8(buf)))
}

unsafe fn typed_to_bytes<T>(slice: &mut [T]) -> &mut [u8] {
    from_raw_parts_mut(slice.as_mut_ptr() as *mut u8,
                       slice.len() * mem::size_of::<T>())
}

/// Method to write `Embeddings` to a word2vec binary file.
///
/// This trait defines an extension to `Embeddings` to write the word embeddings
/// to a file in word2vec binary format.
pub trait WriteWord2Vec<W>
    where W: Write
{
    /// Write the embeddings from the given writer.
    fn write_word2vec_binary(&self, w: &mut W) -> Result<()>;
}

impl<W> WriteWord2Vec<W> for Embeddings
    where W: Write
{
    fn write_word2vec_binary(&self, w: &mut W) -> Result<()>
        where W: Write
    {
        write!(w, "{} {}\n", self.len(), self.embed_len())?;

        for (word, embed) in self.iter() {
            write!(w, "{} ", word)?;

            // Write embedding to a vector with little-endian encoding.
            for v in embed {
                w.write_f32::<LittleEndian>(*v)?;
            }

            w.write(&[0x0a])?;
        }

        Ok(())
    }
}