1use crate::{TrainError, TrainResult};
12use byteorder::{LittleEndian, WriteBytesExt};
13use chrono::Utc;
14use std::collections::HashMap;
15use std::fs::{File, OpenOptions};
16use std::io::{BufWriter, Write};
17use std::path::{Path, PathBuf};
18
19pub trait LoggingBackend {
21 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()>;
28
29 fn log_text(&mut self, message: &str) -> TrainResult<()>;
34
35 fn flush(&mut self) -> TrainResult<()>;
37}
38
39#[derive(Debug, Clone, Default)]
43pub struct ConsoleLogger {
44 pub include_timestamp: bool,
46}
47
48impl ConsoleLogger {
49 pub fn new() -> Self {
51 Self {
52 include_timestamp: true,
53 }
54 }
55
56 pub fn without_timestamp() -> Self {
58 Self {
59 include_timestamp: false,
60 }
61 }
62
63 fn format_timestamp(&self) -> String {
64 if self.include_timestamp {
65 let now = std::time::SystemTime::now();
66 match now.duration_since(std::time::UNIX_EPOCH) {
67 Ok(duration) => format!("[{:.3}] ", duration.as_secs_f64()),
68 Err(_) => String::new(),
69 }
70 } else {
71 String::new()
72 }
73 }
74}
75
76impl LoggingBackend for ConsoleLogger {
77 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
78 println!(
79 "{}Step {}: {} = {:.6}",
80 self.format_timestamp(),
81 step,
82 name,
83 value
84 );
85 Ok(())
86 }
87
88 fn log_text(&mut self, message: &str) -> TrainResult<()> {
89 println!("{}{}", self.format_timestamp(), message);
90 Ok(())
91 }
92
93 fn flush(&mut self) -> TrainResult<()> {
94 use std::io::stdout;
95 stdout()
96 .flush()
97 .map_err(|e| TrainError::Other(format!("Failed to flush stdout: {}", e)))?;
98 Ok(())
99 }
100}
101
102#[derive(Debug)]
106pub struct FileLogger {
107 file: File,
108 path: PathBuf,
109}
110
111impl FileLogger {
112 pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
117 let path = path.as_ref().to_path_buf();
118 let file = OpenOptions::new()
119 .create(true)
120 .append(true)
121 .open(&path)
122 .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
123
124 Ok(Self { file, path })
125 }
126
127 pub fn new_truncate<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
132 let path = path.as_ref().to_path_buf();
133 let file = OpenOptions::new()
134 .create(true)
135 .write(true)
136 .truncate(true)
137 .open(&path)
138 .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
139
140 Ok(Self { file, path })
141 }
142
143 pub fn path(&self) -> &Path {
145 &self.path
146 }
147}
148
149impl LoggingBackend for FileLogger {
150 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
151 writeln!(self.file, "Step {}: {} = {:.6}", step, name, value)
152 .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
153 Ok(())
154 }
155
156 fn log_text(&mut self, message: &str) -> TrainResult<()> {
157 writeln!(self.file, "{}", message)
158 .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
159 Ok(())
160 }
161
162 fn flush(&mut self) -> TrainResult<()> {
163 self.file
164 .flush()
165 .map_err(|e| TrainError::Other(format!("Failed to flush log file: {}", e)))?;
166 Ok(())
167 }
168}
169
170#[derive(Debug)]
186pub struct TensorBoardLogger {
187 log_dir: PathBuf,
188 writer: BufWriter<File>,
189 file_path: PathBuf,
190}
191
192impl TensorBoardLogger {
193 pub fn new<P: AsRef<Path>>(log_dir: P) -> TrainResult<Self> {
198 let log_dir = log_dir.as_ref().to_path_buf();
199
200 std::fs::create_dir_all(&log_dir).map_err(|e| {
202 TrainError::Other(format!(
203 "Failed to create log directory {:?}: {}",
204 log_dir, e
205 ))
206 })?;
207
208 let timestamp = Utc::now().timestamp();
210 let hostname = hostname::get()
211 .map(|h| h.to_string_lossy().to_string())
212 .unwrap_or_else(|_| "localhost".to_string());
213 let filename = format!("events.out.tfevents.{}.{}", timestamp, hostname);
214 let file_path = log_dir.join(&filename);
215
216 let file = File::create(&file_path).map_err(|e| {
217 TrainError::Other(format!(
218 "Failed to create event file {:?}: {}",
219 file_path, e
220 ))
221 })?;
222
223 let mut logger = Self {
224 log_dir,
225 writer: BufWriter::new(file),
226 file_path,
227 };
228
229 logger.write_file_version()?;
231
232 Ok(logger)
233 }
234
235 pub fn log_dir(&self) -> &Path {
237 &self.log_dir
238 }
239
240 pub fn file_path(&self) -> &Path {
242 &self.file_path
243 }
244
245 fn write_file_version(&mut self) -> TrainResult<()> {
247 let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
248
249 let event = TensorBoardEvent {
251 wall_time,
252 step: 0,
253 value: TensorBoardValue::FileVersion("brain.Event:2".to_string()),
254 };
255
256 self.write_event(&event)
257 }
258
259 fn write_event(&mut self, event: &TensorBoardEvent) -> TrainResult<()> {
261 let data = event.to_bytes();
262
263 let length = data.len() as u64;
270 let length_bytes = length.to_le_bytes();
271 let length_crc = masked_crc32(&length_bytes);
272 let data_crc = masked_crc32(&data);
273
274 self.writer
275 .write_u64::<LittleEndian>(length)
276 .map_err(|e| TrainError::Other(format!("Failed to write event length: {}", e)))?;
277 self.writer
278 .write_u32::<LittleEndian>(length_crc)
279 .map_err(|e| TrainError::Other(format!("Failed to write length CRC: {}", e)))?;
280 self.writer
281 .write_all(&data)
282 .map_err(|e| TrainError::Other(format!("Failed to write event data: {}", e)))?;
283 self.writer
284 .write_u32::<LittleEndian>(data_crc)
285 .map_err(|e| TrainError::Other(format!("Failed to write data CRC: {}", e)))?;
286
287 Ok(())
288 }
289
290 pub fn log_histogram(&mut self, tag: &str, values: &[f64], step: usize) -> TrainResult<()> {
292 let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
293
294 let event = TensorBoardEvent {
295 wall_time,
296 step: step as i64,
297 value: TensorBoardValue::Histogram {
298 tag: tag.to_string(),
299 values: values.to_vec(),
300 },
301 };
302
303 self.write_event(&event)
304 }
305}
306
307impl LoggingBackend for TensorBoardLogger {
308 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
309 let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
310
311 let event = TensorBoardEvent {
312 wall_time,
313 step: step as i64,
314 value: TensorBoardValue::Scalar {
315 tag: name.to_string(),
316 value,
317 },
318 };
319
320 self.write_event(&event)
321 }
322
323 fn log_text(&mut self, message: &str) -> TrainResult<()> {
324 let wall_time = Utc::now().timestamp_micros() as f64 / 1_000_000.0;
325
326 let event = TensorBoardEvent {
327 wall_time,
328 step: 0,
329 value: TensorBoardValue::Text {
330 tag: "text".to_string(),
331 content: message.to_string(),
332 },
333 };
334
335 self.write_event(&event)
336 }
337
338 fn flush(&mut self) -> TrainResult<()> {
339 self.writer
340 .flush()
341 .map_err(|e| TrainError::Other(format!("Failed to flush TensorBoard writer: {}", e)))?;
342 Ok(())
343 }
344}
345
346#[derive(Debug)]
348struct TensorBoardEvent {
349 wall_time: f64,
350 step: i64,
351 value: TensorBoardValue,
352}
353
354#[derive(Debug)]
356enum TensorBoardValue {
357 FileVersion(String),
358 Scalar { tag: String, value: f64 },
359 Histogram { tag: String, values: Vec<f64> },
360 Text { tag: String, content: String },
361}
362
363impl TensorBoardEvent {
364 fn to_bytes(&self) -> Vec<u8> {
366 let mut bytes = Vec::new();
367
368 bytes.push(0x09); bytes.extend_from_slice(&self.wall_time.to_le_bytes());
371
372 bytes.push(0x10); write_varint(&mut bytes, self.step as u64);
375
376 match &self.value {
377 TensorBoardValue::FileVersion(version) => {
378 bytes.push(0x1a); write_varint(&mut bytes, version.len() as u64);
381 bytes.extend_from_slice(version.as_bytes());
382 }
383 TensorBoardValue::Scalar { tag, value } => {
384 let summary_bytes = encode_scalar_summary(tag, *value);
386 bytes.push(0x2a); write_varint(&mut bytes, summary_bytes.len() as u64);
388 bytes.extend_from_slice(&summary_bytes);
389 }
390 TensorBoardValue::Histogram { tag, values } => {
391 let summary_bytes = encode_histogram_summary(tag, values);
393 bytes.push(0x2a);
394 write_varint(&mut bytes, summary_bytes.len() as u64);
395 bytes.extend_from_slice(&summary_bytes);
396 }
397 TensorBoardValue::Text { tag, content } => {
398 let summary_bytes = encode_text_summary(tag, content);
400 bytes.push(0x2a);
401 write_varint(&mut bytes, summary_bytes.len() as u64);
402 bytes.extend_from_slice(&summary_bytes);
403 }
404 }
405
406 bytes
407 }
408}
409
410fn encode_scalar_summary(tag: &str, value: f64) -> Vec<u8> {
412 let mut bytes = Vec::new();
413
414 let mut value_bytes = Vec::new();
417
418 value_bytes.push(0x0a);
420 write_varint(&mut value_bytes, tag.len() as u64);
421 value_bytes.extend_from_slice(tag.as_bytes());
422
423 value_bytes.push(0x15); value_bytes.extend_from_slice(&(value as f32).to_le_bytes());
426
427 bytes.push(0x0a);
429 write_varint(&mut bytes, value_bytes.len() as u64);
430 bytes.extend_from_slice(&value_bytes);
431
432 bytes
433}
434
435fn encode_histogram_summary(tag: &str, values: &[f64]) -> Vec<u8> {
437 let mut bytes = Vec::new();
438
439 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
441 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
442 let sum: f64 = values.iter().sum();
443 let sum_squares: f64 = values.iter().map(|x| x * x).sum();
444
445 let mut value_bytes = Vec::new();
446
447 value_bytes.push(0x0a);
449 write_varint(&mut value_bytes, tag.len() as u64);
450 value_bytes.extend_from_slice(tag.as_bytes());
451
452 let mut histo_bytes = Vec::new();
454
455 histo_bytes.push(0x09);
457 histo_bytes.extend_from_slice(&min.to_le_bytes());
458 histo_bytes.push(0x11);
460 histo_bytes.extend_from_slice(&max.to_le_bytes());
461 histo_bytes.push(0x18);
463 write_varint(&mut histo_bytes, values.len() as u64);
464 histo_bytes.push(0x21);
466 histo_bytes.extend_from_slice(&sum.to_le_bytes());
467 histo_bytes.push(0x29);
469 histo_bytes.extend_from_slice(&sum_squares.to_le_bytes());
470
471 value_bytes.push(0x22); write_varint(&mut value_bytes, histo_bytes.len() as u64);
473 value_bytes.extend_from_slice(&histo_bytes);
474
475 bytes.push(0x0a);
476 write_varint(&mut bytes, value_bytes.len() as u64);
477 bytes.extend_from_slice(&value_bytes);
478
479 bytes
480}
481
482fn encode_text_summary(tag: &str, content: &str) -> Vec<u8> {
484 let mut bytes = Vec::new();
485
486 let mut value_bytes = Vec::new();
487
488 value_bytes.push(0x0a);
490 write_varint(&mut value_bytes, tag.len() as u64);
491 value_bytes.extend_from_slice(tag.as_bytes());
492
493 let mut tensor_bytes = Vec::new();
495 tensor_bytes.push(0x08);
497 write_varint(&mut tensor_bytes, 7);
498 tensor_bytes.push(0x42);
500 write_varint(&mut tensor_bytes, content.len() as u64);
501 tensor_bytes.extend_from_slice(content.as_bytes());
502
503 value_bytes.push(0x42); write_varint(&mut value_bytes, tensor_bytes.len() as u64);
505 value_bytes.extend_from_slice(&tensor_bytes);
506
507 bytes.push(0x0a);
508 write_varint(&mut bytes, value_bytes.len() as u64);
509 bytes.extend_from_slice(&value_bytes);
510
511 bytes
512}
513
514fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
516 loop {
517 let byte = (value & 0x7f) as u8;
518 value >>= 7;
519 if value == 0 {
520 buf.push(byte);
521 break;
522 } else {
523 buf.push(byte | 0x80);
524 }
525 }
526}
527
528fn masked_crc32(data: &[u8]) -> u32 {
530 let crc = crc32fast::hash(data);
531 crc.rotate_right(15).wrapping_add(0xa282ead8)
532}
533
534#[derive(Debug)]
550pub struct CsvLogger {
551 writer: BufWriter<File>,
552 path: PathBuf,
553 header_written: bool,
554}
555
556impl CsvLogger {
557 pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
562 let path = path.as_ref().to_path_buf();
563 let file = File::create(&path).map_err(|e| {
564 TrainError::Other(format!("Failed to create CSV file {:?}: {}", path, e))
565 })?;
566
567 let mut logger = Self {
568 writer: BufWriter::new(file),
569 path,
570 header_written: false,
571 };
572
573 writeln!(logger.writer, "step,metric,value,timestamp")
575 .map_err(|e| TrainError::Other(format!("Failed to write CSV header: {}", e)))?;
576 logger.header_written = true;
577
578 Ok(logger)
579 }
580
581 pub fn path(&self) -> &Path {
583 &self.path
584 }
585}
586
587impl LoggingBackend for CsvLogger {
588 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
589 let timestamp = Utc::now().to_rfc3339();
590 writeln!(self.writer, "{},{},{:.6},{}", step, name, value, timestamp)
591 .map_err(|e| TrainError::Other(format!("Failed to write to CSV: {}", e)))?;
592 Ok(())
593 }
594
595 fn log_text(&mut self, message: &str) -> TrainResult<()> {
596 let timestamp = Utc::now().to_rfc3339();
597 let escaped = message.replace('"', "\"\"");
599 writeln!(self.writer, "0,text,\"{}\",{}", escaped, timestamp)
600 .map_err(|e| TrainError::Other(format!("Failed to write to CSV: {}", e)))?;
601 Ok(())
602 }
603
604 fn flush(&mut self) -> TrainResult<()> {
605 self.writer
606 .flush()
607 .map_err(|e| TrainError::Other(format!("Failed to flush CSV writer: {}", e)))?;
608 Ok(())
609 }
610}
611
612impl Clone for CsvLogger {
613 fn clone(&self) -> Self {
614 Self::new(&self.path).expect("Failed to clone CsvLogger")
616 }
617}
618
619#[derive(Debug)]
635pub struct JsonlLogger {
636 writer: BufWriter<File>,
637 path: PathBuf,
638}
639
640impl JsonlLogger {
641 pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
646 let path = path.as_ref().to_path_buf();
647 let file = File::create(&path).map_err(|e| {
648 TrainError::Other(format!("Failed to create JSONL file {:?}: {}", path, e))
649 })?;
650
651 Ok(Self {
652 writer: BufWriter::new(file),
653 path,
654 })
655 }
656
657 pub fn path(&self) -> &Path {
659 &self.path
660 }
661}
662
663impl LoggingBackend for JsonlLogger {
664 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
665 let timestamp = Utc::now().to_rfc3339();
666 let json = format!(
667 r#"{{"type":"scalar","step":{},"metric":"{}","value":{},"timestamp":"{}"}}"#,
668 step, name, value, timestamp
669 );
670 writeln!(self.writer, "{}", json)
671 .map_err(|e| TrainError::Other(format!("Failed to write to JSONL: {}", e)))?;
672 Ok(())
673 }
674
675 fn log_text(&mut self, message: &str) -> TrainResult<()> {
676 let timestamp = Utc::now().to_rfc3339();
677 let escaped = message.replace('\\', "\\\\").replace('"', "\\\"");
679 let json = format!(
680 r#"{{"type":"text","step":0,"message":"{}","timestamp":"{}"}}"#,
681 escaped, timestamp
682 );
683 writeln!(self.writer, "{}", json)
684 .map_err(|e| TrainError::Other(format!("Failed to write to JSONL: {}", e)))?;
685 Ok(())
686 }
687
688 fn flush(&mut self) -> TrainResult<()> {
689 self.writer
690 .flush()
691 .map_err(|e| TrainError::Other(format!("Failed to flush JSONL writer: {}", e)))?;
692 Ok(())
693 }
694}
695
696impl Clone for JsonlLogger {
697 fn clone(&self) -> Self {
698 Self::new(&self.path).expect("Failed to clone JsonlLogger")
700 }
701}
702
703#[derive(Debug)]
707pub struct MetricsLogger {
708 backends: Vec<Box<dyn LoggingBackendClone>>,
709 current_step: usize,
710 accumulated_metrics: HashMap<String, Vec<f64>>,
711}
712
713trait LoggingBackendClone: LoggingBackend + std::fmt::Debug {
715 fn clone_box(&self) -> Box<dyn LoggingBackendClone>;
716}
717
718impl<T: LoggingBackend + Clone + std::fmt::Debug + 'static> LoggingBackendClone for T {
719 fn clone_box(&self) -> Box<dyn LoggingBackendClone> {
720 Box::new(self.clone())
721 }
722}
723
724impl Clone for Box<dyn LoggingBackendClone> {
725 fn clone(&self) -> Self {
726 self.clone_box()
727 }
728}
729
730impl LoggingBackend for Box<dyn LoggingBackendClone> {
731 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
732 (**self).log_scalar(name, value, step)
733 }
734
735 fn log_text(&mut self, message: &str) -> TrainResult<()> {
736 (**self).log_text(message)
737 }
738
739 fn flush(&mut self) -> TrainResult<()> {
740 (**self).flush()
741 }
742}
743
744impl MetricsLogger {
745 pub fn new() -> Self {
747 Self {
748 backends: Vec::new(),
749 current_step: 0,
750 accumulated_metrics: HashMap::new(),
751 }
752 }
753
754 pub fn add_backend<B: LoggingBackend + Clone + std::fmt::Debug + 'static>(
759 &mut self,
760 backend: B,
761 ) {
762 self.backends.push(Box::new(backend));
763 }
764
765 pub fn log_metric(&mut self, name: &str, value: f64) -> TrainResult<()> {
771 for backend in &mut self.backends {
772 backend.log_scalar(name, value, self.current_step)?;
773 }
774 Ok(())
775 }
776
777 pub fn accumulate_metric(&mut self, name: &str, value: f64) {
783 self.accumulated_metrics
784 .entry(name.to_string())
785 .or_default()
786 .push(value);
787 }
788
789 pub fn log_accumulated_metrics(&mut self) -> TrainResult<()> {
791 let metrics_to_log: Vec<(String, f64)> = self
793 .accumulated_metrics
794 .iter()
795 .filter(|(_, values)| !values.is_empty())
796 .map(|(name, values)| {
797 let avg = values.iter().sum::<f64>() / values.len() as f64;
798 (name.clone(), avg)
799 })
800 .collect();
801
802 for (name, avg) in metrics_to_log {
804 self.log_metric(&name, avg)?;
805 }
806
807 self.accumulated_metrics.clear();
809 Ok(())
810 }
811
812 pub fn log_message(&mut self, message: &str) -> TrainResult<()> {
817 for backend in &mut self.backends {
818 backend.log_text(message)?;
819 }
820 Ok(())
821 }
822
823 pub fn step(&mut self) {
825 self.current_step += 1;
826 }
827
828 pub fn set_step(&mut self, step: usize) {
833 self.current_step = step;
834 }
835
836 pub fn current_step(&self) -> usize {
838 self.current_step
839 }
840
841 pub fn flush(&mut self) -> TrainResult<()> {
843 for backend in &mut self.backends {
844 backend.flush()?;
845 }
846 Ok(())
847 }
848
849 pub fn num_backends(&self) -> usize {
851 self.backends.len()
852 }
853}
854
855impl Default for MetricsLogger {
856 fn default() -> Self {
857 Self::new()
858 }
859}
860
861#[cfg(test)]
862mod tests {
863 use super::*;
864 use std::env;
865 use std::fs;
866
867 #[test]
868 fn test_console_logger() {
869 let mut logger = ConsoleLogger::new();
870
871 logger.log_scalar("loss", 0.5, 1).unwrap();
873 logger.log_text("Test message").unwrap();
874 logger.flush().unwrap();
875 }
876
877 #[test]
878 fn test_console_logger_without_timestamp() {
879 let mut logger = ConsoleLogger::without_timestamp();
880
881 logger.log_scalar("accuracy", 0.95, 10).unwrap();
882 logger.log_text("Another test").unwrap();
883 }
884
885 #[test]
886 fn test_file_logger() {
887 let temp_dir = env::temp_dir();
888 let log_path = temp_dir.join("test_training.log");
889
890 let _ = fs::remove_file(&log_path);
892
893 let mut logger = FileLogger::new(&log_path).unwrap();
894
895 logger.log_scalar("loss", 0.5, 1).unwrap();
896 logger.log_scalar("accuracy", 0.9, 1).unwrap();
897 logger.log_text("Training started").unwrap();
898 logger.flush().unwrap();
899
900 assert!(log_path.exists());
902
903 let contents = fs::read_to_string(&log_path).unwrap();
905 assert!(contents.contains("loss = 0.500000"));
906 assert!(contents.contains("accuracy = 0.900000"));
907 assert!(contents.contains("Training started"));
908
909 fs::remove_file(&log_path).unwrap();
911 }
912
913 #[test]
914 fn test_file_logger_truncate() {
915 let temp_dir = env::temp_dir();
916 let log_path = temp_dir.join("test_training_truncate.log");
917
918 {
920 let mut logger = FileLogger::new(&log_path).unwrap();
921 logger.log_text("Old content").unwrap();
922 logger.flush().unwrap();
923 }
924
925 {
927 let mut logger = FileLogger::new_truncate(&log_path).unwrap();
928 logger.log_text("New content").unwrap();
929 logger.flush().unwrap();
930 }
931
932 let contents = fs::read_to_string(&log_path).unwrap();
934 assert!(!contents.contains("Old content"));
935 assert!(contents.contains("New content"));
936
937 fs::remove_file(&log_path).unwrap();
939 }
940
941 #[test]
942 fn test_tensorboard_logger() {
943 let temp_dir = env::temp_dir();
944 let tb_dir = temp_dir.join("test_tensorboard_real");
945
946 let _ = fs::remove_dir_all(&tb_dir);
948
949 let mut logger = TensorBoardLogger::new(&tb_dir).unwrap();
950
951 assert!(tb_dir.exists());
953
954 logger.log_scalar("loss", 0.5, 1).unwrap();
956 logger.log_scalar("accuracy", 0.95, 1).unwrap();
957 logger.log_text("Test message").unwrap();
958
959 let values = vec![0.1, 0.2, 0.3, 0.4, 0.5];
961 logger.log_histogram("weights", &values, 1).unwrap();
962
963 logger.flush().unwrap();
964
965 let event_file = logger.file_path();
967 assert!(event_file.exists());
968 assert!(event_file.to_string_lossy().contains("tfevents"));
969
970 fs::remove_dir_all(&tb_dir).unwrap();
972 }
973
974 #[test]
975 fn test_csv_logger() {
976 let temp_dir = env::temp_dir();
977 let csv_path = temp_dir.join("test_metrics.csv");
978
979 let _ = fs::remove_file(&csv_path);
981
982 let mut logger = CsvLogger::new(&csv_path).unwrap();
983
984 logger.log_scalar("loss", 0.5, 1).unwrap();
985 logger.log_scalar("accuracy", 0.95, 2).unwrap();
986 logger.log_text("Training started").unwrap();
987 logger.flush().unwrap();
988
989 assert!(csv_path.exists());
991
992 let contents = fs::read_to_string(&csv_path).unwrap();
994 assert!(contents.contains("step,metric,value,timestamp")); assert!(contents.contains("1,loss,0.500000"));
996 assert!(contents.contains("2,accuracy,0.950000"));
997 assert!(contents.contains("Training started"));
998
999 fs::remove_file(&csv_path).unwrap();
1001 }
1002
1003 #[test]
1004 fn test_jsonl_logger() {
1005 let temp_dir = env::temp_dir();
1006 let jsonl_path = temp_dir.join("test_metrics.jsonl");
1007
1008 let _ = fs::remove_file(&jsonl_path);
1010
1011 let mut logger = JsonlLogger::new(&jsonl_path).unwrap();
1012
1013 logger.log_scalar("loss", 0.5, 1).unwrap();
1014 logger.log_scalar("accuracy", 0.95, 2).unwrap();
1015 logger.log_text("Training started").unwrap();
1016 logger.flush().unwrap();
1017
1018 assert!(jsonl_path.exists());
1020
1021 let contents = fs::read_to_string(&jsonl_path).unwrap();
1023 let lines: Vec<&str> = contents.lines().collect();
1024 assert_eq!(lines.len(), 3);
1025
1026 assert!(lines[0].contains("\"type\":\"scalar\""));
1028 assert!(lines[0].contains("\"metric\":\"loss\""));
1029 assert!(lines[0].contains("\"value\":0.5"));
1030
1031 assert!(lines[1].contains("\"metric\":\"accuracy\""));
1033 assert!(lines[1].contains("\"value\":0.95"));
1034
1035 assert!(lines[2].contains("\"type\":\"text\""));
1037 assert!(lines[2].contains("Training started"));
1038
1039 fs::remove_file(&jsonl_path).unwrap();
1041 }
1042
1043 #[test]
1044 fn test_csv_logger_path() {
1045 let temp_dir = env::temp_dir();
1046 let csv_path = temp_dir.join("test_csv_path.csv");
1047 let _ = fs::remove_file(&csv_path);
1048
1049 let logger = CsvLogger::new(&csv_path).unwrap();
1050 assert_eq!(logger.path(), csv_path.as_path());
1051
1052 fs::remove_file(&csv_path).unwrap();
1054 }
1055
1056 #[test]
1057 fn test_jsonl_logger_path() {
1058 let temp_dir = env::temp_dir();
1059 let jsonl_path = temp_dir.join("test_jsonl_path.jsonl");
1060 let _ = fs::remove_file(&jsonl_path);
1061
1062 let logger = JsonlLogger::new(&jsonl_path).unwrap();
1063 assert_eq!(logger.path(), jsonl_path.as_path());
1064
1065 fs::remove_file(&jsonl_path).unwrap();
1067 }
1068
1069 #[test]
1070 fn test_metrics_logger() {
1071 let mut logger = MetricsLogger::new();
1072 assert_eq!(logger.num_backends(), 0);
1073
1074 logger.add_backend(ConsoleLogger::without_timestamp());
1075 assert_eq!(logger.num_backends(), 1);
1076
1077 logger.log_metric("loss", 0.5).unwrap();
1078 logger.log_message("Epoch 1").unwrap();
1079
1080 assert_eq!(logger.current_step(), 0);
1081 logger.step();
1082 assert_eq!(logger.current_step(), 1);
1083
1084 logger.set_step(10);
1085 assert_eq!(logger.current_step(), 10);
1086
1087 logger.flush().unwrap();
1088 }
1089
1090 #[test]
1091 fn test_metrics_logger_accumulation() {
1092 let mut logger = MetricsLogger::new();
1093 logger.add_backend(ConsoleLogger::without_timestamp());
1094
1095 logger.accumulate_metric("batch_loss", 0.5);
1097 logger.accumulate_metric("batch_loss", 0.4);
1098 logger.accumulate_metric("batch_loss", 0.6);
1099
1100 logger.log_accumulated_metrics().unwrap();
1102
1103 logger.log_accumulated_metrics().unwrap(); }
1106
1107 #[test]
1108 fn test_metrics_logger_multiple_backends() {
1109 let mut logger = MetricsLogger::new();
1110 logger.add_backend(ConsoleLogger::without_timestamp());
1111 logger.add_backend(ConsoleLogger::new());
1112
1113 assert_eq!(logger.num_backends(), 2);
1114
1115 logger.log_metric("loss", 0.5).unwrap();
1116 logger.flush().unwrap();
1117 }
1118
1119 #[test]
1120 fn test_metrics_logger_empty_accumulation() {
1121 let mut logger = MetricsLogger::new();
1122 logger.add_backend(ConsoleLogger::without_timestamp());
1123
1124 logger.log_accumulated_metrics().unwrap();
1126 }
1127
1128 #[test]
1129 fn test_file_logger_path() {
1130 let temp_dir = env::temp_dir();
1131 let log_path = temp_dir.join("test_path.log");
1132 let _ = fs::remove_file(&log_path);
1133
1134 let logger = FileLogger::new(&log_path).unwrap();
1135 assert_eq!(logger.path(), log_path.as_path());
1136
1137 fs::remove_file(&log_path).unwrap();
1139 }
1140
1141 #[test]
1142 fn test_tensorboard_logger_log_dir() {
1143 let temp_dir = env::temp_dir();
1144 let tb_dir = temp_dir.join("test_tb_path");
1145 let _ = fs::remove_dir_all(&tb_dir);
1146
1147 let logger = TensorBoardLogger::new(&tb_dir).unwrap();
1148 assert_eq!(logger.log_dir(), tb_dir.as_path());
1149
1150 fs::remove_dir_all(&tb_dir).unwrap();
1152 }
1153}