Skip to main content

yscv_model/
tensorboard.rs

1use std::collections::HashMap;
2use std::fs::{self, File};
3use std::io::{BufWriter, Write};
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use super::callbacks::TrainingCallback;
8use super::error::ModelError;
9
10// ---------- CRC32C (Castagnoli) ----------
11
12const CRC32C_TABLE: [u32; 256] = {
13    let mut table = [0u32; 256];
14    let mut i = 0usize;
15    while i < 256 {
16        let mut crc = i as u32;
17        let mut j = 0;
18        while j < 8 {
19            if crc & 1 != 0 {
20                crc = (crc >> 1) ^ 0x82F6_3B78;
21            } else {
22                crc >>= 1;
23            }
24            j += 1;
25        }
26        table[i] = crc;
27        i += 1;
28    }
29    table
30};
31
32fn crc32c(data: &[u8]) -> u32 {
33    let mut crc: u32 = !0;
34    for &b in data {
35        crc = CRC32C_TABLE[((crc ^ b as u32) & 0xFF) as usize] ^ (crc >> 8);
36    }
37    !crc
38}
39
40fn masked_crc32c(data: &[u8]) -> u32 {
41    let crc = crc32c(data);
42    crc.rotate_right(15).wrapping_add(0xa282_ead8)
43}
44
45// ---------- Protobuf helpers ----------
46
47fn encode_varint(mut value: u64, buf: &mut Vec<u8>) {
48    while value >= 0x80 {
49        buf.push((value as u8) | 0x80);
50        value >>= 7;
51    }
52    buf.push(value as u8);
53}
54
55/// Encode a Summary.Value proto: field 1 = tag (string), field 2 = simple_value (float)
56fn encode_summary_value(tag: &str, value: f32) -> Vec<u8> {
57    let mut buf = Vec::new();
58    // field 1, wire type 2 (length-delimited): tag string
59    buf.push(0x0a);
60    encode_varint(tag.len() as u64, &mut buf);
61    buf.extend_from_slice(tag.as_bytes());
62    // field 2, wire type 5 (32-bit): simple_value float
63    buf.push(0x15);
64    buf.extend_from_slice(&value.to_le_bytes());
65    buf
66}
67
68/// Encode a Summary proto: repeated field 1 = Value
69fn encode_summary(tag: &str, value: f32) -> Vec<u8> {
70    let val = encode_summary_value(tag, value);
71    let mut buf = Vec::new();
72    // field 1, wire type 2
73    buf.push(0x0a);
74    encode_varint(val.len() as u64, &mut buf);
75    buf.extend_from_slice(&val);
76    buf
77}
78
79/// Encode an Event proto with wall_time, step, and summary.
80fn encode_event_summary(wall_time: f64, step: i64, tag: &str, value: f32) -> Vec<u8> {
81    let summary = encode_summary(tag, value);
82    let mut buf = Vec::new();
83    // field 1, wire type 1 (64-bit): wall_time (double)
84    buf.push(0x09);
85    buf.extend_from_slice(&wall_time.to_le_bytes());
86    // field 2, wire type 0 (varint): step
87    buf.push(0x10);
88    encode_varint(step as u64, &mut buf);
89    // field 5, wire type 2 (length-delimited): summary
90    buf.push(0x2a);
91    encode_varint(summary.len() as u64, &mut buf);
92    buf.extend_from_slice(&summary);
93    buf
94}
95
96/// Encode an Event proto with file_version string (field 6).
97fn encode_event_file_version(wall_time: f64, step: i64, version: &str) -> Vec<u8> {
98    let mut buf = Vec::new();
99    // field 1, wire type 1: wall_time
100    buf.push(0x09);
101    buf.extend_from_slice(&wall_time.to_le_bytes());
102    // field 2, wire type 0: step
103    buf.push(0x10);
104    encode_varint(step as u64, &mut buf);
105    // field 6, wire type 2 (length-delimited): file_version string
106    buf.push(0x32);
107    encode_varint(version.len() as u64, &mut buf);
108    buf.extend_from_slice(version.as_bytes());
109    buf
110}
111
112// ---------- TensorBoardWriter ----------
113
114/// Writes TensorBoard-compatible event files in TFRecord format.
115pub struct TensorBoardWriter {
116    file: BufWriter<File>,
117}
118
119impl TensorBoardWriter {
120    /// Create a new writer that stores events under `log_dir`.
121    ///
122    /// Creates `log_dir` if it does not exist and writes the initial
123    /// `file_version` event required by TensorBoard.
124    pub fn new(log_dir: impl AsRef<Path>) -> Result<Self, ModelError> {
125        let log_dir = log_dir.as_ref();
126        fs::create_dir_all(log_dir).map_err(|e| ModelError::DatasetLoadIo {
127            path: log_dir.display().to_string(),
128            message: e.to_string(),
129        })?;
130
131        let timestamp = SystemTime::now()
132            .duration_since(UNIX_EPOCH)
133            .unwrap_or_default()
134            .as_secs();
135
136        let hostname = "localhost";
137        let filename = format!("events.out.tfevents.{timestamp}.{hostname}");
138        let filepath = log_dir.join(filename);
139
140        let f = File::create(&filepath).map_err(|e| ModelError::DatasetLoadIo {
141            path: filepath.display().to_string(),
142            message: e.to_string(),
143        })?;
144        let mut writer = Self {
145            file: BufWriter::new(f),
146        };
147
148        let wall_time = timestamp as f64;
149        let event = encode_event_file_version(wall_time, 0, "brain.Event:2");
150        writer.write_record(&event)?;
151
152        Ok(writer)
153    }
154
155    /// Log a scalar value with the given tag and step.
156    pub fn add_scalar(&mut self, tag: &str, value: f32, step: i64) -> Result<(), ModelError> {
157        let wall_time = SystemTime::now()
158            .duration_since(UNIX_EPOCH)
159            .unwrap_or_default()
160            .as_secs_f64();
161        let event = encode_event_summary(wall_time, step, tag, value);
162        self.write_record(&event)
163    }
164
165    /// Flush buffered data to disk.
166    pub fn flush(&mut self) -> Result<(), ModelError> {
167        self.file.flush().map_err(|e| ModelError::DatasetLoadIo {
168            path: "tensorboard events".to_string(),
169            message: e.to_string(),
170        })
171    }
172
173    /// Write a single TFRecord.
174    fn write_record(&mut self, data: &[u8]) -> Result<(), ModelError> {
175        let len = data.len() as u64;
176        let len_bytes = len.to_le_bytes();
177        let len_crc = masked_crc32c(&len_bytes);
178        let data_crc = masked_crc32c(data);
179
180        let w = &mut self.file;
181        w.write_all(&len_bytes)
182            .and_then(|_| w.write_all(&len_crc.to_le_bytes()))
183            .and_then(|_| w.write_all(data))
184            .and_then(|_| w.write_all(&data_crc.to_le_bytes()))
185            .map_err(|e| ModelError::DatasetLoadIo {
186                path: "tensorboard events".to_string(),
187                message: e.to_string(),
188            })
189    }
190}
191
192// ---------- TensorBoardCallback ----------
193
194/// Training callback that logs scalar metrics to TensorBoard event files.
195pub struct TensorBoardCallback {
196    writer: TensorBoardWriter,
197    global_step: i64,
198}
199
200impl TensorBoardCallback {
201    /// Create a new callback writing events to `log_dir`.
202    pub fn new(log_dir: impl AsRef<Path>) -> Result<Self, ModelError> {
203        Ok(Self {
204            writer: TensorBoardWriter::new(log_dir)?,
205            global_step: 0,
206        })
207    }
208
209    /// Returns the current global step counter.
210    pub fn global_step(&self) -> i64 {
211        self.global_step
212    }
213}
214
215impl TrainingCallback for TensorBoardCallback {
216    fn on_epoch_end(&mut self, _epoch: usize, metrics: &HashMap<String, f32>) -> bool {
217        self.global_step += 1;
218        for (tag, &value) in metrics {
219            let _ = self.writer.add_scalar(tag, value, self.global_step);
220        }
221        let _ = self.writer.flush();
222        false
223    }
224
225    fn on_batch_end(&mut self, _epoch: usize, _batch: usize, loss: f32) {
226        self.global_step += 1;
227        let _ = self.writer.add_scalar("batch_loss", loss, self.global_step);
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn masked_crc32c_known_values() {
237        // Empty data
238        let crc = crc32c(b"");
239        assert_eq!(crc, 0x0000_0000);
240
241        // Verify masking transforms the CRC
242        let m = masked_crc32c(b"hello");
243        assert_ne!(m, 0);
244    }
245
246    #[test]
247    fn encode_varint_simple() {
248        let mut buf = Vec::new();
249        encode_varint(150, &mut buf);
250        assert_eq!(buf, &[0x96, 0x01]);
251    }
252
253    #[test]
254    fn writer_creates_event_file() {
255        let dir = std::env::temp_dir().join("yscv_tb_test");
256        let _ = fs::remove_dir_all(&dir);
257        {
258            let mut w = TensorBoardWriter::new(&dir).unwrap();
259            w.add_scalar("loss", 0.5, 1).unwrap();
260            w.add_scalar("loss", 0.3, 2).unwrap();
261            w.flush().unwrap();
262        }
263        // Verify a file was created
264        let entries: Vec<_> = fs::read_dir(&dir).unwrap().filter_map(|e| e.ok()).collect();
265        assert_eq!(entries.len(), 1);
266        let name = entries[0].file_name().to_string_lossy().to_string();
267        assert!(name.starts_with("events.out.tfevents."));
268        // File must be non-empty (header + 2 records)
269        let meta = entries[0].metadata().unwrap();
270        assert!(meta.len() > 50);
271        let _ = fs::remove_dir_all(&dir);
272    }
273
274    #[test]
275    fn callback_logs_metrics() {
276        let dir = std::env::temp_dir().join("yscv_tb_cb_test");
277        let _ = fs::remove_dir_all(&dir);
278        {
279            let mut cb = TensorBoardCallback::new(&dir).unwrap();
280            let mut metrics = HashMap::new();
281            metrics.insert("train_loss".to_string(), 0.42f32);
282            let stop = cb.on_epoch_end(0, &metrics);
283            assert!(!stop);
284            assert_eq!(cb.global_step(), 1);
285        }
286        let _ = fs::remove_dir_all(&dir);
287    }
288}