yscv_model/
tensorboard.rs1use std::collections::HashMap;
2use std::fs::{self, File};
3use std::io::{BufWriter, Write};
4use std::path::Path;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use super::callbacks::TrainingCallback;
8use super::error::ModelError;
9
10const CRC32C_TABLE: [u32; 256] = {
13 let mut table = [0u32; 256];
14 let mut i = 0usize;
15 while i < 256 {
16 let mut crc = i as u32;
17 let mut j = 0;
18 while j < 8 {
19 if crc & 1 != 0 {
20 crc = (crc >> 1) ^ 0x82F6_3B78;
21 } else {
22 crc >>= 1;
23 }
24 j += 1;
25 }
26 table[i] = crc;
27 i += 1;
28 }
29 table
30};
31
32fn crc32c(data: &[u8]) -> u32 {
33 let mut crc: u32 = !0;
34 for &b in data {
35 crc = CRC32C_TABLE[((crc ^ b as u32) & 0xFF) as usize] ^ (crc >> 8);
36 }
37 !crc
38}
39
40fn masked_crc32c(data: &[u8]) -> u32 {
41 let crc = crc32c(data);
42 crc.rotate_right(15).wrapping_add(0xa282_ead8)
43}
44
45fn encode_varint(mut value: u64, buf: &mut Vec<u8>) {
48 while value >= 0x80 {
49 buf.push((value as u8) | 0x80);
50 value >>= 7;
51 }
52 buf.push(value as u8);
53}
54
55fn encode_summary_value(tag: &str, value: f32) -> Vec<u8> {
57 let mut buf = Vec::new();
58 buf.push(0x0a);
60 encode_varint(tag.len() as u64, &mut buf);
61 buf.extend_from_slice(tag.as_bytes());
62 buf.push(0x15);
64 buf.extend_from_slice(&value.to_le_bytes());
65 buf
66}
67
68fn encode_summary(tag: &str, value: f32) -> Vec<u8> {
70 let val = encode_summary_value(tag, value);
71 let mut buf = Vec::new();
72 buf.push(0x0a);
74 encode_varint(val.len() as u64, &mut buf);
75 buf.extend_from_slice(&val);
76 buf
77}
78
79fn encode_event_summary(wall_time: f64, step: i64, tag: &str, value: f32) -> Vec<u8> {
81 let summary = encode_summary(tag, value);
82 let mut buf = Vec::new();
83 buf.push(0x09);
85 buf.extend_from_slice(&wall_time.to_le_bytes());
86 buf.push(0x10);
88 encode_varint(step as u64, &mut buf);
89 buf.push(0x2a);
91 encode_varint(summary.len() as u64, &mut buf);
92 buf.extend_from_slice(&summary);
93 buf
94}
95
96fn encode_event_file_version(wall_time: f64, step: i64, version: &str) -> Vec<u8> {
98 let mut buf = Vec::new();
99 buf.push(0x09);
101 buf.extend_from_slice(&wall_time.to_le_bytes());
102 buf.push(0x10);
104 encode_varint(step as u64, &mut buf);
105 buf.push(0x32);
107 encode_varint(version.len() as u64, &mut buf);
108 buf.extend_from_slice(version.as_bytes());
109 buf
110}
111
112pub struct TensorBoardWriter {
116 file: BufWriter<File>,
117}
118
119impl TensorBoardWriter {
120 pub fn new(log_dir: impl AsRef<Path>) -> Result<Self, ModelError> {
125 let log_dir = log_dir.as_ref();
126 fs::create_dir_all(log_dir).map_err(|e| ModelError::DatasetLoadIo {
127 path: log_dir.display().to_string(),
128 message: e.to_string(),
129 })?;
130
131 let timestamp = SystemTime::now()
132 .duration_since(UNIX_EPOCH)
133 .unwrap_or_default()
134 .as_secs();
135
136 let hostname = "localhost";
137 let filename = format!("events.out.tfevents.{timestamp}.{hostname}");
138 let filepath = log_dir.join(filename);
139
140 let f = File::create(&filepath).map_err(|e| ModelError::DatasetLoadIo {
141 path: filepath.display().to_string(),
142 message: e.to_string(),
143 })?;
144 let mut writer = Self {
145 file: BufWriter::new(f),
146 };
147
148 let wall_time = timestamp as f64;
149 let event = encode_event_file_version(wall_time, 0, "brain.Event:2");
150 writer.write_record(&event)?;
151
152 Ok(writer)
153 }
154
155 pub fn add_scalar(&mut self, tag: &str, value: f32, step: i64) -> Result<(), ModelError> {
157 let wall_time = SystemTime::now()
158 .duration_since(UNIX_EPOCH)
159 .unwrap_or_default()
160 .as_secs_f64();
161 let event = encode_event_summary(wall_time, step, tag, value);
162 self.write_record(&event)
163 }
164
165 pub fn flush(&mut self) -> Result<(), ModelError> {
167 self.file.flush().map_err(|e| ModelError::DatasetLoadIo {
168 path: "tensorboard events".to_string(),
169 message: e.to_string(),
170 })
171 }
172
173 fn write_record(&mut self, data: &[u8]) -> Result<(), ModelError> {
175 let len = data.len() as u64;
176 let len_bytes = len.to_le_bytes();
177 let len_crc = masked_crc32c(&len_bytes);
178 let data_crc = masked_crc32c(data);
179
180 let w = &mut self.file;
181 w.write_all(&len_bytes)
182 .and_then(|_| w.write_all(&len_crc.to_le_bytes()))
183 .and_then(|_| w.write_all(data))
184 .and_then(|_| w.write_all(&data_crc.to_le_bytes()))
185 .map_err(|e| ModelError::DatasetLoadIo {
186 path: "tensorboard events".to_string(),
187 message: e.to_string(),
188 })
189 }
190}
191
192pub struct TensorBoardCallback {
196 writer: TensorBoardWriter,
197 global_step: i64,
198}
199
200impl TensorBoardCallback {
201 pub fn new(log_dir: impl AsRef<Path>) -> Result<Self, ModelError> {
203 Ok(Self {
204 writer: TensorBoardWriter::new(log_dir)?,
205 global_step: 0,
206 })
207 }
208
209 pub fn global_step(&self) -> i64 {
211 self.global_step
212 }
213}
214
215impl TrainingCallback for TensorBoardCallback {
216 fn on_epoch_end(&mut self, _epoch: usize, metrics: &HashMap<String, f32>) -> bool {
217 self.global_step += 1;
218 for (tag, &value) in metrics {
219 let _ = self.writer.add_scalar(tag, value, self.global_step);
220 }
221 let _ = self.writer.flush();
222 false
223 }
224
225 fn on_batch_end(&mut self, _epoch: usize, _batch: usize, loss: f32) {
226 self.global_step += 1;
227 let _ = self.writer.add_scalar("batch_loss", loss, self.global_step);
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn masked_crc32c_known_values() {
237 let crc = crc32c(b"");
239 assert_eq!(crc, 0x0000_0000);
240
241 let m = masked_crc32c(b"hello");
243 assert_ne!(m, 0);
244 }
245
246 #[test]
247 fn encode_varint_simple() {
248 let mut buf = Vec::new();
249 encode_varint(150, &mut buf);
250 assert_eq!(buf, &[0x96, 0x01]);
251 }
252
253 #[test]
254 fn writer_creates_event_file() {
255 let dir = std::env::temp_dir().join("yscv_tb_test");
256 let _ = fs::remove_dir_all(&dir);
257 {
258 let mut w = TensorBoardWriter::new(&dir).unwrap();
259 w.add_scalar("loss", 0.5, 1).unwrap();
260 w.add_scalar("loss", 0.3, 2).unwrap();
261 w.flush().unwrap();
262 }
263 let entries: Vec<_> = fs::read_dir(&dir).unwrap().filter_map(|e| e.ok()).collect();
265 assert_eq!(entries.len(), 1);
266 let name = entries[0].file_name().to_string_lossy().to_string();
267 assert!(name.starts_with("events.out.tfevents."));
268 let meta = entries[0].metadata().unwrap();
270 assert!(meta.len() > 50);
271 let _ = fs::remove_dir_all(&dir);
272 }
273
274 #[test]
275 fn callback_logs_metrics() {
276 let dir = std::env::temp_dir().join("yscv_tb_cb_test");
277 let _ = fs::remove_dir_all(&dir);
278 {
279 let mut cb = TensorBoardCallback::new(&dir).unwrap();
280 let mut metrics = HashMap::new();
281 metrics.insert("train_loss".to_string(), 0.42f32);
282 let stop = cb.on_epoch_end(0, &metrics);
283 assert!(!stop);
284 assert_eq!(cb.global_step(), 1);
285 }
286 let _ = fs::remove_dir_all(&dir);
287 }
288}