1use std::fs::File;
2use std::io::BufReader;
3
4use failure::{format_err, Error, ResultExt};
5
6use rust2vec::prelude::*;
7
8#[derive(Clone, Copy, PartialEq, Eq)]
9pub enum EmbeddingFormat {
10 FinalFusion,
11 FinalFusionMmap,
12 Word2Vec,
13 Text,
14 TextDims,
15}
16
17impl EmbeddingFormat {
18 pub fn try_from(format: impl AsRef<str>) -> Result<Self, Error> {
19 use EmbeddingFormat::*;
20
21 match format.as_ref() {
22 "finalfusion" => Ok(FinalFusion),
23 "finalfusion_mmap" => Ok(FinalFusionMmap),
24 "word2vec" => Ok(Word2Vec),
25 "text" => Ok(Text),
26 "textdims" => Ok(TextDims),
27 unknown => Err(format_err!("Unknown embedding format: {}", unknown)),
28 }
29 }
30}
31
32pub fn read_embeddings_view(
33 filename: &str,
34 embedding_format: EmbeddingFormat,
35) -> Result<Embeddings<VocabWrap, StorageViewWrap>, Error> {
36 let f = File::open(filename).context("Cannot open embeddings file")?;
37 let mut reader = BufReader::new(f);
38
39 use EmbeddingFormat::*;
40 let embeddings = match embedding_format {
41 FinalFusion => ReadEmbeddings::read_embeddings(&mut reader),
42 FinalFusionMmap => MmapEmbeddings::mmap_embeddings(&mut reader),
43 Word2Vec => ReadWord2Vec::read_word2vec_binary(&mut reader, true).map(Embeddings::into),
44 Text => ReadText::read_text(&mut reader, true).map(Embeddings::into),
45 TextDims => ReadTextDims::read_text_dims(&mut reader, true).map(Embeddings::into),
46 }
47 .context("Cannot read embeddings")?;
48
49 Ok(embeddings)
50}