spawn_access_control/
ml_analyzer.rs1use smartcore::linalg::basic::matrix::DenseMatrix;
2use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
3use smartcore::model_selection::train_test_split;
4use crate::behavioral::AccessEvent;
5use chrono::{DateTime, Utc, Timelike, Datelike};
6use serde::Serialize;
7use crate::ml_metrics::{ModelMetrics, ConfusionMatrix};
8use std::time::Instant;
9use std::collections::HashMap;
10
11#[derive(Debug, Serialize)]
12pub struct MLPrediction {
13 pub is_anomaly: bool,
14 pub confidence: f64,
15 pub features: Vec<String>,
16 pub timestamp: DateTime<Utc>,
17}
18
19pub struct MLAnalyzer {
20 model: Option<RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>>,
21 feature_names: Vec<String>,
22}
23
24impl MLAnalyzer {
25 pub fn new() -> Self {
26 Self {
27 model: None,
28 feature_names: vec![
29 "hour_of_day".to_string(),
30 "day_of_week".to_string(),
31 "duration_seconds".to_string(),
32 "resource_frequency".to_string(),
33 "success_rate".to_string(),
34 ],
35 }
36 }
37
38 pub fn train(&mut self, events: &[AccessEvent], anomalies: &[AccessEvent]) {
39 let (features, labels) = self.prepare_training_data(events, anomalies);
40
41 if features.is_empty() {
42 return;
43 }
44
45 let x = DenseMatrix::from_2d_vec(&features);
46 let y = labels.iter().map(|&x| x as i32).collect::<Vec<_>>();
47
48 let (x_train, _x_test, y_train, _y_test) = train_test_split(
50 &x,
51 &y,
52 0.2,
53 true,
54 Some(42)
55 );
56
57 let model = RandomForestClassifier::fit(
59 &x_train,
60 &y_train,
61 Default::default()
62 ).unwrap();
63
64 self.model = Some(model);
65 }
66
67 pub fn predict(&self, event: &AccessEvent) -> Option<MLPrediction> {
68 let model = self.model.as_ref()?;
69
70 let features = self.extract_features(event);
71 let x = DenseMatrix::from_2d_vec(&vec![features]);
72
73 let prediction = model.predict(&x).ok()?;
74
75 let confidence = if prediction[0] == 1 { 0.8 } else { 0.2 };
77
78 Some(MLPrediction {
79 is_anomaly: prediction[0] == 1,
80 confidence,
81 features: self.feature_names.clone(),
82 timestamp: Utc::now(),
83 })
84 }
85
86 fn extract_features(&self, event: &AccessEvent) -> Vec<f64> {
87 vec![
88 event.timestamp.hour() as f64,
89 event.timestamp.weekday().num_days_from_monday() as f64,
90 event.duration.as_secs_f64(),
91 1.0, if event.success { 1.0 } else { 0.0 },
93 ]
94 }
95
96 fn prepare_training_data(&self, events: &[AccessEvent], anomalies: &[AccessEvent])
97 -> (Vec<Vec<f64>>, Vec<f64>)
98 {
99 let mut features = Vec::new();
100 let mut labels = Vec::new();
101
102 for event in events {
103 features.push(self.extract_features(event));
104
105 let is_anomaly = anomalies.iter().any(|a| {
107 (event.timestamp - a.timestamp).num_minutes().abs() < 1
108 });
109
110 labels.push(if is_anomaly { 1.0 } else { 0.0 });
111 }
112
113 (features, labels)
114 }
115
116 pub fn evaluate_model(&self, test_events: &[AccessEvent], test_anomalies: &[AccessEvent]) -> Option<ModelMetrics> {
117 let model = self.model.as_ref()?;
118 let start_time = Instant::now();
119
120 let (features, actual_labels) = self.prepare_training_data(test_events, test_anomalies);
121 if features.is_empty() {
122 return None;
123 }
124
125 let x = DenseMatrix::from_2d_vec(&features);
126 let predictions = model.predict(&x).ok()?;
127
128 let mut confusion_matrix = ConfusionMatrix::new();
130 for (pred, actual) in predictions.iter().zip(actual_labels.iter()) {
131 confusion_matrix.update(
132 (*pred as i32) == 1,
133 (*actual as i32) == 1
134 );
135 }
136
137 let feature_importance = self.calculate_feature_importance();
139
140 Some(ModelMetrics {
141 model_id: format!("rf_model_{}", Utc::now().timestamp()),
142 timestamp: Utc::now(),
143 accuracy: confusion_matrix.accuracy(),
144 precision: confusion_matrix.precision(),
145 recall: confusion_matrix.recall(),
146 f1_score: confusion_matrix.f1_score(),
147 confusion_matrix,
148 feature_importance,
149 training_duration: start_time.elapsed(),
150 })
151 }
152
153 fn calculate_feature_importance(&self) -> HashMap<String, f64> {
154 let mut importance = HashMap::new();
155
156 for (idx, name) in self.feature_names.iter().enumerate() {
158 let score = 1.0 / (idx + 1) as f64; importance.insert(name.clone(), score);
160 }
161
162 let total: f64 = importance.values().sum();
164 for score in importance.values_mut() {
165 *score /= total;
166 }
167
168 importance
169 }
170
171 pub fn predict_with_threshold(&self, event: &AccessEvent, threshold: f64) -> Option<MLPrediction> {
172 let prediction = self.predict(event)?;
173
174 let is_anomaly = prediction.confidence > threshold;
176
177 Some(MLPrediction {
178 is_anomaly,
179 ..prediction
180 })
181 }
182
183 pub fn update_model(&mut self, new_events: &[AccessEvent], new_anomalies: &[AccessEvent]) -> Option<ModelMetrics> {
184 self.train(new_events, new_anomalies);
186
187 self.evaluate_model(new_events, new_anomalies)
189 }
190}