1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
pub use self::{
binary_classifier::BinaryClassifier, multiclass_classifier::MulticlassClassifier,
regressor::Regressor,
};
use ndarray::prelude::*;
use num::ToPrimitive;
use tangram_progress_counter::ProgressCounter;
mod binary_classifier;
mod multiclass_classifier;
mod regressor;
pub mod serialize;
mod shap;
#[derive(Clone, Debug)]
pub struct TrainOptions {
pub compute_losses: bool,
pub early_stopping_options: Option<EarlyStoppingOptions>,
pub l2_regularization: f32,
pub learning_rate: f32,
pub max_epochs: usize,
pub n_examples_per_batch: usize,
}
impl Default for TrainOptions {
fn default() -> TrainOptions {
TrainOptions {
compute_losses: false,
early_stopping_options: None,
l2_regularization: 0.0,
learning_rate: 0.1,
max_epochs: 100,
n_examples_per_batch: 32,
}
}
}
#[derive(Clone, Debug)]
pub struct EarlyStoppingOptions {
pub early_stopping_fraction: f32,
pub n_rounds_without_improvement_to_stop: usize,
pub min_decrease_in_loss_for_significant_change: f32,
}
pub struct Progress<'a> {
pub kill_chip: &'a tangram_kill_chip::KillChip,
pub handle_progress_event: &'a mut dyn FnMut(TrainProgressEvent),
}
#[derive(Clone, Debug)]
pub enum TrainProgressEvent {
Train(ProgressCounter),
TrainDone,
}
fn train_early_stopping_split<'features, 'labels, Label>(
features: ArrayView2<'features, f32>,
labels: ArrayView1<'labels, Label>,
early_stopping_fraction: f32,
) -> (
ArrayView2<'features, f32>,
ArrayView1<'labels, Label>,
ArrayView2<'features, f32>,
ArrayView1<'labels, Label>,
) {
let split_index = ((1.0 - early_stopping_fraction) * features.nrows().to_f32().unwrap())
.to_usize()
.unwrap();
let (features_train, features_early_stopping) = features.split_at(Axis(0), split_index);
let (labels_train, labels_early_stopping) = labels.split_at(Axis(0), split_index);
(
features_train,
labels_train,
features_early_stopping,
labels_early_stopping,
)
}
struct EarlyStoppingMonitor {
threshold: f32,
epochs: usize,
n_epochs_without_observed_improvement: usize,
previous_epoch_metric_value: Option<f32>,
}
impl EarlyStoppingMonitor {
pub fn new(threshold: f32, epochs: usize) -> EarlyStoppingMonitor {
EarlyStoppingMonitor {
threshold,
epochs,
previous_epoch_metric_value: None,
n_epochs_without_observed_improvement: 0,
}
}
pub fn update(&mut self, early_stopping_metric_value: f32) -> bool {
let result = if let Some(previous_stopping_metric) = self.previous_epoch_metric_value {
if early_stopping_metric_value > previous_stopping_metric
|| f32::abs(early_stopping_metric_value - previous_stopping_metric) < self.threshold
{
self.n_epochs_without_observed_improvement += 1;
self.n_epochs_without_observed_improvement >= self.epochs
} else {
self.n_epochs_without_observed_improvement = 0;
false
}
} else {
false
};
self.previous_epoch_metric_value = Some(early_stopping_metric_value);
result
}
}