1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use crate::{
	ColumnStats, FeatureGroup, LinearModelTrainOptions, StatsSettings, TrainGridItemOutput,
	TreeModelTrainOptions,
};

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct BinaryClassifier {
	#[buffalo(id = 0, required)]
	pub target_column_name: String,
	#[buffalo(id = 1, required)]
	pub negative_class: String,
	#[buffalo(id = 2, required)]
	pub positive_class: String,
	#[buffalo(id = 3, required)]
	pub train_row_count: u64,
	#[buffalo(id = 4, required)]
	pub test_row_count: u64,
	#[buffalo(id = 5, required)]
	pub overall_row_count: u64,
	#[buffalo(id = 6, required)]
	pub stats_settings: StatsSettings,
	#[buffalo(id = 7, required)]
	pub overall_column_stats: Vec<ColumnStats>,
	#[buffalo(id = 8, required)]
	pub overall_target_column_stats: ColumnStats,
	#[buffalo(id = 9, required)]
	pub train_column_stats: Vec<ColumnStats>,
	#[buffalo(id = 10, required)]
	pub train_target_column_stats: ColumnStats,
	#[buffalo(id = 11, required)]
	pub test_column_stats: Vec<ColumnStats>,
	#[buffalo(id = 12, required)]
	pub test_target_column_stats: ColumnStats,
	#[buffalo(id = 13, required)]
	pub baseline_metrics: BinaryClassificationMetrics,
	#[buffalo(id = 14, required)]
	pub comparison_metric: BinaryClassificationComparisonMetric,
	#[buffalo(id = 15, required)]
	pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
	#[buffalo(id = 16, required)]
	pub best_grid_item_index: u64,
	#[buffalo(id = 17, required)]
	pub model: BinaryClassificationModel,
	#[buffalo(id = 18, required)]
	pub test_metrics: BinaryClassificationMetrics,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 0)]
pub enum BinaryClassificationComparisonMetric {
	#[buffalo(id = 0)]
	Aucroc,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct BinaryClassificationMetrics {
	#[buffalo(id = 0, required)]
	pub auc_roc: f32,
	#[buffalo(id = 1, required)]
	pub default_threshold: BinaryClassificationMetricsForThreshold,
	#[buffalo(id = 2, required)]
	pub thresholds: Vec<BinaryClassificationMetricsForThreshold>,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct BinaryClassificationMetricsForThreshold {
	#[buffalo(id = 0, required)]
	pub threshold: f32,
	#[buffalo(id = 1, required)]
	pub true_positives: u64,
	#[buffalo(id = 2, required)]
	pub false_positives: u64,
	#[buffalo(id = 3, required)]
	pub true_negatives: u64,
	#[buffalo(id = 4, required)]
	pub false_negatives: u64,
	#[buffalo(id = 5, required)]
	pub accuracy: f32,
	#[buffalo(id = 6, required)]
	pub precision: Option<f32>,
	#[buffalo(id = 7, required)]
	pub recall: Option<f32>,
	#[buffalo(id = 8, required)]
	pub f1_score: Option<f32>,
	#[buffalo(id = 9, required)]
	pub true_positive_rate: f32,
	#[buffalo(id = 10, required)]
	pub false_positive_rate: f32,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "static", value_size = 8)]
pub enum BinaryClassificationModel {
	#[buffalo(id = 0)]
	Linear(LinearBinaryClassifier),
	#[buffalo(id = 1)]
	Tree(TreeBinaryClassifier),
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct LinearBinaryClassifier {
	#[buffalo(id = 0, required)]
	pub model: tangram_linear::serialize::BinaryClassifier,
	#[buffalo(id = 1, required)]
	pub train_options: LinearModelTrainOptions,
	#[buffalo(id = 2, required)]
	pub feature_groups: Vec<FeatureGroup>,
	#[buffalo(id = 3, required)]
	pub losses: Option<Vec<f32>>,
	#[buffalo(id = 4, required)]
	pub feature_importances: Vec<f32>,
}

#[derive(buffalo::Read, buffalo::Write)]
#[buffalo(size = "dynamic")]
pub struct TreeBinaryClassifier {
	#[buffalo(id = 0, required)]
	pub model: tangram_tree::serialize::BinaryClassifier,
	#[buffalo(id = 1, required)]
	pub train_options: TreeModelTrainOptions,
	#[buffalo(id = 2, required)]
	pub feature_groups: Vec<FeatureGroup>,
	#[buffalo(id = 3, required)]
	pub losses: Option<Vec<f32>>,
	#[buffalo(id = 4, required)]
	pub feature_importances: Vec<f32>,
}