1use std::{
3 collections::HashMap,
4 sync::{Arc, Mutex},
5 thread::ThreadId,
6 time::Instant,
7};
8
9use crate::{
10 data::{EventCounts, LogTree, StoringFieldVisitor},
11 env_utils::{get_bool_env_var, get_env_var},
12 errors::err_msg,
13};
14use linear_map::LinearMap;
15use tracing::span;
16
17#[derive(Debug, Clone)]
19pub struct Config {
20 pub attention_above_percent: f64,
23
24 pub relevant_above_percent: f64,
28
29 pub hide_below_percent: f64,
33
34 pub display_unaccounted: bool,
38
39 pub accumulate_events: bool,
43
44 pub accumulate_spans_count: bool,
48
49 pub no_color: bool,
52}
53
54impl Config {
55 fn from_env() -> Self {
56 Self {
57 attention_above_percent: get_env_var("TREE_LAYER_ATTENTION_ABOVE", 25.0),
58 relevant_above_percent: get_env_var("TREE_LAYER_RELEVANT_ABOVE", 2.5),
59 hide_below_percent: get_env_var("TREE_LAYER_HIDE_BELOW", 1.0),
60 display_unaccounted: get_env_var("TREE_LAYER_DISPLAY_", false),
61 accumulate_events: get_bool_env_var("TREE_LAYER_ACCUMULATE_EVENTS", true),
62 accumulate_spans_count: get_bool_env_var("TREE_LAYER_ACCUMULATE_SPANS_COUNT", false),
63 no_color: get_bool_env_var("NO_COLOR", false),
64 }
65 }
66}
67
68impl Default for Config {
69 fn default() -> Self {
70 Self::from_env()
71 }
72}
73
74#[derive(Default)]
75struct State {
76 current_span: Option<span::Id>,
77 unfinished_spans: LinearMap<u64, GraphNode>,
78 zero_level_events: EventCounts,
79}
80
81impl State {
82 fn print_zero_level_events(&mut self) {
83 if !self.zero_level_events.is_empty() {
84 println!("> {}\n", self.zero_level_events.format().join("\n> "));
85
86 self.zero_level_events.clear();
87 }
88 }
89}
90
91pub struct Guard {
92 state: Arc<Mutex<State>>,
93}
94
95impl Drop for Guard {
96 fn drop(&mut self) {
97 let Ok(mut state) = self.state.lock() else {
98 return err_msg!("failed to get mutex");
99 };
100
101 state.print_zero_level_events();
102 }
103}
104
105pub struct Layer {
135 main_thread: ThreadId,
136 state: Arc<Mutex<State>>,
137 config: Config,
138}
139
140impl Layer {
141 pub fn new(config: Config) -> (Self, Guard) {
142 let state = Arc::new(Mutex::new(State::default()));
143 let layer = Self {
144 main_thread: std::thread::current().id(),
145 state: state.clone(),
146 config: config.clone(),
147 };
148 let guard = Guard { state };
149
150 (layer, guard)
151 }
152
153 fn is_main_thread(&self) -> bool {
154 self.main_thread == std::thread::current().id()
155 }
156}
157
158impl<S> tracing_subscriber::Layer<S> for Layer
159where
160 S: tracing::Subscriber,
161 S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
163{
164 fn on_new_span(
165 &self,
166 attrs: &span::Attributes<'_>,
167 id: &span::Id,
168 _ctx: tracing_subscriber::layer::Context<'_, S>,
169 ) {
170 if !self.is_main_thread() {
171 return;
172 }
173
174 let mut graph_node = GraphNode {
175 call_count: 1,
176 ..Default::default()
177 };
178 let mut visitor = StoringFieldVisitor(&mut graph_node.metadata);
179 attrs.record(&mut visitor);
180
181 let Ok(mut state) = self.state.lock() else {
182 return err_msg!("failed to get mutex");
183 };
184
185 state.unfinished_spans.insert(id.into_u64(), graph_node);
186 }
187
188 fn on_record(
189 &self,
190 id: &span::Id,
191 values: &span::Record<'_>,
192 _ctx: tracing_subscriber::layer::Context<'_, S>,
193 ) {
194 if !self.is_main_thread() {
195 return;
196 }
197
198 let Ok(mut state) = self.state.lock() else {
199 return err_msg!("failed to get mutex");
200 };
201
202 if let Some(graph_node) = state.unfinished_spans.get_mut(&id.into_u64()) {
203 let mut visitor = StoringFieldVisitor(&mut graph_node.metadata);
204 values.record(&mut visitor);
205 }
206 }
207
208 fn on_enter(&self, id: &span::Id, _ctx: tracing_subscriber::layer::Context<'_, S>) {
209 if !self.is_main_thread() {
210 return;
211 }
212
213 let Ok(mut state) = self.state.lock() else {
214 return err_msg!("failed to get mutex");
215 };
216
217 state.current_span = Some(id.clone());
218 if let Some(graph_node) = state.unfinished_spans.get_mut(&id.into_u64()) {
219 graph_node.started = Some(Instant::now());
220 }
221
222 state.print_zero_level_events();
223 }
224
225 fn on_exit(&self, id: &span::Id, ctx: tracing_subscriber::layer::Context<'_, S>) {
226 if !self.is_main_thread() {
227 return;
228 }
229
230 let Some(span) = ctx.span(id) else {
231 return err_msg!("failed to get span on_exit");
232 };
233
234 let Ok(mut state) = self.state.lock() else {
235 return err_msg!("failed to get mutex");
236 };
237
238 let mut node = state
239 .unfinished_spans
240 .remove(&id.into_u64())
241 .unwrap_or_default();
242 node.execution_duration = node
243 .started
244 .map(|started| Instant::elapsed(&started))
245 .unwrap_or_default();
246 node.name = span.name();
247
248 let parent = match span.parent() {
249 Some(p) => {
250 let Some(parent_node) = state.unfinished_spans.get_mut(&p.id().into_u64()) else {
251 return err_msg!("failed to get parent node");
252 };
253
254 parent_node.child_nodes.push(node);
255 Some(p.id().clone())
256 }
257 None => {
258 node.print(&self.config);
259
260 None
261 }
262 };
263
264 state.current_span = parent;
265 }
266
267 fn on_event(&self, event: &tracing::Event<'_>, ctx: tracing_subscriber::layer::Context<'_, S>) {
268 if event.is_root() {
269 return;
270 }
271
272 let Ok(mut state) = self.state.lock() else {
273 return err_msg!("failed to get mutex");
274 };
275
276 let span_id = if self.is_main_thread() {
277 event
278 .parent()
279 .cloned()
280 .or_else(|| ctx.current_span().id().cloned())
281 } else {
282 state.current_span.clone()
284 };
285
286 match span_id {
287 Some(span_id) => {
288 if let Some(graph_node) = state.unfinished_spans.get_mut(&span_id.into_u64()) {
289 graph_node.events.record(event);
290 }
291 }
292 None => {
293 state.zero_level_events.record(event);
294 }
295 }
296 }
297}
298
299#[derive(Default, Debug, Clone)]
300struct GraphNode {
301 name: &'static str,
302 started: Option<Instant>,
303 execution_duration: std::time::Duration,
304 metadata: LinearMap<&'static str, String>,
305 events: EventCounts,
306 child_nodes: Vec<GraphNode>,
307 call_count: usize,
308}
309
310impl GraphNode {
311 fn new(name: &'static str) -> Self {
312 Self {
313 name,
314 ..Default::default()
315 }
316 }
317
318 fn execution_percentage(&self, root_time: std::time::Duration) -> f64 {
319 100.0 * self.execution_duration.as_secs_f64() / root_time.as_secs_f64()
320 }
321
322 fn accumulate_children_events(&mut self, accumulate_spans_count: bool) {
324 for child in self.child_nodes.iter_mut() {
325 child.accumulate_children_events(accumulate_spans_count);
326
327 if accumulate_spans_count {
328 child.record_self_as_event();
329 }
330
331 self.events += &child.events;
332 }
333 }
334
335 fn record_self_as_event(&mut self) {
338 self.events.increment_events_counter(self.name);
339 }
340
341 fn print(mut self, config: &Config) {
342 if config.accumulate_events {
343 self.accumulate_children_events(config.accumulate_spans_count);
344 }
345
346 let tree = self.render_tree(self.execution_duration, config);
347 println!("{tree}");
348 }
349
350 fn label(&self, root_time: std::time::Duration, config: &Config) -> String {
351 let mut info = vec![];
352 if self.call_count > 1 {
353 info.push(format!("({} calls)", self.call_count))
354 } else if !self.metadata.is_empty() {
355 let kv: Vec<_> = self
356 .metadata
357 .iter()
358 .map(|(k, v)| format!("{k} = {v}"))
359 .collect();
360 info.push(format!("{{ {} }}", kv.join(", ")))
361 }
362
363 let name = &self.name;
364 let execution_time = self.execution_duration;
365 let execution_time_percent = self.execution_percentage(root_time);
366 let mut result = format!("{name} [ {execution_time:.2?} | {execution_time_percent:.2}% ]");
367 if !info.is_empty() {
368 result = format!("{result} {}", info.join(" "));
369 }
370
371 if config.no_color {
372 result
373 } else {
374 format!(
375 "{}{}\x1b[0m",
376 if execution_time_percent > config.attention_above_percent {
377 "\x1b[1;31m" } else if execution_time_percent > config.relevant_above_percent {
379 "\x1b[0m" } else {
381 "\x1b[2m" },
383 result
384 )
385 }
386 }
387
388 fn render_tree(&self, root_time: std::time::Duration, config: &Config) -> LogTree {
389 let mut children = vec![];
390 let mut aggregated_node: Option<GraphNode> = None;
391 let mut name_counter: HashMap<&str, usize> = HashMap::new();
392
393 for (i, child) in self.child_nodes.iter().enumerate() {
394 let name_count = name_counter.entry(child.name).or_insert(0);
395 *name_count += 1;
396
397 let next = self.child_nodes.get(i + 1);
398 if next.is_some_and(|next| next.name == child.name) {
399 if child.execution_percentage(root_time) > config.relevant_above_percent {
400 let mut indexed_child = child.clone();
401 indexed_child
402 .metadata
403 .insert("index", format!("{name_count}"));
404 children.push(indexed_child);
405 } else {
406 aggregated_node = aggregated_node
407 .map(|node| node.clone().aggregate(child))
408 .or_else(|| Some(child.clone()));
409 }
410 } else {
411 let child = aggregated_node.take().unwrap_or_else(|| child.clone());
412 children.push(child);
413 }
414 }
415
416 if config.hide_below_percent > 0.0 {
417 children = children.into_iter().fold(vec![], |acc, child| {
418 let mut acc = acc;
419 if child.execution_percentage(root_time) < config.hide_below_percent {
420 if let Some(x) = acc.last_mut() {
421 if x.name == "[...]" {
422 *x = x.clone().aggregate(&child);
423 } else {
424 acc.push(GraphNode::new("[...]").aggregate(&child))
425 }
426 }
427 } else {
428 acc.push(child);
429 }
430 acc
431 });
432 }
433
434 if config.display_unaccounted && !children.is_empty() {
435 let mut unaccounted = GraphNode::new("[unaccounted]");
436 unaccounted.execution_duration = self.execution_duration
437 - self
438 .child_nodes
439 .iter()
440 .map(|x| x.execution_duration)
441 .fold(std::time::Duration::new(0, 0), |x, y| x + y);
442
443 children.insert(0, unaccounted);
444 }
445
446 LogTree {
447 label: self.label(root_time, config),
448 events: self.events.format(),
449 children: children
450 .into_iter()
451 .map(|child| child.render_tree(root_time, config))
452 .collect(),
453 }
454 }
455
456 fn aggregate(mut self, other: &GraphNode) -> Self {
457 self.execution_duration += other.execution_duration;
458 self.call_count += other.call_count;
459 self.events += &other.events;
460
461 self
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use {crate::data::CounterValue, tracing_subscriber::util::SubscriberInitExt};
468 use {
469 crate::{PrintTreeConfig, PrintTreeLayer},
470 tracing_subscriber::layer::SubscriberExt,
471 };
472 use {
473 std::{thread, time::Duration},
474 tracing::{debug_span, event, Level},
475 };
476
477 #[test]
478 fn test_incremental_events_counts() {
479 let (layer, guard) = PrintTreeLayer::new(PrintTreeConfig::default());
480 let layer = tracing_subscriber::registry().with(layer);
481 layer.try_init().unwrap();
482
483 let span = debug_span!("root span");
484 let _scope1 = span.enter();
485 thread::sleep(Duration::from_millis(20));
486 event!(name: "proof_size", Level::INFO, counter=true, incremental=true, value=1);
487 let span2 = debug_span!("child span1", field1 = "value1", perfetto_track_id = 5);
489 let scope2 = span2.enter();
490 thread::sleep(Duration::from_millis(20));
491 drop(scope2);
492
493 let span3 = debug_span!(
494 "child span2",
495 field2 = "value2",
496 value = 20,
497 perfetto_track_id = 5,
498 perfetto_flow_id = 10
499 );
500 let _scope3 = span3.enter();
501
502 thread::sleep(Duration::from_millis(20));
503 event!(name: "proof_size", Level::INFO, counter=true, incremental=true, value=3);
504
505 let span = debug_span!("child span3", field3 = "value3");
507 let scope = span.enter();
508 thread::sleep(Duration::from_millis(20));
509 event!(name: "custom event", Level::DEBUG, {field5 = "value5", counter = true, value = 30});
510 drop(scope);
511
512 thread::spawn(|| {
513 let span = debug_span!("child span5", field5 = "value5");
514 let _scope = span.enter();
515 thread::sleep(Duration::from_millis(20));
516 event!(name: "proof_size", Level::INFO, counter=true, incremental=true, value=6);
517 })
518 .join()
519 .unwrap();
520
521 let span = debug_span!("child span4", field4 = "value4", perfetto_flow_id = 10);
522 thread::sleep(Duration::from_millis(20));
523 event!(name: "custom event", Level::DEBUG, {field5 = "value5", counter = true, value = 40});
524 let scope = span.enter();
525 thread::sleep(Duration::from_millis(20));
526 drop(scope);
527 drop(_scope3);
528
529 let mut state = guard.state.lock().unwrap();
530 let root = state.unfinished_spans.get_mut(&1).unwrap();
531
532 root.accumulate_children_events(true);
533
534 assert_eq!(
535 *root.events.get("proof_size").unwrap(),
536 CounterValue::Int(10)
537 );
538
539 state.unfinished_spans.remove(&1).unwrap();
541 }
542}