1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::Write;
11use std::path::{Path, PathBuf};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use scirs2_core::ndarray::ArrayD;
15
16#[derive(Debug)]
18pub struct TensorBoardWriter {
19 log_dir: PathBuf,
20 run_name: String,
21 step_counter: u64,
22 scalar_logs: Vec<ScalarEvent>,
23 histogram_logs: Vec<HistogramEvent>,
24 text_logs: Vec<TextEvent>,
25 embedding_logs: Vec<EmbeddingEvent>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ScalarEvent {
31 pub tag: String,
32 pub value: f64,
33 pub step: u64,
34 pub timestamp: u64,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HistogramEvent {
40 pub tag: String,
41 pub values: Vec<f64>,
42 pub step: u64,
43 pub timestamp: u64,
44 pub min: f64,
45 pub max: f64,
46 pub num: usize,
47 pub sum: f64,
48 pub sum_squares: f64,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TextEvent {
54 pub tag: String,
55 pub text: String,
56 pub step: u64,
57 pub timestamp: u64,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct EmbeddingEvent {
63 pub tag: String,
64 pub embeddings: Vec<Vec<f64>>,
65 pub labels: Option<Vec<String>>,
66 pub step: u64,
67 pub timestamp: u64,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GraphNode {
73 pub name: String,
74 pub op_type: String,
75 pub input_names: Vec<String>,
76 pub output_names: Vec<String>,
77 pub attributes: HashMap<String, String>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct GraphDef {
83 pub nodes: Vec<GraphNode>,
84 pub metadata: HashMap<String, String>,
85}
86
87impl TensorBoardWriter {
88 pub fn new<P: AsRef<Path>>(log_dir: P) -> Result<Self> {
102 let log_dir = log_dir.as_ref().to_path_buf();
103 let run_name = format!(
104 "run_{}",
105 SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs()
106 );
107
108 fs::create_dir_all(&log_dir)?;
110
111 Ok(Self {
112 log_dir,
113 run_name,
114 step_counter: 0,
115 scalar_logs: Vec::new(),
116 histogram_logs: Vec::new(),
117 text_logs: Vec::new(),
118 embedding_logs: Vec::new(),
119 })
120 }
121
122 pub fn with_run_name<P: AsRef<Path>>(log_dir: P, run_name: String) -> Result<Self> {
124 let log_dir = log_dir.as_ref().to_path_buf();
125
126 fs::create_dir_all(&log_dir)?;
128
129 Ok(Self {
130 log_dir,
131 run_name,
132 step_counter: 0,
133 scalar_logs: Vec::new(),
134 histogram_logs: Vec::new(),
135 text_logs: Vec::new(),
136 embedding_logs: Vec::new(),
137 })
138 }
139
140 pub fn add_scalar(&mut self, tag: &str, value: f64, step: u64) -> Result<()> {
157 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
158
159 self.scalar_logs.push(ScalarEvent {
160 tag: tag.to_string(),
161 value,
162 step,
163 timestamp,
164 });
165
166 Ok(())
167 }
168
169 pub fn add_scalars(
171 &mut self,
172 main_tag: &str,
173 tag_scalar_dict: HashMap<String, f64>,
174 step: u64,
175 ) -> Result<()> {
176 for (tag, value) in tag_scalar_dict {
177 let full_tag = format!("{}/{}", main_tag, tag);
178 self.add_scalar(&full_tag, value, step)?;
179 }
180 Ok(())
181 }
182
183 pub fn add_histogram(&mut self, tag: &str, values: &[f64], step: u64) -> Result<()> {
200 if values.is_empty() {
201 return Ok(());
202 }
203
204 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
205
206 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
207 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
208 let sum: f64 = values.iter().sum();
209 let sum_squares: f64 = values.iter().map(|x| x * x).sum();
210
211 self.histogram_logs.push(HistogramEvent {
212 tag: tag.to_string(),
213 values: values.to_vec(),
214 step,
215 timestamp,
216 min,
217 max,
218 num: values.len(),
219 sum,
220 sum_squares,
221 });
222
223 Ok(())
224 }
225
226 pub fn add_text(&mut self, tag: &str, text: &str, step: u64) -> Result<()> {
228 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
229
230 self.text_logs.push(TextEvent {
231 tag: tag.to_string(),
232 text: text.to_string(),
233 step,
234 timestamp,
235 });
236
237 Ok(())
238 }
239
240 pub fn add_embedding(
249 &mut self,
250 tag: &str,
251 embeddings: Vec<Vec<f64>>,
252 labels: Option<Vec<String>>,
253 step: u64,
254 ) -> Result<()> {
255 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
256
257 self.embedding_logs.push(EmbeddingEvent {
258 tag: tag.to_string(),
259 embeddings,
260 labels,
261 step,
262 timestamp,
263 });
264
265 Ok(())
266 }
267
268 pub fn add_graph(&mut self, graph: &GraphDef) -> Result<()> {
270 let graph_path = self.log_dir.join(&self.run_name).join("graph.json");
271 fs::create_dir_all(graph_path.parent().unwrap())?;
272
273 let graph_json = serde_json::to_string_pretty(graph)?;
274 fs::write(graph_path, graph_json)?;
275
276 Ok(())
277 }
278
279 pub fn flush(&mut self) -> Result<()> {
281 let run_dir = self.log_dir.join(&self.run_name);
282 fs::create_dir_all(&run_dir)?;
283
284 if !self.scalar_logs.is_empty() {
286 let scalars_path = run_dir.join("scalars.jsonl");
287 let mut file = File::create(scalars_path)?;
288 for event in &self.scalar_logs {
289 let line = serde_json::to_string(event)?;
290 writeln!(file, "{}", line)?;
291 }
292 }
293
294 if !self.histogram_logs.is_empty() {
296 let histograms_path = run_dir.join("histograms.jsonl");
297 let mut file = File::create(histograms_path)?;
298 for event in &self.histogram_logs {
299 let line = serde_json::to_string(event)?;
300 writeln!(file, "{}", line)?;
301 }
302 }
303
304 if !self.text_logs.is_empty() {
306 let text_path = run_dir.join("text.jsonl");
307 let mut file = File::create(text_path)?;
308 for event in &self.text_logs {
309 let line = serde_json::to_string(event)?;
310 writeln!(file, "{}", line)?;
311 }
312 }
313
314 if !self.embedding_logs.is_empty() {
316 let embeddings_path = run_dir.join("embeddings.jsonl");
317 let mut file = File::create(embeddings_path)?;
318 for event in &self.embedding_logs {
319 let line = serde_json::to_string(event)?;
320 writeln!(file, "{}", line)?;
321 }
322 }
323
324 Ok(())
325 }
326
327 pub fn log_dir(&self) -> &Path {
329 &self.log_dir
330 }
331
332 pub fn run_name(&self) -> &str {
334 &self.run_name
335 }
336
337 pub fn increment_step(&mut self) -> u64 {
339 self.step_counter += 1;
340 self.step_counter
341 }
342
343 pub fn current_step(&self) -> u64 {
345 self.step_counter
346 }
347
348 pub fn close(mut self) -> Result<()> {
350 self.flush()
351 }
352}
353
354impl Drop for TensorBoardWriter {
355 fn drop(&mut self) {
356 let _ = self.flush();
358 }
359}
360
361pub fn create_graph_node(
363 name: String,
364 op_type: String,
365 inputs: Vec<String>,
366 outputs: Vec<String>,
367) -> GraphNode {
368 GraphNode {
369 name,
370 op_type,
371 input_names: inputs,
372 output_names: outputs,
373 attributes: HashMap::new(),
374 }
375}
376
377pub fn tensor_to_histogram_values(tensor: &ArrayD<f32>) -> Vec<f64> {
379 tensor.iter().map(|&x| x as f64).collect()
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use std::env;
386
387 #[test]
388 fn test_tensorboard_writer_creation() {
389 let temp_dir = env::temp_dir().join("tensorboard_test");
390 let writer = TensorBoardWriter::new(&temp_dir).unwrap();
391 assert!(writer.log_dir().exists());
392 }
393
394 #[test]
395 fn test_add_scalar() {
396 let temp_dir = env::temp_dir().join("tensorboard_scalar_test");
397 let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
398
399 writer.add_scalar("test/loss", 0.5, 0).unwrap();
400 writer.add_scalar("test/loss", 0.4, 1).unwrap();
401 writer.add_scalar("test/loss", 0.3, 2).unwrap();
402
403 assert_eq!(writer.scalar_logs.len(), 3);
404 assert_eq!(writer.scalar_logs[0].value, 0.5);
405 assert_eq!(writer.scalar_logs[1].value, 0.4);
406 }
407
408 #[test]
409 fn test_add_histogram() {
410 let temp_dir = env::temp_dir().join("tensorboard_histogram_test");
411 let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
412
413 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
414 writer.add_histogram("test/weights", &values, 0).unwrap();
415
416 assert_eq!(writer.histogram_logs.len(), 1);
417 assert_eq!(writer.histogram_logs[0].min, 1.0);
418 assert_eq!(writer.histogram_logs[0].max, 5.0);
419 assert_eq!(writer.histogram_logs[0].num, 5);
420 }
421
422 #[test]
423 fn test_add_text() {
424 let temp_dir = env::temp_dir().join("tensorboard_text_test");
425 let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
426
427 writer.add_text("test/note", "This is a test", 0).unwrap();
428 assert_eq!(writer.text_logs.len(), 1);
429 assert_eq!(writer.text_logs[0].text, "This is a test");
430 }
431
432 #[test]
433 fn test_flush() {
434 let temp_dir = env::temp_dir().join("tensorboard_flush_test");
435 let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
436
437 writer.add_scalar("test/metric", 1.0, 0).unwrap();
438 writer.flush().unwrap();
439
440 let scalars_path = temp_dir.join(writer.run_name()).join("scalars.jsonl");
441 assert!(scalars_path.exists());
442 }
443
444 #[test]
445 fn test_add_scalars() {
446 let temp_dir = env::temp_dir().join("tensorboard_scalars_test");
447 let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
448
449 let mut metrics = HashMap::new();
450 metrics.insert("loss".to_string(), 0.5);
451 metrics.insert("accuracy".to_string(), 0.95);
452
453 writer.add_scalars("train", metrics, 0).unwrap();
454 assert_eq!(writer.scalar_logs.len(), 2);
455 }
456
457 #[test]
458 fn test_add_embedding() {
459 let temp_dir = env::temp_dir().join("tensorboard_embedding_test");
460 let mut writer = TensorBoardWriter::new(&temp_dir).unwrap();
461
462 let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
463 let labels = vec!["class1".to_string(), "class2".to_string()];
464
465 writer.add_embedding("test/emb", embeddings, Some(labels), 0).unwrap();
466 assert_eq!(writer.embedding_logs.len(), 1);
467 }
468
469 #[test]
470 fn test_graph_node_creation() {
471 let node = create_graph_node(
472 "layer1".to_string(),
473 "Linear".to_string(),
474 vec!["input".to_string()],
475 vec!["output".to_string()],
476 );
477
478 assert_eq!(node.name, "layer1");
479 assert_eq!(node.op_type, "Linear");
480 assert_eq!(node.input_names.len(), 1);
481 assert_eq!(node.output_names.len(), 1);
482 }
483}