1use crate::{
2 stats::{
3 ColumnStatsOutput, EnumColumnStatsOutput, NumberColumnStatsOutput, StatsSettings,
4 TextColumnStatsOutput, TextColumnStatsOutputTopNGramsEntry, UnknownColumnStatsOutput,
5 },
6 train::{TrainGridItemOutput, TrainModelOutput},
7};
8use anyhow::Result;
9use num::ToPrimitive;
10use std::path::Path;
11use tangram_id::Id;
12use tangram_zip::zip;
13
14pub struct Model {
15 pub id: Id,
16 pub version: String,
17 pub date: String,
18 pub inner: ModelInner,
19}
20
21pub enum ModelInner {
22 Regressor(Regressor),
23 BinaryClassifier(BinaryClassifier),
24 MulticlassClassifier(MulticlassClassifier),
25}
26
27pub struct Regressor {
28 pub target_column_name: String,
29 pub train_row_count: usize,
30 pub test_row_count: usize,
31 pub overall_row_count: usize,
32 pub stats_settings: StatsSettings,
33 pub overall_column_stats: Vec<ColumnStatsOutput>,
34 pub overall_target_column_stats: ColumnStatsOutput,
35 pub train_column_stats: Vec<ColumnStatsOutput>,
36 pub train_target_column_stats: ColumnStatsOutput,
37 pub test_column_stats: Vec<ColumnStatsOutput>,
38 pub test_target_column_stats: ColumnStatsOutput,
39 pub baseline_metrics: tangram_metrics::RegressionMetricsOutput,
40 pub comparison_metric: RegressionComparisonMetric,
41 pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
42 pub best_grid_item_index: usize,
43 pub model: RegressionModel,
44 pub test_metrics: tangram_metrics::RegressionMetricsOutput,
45}
46
47pub struct BinaryClassifier {
48 pub target_column_name: String,
49 pub negative_class: String,
50 pub positive_class: String,
51 pub train_row_count: usize,
52 pub test_row_count: usize,
53 pub overall_row_count: usize,
54 pub stats_settings: StatsSettings,
55 pub overall_column_stats: Vec<ColumnStatsOutput>,
56 pub overall_target_column_stats: ColumnStatsOutput,
57 pub train_column_stats: Vec<ColumnStatsOutput>,
58 pub train_target_column_stats: ColumnStatsOutput,
59 pub test_column_stats: Vec<ColumnStatsOutput>,
60 pub test_target_column_stats: ColumnStatsOutput,
61 pub baseline_metrics: tangram_metrics::BinaryClassificationMetricsOutput,
62 pub comparison_metric: BinaryClassificationComparisonMetric,
63 pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
64 pub best_grid_item_index: usize,
65 pub model: BinaryClassificationModel,
66 pub test_metrics: tangram_metrics::BinaryClassificationMetricsOutput,
67}
68
69pub struct MulticlassClassifier {
70 pub target_column_name: String,
71 pub classes: Vec<String>,
72 pub train_row_count: usize,
73 pub test_row_count: usize,
74 pub overall_row_count: usize,
75 pub stats_settings: StatsSettings,
76 pub overall_column_stats: Vec<ColumnStatsOutput>,
77 pub overall_target_column_stats: ColumnStatsOutput,
78 pub train_column_stats: Vec<ColumnStatsOutput>,
79 pub train_target_column_stats: ColumnStatsOutput,
80 pub test_column_stats: Vec<ColumnStatsOutput>,
81 pub test_target_column_stats: ColumnStatsOutput,
82 pub baseline_metrics: tangram_metrics::MulticlassClassificationMetricsOutput,
83 pub comparison_metric: MulticlassClassificationComparisonMetric,
84 pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
85 pub best_grid_item_index: usize,
86 pub model: MulticlassClassificationModel,
87 pub test_metrics: tangram_metrics::MulticlassClassificationMetricsOutput,
88}
89
90#[derive(Clone, Copy)]
91pub enum Task {
92 Regression,
93 BinaryClassification,
94 MulticlassClassification,
95}
96
97#[derive(Clone, Copy)]
98pub enum BinaryClassificationComparisonMetric {
99 AucRoc,
100}
101
102#[derive(Clone, Copy)]
103pub enum MulticlassClassificationComparisonMetric {
104 Accuracy,
105}
106
107pub enum RegressionModel {
108 Linear(LinearRegressionModel),
109 Tree(TreeRegressionModel),
110}
111
112pub struct LinearRegressionModel {
113 pub model: tangram_linear::Regressor,
114 pub train_options: tangram_linear::TrainOptions,
115 pub feature_groups: Vec<tangram_features::FeatureGroup>,
116 pub losses: Option<Vec<f32>>,
117 pub feature_importances: Vec<f32>,
118}
119
120pub struct TreeRegressionModel {
121 pub model: tangram_tree::Regressor,
122 pub train_options: tangram_tree::TrainOptions,
123 pub feature_groups: Vec<tangram_features::FeatureGroup>,
124 pub losses: Option<Vec<f32>>,
125 pub feature_importances: Vec<f32>,
126}
127
128#[derive(Clone, Copy)]
129pub enum RegressionComparisonMetric {
130 MeanAbsoluteError,
131 MeanSquaredError,
132 RootMeanSquaredError,
133 R2,
134}
135
136pub enum BinaryClassificationModel {
137 Linear(LinearBinaryClassificationModel),
138 Tree(TreeBinaryClassificationModel),
139}
140
141pub struct LinearBinaryClassificationModel {
142 pub model: tangram_linear::BinaryClassifier,
143 pub train_options: tangram_linear::TrainOptions,
144 pub feature_groups: Vec<tangram_features::FeatureGroup>,
145 pub losses: Option<Vec<f32>>,
146 pub feature_importances: Vec<f32>,
147}
148
149pub struct TreeBinaryClassificationModel {
150 pub model: tangram_tree::BinaryClassifier,
151 pub train_options: tangram_tree::TrainOptions,
152 pub feature_groups: Vec<tangram_features::FeatureGroup>,
153 pub losses: Option<Vec<f32>>,
154 pub feature_importances: Vec<f32>,
155}
156
157pub enum MulticlassClassificationModel {
158 Linear(LinearMulticlassClassificationModel),
159 Tree(TreeMulticlassClassificationModel),
160}
161
162pub struct LinearMulticlassClassificationModel {
163 pub model: tangram_linear::MulticlassClassifier,
164 pub train_options: tangram_linear::TrainOptions,
165 pub feature_groups: Vec<tangram_features::FeatureGroup>,
166 pub losses: Option<Vec<f32>>,
167 pub feature_importances: Vec<f32>,
168}
169
170pub struct TreeMulticlassClassificationModel {
171 pub model: tangram_tree::MulticlassClassifier,
172 pub train_options: tangram_tree::TrainOptions,
173 pub feature_groups: Vec<tangram_features::FeatureGroup>,
174 pub losses: Option<Vec<f32>>,
175 pub feature_importances: Vec<f32>,
176}
177
178#[derive(Clone, Copy)]
179pub enum ComparisonMetric {
180 Regression(RegressionComparisonMetric),
181 BinaryClassification(BinaryClassificationComparisonMetric),
182 MulticlassClassification(MulticlassClassificationComparisonMetric),
183}
184
185pub enum Metrics {
186 Regression(tangram_metrics::RegressionMetricsOutput),
187 BinaryClassification(tangram_metrics::BinaryClassificationMetricsOutput),
188 MulticlassClassification(tangram_metrics::MulticlassClassificationMetricsOutput),
189}
190
191impl Model {
192 pub fn to_path(&self, path: &Path) -> Result<()> {
193 let mut writer = buffalo::Writer::new();
194 let model = serialize_model(self, &mut writer);
195 writer.write(&model);
196 let bytes = writer.into_bytes();
197 tangram_model::to_path(path, &bytes)?;
198 Ok(())
199 }
200}
201
202fn serialize_model(
203 model: &Model,
204 writer: &mut buffalo::Writer,
205) -> buffalo::Position<tangram_model::ModelWriter> {
206 let id = writer.write(model.id.to_string().as_str());
207 let version = writer.write(model.version.as_str());
208 let date = writer.write(model.date.to_string().as_str());
209 let inner = serialize_model_inner(&model.inner, writer);
210 writer.write(&tangram_model::ModelWriter {
211 id,
212 version,
213 date,
214 inner,
215 })
216}
217
218fn serialize_model_inner(
219 model_inner: &ModelInner,
220 writer: &mut buffalo::Writer,
221) -> tangram_model::ModelInnerWriter {
222 match model_inner {
223 ModelInner::Regressor(regressor) => {
224 let regressor = serialize_regressor(regressor, writer);
225 tangram_model::ModelInnerWriter::Regressor(regressor)
226 }
227 ModelInner::BinaryClassifier(binary_classifier) => {
228 let binary_classifier = serialize_binary_classifier(binary_classifier, writer);
229 tangram_model::ModelInnerWriter::BinaryClassifier(binary_classifier)
230 }
231 ModelInner::MulticlassClassifier(multiclass_classifier) => {
232 let multiclass_classifier =
233 serialize_multiclass_classifier(multiclass_classifier, writer);
234 tangram_model::ModelInnerWriter::MulticlassClassifier(multiclass_classifier)
235 }
236 }
237}
238
239fn serialize_regressor(
240 regressor: &Regressor,
241 writer: &mut buffalo::Writer,
242) -> buffalo::Position<tangram_model::RegressorWriter> {
243 let target_column_name = writer.write(regressor.target_column_name.as_str());
244 let stats_settings = serialize_stats_settings(®ressor.stats_settings, writer);
245 let overall_column_stats = regressor
246 .overall_column_stats
247 .iter()
248 .map(|overall_column_stats| serialize_column_stats_output(overall_column_stats, writer))
249 .collect::<Vec<_>>();
250 let overall_column_stats = writer.write(&overall_column_stats);
251 let overall_target_column_stats =
252 serialize_column_stats_output(®ressor.overall_target_column_stats, writer);
253 let train_column_stats = regressor
254 .train_column_stats
255 .iter()
256 .map(|train_column_stats| serialize_column_stats_output(train_column_stats, writer))
257 .collect::<Vec<_>>();
258 let train_column_stats = writer.write(&train_column_stats);
259 let train_target_column_stats =
260 serialize_column_stats_output(®ressor.train_target_column_stats, writer);
261 let test_column_stats = regressor
262 .test_column_stats
263 .iter()
264 .map(|test_column_stats| serialize_column_stats_output(test_column_stats, writer))
265 .collect::<Vec<_>>();
266 let test_column_stats = writer.write(&test_column_stats);
267 let test_target_column_stats =
268 serialize_column_stats_output(®ressor.test_target_column_stats, writer);
269 let baseline_metrics = serialize_regression_metrics_output(®ressor.baseline_metrics, writer);
270 let comparison_metric =
271 serialize_regression_comparison_metric(®ressor.comparison_metric, writer);
272 let train_grid_item_outputs = regressor
273 .train_grid_item_outputs
274 .iter()
275 .map(|train_grid_item_output| {
276 serialize_train_grid_item_output(train_grid_item_output, writer)
277 })
278 .collect::<Vec<_>>();
279 let train_grid_item_outputs = writer.write(&train_grid_item_outputs);
280 let model = serialize_regression_model(®ressor.model, writer);
281 let test_metrics = serialize_regression_metrics_output(®ressor.test_metrics, writer);
282 let regressor_writer = tangram_model::RegressorWriter {
283 target_column_name,
284 train_row_count: regressor.train_row_count.to_u64().unwrap(),
285 test_row_count: regressor.test_row_count.to_u64().unwrap(),
286 overall_row_count: regressor.overall_row_count.to_u64().unwrap(),
287 stats_settings,
288 overall_column_stats,
289 overall_target_column_stats,
290 train_column_stats,
291 train_target_column_stats,
292 test_column_stats,
293 test_target_column_stats,
294 baseline_metrics,
295 comparison_metric,
296 train_grid_item_outputs,
297 best_grid_item_index: regressor.best_grid_item_index.to_u64().unwrap(),
298 model,
299 test_metrics,
300 };
301 writer.write(®ressor_writer)
302}
303
304fn serialize_binary_classifier(
305 binary_classifier: &BinaryClassifier,
306 writer: &mut buffalo::Writer,
307) -> buffalo::Position<tangram_model::BinaryClassifierWriter> {
308 let negative_class = writer.write(binary_classifier.negative_class.as_str());
309 let positive_class = writer.write(binary_classifier.positive_class.as_str());
310 let target_column_name = writer.write(binary_classifier.target_column_name.as_str());
311 let stats_settings = serialize_stats_settings(&binary_classifier.stats_settings, writer);
312 let overall_column_stats = binary_classifier
313 .overall_column_stats
314 .iter()
315 .map(|overall_column_stats| serialize_column_stats_output(overall_column_stats, writer))
316 .collect::<Vec<_>>();
317 let overall_column_stats = writer.write(&overall_column_stats);
318 let overall_target_column_stats =
319 serialize_column_stats_output(&binary_classifier.overall_target_column_stats, writer);
320 let train_column_stats = binary_classifier
321 .train_column_stats
322 .iter()
323 .map(|train_column_stats| serialize_column_stats_output(train_column_stats, writer))
324 .collect::<Vec<_>>();
325 let train_column_stats = writer.write(&train_column_stats);
326 let train_target_column_stats =
327 serialize_column_stats_output(&binary_classifier.train_target_column_stats, writer);
328 let test_column_stats = binary_classifier
329 .test_column_stats
330 .iter()
331 .map(|test_column_stats| serialize_column_stats_output(test_column_stats, writer))
332 .collect::<Vec<_>>();
333 let test_column_stats = writer.write(&test_column_stats);
334 let test_target_column_stats =
335 serialize_column_stats_output(&binary_classifier.test_target_column_stats, writer);
336 let baseline_metrics =
337 serialize_binary_classification_metrics_output(&binary_classifier.baseline_metrics, writer);
338 let comparison_metric = serialize_binary_classification_comparison_metric(
339 &binary_classifier.comparison_metric,
340 writer,
341 );
342 let train_grid_item_outputs = binary_classifier
343 .train_grid_item_outputs
344 .iter()
345 .map(|train_grid_item_output| {
346 serialize_train_grid_item_output(train_grid_item_output, writer)
347 })
348 .collect::<Vec<_>>();
349 let train_grid_item_outputs = writer.write(&train_grid_item_outputs);
350 let model = serialize_binary_classification_model(&binary_classifier.model, writer);
351 let test_metrics =
352 serialize_binary_classification_metrics_output(&binary_classifier.test_metrics, writer);
353 let binary_classifier_writer = tangram_model::BinaryClassifierWriter {
354 target_column_name,
355 train_row_count: binary_classifier.train_row_count.to_u64().unwrap(),
356 test_row_count: binary_classifier.test_row_count.to_u64().unwrap(),
357 overall_row_count: binary_classifier.overall_row_count.to_u64().unwrap(),
358 stats_settings,
359 overall_column_stats,
360 overall_target_column_stats,
361 train_column_stats,
362 train_target_column_stats,
363 test_column_stats,
364 test_target_column_stats,
365 baseline_metrics,
366 comparison_metric,
367 train_grid_item_outputs,
368 best_grid_item_index: binary_classifier.best_grid_item_index.to_u64().unwrap(),
369 model,
370 test_metrics,
371 negative_class,
372 positive_class,
373 };
374 writer.write(&binary_classifier_writer)
375}
376
377fn serialize_multiclass_classifier(
378 multiclass_classifier: &MulticlassClassifier,
379 writer: &mut buffalo::Writer,
380) -> buffalo::Position<tangram_model::MulticlassClassifierWriter> {
381 let target_column_name = writer.write(multiclass_classifier.target_column_name.as_str());
382 let stats_settings = serialize_stats_settings(&multiclass_classifier.stats_settings, writer);
383 let overall_column_stats = multiclass_classifier
384 .overall_column_stats
385 .iter()
386 .map(|overall_column_stats| serialize_column_stats_output(overall_column_stats, writer))
387 .collect::<Vec<_>>();
388 let overall_column_stats = writer.write(&overall_column_stats);
389 let overall_target_column_stats =
390 serialize_column_stats_output(&multiclass_classifier.overall_target_column_stats, writer);
391 let train_column_stats = multiclass_classifier
392 .train_column_stats
393 .iter()
394 .map(|train_column_stats| serialize_column_stats_output(train_column_stats, writer))
395 .collect::<Vec<_>>();
396 let train_column_stats = writer.write(&train_column_stats);
397 let train_target_column_stats =
398 serialize_column_stats_output(&multiclass_classifier.train_target_column_stats, writer);
399 let test_column_stats = multiclass_classifier
400 .test_column_stats
401 .iter()
402 .map(|test_column_stats| serialize_column_stats_output(test_column_stats, writer))
403 .collect::<Vec<_>>();
404 let test_column_stats = writer.write(&test_column_stats);
405 let test_target_column_stats =
406 serialize_column_stats_output(&multiclass_classifier.test_target_column_stats, writer);
407 let baseline_metrics = serialize_multiclass_classification_metrics_output(
408 &multiclass_classifier.baseline_metrics,
409 writer,
410 );
411 let comparison_metric = serialize_multiclass_classification_comparison_metric(
412 &multiclass_classifier.comparison_metric,
413 writer,
414 );
415 let train_grid_item_outputs = multiclass_classifier
416 .train_grid_item_outputs
417 .iter()
418 .map(|train_grid_item_output| {
419 serialize_train_grid_item_output(train_grid_item_output, writer)
420 })
421 .collect::<Vec<_>>();
422 let train_grid_item_outputs = writer.write(&train_grid_item_outputs);
423 let model = serialize_multiclass_classification_model(&multiclass_classifier.model, writer);
424 let test_metrics = serialize_multiclass_classification_metrics_output(
425 &multiclass_classifier.test_metrics,
426 writer,
427 );
428 let classes = multiclass_classifier
429 .classes
430 .iter()
431 .map(|class| writer.write(class))
432 .collect::<Vec<_>>();
433 let classes = writer.write(&classes);
434 let multiclass_classifier_writer = tangram_model::MulticlassClassifierWriter {
435 target_column_name,
436 train_row_count: multiclass_classifier.train_row_count.to_u64().unwrap(),
437 test_row_count: multiclass_classifier.test_row_count.to_u64().unwrap(),
438 overall_row_count: multiclass_classifier.overall_row_count.to_u64().unwrap(),
439 stats_settings,
440 overall_column_stats,
441 overall_target_column_stats,
442 train_column_stats,
443 train_target_column_stats,
444 test_column_stats,
445 test_target_column_stats,
446 baseline_metrics,
447 comparison_metric,
448 train_grid_item_outputs,
449 best_grid_item_index: multiclass_classifier.best_grid_item_index.to_u64().unwrap(),
450 model,
451 test_metrics,
452 classes,
453 };
454 writer.write(&multiclass_classifier_writer)
455}
456
457fn serialize_stats_settings(
458 stats_settings: &StatsSettings,
459 writer: &mut buffalo::Writer,
460) -> buffalo::Position<tangram_model::StatsSettingsWriter> {
461 let stats_settings_writer = tangram_model::StatsSettingsWriter {
462 number_histogram_max_size: stats_settings.number_histogram_max_size.to_u64().unwrap(),
463 };
464 writer.write(&stats_settings_writer)
465}
466
467fn serialize_column_stats_output(
468 column_stats_output: &ColumnStatsOutput,
469 writer: &mut buffalo::Writer,
470) -> tangram_model::ColumnStatsWriter {
471 match column_stats_output {
472 ColumnStatsOutput::Unknown(unknown_column_stats) => {
473 let unknown_column_stats =
474 serialize_unknown_column_stats_output(unknown_column_stats, writer);
475 tangram_model::ColumnStatsWriter::UnknownColumn(unknown_column_stats)
476 }
477 ColumnStatsOutput::Number(number_column_stats) => {
478 let number_column_stats =
479 serialize_number_column_stats_output(number_column_stats, writer);
480 tangram_model::ColumnStatsWriter::NumberColumn(number_column_stats)
481 }
482 ColumnStatsOutput::Enum(enum_column_stats) => {
483 let enum_column_stats = serialize_enum_column_stats_output(enum_column_stats, writer);
484 tangram_model::ColumnStatsWriter::EnumColumn(enum_column_stats)
485 }
486 ColumnStatsOutput::Text(text_column_stats) => {
487 let text_column_stats = serialize_text_column_stats_output(text_column_stats, writer);
488 tangram_model::ColumnStatsWriter::TextColumn(text_column_stats)
489 }
490 }
491}
492
493fn serialize_unknown_column_stats_output(
494 uknown_column_stats_output: &UnknownColumnStatsOutput,
495 writer: &mut buffalo::Writer,
496) -> buffalo::Position<tangram_model::UnknownColumnStatsWriter> {
497 let column_name = writer.write(uknown_column_stats_output.column_name.as_str());
498 let unknown_column_stats = tangram_model::UnknownColumnStatsWriter { column_name };
499 writer.write(&unknown_column_stats)
500}
501
502fn serialize_number_column_stats_output(
503 number_column_stats_output: &NumberColumnStatsOutput,
504 writer: &mut buffalo::Writer,
505) -> buffalo::Position<tangram_model::NumberColumnStatsWriter> {
506 let column_name = writer.write(number_column_stats_output.column_name.as_str());
507 let histogram = number_column_stats_output
508 .histogram
509 .as_ref()
510 .map(|histogram| {
511 histogram
512 .iter()
513 .map(|(key, value)| (key.get(), value.to_u64().unwrap()))
514 .collect::<Vec<_>>()
515 });
516 let histogram = histogram.map(|histogram| writer.write(histogram.as_slice()));
517 let number_column_stats = tangram_model::NumberColumnStatsWriter {
518 column_name,
519 invalid_count: number_column_stats_output.invalid_count.to_u64().unwrap(),
520 unique_count: number_column_stats_output.unique_count.to_u64().unwrap(),
521 histogram,
522 min: number_column_stats_output.min,
523 max: number_column_stats_output.max,
524 mean: number_column_stats_output.mean,
525 variance: number_column_stats_output.variance,
526 std: number_column_stats_output.std,
527 p25: number_column_stats_output.p25,
528 p50: number_column_stats_output.p50,
529 p75: number_column_stats_output.p75,
530 };
531 writer.write(&number_column_stats)
532}
533
534fn serialize_enum_column_stats_output(
535 enum_column_stats_output: &EnumColumnStatsOutput,
536 writer: &mut buffalo::Writer,
537) -> buffalo::Position<tangram_model::EnumColumnStatsWriter> {
538 let column_name = writer.write(enum_column_stats_output.column_name.as_str());
539 let strings = enum_column_stats_output
540 .histogram
541 .iter()
542 .map(|(key, _)| writer.write(key))
543 .collect::<Vec<_>>();
544 let histogram = zip!(strings, enum_column_stats_output.histogram.iter())
545 .map(|(key, (_, value))| (key, value.to_u64().unwrap()))
546 .collect::<Vec<_>>();
547 let histogram = writer.write(&histogram);
548 let enum_column_stats = tangram_model::EnumColumnStatsWriter {
549 column_name,
550 invalid_count: enum_column_stats_output.invalid_count.to_u64().unwrap(),
551 histogram,
552 unique_count: enum_column_stats_output.unique_count.to_u64().unwrap(),
553 };
554 writer.write(&enum_column_stats)
555}
556
557fn serialize_text_column_stats_output(
558 text_column_stats_output: &TextColumnStatsOutput,
559 writer: &mut buffalo::Writer,
560) -> buffalo::Position<tangram_model::TextColumnStatsWriter> {
561 let column_name = writer.write(text_column_stats_output.column_name.as_str());
562 let tokenizer = serialize_tokenizer(&text_column_stats_output.tokenizer, writer);
563 let ngram_types = text_column_stats_output
564 .ngram_types
565 .iter()
566 .map(|ngram_type| serialize_ngram_type(ngram_type, writer))
567 .collect::<Vec<_>>();
568 let ngram_types = writer.write(&ngram_types);
569 let ngrams_count = text_column_stats_output.ngrams_count.to_u64().unwrap();
570 let top_ngrams = text_column_stats_output
571 .top_ngrams
572 .iter()
573 .map(|(ngram, entry)| {
574 (
575 serialize_ngram(ngram, writer),
576 serialize_text_column_stats_output_top_n_grams_entry(entry, writer),
577 )
578 })
579 .collect::<Vec<_>>();
580 let top_ngrams = writer.write(&top_ngrams);
581 let text_column_stats = tangram_model::TextColumnStatsWriter {
582 column_name,
583 tokenizer,
584 ngram_types,
585 ngrams_count,
586 top_ngrams,
587 };
588 writer.write(&text_column_stats)
589}
590
591fn serialize_tokenizer(
592 tokenizer: &tangram_text::Tokenizer,
593 writer: &mut buffalo::Writer,
594) -> buffalo::Position<tangram_model::TokenizerWriter> {
595 writer.write(&tangram_model::TokenizerWriter {
596 lowercase: tokenizer.lowercase,
597 alphanumeric: tokenizer.alphanumeric,
598 })
599}
600
601fn serialize_ngram(
602 ngram: &tangram_text::NGram,
603 writer: &mut buffalo::Writer,
604) -> buffalo::Position<tangram_model::NGramWriter> {
605 match ngram {
606 tangram_text::NGram::Unigram(token) => {
607 let token = writer.write(token);
608 writer.write(&tangram_model::NGramWriter::Unigram(token))
609 }
610 tangram_text::NGram::Bigram(token_a, token_b) => {
611 let token_a = writer.write(token_a);
612 let token_b = writer.write(token_b);
613 writer.write(&tangram_model::NGramWriter::Bigram((token_a, token_b)))
614 }
615 }
616}
617
618fn serialize_text_column_stats_output_top_n_grams_entry(
619 text_column_stats_output_top_n_grams_entry: &TextColumnStatsOutputTopNGramsEntry,
620 writer: &mut buffalo::Writer,
621) -> buffalo::Position<tangram_model::TextColumnStatsTopNGramsEntryWriter> {
622 let token_stats = tangram_model::TextColumnStatsTopNGramsEntryWriter {
623 occurrence_count: text_column_stats_output_top_n_grams_entry
624 .occurrence_count
625 .to_u64()
626 .unwrap(),
627 row_count: text_column_stats_output_top_n_grams_entry
628 .row_count
629 .to_u64()
630 .unwrap(),
631 };
632 writer.write(&token_stats)
633}
634
635fn serialize_ngram_type(
636 ngram_type: &tangram_text::NGramType,
637 _writer: &mut buffalo::Writer,
638) -> tangram_model::NGramTypeWriter {
639 match ngram_type {
640 tangram_text::NGramType::Unigram => tangram_model::NGramTypeWriter::Unigram,
641 tangram_text::NGramType::Bigram => tangram_model::NGramTypeWriter::Bigram,
642 }
643}
644
645fn serialize_regression_metrics_output(
646 regression_metrics_output: &tangram_metrics::RegressionMetricsOutput,
647 writer: &mut buffalo::Writer,
648) -> buffalo::Position<tangram_model::RegressionMetricsWriter> {
649 let regression_metrics_writer = tangram_model::RegressionMetricsWriter {
650 mse: regression_metrics_output.mse,
651 rmse: regression_metrics_output.rmse,
652 mae: regression_metrics_output.mae,
653 r2: regression_metrics_output.r2,
654 };
655 writer.write(®ression_metrics_writer)
656}
657
658fn serialize_regression_comparison_metric(
659 regression_comparison_metric_writer: &RegressionComparisonMetric,
660 _writer: &mut buffalo::Writer,
661) -> tangram_model::RegressionComparisonMetricWriter {
662 match regression_comparison_metric_writer {
663 RegressionComparisonMetric::MeanAbsoluteError => {
664 tangram_model::RegressionComparisonMetricWriter::MeanAbsoluteError
665 }
666 RegressionComparisonMetric::MeanSquaredError => {
667 tangram_model::RegressionComparisonMetricWriter::MeanSquaredError
668 }
669 RegressionComparisonMetric::RootMeanSquaredError => {
670 tangram_model::RegressionComparisonMetricWriter::RootMeanSquaredError
671 }
672 RegressionComparisonMetric::R2 => tangram_model::RegressionComparisonMetricWriter::R2,
673 }
674}
675
676fn serialize_regression_model(
677 regression_model: &RegressionModel,
678 writer: &mut buffalo::Writer,
679) -> tangram_model::RegressionModelWriter {
680 match regression_model {
681 RegressionModel::Linear(linear_model) => {
682 let linear_regressor = serialize_linear_regression_model(linear_model, writer);
683 tangram_model::RegressionModelWriter::Linear(linear_regressor)
684 }
685 RegressionModel::Tree(tree_model) => {
686 let tree_regressor = serialize_tree_regression_model(tree_model, writer);
687 tangram_model::RegressionModelWriter::Tree(tree_regressor)
688 }
689 }
690}
691
692fn serialize_early_stopping_options(
693 early_stopping_options: &tangram_linear::EarlyStoppingOptions,
694 writer: &mut buffalo::Writer,
695) -> buffalo::Position<tangram_model::LinearEarlyStoppingOptionsWriter> {
696 let early_stopping_options = tangram_model::LinearEarlyStoppingOptionsWriter {
697 early_stopping_fraction: early_stopping_options.early_stopping_fraction,
698 n_rounds_without_improvement_to_stop: early_stopping_options
699 .n_rounds_without_improvement_to_stop
700 .to_u64()
701 .unwrap(),
702 min_decrease_in_loss_for_significant_change: early_stopping_options
703 .min_decrease_in_loss_for_significant_change,
704 };
705 writer.write(&early_stopping_options)
706}
707
708fn serialize_linear_regression_model(
709 linear_regression_model: &LinearRegressionModel,
710 writer: &mut buffalo::Writer,
711) -> buffalo::Position<tangram_model::LinearRegressorWriter> {
712 let feature_importances = writer.write(linear_regression_model.feature_importances.as_slice());
713 let train_options =
714 serialize_linear_train_options(&linear_regression_model.train_options, writer);
715 let feature_groups = linear_regression_model
716 .feature_groups
717 .iter()
718 .map(|feature_group| serialize_feature_group(feature_group, writer))
719 .collect::<Vec<_>>();
720 let feature_groups = writer.write(&feature_groups);
721 let losses = linear_regression_model
722 .losses
723 .as_ref()
724 .map(|losses| writer.write(losses.as_slice()));
725 let model = linear_regression_model.model.to_writer(writer);
726 let linear_regressor_writer = tangram_model::LinearRegressorWriter {
727 model,
728 train_options,
729 feature_groups,
730 losses,
731 feature_importances,
732 };
733 writer.write(&linear_regressor_writer)
734}
735
736fn serialize_linear_train_options(
737 train_options: &tangram_linear::TrainOptions,
738 writer: &mut buffalo::Writer,
739) -> buffalo::Position<tangram_model::LinearModelTrainOptionsWriter> {
740 let early_stopping_options =
741 train_options
742 .early_stopping_options
743 .as_ref()
744 .map(|early_stopping_options| {
745 serialize_early_stopping_options(early_stopping_options, writer)
746 });
747 let train_options = tangram_model::LinearModelTrainOptionsWriter {
748 compute_loss: train_options.compute_losses,
749 l2_regularization: train_options.l2_regularization,
750 learning_rate: train_options.learning_rate,
751 max_epochs: train_options.max_epochs.to_u64().unwrap(),
752 n_examples_per_batch: train_options.n_examples_per_batch.to_u64().unwrap(),
753 early_stopping_options,
754 };
755 writer.write(&train_options)
756}
757
758fn serialize_tree_train_options(
759 train_options: &tangram_tree::TrainOptions,
760 writer: &mut buffalo::Writer,
761) -> buffalo::Position<tangram_model::TreeModelTrainOptionsWriter> {
762 let early_stopping_options =
763 train_options
764 .early_stopping_options
765 .as_ref()
766 .map(|early_stopping_options| {
767 serialize_tree_early_stopping_options(early_stopping_options, writer)
768 });
769 let max_depth = train_options
770 .max_depth
771 .map(|max_depth| max_depth.to_u64().unwrap());
772 let binned_features_layout =
773 serialize_binned_features_layout(&train_options.binned_features_layout, writer);
774 let train_options = tangram_model::TreeModelTrainOptionsWriter {
775 compute_loss: train_options.compute_losses,
776 l2_regularization_for_continuous_splits: train_options
777 .l2_regularization_for_continuous_splits,
778 l2_regularization_for_discrete_splits: train_options.l2_regularization_for_discrete_splits,
779 learning_rate: train_options.learning_rate,
780 early_stopping_options,
781 binned_features_layout,
782 max_depth,
783 max_examples_for_computing_bin_thresholds: train_options
784 .max_examples_for_computing_bin_thresholds
785 .to_u64()
786 .unwrap(),
787 max_leaf_nodes: train_options.max_leaf_nodes.to_u64().unwrap(),
788 max_rounds: train_options.max_rounds.to_u64().unwrap(),
789 max_valid_bins_for_number_features: train_options
790 .max_valid_bins_for_number_features
791 .to_u8()
792 .unwrap(),
793 min_examples_per_node: train_options.min_examples_per_node.to_u64().unwrap(),
794 min_gain_to_split: train_options.min_gain_to_split,
795 min_sum_hessians_per_node: train_options.min_sum_hessians_per_node,
796 smoothing_factor_for_discrete_bin_sorting: train_options
797 .smoothing_factor_for_discrete_bin_sorting,
798 };
799 writer.write(&train_options)
800}
801
802fn serialize_tree_early_stopping_options(
803 early_stopping_options: &tangram_tree::EarlyStoppingOptions,
804 writer: &mut buffalo::Writer,
805) -> buffalo::Position<tangram_model::TreeEarlyStoppingOptionsWriter> {
806 let early_stopping_options = tangram_model::TreeEarlyStoppingOptionsWriter {
807 early_stopping_fraction: early_stopping_options.early_stopping_fraction,
808 n_rounds_without_improvement_to_stop: early_stopping_options
809 .n_rounds_without_improvement_to_stop
810 .to_u64()
811 .unwrap(),
812 min_decrease_in_loss_for_significant_change: early_stopping_options
813 .min_decrease_in_loss_for_significant_change,
814 };
815 writer.write(&early_stopping_options)
816}
817
818fn serialize_tree_regression_model(
819 tree_regression_model: &TreeRegressionModel,
820 writer: &mut buffalo::Writer,
821) -> buffalo::Position<tangram_model::TreeRegressorWriter> {
822 let feature_importances = writer.write(tree_regression_model.feature_importances.as_slice());
823 let train_options = serialize_tree_train_options(&tree_regression_model.train_options, writer);
824 let feature_groups = tree_regression_model
825 .feature_groups
826 .iter()
827 .map(|feature_group| serialize_feature_group(feature_group, writer))
828 .collect::<Vec<_>>();
829 let feature_groups = writer.write(&feature_groups);
830 let losses = tree_regression_model
831 .losses
832 .as_ref()
833 .map(|losses| writer.write(losses.as_slice()));
834 let model = tree_regression_model.model.to_writer(writer);
835 let model = tangram_model::TreeRegressorWriter {
836 model,
837 train_options,
838 feature_groups,
839 losses,
840 feature_importances,
841 };
842 writer.write(&model)
843}
844
845fn serialize_binned_features_layout(
846 binned_features_layout: &tangram_tree::BinnedFeaturesLayout,
847 _writer: &mut buffalo::Writer,
848) -> tangram_model::BinnedFeaturesLayoutWriter {
849 match binned_features_layout {
850 tangram_tree::BinnedFeaturesLayout::RowMajor => {
851 tangram_model::BinnedFeaturesLayoutWriter::RowMajor
852 }
853 tangram_tree::BinnedFeaturesLayout::ColumnMajor => {
854 tangram_model::BinnedFeaturesLayoutWriter::ColumnMajor
855 }
856 }
857}
858
859fn serialize_train_grid_item_output(
860 train_grid_item_output: &TrainGridItemOutput,
861 writer: &mut buffalo::Writer,
862) -> buffalo::Position<tangram_model::TrainGridItemOutputWriter> {
863 let hyperparameters = match &train_grid_item_output.train_model_output {
864 TrainModelOutput::LinearRegressor(model) => {
865 let options = serialize_linear_train_options(&model.train_options, writer);
866 tangram_model::ModelTrainOptionsWriter::Linear(options)
867 }
868 TrainModelOutput::TreeRegressor(model) => {
869 let options = serialize_tree_train_options(&model.train_options, writer);
870 tangram_model::ModelTrainOptionsWriter::Tree(options)
871 }
872 TrainModelOutput::LinearBinaryClassifier(model) => {
873 let options = serialize_linear_train_options(&model.train_options, writer);
874 tangram_model::ModelTrainOptionsWriter::Linear(options)
875 }
876 TrainModelOutput::TreeBinaryClassifier(model) => {
877 let options = serialize_tree_train_options(&model.train_options, writer);
878 tangram_model::ModelTrainOptionsWriter::Tree(options)
879 }
880 TrainModelOutput::LinearMulticlassClassifier(model) => {
881 let options = serialize_linear_train_options(&model.train_options, writer);
882 tangram_model::ModelTrainOptionsWriter::Linear(options)
883 }
884 TrainModelOutput::TreeMulticlassClassifier(model) => {
885 let options = serialize_tree_train_options(&model.train_options, writer);
886 tangram_model::ModelTrainOptionsWriter::Tree(options)
887 }
888 };
889 let train_grid_item_output_writer = tangram_model::TrainGridItemOutputWriter {
890 comparison_metric_value: train_grid_item_output.comparison_metric_value,
891 hyperparameters,
892 duration: train_grid_item_output.duration.as_secs_f32(),
893 };
894 writer.write(&train_grid_item_output_writer)
895}
896
897fn serialize_feature_group(
898 feature_group: &tangram_features::FeatureGroup,
899 writer: &mut buffalo::Writer,
900) -> tangram_model::FeatureGroupWriter {
901 match feature_group {
902 tangram_features::FeatureGroup::Identity(feature_group) => {
903 let feature_group = serialize_identity_feature_group(feature_group, writer);
904 tangram_model::FeatureGroupWriter::Identity(feature_group)
905 }
906 tangram_features::FeatureGroup::Normalized(feature_group) => {
907 let feature_group = serialize_normalized_feature_group(feature_group, writer);
908 tangram_model::FeatureGroupWriter::Normalized(feature_group)
909 }
910 tangram_features::FeatureGroup::OneHotEncoded(feature_group) => {
911 let feature_group = serialize_one_hot_encoded_feature_group(feature_group, writer);
912 tangram_model::FeatureGroupWriter::OneHotEncoded(feature_group)
913 }
914 tangram_features::FeatureGroup::BagOfWords(feature_group) => {
915 let feature_group = serialize_bag_of_words_feature_group(feature_group, writer);
916 tangram_model::FeatureGroupWriter::BagOfWords(feature_group)
917 }
918 tangram_features::FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
919 let feature_group =
920 serialize_bag_of_words_cosine_similarity_feature_group(feature_group, writer);
921 tangram_model::FeatureGroupWriter::BagOfWordsCosineSimilarity(feature_group)
922 }
923 tangram_features::FeatureGroup::WordEmbedding(feature_group) => {
924 let feature_group = serialize_word_embedding_feature_group(feature_group, writer);
925 tangram_model::FeatureGroupWriter::WordEmbedding(feature_group)
926 }
927 }
928}
929
930fn serialize_identity_feature_group(
931 identity_feature_group: &tangram_features::IdentityFeatureGroup,
932 writer: &mut buffalo::Writer,
933) -> buffalo::Position<tangram_model::IdentityFeatureGroupWriter> {
934 let source_column_name = writer.write(identity_feature_group.source_column_name.as_str());
935 let feature_group = tangram_model::IdentityFeatureGroupWriter { source_column_name };
936 writer.write(&feature_group)
937}
938
939fn serialize_normalized_feature_group(
940 normalized_feature_group: &tangram_features::NormalizedFeatureGroup,
941 writer: &mut buffalo::Writer,
942) -> buffalo::Position<tangram_model::NormalizedFeatureGroupWriter> {
943 let source_column_name = writer.write(normalized_feature_group.source_column_name.as_str());
944 let feature_group = tangram_model::NormalizedFeatureGroupWriter {
945 source_column_name,
946 mean: normalized_feature_group.mean,
947 variance: normalized_feature_group.variance,
948 };
949 writer.write(&feature_group)
950}
951
952fn serialize_one_hot_encoded_feature_group(
953 one_hot_encoded_feature_group: &tangram_features::OneHotEncodedFeatureGroup,
954 writer: &mut buffalo::Writer,
955) -> buffalo::Position<tangram_model::OneHotEncodedFeatureGroupWriter> {
956 let source_column_name =
957 writer.write(one_hot_encoded_feature_group.source_column_name.as_str());
958 let variants = one_hot_encoded_feature_group
959 .variants
960 .iter()
961 .map(|variant| writer.write(variant))
962 .collect::<Vec<_>>();
963 let variants = writer.write(&variants);
964 let feature_group = tangram_model::OneHotEncodedFeatureGroupWriter {
965 source_column_name,
966 variants,
967 };
968 writer.write(&feature_group)
969}
970
971fn serialize_bag_of_words_feature_group(
972 bag_of_words_feature_group: &tangram_features::BagOfWordsFeatureGroup,
973 writer: &mut buffalo::Writer,
974) -> buffalo::Position<tangram_model::BagOfWordsFeatureGroupWriter> {
975 let source_column_name = writer.write(bag_of_words_feature_group.source_column_name.as_str());
976 let tokenizer = serialize_tokenizer(&bag_of_words_feature_group.tokenizer, writer);
977 let ngrams = bag_of_words_feature_group
978 .ngrams
979 .iter()
980 .map(|(ngram, entry)| {
981 (
982 serialize_ngram(ngram, writer),
983 serialize_bag_of_words_feature_group_n_gram_entry(entry, writer),
984 )
985 })
986 .collect::<Vec<_>>();
987 let ngrams = writer.write(&ngrams);
988 let ngram_types = bag_of_words_feature_group
989 .ngram_types
990 .iter()
991 .map(|ngram_type| serialize_ngram_type(ngram_type, writer))
992 .collect::<Vec<_>>();
993 let ngram_types = writer.write(&ngram_types);
994 let strategy =
995 serialize_bag_of_words_feature_group_strategy(&bag_of_words_feature_group.strategy, writer);
996 let feature_group = tangram_model::BagOfWordsFeatureGroupWriter {
997 source_column_name,
998 tokenizer,
999 strategy,
1000 ngram_types,
1001 ngrams,
1002 };
1003 writer.write(&feature_group)
1004}
1005
1006fn serialize_bag_of_words_feature_group_strategy(
1007 bag_of_words_feature_group_strategy: &tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy,
1008 _writer: &mut buffalo::Writer,
1009) -> tangram_model::BagOfWordsFeatureGroupStrategyWriter {
1010 match bag_of_words_feature_group_strategy {
1011 tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy::Present => {
1012 tangram_model::BagOfWordsFeatureGroupStrategyWriter::Present
1013 }
1014 tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy::Count => {
1015 tangram_model::BagOfWordsFeatureGroupStrategyWriter::Count
1016 }
1017 tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy::TfIdf => {
1018 tangram_model::BagOfWordsFeatureGroupStrategyWriter::TfIdf
1019 }
1020 }
1021}
1022
1023fn serialize_bag_of_words_feature_group_n_gram_entry(
1024 bag_of_words_feature_group_n_gram_entry: &tangram_features::bag_of_words::BagOfWordsFeatureGroupNGramEntry,
1025 writer: &mut buffalo::Writer,
1026) -> buffalo::Position<tangram_model::BagOfWordsFeatureGroupNGramEntryWriter> {
1027 writer.write(&tangram_model::BagOfWordsFeatureGroupNGramEntryWriter {
1028 idf: bag_of_words_feature_group_n_gram_entry.idf,
1029 })
1030}
1031
1032fn serialize_word_embedding_feature_group(
1033 word_embedding_feature_group: &tangram_features::WordEmbeddingFeatureGroup,
1034 writer: &mut buffalo::Writer,
1035) -> buffalo::Position<tangram_model::WordEmbeddingFeatureGroupWriter> {
1036 let source_column_name = writer.write(word_embedding_feature_group.source_column_name.as_str());
1037 let tokenizer = serialize_tokenizer(&word_embedding_feature_group.tokenizer, writer);
1038 let model = serialize_word_embedding_model(&word_embedding_feature_group.model, writer);
1039 let feature_group = tangram_model::WordEmbeddingFeatureGroupWriter {
1040 source_column_name,
1041 tokenizer,
1042 model,
1043 };
1044 writer.write(&feature_group)
1045}
1046
1047fn serialize_word_embedding_model(
1048 word_embedding_model: &tangram_text::WordEmbeddingModel,
1049 writer: &mut buffalo::Writer,
1050) -> buffalo::Position<tangram_model::WordEmbeddingModelWriter> {
1051 let size = word_embedding_model.size.to_u64().unwrap();
1052 let values = writer.write(word_embedding_model.values.as_slice());
1053 let words = word_embedding_model
1054 .words
1055 .keys()
1056 .map(|word| writer.write(word))
1057 .collect::<Vec<_>>();
1058 let words = zip!(words, word_embedding_model.words.values())
1059 .map(|(key, index)| (key, index.to_u64().unwrap()))
1060 .collect::<Vec<_>>();
1061 let words = writer.write(&words);
1062 writer.write(&tangram_model::WordEmbeddingModelWriter {
1063 size,
1064 words,
1065 values,
1066 })
1067}
1068
1069fn serialize_bag_of_words_cosine_similarity_feature_group(
1070 bag_of_words_cosine_similarity_feature_group: &tangram_features::BagOfWordsCosineSimilarityFeatureGroup,
1071 writer: &mut buffalo::Writer,
1072) -> buffalo::Position<tangram_model::BagOfWordsCosineSimilarityFeatureGroupWriter> {
1073 let source_column_name_a = writer.write(
1074 bag_of_words_cosine_similarity_feature_group
1075 .source_column_name_a
1076 .as_str(),
1077 );
1078 let source_column_name_b = writer.write(
1079 bag_of_words_cosine_similarity_feature_group
1080 .source_column_name_b
1081 .as_str(),
1082 );
1083 let tokenizer = serialize_tokenizer(
1084 &bag_of_words_cosine_similarity_feature_group.tokenizer,
1085 writer,
1086 );
1087 let ngrams = bag_of_words_cosine_similarity_feature_group
1088 .ngrams
1089 .iter()
1090 .map(|(ngram, entry)| {
1091 (
1092 serialize_ngram(ngram, writer),
1093 serialize_bag_of_words_feature_group_n_gram_entry(entry, writer),
1094 )
1095 })
1096 .collect::<Vec<_>>();
1097 let ngrams = writer.write(&ngrams);
1098 let ngram_types = bag_of_words_cosine_similarity_feature_group
1099 .ngram_types
1100 .iter()
1101 .map(|ngram_type| serialize_ngram_type(ngram_type, writer))
1102 .collect::<Vec<_>>();
1103 let ngram_types = writer.write(&ngram_types);
1104 let strategy = serialize_bag_of_words_feature_group_strategy(
1105 &bag_of_words_cosine_similarity_feature_group.strategy,
1106 writer,
1107 );
1108 let feature_group = tangram_model::BagOfWordsCosineSimilarityFeatureGroupWriter {
1109 source_column_name_a,
1110 source_column_name_b,
1111 tokenizer,
1112 strategy,
1113 ngram_types,
1114 ngrams,
1115 };
1116 writer.write(&feature_group)
1117}
1118
1119fn serialize_binary_classification_model(
1120 binary_classification_model: &BinaryClassificationModel,
1121 writer: &mut buffalo::Writer,
1122) -> tangram_model::BinaryClassificationModelWriter {
1123 match binary_classification_model {
1124 BinaryClassificationModel::Linear(model) => {
1125 let linear_binary_classifier =
1126 serialize_linear_binary_classification_model(model, writer);
1127 tangram_model::BinaryClassificationModelWriter::Linear(linear_binary_classifier)
1128 }
1129 BinaryClassificationModel::Tree(model) => {
1130 let tree_binary_classifier = serialize_tree_binary_classification_model(model, writer);
1131 tangram_model::BinaryClassificationModelWriter::Tree(tree_binary_classifier)
1132 }
1133 }
1134}
1135
1136fn serialize_linear_binary_classification_model(
1137 linear_binary_classification_model: &LinearBinaryClassificationModel,
1138 writer: &mut buffalo::Writer,
1139) -> buffalo::Position<tangram_model::LinearBinaryClassifierWriter> {
1140 let model = linear_binary_classification_model.model.to_writer(writer);
1141 let train_options =
1142 serialize_linear_train_options(&linear_binary_classification_model.train_options, writer);
1143 let feature_groups = linear_binary_classification_model
1144 .feature_groups
1145 .iter()
1146 .map(|feature_group| serialize_feature_group(feature_group, writer))
1147 .collect::<Vec<_>>();
1148 let feature_groups = writer.write(&feature_groups);
1149 let losses = linear_binary_classification_model
1150 .losses
1151 .as_ref()
1152 .map(|losses| writer.write(losses.as_slice()));
1153 let feature_importances = writer.write(
1154 linear_binary_classification_model
1155 .feature_importances
1156 .as_slice(),
1157 );
1158 let model = tangram_model::LinearBinaryClassifierWriter {
1159 model,
1160 train_options,
1161 feature_groups,
1162 losses,
1163 feature_importances,
1164 };
1165 writer.write(&model)
1166}
1167
1168fn serialize_tree_binary_classification_model(
1169 tree_binary_classification_model: &TreeBinaryClassificationModel,
1170 writer: &mut buffalo::Writer,
1171) -> buffalo::Position<tangram_model::TreeBinaryClassifierWriter> {
1172 let feature_importances = writer.write(
1173 tree_binary_classification_model
1174 .feature_importances
1175 .as_slice(),
1176 );
1177 let train_options =
1178 serialize_tree_train_options(&tree_binary_classification_model.train_options, writer);
1179 let feature_groups = tree_binary_classification_model
1180 .feature_groups
1181 .iter()
1182 .map(|feature_group| serialize_feature_group(feature_group, writer))
1183 .collect::<Vec<_>>();
1184 let feature_groups = writer.write(&feature_groups);
1185 let losses = tree_binary_classification_model
1186 .losses
1187 .as_ref()
1188 .map(|losses| writer.write(losses.as_slice()));
1189 let model = tree_binary_classification_model.model.to_writer(writer);
1190 let model = tangram_model::TreeBinaryClassifierWriter {
1191 model,
1192 train_options,
1193 feature_groups,
1194 losses,
1195 feature_importances,
1196 };
1197 writer.write(&model)
1198}
1199
1200fn serialize_binary_classification_metrics_output(
1201 binary_classification_metrics_output: &tangram_metrics::BinaryClassificationMetricsOutput,
1202 writer: &mut buffalo::Writer,
1203) -> buffalo::Position<tangram_model::BinaryClassificationMetricsWriter> {
1204 let thresholds = binary_classification_metrics_output
1205 .thresholds
1206 .iter()
1207 .map(|threshold| {
1208 serialize_binary_classification_metrics_output_for_threshold(threshold, writer)
1209 })
1210 .collect::<Vec<_>>();
1211 let thresholds = writer.write(&thresholds);
1212 let default_threshold = serialize_binary_classification_metrics_output_for_threshold(
1213 &binary_classification_metrics_output.thresholds
1214 [binary_classification_metrics_output.thresholds.len() / 2],
1215 writer,
1216 );
1217 let metrics = tangram_model::BinaryClassificationMetricsWriter {
1218 auc_roc: binary_classification_metrics_output.auc_roc_approx,
1219 default_threshold,
1220 thresholds,
1221 };
1222 writer.write(&metrics)
1223}
1224
1225fn serialize_binary_classification_metrics_output_for_threshold(
1226 binary_classification_metrics_output_for_threshold: &tangram_metrics::BinaryClassificationMetricsOutputForThreshold,
1227 writer: &mut buffalo::Writer,
1228) -> buffalo::Position<tangram_model::BinaryClassificationMetricsForThresholdWriter> {
1229 let metrics = tangram_model::BinaryClassificationMetricsForThresholdWriter {
1230 threshold: binary_classification_metrics_output_for_threshold.threshold,
1231 true_positives: binary_classification_metrics_output_for_threshold
1232 .true_positives
1233 .to_u64()
1234 .unwrap(),
1235 false_positives: binary_classification_metrics_output_for_threshold
1236 .false_positives
1237 .to_u64()
1238 .unwrap(),
1239 true_negatives: binary_classification_metrics_output_for_threshold
1240 .true_negatives
1241 .to_u64()
1242 .unwrap(),
1243 false_negatives: binary_classification_metrics_output_for_threshold
1244 .false_negatives
1245 .to_u64()
1246 .unwrap(),
1247 accuracy: binary_classification_metrics_output_for_threshold.accuracy,
1248 precision: binary_classification_metrics_output_for_threshold.precision,
1249 recall: binary_classification_metrics_output_for_threshold.recall,
1250 f1_score: binary_classification_metrics_output_for_threshold.f1_score,
1251 true_positive_rate: binary_classification_metrics_output_for_threshold.true_positive_rate,
1252 false_positive_rate: binary_classification_metrics_output_for_threshold.false_positive_rate,
1253 };
1254 writer.write(&metrics)
1255}
1256
1257fn serialize_binary_classification_comparison_metric(
1258 binary_classification_comparison_metric: &BinaryClassificationComparisonMetric,
1259 _writer: &mut buffalo::Writer,
1260) -> tangram_model::BinaryClassificationComparisonMetricWriter {
1261 match binary_classification_comparison_metric {
1262 BinaryClassificationComparisonMetric::AucRoc => {
1263 tangram_model::BinaryClassificationComparisonMetricWriter::Aucroc
1264 }
1265 }
1266}
1267
1268fn serialize_multiclass_classification_model(
1269 multiclass_classification_model: &MulticlassClassificationModel,
1270 writer: &mut buffalo::Writer,
1271) -> tangram_model::MulticlassClassificationModelWriter {
1272 match multiclass_classification_model {
1273 MulticlassClassificationModel::Linear(model) => {
1274 let linear_multiclass_classifier =
1275 serialize_linear_multiclass_classification_model(model, writer);
1276 tangram_model::MulticlassClassificationModelWriter::Linear(linear_multiclass_classifier)
1277 }
1278 MulticlassClassificationModel::Tree(model) => {
1279 let tree_multiclass_classifier =
1280 serialize_tree_multiclass_classification_model(model, writer);
1281 tangram_model::MulticlassClassificationModelWriter::Tree(tree_multiclass_classifier)
1282 }
1283 }
1284}
1285
1286fn serialize_linear_multiclass_classification_model(
1287 linear_multiclass_classification_model: &LinearMulticlassClassificationModel,
1288 writer: &mut buffalo::Writer,
1289) -> buffalo::Position<tangram_model::LinearMulticlassClassifierWriter> {
1290 let feature_importances = writer.write(
1291 linear_multiclass_classification_model
1292 .feature_importances
1293 .as_slice(),
1294 );
1295 let train_options = serialize_linear_train_options(
1296 &linear_multiclass_classification_model.train_options,
1297 writer,
1298 );
1299 let feature_groups = linear_multiclass_classification_model
1300 .feature_groups
1301 .iter()
1302 .map(|feature_group| serialize_feature_group(feature_group, writer))
1303 .collect::<Vec<_>>();
1304 let feature_groups = writer.write(&feature_groups);
1305 let losses = linear_multiclass_classification_model
1306 .losses
1307 .as_ref()
1308 .map(|losses| writer.write(losses.as_slice()));
1309 let model = linear_multiclass_classification_model
1310 .model
1311 .to_writer(writer);
1312 let model = tangram_model::LinearMulticlassClassifierWriter {
1313 model,
1314 train_options,
1315 feature_groups,
1316 losses,
1317 feature_importances,
1318 };
1319 writer.write(&model)
1320}
1321
1322fn serialize_tree_multiclass_classification_model(
1323 tree_multiclass_classification_model: &TreeMulticlassClassificationModel,
1324 writer: &mut buffalo::Writer,
1325) -> buffalo::Position<tangram_model::TreeMulticlassClassifierWriter> {
1326 let feature_importances = writer.write(
1327 tree_multiclass_classification_model
1328 .feature_importances
1329 .as_slice(),
1330 );
1331 let train_options =
1332 serialize_tree_train_options(&tree_multiclass_classification_model.train_options, writer);
1333 let feature_groups = tree_multiclass_classification_model
1334 .feature_groups
1335 .iter()
1336 .map(|feature_group| serialize_feature_group(feature_group, writer))
1337 .collect::<Vec<_>>();
1338 let feature_groups = writer.write(&feature_groups);
1339 let losses = tree_multiclass_classification_model
1340 .losses
1341 .as_ref()
1342 .map(|losses| writer.write(losses.as_slice()));
1343 let model = tree_multiclass_classification_model.model.to_writer(writer);
1344 let model = tangram_model::TreeMulticlassClassifierWriter {
1345 model,
1346 train_options,
1347 feature_groups,
1348 losses,
1349 feature_importances,
1350 };
1351 writer.write(&model)
1352}
1353
1354fn serialize_multiclass_classification_metrics_output(
1355 multiclass_classification_metrics_output: &tangram_metrics::MulticlassClassificationMetricsOutput,
1356 writer: &mut buffalo::Writer,
1357) -> buffalo::Position<tangram_model::MulticlassClassificationMetricsWriter> {
1358 let class_metrics = multiclass_classification_metrics_output
1359 .class_metrics
1360 .iter()
1361 .map(|class_metric| serialize_class_metrics(class_metric, writer))
1362 .collect::<Vec<_>>();
1363 let class_metrics = writer.write(&class_metrics);
1364 let metrics = tangram_model::MulticlassClassificationMetricsWriter {
1365 class_metrics,
1366 accuracy: multiclass_classification_metrics_output.accuracy,
1367 precision_unweighted: multiclass_classification_metrics_output.precision_unweighted,
1368 precision_weighted: multiclass_classification_metrics_output.precision_weighted,
1369 recall_unweighted: multiclass_classification_metrics_output.recall_weighted,
1370 recall_weighted: multiclass_classification_metrics_output.recall_weighted,
1371 };
1372 writer.write(&metrics)
1373}
1374
1375fn serialize_class_metrics(
1376 class_metrics: &tangram_metrics::ClassMetrics,
1377 writer: &mut buffalo::Writer,
1378) -> buffalo::Position<tangram_model::ClassMetricsWriter> {
1379 let metrics = tangram_model::ClassMetricsWriter {
1380 true_positives: class_metrics.true_positives.to_u64().unwrap(),
1381 false_positives: class_metrics.false_positives.to_u64().unwrap(),
1382 true_negatives: class_metrics.true_negatives.to_u64().unwrap(),
1383 false_negatives: class_metrics.false_negatives.to_u64().unwrap(),
1384 accuracy: class_metrics.accuracy,
1385 precision: class_metrics.precision,
1386 recall: class_metrics.recall,
1387 f1_score: class_metrics.f1_score,
1388 };
1389 writer.write(&metrics)
1390}
1391
1392fn serialize_multiclass_classification_comparison_metric(
1393 multiclass_classification_comparison_metric: &MulticlassClassificationComparisonMetric,
1394 _writer: &mut buffalo::Writer,
1395) -> tangram_model::MulticlassClassificationComparisonMetricWriter {
1396 match multiclass_classification_comparison_metric {
1397 MulticlassClassificationComparisonMetric::Accuracy => {
1398 tangram_model::MulticlassClassificationComparisonMetricWriter::Accuracy
1399 }
1400 }
1401}