rust2vec_utils/
lib.rs

1use std::fs::File;
2use std::io::BufReader;
3
4use failure::{format_err, Error, ResultExt};
5
6use rust2vec::prelude::*;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub enum EmbeddingFormat {
10    FinalFusion,
11    FinalFusionMmap,
12    Word2Vec,
13    Text,
14    TextDims,
15}
16
17impl EmbeddingFormat {
18    pub fn try_from(format: impl AsRef<str>) -> Result<Self, Error> {
19        use EmbeddingFormat::*;
20
21        match format.as_ref() {
22            "finalfusion" => Ok(FinalFusion),
23            "finalfusion_mmap" => Ok(FinalFusionMmap),
24            "word2vec" => Ok(Word2Vec),
25            "text" => Ok(Text),
26            "textdims" => Ok(TextDims),
27            unknown => Err(format_err!("Unknown embedding format: {}", unknown)),
28        }
29    }
30}
31
32pub fn read_embeddings_view(
33    filename: &str,
34    embedding_format: EmbeddingFormat,
35) -> Result<Embeddings<VocabWrap, StorageViewWrap>, Error> {
36    let f = File::open(filename).context("Cannot open embeddings file")?;
37    let mut reader = BufReader::new(f);
38
39    use EmbeddingFormat::*;
40    let embeddings = match embedding_format {
41        FinalFusion => ReadEmbeddings::read_embeddings(&mut reader),
42        FinalFusionMmap => MmapEmbeddings::mmap_embeddings(&mut reader),
43        Word2Vec => ReadWord2Vec::read_word2vec_binary(&mut reader, true).map(Embeddings::into),
44        Text => ReadText::read_text(&mut reader, true).map(Embeddings::into),
45        TextDims => ReadTextDims::read_text_dims(&mut reader, true).map(Embeddings::into),
46    }
47    .context("Cannot read embeddings")?;
48
49    Ok(embeddings)
50}