tangram_model/
regressor.rs1use crate::{
2	ColumnStats, FeatureGroup, LinearModelTrainOptions, StatsSettings, TrainGridItemOutput,
3	TreeModelTrainOptions,
4};
5
6#[derive(buffalo::Read, buffalo::Write)]
7#[buffalo(size = "dynamic")]
8pub struct Regressor {
9	#[buffalo(id = 0, required)]
10	pub target_column_name: String,
11	#[buffalo(id = 1, required)]
12	pub train_row_count: u64,
13	#[buffalo(id = 2, required)]
14	pub test_row_count: u64,
15	#[buffalo(id = 3, required)]
16	pub overall_row_count: u64,
17	#[buffalo(id = 4, required)]
18	pub stats_settings: StatsSettings,
19	#[buffalo(id = 5, required)]
20	pub overall_column_stats: Vec<ColumnStats>,
21	#[buffalo(id = 6, required)]
22	pub overall_target_column_stats: ColumnStats,
23	#[buffalo(id = 7, required)]
24	pub train_column_stats: Vec<ColumnStats>,
25	#[buffalo(id = 8, required)]
26	pub train_target_column_stats: ColumnStats,
27	#[buffalo(id = 9, required)]
28	pub test_column_stats: Vec<ColumnStats>,
29	#[buffalo(id = 10, required)]
30	pub test_target_column_stats: ColumnStats,
31	#[buffalo(id = 11, required)]
32	pub baseline_metrics: RegressionMetrics,
33	#[buffalo(id = 12, required)]
34	pub comparison_metric: RegressionComparisonMetric,
35	#[buffalo(id = 13, required)]
36	pub train_grid_item_outputs: Vec<TrainGridItemOutput>,
37	#[buffalo(id = 14, required)]
38	pub best_grid_item_index: u64,
39	#[buffalo(id = 15, required)]
40	pub model: RegressionModel,
41	#[buffalo(id = 16, required)]
42	pub test_metrics: RegressionMetrics,
43}
44
45#[derive(buffalo::Read, buffalo::Write)]
46#[buffalo(size = "static", value_size = 8)]
47pub enum RegressionModel {
48	#[buffalo(id = 0)]
49	Linear(LinearRegressor),
50	#[buffalo(id = 1)]
51	Tree(TreeRegressor),
52}
53
54#[derive(buffalo::Read, buffalo::Write)]
55#[buffalo(size = "dynamic")]
56pub struct LinearRegressor {
57	#[buffalo(id = 0, required)]
58	pub model: tangram_linear::serialize::Regressor,
59	#[buffalo(id = 1, required)]
60	pub train_options: LinearModelTrainOptions,
61	#[buffalo(id = 2, required)]
62	pub feature_groups: Vec<FeatureGroup>,
63	#[buffalo(id = 3, required)]
64	pub losses: Option<Vec<f32>>,
65	#[buffalo(id = 4, required)]
66	pub feature_importances: Vec<f32>,
67}
68
69#[derive(buffalo::Read, buffalo::Write)]
70#[buffalo(size = "dynamic")]
71pub struct TreeRegressor {
72	#[buffalo(id = 0, required)]
73	pub model: tangram_tree::serialize::Regressor,
74	#[buffalo(id = 1, required)]
75	pub train_options: TreeModelTrainOptions,
76	#[buffalo(id = 2, required)]
77	pub feature_groups: Vec<FeatureGroup>,
78	#[buffalo(id = 3, required)]
79	pub losses: Option<Vec<f32>>,
80	#[buffalo(id = 4, required)]
81	pub feature_importances: Vec<f32>,
82}
83
84#[derive(buffalo::Read, buffalo::Write)]
85#[buffalo(size = "static", value_size = 0)]
86pub enum RegressionComparisonMetric {
87	#[buffalo(id = 0)]
88	MeanAbsoluteError,
89	#[buffalo(id = 1)]
90	MeanSquaredError,
91	#[buffalo(id = 2)]
92	RootMeanSquaredError,
93	#[buffalo(id = 3)]
94	R2,
95}
96
97#[derive(buffalo::Read, buffalo::Write)]
98#[buffalo(size = "dynamic")]
99pub struct RegressionMetrics {
100	#[buffalo(id = 0, required)]
101	pub mse: f32,
102	#[buffalo(id = 1, required)]
103	pub rmse: f32,
104	#[buffalo(id = 2, required)]
105	pub mae: f32,
106	#[buffalo(id = 3, required)]
107	pub r2: f32,
108}