1use std::{
2 ffi::{CStr, CString},
3 marker::PhantomData,
4};
5
6use num_enum::{IntoPrimitive, TryFromPrimitive};
7use singe_core::{impl_enum_conversion, impl_enum_display};
8use singe_cuda_sys::nvtx as sys;
9
10use crate::error::{Error, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub struct Version {
15 pub major: u32,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub struct Color(u32);
20
21impl Color {
22 pub const fn argb(value: u32) -> Self {
23 Self(value)
24 }
25
26 pub const fn rgba(red: u8, green: u8, blue: u8, alpha: u8) -> Self {
27 Self(((alpha as u32) << 24) | ((red as u32) << 16) | ((green as u32) << 8) | blue as u32)
28 }
29
30 pub const fn as_raw(self) -> u32 {
31 self.0
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub struct Category(u32);
37
38impl Category {
39 pub const fn from_raw(value: u32) -> Self {
40 Self(value)
41 }
42
43 pub const fn as_raw(self) -> u32 {
44 self.0
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
49#[repr(u32)]
50#[non_exhaustive]
51pub enum ColorType {
52 Unknown = sys::nvtxColorType_t::NVTX_COLOR_UNKNOWN as _,
53 Argb = sys::nvtxColorType_t::NVTX_COLOR_ARGB as _,
54}
55
56impl_enum_conversion!(sys::nvtxColorType_t, ColorType);
57
58impl_enum_display!(ColorType, {
59 Self::Unknown => "NVTX_COLOR_UNKNOWN",
60 Self::Argb => "NVTX_COLOR_ARGB",
61});
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
64#[repr(u32)]
65#[non_exhaustive]
66pub enum MessageType {
67 Unknown = sys::nvtxMessageType_t::NVTX_MESSAGE_UNKNOWN as _,
68 Ascii = sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_ASCII as _,
69 Unicode = sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_UNICODE as _,
70 Registered = sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_REGISTERED as _,
71}
72
73impl_enum_conversion!(sys::nvtxMessageType_t, MessageType);
74
75impl_enum_display!(MessageType, {
76 Self::Unknown => "NVTX_MESSAGE_UNKNOWN",
77 Self::Ascii => "NVTX_MESSAGE_TYPE_ASCII",
78 Self::Unicode => "NVTX_MESSAGE_TYPE_UNICODE",
79 Self::Registered => "NVTX_MESSAGE_TYPE_REGISTERED",
80});
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
83#[repr(u32)]
84#[non_exhaustive]
85pub enum PayloadType {
86 Unknown = sys::nvtxPayloadType_t::NVTX_PAYLOAD_UNKNOWN as _,
87 UnsignedInt64 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_UNSIGNED_INT64 as _,
88 Int64 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_INT64 as _,
89 Double = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_DOUBLE as _,
90 UnsignedInt32 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_UNSIGNED_INT32 as _,
91 Int32 = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_INT32 as _,
92 Float = sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_FLOAT as _,
93}
94
95impl_enum_conversion!(sys::nvtxPayloadType_t, PayloadType);
96
97impl_enum_display!(PayloadType, {
98 Self::Unknown => "NVTX_PAYLOAD_UNKNOWN",
99 Self::UnsignedInt64 => "NVTX_PAYLOAD_TYPE_UNSIGNED_INT64",
100 Self::Int64 => "NVTX_PAYLOAD_TYPE_INT64",
101 Self::Double => "NVTX_PAYLOAD_TYPE_DOUBLE",
102 Self::UnsignedInt32 => "NVTX_PAYLOAD_TYPE_UNSIGNED_INT32",
103 Self::Int32 => "NVTX_PAYLOAD_TYPE_INT32",
104 Self::Float => "NVTX_PAYLOAD_TYPE_FLOAT",
105});
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
108#[repr(u32)]
109#[non_exhaustive]
110pub enum ResourceGenericType {
111 Unknown = sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_UNKNOWN as _,
112 GenericPointer = sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_POINTER as _,
113 GenericHandle = sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_HANDLE as _,
114 GenericThreadNative =
115 sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_THREAD_NATIVE as _,
116 GenericThreadPosix =
117 sys::nvtxResourceGenericType_t::NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX as _,
118}
119
120impl_enum_conversion!(sys::nvtxResourceGenericType_t, ResourceGenericType);
121
122impl_enum_display!(ResourceGenericType, {
123 Self::Unknown => "NVTX_RESOURCE_TYPE_UNKNOWN",
124 Self::GenericPointer => "NVTX_RESOURCE_TYPE_GENERIC_POINTER",
125 Self::GenericHandle => "NVTX_RESOURCE_TYPE_GENERIC_HANDLE",
126 Self::GenericThreadNative => "NVTX_RESOURCE_TYPE_GENERIC_THREAD_NATIVE",
127 Self::GenericThreadPosix => "NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX",
128});
129
130#[derive(Debug, Clone, Copy, PartialEq)]
131#[non_exhaustive]
132pub enum Payload {
133 I32(i32),
134 I64(i64),
135 U32(u32),
136 U64(u64),
137 F32(f32),
138 F64(f64),
139}
140
141impl Payload {
142 fn encode_type(self) -> sys::nvtxPayloadType_t {
143 match self {
144 Self::I32(_) => PayloadType::Int32.into(),
145 Self::I64(_) => PayloadType::Int64.into(),
146 Self::U32(_) => PayloadType::UnsignedInt32.into(),
147 Self::U64(_) => PayloadType::UnsignedInt64.into(),
148 Self::F32(_) => PayloadType::Float.into(),
149 Self::F64(_) => PayloadType::Double.into(),
150 }
151 }
152
153 fn encode_value(self) -> sys::nvtxEventAttributes_v2_payload_t {
154 match self {
155 Self::I32(value) => sys::nvtxEventAttributes_v2_payload_t { iValue: value },
156 Self::I64(value) => sys::nvtxEventAttributes_v2_payload_t { llValue: value },
157 Self::U32(value) => sys::nvtxEventAttributes_v2_payload_t { uiValue: value },
158 Self::U64(value) => sys::nvtxEventAttributes_v2_payload_t { ullValue: value },
159 Self::F32(value) => sys::nvtxEventAttributes_v2_payload_t { fValue: value },
160 Self::F64(value) => sys::nvtxEventAttributes_v2_payload_t { dValue: value },
161 }
162 }
163}
164
165#[derive(Debug, Clone, Copy)]
166pub struct EventAttributes<'a> {
167 message: Option<&'a CStr>,
168 category: Option<Category>,
169 color: Option<Color>,
170 payload: Option<Payload>,
171}
172
173impl<'a> EventAttributes<'a> {
174 pub const fn new() -> Self {
175 Self {
176 message: None,
177 category: None,
178 color: None,
179 payload: None,
180 }
181 }
182
183 pub fn with_message(mut self, message: &'a CStr) -> Self {
184 self.message = Some(message);
185 self
186 }
187
188 pub fn with_category(mut self, category: Category) -> Self {
189 self.category = Some(category);
190 self
191 }
192
193 pub fn with_color(mut self, color: Color) -> Self {
194 self.color = Some(color);
195 self
196 }
197
198 pub fn with_payload(mut self, payload: Payload) -> Self {
199 self.payload = Some(payload);
200 self
201 }
202
203 pub const fn message(&self) -> Option<&'a CStr> {
204 self.message
205 }
206
207 pub const fn category(&self) -> Option<Category> {
208 self.category
209 }
210
211 pub const fn color(&self) -> Option<Color> {
212 self.color
213 }
214
215 pub const fn payload(&self) -> Option<Payload> {
216 self.payload
217 }
218
219 fn encode(self) -> sys::nvtxEventAttributes_t {
220 let mut raw = sys::nvtxEventAttributes_t {
221 version: sys::NVTX_VERSION as u16,
222 size: size_of::<sys::nvtxEventAttributes_t>() as u16,
223 ..Default::default()
224 };
225
226 if let Some(category) = self.category {
227 raw.category = category.0;
228 }
229
230 if let Some(color) = self.color {
231 raw.colorType = sys::nvtxColorType_t::from(ColorType::Argb) as i32;
232 raw.color = color.0;
233 }
234
235 if let Some(payload) = self.payload {
236 raw.payloadType = payload.encode_type() as i32;
237 raw.payload = payload.encode_value();
238 }
239
240 if let Some(message) = self.message {
241 raw.messageType = sys::nvtxMessageType_t::from(MessageType::Ascii) as i32;
242 raw.message.ascii = message.as_ptr();
243 }
244
245 raw
246 }
247}
248
249impl Default for EventAttributes<'_> {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255#[derive(Debug, Clone)]
256pub struct Event {
257 message: CString,
258 category: Option<Category>,
259 color: Option<Color>,
260 payload: Option<Payload>,
261}
262
263impl Event {
264 pub fn create(message: &str) -> Result<Self> {
265 Ok(Self {
266 message: CString::new(message)?,
267 category: None,
268 color: None,
269 payload: None,
270 })
271 }
272
273 pub fn create_from_c_string(message: CString) -> Self {
274 Self {
275 message,
276 category: None,
277 color: None,
278 payload: None,
279 }
280 }
281
282 pub fn with_category(mut self, category: Category) -> Self {
283 self.category = Some(category);
284 self
285 }
286
287 pub fn with_color(mut self, color: Color) -> Self {
288 self.color = Some(color);
289 self
290 }
291
292 pub fn with_payload(mut self, payload: Payload) -> Self {
293 self.payload = Some(payload);
294 self
295 }
296
297 pub fn mark(&self) {
298 mark_with_attributes(self.attributes());
299 }
300
301 pub fn local_range(&self) -> LocalRange {
302 LocalRange::from_attributes(self.attributes())
303 }
304
305 pub fn range(&self) -> Range {
306 Range::from_attributes(self.attributes())
307 }
308
309 pub fn domain_mark(&self, domain: &Domain) {
310 domain.mark_with_attributes(self.attributes());
311 }
312
313 pub fn domain_local_range<'a>(&self, domain: &'a Domain) -> DomainLocalRange<'a> {
314 domain.range_with_attributes(self.attributes())
315 }
316
317 pub fn domain_range<'a>(&self, domain: &'a Domain) -> DomainRange<'a> {
318 domain.start_range_with_attributes(self.attributes())
319 }
320
321 pub fn attributes(&self) -> EventAttributes<'_> {
322 let mut attributes = EventAttributes::new().with_message(&self.message);
323
324 if let Some(category) = self.category {
325 attributes = attributes.with_category(category);
326 }
327
328 if let Some(color) = self.color {
329 attributes = attributes.with_color(color);
330 }
331
332 if let Some(payload) = self.payload {
333 attributes = attributes.with_payload(payload);
334 }
335
336 attributes
337 }
338}
339
340#[derive(Debug)]
341pub struct Domain {
342 handle: sys::nvtxDomainHandle_t,
343}
344
345unsafe impl Send for Domain {}
348unsafe impl Sync for Domain {}
349
350impl Domain {
351 pub fn create(name: &str) -> Result<Self> {
352 let name = CString::new(name)?;
353 Self::create_from_c_str(&name)
354 }
355
356 pub fn create_from_c_str(name: &CStr) -> Result<Self> {
357 let handle = unsafe { sys::nvtxDomainCreateA(name.as_ptr()) };
358 if handle.is_null() {
359 return Err(Error::NullHandle);
360 }
361 Ok(Self { handle })
362 }
363
364 pub fn as_raw(&self) -> sys::nvtxDomainHandle_t {
365 self.handle
366 }
367
368 pub fn mark(&self, message: &str) -> Result<()> {
369 let message = CString::new(message)?;
370 self.mark_c_str(&message);
371 Ok(())
372 }
373
374 pub fn mark_c_str(&self, message: &CStr) {
375 self.mark_with_attributes(EventAttributes::new().with_message(message));
376 }
377
378 pub fn mark_with_attributes(&self, attributes: EventAttributes<'_>) {
379 let raw = attributes.encode();
380 unsafe { sys::nvtxDomainMarkEx(self.handle, &raw) };
381 }
382
383 pub fn range<'a>(&'a self, message: &str) -> Result<DomainLocalRange<'a>> {
384 let message = CString::new(message)?;
385 Ok(self.range_c_str(&message))
386 }
387
388 pub fn range_c_str<'a>(&'a self, message: &CStr) -> DomainLocalRange<'a> {
389 self.range_with_attributes(EventAttributes::new().with_message(message))
390 }
391
392 pub fn range_with_attributes<'a>(
393 &'a self,
394 attributes: EventAttributes<'_>,
395 ) -> DomainLocalRange<'a> {
396 let raw = attributes.encode();
397 unsafe { sys::nvtxDomainRangePushEx(self.handle, &raw) };
398 DomainLocalRange {
399 domain: self,
400 _not_send: PhantomData,
401 }
402 }
403
404 pub fn start_range(&self, message: &str) -> Result<DomainRange<'_>> {
405 let message = CString::new(message)?;
406 Ok(self.start_range_c_str(&message))
407 }
408
409 pub fn start_range_c_str(&self, message: &CStr) -> DomainRange<'_> {
410 self.start_range_with_attributes(EventAttributes::new().with_message(message))
411 }
412
413 pub fn start_range_with_attributes(&self, attributes: EventAttributes<'_>) -> DomainRange<'_> {
414 let raw = attributes.encode();
415 let id = unsafe { sys::nvtxDomainRangeStartEx(self.handle, &raw) };
416 DomainRange { domain: self, id }
417 }
418
419 pub fn name_category(&self, category: Category, name: &str) -> Result<()> {
420 let name = CString::new(name)?;
421 unsafe { sys::nvtxDomainNameCategoryA(self.handle, category.0, name.as_ptr()) };
422 Ok(())
423 }
424}
425
426impl Drop for Domain {
427 fn drop(&mut self) {
428 unsafe { sys::nvtxDomainDestroy(self.handle) };
429 }
430}
431
432#[derive(Debug)]
433pub struct LocalRange {
434 _not_send: PhantomData<*mut ()>,
435}
436
437impl LocalRange {
438 pub fn create(message: &str) -> Result<Self> {
439 let message = CString::new(message)?;
440 Ok(Self::create_from_c_str(&message))
441 }
442
443 pub fn create_from_c_str(message: &CStr) -> Self {
444 unsafe { sys::nvtxRangePushA(message.as_ptr()) };
445 Self {
446 _not_send: PhantomData,
447 }
448 }
449
450 pub fn from_attributes(attributes: EventAttributes<'_>) -> Self {
451 let raw = attributes.encode();
452 unsafe { sys::nvtxRangePushEx(&raw) };
453 Self {
454 _not_send: PhantomData,
455 }
456 }
457}
458
459impl Drop for LocalRange {
460 fn drop(&mut self) {
461 unsafe { sys::nvtxRangePop() };
462 }
463}
464
465#[derive(Debug)]
466pub struct Range {
467 id: sys::nvtxRangeId_t,
468}
469
470impl Range {
471 pub fn create(message: &str) -> Result<Self> {
472 let message = CString::new(message)?;
473 Ok(Self::create_from_c_str(&message))
474 }
475
476 pub fn create_from_c_str(message: &CStr) -> Self {
477 let id = unsafe { sys::nvtxRangeStartA(message.as_ptr()) };
478 Self { id }
479 }
480
481 pub fn from_attributes(attributes: EventAttributes<'_>) -> Self {
482 let raw = attributes.encode();
483 let id = unsafe { sys::nvtxRangeStartEx(&raw) };
484 Self { id }
485 }
486}
487
488impl Drop for Range {
489 fn drop(&mut self) {
490 unsafe { sys::nvtxRangeEnd(self.id) };
491 }
492}
493
494#[derive(Debug)]
495pub struct DomainLocalRange<'a> {
496 domain: &'a Domain,
497 _not_send: PhantomData<*mut ()>,
498}
499
500impl Drop for DomainLocalRange<'_> {
501 fn drop(&mut self) {
502 unsafe { sys::nvtxDomainRangePop(self.domain.handle) };
503 }
504}
505
506#[derive(Debug)]
507pub struct DomainRange<'a> {
508 domain: &'a Domain,
509 id: sys::nvtxRangeId_t,
510}
511
512impl Drop for DomainRange<'_> {
513 fn drop(&mut self) {
514 unsafe { sys::nvtxDomainRangeEnd(self.domain.handle, self.id) };
515 }
516}
517
518pub fn version() -> Version {
519 Version {
520 major: sys::NVTX_VERSION,
521 }
522}
523
524pub fn initialize() {
525 unsafe { sys::nvtxInitialize(std::ptr::null()) };
526}
527
528pub fn mark(message: &str) -> Result<()> {
529 Event::create(message)?.mark();
530 Ok(())
531}
532
533pub fn mark_c_str(message: &CStr) {
534 unsafe { sys::nvtxMarkA(message.as_ptr()) };
535}
536
537pub fn mark_with_attributes(attributes: EventAttributes<'_>) {
538 let raw = attributes.encode();
539 unsafe { sys::nvtxMarkEx(&raw) };
540}
541
542pub fn name_category(category: Category, name: &str) -> Result<()> {
543 let name = CString::new(name)?;
544 unsafe { sys::nvtxNameCategoryA(category.0, name.as_ptr()) };
545 Ok(())
546}
547
548pub fn name_os_thread(thread_id: u32, name: &str) -> Result<()> {
549 let name = CString::new(name)?;
550 unsafe { sys::nvtxNameOsThreadA(thread_id, name.as_ptr()) };
551 Ok(())
552}
553
554pub fn scoped_range(message: &str) -> Result<LocalRange> {
555 LocalRange::create(message)
556}
557
558#[cfg(test)]
559mod tests {
560 use std::mem;
561
562 use super::*;
563
564 #[test]
565 fn encodes_event_attributes() {
566 let message = c"work";
567 let raw = EventAttributes::new()
568 .with_message(message)
569 .with_category(Category::from_raw(7))
570 .with_color(Color::rgba(1, 2, 3, 4))
571 .with_payload(Payload::I64(-42))
572 .encode();
573
574 assert_eq!(raw.version, sys::NVTX_VERSION as u16);
575 assert_eq!(
576 raw.size,
577 mem::size_of::<sys::nvtxEventAttributes_t>() as u16
578 );
579 assert_eq!(raw.category, 7);
580 assert_eq!(raw.colorType, sys::nvtxColorType_t::NVTX_COLOR_ARGB as i32);
581 assert_eq!(raw.color, 0x0401_0203);
582 assert_eq!(
583 raw.messageType,
584 sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_ASCII as i32
585 );
586 assert_eq!(unsafe { raw.message.ascii }, message.as_ptr());
587 assert_eq!(
588 raw.payloadType,
589 sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_INT64 as i32
590 );
591 assert_eq!(unsafe { raw.payload.llValue }, -42);
592 }
593
594 #[test]
595 fn owned_event_builds_attributes() {
596 let event = Event::create("owned")
597 .unwrap()
598 .with_category(Category::from_raw(3))
599 .with_color(Color::argb(0xff00_00ff))
600 .with_payload(Payload::U32(11));
601
602 let attributes = event.attributes();
603 let raw = attributes.encode();
604
605 assert_eq!(attributes.message(), Some(c"owned".as_ref()));
606 assert_eq!(attributes.category(), Some(Category::from_raw(3)));
607 assert_eq!(attributes.color(), Some(Color::argb(0xff00_00ff)));
608 assert_eq!(attributes.payload(), Some(Payload::U32(11)));
609 assert_eq!(raw.category, 3);
610 assert_eq!(raw.color, 0xff00_00ff);
611 assert_eq!(
612 raw.payloadType,
613 sys::nvtxPayloadType_t::NVTX_PAYLOAD_TYPE_UNSIGNED_INT32 as i32
614 );
615 assert_eq!(unsafe { raw.payload.uiValue }, 11);
616 }
617
618 #[test]
619 fn enum_wrappers_convert_and_display() {
620 assert_eq!(
621 ColorType::from(sys::nvtxColorType_t::NVTX_COLOR_ARGB),
622 ColorType::Argb
623 );
624 assert_eq!(
625 sys::nvtxMessageType_t::from(MessageType::Ascii),
626 sys::nvtxMessageType_t::NVTX_MESSAGE_TYPE_ASCII
627 );
628 assert_eq!(
629 PayloadType::UnsignedInt64.to_string(),
630 "NVTX_PAYLOAD_TYPE_UNSIGNED_INT64"
631 );
632 assert_eq!(
633 ResourceGenericType::GenericThreadPosix.to_string(),
634 "NVTX_RESOURCE_TYPE_GENERIC_THREAD_POSIX"
635 );
636 }
637}