1use crate::ml_metrics::ModelMetrics;
2use serde::Serialize;
3use std::collections::HashMap;
4use rayon::prelude::*;
5use rand::Rng;
6use smartcore::{
7 linalg::basic::matrix::DenseMatrix,
8 ensemble::random_forest_classifier::{
9 RandomForestClassifier,
10 RandomForestClassifierParameters
11 },
12};
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum PerformanceTrend {
16 Improving,
17 Stable,
18 Degrading,
19}
20
21#[derive(Debug, Clone, Serialize)]
22pub struct ModelParameters {
23 pub max_depth: u16,
24 pub min_samples_split: usize,
25 pub learning_rate: f64,
26 pub n_trees: u16,
27}
28
29impl Default for ModelParameters {
30 fn default() -> Self {
31 Self {
32 max_depth: 10,
33 min_samples_split: 2,
34 learning_rate: 0.1,
35 n_trees: 100,
36 }
37 }
38}
39
40#[derive(Debug, Serialize)]
41pub struct OptimizationStep {
42 pub params: ModelParameters,
43 pub metrics: ModelMetrics,
44 pub timestamp: chrono::DateTime<chrono::Utc>,
45}
46
47pub struct ModelOptimizer {
48 config: OptimizationConfig,
49 #[allow(dead_code)]
50 current_best: Option<OptimizationResult>,
51 optimization_history: Vec<OptimizationStep>,
52}
53
54#[derive(Debug, Clone)]
55pub struct GridSearchConfig {
56 pub learning_rates: Vec<f64>,
57 pub max_depths: Vec<u16>,
58 pub min_samples_splits: Vec<usize>,
59 pub n_trees: Vec<u16>,
60}
61
62#[derive(Debug, Clone)]
63pub struct OptimizationConfig {
64 pub learning_rate_range: (f64, f64),
65 pub max_iterations: usize,
66 pub early_stopping_patience: usize,
67 pub validation_split: f64,
68}
69
70#[derive(Debug, Serialize)]
71pub struct OptimizationResult {
72 pub best_params: ModelParameters,
73 pub performance_improvement: f64,
74 pub training_time: std::time::Duration,
75 pub optimization_history: Vec<OptimizationStep>,
76}
77
78impl ModelOptimizer {
79 pub fn new(config: OptimizationConfig) -> Self {
80 Self {
81 config,
82 current_best: None,
83 optimization_history: Vec::new(),
84 }
85 }
86
87 pub fn optimize(&mut self, current_metrics: &ModelMetrics, trend: &PerformanceTrend) -> Option<ModelParameters> {
88 match trend {
89 PerformanceTrend::Degrading => Some(self.find_optimal_parameters(current_metrics)),
90 _ if current_metrics.f1_score < 0.7 => Some(self.find_optimal_parameters(current_metrics)),
91 _ => None,
92 }
93 }
94
95 fn find_optimal_parameters(&mut self, baseline_metrics: &ModelMetrics) -> ModelParameters {
96 let mut best_params = self.get_default_parameters();
97 let mut best_score = baseline_metrics.f1_score;
98
99 for i in 0..self.config.max_iterations {
100 let candidate_params = self.generate_candidate_parameters();
101 let performance = self.evaluate_parameters(&candidate_params);
102
103 if performance > best_score {
104 best_score = performance;
105 best_params = candidate_params;
106
107 self.optimization_history.push(OptimizationStep {
108 params: best_params.clone(),
109 metrics: ModelMetrics {
110 model_id: format!("opt_iter_{}", i),
111 timestamp: chrono::Utc::now(),
112 accuracy: performance,
113 precision: performance, recall: performance, f1_score: performance,
116 confusion_matrix: baseline_metrics.confusion_matrix.clone(),
117 feature_importance: HashMap::new(),
118 training_duration: std::time::Duration::from_secs(0),
119 },
120 timestamp: chrono::Utc::now(),
121 });
122 }
123
124 if self.should_stop_early() {
126 break;
127 }
128 }
129
130 best_params
131 }
132
133 fn get_default_parameters(&self) -> ModelParameters {
134 ModelParameters::default()
135 }
136
137 fn generate_candidate_parameters(&self) -> ModelParameters {
138 let mut rng = rand::thread_rng();
139 ModelParameters {
140 learning_rate: rng.gen_range(self.config.learning_rate_range.0..self.config.learning_rate_range.1),
141 max_depth: rng.gen_range(5..20) as u16,
142 min_samples_split: rng.gen_range(2..10),
143 n_trees: rng.gen_range(50..200) as u16,
144 }
145 }
146
147 fn evaluate_parameters(&self, params: &ModelParameters) -> f64 {
148 let base_score = 0.7;
150 let lr_factor = (-((params.learning_rate - 0.01).powi(2)) / 0.001).exp();
151 let depth_factor = (-((params.max_depth as f64 - 10.0).powi(2)) / 100.0).exp();
152
153 base_score * lr_factor * depth_factor
154 }
155
156 fn should_stop_early(&self) -> bool {
157 if self.optimization_history.len() < self.config.early_stopping_patience {
158 return false;
159 }
160
161 let recent_scores: Vec<f64> = self.optimization_history
162 .iter()
163 .rev()
164 .take(self.config.early_stopping_patience)
165 .map(|step| step.metrics.f1_score)
166 .collect();
167
168 let max_score = recent_scores.iter().fold(0.0f64, |a, &b| a.max(b));
169 let min_score = recent_scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
170
171 max_score - min_score < 0.001 }
173
174 pub fn grid_search(&mut self, validation_data: &ValidationData) -> ModelParameters {
175 let grid_config = self.create_grid_config();
176 let mut best_params = self.get_default_parameters();
177 let mut best_score = 0.0;
178
179 let results: Vec<(ModelParameters, f64)> = grid_config.parameter_combinations()
181 .par_iter()
182 .map(|params| {
183 let score = self.cross_validate(params, validation_data);
184 (params.clone(), score)
185 })
186 .collect();
187
188 for (params, score) in results {
189 if score > best_score {
190 best_score = score;
191 best_params = params;
192 }
193 }
194
195 best_params
196 }
197
198 fn cross_validate(&self, params: &ModelParameters, data: &ValidationData) -> f64 {
199 let k_folds = 5;
200 let fold_size = data.features.len() / k_folds;
201 let mut scores = Vec::with_capacity(k_folds);
202
203 for k in 0..k_folds {
204 let start_idx = k * fold_size;
205 let end_idx = start_idx + fold_size;
206
207 let test_features: Vec<Vec<f64>> = data.features[start_idx..end_idx].to_vec();
209 let test_labels: Vec<bool> = data.labels[start_idx..end_idx].to_vec();
210
211 let train_features: Vec<Vec<f64>> = data.features.iter()
212 .enumerate()
213 .filter(|(i, _)| *i < start_idx || *i >= end_idx)
214 .map(|(_, f)| f.clone())
215 .collect();
216
217 let train_labels: Vec<bool> = data.labels.iter()
218 .enumerate()
219 .filter(|(i, _)| *i < start_idx || *i >= end_idx)
220 .map(|(_, l)| *l)
221 .collect();
222
223 let score = self.train_and_evaluate(
224 params,
225 &train_features,
226 &train_labels,
227 &test_features,
228 &test_labels
229 );
230 scores.push(score);
231 }
232
233 scores.iter().sum::<f64>() / scores.len() as f64
234 }
235
236 fn train_and_evaluate(
237 &self,
238 params: &ModelParameters,
239 train_features: &[Vec<f64>],
240 train_labels: &[bool],
241 test_features: &[Vec<f64>],
242 test_labels: &[bool]
243 ) -> f64 {
244 let x = DenseMatrix::from_2d_vec(&train_features.to_vec());
245 let y: Vec<i32> = train_labels.iter().map(|&b| if b { 1 } else { 0 }).collect();
246
247 let model = RandomForestClassifier::fit(
248 &x, &y,
249 RandomForestClassifierParameters {
250 n_trees: params.n_trees as u16,
251 max_depth: Some(params.max_depth as u16),
252 min_samples_leaf: 5,
253 min_samples_split: params.min_samples_split,
254 ..Default::default()
255 }
256 ).unwrap();
257
258 let x_test = DenseMatrix::from_2d_vec(&test_features.to_vec());
259 let predictions = model.predict(&x_test).unwrap();
260 let predictions: Vec<f64> = predictions.iter().map(|&p| p as f64).collect();
261
262 self.calculate_f1_score(&predictions, test_labels)
263 }
264
265 fn calculate_f1_score(&self, predictions: &[f64], actual: &[bool]) -> f64 {
266 let mut tp = 0;
267 let mut fp = 0;
268 let mut fn_count = 0;
269
270 for (pred, act) in predictions.iter().zip(actual.iter()) {
271 match (*pred > 0.5, *act) {
272 (true, true) => tp += 1,
273 (true, false) => fp += 1,
274 (false, true) => fn_count += 1,
275 _ => {}
276 }
277 }
278
279 let precision = if tp + fp == 0 { 0.0 } else { tp as f64 / (tp + fp) as f64 };
280 let recall = if tp + fn_count == 0 { 0.0 } else { tp as f64 / (tp + fn_count) as f64 };
281
282 if precision + recall == 0.0 {
283 0.0
284 } else {
285 2.0 * (precision * recall) / (precision + recall)
286 }
287 }
288
289 fn create_grid_config(&self) -> GridSearchConfig {
290 GridSearchConfig {
291 learning_rates: vec![0.01, 0.1, 0.5],
292 max_depths: vec![5, 10, 15],
293 min_samples_splits: vec![2, 5, 10],
294 n_trees: vec![50, 100, 200],
295 }
296 }
297}
298
299impl GridSearchConfig {
300 pub fn parameter_combinations(&self) -> Vec<ModelParameters> {
301 let mut combinations = Vec::new();
302
303 for &lr in &self.learning_rates {
304 for &md in &self.max_depths {
305 for &ms in &self.min_samples_splits {
306 for &nt in &self.n_trees {
307 combinations.push(ModelParameters {
308 learning_rate: lr,
309 max_depth: md,
310 min_samples_split: ms,
311 n_trees: nt,
312 });
313 }
314 }
315 }
316 }
317
318 combinations
319 }
320}
321
322#[derive(Debug)]
323pub struct ValidationData {
324 pub features: Vec<Vec<f64>>,
325 pub labels: Vec<bool>,
326}