tritonserver_rs/
trace.rs

1//! Tracing utilities for debugging and profiling.
2//!
3//! Usage example:
4//! ```
5//! struct TraceH;
6//! impl TraceHandler for TraceH {
7//!     fn trace_activity(
8//!        &self,
9//!        trace: &tritonserver_rs::trace::Trace,
10//!        event: Activity,
11//!        event_time: Duration,
12//!     ) {
13//!         log::info!(
14//!             "Tracing activities: Trace_id: {}, event: {event:?}, event_time_secs: {}",
15//!             trace.id().unwrap(),
16//!             event_time.as_secs()
17//!         );
18//!         if event == Activity::ComputeStart {
19//!             log::info!("Computations start, spawning new Trace");
20//!             trace.spawn_child().unwrap();
21//!         }
22//!     }
23//! }
24//!
25//! impl TensorTraceHandler for TraceH {
26//!     fn trace_tensor_activity(
27//!         &self,
28//!         trace: &Trace,
29//!         event: Activity,
30//!         _tensor_data: &tritonserver_rs::Buffer,
31//!         tensor_shape: tritonserver_rs::message::Shape,
32//!     ) {
33//!         log::info!(
34//!             "Tracing Tensor Activity: Trace_id: {}, event: {event:?}, tensor name: {}",
35//!             trace.id().unwrap(),
36//!             tensor_shape.name
37//!         );
38//!     }
39//! }
40//!
41//! /// Adds custom tracing to Inference Request.
42//! fn add_trace_to_request(request: &mut Request) {
43//!    request.add_trace(Trace::new_with_handle(
44//!        Level::TIMESTAMPS | Level::TENSORS,
45//!        0,
46//!        TraceH,
47//!        Some(TraceH),
48//!    ).unwrap());
49//! }
50//! ```
51
52use core::slice;
53use std::{
54    ffi::{c_void, CStr},
55    mem::{forget, transmute},
56    os::raw::c_char,
57    ptr::{null, null_mut},
58    sync::Arc,
59    time::Duration,
60};
61
62use crate::{
63    error::{Error, CSTR_CONVERT_ERROR_PLUG},
64    from_char_array,
65    message::Shape,
66    sys, to_cstring, Buffer, MemoryType,
67};
68
69bitflags::bitflags! {
70    /// Trace levels. The trace level controls the type of trace
71    ///  activities that are reported for an inference request.
72    ///
73    /// Trace level values can be combined to trace multiple types of activities. For example, use
74    /// ([Level::TIMESTAMPS] | [Level::TENSORS]) to trace both timestamps and
75    ///  tensors for an inference request.
76    struct Level: u32 {
77        /// Tracing disabled. No trace activities are reported.
78        const DISABLED = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_DISABLED;
79        /// Deprecated. Use [Level::TIMESTAMPS].
80        #[deprecated]
81        const MIN = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MIN;
82        /// Deprecated. Use [Level::TIMESTAMPS].
83        #[deprecated]
84        const MAX = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MAX;
85        /// Record timestamps for the inference request.
86        const TIMESTAMPS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TIMESTAMPS;
87        /// Record input and output tensor values for the inference request.
88        const TENSORS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TENSORS;
89    }
90}
91
92impl Level {
93    #[allow(dead_code)]
94    /// Get the string representation of a trace level.
95    fn as_str(self) -> &'static str {
96        unsafe {
97            let ptr = sys::TRITONSERVER_InferenceTraceLevelString(self.bits());
98            assert!(!ptr.is_null());
99            CStr::from_ptr(ptr)
100                .to_str()
101                .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
102        }
103    }
104}
105
106/// Enum representation of inference status.
107#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
108#[repr(u32)]
109pub enum Activity {
110    RequestStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_START,
111    QueueStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_QUEUE_START,
112    ComputeStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_START,
113    ComputeInputEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_INPUT_END,
114    ComputeOutputStart =
115        sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_OUTPUT_START,
116    ComputeEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_END,
117    RequestEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_END,
118    TensorQueueInput = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT,
119    TensorBackendInput =
120        sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT,
121    TensorBackendOutput =
122        sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT,
123    CustomActivity = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_CUSTOM_ACTIVITY,
124}
125
126/// Inference event handler trait.
127pub trait TraceHandler: Send + Sync + 'static {
128    /// This function is invoked each time the `event` occures.
129    ///
130    /// `trace`: Trace object that was reported.
131    /// Note that child traces of constructed one also are reported with this fn.
132    /// Check [Trace::new_with_handle] for more info.\
133    /// `event`: activity that has occurred. \
134    /// `event_time`: time when event occured. \
135    ///     Triton Trace APIs report timestamps using steady clock, which is a monotonic clock that ensures time always movess forward.
136    ///     This clock is not related to wall clock time and, for example, can measure time since last reboot (aka /proc/uptime).
137    fn trace_activity(&self, trace: &Trace, event: Activity, event_time: Duration);
138}
139
140impl TraceHandler for () {
141    fn trace_activity(&self, _trace: &Trace, _event: Activity, _event_time: Duration) {}
142}
143
144/// Tensor event handler trait.
145pub trait TensorTraceHandler: Send + Sync + 'static {
146    /// This function is invoked each time the tensor `event` occures.
147    ///
148    /// `trace`: Trace object that was reported.
149    /// Note that child traces of constructed one also are reported with this fn.
150    /// Check [Trace::new_with_handle] for more info.\
151    /// `event`: activity that has occurred. \
152    /// `tensor_data`: borrowed buffer containing data of the tensor. \
153    /// `tensor_shape`: shape (name, data_type and dims) of the tensor.
154    fn trace_tensor_activity(
155        &self,
156        trace: &Trace,
157        event: Activity,
158        tensor_data: &Buffer,
159        tensor_shape: Shape,
160    );
161}
162
163impl TensorTraceHandler for () {
164    fn trace_tensor_activity(
165        &self,
166        _trace: &Trace,
167        _event: Activity,
168        _tensor_data: &Buffer,
169        _tensor_shape: Shape,
170    ) {
171    }
172}
173
174/// Can be passed to [Trace::new_with_handle] if no TENSORS or TIMESTAMPS are needed.
175pub const NOOP: Option<()> = None;
176
177struct TraceCallbackItems<H: TraceHandler, T: TensorTraceHandler> {
178    activity_handler: Option<H>,
179    tensor_activity_handler: Option<T>,
180}
181
182/// Don't want to use annotations like Trace<H, T> for
183/// handlers_copy: Arc<TraceCallbackItems<H,T>>, so will use Arc<dyn DynamicTypeHelper>.
184///
185/// If someone can teach me how to do it better, i'm all ears((.
186trait DynamicTypeHelper: Send + Sync {}
187impl<H: TraceHandler, T: TensorTraceHandler> DynamicTypeHelper for TraceCallbackItems<H, T> {}
188
189/// Inference object that provides custom tracing.
190///
191/// Is constructed with [TraceHandler] object that is activated each time an event occures.
192pub struct Trace {
193    pub(crate) ptr: TraceInner,
194    /// So callback won't be dropped if trace reports after the fn delete (inference).
195    handlers_copy: Arc<dyn DynamicTypeHelper>,
196}
197
198pub(crate) struct TraceInner(pub(crate) *mut sys::TRITONSERVER_InferenceTrace);
199unsafe impl Send for TraceInner {}
200/// Unclonable, so Sync is safe
201unsafe impl Sync for TraceInner {}
202
203impl PartialEq for Trace {
204    fn eq(&self, other: &Self) -> bool {
205        let left = match self.id() {
206            Ok(l) => l,
207            Err(err) => {
208                log::warn!("Error getting ID for two Traces comparison: {err}");
209                return false;
210            }
211        };
212        let right = match other.id() {
213            Ok(r) => r,
214            Err(err) => {
215                log::warn!("Error getting ID for two Traces comparison: {err}");
216                return false;
217            }
218        };
219        left == right
220    }
221}
222impl Eq for Trace {}
223
224impl Trace {
225    /// Create a new inference trace object.
226    ///
227    /// The `activity_handler` and `tensor_activity_handler` will be called to report activity
228    /// including [Trace::report_activity] called by this trace as well as by __every__ child traces that are spawned
229    /// by this one. So the [TraceHandler::trace_activity] and [TensorTraceHandler::trace_tensor_activity]
230    /// should check the trace object (first argument) that are passed to it
231    /// to determine specifically what trace was reported.
232    ///
233    /// `level`: The tracing level. \
234    /// `parent_id`: The parent trace id for this trace.
235    /// A value of 0 indicates that there is not parent trace. \
236    /// `activity_handler`: The callback function where activity (on timeline event)
237    ///  for the trace (and all the child traces) is reported. \
238    /// `tensor_activity_handler`: Optional callback function where activity (on tensor event)
239    /// for the trace (and all the child traces) is reported.
240    pub fn new_with_handle<H: TraceHandler, T: TensorTraceHandler>(
241        parent_id: u64,
242        activity_handler: Option<H>,
243        tensor_activity_handler: Option<T>,
244    ) -> Result<Self, Error> {
245        let enable_activity = activity_handler.is_some();
246        let enable_tensor_activity = tensor_activity_handler.is_some();
247
248        let level = match (enable_activity, enable_tensor_activity) {
249            (true, true) => Level::TENSORS | Level::TIMESTAMPS,
250            (true, false) => Level::TIMESTAMPS,
251            (false, true) => Level::TENSORS,
252            (false, false) => Level::DISABLED,
253        };
254
255        let mut ptr = null_mut::<sys::TRITONSERVER_InferenceTrace>();
256        let handlers = Arc::new(TraceCallbackItems {
257            activity_handler,
258            tensor_activity_handler,
259        });
260        let raw_handlers = Arc::into_raw(handlers.clone()) as *mut c_void;
261
262        triton_call!(sys::TRITONSERVER_InferenceTraceTensorNew(
263            &mut ptr as *mut _,
264            level.bits(),
265            parent_id,
266            enable_activity.then_some(activity_wraper::<H, T>),
267            enable_tensor_activity.then_some(tensor_activity_wrapper::<H, T>),
268            Some(delete::<H, T>),
269            raw_handlers,
270        ))?;
271
272        assert!(!ptr.is_null());
273        let trace = Trace {
274            ptr: TraceInner(ptr),
275            handlers_copy: handlers,
276        };
277        Ok(trace)
278    }
279
280    /// Report a trace activity. All the traces reported using this API will be send [Activity::CustomActivity] type.
281    ///
282    /// `timestamp` The timestamp associated with the trace activity. \
283    /// `name` The trace activity name.
284    pub fn report_activity<N: AsRef<str>>(
285        &self,
286        timestamp: Duration,
287        activity_name: N,
288    ) -> Result<(), Error> {
289        let name = to_cstring(activity_name)?;
290        triton_call!(sys::TRITONSERVER_InferenceTraceReportActivity(
291            self.ptr.0,
292            timestamp.as_nanos() as _,
293            name.as_ptr()
294        ))
295    }
296
297    /// Get the id associated with the trace.
298    /// Every trace is assigned an id that is unique across all traces created for a Triton server.
299    pub fn id(&self) -> Result<u64, Error> {
300        let mut id: u64 = 0;
301        triton_call!(
302            sys::TRITONSERVER_InferenceTraceId(self.ptr.0, &mut id as *mut _),
303            id
304        )
305    }
306
307    /// Get the parent id associated with the trace. \
308    /// The parent id indicates a parent-child relationship between two traces.
309    /// A parent id value of 0 indicates that there is no parent trace.
310    pub fn parent_id(&self) -> Result<u64, Error> {
311        let mut id: u64 = 0;
312        triton_call!(
313            sys::TRITONSERVER_InferenceTraceParentId(self.ptr.0, &mut id as *mut _),
314            id
315        )
316    }
317
318    /// Get the name of the model associated with the trace.
319    pub fn model_name(&self) -> Result<String, Error> {
320        let mut name = null::<c_char>();
321        triton_call!(
322            sys::TRITONSERVER_InferenceTraceModelName(self.ptr.0, &mut name as *mut _),
323            from_char_array(name)
324        )
325    }
326
327    /// Get the version of the model associated with the trace.
328    pub fn model_version(&self) -> Result<i64, Error> {
329        let mut version: i64 = 0;
330        triton_call!(
331            sys::TRITONSERVER_InferenceTraceModelVersion(self.ptr.0, &mut version as *mut _),
332            version
333        )
334    }
335
336    /// Get the request id associated with a trace.
337    /// Returns the version of the model associated with the trace.
338    pub fn request_id(&self) -> Result<String, Error> {
339        let mut request_id = null::<c_char>();
340
341        triton_call!(
342            sys::TRITONSERVER_InferenceTraceRequestId(self.ptr.0, &mut request_id as *mut _),
343            from_char_array(request_id)
344        )
345    }
346
347    /// Returns the child trace, spawned from the parent(self) trace.
348    ///
349    /// Be causious: Trace is deleting on drop, so don't forget to save it.
350    /// Also do not use parent and child traces for different Requests: it can lead to Seq Faults.
351    pub fn spawn_child(&self) -> Result<Trace, Error> {
352        let mut trace = null_mut();
353        triton_call!(
354            sys::TRITONSERVER_InferenceTraceSpawnChildTrace(self.ptr.0, &mut trace),
355            Trace {
356                ptr: TraceInner(trace),
357                handlers_copy: self.handlers_copy.clone(),
358            }
359        )
360    }
361
362    /// Set context to Triton Trace.
363    pub fn set_context(&mut self, context: String) -> Result<&mut Self, Error> {
364        let context = to_cstring(context)?;
365        triton_call!(
366            sys::TRITONSERVER_InferenceTraceSetContext(self.ptr.0, context.as_ptr()),
367            self
368        )
369    }
370
371    /// Get Triton Trace context.
372    pub fn context(&self) -> Result<String, Error> {
373        let mut context = null::<c_char>();
374        triton_call!(
375            sys::TRITONSERVER_InferenceTraceContext(self.ptr.0, &mut context as *mut _),
376            from_char_array(context)
377        )
378    }
379}
380
381impl Drop for TraceInner {
382    fn drop(&mut self) {
383        if !self.0.is_null() {
384            unsafe { sys::TRITONSERVER_InferenceTraceDelete(self.0) };
385        }
386    }
387}
388
389unsafe extern "C" fn delete<H: TraceHandler, T: TensorTraceHandler>(
390    this: *mut sys::TRITONSERVER_InferenceTrace,
391    userp: *mut c_void,
392) {
393    if !userp.is_null() && !this.is_null() {
394        sys::TRITONSERVER_InferenceTraceDelete(this);
395        Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
396    }
397}
398
399unsafe extern "C" fn activity_wraper<H: TraceHandler, T: TensorTraceHandler>(
400    trace: *mut sys::TRITONSERVER_InferenceTrace,
401    activity: sys::TRITONSERVER_InferenceTraceActivity,
402    timestamp_ns: u64,
403    userp: *mut ::std::os::raw::c_void,
404) {
405    if !userp.is_null() {
406        let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
407        let foo_trace = Trace {
408            ptr: TraceInner(trace),
409            handlers_copy: handle.clone(),
410        };
411        let activity: Activity = transmute(activity);
412
413        let timestamp = Duration::from_nanos(timestamp_ns);
414
415        if let Some(activity_handle) = handle.activity_handler.as_ref() {
416            activity_handle.trace_activity(&foo_trace, activity, timestamp)
417        };
418
419        // Drop will be in delete method.
420        forget(handle);
421        forget(foo_trace.ptr);
422    }
423}
424
425unsafe extern "C" fn tensor_activity_wrapper<H: TraceHandler, T: TensorTraceHandler>(
426    trace: *mut sys::TRITONSERVER_InferenceTrace,
427    activity: sys::TRITONSERVER_InferenceTraceActivity,
428    name: *const ::std::os::raw::c_char,
429    datatype: sys::TRITONSERVER_DataType,
430    base: *const ::std::os::raw::c_void,
431    byte_size: usize,
432    shape: *const i64,
433    dim_count: u64,
434    memory_type: sys::TRITONSERVER_MemoryType,
435    _memory_type_id: i64,
436    userp: *mut ::std::os::raw::c_void,
437) {
438    if !userp.is_null() {
439        let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
440
441        let foo_trace = Trace {
442            ptr: TraceInner(trace),
443            handlers_copy: handle.clone(),
444        };
445        let activity: Activity = transmute(activity);
446
447        let data_type = unsafe { transmute::<u32, crate::memory::DataType>(datatype) };
448        let memory_type: MemoryType = unsafe { transmute(memory_type) };
449
450        let tensor_shape = Shape {
451            name: from_char_array(name),
452            datatype: data_type,
453            dims: slice::from_raw_parts(shape, dim_count as _).to_vec(),
454        };
455
456        let tensor_data = Buffer {
457            ptr: base as *mut _,
458            len: byte_size,
459            data_type,
460            memory_type,
461            owned: false,
462        };
463
464        if let Some(tensor_activity_handler) = handle.tensor_activity_handler.as_ref() {
465            tensor_activity_handler.trace_tensor_activity(
466                &foo_trace,
467                activity,
468                &tensor_data,
469                tensor_shape,
470            )
471        };
472
473        forget(handle);
474        forget(foo_trace.ptr);
475        // Drop will be in delete method.
476    }
477}