1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
pub use self::{
binary_classifier::*, features::*, grid::*, model_train_options::*, multiclass_classifier::*,
regressor::*, stats::*,
};
use anyhow::{anyhow, Result};
use fnv::FnvHashMap;
use num::ToPrimitive;
use std::{convert::TryInto, io::prelude::*, path::Path};
mod binary_classifier;
mod features;
mod grid;
mod model_train_options;
mod multiclass_classifier;
mod regressor;
mod stats;
const MAGIC_NUMBER: &[u8] = b"tangram\0";
const CURRENT_REVISION: u32 = 0;
const MIN_SUPPORTED_REVISION: u32 = 0;
pub fn from_bytes(bytes: &[u8]) -> Result<ModelReader> {
let magic_number = &bytes[0..MAGIC_NUMBER.len()];
if magic_number != MAGIC_NUMBER {
return Err(anyhow!("This model did not start with the tangram magic number. Are you sure it is a .tangram file?"));
}
let bytes = &bytes[MAGIC_NUMBER.len()..];
let revision = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
#[allow(clippy::absurd_extreme_comparisons)]
if revision > CURRENT_REVISION {
return Err(anyhow!("This model has a revision number of {}, which is greater than the revision number of {} used by this version of tangram. Your model is from the future! Please update to the latest version of tangram to use it.", revision, CURRENT_REVISION));
}
#[allow(clippy::absurd_extreme_comparisons)]
if revision < MIN_SUPPORTED_REVISION {
return Err(anyhow!("This model has a revision number of {}, which is lower than the minumum supported revision number of {} for this version of tangram. Please downgrade to an earlier version of tangram to use it.", revision, MIN_SUPPORTED_REVISION));
}
let bytes = &bytes[4..];
let model = buffalo::read::<ModelReader>(bytes);
Ok(model)
}
pub fn to_path(path: &Path, bytes: &[u8]) -> Result<()> {
let mut file = std::fs::File::create(path)?;
file.write_all(&MAGIC_NUMBER)?;
file.write_all(&CURRENT_REVISION.to_le_bytes())?;
file.write_all(bytes)?;
Ok(())
}
#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct Model {
#[buffalo(id = 0, required)]
pub id: String,
#[buffalo(id = 1, required)]
pub version: String,
#[buffalo(id = 2, required)]
pub date: String,
#[buffalo(id = 3, required)]
pub inner: ModelInner,
}
#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 8)]
#[allow(clippy::large_enum_variant)]
pub enum ModelInner {
#[buffalo(id = 0)]
Regressor(Regressor),
#[buffalo(id = 1)]
BinaryClassifier(BinaryClassifier),
#[buffalo(id = 2)]
MulticlassClassifier(MulticlassClassifier),
}
impl<'a> ColumnStatsReader<'a> {
pub fn column_name(&self) -> &str {
match &self {
ColumnStatsReader::UnknownColumn(c) => &c.read().column_name(),
ColumnStatsReader::NumberColumn(c) => &c.read().column_name(),
ColumnStatsReader::EnumColumn(c) => &c.read().column_name(),
ColumnStatsReader::TextColumn(c) => &c.read().column_name(),
}
}
}
impl<'a> std::fmt::Display for NGramReader<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NGramReader::Unigram(token) => {
let token = token.read();
write!(f, "{}", token)
}
NGramReader::Bigram(token) => {
let token = token.read();
write!(f, "{} {}", token.0, token.1)
}
}
}
}
impl<'a> From<TokenizerReader<'a>> for tangram_text::Tokenizer {
fn from(value: TokenizerReader<'a>) -> Self {
tangram_text::Tokenizer {
lowercase: value.lowercase(),
alphanumeric: value.alphanumeric(),
}
}
}
impl<'a> From<NGramReader<'a>> for tangram_text::NGramRef<'a> {
fn from(value: NGramReader<'a>) -> Self {
match value {
NGramReader::Unigram(token) => {
let token = token.read();
tangram_text::NGramRef::Unigram((*token).into())
}
NGramReader::Bigram(bigram) => {
let bigram = bigram.read();
tangram_text::NGramRef::Bigram(bigram.0.into(), bigram.1.into())
}
}
}
}
impl<'a> From<WordEmbeddingModelReader<'a>> for tangram_text::WordEmbeddingModel {
fn from(value: WordEmbeddingModelReader<'a>) -> Self {
let size = value.size().to_usize().unwrap();
let words = value
.words()
.iter()
.map(|(word, index)| (word.to_owned(), index.to_usize().unwrap()))
.collect::<FnvHashMap<_, _>>();
let values = value.values().iter().collect();
tangram_text::WordEmbeddingModel {
size,
words,
values,
}
}
}