1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
use crate::{ ColumnStats, FeatureGroup, LinearModelTrainOptions, StatsSettings, TrainGridItemOutput, TreeModelTrainOptions, }; #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct MulticlassClassifier { #[buffalo(id = 0, required)] pub target_column_name: String, #[buffalo(id = 1, required)] pub classes: Vec<String>, #[buffalo(id = 2, required)] pub train_row_count: u64, #[buffalo(id = 3, required)] pub test_row_count: u64, #[buffalo(id = 4, required)] pub overall_row_count: u64, #[buffalo(id = 5, required)] pub stats_settings: StatsSettings, #[buffalo(id = 6, required)] pub overall_column_stats: Vec<ColumnStats>, #[buffalo(id = 7, required)] pub overall_target_column_stats: ColumnStats, #[buffalo(id = 8, required)] pub train_column_stats: Vec<ColumnStats>, #[buffalo(id = 9, required)] pub train_target_column_stats: ColumnStats, #[buffalo(id = 10, required)] pub test_column_stats: Vec<ColumnStats>, #[buffalo(id = 11, required)] pub test_target_column_stats: ColumnStats, #[buffalo(id = 12, required)] pub baseline_metrics: MulticlassClassificationMetrics, #[buffalo(id = 13, required)] pub comparison_metric: MulticlassClassificationComparisonMetric, #[buffalo(id = 14, required)] pub train_grid_item_outputs: Vec<TrainGridItemOutput>, #[buffalo(id = 15, required)] pub best_grid_item_index: u64, #[buffalo(id = 16, required)] pub model: MulticlassClassificationModel, #[buffalo(id = 17, required)] pub test_metrics: MulticlassClassificationMetrics, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "static", value_size = 0)] pub enum MulticlassClassificationComparisonMetric { #[buffalo(id = 0)] Accuracy, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct MulticlassClassificationMetrics { #[buffalo(id = 0, required)] pub class_metrics: Vec<ClassMetrics>, #[buffalo(id = 1, required)] pub accuracy: f32, #[buffalo(id = 2, required)] pub precision_unweighted: f32, #[buffalo(id = 3, required)] pub precision_weighted: f32, #[buffalo(id = 4, required)] pub recall_unweighted: f32, #[buffalo(id = 5, required)] pub recall_weighted: f32, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct ClassMetrics { #[buffalo(id = 0, required)] pub true_positives: u64, #[buffalo(id = 1, required)] pub false_positives: u64, #[buffalo(id = 2, required)] pub true_negatives: u64, #[buffalo(id = 3, required)] pub false_negatives: u64, #[buffalo(id = 4, required)] pub accuracy: f32, #[buffalo(id = 5, required)] pub precision: f32, #[buffalo(id = 6, required)] pub recall: f32, #[buffalo(id = 7, required)] pub f1_score: f32, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "static", value_size = 8)] pub enum MulticlassClassificationModel { #[buffalo(id = 0)] Linear(LinearMulticlassClassifier), #[buffalo(id = 1)] Tree(TreeMulticlassClassifier), } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct LinearMulticlassClassifier { #[buffalo(id = 0, required)] pub model: tangram_linear::serialize::MulticlassClassifier, #[buffalo(id = 1, required)] pub train_options: LinearModelTrainOptions, #[buffalo(id = 2, required)] pub feature_groups: Vec<FeatureGroup>, #[buffalo(id = 3, required)] pub losses: Option<Vec<f32>>, #[buffalo(id = 4, required)] pub feature_importances: Vec<f32>, } #[derive(buffalo::Read, buffalo::Write)] #[buffalo(size = "dynamic")] pub struct TreeMulticlassClassifier { #[buffalo(id = 0, required)] pub model: tangram_tree::serialize::MulticlassClassifier, #[buffalo(id = 1, required)] pub train_options: TreeModelTrainOptions, #[buffalo(id = 2, required)] pub feature_groups: Vec<FeatureGroup>, #[buffalo(id = 3, required)] pub losses: Option<Vec<f32>>, #[buffalo(id = 4, required)] pub feature_importances: Vec<f32>, }