1use core::slice;
53use std::{
54 ffi::{c_void, CStr},
55 mem::{forget, transmute},
56 os::raw::c_char,
57 ptr::{null, null_mut},
58 sync::Arc,
59 time::Duration,
60};
61
62use crate::{
63 error::{Error, CSTR_CONVERT_ERROR_PLUG},
64 from_char_array,
65 message::Shape,
66 sys, to_cstring, Buffer, MemoryType,
67};
68
69bitflags::bitflags! {
70 struct Level: u32 {
77 const DISABLED = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_DISABLED;
79 #[deprecated]
81 const MIN = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MIN;
82 #[deprecated]
84 const MAX = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_MAX;
85 const TIMESTAMPS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TIMESTAMPS;
87 const TENSORS = sys::tritonserver_tracelevel_enum_TRITONSERVER_TRACE_LEVEL_TENSORS;
89 }
90}
91
92impl Level {
93 #[allow(dead_code)]
94 fn as_str(self) -> &'static str {
96 unsafe {
97 let ptr = sys::TRITONSERVER_InferenceTraceLevelString(self.bits());
98 assert!(!ptr.is_null());
99 CStr::from_ptr(ptr)
100 .to_str()
101 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
102 }
103 }
104}
105
106#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
108#[repr(u32)]
109pub enum Activity {
110 RequestStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_START,
111 QueueStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_QUEUE_START,
112 ComputeStart = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_START,
113 ComputeInputEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_INPUT_END,
114 ComputeOutputStart =
115 sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_OUTPUT_START,
116 ComputeEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_COMPUTE_END,
117 RequestEnd = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_REQUEST_END,
118 TensorQueueInput = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT,
119 TensorBackendInput =
120 sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT,
121 TensorBackendOutput =
122 sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT,
123 CustomActivity = sys::tritonserver_traceactivity_enum_TRITONSERVER_TRACE_CUSTOM_ACTIVITY,
124}
125
126pub trait TraceHandler: Send + Sync + 'static {
128 fn trace_activity(&self, trace: &Trace, event: Activity, event_time: Duration);
138}
139
140impl TraceHandler for () {
141 fn trace_activity(&self, _trace: &Trace, _event: Activity, _event_time: Duration) {}
142}
143
144pub trait TensorTraceHandler: Send + Sync + 'static {
146 fn trace_tensor_activity(
155 &self,
156 trace: &Trace,
157 event: Activity,
158 tensor_data: &Buffer,
159 tensor_shape: Shape,
160 );
161}
162
163impl TensorTraceHandler for () {
164 fn trace_tensor_activity(
165 &self,
166 _trace: &Trace,
167 _event: Activity,
168 _tensor_data: &Buffer,
169 _tensor_shape: Shape,
170 ) {
171 }
172}
173
174pub const NOOP: Option<()> = None;
176
177struct TraceCallbackItems<H: TraceHandler, T: TensorTraceHandler> {
178 activity_handler: Option<H>,
179 tensor_activity_handler: Option<T>,
180}
181
182trait DynamicTypeHelper: Send + Sync {}
187impl<H: TraceHandler, T: TensorTraceHandler> DynamicTypeHelper for TraceCallbackItems<H, T> {}
188
189pub struct Trace {
193 pub(crate) ptr: TraceInner,
194 handlers_copy: Arc<dyn DynamicTypeHelper>,
196}
197
198pub(crate) struct TraceInner(pub(crate) *mut sys::TRITONSERVER_InferenceTrace);
199unsafe impl Send for TraceInner {}
200unsafe impl Sync for TraceInner {}
202
203impl PartialEq for Trace {
204 fn eq(&self, other: &Self) -> bool {
205 let left = match self.id() {
206 Ok(l) => l,
207 Err(err) => {
208 log::warn!("Error getting ID for two Traces comparison: {err}");
209 return false;
210 }
211 };
212 let right = match other.id() {
213 Ok(r) => r,
214 Err(err) => {
215 log::warn!("Error getting ID for two Traces comparison: {err}");
216 return false;
217 }
218 };
219 left == right
220 }
221}
222impl Eq for Trace {}
223
224impl Trace {
225 pub fn new_with_handle<H: TraceHandler, T: TensorTraceHandler>(
241 parent_id: u64,
242 activity_handler: Option<H>,
243 tensor_activity_handler: Option<T>,
244 ) -> Result<Self, Error> {
245 let enable_activity = activity_handler.is_some();
246 let enable_tensor_activity = tensor_activity_handler.is_some();
247
248 let level = match (enable_activity, enable_tensor_activity) {
249 (true, true) => Level::TENSORS | Level::TIMESTAMPS,
250 (true, false) => Level::TIMESTAMPS,
251 (false, true) => Level::TENSORS,
252 (false, false) => Level::DISABLED,
253 };
254
255 let mut ptr = null_mut::<sys::TRITONSERVER_InferenceTrace>();
256 let handlers = Arc::new(TraceCallbackItems {
257 activity_handler,
258 tensor_activity_handler,
259 });
260 let raw_handlers = Arc::into_raw(handlers.clone()) as *mut c_void;
261
262 triton_call!(sys::TRITONSERVER_InferenceTraceTensorNew(
263 &mut ptr as *mut _,
264 level.bits(),
265 parent_id,
266 enable_activity.then_some(activity_wraper::<H, T>),
267 enable_tensor_activity.then_some(tensor_activity_wrapper::<H, T>),
268 Some(delete::<H, T>),
269 raw_handlers,
270 ))?;
271
272 assert!(!ptr.is_null());
273 let trace = Trace {
274 ptr: TraceInner(ptr),
275 handlers_copy: handlers,
276 };
277 Ok(trace)
278 }
279
280 pub fn report_activity<N: AsRef<str>>(
285 &self,
286 timestamp: Duration,
287 activity_name: N,
288 ) -> Result<(), Error> {
289 let name = to_cstring(activity_name)?;
290 triton_call!(sys::TRITONSERVER_InferenceTraceReportActivity(
291 self.ptr.0,
292 timestamp.as_nanos() as _,
293 name.as_ptr()
294 ))
295 }
296
297 pub fn id(&self) -> Result<u64, Error> {
300 let mut id: u64 = 0;
301 triton_call!(
302 sys::TRITONSERVER_InferenceTraceId(self.ptr.0, &mut id as *mut _),
303 id
304 )
305 }
306
307 pub fn parent_id(&self) -> Result<u64, Error> {
311 let mut id: u64 = 0;
312 triton_call!(
313 sys::TRITONSERVER_InferenceTraceParentId(self.ptr.0, &mut id as *mut _),
314 id
315 )
316 }
317
318 pub fn model_name(&self) -> Result<String, Error> {
320 let mut name = null::<c_char>();
321 triton_call!(
322 sys::TRITONSERVER_InferenceTraceModelName(self.ptr.0, &mut name as *mut _),
323 from_char_array(name)
324 )
325 }
326
327 pub fn model_version(&self) -> Result<i64, Error> {
329 let mut version: i64 = 0;
330 triton_call!(
331 sys::TRITONSERVER_InferenceTraceModelVersion(self.ptr.0, &mut version as *mut _),
332 version
333 )
334 }
335
336 pub fn request_id(&self) -> Result<String, Error> {
339 let mut request_id = null::<c_char>();
340
341 triton_call!(
342 sys::TRITONSERVER_InferenceTraceRequestId(self.ptr.0, &mut request_id as *mut _),
343 from_char_array(request_id)
344 )
345 }
346
347 pub fn spawn_child(&self) -> Result<Trace, Error> {
352 let mut trace = null_mut();
353 triton_call!(
354 sys::TRITONSERVER_InferenceTraceSpawnChildTrace(self.ptr.0, &mut trace),
355 Trace {
356 ptr: TraceInner(trace),
357 handlers_copy: self.handlers_copy.clone(),
358 }
359 )
360 }
361
362 pub fn set_context(&mut self, context: String) -> Result<&mut Self, Error> {
364 let context = to_cstring(context)?;
365 triton_call!(
366 sys::TRITONSERVER_InferenceTraceSetContext(self.ptr.0, context.as_ptr()),
367 self
368 )
369 }
370
371 pub fn context(&self) -> Result<String, Error> {
373 let mut context = null::<c_char>();
374 triton_call!(
375 sys::TRITONSERVER_InferenceTraceContext(self.ptr.0, &mut context as *mut _),
376 from_char_array(context)
377 )
378 }
379}
380
381impl Drop for TraceInner {
382 fn drop(&mut self) {
383 if !self.0.is_null() {
384 unsafe { sys::TRITONSERVER_InferenceTraceDelete(self.0) };
385 }
386 }
387}
388
389unsafe extern "C" fn delete<H: TraceHandler, T: TensorTraceHandler>(
390 this: *mut sys::TRITONSERVER_InferenceTrace,
391 userp: *mut c_void,
392) {
393 if !userp.is_null() && !this.is_null() {
394 sys::TRITONSERVER_InferenceTraceDelete(this);
395 Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
396 }
397}
398
399unsafe extern "C" fn activity_wraper<H: TraceHandler, T: TensorTraceHandler>(
400 trace: *mut sys::TRITONSERVER_InferenceTrace,
401 activity: sys::TRITONSERVER_InferenceTraceActivity,
402 timestamp_ns: u64,
403 userp: *mut ::std::os::raw::c_void,
404) {
405 if !userp.is_null() {
406 let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
407 let foo_trace = Trace {
408 ptr: TraceInner(trace),
409 handlers_copy: handle.clone(),
410 };
411 let activity: Activity = transmute(activity);
412
413 let timestamp = Duration::from_nanos(timestamp_ns);
414
415 if let Some(activity_handle) = handle.activity_handler.as_ref() {
416 activity_handle.trace_activity(&foo_trace, activity, timestamp)
417 };
418
419 forget(handle);
421 forget(foo_trace.ptr);
422 }
423}
424
425unsafe extern "C" fn tensor_activity_wrapper<H: TraceHandler, T: TensorTraceHandler>(
426 trace: *mut sys::TRITONSERVER_InferenceTrace,
427 activity: sys::TRITONSERVER_InferenceTraceActivity,
428 name: *const ::std::os::raw::c_char,
429 datatype: sys::TRITONSERVER_DataType,
430 base: *const ::std::os::raw::c_void,
431 byte_size: usize,
432 shape: *const i64,
433 dim_count: u64,
434 memory_type: sys::TRITONSERVER_MemoryType,
435 _memory_type_id: i64,
436 userp: *mut ::std::os::raw::c_void,
437) {
438 if !userp.is_null() {
439 let handle = Arc::from_raw(userp as *const TraceCallbackItems<H, T>);
440
441 let foo_trace = Trace {
442 ptr: TraceInner(trace),
443 handlers_copy: handle.clone(),
444 };
445 let activity: Activity = transmute(activity);
446
447 let data_type = unsafe { transmute::<u32, crate::memory::DataType>(datatype) };
448 let memory_type: MemoryType = unsafe { transmute(memory_type) };
449
450 let tensor_shape = Shape {
451 name: from_char_array(name),
452 datatype: data_type,
453 dims: slice::from_raw_parts(shape, dim_count as _).to_vec(),
454 };
455
456 let tensor_data = Buffer {
457 ptr: base as *mut _,
458 len: byte_size,
459 data_type,
460 memory_type,
461 owned: false,
462 };
463
464 if let Some(tensor_activity_handler) = handle.tensor_activity_handler.as_ref() {
465 tensor_activity_handler.trace_tensor_activity(
466 &foo_trace,
467 activity,
468 &tensor_data,
469 tensor_shape,
470 )
471 };
472
473 forget(handle);
474 forget(foo_trace.ptr);
475 }
477}