Skip to main content

trustformers_debug/
mlflow_integration.rs

1//! MLflow Integration for Experiment Tracking
2//!
3//! This module provides integration with MLflow for tracking experiments, logging metrics,
4//! parameters, and artifacts during model training and debugging.
5
6use anyhow::{Context, Result};
7use parking_lot::RwLock;
8use scirs2_core::ndarray::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use trustformers_core::tensor::Tensor;
14
15/// MLflow client for experiment tracking
16#[derive(Debug)]
17pub struct MLflowClient {
18    /// MLflow tracking URI
19    tracking_uri: String,
20    /// Current experiment ID
21    experiment_id: Option<String>,
22    /// Current run ID
23    run_id: Option<String>,
24    /// Configuration
25    config: MLflowConfig,
26    /// Cached metrics
27    metrics_cache: Arc<RwLock<HashMap<String, Vec<MetricPoint>>>>,
28    /// Cached parameters
29    params_cache: Arc<RwLock<HashMap<String, String>>>,
30}
31
32/// Configuration for MLflow integration
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MLflowConfig {
35    /// MLflow tracking server URI (default: http://localhost:5000)
36    pub tracking_uri: String,
37    /// Default experiment name
38    pub experiment_name: String,
39    /// Enable automatic metric logging
40    pub auto_log: bool,
41    /// Metric logging interval (steps)
42    pub log_interval: usize,
43    /// Maximum number of cached metrics before flush
44    pub max_cache_size: usize,
45    /// Enable artifact logging
46    pub log_artifacts: bool,
47    /// Artifact storage directory
48    pub artifact_dir: PathBuf,
49}
50
51impl Default for MLflowConfig {
52    fn default() -> Self {
53        Self {
54            tracking_uri: "http://localhost:5000".to_string(),
55            experiment_name: "trustformers-debug".to_string(),
56            auto_log: true,
57            log_interval: 10,
58            max_cache_size: 1000,
59            log_artifacts: true,
60            artifact_dir: PathBuf::from("./mlflow_artifacts"),
61        }
62    }
63}
64
65/// A single metric data point
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MetricPoint {
68    /// Metric value
69    pub value: f64,
70    /// Step number
71    pub step: i64,
72    /// Timestamp (milliseconds since epoch)
73    pub timestamp: i64,
74}
75
76/// MLflow run information
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct RunInfo {
79    /// Run ID
80    pub run_id: String,
81    /// Experiment ID
82    pub experiment_id: String,
83    /// Run name
84    pub run_name: String,
85    /// Start time (milliseconds since epoch)
86    pub start_time: i64,
87    /// End time (milliseconds since epoch, None if active)
88    pub end_time: Option<i64>,
89    /// Run status
90    pub status: RunStatus,
91}
92
93/// Status of an MLflow run
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
95pub enum RunStatus {
96    /// Run is active
97    Running,
98    /// Run completed successfully
99    Finished,
100    /// Run failed
101    Failed,
102    /// Run was killed
103    Killed,
104}
105
106/// Artifact type for logging
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub enum ArtifactType {
109    /// Model weights/checkpoints
110    Model,
111    /// Visualization plots
112    Plot,
113    /// Text reports
114    Report,
115    /// Raw data
116    Data,
117    /// Configuration files
118    Config,
119}
120
121impl MLflowClient {
122    /// Create a new MLflow client
123    ///
124    /// # Arguments
125    /// * `config` - MLflow configuration
126    ///
127    /// # Example
128    /// ```rust
129    /// use trustformers_debug::{MLflowClient, MLflowConfig};
130    ///
131    /// let config = MLflowConfig::default();
132    /// let client = MLflowClient::new(config);
133    /// ```
134    pub fn new(config: MLflowConfig) -> Self {
135        Self {
136            tracking_uri: config.tracking_uri.clone(),
137            experiment_id: None,
138            run_id: None,
139            config,
140            metrics_cache: Arc::new(RwLock::new(HashMap::new())),
141            params_cache: Arc::new(RwLock::new(HashMap::new())),
142        }
143    }
144
145    /// Set the tracking URI
146    ///
147    /// # Arguments
148    /// * `uri` - MLflow tracking server URI
149    pub fn set_tracking_uri(&mut self, uri: impl Into<String>) {
150        self.tracking_uri = uri.into();
151    }
152
153    /// Start a new experiment
154    ///
155    /// # Arguments
156    /// * `name` - Experiment name
157    ///
158    /// # Returns
159    /// Experiment ID
160    pub fn start_experiment(&mut self, name: impl Into<String>) -> Result<String> {
161        let experiment_name = name.into();
162
163        // In a real implementation, this would make an HTTP request to MLflow
164        // For now, we'll simulate it
165        let experiment_id = format!("exp_{}", uuid::Uuid::new_v4());
166
167        self.experiment_id = Some(experiment_id.clone());
168
169        tracing::info!(
170            experiment_id = %experiment_id,
171            experiment_name = %experiment_name,
172            "Started MLflow experiment"
173        );
174
175        Ok(experiment_id)
176    }
177
178    /// Start a new run within the current experiment
179    ///
180    /// # Arguments
181    /// * `run_name` - Optional run name
182    ///
183    /// # Returns
184    /// Run ID
185    pub fn start_run(&mut self, run_name: Option<&str>) -> Result<String> {
186        let experiment_id = self
187            .experiment_id
188            .as_ref()
189            .context("No active experiment. Call start_experiment() first")?;
190
191        let run_id = format!("run_{}", uuid::Uuid::new_v4());
192        let run_name = run_name.unwrap_or("debug_run").to_string();
193
194        self.run_id = Some(run_id.clone());
195
196        // Clear caches for new run
197        self.metrics_cache.write().clear();
198        self.params_cache.write().clear();
199
200        tracing::info!(
201            run_id = %run_id,
202            run_name = %run_name,
203            experiment_id = %experiment_id,
204            "Started MLflow run"
205        );
206
207        Ok(run_id)
208    }
209
210    /// End the current run
211    ///
212    /// # Arguments
213    /// * `status` - Final run status
214    pub fn end_run(&mut self, status: RunStatus) -> Result<()> {
215        let run_id = self.run_id.as_ref().context("No active run")?;
216
217        // Flush any cached metrics
218        self.flush_metrics()?;
219
220        tracing::info!(
221            run_id = %run_id,
222            status = ?status,
223            "Ended MLflow run"
224        );
225
226        self.run_id = None;
227
228        Ok(())
229    }
230
231    /// Log a parameter
232    ///
233    /// # Arguments
234    /// * `key` - Parameter name
235    /// * `value` - Parameter value
236    pub fn log_param(&mut self, key: impl Into<String>, value: impl ToString) -> Result<()> {
237        let key = key.into();
238        let value = value.to_string();
239
240        let _run_id = self.run_id.as_ref().context("No active run. Call start_run() first")?;
241
242        self.params_cache.write().insert(key.clone(), value.clone());
243
244        tracing::debug!(key = %key, value = %value, "Logged parameter");
245
246        Ok(())
247    }
248
249    /// Log multiple parameters at once
250    ///
251    /// # Arguments
252    /// * `params` - Map of parameter names to values
253    pub fn log_params(&mut self, params: HashMap<String, String>) -> Result<()> {
254        for (key, value) in params {
255            self.log_param(key, value)?;
256        }
257        Ok(())
258    }
259
260    /// Log a metric at a specific step
261    ///
262    /// # Arguments
263    /// * `key` - Metric name
264    /// * `value` - Metric value
265    /// * `step` - Step number
266    pub fn log_metric(&mut self, key: impl Into<String>, value: f64, step: i64) -> Result<()> {
267        let key = key.into();
268
269        let _run_id = self.run_id.as_ref().context("No active run. Call start_run() first")?;
270
271        let timestamp = std::time::SystemTime::now()
272            .duration_since(std::time::UNIX_EPOCH)
273            .unwrap()
274            .as_millis() as i64;
275
276        let metric = MetricPoint {
277            value,
278            step,
279            timestamp,
280        };
281
282        self.metrics_cache.write().entry(key.clone()).or_default().push(metric);
283
284        tracing::debug!(key = %key, value = %value, step = %step, "Logged metric");
285
286        // Auto-flush if cache is too large
287        if self.metrics_cache.read().values().map(|v| v.len()).sum::<usize>()
288            >= self.config.max_cache_size
289        {
290            self.flush_metrics()?;
291        }
292
293        Ok(())
294    }
295
296    /// Log multiple metrics at once
297    ///
298    /// # Arguments
299    /// * `metrics` - Map of metric names to values
300    /// * `step` - Step number
301    pub fn log_metrics(&mut self, metrics: HashMap<String, f64>, step: i64) -> Result<()> {
302        for (key, value) in metrics {
303            self.log_metric(key, value, step)?;
304        }
305        Ok(())
306    }
307
308    /// Log tensor statistics as metrics
309    ///
310    /// # Arguments
311    /// * `prefix` - Metric name prefix
312    /// * `tensor` - Tensor to analyze
313    /// * `step` - Step number
314    pub fn log_tensor_stats(&mut self, prefix: &str, tensor: &Tensor, step: i64) -> Result<()> {
315        // Log tensor element count and shape info
316        self.log_metric(
317            format!("{}/element_count", prefix),
318            tensor.len() as f64,
319            step,
320        )?;
321        self.log_metric(
322            format!("{}/memory_bytes", prefix),
323            tensor.memory_usage() as f64,
324            step,
325        )?;
326
327        let shape = tensor.shape();
328        self.log_metric(format!("{}/ndim", prefix), shape.len() as f64, step)?;
329
330        Ok(())
331    }
332
333    /// Log array statistics as metrics
334    ///
335    /// # Arguments
336    /// * `prefix` - Metric name prefix
337    /// * `array` - Array to analyze
338    /// * `step` - Step number
339    pub fn log_array_stats(&mut self, prefix: &str, array: &Array1<f64>, step: i64) -> Result<()> {
340        let mean = array.mean().unwrap_or(0.0);
341        let std = array.std(0.0);
342        let min = array.iter().copied().fold(f64::INFINITY, f64::min);
343        let max = array.iter().copied().fold(f64::NEG_INFINITY, f64::max);
344
345        self.log_metric(format!("{}/mean", prefix), mean, step)?;
346        self.log_metric(format!("{}/std", prefix), std, step)?;
347        self.log_metric(format!("{}/min", prefix), min, step)?;
348        self.log_metric(format!("{}/max", prefix), max, step)?;
349
350        Ok(())
351    }
352
353    /// Flush cached metrics to MLflow server
354    fn flush_metrics(&self) -> Result<()> {
355        let metrics = self.metrics_cache.read();
356
357        if metrics.is_empty() {
358            return Ok(());
359        }
360
361        // In a real implementation, this would make HTTP requests to MLflow
362        tracing::debug!(metric_count = metrics.len(), "Flushed metrics to MLflow");
363
364        Ok(())
365    }
366
367    /// Log an artifact (file)
368    ///
369    /// # Arguments
370    /// * `local_path` - Path to local file
371    /// * `artifact_path` - Optional path within artifact storage
372    /// * `artifact_type` - Type of artifact
373    pub fn log_artifact(
374        &self,
375        local_path: impl AsRef<Path>,
376        artifact_path: Option<&str>,
377        artifact_type: ArtifactType,
378    ) -> Result<()> {
379        let _run_id = self.run_id.as_ref().context("No active run")?;
380
381        let local_path = local_path.as_ref();
382
383        if !self.config.log_artifacts {
384            tracing::debug!("Artifact logging disabled");
385            return Ok(());
386        }
387
388        // Copy to artifact directory
389        let artifact_dir = &self.config.artifact_dir;
390        std::fs::create_dir_all(artifact_dir)?;
391
392        let dest_path = if let Some(rel_path) = artifact_path {
393            artifact_dir.join(rel_path)
394        } else {
395            artifact_dir.join(local_path.file_name().unwrap())
396        };
397
398        if let Some(parent) = dest_path.parent() {
399            std::fs::create_dir_all(parent)?;
400        }
401
402        std::fs::copy(local_path, &dest_path).context("Failed to copy artifact")?;
403
404        tracing::info!(
405            local_path = ?local_path,
406            artifact_path = ?dest_path,
407            artifact_type = ?artifact_type,
408            "Logged artifact"
409        );
410
411        Ok(())
412    }
413
414    /// Log a model artifact
415    ///
416    /// # Arguments
417    /// * `model_path` - Path to model file
418    /// * `model_name` - Optional model name
419    pub fn log_model(&self, model_path: impl AsRef<Path>, model_name: Option<&str>) -> Result<()> {
420        let artifact_path = if let Some(name) = model_name {
421            format!("models/{}", name)
422        } else {
423            "models/model".to_string()
424        };
425
426        self.log_artifact(model_path, Some(&artifact_path), ArtifactType::Model)
427    }
428
429    /// Log a plot/visualization
430    ///
431    /// # Arguments
432    /// * `plot_path` - Path to plot file
433    /// * `plot_name` - Optional plot name
434    pub fn log_plot(&self, plot_path: impl AsRef<Path>, plot_name: Option<&str>) -> Result<()> {
435        let artifact_path = if let Some(name) = plot_name {
436            format!("plots/{}", name)
437        } else {
438            "plots/plot".to_string()
439        };
440
441        self.log_artifact(plot_path, Some(&artifact_path), ArtifactType::Plot)
442    }
443
444    /// Log a text report
445    ///
446    /// # Arguments
447    /// * `content` - Report content
448    /// * `filename` - Report filename
449    pub fn log_report(&self, content: &str, filename: &str) -> Result<()> {
450        let temp_path = std::env::temp_dir().join(filename);
451        std::fs::write(&temp_path, content)?;
452
453        self.log_artifact(
454            &temp_path,
455            Some(&format!("reports/{}", filename)),
456            ArtifactType::Report,
457        )?;
458
459        std::fs::remove_file(&temp_path)?;
460
461        Ok(())
462    }
463
464    /// Get current run information
465    pub fn get_run_info(&self) -> Option<RunInfo> {
466        let run_id = self.run_id.as_ref()?;
467        let experiment_id = self.experiment_id.as_ref()?;
468
469        Some(RunInfo {
470            run_id: run_id.clone(),
471            experiment_id: experiment_id.clone(),
472            run_name: "debug_run".to_string(),
473            start_time: 0, // Would be tracked in real implementation
474            end_time: None,
475            status: RunStatus::Running,
476        })
477    }
478
479    /// Get all logged parameters
480    pub fn get_params(&self) -> HashMap<String, String> {
481        self.params_cache.read().clone()
482    }
483
484    /// Get all logged metrics
485    pub fn get_metrics(&self) -> HashMap<String, Vec<MetricPoint>> {
486        self.metrics_cache.read().clone()
487    }
488}
489
490/// Integration with TrustformeRS debug session
491pub struct MLflowDebugSession {
492    /// MLflow client
493    pub client: MLflowClient,
494    /// Current step
495    step: i64,
496}
497
498impl MLflowDebugSession {
499    /// Create a new MLflow debug session
500    pub fn new(config: MLflowConfig) -> Self {
501        Self {
502            client: MLflowClient::new(config),
503            step: 0,
504        }
505    }
506
507    /// Start debugging with MLflow tracking
508    pub fn start(&mut self, experiment_name: &str, run_name: Option<&str>) -> Result<()> {
509        self.client.start_experiment(experiment_name)?;
510        self.client.start_run(run_name)?;
511        self.step = 0;
512        Ok(())
513    }
514
515    /// Log debugging metrics for current step
516    pub fn log_debug_metrics(&mut self, metrics: HashMap<String, f64>) -> Result<()> {
517        self.client.log_metrics(metrics, self.step)?;
518        self.step += 1;
519        Ok(())
520    }
521
522    /// End debugging session
523    pub fn end(&mut self, status: RunStatus) -> Result<()> {
524        self.client.end_run(status)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use scirs2_core::ndarray::Array1;
532
533    #[test]
534    fn test_mlflow_client_creation() {
535        let config = MLflowConfig::default();
536        let _client = MLflowClient::new(config);
537    }
538
539    #[test]
540    fn test_start_experiment_and_run() -> Result<()> {
541        let config = MLflowConfig::default();
542        let mut client = MLflowClient::new(config);
543
544        let _exp_id = client.start_experiment("test_experiment")?;
545        let _run_id = client.start_run(Some("test_run"))?;
546
547        Ok(())
548    }
549
550    #[test]
551    fn test_log_params() -> Result<()> {
552        let config = MLflowConfig::default();
553        let mut client = MLflowClient::new(config);
554
555        client.start_experiment("test")?;
556        client.start_run(None)?;
557
558        client.log_param("learning_rate", "0.001")?;
559        client.log_param("batch_size", "32")?;
560
561        let params = client.get_params();
562        assert_eq!(params.get("learning_rate"), Some(&"0.001".to_string()));
563        assert_eq!(params.get("batch_size"), Some(&"32".to_string()));
564
565        Ok(())
566    }
567
568    #[test]
569    fn test_log_metrics() -> Result<()> {
570        let config = MLflowConfig::default();
571        let mut client = MLflowClient::new(config);
572
573        client.start_experiment("test")?;
574        client.start_run(None)?;
575
576        client.log_metric("loss", 0.5, 0)?;
577        client.log_metric("loss", 0.4, 1)?;
578        client.log_metric("accuracy", 0.8, 0)?;
579
580        let metrics = client.get_metrics();
581        assert_eq!(metrics.get("loss").unwrap().len(), 2);
582        assert_eq!(metrics.get("accuracy").unwrap().len(), 1);
583
584        Ok(())
585    }
586
587    #[test]
588    fn test_log_array_stats() -> Result<()> {
589        let config = MLflowConfig::default();
590        let mut client = MLflowClient::new(config);
591
592        client.start_experiment("test")?;
593        client.start_run(None)?;
594
595        let array = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
596        client.log_array_stats("weights", &array, 0)?;
597
598        let metrics = client.get_metrics();
599        assert!(metrics.contains_key("weights/mean"));
600        assert!(metrics.contains_key("weights/std"));
601        assert!(metrics.contains_key("weights/min"));
602        assert!(metrics.contains_key("weights/max"));
603
604        Ok(())
605    }
606
607    #[test]
608    fn test_end_run() -> Result<()> {
609        let config = MLflowConfig::default();
610        let mut client = MLflowClient::new(config);
611
612        client.start_experiment("test")?;
613        client.start_run(None)?;
614        client.log_metric("loss", 0.5, 0)?;
615        client.end_run(RunStatus::Finished)?;
616
617        assert!(client.run_id.is_none());
618
619        Ok(())
620    }
621
622    #[test]
623    fn test_mlflow_debug_session() -> Result<()> {
624        let config = MLflowConfig::default();
625        let mut session = MLflowDebugSession::new(config);
626
627        session.start("test_debug", Some("debug_run_1"))?;
628
629        let mut metrics = HashMap::new();
630        metrics.insert("gradient_norm".to_string(), 0.1);
631        metrics.insert("activation_mean".to_string(), 0.5);
632
633        session.log_debug_metrics(metrics)?;
634
635        session.end(RunStatus::Finished)?;
636
637        Ok(())
638    }
639}