1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use crate::{
	ColumnStats, FeatureGroup, LinearModelTrainOptions, StatsSettings, TrainGridItemOutput,
	TreeModelTrainOptions,
};

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct MulticlassClassifier {
	#[buffalo(id = 0, required)]
	pub target_column_name: String,
	#[buffalo(id = 1, required)]
	pub classes: Vec<String>,
	#[buffalo(id = 2, required)]
	pub train_row_count: u64,
	#[buffalo(id = 3, required)]
	pub test_row_count: u64,
	#[buffalo(id = 4, required)]
	pub overall_row_count: u64,
	#[buffalo(id = 5, required)]
	pub stats_settings: StatsSettings,
	#[buffalo(id = 6, required)]
	pub overall_column_stats: Vec<ColumnStats>,
	#[buffalo(id = 7, required)]
	pub overall_target_column_stats: ColumnStats,
	#[buffalo(id = 8, required)]
	pub train_column_stats: Vec<ColumnStats>,
	#[buffalo(id = 9, required)]
	pub train_target_column_stats: ColumnStats,
	#[buffalo(id = 10, required)]
	pub test_column_stats: Vec<ColumnStats>,
	#[buffalo(id = 11, required)]
	pub test_target_column_stats: ColumnStats,
	#[buffalo(id = 12, required)]
	pub baseline_metrics: MulticlassClassificationMetrics,
	#[buffalo(id = 13, required)]
	pub comparison_metric: MulticlassClassificationComparisonMetric,
	#[buffalo(id = 14, required)]
	pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
	#[buffalo(id = 15, required)]
	pub best_grid_item_index: u64,
	#[buffalo(id = 16, required)]
	pub model: MulticlassClassificationModel,
	#[buffalo(id = 17, required)]
	pub test_metrics: MulticlassClassificationMetrics,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 0)]
pub enum MulticlassClassificationComparisonMetric {
	#[buffalo(id = 0)]
	Accuracy,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct MulticlassClassificationMetrics {
	#[buffalo(id = 0, required)]
	pub class_metrics: Vec<ClassMetrics>,
	#[buffalo(id = 1, required)]
	pub accuracy: f32,
	#[buffalo(id = 2, required)]
	pub precision_unweighted: f32,
	#[buffalo(id = 3, required)]
	pub precision_weighted: f32,
	#[buffalo(id = 4, required)]
	pub recall_unweighted: f32,
	#[buffalo(id = 5, required)]
	pub recall_weighted: f32,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct ClassMetrics {
	#[buffalo(id = 0, required)]
	pub true_positives: u64,
	#[buffalo(id = 1, required)]
	pub false_positives: u64,
	#[buffalo(id = 2, required)]
	pub true_negatives: u64,
	#[buffalo(id = 3, required)]
	pub false_negatives: u64,
	#[buffalo(id = 4, required)]
	pub accuracy: f32,
	#[buffalo(id = 5, required)]
	pub precision: f32,
	#[buffalo(id = 6, required)]
	pub recall: f32,
	#[buffalo(id = 7, required)]
	pub f1_score: f32,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 8)]
pub enum MulticlassClassificationModel {
	#[buffalo(id = 0)]
	Linear(LinearMulticlassClassifier),
	#[buffalo(id = 1)]
	Tree(TreeMulticlassClassifier),
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct LinearMulticlassClassifier {
	#[buffalo(id = 0, required)]
	pub model: tangram_linear::serialize::MulticlassClassifier,
	#[buffalo(id = 1, required)]
	pub train_options: LinearModelTrainOptions,
	#[buffalo(id = 2, required)]
	pub feature_groups: Vec<FeatureGroup>,
	#[buffalo(id = 3, required)]
	pub losses: Option<Vec<f32>>,
	#[buffalo(id = 4, required)]
	pub feature_importances: Vec<f32>,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct TreeMulticlassClassifier {
	#[buffalo(id = 0, required)]
	pub model: tangram_tree::serialize::MulticlassClassifier,
	#[buffalo(id = 1, required)]
	pub train_options: TreeModelTrainOptions,
	#[buffalo(id = 2, required)]
	pub feature_groups: Vec<FeatureGroup>,
	#[buffalo(id = 3, required)]
	pub losses: Option<Vec<f32>>,
	#[buffalo(id = 4, required)]
	pub feature_importances: Vec<f32>,
}