tangram_core/
model.rs

1use crate::{
2	stats::{
3		ColumnStatsOutput, EnumColumnStatsOutput, NumberColumnStatsOutput, StatsSettings,
4		TextColumnStatsOutput, TextColumnStatsOutputTopNGramsEntry, UnknownColumnStatsOutput,
5	},
6	train::{TrainGridItemOutput, TrainModelOutput},
7};
8use anyhow::Result;
9use num::ToPrimitive;
10use std::path::Path;
11use tangram_id::Id;
12use tangram_zip::zip;
13
14pub struct Model {
15	pub id: Id,
16	pub version: String,
17	pub date: String,
18	pub inner: ModelInner,
19}
20
21pub enum ModelInner {
22	Regressor(Regressor),
23	BinaryClassifier(BinaryClassifier),
24	MulticlassClassifier(MulticlassClassifier),
25}
26
27pub struct Regressor {
28	pub target_column_name: String,
29	pub train_row_count: usize,
30	pub test_row_count: usize,
31	pub overall_row_count: usize,
32	pub stats_settings: StatsSettings,
33	pub overall_column_stats: Vec<ColumnStatsOutput>,
34	pub overall_target_column_stats: ColumnStatsOutput,
35	pub train_column_stats: Vec<ColumnStatsOutput>,
36	pub train_target_column_stats: ColumnStatsOutput,
37	pub test_column_stats: Vec<ColumnStatsOutput>,
38	pub test_target_column_stats: ColumnStatsOutput,
39	pub baseline_metrics: tangram_metrics::RegressionMetricsOutput,
40	pub comparison_metric: RegressionComparisonMetric,
41	pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
42	pub best_grid_item_index: usize,
43	pub model: RegressionModel,
44	pub test_metrics: tangram_metrics::RegressionMetricsOutput,
45}
46
47pub struct BinaryClassifier {
48	pub target_column_name: String,
49	pub negative_class: String,
50	pub positive_class: String,
51	pub train_row_count: usize,
52	pub test_row_count: usize,
53	pub overall_row_count: usize,
54	pub stats_settings: StatsSettings,
55	pub overall_column_stats: Vec<ColumnStatsOutput>,
56	pub overall_target_column_stats: ColumnStatsOutput,
57	pub train_column_stats: Vec<ColumnStatsOutput>,
58	pub train_target_column_stats: ColumnStatsOutput,
59	pub test_column_stats: Vec<ColumnStatsOutput>,
60	pub test_target_column_stats: ColumnStatsOutput,
61	pub baseline_metrics: tangram_metrics::BinaryClassificationMetricsOutput,
62	pub comparison_metric: BinaryClassificationComparisonMetric,
63	pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
64	pub best_grid_item_index: usize,
65	pub model: BinaryClassificationModel,
66	pub test_metrics: tangram_metrics::BinaryClassificationMetricsOutput,
67}
68
69pub struct MulticlassClassifier {
70	pub target_column_name: String,
71	pub classes: Vec<String>,
72	pub train_row_count: usize,
73	pub test_row_count: usize,
74	pub overall_row_count: usize,
75	pub stats_settings: StatsSettings,
76	pub overall_column_stats: Vec<ColumnStatsOutput>,
77	pub overall_target_column_stats: ColumnStatsOutput,
78	pub train_column_stats: Vec<ColumnStatsOutput>,
79	pub train_target_column_stats: ColumnStatsOutput,
80	pub test_column_stats: Vec<ColumnStatsOutput>,
81	pub test_target_column_stats: ColumnStatsOutput,
82	pub baseline_metrics: tangram_metrics::MulticlassClassificationMetricsOutput,
83	pub comparison_metric: MulticlassClassificationComparisonMetric,
84	pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
85	pub best_grid_item_index: usize,
86	pub model: MulticlassClassificationModel,
87	pub test_metrics: tangram_metrics::MulticlassClassificationMetricsOutput,
88}
89
90#[derive(Clone, Copy)]
91pub enum Task {
92	Regression,
93	BinaryClassification,
94	MulticlassClassification,
95}
96
97#[derive(Clone, Copy)]
98pub enum BinaryClassificationComparisonMetric {
99	AucRoc,
100}
101
102#[derive(Clone, Copy)]
103pub enum MulticlassClassificationComparisonMetric {
104	Accuracy,
105}
106
107pub enum RegressionModel {
108	Linear(LinearRegressionModel),
109	Tree(TreeRegressionModel),
110}
111
112pub struct LinearRegressionModel {
113	pub model: tangram_linear::Regressor,
114	pub train_options: tangram_linear::TrainOptions,
115	pub feature_groups: Vec<tangram_features::FeatureGroup>,
116	pub losses: Option<Vec<f32>>,
117	pub feature_importances: Vec<f32>,
118}
119
120pub struct TreeRegressionModel {
121	pub model: tangram_tree::Regressor,
122	pub train_options: tangram_tree::TrainOptions,
123	pub feature_groups: Vec<tangram_features::FeatureGroup>,
124	pub losses: Option<Vec<f32>>,
125	pub feature_importances: Vec<f32>,
126}
127
128#[derive(Clone, Copy)]
129pub enum RegressionComparisonMetric {
130	MeanAbsoluteError,
131	MeanSquaredError,
132	RootMeanSquaredError,
133	R2,
134}
135
136pub enum BinaryClassificationModel {
137	Linear(LinearBinaryClassificationModel),
138	Tree(TreeBinaryClassificationModel),
139}
140
141pub struct LinearBinaryClassificationModel {
142	pub model: tangram_linear::BinaryClassifier,
143	pub train_options: tangram_linear::TrainOptions,
144	pub feature_groups: Vec<tangram_features::FeatureGroup>,
145	pub losses: Option<Vec<f32>>,
146	pub feature_importances: Vec<f32>,
147}
148
149pub struct TreeBinaryClassificationModel {
150	pub model: tangram_tree::BinaryClassifier,
151	pub train_options: tangram_tree::TrainOptions,
152	pub feature_groups: Vec<tangram_features::FeatureGroup>,
153	pub losses: Option<Vec<f32>>,
154	pub feature_importances: Vec<f32>,
155}
156
157pub enum MulticlassClassificationModel {
158	Linear(LinearMulticlassClassificationModel),
159	Tree(TreeMulticlassClassificationModel),
160}
161
162pub struct LinearMulticlassClassificationModel {
163	pub model: tangram_linear::MulticlassClassifier,
164	pub train_options: tangram_linear::TrainOptions,
165	pub feature_groups: Vec<tangram_features::FeatureGroup>,
166	pub losses: Option<Vec<f32>>,
167	pub feature_importances: Vec<f32>,
168}
169
170pub struct TreeMulticlassClassificationModel {
171	pub model: tangram_tree::MulticlassClassifier,
172	pub train_options: tangram_tree::TrainOptions,
173	pub feature_groups: Vec<tangram_features::FeatureGroup>,
174	pub losses: Option<Vec<f32>>,
175	pub feature_importances: Vec<f32>,
176}
177
178#[derive(Clone, Copy)]
179pub enum ComparisonMetric {
180	Regression(RegressionComparisonMetric),
181	BinaryClassification(BinaryClassificationComparisonMetric),
182	MulticlassClassification(MulticlassClassificationComparisonMetric),
183}
184
185pub enum Metrics {
186	Regression(tangram_metrics::RegressionMetricsOutput),
187	BinaryClassification(tangram_metrics::BinaryClassificationMetricsOutput),
188	MulticlassClassification(tangram_metrics::MulticlassClassificationMetricsOutput),
189}
190
191impl Model {
192	pub fn to_path(&self, path: &Path) -> Result<()> {
193		let mut writer = buffalo::Writer::new();
194		let model = serialize_model(self, &mut writer);
195		writer.write(&model);
196		let bytes = writer.into_bytes();
197		tangram_model::to_path(path, &bytes)?;
198		Ok(())
199	}
200}
201
202fn serialize_model(
203	model: &Model,
204	writer: &mut buffalo::Writer,
205) -> buffalo::Position<tangram_model::ModelWriter> {
206	let id = writer.write(model.id.to_string().as_str());
207	let version = writer.write(model.version.as_str());
208	let date = writer.write(model.date.to_string().as_str());
209	let inner = serialize_model_inner(&model.inner, writer);
210	writer.write(&tangram_model::ModelWriter {
211		id,
212		version,
213		date,
214		inner,
215	})
216}
217
218fn serialize_model_inner(
219	model_inner: &ModelInner,
220	writer: &mut buffalo::Writer,
221) -> tangram_model::ModelInnerWriter {
222	match model_inner {
223		ModelInner::Regressor(regressor) => {
224			let regressor = serialize_regressor(regressor, writer);
225			tangram_model::ModelInnerWriter::Regressor(regressor)
226		}
227		ModelInner::BinaryClassifier(binary_classifier) => {
228			let binary_classifier = serialize_binary_classifier(binary_classifier, writer);
229			tangram_model::ModelInnerWriter::BinaryClassifier(binary_classifier)
230		}
231		ModelInner::MulticlassClassifier(multiclass_classifier) => {
232			let multiclass_classifier =
233				serialize_multiclass_classifier(multiclass_classifier, writer);
234			tangram_model::ModelInnerWriter::MulticlassClassifier(multiclass_classifier)
235		}
236	}
237}
238
239fn serialize_regressor(
240	regressor: &Regressor,
241	writer: &mut buffalo::Writer,
242) -> buffalo::Position<tangram_model::RegressorWriter> {
243	let target_column_name = writer.write(regressor.target_column_name.as_str());
244	let stats_settings = serialize_stats_settings(&regressor.stats_settings, writer);
245	let overall_column_stats = regressor
246		.overall_column_stats
247		.iter()
248		.map(|overall_column_stats| serialize_column_stats_output(overall_column_stats, writer))
249		.collect::<Vec<_>>();
250	let overall_column_stats = writer.write(&overall_column_stats);
251	let overall_target_column_stats =
252		serialize_column_stats_output(&regressor.overall_target_column_stats, writer);
253	let train_column_stats = regressor
254		.train_column_stats
255		.iter()
256		.map(|train_column_stats| serialize_column_stats_output(train_column_stats, writer))
257		.collect::<Vec<_>>();
258	let train_column_stats = writer.write(&train_column_stats);
259	let train_target_column_stats =
260		serialize_column_stats_output(&regressor.train_target_column_stats, writer);
261	let test_column_stats = regressor
262		.test_column_stats
263		.iter()
264		.map(|test_column_stats| serialize_column_stats_output(test_column_stats, writer))
265		.collect::<Vec<_>>();
266	let test_column_stats = writer.write(&test_column_stats);
267	let test_target_column_stats =
268		serialize_column_stats_output(&regressor.test_target_column_stats, writer);
269	let baseline_metrics = serialize_regression_metrics_output(&regressor.baseline_metrics, writer);
270	let comparison_metric =
271		serialize_regression_comparison_metric(&regressor.comparison_metric, writer);
272	let train_grid_item_outputs = regressor
273		.train_grid_item_outputs
274		.iter()
275		.map(|train_grid_item_output| {
276			serialize_train_grid_item_output(train_grid_item_output, writer)
277		})
278		.collect::<Vec<_>>();
279	let train_grid_item_outputs = writer.write(&train_grid_item_outputs);
280	let model = serialize_regression_model(&regressor.model, writer);
281	let test_metrics = serialize_regression_metrics_output(&regressor.test_metrics, writer);
282	let regressor_writer = tangram_model::RegressorWriter {
283		target_column_name,
284		train_row_count: regressor.train_row_count.to_u64().unwrap(),
285		test_row_count: regressor.test_row_count.to_u64().unwrap(),
286		overall_row_count: regressor.overall_row_count.to_u64().unwrap(),
287		stats_settings,
288		overall_column_stats,
289		overall_target_column_stats,
290		train_column_stats,
291		train_target_column_stats,
292		test_column_stats,
293		test_target_column_stats,
294		baseline_metrics,
295		comparison_metric,
296		train_grid_item_outputs,
297		best_grid_item_index: regressor.best_grid_item_index.to_u64().unwrap(),
298		model,
299		test_metrics,
300	};
301	writer.write(&regressor_writer)
302}
303
304fn serialize_binary_classifier(
305	binary_classifier: &BinaryClassifier,
306	writer: &mut buffalo::Writer,
307) -> buffalo::Position<tangram_model::BinaryClassifierWriter> {
308	let negative_class = writer.write(binary_classifier.negative_class.as_str());
309	let positive_class = writer.write(binary_classifier.positive_class.as_str());
310	let target_column_name = writer.write(binary_classifier.target_column_name.as_str());
311	let stats_settings = serialize_stats_settings(&binary_classifier.stats_settings, writer);
312	let overall_column_stats = binary_classifier
313		.overall_column_stats
314		.iter()
315		.map(|overall_column_stats| serialize_column_stats_output(overall_column_stats, writer))
316		.collect::<Vec<_>>();
317	let overall_column_stats = writer.write(&overall_column_stats);
318	let overall_target_column_stats =
319		serialize_column_stats_output(&binary_classifier.overall_target_column_stats, writer);
320	let train_column_stats = binary_classifier
321		.train_column_stats
322		.iter()
323		.map(|train_column_stats| serialize_column_stats_output(train_column_stats, writer))
324		.collect::<Vec<_>>();
325	let train_column_stats = writer.write(&train_column_stats);
326	let train_target_column_stats =
327		serialize_column_stats_output(&binary_classifier.train_target_column_stats, writer);
328	let test_column_stats = binary_classifier
329		.test_column_stats
330		.iter()
331		.map(|test_column_stats| serialize_column_stats_output(test_column_stats, writer))
332		.collect::<Vec<_>>();
333	let test_column_stats = writer.write(&test_column_stats);
334	let test_target_column_stats =
335		serialize_column_stats_output(&binary_classifier.test_target_column_stats, writer);
336	let baseline_metrics =
337		serialize_binary_classification_metrics_output(&binary_classifier.baseline_metrics, writer);
338	let comparison_metric = serialize_binary_classification_comparison_metric(
339		&binary_classifier.comparison_metric,
340		writer,
341	);
342	let train_grid_item_outputs = binary_classifier
343		.train_grid_item_outputs
344		.iter()
345		.map(|train_grid_item_output| {
346			serialize_train_grid_item_output(train_grid_item_output, writer)
347		})
348		.collect::<Vec<_>>();
349	let train_grid_item_outputs = writer.write(&train_grid_item_outputs);
350	let model = serialize_binary_classification_model(&binary_classifier.model, writer);
351	let test_metrics =
352		serialize_binary_classification_metrics_output(&binary_classifier.test_metrics, writer);
353	let binary_classifier_writer = tangram_model::BinaryClassifierWriter {
354		target_column_name,
355		train_row_count: binary_classifier.train_row_count.to_u64().unwrap(),
356		test_row_count: binary_classifier.test_row_count.to_u64().unwrap(),
357		overall_row_count: binary_classifier.overall_row_count.to_u64().unwrap(),
358		stats_settings,
359		overall_column_stats,
360		overall_target_column_stats,
361		train_column_stats,
362		train_target_column_stats,
363		test_column_stats,
364		test_target_column_stats,
365		baseline_metrics,
366		comparison_metric,
367		train_grid_item_outputs,
368		best_grid_item_index: binary_classifier.best_grid_item_index.to_u64().unwrap(),
369		model,
370		test_metrics,
371		negative_class,
372		positive_class,
373	};
374	writer.write(&binary_classifier_writer)
375}
376
377fn serialize_multiclass_classifier(
378	multiclass_classifier: &MulticlassClassifier,
379	writer: &mut buffalo::Writer,
380) -> buffalo::Position<tangram_model::MulticlassClassifierWriter> {
381	let target_column_name = writer.write(multiclass_classifier.target_column_name.as_str());
382	let stats_settings = serialize_stats_settings(&multiclass_classifier.stats_settings, writer);
383	let overall_column_stats = multiclass_classifier
384		.overall_column_stats
385		.iter()
386		.map(|overall_column_stats| serialize_column_stats_output(overall_column_stats, writer))
387		.collect::<Vec<_>>();
388	let overall_column_stats = writer.write(&overall_column_stats);
389	let overall_target_column_stats =
390		serialize_column_stats_output(&multiclass_classifier.overall_target_column_stats, writer);
391	let train_column_stats = multiclass_classifier
392		.train_column_stats
393		.iter()
394		.map(|train_column_stats| serialize_column_stats_output(train_column_stats, writer))
395		.collect::<Vec<_>>();
396	let train_column_stats = writer.write(&train_column_stats);
397	let train_target_column_stats =
398		serialize_column_stats_output(&multiclass_classifier.train_target_column_stats, writer);
399	let test_column_stats = multiclass_classifier
400		.test_column_stats
401		.iter()
402		.map(|test_column_stats| serialize_column_stats_output(test_column_stats, writer))
403		.collect::<Vec<_>>();
404	let test_column_stats = writer.write(&test_column_stats);
405	let test_target_column_stats =
406		serialize_column_stats_output(&multiclass_classifier.test_target_column_stats, writer);
407	let baseline_metrics = serialize_multiclass_classification_metrics_output(
408		&multiclass_classifier.baseline_metrics,
409		writer,
410	);
411	let comparison_metric = serialize_multiclass_classification_comparison_metric(
412		&multiclass_classifier.comparison_metric,
413		writer,
414	);
415	let train_grid_item_outputs = multiclass_classifier
416		.train_grid_item_outputs
417		.iter()
418		.map(|train_grid_item_output| {
419			serialize_train_grid_item_output(train_grid_item_output, writer)
420		})
421		.collect::<Vec<_>>();
422	let train_grid_item_outputs = writer.write(&train_grid_item_outputs);
423	let model = serialize_multiclass_classification_model(&multiclass_classifier.model, writer);
424	let test_metrics = serialize_multiclass_classification_metrics_output(
425		&multiclass_classifier.test_metrics,
426		writer,
427	);
428	let classes = multiclass_classifier
429		.classes
430		.iter()
431		.map(|class| writer.write(class))
432		.collect::<Vec<_>>();
433	let classes = writer.write(&classes);
434	let multiclass_classifier_writer = tangram_model::MulticlassClassifierWriter {
435		target_column_name,
436		train_row_count: multiclass_classifier.train_row_count.to_u64().unwrap(),
437		test_row_count: multiclass_classifier.test_row_count.to_u64().unwrap(),
438		overall_row_count: multiclass_classifier.overall_row_count.to_u64().unwrap(),
439		stats_settings,
440		overall_column_stats,
441		overall_target_column_stats,
442		train_column_stats,
443		train_target_column_stats,
444		test_column_stats,
445		test_target_column_stats,
446		baseline_metrics,
447		comparison_metric,
448		train_grid_item_outputs,
449		best_grid_item_index: multiclass_classifier.best_grid_item_index.to_u64().unwrap(),
450		model,
451		test_metrics,
452		classes,
453	};
454	writer.write(&multiclass_classifier_writer)
455}
456
457fn serialize_stats_settings(
458	stats_settings: &StatsSettings,
459	writer: &mut buffalo::Writer,
460) -> buffalo::Position<tangram_model::StatsSettingsWriter> {
461	let stats_settings_writer = tangram_model::StatsSettingsWriter {
462		number_histogram_max_size: stats_settings.number_histogram_max_size.to_u64().unwrap(),
463	};
464	writer.write(&stats_settings_writer)
465}
466
467fn serialize_column_stats_output(
468	column_stats_output: &ColumnStatsOutput,
469	writer: &mut buffalo::Writer,
470) -> tangram_model::ColumnStatsWriter {
471	match column_stats_output {
472		ColumnStatsOutput::Unknown(unknown_column_stats) => {
473			let unknown_column_stats =
474				serialize_unknown_column_stats_output(unknown_column_stats, writer);
475			tangram_model::ColumnStatsWriter::UnknownColumn(unknown_column_stats)
476		}
477		ColumnStatsOutput::Number(number_column_stats) => {
478			let number_column_stats =
479				serialize_number_column_stats_output(number_column_stats, writer);
480			tangram_model::ColumnStatsWriter::NumberColumn(number_column_stats)
481		}
482		ColumnStatsOutput::Enum(enum_column_stats) => {
483			let enum_column_stats = serialize_enum_column_stats_output(enum_column_stats, writer);
484			tangram_model::ColumnStatsWriter::EnumColumn(enum_column_stats)
485		}
486		ColumnStatsOutput::Text(text_column_stats) => {
487			let text_column_stats = serialize_text_column_stats_output(text_column_stats, writer);
488			tangram_model::ColumnStatsWriter::TextColumn(text_column_stats)
489		}
490	}
491}
492
493fn serialize_unknown_column_stats_output(
494	uknown_column_stats_output: &UnknownColumnStatsOutput,
495	writer: &mut buffalo::Writer,
496) -> buffalo::Position<tangram_model::UnknownColumnStatsWriter> {
497	let column_name = writer.write(uknown_column_stats_output.column_name.as_str());
498	let unknown_column_stats = tangram_model::UnknownColumnStatsWriter { column_name };
499	writer.write(&unknown_column_stats)
500}
501
502fn serialize_number_column_stats_output(
503	number_column_stats_output: &NumberColumnStatsOutput,
504	writer: &mut buffalo::Writer,
505) -> buffalo::Position<tangram_model::NumberColumnStatsWriter> {
506	let column_name = writer.write(number_column_stats_output.column_name.as_str());
507	let histogram = number_column_stats_output
508		.histogram
509		.as_ref()
510		.map(|histogram| {
511			histogram
512				.iter()
513				.map(|(key, value)| (key.get(), value.to_u64().unwrap()))
514				.collect::<Vec<_>>()
515		});
516	let histogram = histogram.map(|histogram| writer.write(histogram.as_slice()));
517	let number_column_stats = tangram_model::NumberColumnStatsWriter {
518		column_name,
519		invalid_count: number_column_stats_output.invalid_count.to_u64().unwrap(),
520		unique_count: number_column_stats_output.unique_count.to_u64().unwrap(),
521		histogram,
522		min: number_column_stats_output.min,
523		max: number_column_stats_output.max,
524		mean: number_column_stats_output.mean,
525		variance: number_column_stats_output.variance,
526		std: number_column_stats_output.std,
527		p25: number_column_stats_output.p25,
528		p50: number_column_stats_output.p50,
529		p75: number_column_stats_output.p75,
530	};
531	writer.write(&number_column_stats)
532}
533
534fn serialize_enum_column_stats_output(
535	enum_column_stats_output: &EnumColumnStatsOutput,
536	writer: &mut buffalo::Writer,
537) -> buffalo::Position<tangram_model::EnumColumnStatsWriter> {
538	let column_name = writer.write(enum_column_stats_output.column_name.as_str());
539	let strings = enum_column_stats_output
540		.histogram
541		.iter()
542		.map(|(key, _)| writer.write(key))
543		.collect::<Vec<_>>();
544	let histogram = zip!(strings, enum_column_stats_output.histogram.iter())
545		.map(|(key, (_, value))| (key, value.to_u64().unwrap()))
546		.collect::<Vec<_>>();
547	let histogram = writer.write(&histogram);
548	let enum_column_stats = tangram_model::EnumColumnStatsWriter {
549		column_name,
550		invalid_count: enum_column_stats_output.invalid_count.to_u64().unwrap(),
551		histogram,
552		unique_count: enum_column_stats_output.unique_count.to_u64().unwrap(),
553	};
554	writer.write(&enum_column_stats)
555}
556
557fn serialize_text_column_stats_output(
558	text_column_stats_output: &TextColumnStatsOutput,
559	writer: &mut buffalo::Writer,
560) -> buffalo::Position<tangram_model::TextColumnStatsWriter> {
561	let column_name = writer.write(text_column_stats_output.column_name.as_str());
562	let tokenizer = serialize_tokenizer(&text_column_stats_output.tokenizer, writer);
563	let ngram_types = text_column_stats_output
564		.ngram_types
565		.iter()
566		.map(|ngram_type| serialize_ngram_type(ngram_type, writer))
567		.collect::<Vec<_>>();
568	let ngram_types = writer.write(&ngram_types);
569	let ngrams_count = text_column_stats_output.ngrams_count.to_u64().unwrap();
570	let top_ngrams = text_column_stats_output
571		.top_ngrams
572		.iter()
573		.map(|(ngram, entry)| {
574			(
575				serialize_ngram(ngram, writer),
576				serialize_text_column_stats_output_top_n_grams_entry(entry, writer),
577			)
578		})
579		.collect::<Vec<_>>();
580	let top_ngrams = writer.write(&top_ngrams);
581	let text_column_stats = tangram_model::TextColumnStatsWriter {
582		column_name,
583		tokenizer,
584		ngram_types,
585		ngrams_count,
586		top_ngrams,
587	};
588	writer.write(&text_column_stats)
589}
590
591fn serialize_tokenizer(
592	tokenizer: &tangram_text::Tokenizer,
593	writer: &mut buffalo::Writer,
594) -> buffalo::Position<tangram_model::TokenizerWriter> {
595	writer.write(&tangram_model::TokenizerWriter {
596		lowercase: tokenizer.lowercase,
597		alphanumeric: tokenizer.alphanumeric,
598	})
599}
600
601fn serialize_ngram(
602	ngram: &tangram_text::NGram,
603	writer: &mut buffalo::Writer,
604) -> buffalo::Position<tangram_model::NGramWriter> {
605	match ngram {
606		tangram_text::NGram::Unigram(token) => {
607			let token = writer.write(token);
608			writer.write(&tangram_model::NGramWriter::Unigram(token))
609		}
610		tangram_text::NGram::Bigram(token_a, token_b) => {
611			let token_a = writer.write(token_a);
612			let token_b = writer.write(token_b);
613			writer.write(&tangram_model::NGramWriter::Bigram((token_a, token_b)))
614		}
615	}
616}
617
618fn serialize_text_column_stats_output_top_n_grams_entry(
619	text_column_stats_output_top_n_grams_entry: &TextColumnStatsOutputTopNGramsEntry,
620	writer: &mut buffalo::Writer,
621) -> buffalo::Position<tangram_model::TextColumnStatsTopNGramsEntryWriter> {
622	let token_stats = tangram_model::TextColumnStatsTopNGramsEntryWriter {
623		occurrence_count: text_column_stats_output_top_n_grams_entry
624			.occurrence_count
625			.to_u64()
626			.unwrap(),
627		row_count: text_column_stats_output_top_n_grams_entry
628			.row_count
629			.to_u64()
630			.unwrap(),
631	};
632	writer.write(&token_stats)
633}
634
635fn serialize_ngram_type(
636	ngram_type: &tangram_text::NGramType,
637	_writer: &mut buffalo::Writer,
638) -> tangram_model::NGramTypeWriter {
639	match ngram_type {
640		tangram_text::NGramType::Unigram => tangram_model::NGramTypeWriter::Unigram,
641		tangram_text::NGramType::Bigram => tangram_model::NGramTypeWriter::Bigram,
642	}
643}
644
645fn serialize_regression_metrics_output(
646	regression_metrics_output: &tangram_metrics::RegressionMetricsOutput,
647	writer: &mut buffalo::Writer,
648) -> buffalo::Position<tangram_model::RegressionMetricsWriter> {
649	let regression_metrics_writer = tangram_model::RegressionMetricsWriter {
650		mse: regression_metrics_output.mse,
651		rmse: regression_metrics_output.rmse,
652		mae: regression_metrics_output.mae,
653		r2: regression_metrics_output.r2,
654	};
655	writer.write(&regression_metrics_writer)
656}
657
658fn serialize_regression_comparison_metric(
659	regression_comparison_metric_writer: &RegressionComparisonMetric,
660	_writer: &mut buffalo::Writer,
661) -> tangram_model::RegressionComparisonMetricWriter {
662	match regression_comparison_metric_writer {
663		RegressionComparisonMetric::MeanAbsoluteError => {
664			tangram_model::RegressionComparisonMetricWriter::MeanAbsoluteError
665		}
666		RegressionComparisonMetric::MeanSquaredError => {
667			tangram_model::RegressionComparisonMetricWriter::MeanSquaredError
668		}
669		RegressionComparisonMetric::RootMeanSquaredError => {
670			tangram_model::RegressionComparisonMetricWriter::RootMeanSquaredError
671		}
672		RegressionComparisonMetric::R2 => tangram_model::RegressionComparisonMetricWriter::R2,
673	}
674}
675
676fn serialize_regression_model(
677	regression_model: &RegressionModel,
678	writer: &mut buffalo::Writer,
679) -> tangram_model::RegressionModelWriter {
680	match regression_model {
681		RegressionModel::Linear(linear_model) => {
682			let linear_regressor = serialize_linear_regression_model(linear_model, writer);
683			tangram_model::RegressionModelWriter::Linear(linear_regressor)
684		}
685		RegressionModel::Tree(tree_model) => {
686			let tree_regressor = serialize_tree_regression_model(tree_model, writer);
687			tangram_model::RegressionModelWriter::Tree(tree_regressor)
688		}
689	}
690}
691
692fn serialize_early_stopping_options(
693	early_stopping_options: &tangram_linear::EarlyStoppingOptions,
694	writer: &mut buffalo::Writer,
695) -> buffalo::Position<tangram_model::LinearEarlyStoppingOptionsWriter> {
696	let early_stopping_options = tangram_model::LinearEarlyStoppingOptionsWriter {
697		early_stopping_fraction: early_stopping_options.early_stopping_fraction,
698		n_rounds_without_improvement_to_stop: early_stopping_options
699			.n_rounds_without_improvement_to_stop
700			.to_u64()
701			.unwrap(),
702		min_decrease_in_loss_for_significant_change: early_stopping_options
703			.min_decrease_in_loss_for_significant_change,
704	};
705	writer.write(&early_stopping_options)
706}
707
708fn serialize_linear_regression_model(
709	linear_regression_model: &LinearRegressionModel,
710	writer: &mut buffalo::Writer,
711) -> buffalo::Position<tangram_model::LinearRegressorWriter> {
712	let feature_importances = writer.write(linear_regression_model.feature_importances.as_slice());
713	let train_options =
714		serialize_linear_train_options(&linear_regression_model.train_options, writer);
715	let feature_groups = linear_regression_model
716		.feature_groups
717		.iter()
718		.map(|feature_group| serialize_feature_group(feature_group, writer))
719		.collect::<Vec<_>>();
720	let feature_groups = writer.write(&feature_groups);
721	let losses = linear_regression_model
722		.losses
723		.as_ref()
724		.map(|losses| writer.write(losses.as_slice()));
725	let model = linear_regression_model.model.to_writer(writer);
726	let linear_regressor_writer = tangram_model::LinearRegressorWriter {
727		model,
728		train_options,
729		feature_groups,
730		losses,
731		feature_importances,
732	};
733	writer.write(&linear_regressor_writer)
734}
735
736fn serialize_linear_train_options(
737	train_options: &tangram_linear::TrainOptions,
738	writer: &mut buffalo::Writer,
739) -> buffalo::Position<tangram_model::LinearModelTrainOptionsWriter> {
740	let early_stopping_options =
741		train_options
742			.early_stopping_options
743			.as_ref()
744			.map(|early_stopping_options| {
745				serialize_early_stopping_options(early_stopping_options, writer)
746			});
747	let train_options = tangram_model::LinearModelTrainOptionsWriter {
748		compute_loss: train_options.compute_losses,
749		l2_regularization: train_options.l2_regularization,
750		learning_rate: train_options.learning_rate,
751		max_epochs: train_options.max_epochs.to_u64().unwrap(),
752		n_examples_per_batch: train_options.n_examples_per_batch.to_u64().unwrap(),
753		early_stopping_options,
754	};
755	writer.write(&train_options)
756}
757
758fn serialize_tree_train_options(
759	train_options: &tangram_tree::TrainOptions,
760	writer: &mut buffalo::Writer,
761) -> buffalo::Position<tangram_model::TreeModelTrainOptionsWriter> {
762	let early_stopping_options =
763		train_options
764			.early_stopping_options
765			.as_ref()
766			.map(|early_stopping_options| {
767				serialize_tree_early_stopping_options(early_stopping_options, writer)
768			});
769	let max_depth = train_options
770		.max_depth
771		.map(|max_depth| max_depth.to_u64().unwrap());
772	let binned_features_layout =
773		serialize_binned_features_layout(&train_options.binned_features_layout, writer);
774	let train_options = tangram_model::TreeModelTrainOptionsWriter {
775		compute_loss: train_options.compute_losses,
776		l2_regularization_for_continuous_splits: train_options
777			.l2_regularization_for_continuous_splits,
778		l2_regularization_for_discrete_splits: train_options.l2_regularization_for_discrete_splits,
779		learning_rate: train_options.learning_rate,
780		early_stopping_options,
781		binned_features_layout,
782		max_depth,
783		max_examples_for_computing_bin_thresholds: train_options
784			.max_examples_for_computing_bin_thresholds
785			.to_u64()
786			.unwrap(),
787		max_leaf_nodes: train_options.max_leaf_nodes.to_u64().unwrap(),
788		max_rounds: train_options.max_rounds.to_u64().unwrap(),
789		max_valid_bins_for_number_features: train_options
790			.max_valid_bins_for_number_features
791			.to_u8()
792			.unwrap(),
793		min_examples_per_node: train_options.min_examples_per_node.to_u64().unwrap(),
794		min_gain_to_split: train_options.min_gain_to_split,
795		min_sum_hessians_per_node: train_options.min_sum_hessians_per_node,
796		smoothing_factor_for_discrete_bin_sorting: train_options
797			.smoothing_factor_for_discrete_bin_sorting,
798	};
799	writer.write(&train_options)
800}
801
802fn serialize_tree_early_stopping_options(
803	early_stopping_options: &tangram_tree::EarlyStoppingOptions,
804	writer: &mut buffalo::Writer,
805) -> buffalo::Position<tangram_model::TreeEarlyStoppingOptionsWriter> {
806	let early_stopping_options = tangram_model::TreeEarlyStoppingOptionsWriter {
807		early_stopping_fraction: early_stopping_options.early_stopping_fraction,
808		n_rounds_without_improvement_to_stop: early_stopping_options
809			.n_rounds_without_improvement_to_stop
810			.to_u64()
811			.unwrap(),
812		min_decrease_in_loss_for_significant_change: early_stopping_options
813			.min_decrease_in_loss_for_significant_change,
814	};
815	writer.write(&early_stopping_options)
816}
817
818fn serialize_tree_regression_model(
819	tree_regression_model: &TreeRegressionModel,
820	writer: &mut buffalo::Writer,
821) -> buffalo::Position<tangram_model::TreeRegressorWriter> {
822	let feature_importances = writer.write(tree_regression_model.feature_importances.as_slice());
823	let train_options = serialize_tree_train_options(&tree_regression_model.train_options, writer);
824	let feature_groups = tree_regression_model
825		.feature_groups
826		.iter()
827		.map(|feature_group| serialize_feature_group(feature_group, writer))
828		.collect::<Vec<_>>();
829	let feature_groups = writer.write(&feature_groups);
830	let losses = tree_regression_model
831		.losses
832		.as_ref()
833		.map(|losses| writer.write(losses.as_slice()));
834	let model = tree_regression_model.model.to_writer(writer);
835	let model = tangram_model::TreeRegressorWriter {
836		model,
837		train_options,
838		feature_groups,
839		losses,
840		feature_importances,
841	};
842	writer.write(&model)
843}
844
845fn serialize_binned_features_layout(
846	binned_features_layout: &tangram_tree::BinnedFeaturesLayout,
847	_writer: &mut buffalo::Writer,
848) -> tangram_model::BinnedFeaturesLayoutWriter {
849	match binned_features_layout {
850		tangram_tree::BinnedFeaturesLayout::RowMajor => {
851			tangram_model::BinnedFeaturesLayoutWriter::RowMajor
852		}
853		tangram_tree::BinnedFeaturesLayout::ColumnMajor => {
854			tangram_model::BinnedFeaturesLayoutWriter::ColumnMajor
855		}
856	}
857}
858
859fn serialize_train_grid_item_output(
860	train_grid_item_output: &TrainGridItemOutput,
861	writer: &mut buffalo::Writer,
862) -> buffalo::Position<tangram_model::TrainGridItemOutputWriter> {
863	let hyperparameters = match &train_grid_item_output.train_model_output {
864		TrainModelOutput::LinearRegressor(model) => {
865			let options = serialize_linear_train_options(&model.train_options, writer);
866			tangram_model::ModelTrainOptionsWriter::Linear(options)
867		}
868		TrainModelOutput::TreeRegressor(model) => {
869			let options = serialize_tree_train_options(&model.train_options, writer);
870			tangram_model::ModelTrainOptionsWriter::Tree(options)
871		}
872		TrainModelOutput::LinearBinaryClassifier(model) => {
873			let options = serialize_linear_train_options(&model.train_options, writer);
874			tangram_model::ModelTrainOptionsWriter::Linear(options)
875		}
876		TrainModelOutput::TreeBinaryClassifier(model) => {
877			let options = serialize_tree_train_options(&model.train_options, writer);
878			tangram_model::ModelTrainOptionsWriter::Tree(options)
879		}
880		TrainModelOutput::LinearMulticlassClassifier(model) => {
881			let options = serialize_linear_train_options(&model.train_options, writer);
882			tangram_model::ModelTrainOptionsWriter::Linear(options)
883		}
884		TrainModelOutput::TreeMulticlassClassifier(model) => {
885			let options = serialize_tree_train_options(&model.train_options, writer);
886			tangram_model::ModelTrainOptionsWriter::Tree(options)
887		}
888	};
889	let train_grid_item_output_writer = tangram_model::TrainGridItemOutputWriter {
890		comparison_metric_value: train_grid_item_output.comparison_metric_value,
891		hyperparameters,
892		duration: train_grid_item_output.duration.as_secs_f32(),
893	};
894	writer.write(&train_grid_item_output_writer)
895}
896
897fn serialize_feature_group(
898	feature_group: &tangram_features::FeatureGroup,
899	writer: &mut buffalo::Writer,
900) -> tangram_model::FeatureGroupWriter {
901	match feature_group {
902		tangram_features::FeatureGroup::Identity(feature_group) => {
903			let feature_group = serialize_identity_feature_group(feature_group, writer);
904			tangram_model::FeatureGroupWriter::Identity(feature_group)
905		}
906		tangram_features::FeatureGroup::Normalized(feature_group) => {
907			let feature_group = serialize_normalized_feature_group(feature_group, writer);
908			tangram_model::FeatureGroupWriter::Normalized(feature_group)
909		}
910		tangram_features::FeatureGroup::OneHotEncoded(feature_group) => {
911			let feature_group = serialize_one_hot_encoded_feature_group(feature_group, writer);
912			tangram_model::FeatureGroupWriter::OneHotEncoded(feature_group)
913		}
914		tangram_features::FeatureGroup::BagOfWords(feature_group) => {
915			let feature_group = serialize_bag_of_words_feature_group(feature_group, writer);
916			tangram_model::FeatureGroupWriter::BagOfWords(feature_group)
917		}
918		tangram_features::FeatureGroup::BagOfWordsCosineSimilarity(feature_group) => {
919			let feature_group =
920				serialize_bag_of_words_cosine_similarity_feature_group(feature_group, writer);
921			tangram_model::FeatureGroupWriter::BagOfWordsCosineSimilarity(feature_group)
922		}
923		tangram_features::FeatureGroup::WordEmbedding(feature_group) => {
924			let feature_group = serialize_word_embedding_feature_group(feature_group, writer);
925			tangram_model::FeatureGroupWriter::WordEmbedding(feature_group)
926		}
927	}
928}
929
930fn serialize_identity_feature_group(
931	identity_feature_group: &tangram_features::IdentityFeatureGroup,
932	writer: &mut buffalo::Writer,
933) -> buffalo::Position<tangram_model::IdentityFeatureGroupWriter> {
934	let source_column_name = writer.write(identity_feature_group.source_column_name.as_str());
935	let feature_group = tangram_model::IdentityFeatureGroupWriter { source_column_name };
936	writer.write(&feature_group)
937}
938
939fn serialize_normalized_feature_group(
940	normalized_feature_group: &tangram_features::NormalizedFeatureGroup,
941	writer: &mut buffalo::Writer,
942) -> buffalo::Position<tangram_model::NormalizedFeatureGroupWriter> {
943	let source_column_name = writer.write(normalized_feature_group.source_column_name.as_str());
944	let feature_group = tangram_model::NormalizedFeatureGroupWriter {
945		source_column_name,
946		mean: normalized_feature_group.mean,
947		variance: normalized_feature_group.variance,
948	};
949	writer.write(&feature_group)
950}
951
952fn serialize_one_hot_encoded_feature_group(
953	one_hot_encoded_feature_group: &tangram_features::OneHotEncodedFeatureGroup,
954	writer: &mut buffalo::Writer,
955) -> buffalo::Position<tangram_model::OneHotEncodedFeatureGroupWriter> {
956	let source_column_name =
957		writer.write(one_hot_encoded_feature_group.source_column_name.as_str());
958	let variants = one_hot_encoded_feature_group
959		.variants
960		.iter()
961		.map(|variant| writer.write(variant))
962		.collect::<Vec<_>>();
963	let variants = writer.write(&variants);
964	let feature_group = tangram_model::OneHotEncodedFeatureGroupWriter {
965		source_column_name,
966		variants,
967	};
968	writer.write(&feature_group)
969}
970
971fn serialize_bag_of_words_feature_group(
972	bag_of_words_feature_group: &tangram_features::BagOfWordsFeatureGroup,
973	writer: &mut buffalo::Writer,
974) -> buffalo::Position<tangram_model::BagOfWordsFeatureGroupWriter> {
975	let source_column_name = writer.write(bag_of_words_feature_group.source_column_name.as_str());
976	let tokenizer = serialize_tokenizer(&bag_of_words_feature_group.tokenizer, writer);
977	let ngrams = bag_of_words_feature_group
978		.ngrams
979		.iter()
980		.map(|(ngram, entry)| {
981			(
982				serialize_ngram(ngram, writer),
983				serialize_bag_of_words_feature_group_n_gram_entry(entry, writer),
984			)
985		})
986		.collect::<Vec<_>>();
987	let ngrams = writer.write(&ngrams);
988	let ngram_types = bag_of_words_feature_group
989		.ngram_types
990		.iter()
991		.map(|ngram_type| serialize_ngram_type(ngram_type, writer))
992		.collect::<Vec<_>>();
993	let ngram_types = writer.write(&ngram_types);
994	let strategy =
995		serialize_bag_of_words_feature_group_strategy(&bag_of_words_feature_group.strategy, writer);
996	let feature_group = tangram_model::BagOfWordsFeatureGroupWriter {
997		source_column_name,
998		tokenizer,
999		strategy,
1000		ngram_types,
1001		ngrams,
1002	};
1003	writer.write(&feature_group)
1004}
1005
1006fn serialize_bag_of_words_feature_group_strategy(
1007	bag_of_words_feature_group_strategy: &tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy,
1008	_writer: &mut buffalo::Writer,
1009) -> tangram_model::BagOfWordsFeatureGroupStrategyWriter {
1010	match bag_of_words_feature_group_strategy {
1011		tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy::Present => {
1012			tangram_model::BagOfWordsFeatureGroupStrategyWriter::Present
1013		}
1014		tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy::Count => {
1015			tangram_model::BagOfWordsFeatureGroupStrategyWriter::Count
1016		}
1017		tangram_features::bag_of_words::BagOfWordsFeatureGroupStrategy::TfIdf => {
1018			tangram_model::BagOfWordsFeatureGroupStrategyWriter::TfIdf
1019		}
1020	}
1021}
1022
1023fn serialize_bag_of_words_feature_group_n_gram_entry(
1024	bag_of_words_feature_group_n_gram_entry: &tangram_features::bag_of_words::BagOfWordsFeatureGroupNGramEntry,
1025	writer: &mut buffalo::Writer,
1026) -> buffalo::Position<tangram_model::BagOfWordsFeatureGroupNGramEntryWriter> {
1027	writer.write(&tangram_model::BagOfWordsFeatureGroupNGramEntryWriter {
1028		idf: bag_of_words_feature_group_n_gram_entry.idf,
1029	})
1030}
1031
1032fn serialize_word_embedding_feature_group(
1033	word_embedding_feature_group: &tangram_features::WordEmbeddingFeatureGroup,
1034	writer: &mut buffalo::Writer,
1035) -> buffalo::Position<tangram_model::WordEmbeddingFeatureGroupWriter> {
1036	let source_column_name = writer.write(word_embedding_feature_group.source_column_name.as_str());
1037	let tokenizer = serialize_tokenizer(&word_embedding_feature_group.tokenizer, writer);
1038	let model = serialize_word_embedding_model(&word_embedding_feature_group.model, writer);
1039	let feature_group = tangram_model::WordEmbeddingFeatureGroupWriter {
1040		source_column_name,
1041		tokenizer,
1042		model,
1043	};
1044	writer.write(&feature_group)
1045}
1046
1047fn serialize_word_embedding_model(
1048	word_embedding_model: &tangram_text::WordEmbeddingModel,
1049	writer: &mut buffalo::Writer,
1050) -> buffalo::Position<tangram_model::WordEmbeddingModelWriter> {
1051	let size = word_embedding_model.size.to_u64().unwrap();
1052	let values = writer.write(word_embedding_model.values.as_slice());
1053	let words = word_embedding_model
1054		.words
1055		.keys()
1056		.map(|word| writer.write(word))
1057		.collect::<Vec<_>>();
1058	let words = zip!(words, word_embedding_model.words.values())
1059		.map(|(key, index)| (key, index.to_u64().unwrap()))
1060		.collect::<Vec<_>>();
1061	let words = writer.write(&words);
1062	writer.write(&tangram_model::WordEmbeddingModelWriter {
1063		size,
1064		words,
1065		values,
1066	})
1067}
1068
1069fn serialize_bag_of_words_cosine_similarity_feature_group(
1070	bag_of_words_cosine_similarity_feature_group: &tangram_features::BagOfWordsCosineSimilarityFeatureGroup,
1071	writer: &mut buffalo::Writer,
1072) -> buffalo::Position<tangram_model::BagOfWordsCosineSimilarityFeatureGroupWriter> {
1073	let source_column_name_a = writer.write(
1074		bag_of_words_cosine_similarity_feature_group
1075			.source_column_name_a
1076			.as_str(),
1077	);
1078	let source_column_name_b = writer.write(
1079		bag_of_words_cosine_similarity_feature_group
1080			.source_column_name_b
1081			.as_str(),
1082	);
1083	let tokenizer = serialize_tokenizer(
1084		&bag_of_words_cosine_similarity_feature_group.tokenizer,
1085		writer,
1086	);
1087	let ngrams = bag_of_words_cosine_similarity_feature_group
1088		.ngrams
1089		.iter()
1090		.map(|(ngram, entry)| {
1091			(
1092				serialize_ngram(ngram, writer),
1093				serialize_bag_of_words_feature_group_n_gram_entry(entry, writer),
1094			)
1095		})
1096		.collect::<Vec<_>>();
1097	let ngrams = writer.write(&ngrams);
1098	let ngram_types = bag_of_words_cosine_similarity_feature_group
1099		.ngram_types
1100		.iter()
1101		.map(|ngram_type| serialize_ngram_type(ngram_type, writer))
1102		.collect::<Vec<_>>();
1103	let ngram_types = writer.write(&ngram_types);
1104	let strategy = serialize_bag_of_words_feature_group_strategy(
1105		&bag_of_words_cosine_similarity_feature_group.strategy,
1106		writer,
1107	);
1108	let feature_group = tangram_model::BagOfWordsCosineSimilarityFeatureGroupWriter {
1109		source_column_name_a,
1110		source_column_name_b,
1111		tokenizer,
1112		strategy,
1113		ngram_types,
1114		ngrams,
1115	};
1116	writer.write(&feature_group)
1117}
1118
1119fn serialize_binary_classification_model(
1120	binary_classification_model: &BinaryClassificationModel,
1121	writer: &mut buffalo::Writer,
1122) -> tangram_model::BinaryClassificationModelWriter {
1123	match binary_classification_model {
1124		BinaryClassificationModel::Linear(model) => {
1125			let linear_binary_classifier =
1126				serialize_linear_binary_classification_model(model, writer);
1127			tangram_model::BinaryClassificationModelWriter::Linear(linear_binary_classifier)
1128		}
1129		BinaryClassificationModel::Tree(model) => {
1130			let tree_binary_classifier = serialize_tree_binary_classification_model(model, writer);
1131			tangram_model::BinaryClassificationModelWriter::Tree(tree_binary_classifier)
1132		}
1133	}
1134}
1135
1136fn serialize_linear_binary_classification_model(
1137	linear_binary_classification_model: &LinearBinaryClassificationModel,
1138	writer: &mut buffalo::Writer,
1139) -> buffalo::Position<tangram_model::LinearBinaryClassifierWriter> {
1140	let model = linear_binary_classification_model.model.to_writer(writer);
1141	let train_options =
1142		serialize_linear_train_options(&linear_binary_classification_model.train_options, writer);
1143	let feature_groups = linear_binary_classification_model
1144		.feature_groups
1145		.iter()
1146		.map(|feature_group| serialize_feature_group(feature_group, writer))
1147		.collect::<Vec<_>>();
1148	let feature_groups = writer.write(&feature_groups);
1149	let losses = linear_binary_classification_model
1150		.losses
1151		.as_ref()
1152		.map(|losses| writer.write(losses.as_slice()));
1153	let feature_importances = writer.write(
1154		linear_binary_classification_model
1155			.feature_importances
1156			.as_slice(),
1157	);
1158	let model = tangram_model::LinearBinaryClassifierWriter {
1159		model,
1160		train_options,
1161		feature_groups,
1162		losses,
1163		feature_importances,
1164	};
1165	writer.write(&model)
1166}
1167
1168fn serialize_tree_binary_classification_model(
1169	tree_binary_classification_model: &TreeBinaryClassificationModel,
1170	writer: &mut buffalo::Writer,
1171) -> buffalo::Position<tangram_model::TreeBinaryClassifierWriter> {
1172	let feature_importances = writer.write(
1173		tree_binary_classification_model
1174			.feature_importances
1175			.as_slice(),
1176	);
1177	let train_options =
1178		serialize_tree_train_options(&tree_binary_classification_model.train_options, writer);
1179	let feature_groups = tree_binary_classification_model
1180		.feature_groups
1181		.iter()
1182		.map(|feature_group| serialize_feature_group(feature_group, writer))
1183		.collect::<Vec<_>>();
1184	let feature_groups = writer.write(&feature_groups);
1185	let losses = tree_binary_classification_model
1186		.losses
1187		.as_ref()
1188		.map(|losses| writer.write(losses.as_slice()));
1189	let model = tree_binary_classification_model.model.to_writer(writer);
1190	let model = tangram_model::TreeBinaryClassifierWriter {
1191		model,
1192		train_options,
1193		feature_groups,
1194		losses,
1195		feature_importances,
1196	};
1197	writer.write(&model)
1198}
1199
1200fn serialize_binary_classification_metrics_output(
1201	binary_classification_metrics_output: &tangram_metrics::BinaryClassificationMetricsOutput,
1202	writer: &mut buffalo::Writer,
1203) -> buffalo::Position<tangram_model::BinaryClassificationMetricsWriter> {
1204	let thresholds = binary_classification_metrics_output
1205		.thresholds
1206		.iter()
1207		.map(|threshold| {
1208			serialize_binary_classification_metrics_output_for_threshold(threshold, writer)
1209		})
1210		.collect::<Vec<_>>();
1211	let thresholds = writer.write(&thresholds);
1212	let default_threshold = serialize_binary_classification_metrics_output_for_threshold(
1213		&binary_classification_metrics_output.thresholds
1214			[binary_classification_metrics_output.thresholds.len() / 2],
1215		writer,
1216	);
1217	let metrics = tangram_model::BinaryClassificationMetricsWriter {
1218		auc_roc: binary_classification_metrics_output.auc_roc_approx,
1219		default_threshold,
1220		thresholds,
1221	};
1222	writer.write(&metrics)
1223}
1224
1225fn serialize_binary_classification_metrics_output_for_threshold(
1226	binary_classification_metrics_output_for_threshold: &tangram_metrics::BinaryClassificationMetricsOutputForThreshold,
1227	writer: &mut buffalo::Writer,
1228) -> buffalo::Position<tangram_model::BinaryClassificationMetricsForThresholdWriter> {
1229	let metrics = tangram_model::BinaryClassificationMetricsForThresholdWriter {
1230		threshold: binary_classification_metrics_output_for_threshold.threshold,
1231		true_positives: binary_classification_metrics_output_for_threshold
1232			.true_positives
1233			.to_u64()
1234			.unwrap(),
1235		false_positives: binary_classification_metrics_output_for_threshold
1236			.false_positives
1237			.to_u64()
1238			.unwrap(),
1239		true_negatives: binary_classification_metrics_output_for_threshold
1240			.true_negatives
1241			.to_u64()
1242			.unwrap(),
1243		false_negatives: binary_classification_metrics_output_for_threshold
1244			.false_negatives
1245			.to_u64()
1246			.unwrap(),
1247		accuracy: binary_classification_metrics_output_for_threshold.accuracy,
1248		precision: binary_classification_metrics_output_for_threshold.precision,
1249		recall: binary_classification_metrics_output_for_threshold.recall,
1250		f1_score: binary_classification_metrics_output_for_threshold.f1_score,
1251		true_positive_rate: binary_classification_metrics_output_for_threshold.true_positive_rate,
1252		false_positive_rate: binary_classification_metrics_output_for_threshold.false_positive_rate,
1253	};
1254	writer.write(&metrics)
1255}
1256
1257fn serialize_binary_classification_comparison_metric(
1258	binary_classification_comparison_metric: &BinaryClassificationComparisonMetric,
1259	_writer: &mut buffalo::Writer,
1260) -> tangram_model::BinaryClassificationComparisonMetricWriter {
1261	match binary_classification_comparison_metric {
1262		BinaryClassificationComparisonMetric::AucRoc => {
1263			tangram_model::BinaryClassificationComparisonMetricWriter::Aucroc
1264		}
1265	}
1266}
1267
1268fn serialize_multiclass_classification_model(
1269	multiclass_classification_model: &MulticlassClassificationModel,
1270	writer: &mut buffalo::Writer,
1271) -> tangram_model::MulticlassClassificationModelWriter {
1272	match multiclass_classification_model {
1273		MulticlassClassificationModel::Linear(model) => {
1274			let linear_multiclass_classifier =
1275				serialize_linear_multiclass_classification_model(model, writer);
1276			tangram_model::MulticlassClassificationModelWriter::Linear(linear_multiclass_classifier)
1277		}
1278		MulticlassClassificationModel::Tree(model) => {
1279			let tree_multiclass_classifier =
1280				serialize_tree_multiclass_classification_model(model, writer);
1281			tangram_model::MulticlassClassificationModelWriter::Tree(tree_multiclass_classifier)
1282		}
1283	}
1284}
1285
1286fn serialize_linear_multiclass_classification_model(
1287	linear_multiclass_classification_model: &LinearMulticlassClassificationModel,
1288	writer: &mut buffalo::Writer,
1289) -> buffalo::Position<tangram_model::LinearMulticlassClassifierWriter> {
1290	let feature_importances = writer.write(
1291		linear_multiclass_classification_model
1292			.feature_importances
1293			.as_slice(),
1294	);
1295	let train_options = serialize_linear_train_options(
1296		&linear_multiclass_classification_model.train_options,
1297		writer,
1298	);
1299	let feature_groups = linear_multiclass_classification_model
1300		.feature_groups
1301		.iter()
1302		.map(|feature_group| serialize_feature_group(feature_group, writer))
1303		.collect::<Vec<_>>();
1304	let feature_groups = writer.write(&feature_groups);
1305	let losses = linear_multiclass_classification_model
1306		.losses
1307		.as_ref()
1308		.map(|losses| writer.write(losses.as_slice()));
1309	let model = linear_multiclass_classification_model
1310		.model
1311		.to_writer(writer);
1312	let model = tangram_model::LinearMulticlassClassifierWriter {
1313		model,
1314		train_options,
1315		feature_groups,
1316		losses,
1317		feature_importances,
1318	};
1319	writer.write(&model)
1320}
1321
1322fn serialize_tree_multiclass_classification_model(
1323	tree_multiclass_classification_model: &TreeMulticlassClassificationModel,
1324	writer: &mut buffalo::Writer,
1325) -> buffalo::Position<tangram_model::TreeMulticlassClassifierWriter> {
1326	let feature_importances = writer.write(
1327		tree_multiclass_classification_model
1328			.feature_importances
1329			.as_slice(),
1330	);
1331	let train_options =
1332		serialize_tree_train_options(&tree_multiclass_classification_model.train_options, writer);
1333	let feature_groups = tree_multiclass_classification_model
1334		.feature_groups
1335		.iter()
1336		.map(|feature_group| serialize_feature_group(feature_group, writer))
1337		.collect::<Vec<_>>();
1338	let feature_groups = writer.write(&feature_groups);
1339	let losses = tree_multiclass_classification_model
1340		.losses
1341		.as_ref()
1342		.map(|losses| writer.write(losses.as_slice()));
1343	let model = tree_multiclass_classification_model.model.to_writer(writer);
1344	let model = tangram_model::TreeMulticlassClassifierWriter {
1345		model,
1346		train_options,
1347		feature_groups,
1348		losses,
1349		feature_importances,
1350	};
1351	writer.write(&model)
1352}
1353
1354fn serialize_multiclass_classification_metrics_output(
1355	multiclass_classification_metrics_output: &tangram_metrics::MulticlassClassificationMetricsOutput,
1356	writer: &mut buffalo::Writer,
1357) -> buffalo::Position<tangram_model::MulticlassClassificationMetricsWriter> {
1358	let class_metrics = multiclass_classification_metrics_output
1359		.class_metrics
1360		.iter()
1361		.map(|class_metric| serialize_class_metrics(class_metric, writer))
1362		.collect::<Vec<_>>();
1363	let class_metrics = writer.write(&class_metrics);
1364	let metrics = tangram_model::MulticlassClassificationMetricsWriter {
1365		class_metrics,
1366		accuracy: multiclass_classification_metrics_output.accuracy,
1367		precision_unweighted: multiclass_classification_metrics_output.precision_unweighted,
1368		precision_weighted: multiclass_classification_metrics_output.precision_weighted,
1369		recall_unweighted: multiclass_classification_metrics_output.recall_weighted,
1370		recall_weighted: multiclass_classification_metrics_output.recall_weighted,
1371	};
1372	writer.write(&metrics)
1373}
1374
1375fn serialize_class_metrics(
1376	class_metrics: &tangram_metrics::ClassMetrics,
1377	writer: &mut buffalo::Writer,
1378) -> buffalo::Position<tangram_model::ClassMetricsWriter> {
1379	let metrics = tangram_model::ClassMetricsWriter {
1380		true_positives: class_metrics.true_positives.to_u64().unwrap(),
1381		false_positives: class_metrics.false_positives.to_u64().unwrap(),
1382		true_negatives: class_metrics.true_negatives.to_u64().unwrap(),
1383		false_negatives: class_metrics.false_negatives.to_u64().unwrap(),
1384		accuracy: class_metrics.accuracy,
1385		precision: class_metrics.precision,
1386		recall: class_metrics.recall,
1387		f1_score: class_metrics.f1_score,
1388	};
1389	writer.write(&metrics)
1390}
1391
1392fn serialize_multiclass_classification_comparison_metric(
1393	multiclass_classification_comparison_metric: &MulticlassClassificationComparisonMetric,
1394	_writer: &mut buffalo::Writer,
1395) -> tangram_model::MulticlassClassificationComparisonMetricWriter {
1396	match multiclass_classification_comparison_metric {
1397		MulticlassClassificationComparisonMetric::Accuracy => {
1398			tangram_model::MulticlassClassificationComparisonMetricWriter::Accuracy
1399		}
1400	}
1401}