use crate::ml_metrics::ModelMetrics;
use crate::model_explainer::SecurityImpactAnalysis;
use chrono::{DateTime, Utc, Duration};
use serde::Serialize;
use std::collections::VecDeque;
use tokio::sync::RwLock;
use std::sync::Arc;
#[derive(Debug, Serialize)]
pub struct ModelHealth {
pub current_status: HealthStatus,
pub performance_trend: PerformanceTrend,
pub alerts: Vec<HealthAlert>,
pub last_update: DateTime<Utc>,
}
#[derive(Debug, Serialize, PartialEq)]
pub enum HealthStatus {
Healthy,
Degraded,
Critical,
Unknown,
}
#[derive(Debug, Serialize)]
pub enum PerformanceTrend {
Improving,
Stable,
Degrading,
}
#[derive(Debug, Serialize, Clone)]
pub struct HealthAlert {
pub severity: AlertSeverity,
pub message: String,
pub timestamp: DateTime<Utc>,
pub metric_name: String,
pub threshold: f64,
pub current_value: f64,
}
#[derive(Debug, Clone, Serialize)]
pub enum AlertSeverity {
Critical,
Warning,
Info,
}
pub struct ModelMonitor {
metrics_history: Arc<RwLock<VecDeque<ModelMetrics>>>,
security_history: Arc<RwLock<VecDeque<SecurityImpactAnalysis>>>,
config: MonitoringConfig,
alerts: Arc<RwLock<Vec<HealthAlert>>>,
}
#[derive(Clone)]
pub struct MonitoringConfig {
pub metrics_window_size: usize,
pub performance_threshold: f64,
pub security_threshold: f64,
pub alert_cooldown: Duration,
}
impl ModelMonitor {
pub fn new(config: MonitoringConfig) -> Self {
Self {
metrics_history: Arc::new(RwLock::new(VecDeque::new())),
security_history: Arc::new(RwLock::new(VecDeque::new())),
config,
alerts: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn update_metrics(&self, metrics: ModelMetrics) {
let mut history = self.metrics_history.write().await;
if history.len() >= self.config.metrics_window_size {
history.pop_front();
}
history.push_back(metrics.clone());
self.check_performance_alerts(&metrics).await;
}
pub async fn update_security_analysis(&self, analysis: SecurityImpactAnalysis) {
let mut history = self.security_history.write().await;
if history.len() >= self.config.metrics_window_size {
history.pop_front();
}
history.push_back(analysis.clone());
self.check_security_alerts(&analysis).await;
}
pub async fn get_model_health(&self) -> ModelHealth {
let metrics = self.metrics_history.read().await;
let security = self.security_history.read().await;
let alerts = self.alerts.read().await;
let status = self.calculate_health_status(&metrics, &security).await;
let trend = self.calculate_performance_trend(&metrics).await;
ModelHealth {
current_status: status,
performance_trend: trend,
alerts: alerts.clone(),
last_update: Utc::now(),
}
}
async fn check_performance_alerts(&self, metrics: &ModelMetrics) {
let mut alerts = self.alerts.write().await;
if metrics.f1_score < self.config.performance_threshold {
alerts.push(HealthAlert {
severity: AlertSeverity::Warning,
message: format!("Low F1 score: {:.2}", metrics.f1_score),
timestamp: Utc::now(),
metric_name: "f1_score".to_string(),
threshold: self.config.performance_threshold,
current_value: metrics.f1_score,
});
}
let pr_diff = (metrics.precision - metrics.recall).abs();
if pr_diff > 0.2 {
alerts.push(HealthAlert {
severity: AlertSeverity::Warning,
message: "Significant precision-recall imbalance detected".to_string(),
timestamp: Utc::now(),
metric_name: "pr_balance".to_string(),
threshold: 0.2,
current_value: pr_diff,
});
}
}
async fn check_security_alerts(&self, analysis: &SecurityImpactAnalysis) {
let mut alerts = self.alerts.write().await;
if analysis.false_positive_impact > self.config.security_threshold {
alerts.push(HealthAlert {
severity: AlertSeverity::Critical,
message: "High false positive impact detected".to_string(),
timestamp: Utc::now(),
metric_name: "false_positive_impact".to_string(),
threshold: self.config.security_threshold,
current_value: analysis.false_positive_impact,
});
}
for factor in &analysis.risk_factors {
if factor.impact_score > 0.8 {
alerts.push(HealthAlert {
severity: AlertSeverity::Critical,
message: format!("Critical risk factor: {}", factor.name),
timestamp: Utc::now(),
metric_name: "risk_factor".to_string(),
threshold: 0.8,
current_value: factor.impact_score,
});
}
}
}
async fn calculate_health_status(
&self,
metrics: &VecDeque<ModelMetrics>,
security: &VecDeque<SecurityImpactAnalysis>
) -> HealthStatus {
if metrics.is_empty() || security.is_empty() {
return HealthStatus::Unknown;
}
let latest_metrics = metrics.back().unwrap();
let latest_security = security.back().unwrap();
if latest_metrics.f1_score < 0.6 || latest_security.false_positive_impact > 0.4 {
HealthStatus::Critical
} else if latest_metrics.f1_score < 0.8 || latest_security.false_positive_impact > 0.2 {
HealthStatus::Degraded
} else {
HealthStatus::Healthy
}
}
async fn calculate_performance_trend(&self, metrics: &VecDeque<ModelMetrics>) -> PerformanceTrend {
if metrics.len() < 2 {
return PerformanceTrend::Stable;
}
let recent_scores: Vec<f64> = metrics.iter()
.rev()
.take(5)
.map(|m| m.f1_score)
.collect();
let trend = recent_scores.windows(2)
.map(|w| w[0] - w[1])
.sum::<f64>();
match trend {
t if t > 0.05 => PerformanceTrend::Improving,
t if t < -0.05 => PerformanceTrend::Degrading,
_ => PerformanceTrend::Stable,
}
}
}