syntaxdot_summary/
summary_writer.rs

1use std::fs::{create_dir_all, File};
2use std::io::{self, BufWriter, ErrorKind, Write};
3use std::path::PathBuf;
4
5use crate::event_writer::event::What;
6use crate::event_writer::summary::value::Value::SimpleValue;
7use crate::event_writer::summary::Value;
8use crate::event_writer::{EventWriter, Summary};
9use std::time::{SystemTime, UNIX_EPOCH};
10
11/// TensorBoard summary writer.
12pub struct SummaryWriter<W> {
13    writer: EventWriter<W>,
14}
15
16impl SummaryWriter<BufWriter<File>> {
17    /// Construct a writer from a path prefix.
18    ///
19    /// For instance, a path such as `tensorboard/bert/opt` will create
20    /// the directory `tensorboard/bert` if it does not exist. Within that
21    /// directory, it will write to the file
22    /// `opt.out.tfevents.<timestamp>.<hostname>`.
23    pub fn from_prefix(path: impl Into<PathBuf>) -> io::Result<Self> {
24        let path = path.into();
25
26        if path.components().count() == 0 {
27            return Err(io::Error::new(
28                ErrorKind::NotFound,
29                "summary prefix must not be empty".to_string(),
30            ));
31        }
32
33        if let Some(dir) = path.parent() {
34            create_dir_all(dir)?;
35        }
36
37        let timestamp = SystemTime::now()
38            .duration_since(UNIX_EPOCH)
39            .unwrap()
40            .as_micros();
41        let hostname = hostname::get()?;
42
43        let mut path_string = path.into_os_string();
44        path_string.push(format!(".out.tfevents.{}.", timestamp));
45        path_string.push(hostname);
46
47        SummaryWriter::new(BufWriter::new(File::create(path_string)?))
48    }
49}
50
51impl<W> SummaryWriter<W>
52where
53    W: Write,
54{
55    /// Construct a writer from a `Write` type.
56    pub fn new(write: W) -> io::Result<Self> {
57        let writer = EventWriter::new(write)?;
58        Ok(SummaryWriter { writer })
59    }
60
61    /// Create a writer that uses the given wall time in the version record.
62    ///
63    /// This constructor is provided for unit tests.
64    #[allow(dead_code)]
65    fn new_with_wall_time(write: W, wall_time: f64) -> io::Result<Self> {
66        let writer = EventWriter::new_with_wall_time(write, wall_time)?;
67        Ok(SummaryWriter { writer })
68    }
69
70    /// Write a scalar.
71    pub fn write_scalar(
72        &mut self,
73        tag: impl Into<String>,
74        step: i64,
75        scalar: f32,
76    ) -> std::io::Result<()> {
77        self.writer.write_event(
78            step,
79            What::Summary(Summary {
80                value: vec![Value {
81                    node_name: "".to_string(),
82                    tag: tag.into(),
83                    value: Some(SimpleValue(scalar)),
84                }],
85            }),
86        )
87    }
88
89    /// Write a scalar with the given wall time.
90    ///
91    /// This method is provided for unit tests.
92    #[allow(dead_code)]
93    fn write_scalar_with_wall_time(
94        &mut self,
95        wall_time: f64,
96        tag: impl Into<String>,
97        step: i64,
98        scalar: f32,
99    ) -> std::io::Result<()> {
100        self.writer.write_event_with_wall_time(
101            wall_time,
102            step,
103            What::Summary(Summary {
104                value: vec![Value {
105                    node_name: "".to_string(),
106                    tag: tag.into(),
107                    value: Some(SimpleValue(scalar)),
108                }],
109            }),
110        )
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use crate::SummaryWriter;
117
118    static CHECK_OUTPUT: [u8; 126] = [
119        24, 0, 0, 0, 0, 0, 0, 0, 163, 127, 75, 34, 9, 0, 0, 128, 54, 111, 246, 215, 65, 26, 13, 98,
120        114, 97, 105, 110, 46, 69, 118, 101, 110, 116, 58, 50, 136, 162, 101, 134, 27, 0, 0, 0, 0,
121        0, 0, 0, 26, 13, 158, 19, 9, 188, 119, 164, 54, 111, 246, 215, 65, 16, 10, 42, 14, 10, 12,
122        10, 5, 104, 101, 108, 108, 111, 21, 0, 0, 40, 66, 93, 240, 111, 128, 27, 0, 0, 0, 0, 0, 0,
123        0, 26, 13, 158, 19, 9, 48, 127, 164, 54, 111, 246, 215, 65, 16, 20, 42, 14, 10, 12, 10, 5,
124        119, 111, 114, 108, 100, 21, 0, 0, 128, 63, 5, 210, 83, 151,
125    ];
126
127    #[test]
128    fn writes_the_same_output_as_tensorflow() {
129        let mut data = vec![];
130        let mut writer = SummaryWriter::new_with_wall_time(&mut data, 1608105178.).unwrap();
131        writer
132            .write_scalar_with_wall_time(1608105178.569808, "hello", 10, 42.)
133            .unwrap();
134        writer
135            .write_scalar_with_wall_time(1608105178.570263, "world", 20, 1.)
136            .unwrap();
137
138        assert_eq!(data, CHECK_OUTPUT);
139    }
140}