1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/*!
This crate is an implementation of linear machine learning models for regression and classification. There are three model types, [`Regressor`], [`BinaryClassifier`], and [`MulticlassClassifier`]. `BinaryClassifier` uses the sigmoid activation function, and `MulticlassClassifier` trains `n_classes` linear models whose outputs are combined with the `softmax` function.

To make training faster on multicore processors, we allow simultaneous read/write access to the model parameters from multiple threads. This means each thread will be reading weights partially updated by other threads and the weights it writes may be clobbered by other threads. This makes training nondeterministic, but in practice we observe little variation in the outcome, because there is feedback control: the change in loss is monitored after each epoch, and training terminates when the loss has stabilized.
*/

pub use self::{
	binary_classifier::BinaryClassifier, multiclass_classifier::MulticlassClassifier,
	regressor::Regressor,
};
use ndarray::prelude::*;
use num::ToPrimitive;
use tangram_progress_counter::ProgressCounter;

mod binary_classifier;
mod multiclass_classifier;
mod regressor;
pub mod serialize;
mod shap;

/// These are the options passed to `Regressor::train`, `BinaryClassifier::train`, and `MulticlassClassifier::train`.
#[derive(Clone, Debug)]
pub struct TrainOptions {
	/// If true, the model will include the loss on the training data after each epoch.
	pub compute_losses: bool,
	/// Specify options for early stopping. If the value is `Some`, early stopping will be enabled. If it is `None`, early stopping will be disabled.
	pub early_stopping_options: Option<EarlyStoppingOptions>,
	/// This is the L2 regularization value to use when updating the model parameters.
	pub l2_regularization: f32,
	/// This is the learning rate to use when updating the model parameters.
	pub learning_rate: f32,
	/// This is the maximum number of epochs to train.
	pub max_epochs: usize,
	/// This is the number of examples to use for each batch of training.
	pub n_examples_per_batch: usize,
}

impl Default for TrainOptions {
	fn default() -> TrainOptions {
		TrainOptions {
			compute_losses: false,
			early_stopping_options: None,
			l2_regularization: 0.0,
			learning_rate: 0.1,
			max_epochs: 100,
			n_examples_per_batch: 32,
		}
	}
}

/// The parameters in this struct control how to determine whether training should stop early after each round or epoch.
#[derive(Clone, Debug)]
pub struct EarlyStoppingOptions {
	/// This is the fraction of the dataset that is set aside to compute the early stopping metric.
	pub early_stopping_fraction: f32,
	/// If this many rounds or epochs pass by without a significant improvement in the early stopping metric over the previous round or epoch, training will be stopped early.
	pub n_rounds_without_improvement_to_stop: usize,
	/// This is the minimum descrease in the early stopping metric for a round or epoch to be considered a significant improvement over the previous round or epoch.
	pub min_decrease_in_loss_for_significant_change: f32,
}

pub struct Progress<'a> {
	pub kill_chip: &'a tangram_kill_chip::KillChip,
	pub handle_progress_event: &'a mut dyn FnMut(TrainProgressEvent),
}

/// This is the training progress, which tracks the current epoch.
#[derive(Clone, Debug)]
pub enum TrainProgressEvent {
	Train(ProgressCounter),
	TrainDone,
}

/// This function splits the `features` and `labels` arrays into training and early stopping arrays, where the size of the early stopping stopping array will be `features.len() * early_stopping_fraction`.
fn train_early_stopping_split<'features, 'labels, Label>(
	features: ArrayView2<'features, f32>,
	labels: ArrayView1<'labels, Label>,
	early_stopping_fraction: f32,
) -> (
	ArrayView2<'features, f32>,
	ArrayView1<'labels, Label>,
	ArrayView2<'features, f32>,
	ArrayView1<'labels, Label>,
) {
	let split_index = ((1.0 - early_stopping_fraction) * features.nrows().to_f32().unwrap())
		.to_usize()
		.unwrap();
	let (features_train, features_early_stopping) = features.split_at(Axis(0), split_index);
	let (labels_train, labels_early_stopping) = labels.split_at(Axis(0), split_index);
	(
		features_train,
		labels_train,
		features_early_stopping,
		labels_early_stopping,
	)
}

/**
The `EarlyStoppingMonitor` keeps track of the values of an early stopping metric for each epoch, and if enough epochs have passed without a significant improvement in the metric, the `update()` function will return `true` to indicate that training should be stopped.
*/
struct EarlyStoppingMonitor {
	threshold: f32,
	epochs: usize,
	n_epochs_without_observed_improvement: usize,
	previous_epoch_metric_value: Option<f32>,
}

impl EarlyStoppingMonitor {
	// Create a new `EarlyStoppingMonitor`.
	pub fn new(threshold: f32, epochs: usize) -> EarlyStoppingMonitor {
		EarlyStoppingMonitor {
			threshold,
			epochs,
			previous_epoch_metric_value: None,
			n_epochs_without_observed_improvement: 0,
		}
	}

	/// This function updates the `EarlyStoppingMonitor` with the next epoch's early stopping metric. THis function returns true if training should stop.
	pub fn update(&mut self, early_stopping_metric_value: f32) -> bool {
		let result = if let Some(previous_stopping_metric) = self.previous_epoch_metric_value {
			if early_stopping_metric_value > previous_stopping_metric
				|| f32::abs(early_stopping_metric_value - previous_stopping_metric) < self.threshold
			{
				self.n_epochs_without_observed_improvement += 1;
				self.n_epochs_without_observed_improvement >= self.epochs
			} else {
				self.n_epochs_without_observed_improvement = 0;
				false
			}
		} else {
			false
		};
		self.previous_epoch_metric_value = Some(early_stopping_metric_value);
		result
	}
}