1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
use crate::{ ColumnStats, FeatureGroup, LinearModelTrainOptions, StatsSettings, TrainGridItemOutput, TreeModelTrainOptions, }; #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct BinaryClassifier { #[buffalo(id = 0, required)] pub target_column_name: String, #[buffalo(id = 1, required)] pub negative_class: String, #[buffalo(id = 2, required)] pub positive_class: String, #[buffalo(id = 3, required)] pub train_row_count: u64, #[buffalo(id = 4, required)] pub test_row_count: u64, #[buffalo(id = 5, required)] pub overall_row_count: u64, #[buffalo(id = 6, required)] pub stats_settings: StatsSettings, #[buffalo(id = 7, required)] pub overall_column_stats: Vec<ColumnStats>, #[buffalo(id = 8, required)] pub overall_target_column_stats: ColumnStats, #[buffalo(id = 9, required)] pub train_column_stats: Vec<ColumnStats>, #[buffalo(id = 10, required)] pub train_target_column_stats: ColumnStats, #[buffalo(id = 11, required)] pub test_column_stats: Vec<ColumnStats>, #[buffalo(id = 12, required)] pub test_target_column_stats: ColumnStats, #[buffalo(id = 13, required)] pub baseline_metrics: BinaryClassificationMetrics, #[buffalo(id = 14, required)] pub comparison_metric: BinaryClassificationComparisonMetric, #[buffalo(id = 15, required)] pub train_grid_item_outputs: Vec<TrainGridItemOutput>, #[buffalo(id = 16, required)] pub best_grid_item_index: u64, #[buffalo(id = 17, required)] pub model: BinaryClassificationModel, #[buffalo(id = 18, required)] pub test_metrics: BinaryClassificationMetrics, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "static", value_size = 0)] pub enum BinaryClassificationComparisonMetric { #[buffalo(id = 0)] Aucroc, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct BinaryClassificationMetrics { #[buffalo(id = 0, required)] pub auc_roc: f32, #[buffalo(id = 1, required)] pub default_threshold: BinaryClassificationMetricsForThreshold, #[buffalo(id = 2, required)] pub thresholds: Vec<BinaryClassificationMetricsForThreshold>, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct BinaryClassificationMetricsForThreshold { #[buffalo(id = 0, required)] pub threshold: f32, #[buffalo(id = 1, required)] pub true_positives: u64, #[buffalo(id = 2, required)] pub false_positives: u64, #[buffalo(id = 3, required)] pub true_negatives: u64, #[buffalo(id = 4, required)] pub false_negatives: u64, #[buffalo(id = 5, required)] pub accuracy: f32, #[buffalo(id = 6, required)] pub precision: Option<f32>, #[buffalo(id = 7, required)] pub recall: Option<f32>, #[buffalo(id = 8, required)] pub f1_score: Option<f32>, #[buffalo(id = 9, required)] pub true_positive_rate: f32, #[buffalo(id = 10, required)] pub false_positive_rate: f32, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "static", value_size = 8)] pub enum BinaryClassificationModel { #[buffalo(id = 0)] Linear(LinearBinaryClassifier), #[buffalo(id = 1)] Tree(TreeBinaryClassifier), } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct LinearBinaryClassifier { #[buffalo(id = 0, required)] pub model: tangram_linear::serialize::BinaryClassifier, #[buffalo(id = 1, required)] pub train_options: LinearModelTrainOptions, #[buffalo(id = 2, required)] pub feature_groups: Vec<FeatureGroup>, #[buffalo(id = 3, required)] pub losses: Option<Vec<f32>>, #[buffalo(id = 4, required)] pub feature_importances: Vec<f32>, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct TreeBinaryClassifier { #[buffalo(id = 0, required)] pub model: tangram_tree::serialize::BinaryClassifier, #[buffalo(id = 1, required)] pub train_options: TreeModelTrainOptions, #[buffalo(id = 2, required)] pub feature_groups: Vec<FeatureGroup>, #[buffalo(id = 3, required)] pub losses: Option<Vec<f32>>, #[buffalo(id = 4, required)] pub feature_importances: Vec<f32>, }