1use crate::{TrainError, TrainResult};
10use std::collections::HashMap;
11use std::fs::{File, OpenOptions};
12use std::io::Write;
13use std::path::{Path, PathBuf};
14
15pub trait LoggingBackend {
17 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()>;
24
25 fn log_text(&mut self, message: &str) -> TrainResult<()>;
30
31 fn flush(&mut self) -> TrainResult<()>;
33}
34
35#[derive(Debug, Clone, Default)]
39pub struct ConsoleLogger {
40 pub include_timestamp: bool,
42}
43
44impl ConsoleLogger {
45 pub fn new() -> Self {
47 Self {
48 include_timestamp: true,
49 }
50 }
51
52 pub fn without_timestamp() -> Self {
54 Self {
55 include_timestamp: false,
56 }
57 }
58
59 fn format_timestamp(&self) -> String {
60 if self.include_timestamp {
61 let now = std::time::SystemTime::now();
62 match now.duration_since(std::time::UNIX_EPOCH) {
63 Ok(duration) => format!("[{:.3}] ", duration.as_secs_f64()),
64 Err(_) => String::new(),
65 }
66 } else {
67 String::new()
68 }
69 }
70}
71
72impl LoggingBackend for ConsoleLogger {
73 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
74 println!(
75 "{}Step {}: {} = {:.6}",
76 self.format_timestamp(),
77 step,
78 name,
79 value
80 );
81 Ok(())
82 }
83
84 fn log_text(&mut self, message: &str) -> TrainResult<()> {
85 println!("{}{}", self.format_timestamp(), message);
86 Ok(())
87 }
88
89 fn flush(&mut self) -> TrainResult<()> {
90 use std::io::stdout;
91 stdout()
92 .flush()
93 .map_err(|e| TrainError::Other(format!("Failed to flush stdout: {}", e)))?;
94 Ok(())
95 }
96}
97
98#[derive(Debug)]
102pub struct FileLogger {
103 file: File,
104 path: PathBuf,
105}
106
107impl FileLogger {
108 pub fn new<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
113 let path = path.as_ref().to_path_buf();
114 let file = OpenOptions::new()
115 .create(true)
116 .append(true)
117 .open(&path)
118 .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
119
120 Ok(Self { file, path })
121 }
122
123 pub fn new_truncate<P: AsRef<Path>>(path: P) -> TrainResult<Self> {
128 let path = path.as_ref().to_path_buf();
129 let file = OpenOptions::new()
130 .create(true)
131 .write(true)
132 .truncate(true)
133 .open(&path)
134 .map_err(|e| TrainError::Other(format!("Failed to open log file {:?}: {}", path, e)))?;
135
136 Ok(Self { file, path })
137 }
138
139 pub fn path(&self) -> &Path {
141 &self.path
142 }
143}
144
145impl LoggingBackend for FileLogger {
146 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
147 writeln!(self.file, "Step {}: {} = {:.6}", step, name, value)
148 .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
149 Ok(())
150 }
151
152 fn log_text(&mut self, message: &str) -> TrainResult<()> {
153 writeln!(self.file, "{}", message)
154 .map_err(|e| TrainError::Other(format!("Failed to write to log file: {}", e)))?;
155 Ok(())
156 }
157
158 fn flush(&mut self) -> TrainResult<()> {
159 self.file
160 .flush()
161 .map_err(|e| TrainError::Other(format!("Failed to flush log file: {}", e)))?;
162 Ok(())
163 }
164}
165
166#[derive(Debug, Clone)]
170pub struct TensorBoardLogger {
171 log_dir: PathBuf,
172}
173
174impl TensorBoardLogger {
175 pub fn new<P: AsRef<Path>>(log_dir: P) -> TrainResult<Self> {
180 let log_dir = log_dir.as_ref().to_path_buf();
181
182 std::fs::create_dir_all(&log_dir).map_err(|e| {
184 TrainError::Other(format!(
185 "Failed to create log directory {:?}: {}",
186 log_dir, e
187 ))
188 })?;
189
190 Ok(Self { log_dir })
191 }
192
193 pub fn log_dir(&self) -> &Path {
195 &self.log_dir
196 }
197}
198
199impl LoggingBackend for TensorBoardLogger {
200 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
201 log::debug!("TensorBoard: Step {}: {} = {:.6}", step, name, value);
203 Ok(())
204 }
205
206 fn log_text(&mut self, message: &str) -> TrainResult<()> {
207 log::debug!("TensorBoard: {}", message);
208 Ok(())
209 }
210
211 fn flush(&mut self) -> TrainResult<()> {
212 Ok(())
214 }
215}
216
217#[derive(Debug)]
221pub struct MetricsLogger {
222 backends: Vec<Box<dyn LoggingBackendClone>>,
223 current_step: usize,
224 accumulated_metrics: HashMap<String, Vec<f64>>,
225}
226
227trait LoggingBackendClone: LoggingBackend + std::fmt::Debug {
229 fn clone_box(&self) -> Box<dyn LoggingBackendClone>;
230}
231
232impl<T: LoggingBackend + Clone + std::fmt::Debug + 'static> LoggingBackendClone for T {
233 fn clone_box(&self) -> Box<dyn LoggingBackendClone> {
234 Box::new(self.clone())
235 }
236}
237
238impl Clone for Box<dyn LoggingBackendClone> {
239 fn clone(&self) -> Self {
240 self.clone_box()
241 }
242}
243
244impl LoggingBackend for Box<dyn LoggingBackendClone> {
245 fn log_scalar(&mut self, name: &str, value: f64, step: usize) -> TrainResult<()> {
246 (**self).log_scalar(name, value, step)
247 }
248
249 fn log_text(&mut self, message: &str) -> TrainResult<()> {
250 (**self).log_text(message)
251 }
252
253 fn flush(&mut self) -> TrainResult<()> {
254 (**self).flush()
255 }
256}
257
258impl MetricsLogger {
259 pub fn new() -> Self {
261 Self {
262 backends: Vec::new(),
263 current_step: 0,
264 accumulated_metrics: HashMap::new(),
265 }
266 }
267
268 pub fn add_backend<B: LoggingBackend + Clone + std::fmt::Debug + 'static>(
273 &mut self,
274 backend: B,
275 ) {
276 self.backends.push(Box::new(backend));
277 }
278
279 pub fn log_metric(&mut self, name: &str, value: f64) -> TrainResult<()> {
285 for backend in &mut self.backends {
286 backend.log_scalar(name, value, self.current_step)?;
287 }
288 Ok(())
289 }
290
291 pub fn accumulate_metric(&mut self, name: &str, value: f64) {
297 self.accumulated_metrics
298 .entry(name.to_string())
299 .or_default()
300 .push(value);
301 }
302
303 pub fn log_accumulated_metrics(&mut self) -> TrainResult<()> {
305 let metrics_to_log: Vec<(String, f64)> = self
307 .accumulated_metrics
308 .iter()
309 .filter(|(_, values)| !values.is_empty())
310 .map(|(name, values)| {
311 let avg = values.iter().sum::<f64>() / values.len() as f64;
312 (name.clone(), avg)
313 })
314 .collect();
315
316 for (name, avg) in metrics_to_log {
318 self.log_metric(&name, avg)?;
319 }
320
321 self.accumulated_metrics.clear();
323 Ok(())
324 }
325
326 pub fn log_message(&mut self, message: &str) -> TrainResult<()> {
331 for backend in &mut self.backends {
332 backend.log_text(message)?;
333 }
334 Ok(())
335 }
336
337 pub fn step(&mut self) {
339 self.current_step += 1;
340 }
341
342 pub fn set_step(&mut self, step: usize) {
347 self.current_step = step;
348 }
349
350 pub fn current_step(&self) -> usize {
352 self.current_step
353 }
354
355 pub fn flush(&mut self) -> TrainResult<()> {
357 for backend in &mut self.backends {
358 backend.flush()?;
359 }
360 Ok(())
361 }
362
363 pub fn num_backends(&self) -> usize {
365 self.backends.len()
366 }
367}
368
369impl Default for MetricsLogger {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use std::env;
379 use std::fs;
380
381 #[test]
382 fn test_console_logger() {
383 let mut logger = ConsoleLogger::new();
384
385 logger.log_scalar("loss", 0.5, 1).unwrap();
387 logger.log_text("Test message").unwrap();
388 logger.flush().unwrap();
389 }
390
391 #[test]
392 fn test_console_logger_without_timestamp() {
393 let mut logger = ConsoleLogger::without_timestamp();
394
395 logger.log_scalar("accuracy", 0.95, 10).unwrap();
396 logger.log_text("Another test").unwrap();
397 }
398
399 #[test]
400 fn test_file_logger() {
401 let temp_dir = env::temp_dir();
402 let log_path = temp_dir.join("test_training.log");
403
404 let _ = fs::remove_file(&log_path);
406
407 let mut logger = FileLogger::new(&log_path).unwrap();
408
409 logger.log_scalar("loss", 0.5, 1).unwrap();
410 logger.log_scalar("accuracy", 0.9, 1).unwrap();
411 logger.log_text("Training started").unwrap();
412 logger.flush().unwrap();
413
414 assert!(log_path.exists());
416
417 let contents = fs::read_to_string(&log_path).unwrap();
419 assert!(contents.contains("loss = 0.500000"));
420 assert!(contents.contains("accuracy = 0.900000"));
421 assert!(contents.contains("Training started"));
422
423 fs::remove_file(&log_path).unwrap();
425 }
426
427 #[test]
428 fn test_file_logger_truncate() {
429 let temp_dir = env::temp_dir();
430 let log_path = temp_dir.join("test_training_truncate.log");
431
432 {
434 let mut logger = FileLogger::new(&log_path).unwrap();
435 logger.log_text("Old content").unwrap();
436 logger.flush().unwrap();
437 }
438
439 {
441 let mut logger = FileLogger::new_truncate(&log_path).unwrap();
442 logger.log_text("New content").unwrap();
443 logger.flush().unwrap();
444 }
445
446 let contents = fs::read_to_string(&log_path).unwrap();
448 assert!(!contents.contains("Old content"));
449 assert!(contents.contains("New content"));
450
451 fs::remove_file(&log_path).unwrap();
453 }
454
455 #[test]
456 fn test_tensorboard_logger() {
457 let temp_dir = env::temp_dir();
458 let tb_dir = temp_dir.join("test_tensorboard");
459
460 let _ = fs::remove_dir_all(&tb_dir);
462
463 let mut logger = TensorBoardLogger::new(&tb_dir).unwrap();
464
465 assert!(tb_dir.exists());
467
468 logger.log_scalar("loss", 0.5, 1).unwrap();
469 logger.log_text("Test message").unwrap();
470 logger.flush().unwrap();
471
472 fs::remove_dir_all(&tb_dir).unwrap();
474 }
475
476 #[test]
477 fn test_metrics_logger() {
478 let mut logger = MetricsLogger::new();
479 assert_eq!(logger.num_backends(), 0);
480
481 logger.add_backend(ConsoleLogger::without_timestamp());
482 assert_eq!(logger.num_backends(), 1);
483
484 logger.log_metric("loss", 0.5).unwrap();
485 logger.log_message("Epoch 1").unwrap();
486
487 assert_eq!(logger.current_step(), 0);
488 logger.step();
489 assert_eq!(logger.current_step(), 1);
490
491 logger.set_step(10);
492 assert_eq!(logger.current_step(), 10);
493
494 logger.flush().unwrap();
495 }
496
497 #[test]
498 fn test_metrics_logger_accumulation() {
499 let mut logger = MetricsLogger::new();
500 logger.add_backend(ConsoleLogger::without_timestamp());
501
502 logger.accumulate_metric("batch_loss", 0.5);
504 logger.accumulate_metric("batch_loss", 0.4);
505 logger.accumulate_metric("batch_loss", 0.6);
506
507 logger.log_accumulated_metrics().unwrap();
509
510 logger.log_accumulated_metrics().unwrap(); }
513
514 #[test]
515 fn test_metrics_logger_multiple_backends() {
516 let mut logger = MetricsLogger::new();
517 logger.add_backend(ConsoleLogger::without_timestamp());
518 logger.add_backend(ConsoleLogger::new());
519
520 assert_eq!(logger.num_backends(), 2);
521
522 logger.log_metric("loss", 0.5).unwrap();
523 logger.flush().unwrap();
524 }
525
526 #[test]
527 fn test_metrics_logger_empty_accumulation() {
528 let mut logger = MetricsLogger::new();
529 logger.add_backend(ConsoleLogger::without_timestamp());
530
531 logger.log_accumulated_metrics().unwrap();
533 }
534
535 #[test]
536 fn test_file_logger_path() {
537 let temp_dir = env::temp_dir();
538 let log_path = temp_dir.join("test_path.log");
539 let _ = fs::remove_file(&log_path);
540
541 let logger = FileLogger::new(&log_path).unwrap();
542 assert_eq!(logger.path(), log_path.as_path());
543
544 fs::remove_file(&log_path).unwrap();
546 }
547
548 #[test]
549 fn test_tensorboard_logger_log_dir() {
550 let temp_dir = env::temp_dir();
551 let tb_dir = temp_dir.join("test_tb_path");
552 let _ = fs::remove_dir_all(&tb_dir);
553
554 let logger = TensorBoardLogger::new(&tb_dir).unwrap();
555 assert_eq!(logger.log_dir(), tb_dir.as_path());
556
557 fs::remove_dir_all(&tb_dir).unwrap();
559 }
560}