tritonserver_rs/
trace.rs

1//! Tracing utilities for debugging and profiling.
2//!
3//! Usage example:
4//! ```
5//! struct TraceH;
6//! impl TraceHandler for TraceH {
7//!     fn trace_activity(
8//!        &self,
9//!        trace: &tritonserver_rs::trace::Trace,
10//!        event: Activity,
11//!        event_time: Duration,
12//!     ) {
13//!         log::info!(
14//!             "Tracing activities: Trace_id: {}, event: {event:?}, event_time_secs: {}",
15//!             trace.id().unwrap(),
16//!             event_time.as_secs()
17//!         );
18//!         if event == Activity::ComputeStart {
19//!             log::info!("Computations start, spawning new Trace");
20//!             trace.spawn_child().unwrap();
21//!         }
22//!     }
23//! }
24//!
25//! impl TensorTraceHandler for TraceH {
26//!     fn trace_tensor_activity(
27//!         &self,
28//!         trace: &Trace,
29//!         event: Activity,
30//!         _tensor_data: &tritonserver_rs::Buffer,
31//!         tensor_shape: tritonserver_rs::message::Shape,
32//!     ) {
33//!         log::info!(
34//!             "Tracing Tensor Activity: Trace_id: {}, event: {event:?}, tensor name: {}",
35//!             trace.id().unwrap(),
36//!             tensor_shape.name
37//!         );
38//!     }
39//! }
40//!
41//! /// Adds custom tracing to Inference Request.
42//! fn add_trace_to_request(request: &mut Request) {
43//!    request.add_trace(Trace::new_with_handle(
44//!        Level::TIMESTAMPS | Level::TENSORS,
45//!        0,
46//!        TraceH,
47//!        Some(TraceH),
48//!    ).unwrap());
49//! }
50//! ```
51
52use core::slice;
53use std::{
54    ffi::{c_void, CStr},
55    mem::{forget, transmute},
56    os::raw::c_char,
57    ptr::{null, null_mut},
58    sync::Arc,
59    time::Duration,
60};
61
62use crate::{
63    error::{Error, CSTR_CONVERT_ERROR_PLUG},
64    from_char_array,
65    message::Shape,
66    sys, to_cstring, Buffer, MemoryType,
67};
68
69bitflags::bitflags! {
70    /// Trace levels. The trace level controls the type of trace
71    ///  activities that are reported for an inference request.
72    ///
73    /// Trace level values can be combined to trace multiple types of activities. For example, use
74    /// ([Level::TIMESTAMPS] | [Level::TENSORS]) to trace both timestamps and
75    ///  tensors for an inference request.
76    struct Level: u32 {
77        /// Tracing disabled. No trace activities are reported.
78        const DISABLED = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_DISABLED;
79        /// Deprecated. Use [Level::TIMESTAMPS].
80        #[deprecated]
81        const MIN = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MIN;
82        /// Deprecated. Use [Level::TIMESTAMPS].
83        #[deprecated]
84        const MAX = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MAX;
85        /// Record timestamps for the inference request.
86        const TIMESTAMPS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TIMESTAMPS;
87        /// Record input and output tensor values for the inference request.
88        const TENSORS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TENSORS;
89    }
90}
91
92impl Level {
93    #[allow(dead_code)]
94    /// Get the string representation of a trace level.
95    fn as_str(self) -> &'static str {
96        unsafe {
97            let ptr = sys::TRITONSERVER_InferenceTraceLevelString(self.bits());
98            assert!(!ptr.is_null());
99            CStr::from_ptr(ptr)
100                .to_str()
101                .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
102        }
103    }
104}
105
106/// Enum representation of inference status.
107#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
108#[repr(u32)]
109pub enum Activity {
110    RequestStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_START,
111    QueueStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_QUEUE_START,
112    ComputeStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_START,
113    ComputeInputEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_INPUT_END,
114    ComputeOutputStart =
115        sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_OUTPUT_START,
116    ComputeEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_END,
117    RequestEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_END,
118    TensorQueueInput = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT,
119    TensorBackendInput =
120        sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT,
121    TensorBackendOutput =
122        sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT,
123    CustomActivity = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_CUSTOM_ACTIVITY,
124}
125
126/// Inference event handler trait.
127pub trait TraceHandler: Send + Sync + 'static {
128    /// This function is invoked each time the `event` occures.
129    ///
130    /// `trace`: Trace object that was reported.
131    /// Note that child traces of constructed one also are reported with this fn.
132    /// Check [Trace::new_with_handle] for more info.\
133    /// `event`: activity that has occurred. \
134    /// `event_time`: time when event occured. \
135    ///     Triton Trace APIs report timestamps using steady clock, which is a monotonic clock that ensures time always movess forward.
136    ///     This clock is not related to wall clock time and, for example, can measure time since last reboot (aka /proc/uptime).
137    fn trace_activity(&self, trace: &Trace, event: Activity, event_time: Duration);
138}
139
140impl TraceHandler for () {
141    fn trace_activity(&self, _trace: &Trace, _event: Activity, _event_time: Duration) {}
142}
143
144/// Tensor event handler trait.
145pub trait TensorTraceHandler: Send + Sync + 'static {
146    /// This function is invoked each time the tensor `event` occures.
147    ///
148    /// `trace`: Trace object that was reported.
149    /// Note that child traces of constructed one also are reported with this fn.
150    /// Check [Trace::new_with_handle] for more info.\
151    /// `event`: activity that has occurred. \
152    /// `tensor_data`: borrowed buffer containing data of the tensor. \
153    /// `tensor_shape`: shape (name, data_type and dims) of the tensor.
154    fn trace_tensor_activity(
155        &self,
156        trace: &Trace,
157        event: Activity,
158        tensor_data: &Buffer,
159        tensor_shape: Shape,
160    );
161}
162
163impl TensorTraceHandler for () {
164    fn trace_tensor_activity(
165        &self,
166        _trace: &Trace,
167        _event: Activity,
168        _tensor_data: &Buffer,
169        _tensor_shape: Shape,
170    ) {
171    }
172}
173
174/// Can be passed to [Trace::new_with_handle] if no TENSORS or TIMESTAMPS are needed.
175pub const NOOP: Option<()> = None;
176
177struct TraceCallbackItems<H: TraceHandler, T: TensorTraceHandler> {
178    activity_handler: Option<H>,
179    tensor_activity_handler: Option<T>,
180}
181
182/// Don't want to use annotations like Trace<H, T> for
183/// handlers_copy: Arc<TraceCallbackItems<H,T>>, so will use Arc<dyn DynamicTypeHelper>.
184///
185/// If someone can teach me how to do it better, i'm all ears((.
186trait DynamicTypeHelper: Send + Sync {}
187impl<H: TraceHandler, T: TensorTraceHandler> DynamicTypeHelper for TraceCallbackItems<H, T> {}
188
189/// Inference object that provides custom tracing.
190///
191/// Is constructed with [TraceHandler] object that is activated each time an event occures.
192pub struct Trace {
193    pub(crate) ptr: TraceInner,
194    /// So callback won't be dropped if trace reports after the fn delete (inference).
195    handlers_copy: Arc<dyn DynamicTypeHelper>,
196}
197
198pub(crate) struct TraceInner(pub(crate) *mut sys::TRITONSERVER_InferenceTrace);
199unsafe impl Send for TraceInner {}
200unsafe impl Sync for TraceInner {}
201
202impl PartialEq for Trace {
203    fn eq(&self, other: &Self) -> bool {
204        let left = match self.id() {
205            Ok(l) => l,
206            Err(err) => {
207                log::warn!("Error getting ID for two Traces comparison: {err}");
208                return false;
209            }
210        };
211        let right = match other.id() {
212            Ok(r) => r,
213            Err(err) => {
214                log::warn!("Error getting ID for two Traces comparison: {err}");
215                return false;
216            }
217        };
218        left == right
219    }
220}
221impl Eq for Trace {}
222
223impl Trace {
224    /// Create a new inference trace object.
225    ///
226    /// The `activity_handler` and `tensor_activity_handler` will be called to report activity
227    /// including [Trace::report_activity] called by this trace as well as by __every__ child traces that are spawned
228    /// by this one. So the [TraceHandler::trace_activity] and [TensorTraceHandler::trace_tensor_activity]
229    /// should check the trace object (first argument) that are passed to it
230    /// to determine specifically what trace was reported.
231    ///
232    /// `level`: The tracing level. \
233    /// `parent_id`: The parent trace id for this trace.
234    /// A value of 0 indicates that there is not parent trace. \
235    /// `activity_handler`: The callback function where activity (on timeline event)
236    ///  for the trace (and all the child traces) is reported. \
237    /// `tensor_activity_handler`: Optional callback function where activity (on tensor event)
238    /// for the trace (and all the child traces) is reported.
239    pub fn new_with_handle<H: TraceHandler, T: TensorTraceHandler>(
240        parent_id: u64,
241        activity_handler: Option<H>,
242        tensor_activity_handler: Option<T>,
243    ) -> Result<Self, Error> {
244        let enable_activity = activity_handler.is_some();
245        let enable_tensor_activity = tensor_activity_handler.is_some();
246
247        let level = match (enable_activity, enable_tensor_activity) {
248            (true, true) => Level::TENSORS | Level::TIMESTAMPS,
249            (true, false) => Level::TIMESTAMPS,
250            (false, true) => Level::TENSORS,
251            (false, false) => Level::DISABLED,
252        };
253
254        let mut ptr = null_mut::<sys::TRITONSERVER_InferenceTrace>();
255        let handlers = Arc::new(TraceCallbackItems {
256            activity_handler,
257            tensor_activity_handler,
258        });
259        let raw_handlers = Arc::into_raw(handlers.clone()) as *mut c_void;
260
261        triton_call!(sys::TRITONSERVER_InferenceTraceTensorNew(
262            &mut ptr as *mut _,
263            level.bits(),
264            parent_id,
265            enable_activity.then_some(activity_wraper::<H, T>),
266            enable_tensor_activity.then_some(tensor_activity_wrapper::<H, T>),
267            Some(delete::<H, T>),
268            raw_handlers,
269        ))?;
270
271        assert!(!ptr.is_null());
272        let trace = Trace {
273            ptr: TraceInner(ptr),
274            handlers_copy: handlers,
275        };
276        Ok(trace)
277    }
278
279    /// Report a trace activity. All the traces reported using this API will be send [Activity::CustomActivity] type.
280    ///
281    /// `timestamp` The timestamp associated with the trace activity. \
282    /// `name` The trace activity name.
283    pub fn report_activity<N: AsRef<str>>(
284        &self,
285        timestamp: Duration,
286        activity_name: N,
287    ) -> Result<(), Error> {
288        let name = to_cstring(activity_name)?;
289        triton_call!(sys::TRITONSERVER_InferenceTraceReportActivity(
290            self.ptr.0,
291            timestamp.as_nanos() as _,
292            name.as_ptr()
293        ))
294    }
295
296    /// Get the id associated with the trace.
297    /// Every trace is assigned an id that is unique across all traces created for a Triton server.
298    pub fn id(&self) -> Result<u64, Error> {
299        let mut id: u64 = 0;
300        triton_call!(
301            sys::TRITONSERVER_InferenceTraceId(self.ptr.0, &mut id as *mut _),
302            id
303        )
304    }
305
306    /// Get the parent id associated with the trace. \
307    /// The parent id indicates a parent-child relationship between two traces.
308    /// A parent id value of 0 indicates that there is no parent trace.
309    pub fn parent_id(&self) -> Result<u64, Error> {
310        let mut id: u64 = 0;
311        triton_call!(
312            sys::TRITONSERVER_InferenceTraceParentId(self.ptr.0, &mut id as *mut _),
313            id
314        )
315    }
316
317    /// Get the name of the model associated with the trace.
318    pub fn model_name(&self) -> Result<String, Error> {
319        let mut name = null::<c_char>();
320        triton_call!(
321            sys::TRITONSERVER_InferenceTraceModelName(self.ptr.0, &mut name as *mut _),
322            from_char_array(name)
323        )
324    }
325
326    /// Get the version of the model associated with the trace.
327    pub fn model_version(&self) -> Result<i64, Error> {
328        let mut version: i64 = 0;
329        triton_call!(
330            sys::TRITONSERVER_InferenceTraceModelVersion(self.ptr.0, &mut version as *mut _),
331            version
332        )
333    }
334
335    /// Get the request id associated with a trace.
336    /// Returns the version of the model associated with the trace.
337    pub fn request_id(&self) -> Result<String, Error> {
338        let mut request_id = null::<c_char>();
339
340        triton_call!(
341            sys::TRITONSERVER_InferenceTraceRequestId(self.ptr.0, &mut request_id as *mut _),
342            from_char_array(request_id)
343        )
344    }
345
346    /// Returns the child trace, spawned from the parent(self) trace.
347    ///
348    /// Be causious: Trace is deleting on drop, so don't forget to save it.
349    /// Also do not use parent and child traces for different Requests: it can lead to Seq Faults.
350    pub fn spawn_child(&self) -> Result<Trace, Error> {
351        let mut trace = null_mut();
352        triton_call!(
353            sys::TRITONSERVER_InferenceTraceSpawnChildTrace(self.ptr.0, &mut trace),
354            Trace {
355                ptr: TraceInner(trace),
356                handlers_copy: self.handlers_copy.clone(),
357            }
358        )
359    }
360
361    /// Set context to Triton Trace.
362    pub fn set_context(&mut self, context: String) -> Result<&mut Self, Error> {
363        let context = to_cstring(context)?;
364        triton_call!(
365            sys::TRITONSERVER_InferenceTraceSetContext(self.ptr.0, context.as_ptr()),
366            self
367        )
368    }
369
370    /// Get Triton Trace context.
371    pub fn context(&self) -> Result<String, Error> {
372        let mut context = null::<c_char>();
373        triton_call!(
374            sys::TRITONSERVER_InferenceTraceContext(self.ptr.0, &mut context as *mut _),
375            from_char_array(context)
376        )
377    }
378}
379
380impl Drop for TraceInner {
381    fn drop(&mut self) {
382        if !self.0.is_null() {
383            unsafe {
384                sys::TRITONSERVER_InferenceTraceDelete(self.0);
385            }
386        }
387    }
388}
389
390unsafe extern "C" fn delete<H: TraceHandler, T: TensorTraceHandler>(
391    this: *mut sys::TRITONSERVER_InferenceTrace,
392    userp: *mut c_void,
393) {
394    if !userp.is_null() && !this.is_null() {
395        sys::TRITONSERVER_InferenceTraceDelete(this);
396        Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
397    }
398}
399
400unsafe extern "C" fn activity_wraper<H: TraceHandler, T: TensorTraceHandler>(
401    trace: *mut sys::TRITONSERVER_InferenceTrace,
402    activity: sys::TRITONSERVER_InferenceTraceActivity,
403    timestamp_ns: u64,
404    userp: *mut ::std::os::raw::c_void,
405) {
406    if !userp.is_null() {
407        let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
408        let foo_trace = Trace {
409            ptr: TraceInner(trace),
410            handlers_copy: handle.clone(),
411        };
412        let activity: Activity = transmute(activity);
413
414        let timestamp = Duration::from_nanos(timestamp_ns);
415
416        if let Some(activity_handle) = handle.activity_handler.as_ref() {
417            activity_handle.trace_activity(&foo_trace, activity, timestamp)
418        };
419
420        // Drop will be in delete method.
421        forget(handle);
422        forget(foo_trace.ptr);
423    }
424}
425
426unsafe extern "C" fn tensor_activity_wrapper<H: TraceHandler, T: TensorTraceHandler>(
427    trace: *mut sys::TRITONSERVER_InferenceTrace,
428    activity: sys::TRITONSERVER_InferenceTraceActivity,
429    name: *const ::std::os::raw::c_char,
430    datatype: sys::TRITONSERVER_DataType,
431    base: *const ::std::os::raw::c_void,
432    byte_size: usize,
433    shape: *const i64,
434    dim_count: u64,
435    memory_type: sys::TRITONSERVER_MemoryType,
436    _memory_type_id: i64,
437    userp: *mut ::std::os::raw::c_void,
438) {
439    if !userp.is_null() {
440        let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
441
442        let foo_trace = Trace {
443            ptr: TraceInner(trace),
444            handlers_copy: handle.clone(),
445        };
446        let activity: Activity = transmute(activity);
447
448        let data_type = unsafe { transmute::<u32, crate::memory::DataType>(datatype) };
449        let memory_type: MemoryType = unsafe { transmute(memory_type) };
450
451        let tensor_shape = Shape {
452            name: from_char_array(name),
453            datatype: data_type,
454            dims: slice::from_raw_parts(shape, dim_count as _).to_vec(),
455        };
456
457        let tensor_data = Buffer {
458            ptr: base as *mut _,
459            len: byte_size,
460            data_type,
461            memory_type,
462            owned: false,
463        };
464
465        if let Some(tensor_activity_handler) = handle.tensor_activity_handler.as_ref() {
466            tensor_activity_handler.trace_tensor_activity(
467                &foo_trace,
468                activity,
469                &tensor_data,
470                tensor_shape,
471            )
472        };
473
474        forget(handle);
475        forget(foo_trace.ptr);
476        // Drop will be in delete method.
477    }
478}