Skip to main content

tensorlogic_train/
logging.rs

1//! Logging infrastructure for training.
2//!
3//! This module provides various logging backends to track training progress:
4//! - Console logging (stdout/stderr)
5//! - File logging (write to file)
6//! - TensorBoard logging (real event file writing)
7//! - CSV logging (for easy analysis)
8//! - JSONL logging (machine-readable format)
9//! - Metrics logging and aggregation
10
11use crate::{TrainError, TrainResult};
12use byteorder::{LittleEndian, WriteBytesExt};
13use chrono::Utc;
14use std::collections::HashMap;
15use std::fs::{File, OpenOptions};
16use std::io::{BufWriter, Write};
17use std::path::{Path, PathBuf};
18
19/// Trait for logging backends.
20pub trait LoggingBackend {
21    /// Log a scalar metric.
22    ///
23    /// # Arguments
24    /// * `name` - Name of the metric
25    /// * `value` - Value of the metric
26    /// * `step` - Training step/epoch number
27    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()>;
28
29    /// Log a text message.
30    ///
31    /// # Arguments
32    /// * `message` - Text message to log
33    fn log_text(&mut self, message: &str) -> TrainResult<()>;
34
35    /// Flush any buffered logs.
36    fn flush(&mut self) -> TrainResult<()>;
37}
38
39/// Console logger that outputs to stdout.
40///
41/// Simple logger for debugging and development.
42#[derive(Debug, Clone, Default)]
43pub struct ConsoleLogger {
44    /// Whether to include timestamps.
45    pub include_timestamp: bool,
46}
47
48impl ConsoleLogger {
49    /// Create a new console logger.
50    pub fn new() -> Self {
51        Self {
52            include_timestamp: true,
53        }
54    }
55
56    /// Create a console logger without timestamps.
57    pub fn without_timestamp() -> Self {
58        Self {
59            include_timestamp: false,
60        }
61    }
62
63    fn format_timestamp(&self) -> String {
64        if self.include_timestamp {
65            let now = std::time::SystemTime::now();
66            match now.duration_since(std::time::UNIX_EPOCH) {
67                Ok(duration) => format!("[{:.3}] ", duration.as_secs_f64()),
68                Err(_) => String::new(),
69            }
70        } else {
71            String::new()
72        }
73    }
74}
75
76impl LoggingBackend for ConsoleLogger {
77    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
78        println!(
79            "{}Step {}: {} = {:.6}",
80            self.format_timestamp(),
81            step,
82            name,
83            value
84        );
85        Ok(())
86    }
87
88    fn log_text(&mut self, message: &str) -> TrainResult<()> {
89        println!("{}{}", self.format_timestamp(), message);
90        Ok(())
91    }
92
93    fn flush(&mut self) -> TrainResult<()> {
94        use std::io::stdout;
95        stdout()
96            .flush()
97            .map_err(|e| TrainError::Other(format!("Failed to flush stdout: {}", e)))?;
98        Ok(())
99    }
100}
101
102/// File logger that writes logs to a file.
103///
104/// Useful for persistent logging and later analysis.
105#[derive(Debug)]
106pub struct FileLogger {
107    file: File,
108    path: PathBuf,
109}
110
111impl FileLogger {
112    /// Create a new file logger.
113    ///
114    /// # Arguments
115    /// * `path` - Path to the log file (will be created or appended)
116    pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
117        let path = path.as_ref().to_path_buf();
118        let file = OpenOptions::new()
119            .create(true)
120            .append(true)
121            .open(&path)
122            .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
123
124        Ok(Self { file, path })
125    }
126
127    /// Create a new file logger, truncating the file if it exists.
128    ///
129    /// # Arguments
130    /// * `path` - Path to the log file
131    pub fn new_truncate<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
132        let path = path.as_ref().to_path_buf();
133        let file = OpenOptions::new()
134            .create(true)
135            .write(true)
136            .truncate(true)
137            .open(&path)
138            .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
139
140        Ok(Self { file, path })
141    }
142
143    /// Get the path to the log file.
144    pub fn path(&self) -> &Path {
145        &self.path
146    }
147}
148
149impl LoggingBackend for FileLogger {
150    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
151        writeln!(self.file, "Step {}: {} = {:.6}", step, name, value)
152            .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
153        Ok(())
154    }
155
156    fn log_text(&mut self, message: &str) -> TrainResult<()> {
157        writeln!(self.file, "{}", message)
158            .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
159        Ok(())
160    }
161
162    fn flush(&mut self) -> TrainResult<()> {
163        self.file
164            .flush()
165            .map_err(|e| TrainError::Other(format!("Failed to flush log file: {}", e)))?;
166        Ok(())
167    }
168}
169
170/// TensorBoard logger that writes real event files.
171///
172/// Writes TensorBoard event files in the tfevents format,
173/// which can be visualized using TensorBoard.
174///
175/// # Example
176/// ```no_run
177/// use tensorlogic_train::TensorBoardLogger;
178/// use tensorlogic_train::LoggingBackend;
179///
180/// let mut logger = TensorBoardLogger::new("./logs/run1").unwrap();
181/// logger.log_scalar("loss", 0.5, 1).unwrap();
182/// logger.log_scalar("accuracy", 0.95, 1).unwrap();
183/// logger.flush().unwrap();
184/// ```
185#[derive(Debug)]
186pub struct TensorBoardLogger {
187    log_dir: PathBuf,
188    writer: BufWriter<File>,
189    file_path: PathBuf,
190}
191
192impl TensorBoardLogger {
193    /// Create a new TensorBoard logger.
194    ///
195    /// # Arguments
196    /// * `log_dir` - Directory for TensorBoard logs
197    pub fn new<P: AsRef<Path>>(log_dir: P) -> TrainResult<Self> {
198        let log_dir = log_dir.as_ref().to_path_buf();
199
200        // Create directory if it doesn't exist
201        std::fs::create_dir_all(&log_dir).map_err(|e| {
202            TrainError::Other(format!(
203                "Failed to create log directory {:?}: {}",
204                log_dir, e
205            ))
206        })?;
207
208        // Create event file with TensorBoard naming convention
209        let timestamp = Utc::now().timestamp();
210        let hostname = hostname::get()
211            .map(|h| h.to_string_lossy().to_string())
212            .unwrap_or_else(|_| "localhost".to_string());
213        let filename = format!("events.out.tfevents.{}.{}", timestamp, hostname);
214        let file_path = log_dir.join(&filename);
215
216        let file = File::create(&file_path).map_err(|e| {
217            TrainError::Other(format!(
218                "Failed to create event file {:?}: {}",
219                file_path, e
220            ))
221        })?;
222
223        let mut logger = Self {
224            log_dir,
225            writer: BufWriter::new(file),
226            file_path,
227        };
228
229        // Write initial file_version event
230        logger.write_file_version()?;
231
232        Ok(logger)
233    }
234
235    /// Get the log directory.
236    pub fn log_dir(&self) -> &Path {
237        &self.log_dir
238    }
239
240    /// Get the event file path.
241    pub fn file_path(&self) -> &Path {
242        &self.file_path
243    }
244
245    /// Write the file version event (required by TensorBoard).
246    fn write_file_version(&mut self) -> TrainResult<()> {
247        let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
248
249        // Create file_version event
250        let event = TensorBoardEvent {
251            wall_time,
252            step: 0,
253            value: TensorBoardValue::FileVersion("brain.Event:2".to_string()),
254        };
255
256        self.write_event(&event)
257    }
258
259    /// Write a TensorBoard event.
260    fn write_event(&mut self, event: &TensorBoardEvent) -> TrainResult<()> {
261        let data = event.to_bytes();
262
263        // TensorBoard record format:
264        // uint64 length
265        // uint32 masked_crc32_of_length
266        // byte   data[length]
267        // uint32 masked_crc32_of_data
268
269        let length = data.len() as u64;
270        let length_bytes = length.to_le_bytes();
271        let length_crc = masked_crc32(&length_bytes);
272        let data_crc = masked_crc32(&data);
273
274        self.writer
275            .write_u64::<LittleEndian>(length)
276            .map_err(|e| TrainError::Other(format!("Failed to write event length: {}", e)))?;
277        self.writer
278            .write_u32::<LittleEndian>(length_crc)
279            .map_err(|e| TrainError::Other(format!("Failed to write length CRC: {}", e)))?;
280        self.writer
281            .write_all(&data)
282            .map_err(|e| TrainError::Other(format!("Failed to write event data: {}", e)))?;
283        self.writer
284            .write_u32::<LittleEndian>(data_crc)
285            .map_err(|e| TrainError::Other(format!("Failed to write data CRC: {}", e)))?;
286
287        Ok(())
288    }
289
290    /// Log a histogram (weight distributions).
291    pub fn log_histogram(&mut self, tag: &str, values: &[f64], step: usize) -> TrainResult<()> {
292        let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
293
294        let event = TensorBoardEvent {
295            wall_time,
296            step: step as i64,
297            value: TensorBoardValue::Histogram {
298                tag: tag.to_string(),
299                values: values.to_vec(),
300            },
301        };
302
303        self.write_event(&event)
304    }
305}
306
307impl LoggingBackend for TensorBoardLogger {
308    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
309        let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
310
311        let event = TensorBoardEvent {
312            wall_time,
313            step: step as i64,
314            value: TensorBoardValue::Scalar {
315                tag: name.to_string(),
316                value,
317            },
318        };
319
320        self.write_event(&event)
321    }
322
323    fn log_text(&mut self, message: &str) -> TrainResult<()> {
324        let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
325
326        let event = TensorBoardEvent {
327            wall_time,
328            step: 0,
329            value: TensorBoardValue::Text {
330                tag: "text".to_string(),
331                content: message.to_string(),
332            },
333        };
334
335        self.write_event(&event)
336    }
337
338    fn flush(&mut self) -> TrainResult<()> {
339        self.writer
340            .flush()
341            .map_err(|e| TrainError::Other(format!("Failed to flush TensorBoard writer: {}", e)))?;
342        Ok(())
343    }
344}
345
346/// TensorBoard event structure.
347#[derive(Debug)]
348struct TensorBoardEvent {
349    wall_time: f64,
350    step: i64,
351    value: TensorBoardValue,
352}
353
354/// TensorBoard value types.
355#[derive(Debug)]
356enum TensorBoardValue {
357    FileVersion(String),
358    Scalar { tag: String, value: f64 },
359    Histogram { tag: String, values: Vec<f64> },
360    Text { tag: String, content: String },
361}
362
363impl TensorBoardEvent {
364    /// Convert event to bytes (simplified protobuf-like format).
365    fn to_bytes(&self) -> Vec<u8> {
366        let mut bytes = Vec::new();
367
368        // Write wall_time (field 1, double)
369        bytes.push(0x09); // field 1, wire type 1 (64-bit)
370        bytes.extend_from_slice(&self.wall_time.to_le_bytes());
371
372        // Write step (field 2, int64)
373        bytes.push(0x10); // field 2, wire type 0 (varint)
374        write_varint(&mut bytes, self.step as u64);
375
376        match &self.value {
377            TensorBoardValue::FileVersion(version) => {
378                // Write file_version (field 3, string)
379                bytes.push(0x1a); // field 3, wire type 2 (length-delimited)
380                write_varint(&mut bytes, version.len() as u64);
381                bytes.extend_from_slice(version.as_bytes());
382            }
383            TensorBoardValue::Scalar { tag, value } => {
384                // Write summary (field 5)
385                let summary_bytes = encode_scalar_summary(tag, *value);
386                bytes.push(0x2a); // field 5, wire type 2
387                write_varint(&mut bytes, summary_bytes.len() as u64);
388                bytes.extend_from_slice(&summary_bytes);
389            }
390            TensorBoardValue::Histogram { tag, values } => {
391                // Write summary with histogram
392                let summary_bytes = encode_histogram_summary(tag, values);
393                bytes.push(0x2a);
394                write_varint(&mut bytes, summary_bytes.len() as u64);
395                bytes.extend_from_slice(&summary_bytes);
396            }
397            TensorBoardValue::Text { tag, content } => {
398                // Write summary with text
399                let summary_bytes = encode_text_summary(tag, content);
400                bytes.push(0x2a);
401                write_varint(&mut bytes, summary_bytes.len() as u64);
402                bytes.extend_from_slice(&summary_bytes);
403            }
404        }
405
406        bytes
407    }
408}
409
410/// Encode a scalar summary.
411fn encode_scalar_summary(tag: &str, value: f64) -> Vec<u8> {
412    let mut bytes = Vec::new();
413
414    // Summary message contains repeated Value
415    // Value: tag (field 1), simple_value (field 2)
416    let mut value_bytes = Vec::new();
417
418    // tag (field 1, string)
419    value_bytes.push(0x0a);
420    write_varint(&mut value_bytes, tag.len() as u64);
421    value_bytes.extend_from_slice(tag.as_bytes());
422
423    // simple_value (field 2, float)
424    value_bytes.push(0x15); // field 2, wire type 5 (32-bit)
425    value_bytes.extend_from_slice(&(value as f32).to_le_bytes());
426
427    // Wrap in Summary.value (field 1, repeated)
428    bytes.push(0x0a);
429    write_varint(&mut bytes, value_bytes.len() as u64);
430    bytes.extend_from_slice(&value_bytes);
431
432    bytes
433}
434
435/// Encode a histogram summary.
436fn encode_histogram_summary(tag: &str, values: &[f64]) -> Vec<u8> {
437    let mut bytes = Vec::new();
438
439    // Compute histogram statistics
440    let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
441    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
442    let sum: f64 = values.iter().sum();
443    let sum_squares: f64 = values.iter().map(|x| x * x).sum();
444
445    let mut value_bytes = Vec::new();
446
447    // tag
448    value_bytes.push(0x0a);
449    write_varint(&mut value_bytes, tag.len() as u64);
450    value_bytes.extend_from_slice(tag.as_bytes());
451
452    // histo (field 4) - simplified histogram encoding
453    let mut histo_bytes = Vec::new();
454
455    // min (field 1)
456    histo_bytes.push(0x09);
457    histo_bytes.extend_from_slice(&min.to_le_bytes());
458    // max (field 2)
459    histo_bytes.push(0x11);
460    histo_bytes.extend_from_slice(&max.to_le_bytes());
461    // num (field 3)
462    histo_bytes.push(0x18);
463    write_varint(&mut histo_bytes, values.len() as u64);
464    // sum (field 4)
465    histo_bytes.push(0x21);
466    histo_bytes.extend_from_slice(&sum.to_le_bytes());
467    // sum_squares (field 5)
468    histo_bytes.push(0x29);
469    histo_bytes.extend_from_slice(&sum_squares.to_le_bytes());
470
471    value_bytes.push(0x22); // field 4, wire type 2
472    write_varint(&mut value_bytes, histo_bytes.len() as u64);
473    value_bytes.extend_from_slice(&histo_bytes);
474
475    bytes.push(0x0a);
476    write_varint(&mut bytes, value_bytes.len() as u64);
477    bytes.extend_from_slice(&value_bytes);
478
479    bytes
480}
481
482/// Encode a text summary.
483fn encode_text_summary(tag: &str, content: &str) -> Vec<u8> {
484    let mut bytes = Vec::new();
485
486    let mut value_bytes = Vec::new();
487
488    // tag
489    value_bytes.push(0x0a);
490    write_varint(&mut value_bytes, tag.len() as u64);
491    value_bytes.extend_from_slice(tag.as_bytes());
492
493    // tensor (field 8) for text
494    let mut tensor_bytes = Vec::new();
495    // dtype = DT_STRING (7)
496    tensor_bytes.push(0x08);
497    write_varint(&mut tensor_bytes, 7);
498    // string_val (field 8)
499    tensor_bytes.push(0x42);
500    write_varint(&mut tensor_bytes, content.len() as u64);
501    tensor_bytes.extend_from_slice(content.as_bytes());
502
503    value_bytes.push(0x42); // field 8, wire type 2
504    write_varint(&mut value_bytes, tensor_bytes.len() as u64);
505    value_bytes.extend_from_slice(&tensor_bytes);
506
507    bytes.push(0x0a);
508    write_varint(&mut bytes, value_bytes.len() as u64);
509    bytes.extend_from_slice(&value_bytes);
510
511    bytes
512}
513
514/// Write a varint to the buffer.
515fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
516    loop {
517        let byte = (value & 0x7f) as u8;
518        value >>= 7;
519        if value == 0 {
520            buf.push(byte);
521            break;
522        } else {
523            buf.push(byte | 0x80);
524        }
525    }
526}
527
528/// Compute masked CRC32.
529fn masked_crc32(data: &[u8]) -> u32 {
530    let crc = crc32fast::hash(data);
531    crc.rotate_right(15).wrapping_add(0xa282ead8)
532}
533
534/// CSV logger for easy data analysis.
535///
536/// Writes metrics to a CSV file that can be imported into spreadsheets
537/// or analyzed with pandas/numpy.
538///
539/// # Example
540/// ```no_run
541/// use tensorlogic_train::CsvLogger;
542/// use tensorlogic_train::LoggingBackend;
543///
544/// let mut logger = CsvLogger::new("/tmp/metrics.csv").unwrap();
545/// logger.log_scalar("loss", 0.5, 1).unwrap();
546/// logger.log_scalar("accuracy", 0.95, 1).unwrap();
547/// logger.flush().unwrap();
548/// ```
549#[derive(Debug)]
550pub struct CsvLogger {
551    writer: BufWriter<File>,
552    path: PathBuf,
553    header_written: bool,
554}
555
556impl CsvLogger {
557    /// Create a new CSV logger.
558    ///
559    /// # Arguments
560    /// * `path` - Path to the CSV file
561    pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
562        let path = path.as_ref().to_path_buf();
563        let file = File::create(&path).map_err(|e| {
564            TrainError::Other(format!("Failed to create CSV file {:?}: {}", path, e))
565        })?;
566
567        let mut logger = Self {
568            writer: BufWriter::new(file),
569            path,
570            header_written: false,
571        };
572
573        // Write header
574        writeln!(logger.writer, "step,metric,value,timestamp")
575            .map_err(|e| TrainError::Other(format!("Failed to write CSV header: {}", e)))?;
576        logger.header_written = true;
577
578        Ok(logger)
579    }
580
581    /// Get the path to the CSV file.
582    pub fn path(&self) -> &Path {
583        &self.path
584    }
585}
586
587impl LoggingBackend for CsvLogger {
588    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
589        let timestamp = Utc::now().to_rfc3339();
590        writeln!(self.writer, "{},{},{:.6},{}", step, name, value, timestamp)
591            .map_err(|e| TrainError::Other(format!("Failed to write to CSV: {}", e)))?;
592        Ok(())
593    }
594
595    fn log_text(&mut self, message: &str) -> TrainResult<()> {
596        let timestamp = Utc::now().to_rfc3339();
597        // Escape message for CSV
598        let escaped = message.replace('"', "\"\"");
599        writeln!(self.writer, "0,text,\"{}\",{}", escaped, timestamp)
600            .map_err(|e| TrainError::Other(format!("Failed to write to CSV: {}", e)))?;
601        Ok(())
602    }
603
604    fn flush(&mut self) -> TrainResult<()> {
605        self.writer
606            .flush()
607            .map_err(|e| TrainError::Other(format!("Failed to flush CSV writer: {}", e)))?;
608        Ok(())
609    }
610}
611
612impl Clone for CsvLogger {
613    fn clone(&self) -> Self {
614        // Create a new logger pointing to the same file in append mode
615        Self::new(&self.path).expect("Failed to clone CsvLogger")
616    }
617}
618
619/// JSONL (JSON Lines) logger for machine-readable output.
620///
621/// Writes each metric as a JSON object on its own line,
622/// making it easy to parse and process programmatically.
623///
624/// # Example
625/// ```no_run
626/// use tensorlogic_train::JsonlLogger;
627/// use tensorlogic_train::LoggingBackend;
628///
629/// let mut logger = JsonlLogger::new("/tmp/metrics.jsonl").unwrap();
630/// logger.log_scalar("loss", 0.5, 1).unwrap();
631/// logger.log_scalar("accuracy", 0.95, 1).unwrap();
632/// logger.flush().unwrap();
633/// ```
634#[derive(Debug)]
635pub struct JsonlLogger {
636    writer: BufWriter<File>,
637    path: PathBuf,
638}
639
640impl JsonlLogger {
641    /// Create a new JSONL logger.
642    ///
643    /// # Arguments
644    /// * `path` - Path to the JSONL file
645    pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
646        let path = path.as_ref().to_path_buf();
647        let file = File::create(&path).map_err(|e| {
648            TrainError::Other(format!("Failed to create JSONL file {:?}: {}", path, e))
649        })?;
650
651        Ok(Self {
652            writer: BufWriter::new(file),
653            path,
654        })
655    }
656
657    /// Get the path to the JSONL file.
658    pub fn path(&self) -> &Path {
659        &self.path
660    }
661}
662
663impl LoggingBackend for JsonlLogger {
664    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
665        let timestamp = Utc::now().to_rfc3339();
666        let json = format!(
667            r#"{{"type":"scalar","step":{},"metric":"{}","value":{},"timestamp":"{}"}}"#,
668            step, name, value, timestamp
669        );
670        writeln!(self.writer, "{}", json)
671            .map_err(|e| TrainError::Other(format!("Failed to write to JSONL: {}", e)))?;
672        Ok(())
673    }
674
675    fn log_text(&mut self, message: &str) -> TrainResult<()> {
676        let timestamp = Utc::now().to_rfc3339();
677        // Escape message for JSON
678        let escaped = message.replace('\\', "\\\\").replace('"', "\\\"");
679        let json = format!(
680            r#"{{"type":"text","step":0,"message":"{}","timestamp":"{}"}}"#,
681            escaped, timestamp
682        );
683        writeln!(self.writer, "{}", json)
684            .map_err(|e| TrainError::Other(format!("Failed to write to JSONL: {}", e)))?;
685        Ok(())
686    }
687
688    fn flush(&mut self) -> TrainResult<()> {
689        self.writer
690            .flush()
691            .map_err(|e| TrainError::Other(format!("Failed to flush JSONL writer: {}", e)))?;
692        Ok(())
693    }
694}
695
696impl Clone for JsonlLogger {
697    fn clone(&self) -> Self {
698        // Create a new logger pointing to the same file in append mode
699        Self::new(&self.path).expect("Failed to clone JsonlLogger")
700    }
701}
702
703/// Metrics logger that aggregates and logs training metrics.
704///
705/// Collects metrics and logs them using multiple backends.
706#[derive(Debug)]
707pub struct MetricsLogger {
708    backends: Vec<Box<dyn LoggingBackendClone>>,
709    current_step: usize,
710    accumulated_metrics: HashMap<String, Vec<f64>>,
711}
712
713/// Helper trait for cloning boxed logging backends.
714trait LoggingBackendClone: LoggingBackend + std::fmt::Debug {
715    fn clone_box(&self) -> Box<dyn LoggingBackendClone>;
716}
717
718impl<T: LoggingBackend + Clone + std::fmt::Debug + 'static> LoggingBackendClone for T {
719    fn clone_box(&self) -> Box<dyn LoggingBackendClone> {
720        Box::new(self.clone())
721    }
722}
723
724impl Clone for Box<dyn LoggingBackendClone> {
725    fn clone(&self) -> Self {
726        self.clone_box()
727    }
728}
729
730impl LoggingBackend for Box<dyn LoggingBackendClone> {
731    fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
732        (**self).log_scalar(name, value, step)
733    }
734
735    fn log_text(&mut self, message: &str) -> TrainResult<()> {
736        (**self).log_text(message)
737    }
738
739    fn flush(&mut self) -> TrainResult<()> {
740        (**self).flush()
741    }
742}
743
744impl MetricsLogger {
745    /// Create a new metrics logger.
746    pub fn new() -> Self {
747        Self {
748            backends: Vec::new(),
749            current_step: 0,
750            accumulated_metrics: HashMap::new(),
751        }
752    }
753
754    /// Add a logging backend.
755    ///
756    /// # Arguments
757    /// * `backend` - Backend to add
758    pub fn add_backend<B: LoggingBackend + Clone + std::fmt::Debug + 'static>(
759        &mut self,
760        backend: B,
761    ) {
762        self.backends.push(Box::new(backend));
763    }
764
765    /// Log a scalar metric to all backends.
766    ///
767    /// # Arguments
768    /// * `name` - Metric name
769    /// * `value` - Metric value
770    pub fn log_metric(&mut self, name: &str, value: f64) -> TrainResult<()> {
771        for backend in &mut self.backends {
772            backend.log_scalar(name, value, self.current_step)?;
773        }
774        Ok(())
775    }
776
777    /// Accumulate a metric value (for averaging over batch).
778    ///
779    /// # Arguments
780    /// * `name` - Metric name
781    /// * `value` - Metric value
782    pub fn accumulate_metric(&mut self, name: &str, value: f64) {
783        self.accumulated_metrics
784            .entry(name.to_string())
785            .or_default()
786            .push(value);
787    }
788
789    /// Log accumulated metrics (average) and clear accumulation.
790    pub fn log_accumulated_metrics(&mut self) -> TrainResult<()> {
791        // Collect metrics to log before clearing
792        let metrics_to_log: Vec<(String, f64)> = self
793            .accumulated_metrics
794            .iter()
795            .filter(|(_, values)| !values.is_empty())
796            .map(|(name, values)| {
797                let avg = values.iter().sum::<f64>() / values.len() as f64;
798                (name.clone(), avg)
799            })
800            .collect();
801
802        // Log all metrics
803        for (name, avg) in metrics_to_log {
804            self.log_metric(&name, avg)?;
805        }
806
807        // Clear accumulation
808        self.accumulated_metrics.clear();
809        Ok(())
810    }
811
812    /// Log a text message to all backends.
813    ///
814    /// # Arguments
815    /// * `message` - Text message
816    pub fn log_message(&mut self, message: &str) -> TrainResult<()> {
817        for backend in &mut self.backends {
818            backend.log_text(message)?;
819        }
820        Ok(())
821    }
822
823    /// Increment the step counter.
824    pub fn step(&mut self) {
825        self.current_step += 1;
826    }
827
828    /// Set the current step.
829    ///
830    /// # Arguments
831    /// * `step` - Step number
832    pub fn set_step(&mut self, step: usize) {
833        self.current_step = step;
834    }
835
836    /// Get the current step.
837    pub fn current_step(&self) -> usize {
838        self.current_step
839    }
840
841    /// Flush all backends.
842    pub fn flush(&mut self) -> TrainResult<()> {
843        for backend in &mut self.backends {
844            backend.flush()?;
845        }
846        Ok(())
847    }
848
849    /// Get the number of backends.
850    pub fn num_backends(&self) -> usize {
851        self.backends.len()
852    }
853}
854
855impl Default for MetricsLogger {
856    fn default() -> Self {
857        Self::new()
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864    use std::env;
865    use std::fs;
866
867    #[test]
868    fn test_console_logger() {
869        let mut logger = ConsoleLogger::new();
870
871        // These should not fail
872        logger.log_scalar("loss", 0.5, 1).unwrap();
873        logger.log_text("Test message").unwrap();
874        logger.flush().unwrap();
875    }
876
877    #[test]
878    fn test_console_logger_without_timestamp() {
879        let mut logger = ConsoleLogger::without_timestamp();
880
881        logger.log_scalar("accuracy", 0.95, 10).unwrap();
882        logger.log_text("Another test").unwrap();
883    }
884
885    #[test]
886    fn test_file_logger() {
887        let temp_dir = env::temp_dir();
888        let log_path = temp_dir.join("test_training.log");
889
890        // Clean up if file exists
891        let _ = fs::remove_file(&log_path);
892
893        let mut logger = FileLogger::new(&log_path).unwrap();
894
895        logger.log_scalar("loss", 0.5, 1).unwrap();
896        logger.log_scalar("accuracy", 0.9, 1).unwrap();
897        logger.log_text("Training started").unwrap();
898        logger.flush().unwrap();
899
900        // Verify file was created
901        assert!(log_path.exists());
902
903        // Read and verify contents
904        let contents = fs::read_to_string(&log_path).unwrap();
905        assert!(contents.contains("loss = 0.500000"));
906        assert!(contents.contains("accuracy = 0.900000"));
907        assert!(contents.contains("Training started"));
908
909        // Clean up
910        fs::remove_file(&log_path).unwrap();
911    }
912
913    #[test]
914    fn test_file_logger_truncate() {
915        let temp_dir = env::temp_dir();
916        let log_path = temp_dir.join("test_training_truncate.log");
917
918        // Create file with some content
919        {
920            let mut logger = FileLogger::new(&log_path).unwrap();
921            logger.log_text("Old content").unwrap();
922            logger.flush().unwrap();
923        }
924
925        // Truncate and write new content
926        {
927            let mut logger = FileLogger::new_truncate(&log_path).unwrap();
928            logger.log_text("New content").unwrap();
929            logger.flush().unwrap();
930        }
931
932        // Verify old content is gone
933        let contents = fs::read_to_string(&log_path).unwrap();
934        assert!(!contents.contains("Old content"));
935        assert!(contents.contains("New content"));
936
937        // Clean up
938        fs::remove_file(&log_path).unwrap();
939    }
940
941    #[test]
942    fn test_tensorboard_logger() {
943        let temp_dir = env::temp_dir();
944        let tb_dir = temp_dir.join("test_tensorboard_real");
945
946        // Clean up if directory exists
947        let _ = fs::remove_dir_all(&tb_dir);
948
949        let mut logger = TensorBoardLogger::new(&tb_dir).unwrap();
950
951        // Directory should be created
952        assert!(tb_dir.exists());
953
954        // Log some scalars
955        logger.log_scalar("loss", 0.5, 1).unwrap();
956        logger.log_scalar("accuracy", 0.95, 1).unwrap();
957        logger.log_text("Test message").unwrap();
958
959        // Log a histogram
960        let values = vec![0.1, 0.2, 0.3, 0.4, 0.5];
961        logger.log_histogram("weights", &values, 1).unwrap();
962
963        logger.flush().unwrap();
964
965        // Verify event file was created
966        let event_file = logger.file_path();
967        assert!(event_file.exists());
968        assert!(event_file.to_string_lossy().contains("tfevents"));
969
970        // Clean up
971        fs::remove_dir_all(&tb_dir).unwrap();
972    }
973
974    #[test]
975    fn test_csv_logger() {
976        let temp_dir = env::temp_dir();
977        let csv_path = temp_dir.join("test_metrics.csv");
978
979        // Clean up if file exists
980        let _ = fs::remove_file(&csv_path);
981
982        let mut logger = CsvLogger::new(&csv_path).unwrap();
983
984        logger.log_scalar("loss", 0.5, 1).unwrap();
985        logger.log_scalar("accuracy", 0.95, 2).unwrap();
986        logger.log_text("Training started").unwrap();
987        logger.flush().unwrap();
988
989        // Verify file was created
990        assert!(csv_path.exists());
991
992        // Read and verify contents
993        let contents = fs::read_to_string(&csv_path).unwrap();
994        assert!(contents.contains("step,metric,value,timestamp")); // Header
995        assert!(contents.contains("1,loss,0.500000"));
996        assert!(contents.contains("2,accuracy,0.950000"));
997        assert!(contents.contains("Training started"));
998
999        // Clean up
1000        fs::remove_file(&csv_path).unwrap();
1001    }
1002
1003    #[test]
1004    fn test_jsonl_logger() {
1005        let temp_dir = env::temp_dir();
1006        let jsonl_path = temp_dir.join("test_metrics.jsonl");
1007
1008        // Clean up if file exists
1009        let _ = fs::remove_file(&jsonl_path);
1010
1011        let mut logger = JsonlLogger::new(&jsonl_path).unwrap();
1012
1013        logger.log_scalar("loss", 0.5, 1).unwrap();
1014        logger.log_scalar("accuracy", 0.95, 2).unwrap();
1015        logger.log_text("Training started").unwrap();
1016        logger.flush().unwrap();
1017
1018        // Verify file was created
1019        assert!(jsonl_path.exists());
1020
1021        // Read and verify contents
1022        let contents = fs::read_to_string(&jsonl_path).unwrap();
1023        let lines: Vec<&str> = contents.lines().collect();
1024        assert_eq!(lines.len(), 3);
1025
1026        // Verify first line is valid JSON
1027        assert!(lines[0].contains("\"type\":\"scalar\""));
1028        assert!(lines[0].contains("\"metric\":\"loss\""));
1029        assert!(lines[0].contains("\"value\":0.5"));
1030
1031        // Verify second line
1032        assert!(lines[1].contains("\"metric\":\"accuracy\""));
1033        assert!(lines[1].contains("\"value\":0.95"));
1034
1035        // Verify text message
1036        assert!(lines[2].contains("\"type\":\"text\""));
1037        assert!(lines[2].contains("Training started"));
1038
1039        // Clean up
1040        fs::remove_file(&jsonl_path).unwrap();
1041    }
1042
1043    #[test]
1044    fn test_csv_logger_path() {
1045        let temp_dir = env::temp_dir();
1046        let csv_path = temp_dir.join("test_csv_path.csv");
1047        let _ = fs::remove_file(&csv_path);
1048
1049        let logger = CsvLogger::new(&csv_path).unwrap();
1050        assert_eq!(logger.path(), csv_path.as_path());
1051
1052        // Clean up
1053        fs::remove_file(&csv_path).unwrap();
1054    }
1055
1056    #[test]
1057    fn test_jsonl_logger_path() {
1058        let temp_dir = env::temp_dir();
1059        let jsonl_path = temp_dir.join("test_jsonl_path.jsonl");
1060        let _ = fs::remove_file(&jsonl_path);
1061
1062        let logger = JsonlLogger::new(&jsonl_path).unwrap();
1063        assert_eq!(logger.path(), jsonl_path.as_path());
1064
1065        // Clean up
1066        fs::remove_file(&jsonl_path).unwrap();
1067    }
1068
1069    #[test]
1070    fn test_metrics_logger() {
1071        let mut logger = MetricsLogger::new();
1072        assert_eq!(logger.num_backends(), 0);
1073
1074        logger.add_backend(ConsoleLogger::without_timestamp());
1075        assert_eq!(logger.num_backends(), 1);
1076
1077        logger.log_metric("loss", 0.5).unwrap();
1078        logger.log_message("Epoch 1").unwrap();
1079
1080        assert_eq!(logger.current_step(), 0);
1081        logger.step();
1082        assert_eq!(logger.current_step(), 1);
1083
1084        logger.set_step(10);
1085        assert_eq!(logger.current_step(), 10);
1086
1087        logger.flush().unwrap();
1088    }
1089
1090    #[test]
1091    fn test_metrics_logger_accumulation() {
1092        let mut logger = MetricsLogger::new();
1093        logger.add_backend(ConsoleLogger::without_timestamp());
1094
1095        // Accumulate multiple values
1096        logger.accumulate_metric("batch_loss", 0.5);
1097        logger.accumulate_metric("batch_loss", 0.4);
1098        logger.accumulate_metric("batch_loss", 0.6);
1099
1100        // Log accumulated (should be average: 0.5)
1101        logger.log_accumulated_metrics().unwrap();
1102
1103        // Accumulation should be cleared
1104        logger.log_accumulated_metrics().unwrap(); // Should not fail even if empty
1105    }
1106
1107    #[test]
1108    fn test_metrics_logger_multiple_backends() {
1109        let mut logger = MetricsLogger::new();
1110        logger.add_backend(ConsoleLogger::without_timestamp());
1111        logger.add_backend(ConsoleLogger::new());
1112
1113        assert_eq!(logger.num_backends(), 2);
1114
1115        logger.log_metric("loss", 0.5).unwrap();
1116        logger.flush().unwrap();
1117    }
1118
1119    #[test]
1120    fn test_metrics_logger_empty_accumulation() {
1121        let mut logger = MetricsLogger::new();
1122        logger.add_backend(ConsoleLogger::without_timestamp());
1123
1124        // Log without accumulating anything
1125        logger.log_accumulated_metrics().unwrap();
1126    }
1127
1128    #[test]
1129    fn test_file_logger_path() {
1130        let temp_dir = env::temp_dir();
1131        let log_path = temp_dir.join("test_path.log");
1132        let _ = fs::remove_file(&log_path);
1133
1134        let logger = FileLogger::new(&log_path).unwrap();
1135        assert_eq!(logger.path(), log_path.as_path());
1136
1137        // Clean up
1138        fs::remove_file(&log_path).unwrap();
1139    }
1140
1141    #[test]
1142    fn test_tensorboard_logger_log_dir() {
1143        let temp_dir = env::temp_dir();
1144        let tb_dir = temp_dir.join("test_tb_path");
1145        let _ = fs::remove_dir_all(&tb_dir);
1146
1147        let logger = TensorBoardLogger::new(&tb_dir).unwrap();
1148        assert_eq!(logger.log_dir(), tb_dir.as_path());
1149
1150        // Clean up
1151        fs::remove_dir_all(&tb_dir).unwrap();
1152    }
1153}