1use super::decode::{DecodedEvent, decode_from_record_parts};
4use super::schema::SchemaCache;
5use super::types::{CpuSample, StackTrace, SystemProvider, ThreadContext, TraceEvent};
6use crate::Result;
7use crate::error::{Error, EtwConsumeError, EtwError, EtwProviderError, EtwSessionError};
8use crate::types::ProcessId;
9use crate::utils::to_utf16_nul;
10use crate::wait::Wait;
11use std::borrow::Cow;
12use std::collections::HashSet;
13use std::sync::mpsc::{self, Receiver, SyncSender};
14use std::sync::{Arc, Mutex};
15use std::thread::JoinHandle;
16use std::time::Duration;
17use windows::Win32::Foundation::ERROR_SUCCESS;
18use windows::Win32::System::Diagnostics::Etw::*;
19use windows::Win32::System::SystemInformation::GetSystemTimeAsFileTime;
20use windows::core::{GUID, PWSTR};
21
22const MAX_SESSION_NAME_LEN: usize = 1024;
23const KERNEL_LOGGER_NAME: &str = "NT Kernel Logger";
24const ERROR_ALREADY_EXISTS_CODE: u32 = 183;
25
26struct CallbackContext {
28 raw_sender: Option<SyncSender<TraceEvent>>,
29 decoded_sender: Option<SyncSender<DecodedEvent>>,
30 schema_cache: Option<Mutex<SchemaCache>>,
31 process_filter: Option<HashSet<ProcessId>>,
32 include_thread_context: bool,
33 include_stack_traces: bool,
34 include_cpu_samples: bool,
35}
36
37fn normalize_process_filter(pids: Vec<ProcessId>) -> Option<HashSet<ProcessId>> {
38 if pids.is_empty() {
39 return None;
40 }
41 Some(pids.into_iter().collect())
42}
43
44fn extract_stack_trace(record: &EVENT_RECORD) -> Option<StackTrace> {
45 if record.ExtendedDataCount == 0 || record.ExtendedData.is_null() {
46 return None;
47 }
48
49 let items = unsafe {
50 std::slice::from_raw_parts(record.ExtendedData, record.ExtendedDataCount as usize)
51 };
52
53 for item in items {
54 let ext_type = item.ExtType;
55 let is_stack32 = ext_type == EVENT_HEADER_EXT_TYPE_STACK_TRACE32 as u16;
56 let is_stack64 = ext_type == EVENT_HEADER_EXT_TYPE_STACK_TRACE64 as u16;
57 if !is_stack32 && !is_stack64 {
58 continue;
59 }
60
61 if item.DataPtr == 0 || item.DataSize < 8 {
62 continue;
63 }
64
65 let raw = unsafe {
66 std::slice::from_raw_parts(item.DataPtr as *const u8, item.DataSize as usize)
67 };
68
69 if raw.len() < 8 {
70 continue;
71 }
72
73 let match_id = u64::from_le_bytes(raw[0..8].try_into().ok()?);
74 let frame_size = if is_stack32 { 4 } else { 8 };
75
76 let mut frames = Vec::new();
77 let mut offset = 8usize;
78 while offset + frame_size <= raw.len() {
79 let addr = if frame_size == 4 {
80 let bytes: [u8; 4] = raw[offset..offset + 4].try_into().ok()?;
81 u32::from_le_bytes(bytes) as u64
82 } else {
83 let bytes: [u8; 8] = raw[offset..offset + 8].try_into().ok()?;
84 u64::from_le_bytes(bytes)
85 };
86
87 if addr != 0 {
88 frames.push(addr);
89 }
90 offset += frame_size;
91 }
92
93 return Some(StackTrace::new(match_id, frames));
94 }
95
96 None
97}
98
99fn extract_cpu_sample(record: &EVENT_RECORD) -> CpuSample {
100 let processor_number = unsafe { *(std::ptr::addr_of!(record.BufferContext) as *const u8) };
102 CpuSample::new(processor_number)
103}
104
105struct CallbackContextGuard {
111 #[allow(clippy::redundant_allocation)]
112 boxed_ctx: Box<Arc<CallbackContext>>,
113}
114
115impl CallbackContextGuard {
116 fn new(ctx: CallbackContext) -> Self {
117 Self {
118 boxed_ctx: Box::new(Arc::new(ctx)),
119 }
120 }
121
122 fn as_user_context_ptr(&self) -> *mut std::ffi::c_void {
123 self.boxed_ctx.as_ref() as *const Arc<CallbackContext> as *mut std::ffi::c_void
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub enum EventStreamMode {
130 Raw,
132 Decoded,
134 Both,
136}
137
138unsafe extern "system" fn trace_callback_fn(event_record: *mut EVENT_RECORD) {
146 let record = match unsafe { event_record.as_ref() } {
147 Some(r) => r,
148 None => return,
149 };
150 let ctx_ptr = record.UserContext as *const Arc<CallbackContext>;
151 let ctx = match unsafe { ctx_ptr.as_ref() } {
152 Some(c) => c,
153 None => return,
154 };
155
156 if let Some(filter) = &ctx.process_filter
157 && !filter.contains(&ProcessId::new(record.EventHeader.ProcessId))
158 {
159 return;
160 }
161
162 let fields = ctx
163 .schema_cache
164 .as_ref()
165 .and_then(|cache| cache.lock().ok())
166 .and_then(|mut cache| cache.parse_event_fields(record));
167
168 let payload = if record.UserDataLength > 0 && !record.UserData.is_null() {
169 unsafe {
170 std::slice::from_raw_parts(record.UserData as *const u8, record.UserDataLength as usize)
171 }
172 } else {
173 &[]
174 };
175
176 if let Some(sender) = &ctx.decoded_sender {
177 let desc = record.EventHeader.EventDescriptor;
178 let decoded = decode_from_record_parts(
179 record.EventHeader.ProviderId,
180 desc.Version,
181 desc.Opcode,
182 payload,
183 fields.as_deref(),
184 );
185 let _ = sender.try_send(decoded);
187 }
188
189 if let Some(sender) = &ctx.raw_sender {
190 let mut event = TraceEvent::from_event_record_with_fields(record, fields);
191 if ctx.include_thread_context {
192 event.thread_context = Some(ThreadContext::new(event.process_id, event.thread_id));
193 }
194 if ctx.include_stack_traces {
195 event.stack_trace = extract_stack_trace(record);
196 }
197 if ctx.include_cpu_samples {
198 event.cpu_sample = Some(extract_cpu_sample(record));
199 }
200 let _ = sender.try_send(event);
202 }
203}
204
205pub struct EventTrace {
212 name: String,
214
215 session_handle: CONTROLTRACE_HANDLE,
217
218 trace_handle: PROCESSTRACE_HANDLE,
220
221 event_rx: Option<Receiver<TraceEvent>>,
223
224 decoded_rx: Option<Receiver<DecodedEvent>>,
226
227 events_processed: usize,
229
230 started: bool,
232
233 process_thread: Option<JoinHandle<()>>,
235
236 stop_signal: Wait,
238
239 _callback_ctx_guard: CallbackContextGuard,
241}
242
243impl EventTrace {
244 pub fn builder(name: impl Into<String>) -> EventTraceBuilder {
261 EventTraceBuilder {
262 name: name.into(),
263 system_providers: Vec::new(),
264 user_providers: Vec::new(),
265 buffer_size: 64,
266 min_buffers: 2,
267 max_buffers: 20,
268 flush_interval: 1,
269 channel_capacity: 10_000,
270 stream_mode: EventStreamMode::Raw,
271 stack_traces: false,
272 thread_context: false,
273 detailed_events: false,
274 cpu_samples: false,
275 process_filter: Vec::new(),
276 }
277 }
278
279 pub fn name(&self) -> &str {
284 &self.name
285 }
286
287 pub fn events_processed(&self) -> usize {
289 self.events_processed
290 }
291
292 pub fn stop_handle(&self) -> Wait {
294 self.stop_signal.clone()
295 }
296
297 pub fn next_batch(&mut self, out_events: &mut Vec<TraceEvent>) -> Result<usize> {
301 self.next_batch_with_filter(out_events, |_| true)
302 }
303
304 pub fn next_batch_or_stopped(&mut self, out_events: &mut Vec<TraceEvent>) -> Result<usize> {
308 if self.stop_signal.is_signaled()? {
309 out_events.clear();
310 return Ok(0);
311 }
312 self.next_batch(out_events)
313 }
314
315 pub fn run_until_stopped(
319 &mut self,
320 out_events: &mut Vec<TraceEvent>,
321 poll_interval: Duration,
322 ) -> Result<()> {
323 loop {
324 if self.stop_signal.is_signaled()? {
325 out_events.clear();
326 return Ok(());
327 }
328 let _ = self.next_batch(out_events)?;
329 std::thread::sleep(poll_interval);
330 }
331 }
332
333 pub fn next_batch_with_filter<F>(
340 &mut self,
341 out_events: &mut Vec<TraceEvent>,
342 filter: F,
343 ) -> Result<usize>
344 where
345 F: Fn(&TraceEvent) -> bool,
346 {
347 let rx = self.event_rx.as_ref().ok_or_else(|| {
348 Error::Etw(EtwError::ConsumeFailed(EtwConsumeError::new(
349 Cow::Borrowed("Raw event stream is disabled for this session"),
350 )))
351 })?;
352
353 out_events.clear();
354 while let Ok(event) = rx.try_recv() {
355 if filter(&event) {
356 out_events.push(event);
357 self.events_processed += 1;
358 }
359 }
360 Ok(out_events.len())
361 }
362
363 pub fn next_batch_decoded(&mut self, out_events: &mut Vec<DecodedEvent>) -> Result<usize> {
367 self.next_batch_decoded_with_filter(out_events, |_| true)
368 }
369
370 pub fn next_batch_decoded_with_filter<F>(
374 &mut self,
375 out_events: &mut Vec<DecodedEvent>,
376 filter: F,
377 ) -> Result<usize>
378 where
379 F: Fn(&DecodedEvent) -> bool,
380 {
381 let rx = self.decoded_rx.as_ref().ok_or_else(|| {
382 Error::Etw(EtwError::ConsumeFailed(EtwConsumeError::new(
383 Cow::Borrowed("Decoded event stream is disabled for this session"),
384 )))
385 })?;
386
387 out_events.clear();
388 while let Ok(event) = rx.try_recv() {
389 if filter(&event) {
390 out_events.push(event);
391 self.events_processed += 1;
392 }
393 }
394 Ok(out_events.len())
395 }
396
397 pub fn stop(&mut self) -> Result<()> {
401 if !self.started {
402 return Ok(());
403 }
404
405 let _ = self.stop_signal.set();
406
407 let name_wide = to_utf16_nul(&self.name);
409
410 let mut properties_buffer =
411 vec![0u8; std::mem::size_of::<EVENT_TRACE_PROPERTIES>() + (MAX_SESSION_NAME_LEN * 2)];
412
413 unsafe {
414 let properties = &mut *(properties_buffer.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
415 properties.Wnode.BufferSize = properties_buffer.len() as u32;
416 properties.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
417
418 let _ = ControlTraceW(
419 self.session_handle,
420 PWSTR(name_wide.as_ptr() as *mut u16),
421 properties,
422 EVENT_TRACE_CONTROL_STOP,
423 );
424 }
425
426 if self.trace_handle.Value != u64::MAX {
428 unsafe {
429 let _ = CloseTrace(self.trace_handle);
431 }
432 self.trace_handle = PROCESSTRACE_HANDLE { Value: u64::MAX };
433 }
434
435 if let Some(handle) = self.process_thread.take() {
437 let _ = handle.join();
438 }
439
440 self.started = false;
441 Ok(())
442 }
443}
444
445impl Drop for EventTrace {
446 fn drop(&mut self) {
447 let _ = self.stop();
448 }
449}
450
451unsafe impl Send for EventTrace {}
453
454pub struct EventTraceBuilder {
459 name: String,
460 system_providers: Vec<SystemProvider>,
461 user_providers: Vec<GUID>,
462 buffer_size: u32,
463 min_buffers: u32,
464 max_buffers: u32,
465 flush_interval: u32,
466 channel_capacity: usize,
467 stream_mode: EventStreamMode,
468
469 stack_traces: bool,
471 thread_context: bool,
472 detailed_events: bool,
473 cpu_samples: bool,
474 process_filter: Vec<ProcessId>,
475}
476
477impl EventTraceBuilder {
478 pub fn system_provider(mut self, provider: SystemProvider) -> Self {
495 self.system_providers.push(provider);
496 self
497 }
498
499 pub fn user_provider(mut self, provider_guid: GUID) -> Self {
507 self.user_providers.push(provider_guid);
508 self
509 }
510
511 pub fn buffer_size(mut self, size_kb: u32) -> Self {
515 self.buffer_size = size_kb;
516 self
517 }
518
519 pub fn min_buffers(mut self, count: u32) -> Self {
521 self.min_buffers = count;
522 self
523 }
524
525 pub fn max_buffers(mut self, count: u32) -> Self {
527 self.max_buffers = count;
528 self
529 }
530
531 pub fn flush_interval(mut self, seconds: u32) -> Self {
533 self.flush_interval = seconds;
534 self
535 }
536
537 pub fn channel_capacity(mut self, capacity: usize) -> Self {
542 self.channel_capacity = capacity;
543 self
544 }
545
546 pub fn with_decoded_stream(mut self) -> Self {
548 self.stream_mode = EventStreamMode::Decoded;
549 self
550 }
551
552 pub fn with_both_streams(mut self) -> Self {
554 self.stream_mode = EventStreamMode::Both;
555 self
556 }
557
558 pub fn with_stack_traces(mut self) -> Self {
567 self.stack_traces = true;
568 self
569 }
570
571 pub fn with_thread_context(mut self) -> Self {
576 self.thread_context = true;
577 self
578 }
579
580 pub fn with_detailed_events(mut self) -> Self {
585 self.detailed_events = true;
586 self
587 }
588
589 pub fn with_cpu_samples(mut self) -> Self {
594 self.cpu_samples = true;
595 self
596 }
597
598 pub fn with_process_filter<I, P>(mut self, pids: I) -> Self
615 where
616 I: IntoIterator<Item = P>,
617 P: Into<ProcessId>,
618 {
619 self.process_filter = pids.into_iter().map(Into::into).collect();
620 self
621 }
622
623 pub fn start(self) -> Result<EventTrace> {
637 if self.name.is_empty() {
640 return Err(Error::Etw(EtwError::SessionStartFailed(
641 EtwSessionError::new(
642 Cow::Borrowed(""),
643 Cow::Borrowed("Session name cannot be empty"),
644 ),
645 )));
646 }
647
648 if self.name.len() > MAX_SESSION_NAME_LEN {
649 return Err(Error::Etw(EtwError::SessionStartFailed(
650 EtwSessionError::new(
651 Cow::Owned(self.name.clone()),
652 Cow::Borrowed("Session name exceeds 1024 characters"),
653 ),
654 )));
655 }
656
657 if self.system_providers.is_empty() && self.user_providers.is_empty() {
658 return Err(Error::Etw(EtwError::SessionStartFailed(
659 EtwSessionError::new(
660 Cow::Owned(self.name.clone()),
661 Cow::Borrowed(
662 "At least one system provider or user provider GUID must be specified",
663 ),
664 ),
665 )));
666 }
667
668 if !self.system_providers.is_empty() && !self.user_providers.is_empty() {
669 return Err(Error::Etw(EtwError::SessionStartFailed(
670 EtwSessionError::invalid_config(
671 Cow::Owned(self.name.clone()),
672 "providers",
673 Cow::Borrowed(
674 "Cannot mix kernel system providers with user-mode provider GUIDs in one session",
675 ),
676 ),
677 )));
678 }
679
680 if self.min_buffers > self.max_buffers {
681 return Err(Error::Etw(EtwError::SessionStartFailed(
682 EtwSessionError::new(
683 Cow::Owned(self.name.clone()),
684 Cow::Owned(format!(
685 "min_buffers ({}) cannot exceed max_buffers ({})",
686 self.min_buffers, self.max_buffers
687 )),
688 ),
689 )));
690 }
691
692 let is_kernel_session = !self.system_providers.is_empty();
695
696 let session_name = if is_kernel_session {
698 KERNEL_LOGGER_NAME.to_string()
699 } else {
700 self.name.clone()
701 };
702 let name_wide: Vec<u16> = session_name
703 .encode_utf16()
704 .chain(std::iter::once(0))
705 .collect();
706
707 let properties_size =
708 std::mem::size_of::<EVENT_TRACE_PROPERTIES>() + (MAX_SESSION_NAME_LEN * 2);
709 let mut properties_buffer = vec![0u8; properties_size];
710
711 let properties =
712 unsafe { &mut *(properties_buffer.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES) };
713
714 let enable_flags: u32 = if is_kernel_session {
716 self.system_providers
717 .iter()
718 .fold(0u32, |acc, p| acc | p.trace_flags())
719 } else {
720 0
721 };
722
723 properties.Wnode.BufferSize = properties_buffer.len() as u32;
724 properties.Wnode.Flags = WNODE_FLAG_TRACED_GUID;
725 properties.Wnode.ClientContext = 1; properties.Wnode.Guid = GUID::zeroed();
727 properties.BufferSize = self.buffer_size;
728 properties.MinimumBuffers = self.min_buffers;
729 properties.MaximumBuffers = self.max_buffers;
730 properties.FlushTimer = self.flush_interval;
731 properties.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
732 properties.EnableFlags = EVENT_TRACE_FLAG(enable_flags);
733 properties.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
734
735 let mut session_handle = CONTROLTRACE_HANDLE::default();
738
739 let start_result = unsafe {
740 StartTraceW(
741 &mut session_handle,
742 PWSTR(name_wide.as_ptr() as *mut u16),
743 properties,
744 )
745 };
746
747 if start_result.0 == ERROR_ALREADY_EXISTS_CODE && is_kernel_session {
748 let stop_buf_size =
750 std::mem::size_of::<EVENT_TRACE_PROPERTIES>() + (MAX_SESSION_NAME_LEN * 2);
751 let mut stop_buf = vec![0u8; stop_buf_size];
752 unsafe {
753 let stop_props = &mut *(stop_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
754 stop_props.Wnode.BufferSize = stop_buf.len() as u32;
755 stop_props.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
756 let _ = ControlTraceW(
757 CONTROLTRACE_HANDLE::default(),
758 PWSTR(name_wide.as_ptr() as *mut u16),
759 stop_props,
760 EVENT_TRACE_CONTROL_STOP,
761 );
762 }
763
764 let mut retry_buf = vec![0u8; properties_size];
766 let retry_result = unsafe {
767 let props = &mut *(retry_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
768 props.Wnode.BufferSize = retry_buf.len() as u32;
769 props.Wnode.Flags = WNODE_FLAG_TRACED_GUID;
770 props.Wnode.ClientContext = 1;
771 props.Wnode.Guid = GUID::zeroed();
772 props.BufferSize = self.buffer_size;
773 props.MinimumBuffers = self.min_buffers;
774 props.MaximumBuffers = self.max_buffers;
775 props.FlushTimer = self.flush_interval;
776 props.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
777 props.EnableFlags = EVENT_TRACE_FLAG(enable_flags);
778 props.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
779 StartTraceW(
780 &mut session_handle,
781 PWSTR(name_wide.as_ptr() as *mut u16),
782 props,
783 )
784 };
785
786 if retry_result != ERROR_SUCCESS {
787 return Err(Error::Etw(EtwError::SessionStartFailed(
788 EtwSessionError::with_code(
789 Cow::Owned(session_name),
790 Cow::Borrowed("Failed to start trace after stopping stale session"),
791 retry_result.0 as i32,
792 ),
793 )));
794 }
795 } else if start_result != ERROR_SUCCESS {
796 return Err(Error::Etw(EtwError::SessionStartFailed(
797 EtwSessionError::with_code(
798 Cow::Owned(session_name),
799 Cow::Borrowed("Failed to start trace session"),
800 start_result.0 as i32,
801 ),
802 )));
803 }
804
805 if !is_kernel_session {
806 for provider_guid in &self.user_providers {
807 let enable_result = unsafe {
808 EnableTraceEx2(
809 session_handle,
810 provider_guid as *const GUID,
811 EVENT_CONTROL_CODE_ENABLE_PROVIDER.0,
812 TRACE_LEVEL_VERBOSE as u8,
813 u64::MAX,
814 0,
815 0,
816 None,
817 )
818 };
819
820 if enable_result != ERROR_SUCCESS {
821 let mut stop_buf = vec![
822 0u8;
823 std::mem::size_of::<EVENT_TRACE_PROPERTIES>()
824 + (MAX_SESSION_NAME_LEN * 2)
825 ];
826 unsafe {
827 let stop_props =
828 &mut *(stop_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
829 stop_props.Wnode.BufferSize = stop_buf.len() as u32;
830 stop_props.LoggerNameOffset =
831 std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
832 let _ = ControlTraceW(
833 session_handle,
834 PWSTR(name_wide.as_ptr() as *mut u16),
835 stop_props,
836 EVENT_TRACE_CONTROL_STOP,
837 );
838 }
839
840 return Err(Error::Etw(EtwError::ProviderEnableFailed(
841 EtwProviderError::with_code(
842 Cow::Owned(format!("{provider_guid:?}")),
843 Cow::Borrowed("Failed to enable user-mode ETW provider"),
844 enable_result.0 as i32,
845 ),
846 )));
847 }
848 }
849 }
850
851 let (raw_tx, event_rx) = match self.stream_mode {
854 EventStreamMode::Raw | EventStreamMode::Both => {
855 let (tx, rx) = mpsc::sync_channel(self.channel_capacity);
856 (Some(tx), Some(rx))
857 }
858 EventStreamMode::Decoded => (None, None),
859 };
860
861 let (decoded_tx, decoded_rx) = match self.stream_mode {
862 EventStreamMode::Decoded | EventStreamMode::Both => {
863 let (tx, rx) = mpsc::sync_channel(self.channel_capacity);
864 (Some(tx), Some(rx))
865 }
866 EventStreamMode::Raw => (None, None),
867 };
868
869 let schema_cache = if self.detailed_events || decoded_tx.is_some() {
870 Some(Mutex::new(SchemaCache::new()))
871 } else {
872 None
873 };
874
875 let callback_ctx_guard = CallbackContextGuard::new(CallbackContext {
876 raw_sender: raw_tx,
877 decoded_sender: decoded_tx,
878 schema_cache,
879 process_filter: normalize_process_filter(self.process_filter),
880 include_thread_context: self.thread_context,
881 include_stack_traces: self.stack_traces,
882 include_cpu_samples: self.cpu_samples,
883 });
884 let ctx_ptr = callback_ctx_guard.as_user_context_ptr();
885
886 let mut log_file = EVENT_TRACE_LOGFILEW {
888 LoggerName: PWSTR(name_wide.as_ptr() as *mut u16),
889 Anonymous1: EVENT_TRACE_LOGFILEW_0 {
890 ProcessTraceMode: PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_REAL_TIME,
891 },
892 Anonymous2: EVENT_TRACE_LOGFILEW_1 {
893 EventRecordCallback: Some(trace_callback_fn),
894 },
895 Context: ctx_ptr,
896 ..Default::default()
897 };
898
899 let trace_handle = unsafe { OpenTraceW(&mut log_file) };
900 if trace_handle.Value == u64::MAX {
901 let mut stop_buf = vec![
903 0u8;
904 std::mem::size_of::<EVENT_TRACE_PROPERTIES>()
905 + (MAX_SESSION_NAME_LEN * 2)
906 ];
907 unsafe {
908 let stop_props = &mut *(stop_buf.as_mut_ptr() as *mut EVENT_TRACE_PROPERTIES);
909 stop_props.Wnode.BufferSize = stop_buf.len() as u32;
910 stop_props.LoggerNameOffset = std::mem::size_of::<EVENT_TRACE_PROPERTIES>() as u32;
911 let _ = ControlTraceW(
912 session_handle,
913 PWSTR(name_wide.as_ptr() as *mut u16),
914 stop_props,
915 EVENT_TRACE_CONTROL_STOP,
916 );
917 }
918 return Err(Error::Etw(EtwError::ConsumeFailed(EtwConsumeError::new(
919 Cow::Borrowed("OpenTraceW failed"),
920 ))));
921 }
922
923 let process_trace_handle = trace_handle;
925 let process_thread = std::thread::spawn(move || unsafe {
926 let handles = [process_trace_handle];
927 let now = GetSystemTimeAsFileTime();
928 let _ = ProcessTrace(&handles, Some(&now as *const _), None);
929 });
930
931 Ok(EventTrace {
932 name: session_name,
933 session_handle,
934 trace_handle,
935 event_rx,
936 decoded_rx,
937 events_processed: 0,
938 started: true,
939 process_thread: Some(process_thread),
940 stop_signal: Wait::manual_reset(false)?,
941 _callback_ctx_guard: callback_ctx_guard,
942 })
943 }
944}
945
946#[cfg(test)]
947mod tests {
948 use super::*;
949
950 fn make_trace_event(id: u16, process_id: u32) -> TraceEvent {
951 const FILETIME_UNIX_EPOCH: i64 = 116_444_736_000_000_000;
952
953 let mut record = EVENT_RECORD::default();
954 record.EventHeader.EventDescriptor.Id = id;
955 record.EventHeader.ProviderId = GUID::zeroed();
956 record.EventHeader.ProcessId = process_id;
957 record.EventHeader.ThreadId = 1;
958 record.EventHeader.TimeStamp = FILETIME_UNIX_EPOCH;
959 record.UserDataLength = 0;
960 record.UserData = std::ptr::null_mut();
961 TraceEvent::from_event_record_with_fields(&record, None)
962 }
963
964 fn inert_trace(
965 event_rx: Option<Receiver<TraceEvent>>,
966 decoded_rx: Option<Receiver<DecodedEvent>>,
967 ) -> EventTrace {
968 EventTrace {
969 name: "TestTrace".to_string(),
970 session_handle: CONTROLTRACE_HANDLE::default(),
971 trace_handle: PROCESSTRACE_HANDLE { Value: u64::MAX },
972 event_rx,
973 decoded_rx,
974 events_processed: 0,
975 started: false,
976 process_thread: None,
977 stop_signal: Wait::manual_reset(false).expect("wait handle create"),
978 _callback_ctx_guard: CallbackContextGuard::new(CallbackContext {
979 raw_sender: None,
980 decoded_sender: None,
981 schema_cache: None,
982 process_filter: None,
983 include_thread_context: false,
984 include_stack_traces: false,
985 include_cpu_samples: false,
986 }),
987 }
988 }
989
990 #[test]
991 fn test_builder_requires_provider() {
992 let result = EventTrace::builder("TestSession").start();
994 assert!(result.is_err());
995 }
996
997 #[test]
998 fn test_start_fails_when_mixing_kernel_and_user_providers() {
999 let result = EventTrace::builder("TestSession")
1000 .system_provider(SystemProvider::Process)
1001 .user_provider(GUID::zeroed())
1002 .start();
1003
1004 match result {
1005 Err(Error::Etw(EtwError::SessionStartFailed(e))) => {
1006 assert!(e.reason.contains("Cannot mix kernel system providers"));
1007 }
1008 _ => panic!("expected SessionStartFailed"),
1009 }
1010 }
1011
1012 #[test]
1013 fn test_empty_name_fails() {
1014 let result = EventTrace::builder("").start();
1015 assert!(result.is_err());
1016 }
1017
1018 #[test]
1019 fn test_name_too_long_fails() {
1020 let long_name = "x".repeat(MAX_SESSION_NAME_LEN + 1);
1021 let result = EventTrace::builder(long_name).start();
1022
1023 match result {
1024 Err(Error::Etw(EtwError::SessionStartFailed(e))) => {
1025 assert!(e.reason.contains("exceeds 1024"));
1026 }
1027 _ => panic!("expected SessionStartFailed"),
1028 }
1029 }
1030
1031 #[test]
1032 fn test_max_name_length_passes_length_validation() {
1033 let max_name = "x".repeat(MAX_SESSION_NAME_LEN);
1034 let result = EventTrace::builder(max_name).start();
1035
1036 match result {
1037 Err(Error::Etw(EtwError::SessionStartFailed(e))) => {
1038 assert!(
1040 e.reason
1041 .contains("At least one system provider or user provider GUID")
1042 );
1043 }
1044 _ => panic!("expected SessionStartFailed"),
1045 }
1046 }
1047
1048 #[test]
1049 fn test_buffer_constraint_fails() {
1050 let result = EventTrace::builder("Test")
1051 .system_provider(SystemProvider::Process)
1052 .min_buffers(10)
1053 .max_buffers(5) .start();
1055 assert!(result.is_err());
1056 }
1057
1058 #[test]
1059 fn test_normalize_process_filter_empty_is_none() {
1060 let filter = normalize_process_filter(Vec::new());
1061 assert!(filter.is_none());
1062 }
1063
1064 #[test]
1065 fn test_normalize_process_filter_deduplicates() {
1066 let filter = normalize_process_filter(vec![
1067 ProcessId::new(100),
1068 ProcessId::new(200),
1069 ProcessId::new(100),
1070 ])
1071 .expect("expected filter set");
1072 assert_eq!(filter.len(), 2);
1073 assert!(filter.contains(&ProcessId::new(100)));
1074 assert!(filter.contains(&ProcessId::new(200)));
1075 }
1076
1077 #[test]
1078 fn test_extract_stack_trace_none_without_extended_data() {
1079 let mut record: EVENT_RECORD = unsafe { std::mem::zeroed() };
1080 record.ExtendedDataCount = 0;
1081 record.ExtendedData = std::ptr::null_mut();
1082
1083 assert!(extract_stack_trace(&record).is_none());
1084 }
1085
1086 #[test]
1087 fn test_extract_stack_trace_64bit_payload() {
1088 let mut payload = Vec::new();
1089 payload.extend_from_slice(&0x1122_3344_5566_7788u64.to_le_bytes());
1090 payload.extend_from_slice(&0x0000_0000_0000_1111u64.to_le_bytes());
1091 payload.extend_from_slice(&0x0000_0000_0000_2222u64.to_le_bytes());
1092
1093 let mut ext: EVENT_HEADER_EXTENDED_DATA_ITEM = unsafe { std::mem::zeroed() };
1094 ext.ExtType = EVENT_HEADER_EXT_TYPE_STACK_TRACE64 as u16;
1095 ext.DataSize = payload.len() as u16;
1096 ext.DataPtr = payload.as_ptr() as u64;
1097
1098 let mut record: EVENT_RECORD = unsafe { std::mem::zeroed() };
1099 record.ExtendedDataCount = 1;
1100 record.ExtendedData = &mut ext;
1101
1102 let parsed = extract_stack_trace(&record).expect("stack should parse");
1103 assert_eq!(parsed.match_id, 0x1122_3344_5566_7788u64);
1104 assert_eq!(parsed.frames, vec![0x1111, 0x2222]);
1105 }
1106
1107 #[test]
1108 fn test_extract_cpu_sample_reads_processor_number() {
1109 let mut record: EVENT_RECORD = unsafe { std::mem::zeroed() };
1110 unsafe {
1111 *(std::ptr::addr_of_mut!(record.BufferContext) as *mut u8) = 13;
1112 }
1113
1114 let sample = extract_cpu_sample(&record);
1115 assert_eq!(sample.processor_number, 13);
1116 }
1117
1118 #[test]
1119 fn test_next_batch_fails_when_raw_stream_disabled() {
1120 let mut trace = inert_trace(None, None);
1121 let mut out = Vec::new();
1122
1123 let result = trace.next_batch(&mut out);
1124 match result {
1125 Err(Error::Etw(EtwError::ConsumeFailed(e))) => {
1126 assert!(e.reason.contains("Raw event stream is disabled"));
1127 }
1128 _ => panic!("expected ConsumeFailed"),
1129 }
1130 }
1131
1132 #[test]
1133 fn test_next_batch_decoded_fails_when_decoded_stream_disabled() {
1134 let mut trace = inert_trace(None, None);
1135 let mut out = Vec::new();
1136
1137 let result = trace.next_batch_decoded(&mut out);
1138 match result {
1139 Err(Error::Etw(EtwError::ConsumeFailed(e))) => {
1140 assert!(e.reason.contains("Decoded event stream is disabled"));
1141 }
1142 _ => panic!("expected ConsumeFailed"),
1143 }
1144 }
1145
1146 #[test]
1147 fn test_next_batch_drains_raw_stream_and_updates_counter() {
1148 let (tx, rx) = mpsc::sync_channel(8);
1149 tx.send(make_trace_event(1, 100)).expect("send event 1");
1150 tx.send(make_trace_event(2, 200)).expect("send event 2");
1151 drop(tx);
1152
1153 let mut trace = inert_trace(Some(rx), None);
1154 let mut out = Vec::new();
1155
1156 let count = trace
1157 .next_batch(&mut out)
1158 .expect("next_batch should succeed");
1159 assert_eq!(count, 2);
1160 assert_eq!(out.len(), 2);
1161 assert_eq!(trace.events_processed(), 2);
1162 }
1163
1164 #[test]
1165 fn test_next_batch_with_filter_filters_during_drain() {
1166 let (tx, rx) = mpsc::sync_channel(8);
1167 tx.send(make_trace_event(1, 111)).expect("send event 1");
1168 tx.send(make_trace_event(2, 222)).expect("send event 2");
1169 tx.send(make_trace_event(3, 333)).expect("send event 3");
1170 drop(tx);
1171
1172 let mut trace = inert_trace(Some(rx), None);
1173 let mut out = Vec::new();
1174
1175 let count = trace
1176 .next_batch_with_filter(&mut out, |e| e.process_id != 222)
1177 .expect("next_batch_with_filter should succeed");
1178
1179 assert_eq!(count, 2);
1180 assert_eq!(out.len(), 2);
1181 assert!(out.iter().all(|e| e.process_id != 222));
1182 assert_eq!(trace.events_processed(), 2);
1183 }
1184
1185 #[test]
1186 fn test_next_batch_decoded_drains_stream_and_updates_counter() {
1187 let (tx, rx) = mpsc::sync_channel(8);
1188 tx.send(DecodedEvent::Unknown)
1189 .expect("send decoded event 1");
1190 tx.send(DecodedEvent::Unknown)
1191 .expect("send decoded event 2");
1192 drop(tx);
1193
1194 let mut trace = inert_trace(None, Some(rx));
1195 let mut out = Vec::new();
1196
1197 let count = trace
1198 .next_batch_decoded(&mut out)
1199 .expect("next_batch_decoded should succeed");
1200
1201 assert_eq!(count, 2);
1202 assert_eq!(out.len(), 2);
1203 assert_eq!(trace.events_processed(), 2);
1204 }
1205}