1use std::fs::{create_dir_all, File};
2use std::io::{self, BufWriter, ErrorKind, Write};
3use std::path::PathBuf;
4
5use crate::event_writer::event::What;
6use crate::event_writer::summary::value::Value::SimpleValue;
7use crate::event_writer::summary::Value;
8use crate::event_writer::{EventWriter, Summary};
9use std::time::{SystemTime, UNIX_EPOCH};
10
11pub struct SummaryWriter<W> {
13 writer: EventWriter<W>,
14}
15
16impl SummaryWriter<BufWriter<File>> {
17 pub fn from_prefix(path: impl Into<PathBuf>) -> io::Result<Self> {
24 let path = path.into();
25
26 if path.components().count() == 0 {
27 return Err(io::Error::new(
28 ErrorKind::NotFound,
29 "summary prefix must not be empty".to_string(),
30 ));
31 }
32
33 if let Some(dir) = path.parent() {
34 create_dir_all(dir)?;
35 }
36
37 let timestamp = SystemTime::now()
38 .duration_since(UNIX_EPOCH)
39 .unwrap()
40 .as_micros();
41 let hostname = hostname::get()?;
42
43 let mut path_string = path.into_os_string();
44 path_string.push(format!(".out.tfevents.{}.", timestamp));
45 path_string.push(hostname);
46
47 SummaryWriter::new(BufWriter::new(File::create(path_string)?))
48 }
49}
50
51impl<W> SummaryWriter<W>
52where
53 W: Write,
54{
55 pub fn new(write: W) -> io::Result<Self> {
57 let writer = EventWriter::new(write)?;
58 Ok(SummaryWriter { writer })
59 }
60
61 #[allow(dead_code)]
65 fn new_with_wall_time(write: W, wall_time: f64) -> io::Result<Self> {
66 let writer = EventWriter::new_with_wall_time(write, wall_time)?;
67 Ok(SummaryWriter { writer })
68 }
69
70 pub fn write_scalar(
72 &mut self,
73 tag: impl Into<String>,
74 step: i64,
75 scalar: f32,
76 ) -> std::io::Result<()> {
77 self.writer.write_event(
78 step,
79 What::Summary(Summary {
80 value: vec![Value {
81 node_name: "".to_string(),
82 tag: tag.into(),
83 value: Some(SimpleValue(scalar)),
84 }],
85 }),
86 )
87 }
88
89 #[allow(dead_code)]
93 fn write_scalar_with_wall_time(
94 &mut self,
95 wall_time: f64,
96 tag: impl Into<String>,
97 step: i64,
98 scalar: f32,
99 ) -> std::io::Result<()> {
100 self.writer.write_event_with_wall_time(
101 wall_time,
102 step,
103 What::Summary(Summary {
104 value: vec![Value {
105 node_name: "".to_string(),
106 tag: tag.into(),
107 value: Some(SimpleValue(scalar)),
108 }],
109 }),
110 )
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use crate::SummaryWriter;
117
118 static CHECK_OUTPUT: [u8; 126] = [
119 24, 0, 0, 0, 0, 0, 0, 0, 163, 127, 75, 34, 9, 0, 0, 128, 54, 111, 246, 215, 65, 26, 13, 98,
120 114, 97, 105, 110, 46, 69, 118, 101, 110, 116, 58, 50, 136, 162, 101, 134, 27, 0, 0, 0, 0,
121 0, 0, 0, 26, 13, 158, 19, 9, 188, 119, 164, 54, 111, 246, 215, 65, 16, 10, 42, 14, 10, 12,
122 10, 5, 104, 101, 108, 108, 111, 21, 0, 0, 40, 66, 93, 240, 111, 128, 27, 0, 0, 0, 0, 0, 0,
123 0, 26, 13, 158, 19, 9, 48, 127, 164, 54, 111, 246, 215, 65, 16, 20, 42, 14, 10, 12, 10, 5,
124 119, 111, 114, 108, 100, 21, 0, 0, 128, 63, 5, 210, 83, 151,
125 ];
126
127 #[test]
128 fn writes_the_same_output_as_tensorflow() {
129 let mut data = vec![];
130 let mut writer = SummaryWriter::new_with_wall_time(&mut data, 1608105178.).unwrap();
131 writer
132 .write_scalar_with_wall_time(1608105178.569808, "hello", 10, 42.)
133 .unwrap();
134 writer
135 .write_scalar_with_wall_time(1608105178.570263, "world", 20, 1.)
136 .unwrap();
137
138 assert_eq!(data, CHECK_OUTPUT);
139 }
140}