tensorlogic_train/
logging.rs

1//! Logging infrastructure for training.
2//!
3//! This module provides various logging backends to track training progress:
4//! - Console logging (stdout/stderr)
5//! - File logging (write to file)
6//! - TensorBoard logging (placeholder for future integration)
7//! - Metrics logging and aggregation
8
9use crate::{TrainError, TrainResult};
10use std::collections::HashMap;
11use std::fs::{File, OpenOptions};
12use std::io::Write;
13use std::path::{Path, PathBuf};
14
15/// Trait for logging backends.
16pub trait LoggingBackend {
17    /// Log a scalar metric.
18    ///
19    /// # Arguments
20    /// * `name` - Name of the metric
21    /// * `value` - Value of the metric
22    /// * `step` - Training step/epoch number
23    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()>;
24
25    /// Log a text message.
26    ///
27    /// # Arguments
28    /// * `message` - Text message to log
29    fn log_text(&mut self, message: &str) -> TrainResult<()>;
30
31    /// Flush any buffered logs.
32    fn flush(&mut self) -> TrainResult<()>;
33}
34
35/// Console logger that outputs to stdout.
36///
37/// Simple logger for debugging and development.
38#[derive(Debug, Clone, Default)]
39pub struct ConsoleLogger {
40    /// Whether to include timestamps.
41    pub include_timestamp: bool,
42}
43
44impl ConsoleLogger {
45    /// Create a new console logger.
46    pub fn new() -> Self {
47        Self {
48            include_timestamp: true,
49        }
50    }
51
52    /// Create a console logger without timestamps.
53    pub fn without_timestamp() -> Self {
54        Self {
55            include_timestamp: false,
56        }
57    }
58
59    fn format_timestamp(&self) -> String {
60        if self.include_timestamp {
61            let now = std::time::SystemTime::now();
62            match now.duration_since(std::time::UNIX_EPOCH) {
63                Ok(duration) => format!("[{:.3}] ", duration.as_secs_f64()),
64                Err(_) => String::new(),
65            }
66        } else {
67            String::new()
68        }
69    }
70}
71
72impl LoggingBackend for ConsoleLogger {
73    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
74        println!(
75            "{}Step {}: {} = {:.6}",
76            self.format_timestamp(),
77            step,
78            name,
79            value
80        );
81        Ok(())
82    }
83
84    fn log_text(&mut self, message: &str) -> TrainResult<()> {
85        println!("{}{}", self.format_timestamp(), message);
86        Ok(())
87    }
88
89    fn flush(&mut self) -> TrainResult<()> {
90        use std::io::stdout;
91        stdout()
92            .flush()
93            .map_err(|e| TrainError::Other(format!("Failed to flush stdout: {}", e)))?;
94        Ok(())
95    }
96}
97
98/// File logger that writes logs to a file.
99///
100/// Useful for persistent logging and later analysis.
101#[derive(Debug)]
102pub struct FileLogger {
103    file: File,
104    path: PathBuf,
105}
106
107impl FileLogger {
108    /// Create a new file logger.
109    ///
110    /// # Arguments
111    /// * `path` - Path to the log file (will be created or appended)
112    pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
113        let path = path.as_ref().to_path_buf();
114        let file = OpenOptions::new()
115            .create(true)
116            .append(true)
117            .open(&path)
118            .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
119
120        Ok(Self { file, path })
121    }
122
123    /// Create a new file logger, truncating the file if it exists.
124    ///
125    /// # Arguments
126    /// * `path` - Path to the log file
127    pub fn new_truncate<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
128        let path = path.as_ref().to_path_buf();
129        let file = OpenOptions::new()
130            .create(true)
131            .write(true)
132            .truncate(true)
133            .open(&path)
134            .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
135
136        Ok(Self { file, path })
137    }
138
139    /// Get the path to the log file.
140    pub fn path(&self) -> &Path {
141        &self.path
142    }
143}
144
145impl LoggingBackend for FileLogger {
146    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
147        writeln!(self.file, "Step {}: {} = {:.6}", step, name, value)
148            .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
149        Ok(())
150    }
151
152    fn log_text(&mut self, message: &str) -> TrainResult<()> {
153        writeln!(self.file, "{}", message)
154            .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
155        Ok(())
156    }
157
158    fn flush(&mut self) -> TrainResult<()> {
159        self.file
160            .flush()
161            .map_err(|e| TrainError::Other(format!("Failed to flush log file: {}", e)))?;
162        Ok(())
163    }
164}
165
166/// TensorBoard logger (placeholder for future implementation).
167///
168/// Will integrate with TensorBoard for visualization.
169#[derive(Debug, Clone)]
170pub struct TensorBoardLogger {
171    log_dir: PathBuf,
172}
173
174impl TensorBoardLogger {
175    /// Create a new TensorBoard logger.
176    ///
177    /// # Arguments
178    /// * `log_dir` - Directory for TensorBoard logs
179    pub fn new<P: AsRef<Path>>(log_dir: P) -> TrainResult<Self> {
180        let log_dir = log_dir.as_ref().to_path_buf();
181
182        // Create directory if it doesn't exist
183        std::fs::create_dir_all(&log_dir).map_err(|e| {
184            TrainError::Other(format!(
185                "Failed to create log directory {:?}: {}",
186                log_dir, e
187            ))
188        })?;
189
190        Ok(Self { log_dir })
191    }
192
193    /// Get the log directory.
194    pub fn log_dir(&self) -> &Path {
195        &self.log_dir
196    }
197}
198
199impl LoggingBackend for TensorBoardLogger {
200    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
201        // Placeholder: in the future, this would write TensorBoard event files
202        log::debug!("TensorBoard: Step {}: {} = {:.6}", step, name, value);
203        Ok(())
204    }
205
206    fn log_text(&mut self, message: &str) -> TrainResult<()> {
207        log::debug!("TensorBoard: {}", message);
208        Ok(())
209    }
210
211    fn flush(&mut self) -> TrainResult<()> {
212        // Placeholder
213        Ok(())
214    }
215}
216
217/// Metrics logger that aggregates and logs training metrics.
218///
219/// Collects metrics and logs them using multiple backends.
220#[derive(Debug)]
221pub struct MetricsLogger {
222    backends: Vec<Box<dyn LoggingBackendClone>>,
223    current_step: usize,
224    accumulated_metrics: HashMap<String, Vec<f64>>,
225}
226
227/// Helper trait for cloning boxed logging backends.
228trait LoggingBackendClone: LoggingBackend + std::fmt::Debug {
229    fn clone_box(&self) -> Box<dyn LoggingBackendClone>;
230}
231
232impl<T: LoggingBackend + Clone + std::fmt::Debug + 'static> LoggingBackendClone for T {
233    fn clone_box(&self) -> Box<dyn LoggingBackendClone> {
234        Box::new(self.clone())
235    }
236}
237
238impl Clone for Box<dyn LoggingBackendClone> {
239    fn clone(&self) -> Self {
240        self.clone_box()
241    }
242}
243
244impl LoggingBackend for Box<dyn LoggingBackendClone> {
245    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
246        (**self).log_scalar(name, value, step)
247    }
248
249    fn log_text(&mut self, message: &str) -> TrainResult<()> {
250        (**self).log_text(message)
251    }
252
253    fn flush(&mut self) -> TrainResult<()> {
254        (**self).flush()
255    }
256}
257
258impl MetricsLogger {
259    /// Create a new metrics logger.
260    pub fn new() -> Self {
261        Self {
262            backends: Vec::new(),
263            current_step: 0,
264            accumulated_metrics: HashMap::new(),
265        }
266    }
267
268    /// Add a logging backend.
269    ///
270    /// # Arguments
271    /// * `backend` - Backend to add
272    pub fn add_backend<B: LoggingBackend + Clone + std::fmt::Debug + 'static>(
273        &mut self,
274        backend: B,
275    ) {
276        self.backends.push(Box::new(backend));
277    }
278
279    /// Log a scalar metric to all backends.
280    ///
281    /// # Arguments
282    /// * `name` - Metric name
283    /// * `value` - Metric value
284    pub fn log_metric(&mut self, name: &str, value: f64) -> TrainResult<()> {
285        for backend in &mut self.backends {
286            backend.log_scalar(name, value, self.current_step)?;
287        }
288        Ok(())
289    }
290
291    /// Accumulate a metric value (for averaging over batch).
292    ///
293    /// # Arguments
294    /// * `name` - Metric name
295    /// * `value` - Metric value
296    pub fn accumulate_metric(&mut self, name: &str, value: f64) {
297        self.accumulated_metrics
298            .entry(name.to_string())
299            .or_default()
300            .push(value);
301    }
302
303    /// Log accumulated metrics (average) and clear accumulation.
304    pub fn log_accumulated_metrics(&mut self) -> TrainResult<()> {
305        // Collect metrics to log before clearing
306        let metrics_to_log: Vec<(String, f64)> = self
307            .accumulated_metrics
308            .iter()
309            .filter(|(_, values)| !values.is_empty())
310            .map(|(name, values)| {
311                let avg = values.iter().sum::<f64>() / values.len() as f64;
312                (name.clone(), avg)
313            })
314            .collect();
315
316        // Log all metrics
317        for (name, avg) in metrics_to_log {
318            self.log_metric(&name, avg)?;
319        }
320
321        // Clear accumulation
322        self.accumulated_metrics.clear();
323        Ok(())
324    }
325
326    /// Log a text message to all backends.
327    ///
328    /// # Arguments
329    /// * `message` - Text message
330    pub fn log_message(&mut self, message: &str) -> TrainResult<()> {
331        for backend in &mut self.backends {
332            backend.log_text(message)?;
333        }
334        Ok(())
335    }
336
337    /// Increment the step counter.
338    pub fn step(&mut self) {
339        self.current_step += 1;
340    }
341
342    /// Set the current step.
343    ///
344    /// # Arguments
345    /// * `step` - Step number
346    pub fn set_step(&mut self, step: usize) {
347        self.current_step = step;
348    }
349
350    /// Get the current step.
351    pub fn current_step(&self) -> usize {
352        self.current_step
353    }
354
355    /// Flush all backends.
356    pub fn flush(&mut self) -> TrainResult<()> {
357        for backend in &mut self.backends {
358            backend.flush()?;
359        }
360        Ok(())
361    }
362
363    /// Get the number of backends.
364    pub fn num_backends(&self) -> usize {
365        self.backends.len()
366    }
367}
368
369impl Default for MetricsLogger {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use std::env;
379    use std::fs;
380
381    #[test]
382    fn test_console_logger() {
383        let mut logger = ConsoleLogger::new();
384
385        // These should not fail
386        logger.log_scalar("loss", 0.5, 1).unwrap();
387        logger.log_text("Test message").unwrap();
388        logger.flush().unwrap();
389    }
390
391    #[test]
392    fn test_console_logger_without_timestamp() {
393        let mut logger = ConsoleLogger::without_timestamp();
394
395        logger.log_scalar("accuracy", 0.95, 10).unwrap();
396        logger.log_text("Another test").unwrap();
397    }
398
399    #[test]
400    fn test_file_logger() {
401        let temp_dir = env::temp_dir();
402        let log_path = temp_dir.join("test_training.log");
403
404        // Clean up if file exists
405        let _ = fs::remove_file(&log_path);
406
407        let mut logger = FileLogger::new(&log_path).unwrap();
408
409        logger.log_scalar("loss", 0.5, 1).unwrap();
410        logger.log_scalar("accuracy", 0.9, 1).unwrap();
411        logger.log_text("Training started").unwrap();
412        logger.flush().unwrap();
413
414        // Verify file was created
415        assert!(log_path.exists());
416
417        // Read and verify contents
418        let contents = fs::read_to_string(&log_path).unwrap();
419        assert!(contents.contains("loss = 0.500000"));
420        assert!(contents.contains("accuracy = 0.900000"));
421        assert!(contents.contains("Training started"));
422
423        // Clean up
424        fs::remove_file(&log_path).unwrap();
425    }
426
427    #[test]
428    fn test_file_logger_truncate() {
429        let temp_dir = env::temp_dir();
430        let log_path = temp_dir.join("test_training_truncate.log");
431
432        // Create file with some content
433        {
434            let mut logger = FileLogger::new(&log_path).unwrap();
435            logger.log_text("Old content").unwrap();
436            logger.flush().unwrap();
437        }
438
439        // Truncate and write new content
440        {
441            let mut logger = FileLogger::new_truncate(&log_path).unwrap();
442            logger.log_text("New content").unwrap();
443            logger.flush().unwrap();
444        }
445
446        // Verify old content is gone
447        let contents = fs::read_to_string(&log_path).unwrap();
448        assert!(!contents.contains("Old content"));
449        assert!(contents.contains("New content"));
450
451        // Clean up
452        fs::remove_file(&log_path).unwrap();
453    }
454
455    #[test]
456    fn test_tensorboard_logger() {
457        let temp_dir = env::temp_dir();
458        let tb_dir = temp_dir.join("test_tensorboard");
459
460        // Clean up if directory exists
461        let _ = fs::remove_dir_all(&tb_dir);
462
463        let mut logger = TensorBoardLogger::new(&tb_dir).unwrap();
464
465        // Directory should be created
466        assert!(tb_dir.exists());
467
468        logger.log_scalar("loss", 0.5, 1).unwrap();
469        logger.log_text("Test message").unwrap();
470        logger.flush().unwrap();
471
472        // Clean up
473        fs::remove_dir_all(&tb_dir).unwrap();
474    }
475
476    #[test]
477    fn test_metrics_logger() {
478        let mut logger = MetricsLogger::new();
479        assert_eq!(logger.num_backends(), 0);
480
481        logger.add_backend(ConsoleLogger::without_timestamp());
482        assert_eq!(logger.num_backends(), 1);
483
484        logger.log_metric("loss", 0.5).unwrap();
485        logger.log_message("Epoch 1").unwrap();
486
487        assert_eq!(logger.current_step(), 0);
488        logger.step();
489        assert_eq!(logger.current_step(), 1);
490
491        logger.set_step(10);
492        assert_eq!(logger.current_step(), 10);
493
494        logger.flush().unwrap();
495    }
496
497    #[test]
498    fn test_metrics_logger_accumulation() {
499        let mut logger = MetricsLogger::new();
500        logger.add_backend(ConsoleLogger::without_timestamp());
501
502        // Accumulate multiple values
503        logger.accumulate_metric("batch_loss", 0.5);
504        logger.accumulate_metric("batch_loss", 0.4);
505        logger.accumulate_metric("batch_loss", 0.6);
506
507        // Log accumulated (should be average: 0.5)
508        logger.log_accumulated_metrics().unwrap();
509
510        // Accumulation should be cleared
511        logger.log_accumulated_metrics().unwrap(); // Should not fail even if empty
512    }
513
514    #[test]
515    fn test_metrics_logger_multiple_backends() {
516        let mut logger = MetricsLogger::new();
517        logger.add_backend(ConsoleLogger::without_timestamp());
518        logger.add_backend(ConsoleLogger::new());
519
520        assert_eq!(logger.num_backends(), 2);
521
522        logger.log_metric("loss", 0.5).unwrap();
523        logger.flush().unwrap();
524    }
525
526    #[test]
527    fn test_metrics_logger_empty_accumulation() {
528        let mut logger = MetricsLogger::new();
529        logger.add_backend(ConsoleLogger::without_timestamp());
530
531        // Log without accumulating anything
532        logger.log_accumulated_metrics().unwrap();
533    }
534
535    #[test]
536    fn test_file_logger_path() {
537        let temp_dir = env::temp_dir();
538        let log_path = temp_dir.join("test_path.log");
539        let _ = fs::remove_file(&log_path);
540
541        let logger = FileLogger::new(&log_path).unwrap();
542        assert_eq!(logger.path(), log_path.as_path());
543
544        // Clean up
545        fs::remove_file(&log_path).unwrap();
546    }
547
548    #[test]
549    fn test_tensorboard_logger_log_dir() {
550        let temp_dir = env::temp_dir();
551        let tb_dir = temp_dir.join("test_tb_path");
552        let _ = fs::remove_dir_all(&tb_dir);
553
554        let logger = TensorBoardLogger::new(&tb_dir).unwrap();
555        assert_eq!(logger.log_dir(), tb_dir.as_path());
556
557        // Clean up
558        fs::remove_dir_all(&tb_dir).unwrap();
559    }
560}