1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::Write;
11use std::path::{Path, PathBuf};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use scirs2_core::ndarray::ArrayD;
15
16#[derive(Debug)]
18pub struct TensorBoardWriter {
19 log_dir: PathBuf,
20 run_name: String,
21 step_counter: u64,
22 scalar_logs: Vec<ScalarEvent>,
23 histogram_logs: Vec<HistogramEvent>,
24 text_logs: Vec<TextEvent>,
25 embedding_logs: Vec<EmbeddingEvent>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ScalarEvent {
31 pub tag: String,
32 pub value: f64,
33 pub step: u64,
34 pub timestamp: u64,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct HistogramEvent {
40 pub tag: String,
41 pub values: Vec<f64>,
42 pub step: u64,
43 pub timestamp: u64,
44 pub min: f64,
45 pub max: f64,
46 pub num: usize,
47 pub sum: f64,
48 pub sum_squares: f64,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TextEvent {
54 pub tag: String,
55 pub text: String,
56 pub step: u64,
57 pub timestamp: u64,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct EmbeddingEvent {
63 pub tag: String,
64 pub embeddings: Vec<Vec<f64>>,
65 pub labels: Option<Vec<String>>,
66 pub step: u64,
67 pub timestamp: u64,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GraphNode {
73 pub name: String,
74 pub op_type: String,
75 pub input_names: Vec<String>,
76 pub output_names: Vec<String>,
77 pub attributes: HashMap<String, String>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct GraphDef {
83 pub nodes: Vec<GraphNode>,
84 pub metadata: HashMap<String, String>,
85}
86
87impl TensorBoardWriter {
88 pub fn new<P: AsRef<Path>>(log_dir: P) -> Result<Self> {
102 let log_dir = log_dir.as_ref().to_path_buf();
103 let run_name = format!(
104 "run_{}",
105 SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs()
106 );
107
108 fs::create_dir_all(&log_dir)?;
110
111 Ok(Self {
112 log_dir,
113 run_name,
114 step_counter: 0,
115 scalar_logs: Vec::new(),
116 histogram_logs: Vec::new(),
117 text_logs: Vec::new(),
118 embedding_logs: Vec::new(),
119 })
120 }
121
122 pub fn with_run_name<P: AsRef<Path>>(log_dir: P, run_name: String) -> Result<Self> {
124 let log_dir = log_dir.as_ref().to_path_buf();
125
126 fs::create_dir_all(&log_dir)?;
128
129 Ok(Self {
130 log_dir,
131 run_name,
132 step_counter: 0,
133 scalar_logs: Vec::new(),
134 histogram_logs: Vec::new(),
135 text_logs: Vec::new(),
136 embedding_logs: Vec::new(),
137 })
138 }
139
140 pub fn add_scalar(&mut self, tag: &str, value: f64, step: u64) -> Result<()> {
157 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
158
159 self.scalar_logs.push(ScalarEvent {
160 tag: tag.to_string(),
161 value,
162 step,
163 timestamp,
164 });
165
166 Ok(())
167 }
168
169 pub fn add_scalars(
171 &mut self,
172 main_tag: &str,
173 tag_scalar_dict: HashMap<String, f64>,
174 step: u64,
175 ) -> Result<()> {
176 for (tag, value) in tag_scalar_dict {
177 let full_tag = format!("{}/{}", main_tag, tag);
178 self.add_scalar(&full_tag, value, step)?;
179 }
180 Ok(())
181 }
182
183 pub fn add_histogram(&mut self, tag: &str, values: &[f64], step: u64) -> Result<()> {
200 if values.is_empty() {
201 return Ok(());
202 }
203
204 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
205
206 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
207 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
208 let sum: f64 = values.iter().sum();
209 let sum_squares: f64 = values.iter().map(|x| x * x).sum();
210
211 self.histogram_logs.push(HistogramEvent {
212 tag: tag.to_string(),
213 values: values.to_vec(),
214 step,
215 timestamp,
216 min,
217 max,
218 num: values.len(),
219 sum,
220 sum_squares,
221 });
222
223 Ok(())
224 }
225
226 pub fn add_text(&mut self, tag: &str, text: &str, step: u64) -> Result<()> {
228 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
229
230 self.text_logs.push(TextEvent {
231 tag: tag.to_string(),
232 text: text.to_string(),
233 step,
234 timestamp,
235 });
236
237 Ok(())
238 }
239
240 pub fn add_embedding(
249 &mut self,
250 tag: &str,
251 embeddings: Vec<Vec<f64>>,
252 labels: Option<Vec<String>>,
253 step: u64,
254 ) -> Result<()> {
255 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
256
257 self.embedding_logs.push(EmbeddingEvent {
258 tag: tag.to_string(),
259 embeddings,
260 labels,
261 step,
262 timestamp,
263 });
264
265 Ok(())
266 }
267
268 pub fn add_graph(&mut self, graph: &GraphDef) -> Result<()> {
270 let graph_path = self.log_dir.join(&self.run_name).join("graph.json");
271 if let Some(parent) = graph_path.parent() {
272 fs::create_dir_all(parent)?;
273 }
274
275 let graph_json = serde_json::to_string_pretty(graph)?;
276 fs::write(graph_path, graph_json)?;
277
278 Ok(())
279 }
280
281 pub fn flush(&mut self) -> Result<()> {
283 let run_dir = self.log_dir.join(&self.run_name);
284 fs::create_dir_all(&run_dir)?;
285
286 if !self.scalar_logs.is_empty() {
288 let scalars_path = run_dir.join("scalars.jsonl");
289 let mut file = File::create(scalars_path)?;
290 for event in &self.scalar_logs {
291 let line = serde_json::to_string(event)?;
292 writeln!(file, "{}", line)?;
293 }
294 }
295
296 if !self.histogram_logs.is_empty() {
298 let histograms_path = run_dir.join("histograms.jsonl");
299 let mut file = File::create(histograms_path)?;
300 for event in &self.histogram_logs {
301 let line = serde_json::to_string(event)?;
302 writeln!(file, "{}", line)?;
303 }
304 }
305
306 if !self.text_logs.is_empty() {
308 let text_path = run_dir.join("text.jsonl");
309 let mut file = File::create(text_path)?;
310 for event in &self.text_logs {
311 let line = serde_json::to_string(event)?;
312 writeln!(file, "{}", line)?;
313 }
314 }
315
316 if !self.embedding_logs.is_empty() {
318 let embeddings_path = run_dir.join("embeddings.jsonl");
319 let mut file = File::create(embeddings_path)?;
320 for event in &self.embedding_logs {
321 let line = serde_json::to_string(event)?;
322 writeln!(file, "{}", line)?;
323 }
324 }
325
326 Ok(())
327 }
328
329 pub fn log_dir(&self) -> &Path {
331 &self.log_dir
332 }
333
334 pub fn run_name(&self) -> &str {
336 &self.run_name
337 }
338
339 pub fn increment_step(&mut self) -> u64 {
341 self.step_counter += 1;
342 self.step_counter
343 }
344
345 pub fn current_step(&self) -> u64 {
347 self.step_counter
348 }
349
350 pub fn close(mut self) -> Result<()> {
352 self.flush()
353 }
354}
355
356impl Drop for TensorBoardWriter {
357 fn drop(&mut self) {
358 let _ = self.flush();
360 }
361}
362
363pub fn create_graph_node(
365 name: String,
366 op_type: String,
367 inputs: Vec<String>,
368 outputs: Vec<String>,
369) -> GraphNode {
370 GraphNode {
371 name,
372 op_type,
373 input_names: inputs,
374 output_names: outputs,
375 attributes: HashMap::new(),
376 }
377}
378
379pub fn tensor_to_histogram_values(tensor: &ArrayD<f32>) -> Vec<f64> {
381 tensor.iter().map(|&x| x as f64).collect()
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use std::env;
388
389 #[test]
390 fn test_tensorboard_writer_creation() {
391 let temp_dir = env::temp_dir().join("tensorboard_test");
392 let writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
393 assert!(writer.log_dir().exists());
394 }
395
396 #[test]
397 fn test_add_scalar() {
398 let temp_dir = env::temp_dir().join("tensorboard_scalar_test");
399 let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
400
401 writer.add_scalar("test/loss", 0.5, 0).expect("add operation failed");
402 writer.add_scalar("test/loss", 0.4, 1).expect("add operation failed");
403 writer.add_scalar("test/loss", 0.3, 2).expect("add operation failed");
404
405 assert_eq!(writer.scalar_logs.len(), 3);
406 assert_eq!(writer.scalar_logs[0].value, 0.5);
407 assert_eq!(writer.scalar_logs[1].value, 0.4);
408 }
409
410 #[test]
411 fn test_add_histogram() {
412 let temp_dir = env::temp_dir().join("tensorboard_histogram_test");
413 let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
414
415 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
416 writer.add_histogram("test/weights", &values, 0).expect("add operation failed");
417
418 assert_eq!(writer.histogram_logs.len(), 1);
419 assert_eq!(writer.histogram_logs[0].min, 1.0);
420 assert_eq!(writer.histogram_logs[0].max, 5.0);
421 assert_eq!(writer.histogram_logs[0].num, 5);
422 }
423
424 #[test]
425 fn test_add_text() {
426 let temp_dir = env::temp_dir().join("tensorboard_text_test");
427 let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
428
429 writer.add_text("test/note", "This is a test", 0).expect("add operation failed");
430 assert_eq!(writer.text_logs.len(), 1);
431 assert_eq!(writer.text_logs[0].text, "This is a test");
432 }
433
434 #[test]
435 fn test_flush() {
436 let temp_dir = env::temp_dir().join("tensorboard_flush_test");
437 let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
438
439 writer.add_scalar("test/metric", 1.0, 0).expect("add operation failed");
440 writer.flush().expect("operation failed in test");
441
442 let scalars_path = temp_dir.join(writer.run_name()).join("scalars.jsonl");
443 assert!(scalars_path.exists());
444 }
445
446 #[test]
447 fn test_add_scalars() {
448 let temp_dir = env::temp_dir().join("tensorboard_scalars_test");
449 let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
450
451 let mut metrics = HashMap::new();
452 metrics.insert("loss".to_string(), 0.5);
453 metrics.insert("accuracy".to_string(), 0.95);
454
455 writer.add_scalars("train", metrics, 0).expect("add operation failed");
456 assert_eq!(writer.scalar_logs.len(), 2);
457 }
458
459 #[test]
460 fn test_add_embedding() {
461 let temp_dir = env::temp_dir().join("tensorboard_embedding_test");
462 let mut writer = TensorBoardWriter::new(&temp_dir).expect("tensor operation failed");
463
464 let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
465 let labels = vec!["class1".to_string(), "class2".to_string()];
466
467 writer
468 .add_embedding("test/emb", embeddings, Some(labels), 0)
469 .expect("add operation failed");
470 assert_eq!(writer.embedding_logs.len(), 1);
471 }
472
473 #[test]
474 fn test_graph_node_creation() {
475 let node = create_graph_node(
476 "layer1".to_string(),
477 "Linear".to_string(),
478 vec!["input".to_string()],
479 vec!["output".to_string()],
480 );
481
482 assert_eq!(node.name, "layer1");
483 assert_eq!(node.op_type, "Linear");
484 assert_eq!(node.input_names.len(), 1);
485 assert_eq!(node.output_names.len(), 1);
486 }
487}