1use anyhow::{Context, Result};
7use parking_lot::RwLock;
8use scirs2_core::ndarray::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use trustformers_core::tensor::Tensor;
14
15#[derive(Debug)]
17pub struct MLflowClient {
18 tracking_uri: String,
20 experiment_id: Option<String>,
22 run_id: Option<String>,
24 config: MLflowConfig,
26 metrics_cache: Arc<RwLock<HashMap<String, Vec<MetricPoint>>>>,
28 params_cache: Arc<RwLock<HashMap<String, String>>>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MLflowConfig {
35 pub tracking_uri: String,
37 pub experiment_name: String,
39 pub auto_log: bool,
41 pub log_interval: usize,
43 pub max_cache_size: usize,
45 pub log_artifacts: bool,
47 pub artifact_dir: PathBuf,
49}
50
51impl Default for MLflowConfig {
52 fn default() -> Self {
53 Self {
54 tracking_uri: "http://localhost:5000".to_string(),
55 experiment_name: "trustformers-debug".to_string(),
56 auto_log: true,
57 log_interval: 10,
58 max_cache_size: 1000,
59 log_artifacts: true,
60 artifact_dir: PathBuf::from("./mlflow_artifacts"),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MetricPoint {
68 pub value: f64,
70 pub step: i64,
72 pub timestamp: i64,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct RunInfo {
79 pub run_id: String,
81 pub experiment_id: String,
83 pub run_name: String,
85 pub start_time: i64,
87 pub end_time: Option<i64>,
89 pub status: RunStatus,
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum RunStatus {
96 Running,
98 Finished,
100 Failed,
102 Killed,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum ArtifactType {
109 Model,
111 Plot,
113 Report,
115 Data,
117 Config,
119}
120
121impl MLflowClient {
122 pub fn new(config: MLflowConfig) -> Self {
135 Self {
136 tracking_uri: config.tracking_uri.clone(),
137 experiment_id: None,
138 run_id: None,
139 config,
140 metrics_cache: Arc::new(RwLock::new(HashMap::new())),
141 params_cache: Arc::new(RwLock::new(HashMap::new())),
142 }
143 }
144
145 pub fn set_tracking_uri(&mut self, uri: impl Into<String>) {
150 self.tracking_uri = uri.into();
151 }
152
153 pub fn start_experiment(&mut self, name: impl Into<String>) -> Result<String> {
161 let experiment_name = name.into();
162
163 let experiment_id = format!("exp_{}", uuid::Uuid::new_v4());
166
167 self.experiment_id = Some(experiment_id.clone());
168
169 tracing::info!(
170 experiment_id = %experiment_id,
171 experiment_name = %experiment_name,
172 "Started MLflow experiment"
173 );
174
175 Ok(experiment_id)
176 }
177
178 pub fn start_run(&mut self, run_name: Option<&str>) -> Result<String> {
186 let experiment_id = self
187 .experiment_id
188 .as_ref()
189 .context("No active experiment. Call start_experiment() first")?;
190
191 let run_id = format!("run_{}", uuid::Uuid::new_v4());
192 let run_name = run_name.unwrap_or("debug_run").to_string();
193
194 self.run_id = Some(run_id.clone());
195
196 self.metrics_cache.write().clear();
198 self.params_cache.write().clear();
199
200 tracing::info!(
201 run_id = %run_id,
202 run_name = %run_name,
203 experiment_id = %experiment_id,
204 "Started MLflow run"
205 );
206
207 Ok(run_id)
208 }
209
210 pub fn end_run(&mut self, status: RunStatus) -> Result<()> {
215 let run_id = self.run_id.as_ref().context("No active run")?;
216
217 self.flush_metrics()?;
219
220 tracing::info!(
221 run_id = %run_id,
222 status = ?status,
223 "Ended MLflow run"
224 );
225
226 self.run_id = None;
227
228 Ok(())
229 }
230
231 pub fn log_param(&mut self, key: impl Into<String>, value: impl ToString) -> Result<()> {
237 let key = key.into();
238 let value = value.to_string();
239
240 let _run_id = self.run_id.as_ref().context("No active run. Call start_run() first")?;
241
242 self.params_cache.write().insert(key.clone(), value.clone());
243
244 tracing::debug!(key = %key, value = %value, "Logged parameter");
245
246 Ok(())
247 }
248
249 pub fn log_params(&mut self, params: HashMap<String, String>) -> Result<()> {
254 for (key, value) in params {
255 self.log_param(key, value)?;
256 }
257 Ok(())
258 }
259
260 pub fn log_metric(&mut self, key: impl Into<String>, value: f64, step: i64) -> Result<()> {
267 let key = key.into();
268
269 let _run_id = self.run_id.as_ref().context("No active run. Call start_run() first")?;
270
271 let timestamp = std::time::SystemTime::now()
272 .duration_since(std::time::UNIX_EPOCH)
273 .unwrap()
274 .as_millis() as i64;
275
276 let metric = MetricPoint {
277 value,
278 step,
279 timestamp,
280 };
281
282 self.metrics_cache.write().entry(key.clone()).or_default().push(metric);
283
284 tracing::debug!(key = %key, value = %value, step = %step, "Logged metric");
285
286 if self.metrics_cache.read().values().map(|v| v.len()).sum::<usize>()
288 >= self.config.max_cache_size
289 {
290 self.flush_metrics()?;
291 }
292
293 Ok(())
294 }
295
296 pub fn log_metrics(&mut self, metrics: HashMap<String, f64>, step: i64) -> Result<()> {
302 for (key, value) in metrics {
303 self.log_metric(key, value, step)?;
304 }
305 Ok(())
306 }
307
308 pub fn log_tensor_stats(&mut self, prefix: &str, tensor: &Tensor, step: i64) -> Result<()> {
315 self.log_metric(
317 format!("{}/element_count", prefix),
318 tensor.len() as f64,
319 step,
320 )?;
321 self.log_metric(
322 format!("{}/memory_bytes", prefix),
323 tensor.memory_usage() as f64,
324 step,
325 )?;
326
327 let shape = tensor.shape();
328 self.log_metric(format!("{}/ndim", prefix), shape.len() as f64, step)?;
329
330 Ok(())
331 }
332
333 pub fn log_array_stats(&mut self, prefix: &str, array: &Array1<f64>, step: i64) -> Result<()> {
340 let mean = array.mean().unwrap_or(0.0);
341 let std = array.std(0.0);
342 let min = array.iter().copied().fold(f64::INFINITY, f64::min);
343 let max = array.iter().copied().fold(f64::NEG_INFINITY, f64::max);
344
345 self.log_metric(format!("{}/mean", prefix), mean, step)?;
346 self.log_metric(format!("{}/std", prefix), std, step)?;
347 self.log_metric(format!("{}/min", prefix), min, step)?;
348 self.log_metric(format!("{}/max", prefix), max, step)?;
349
350 Ok(())
351 }
352
353 fn flush_metrics(&self) -> Result<()> {
355 let metrics = self.metrics_cache.read();
356
357 if metrics.is_empty() {
358 return Ok(());
359 }
360
361 tracing::debug!(metric_count = metrics.len(), "Flushed metrics to MLflow");
363
364 Ok(())
365 }
366
367 pub fn log_artifact(
374 &self,
375 local_path: impl AsRef<Path>,
376 artifact_path: Option<&str>,
377 artifact_type: ArtifactType,
378 ) -> Result<()> {
379 let _run_id = self.run_id.as_ref().context("No active run")?;
380
381 let local_path = local_path.as_ref();
382
383 if !self.config.log_artifacts {
384 tracing::debug!("Artifact logging disabled");
385 return Ok(());
386 }
387
388 let artifact_dir = &self.config.artifact_dir;
390 std::fs::create_dir_all(artifact_dir)?;
391
392 let dest_path = if let Some(rel_path) = artifact_path {
393 artifact_dir.join(rel_path)
394 } else {
395 artifact_dir.join(local_path.file_name().unwrap())
396 };
397
398 if let Some(parent) = dest_path.parent() {
399 std::fs::create_dir_all(parent)?;
400 }
401
402 std::fs::copy(local_path, &dest_path).context("Failed to copy artifact")?;
403
404 tracing::info!(
405 local_path = ?local_path,
406 artifact_path = ?dest_path,
407 artifact_type = ?artifact_type,
408 "Logged artifact"
409 );
410
411 Ok(())
412 }
413
414 pub fn log_model(&self, model_path: impl AsRef<Path>, model_name: Option<&str>) -> Result<()> {
420 let artifact_path = if let Some(name) = model_name {
421 format!("models/{}", name)
422 } else {
423 "models/model".to_string()
424 };
425
426 self.log_artifact(model_path, Some(&artifact_path), ArtifactType::Model)
427 }
428
429 pub fn log_plot(&self, plot_path: impl AsRef<Path>, plot_name: Option<&str>) -> Result<()> {
435 let artifact_path = if let Some(name) = plot_name {
436 format!("plots/{}", name)
437 } else {
438 "plots/plot".to_string()
439 };
440
441 self.log_artifact(plot_path, Some(&artifact_path), ArtifactType::Plot)
442 }
443
444 pub fn log_report(&self, content: &str, filename: &str) -> Result<()> {
450 let temp_path = std::env::temp_dir().join(filename);
451 std::fs::write(&temp_path, content)?;
452
453 self.log_artifact(
454 &temp_path,
455 Some(&format!("reports/{}", filename)),
456 ArtifactType::Report,
457 )?;
458
459 std::fs::remove_file(&temp_path)?;
460
461 Ok(())
462 }
463
464 pub fn get_run_info(&self) -> Option<RunInfo> {
466 let run_id = self.run_id.as_ref()?;
467 let experiment_id = self.experiment_id.as_ref()?;
468
469 Some(RunInfo {
470 run_id: run_id.clone(),
471 experiment_id: experiment_id.clone(),
472 run_name: "debug_run".to_string(),
473 start_time: 0, end_time: None,
475 status: RunStatus::Running,
476 })
477 }
478
479 pub fn get_params(&self) -> HashMap<String, String> {
481 self.params_cache.read().clone()
482 }
483
484 pub fn get_metrics(&self) -> HashMap<String, Vec<MetricPoint>> {
486 self.metrics_cache.read().clone()
487 }
488}
489
490pub struct MLflowDebugSession {
492 pub client: MLflowClient,
494 step: i64,
496}
497
498impl MLflowDebugSession {
499 pub fn new(config: MLflowConfig) -> Self {
501 Self {
502 client: MLflowClient::new(config),
503 step: 0,
504 }
505 }
506
507 pub fn start(&mut self, experiment_name: &str, run_name: Option<&str>) -> Result<()> {
509 self.client.start_experiment(experiment_name)?;
510 self.client.start_run(run_name)?;
511 self.step = 0;
512 Ok(())
513 }
514
515 pub fn log_debug_metrics(&mut self, metrics: HashMap<String, f64>) -> Result<()> {
517 self.client.log_metrics(metrics, self.step)?;
518 self.step += 1;
519 Ok(())
520 }
521
522 pub fn end(&mut self, status: RunStatus) -> Result<()> {
524 self.client.end_run(status)
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use scirs2_core::ndarray::Array1;
532
533 #[test]
534 fn test_mlflow_client_creation() {
535 let config = MLflowConfig::default();
536 let _client = MLflowClient::new(config);
537 }
538
539 #[test]
540 fn test_start_experiment_and_run() -> Result<()> {
541 let config = MLflowConfig::default();
542 let mut client = MLflowClient::new(config);
543
544 let _exp_id = client.start_experiment("test_experiment")?;
545 let _run_id = client.start_run(Some("test_run"))?;
546
547 Ok(())
548 }
549
550 #[test]
551 fn test_log_params() -> Result<()> {
552 let config = MLflowConfig::default();
553 let mut client = MLflowClient::new(config);
554
555 client.start_experiment("test")?;
556 client.start_run(None)?;
557
558 client.log_param("learning_rate", "0.001")?;
559 client.log_param("batch_size", "32")?;
560
561 let params = client.get_params();
562 assert_eq!(params.get("learning_rate"), Some(&"0.001".to_string()));
563 assert_eq!(params.get("batch_size"), Some(&"32".to_string()));
564
565 Ok(())
566 }
567
568 #[test]
569 fn test_log_metrics() -> Result<()> {
570 let config = MLflowConfig::default();
571 let mut client = MLflowClient::new(config);
572
573 client.start_experiment("test")?;
574 client.start_run(None)?;
575
576 client.log_metric("loss", 0.5, 0)?;
577 client.log_metric("loss", 0.4, 1)?;
578 client.log_metric("accuracy", 0.8, 0)?;
579
580 let metrics = client.get_metrics();
581 assert_eq!(metrics.get("loss").unwrap().len(), 2);
582 assert_eq!(metrics.get("accuracy").unwrap().len(), 1);
583
584 Ok(())
585 }
586
587 #[test]
588 fn test_log_array_stats() -> Result<()> {
589 let config = MLflowConfig::default();
590 let mut client = MLflowClient::new(config);
591
592 client.start_experiment("test")?;
593 client.start_run(None)?;
594
595 let array = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
596 client.log_array_stats("weights", &array, 0)?;
597
598 let metrics = client.get_metrics();
599 assert!(metrics.contains_key("weights/mean"));
600 assert!(metrics.contains_key("weights/std"));
601 assert!(metrics.contains_key("weights/min"));
602 assert!(metrics.contains_key("weights/max"));
603
604 Ok(())
605 }
606
607 #[test]
608 fn test_end_run() -> Result<()> {
609 let config = MLflowConfig::default();
610 let mut client = MLflowClient::new(config);
611
612 client.start_experiment("test")?;
613 client.start_run(None)?;
614 client.log_metric("loss", 0.5, 0)?;
615 client.end_run(RunStatus::Finished)?;
616
617 assert!(client.run_id.is_none());
618
619 Ok(())
620 }
621
622 #[test]
623 fn test_mlflow_debug_session() -> Result<()> {
624 let config = MLflowConfig::default();
625 let mut session = MLflowDebugSession::new(config);
626
627 session.start("test_debug", Some("debug_run_1"))?;
628
629 let mut metrics = HashMap::new();
630 metrics.insert("gradient_norm".to_string(), 0.1);
631 metrics.insert("activation_mean".to_string(), 0.5);
632
633 session.log_debug_metrics(metrics)?;
634
635 session.end(RunStatus::Finished)?;
636
637 Ok(())
638 }
639}