1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{
4 collections::HashMap,
5 ffi::{c_void, CStr},
6 hint,
7 mem::transmute,
8 os::raw::c_char,
9 ptr::{null, null_mut},
10 slice::from_raw_parts,
11 sync::Arc,
12};
13
14use log::trace;
15use tokio::runtime::Handle;
16
17use crate::{
18 allocator::Allocator,
19 error::{Error, CSTR_CONVERT_ERROR_PLUG},
20 from_char_array,
21 memory::{Buffer, DataType, MemoryType},
22 parameter::{Parameter, ParameterContent},
23 request::infer::InferenceError,
24 sys,
25};
26
27#[derive(Debug)]
35pub struct Output {
36 pub name: String,
38 pub shape: Vec<i64>,
40 buffer: Buffer,
41 parent_response: Arc<InferenceResponseWrapper>,
42 index_in_parent_response: u32,
43}
44
45unsafe impl Send for Output {}
47unsafe impl Sync for Output {}
48
49impl Output {
50 pub fn get_buffer(&self) -> &Buffer {
56 &self.buffer
57 }
58
59 pub fn memory_type(&self) -> MemoryType {
61 self.buffer.memory_type
62 }
63
64 pub fn data_type(&self) -> DataType {
66 self.buffer.data_type
67 }
68
69 pub fn classification_label(&self, class: u64) -> Result<String, Error> {
71 self.parent_response
72 .classification_label(self.index_in_parent_response, class)
73 }
74}
75
76pub struct Response {
77 outputs: Vec<Output>,
78 triton_ptr_wrapper: Arc<InferenceResponseWrapper>,
79 buffers_count: u32,
80 allocator: Arc<Allocator>,
83 parameters: Vec<Parameter>,
84}
85
86unsafe impl Send for Response {}
87unsafe impl Sync for Response {}
88
89impl Response {
90 pub(crate) fn new(
92 ptr: *mut sys::TRITONSERVER_InferenceResponse,
93 buffers_count: u32,
94 allocator: Arc<Allocator>,
95 runtime: Handle,
96 ) -> Result<Self, InferenceError> {
97 trace!("Response::new() is called");
98 let wrapper = Arc::new(InferenceResponseWrapper(ptr));
99
100 if let Some(error) = wrapper.error() {
102 drop(wrapper);
103
104 if allocator.is_alloc_called() {
105 while allocator
108 .0
109 .returned_buffers
110 .load(std::sync::atomic::Ordering::Relaxed)
111 < buffers_count
112 {
113 hint::spin_loop()
114 }
115 }
116
117 let bufs = std::thread::spawn(move || {
118 runtime.block_on(async move {
119 let mut bufs = allocator.0.output_buffers.write().await;
120 bufs.drain().collect()
121 })
122 })
123 .join()
124 .unwrap();
125
126 return Err(InferenceError {
127 error,
128 output_buffers: bufs,
129 });
130 }
131
132 let output_count = wrapper.output_count()?;
133
134 if output_count != buffers_count {
135 log::error!(
136 "output_count: {output_count} != count of assigned output buffers: {buffers_count}",
137 );
138 }
139
140 let mut outputs = Vec::new();
141 let mut output_ids = Vec::new();
142 trace!("Response::new() obtaining outputs");
143 for output_id in 0..output_count {
144 let output = wrapper.output(output_id)?;
145 output_ids.push(output.name.clone());
146 outputs.push(output);
147 }
148
149 let mut parameters = Vec::new();
150 for parameter_id in 0..wrapper.parameter_count()? {
151 parameters.push(wrapper.parameter(parameter_id)?);
152 }
153
154 Ok(Self {
155 outputs,
156 triton_ptr_wrapper: wrapper,
157 buffers_count,
158 allocator,
159 parameters,
160 })
161 }
162
163 pub fn get_outputs(&self) -> &[Output] {
165 &self.outputs
166 }
167
168 pub fn get_output<O: AsRef<str>>(&self, output_name: O) -> Option<&Output> {
170 self.outputs.iter().find(|o| o.name == output_name.as_ref())
171 }
172
173 pub async fn return_buffers(self) -> Result<HashMap<String, Buffer>, Error> {
176 drop(self.outputs);
182 drop(self.triton_ptr_wrapper);
183 trace!("return_buffer() awaiting on output receivers");
184 let buffers_count = self.buffers_count;
185
186 while self
187 .allocator
188 .0
189 .returned_buffers
190 .load(std::sync::atomic::Ordering::Relaxed)
191 < buffers_count
192 {
193 hint::spin_loop()
194 }
195
196 let res = {
197 let mut bufs = self.allocator.0.output_buffers.write().await;
198 bufs.drain().collect()
199 };
200
201 drop(self.allocator);
202 Ok(res)
203 }
204
205 pub fn model(&self) -> Result<(&str, i64), Error> {
207 self.triton_ptr_wrapper.model()
208 }
209
210 pub fn id(&self) -> Result<String, Error> {
212 self.triton_ptr_wrapper.id()
213 }
214
215 pub fn parameters(&self) -> Vec<Parameter> {
217 self.parameters.clone()
218 }
219}
220
221#[derive(Debug)]
222struct InferenceResponseWrapper(*mut sys::TRITONSERVER_InferenceResponse);
223
224impl InferenceResponseWrapper {
227 fn error(&self) -> Option<Error> {
230 let err = unsafe { sys::TRITONSERVER_InferenceResponseError(self.0) };
231 if err.is_null() {
232 None
233 } else {
234 Some(Error {
235 ptr: err,
236 owned: false,
237 })
238 }
239 }
240
241 fn model(&self) -> Result<(&str, i64), Error> {
243 let mut name = null::<c_char>();
244 let mut version: i64 = 0;
245 triton_call!(sys::TRITONSERVER_InferenceResponseModel(
246 self.0,
247 &mut name as *mut _,
248 &mut version as *mut _,
249 ))?;
250
251 assert!(!name.is_null());
252 Ok((
253 unsafe { CStr::from_ptr(name) }
254 .to_str()
255 .unwrap_or(CSTR_CONVERT_ERROR_PLUG),
256 version,
257 ))
258 }
259
260 fn id(&self) -> Result<String, Error> {
262 let mut id = null::<c_char>();
263 triton_call!(
264 sys::TRITONSERVER_InferenceResponseId(self.0, &mut id as *mut _),
265 from_char_array(id)
266 )
267 }
268
269 fn parameter_count(&self) -> Result<u32, Error> {
271 let mut count: u32 = 0;
272 triton_call!(
273 sys::TRITONSERVER_InferenceResponseParameterCount(self.0, &mut count as *mut _),
274 count
275 )
276 }
277
278 fn parameter(&self, index: u32) -> Result<Parameter, Error> {
280 let mut name = null::<c_char>();
281 let mut kind: sys::TRITONSERVER_ParameterType = 0;
282 let mut value = null::<c_void>();
283 triton_call!(sys::TRITONSERVER_InferenceResponseParameter(
284 self.0,
285 index,
286 &mut name as *mut _,
287 &mut kind as *mut _,
288 &mut value as *mut _,
289 ))?;
290
291 assert!(!name.is_null());
292 assert!(!value.is_null());
293 let name = unsafe { CStr::from_ptr(name) }
294 .to_str()
295 .unwrap_or(CSTR_CONVERT_ERROR_PLUG);
296 let value = match kind {
297 sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_STRING => {
298 ParameterContent::String(
299 unsafe { CStr::from_ptr(value as *const c_char) }
300 .to_str()
301 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
302 .to_string(),
303 )
304 }
305 sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_INT => {
306 ParameterContent::Int(unsafe { *(value as *mut i64) })
307 }
308 sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BOOL => {
309 ParameterContent::Bool(unsafe { *(value as *mut bool) })
310 }
311 _ => unreachable!(),
312 };
313 Parameter::new(name, value)
314 }
315
316 fn output_count(&self) -> Result<u32, Error> {
318 let mut count: u32 = 0;
319 triton_call!(
320 sys::TRITONSERVER_InferenceResponseOutputCount(self.0, &mut count as *mut _),
321 count
322 )
323 }
324
325 fn output(self: &Arc<Self>, index: u32) -> Result<Output, Error> {
326 let mut name = null::<c_char>();
327 let mut data_type: sys::TRITONSERVER_DataType = 0;
328 let mut shape = null::<i64>();
329 let mut dim_count: u64 = 0;
330 let mut base = null::<c_void>();
331 let mut byte_size: libc::size_t = 0;
332 let mut memory_type: sys::TRITONSERVER_MemoryType = 0;
333 let mut memory_type_id: i64 = 0;
334 let mut userp = null_mut::<c_void>();
335
336 triton_call!(sys::TRITONSERVER_InferenceResponseOutput(
337 self.0,
338 index,
339 &mut name as *mut _,
340 &mut data_type as *mut _,
341 &mut shape as *mut _,
342 &mut dim_count as *mut _,
343 &mut base as *mut _,
344 &mut byte_size as *mut _,
345 &mut memory_type as *mut _,
346 &mut memory_type_id as *mut _,
347 &mut userp as *mut _,
348 ))?;
349
350 assert!(!name.is_null());
351 assert!(!base.is_null());
352
353 let name = unsafe { CStr::from_ptr(name) }
354 .to_str()
355 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
356 .to_string();
357
358 let shape = if dim_count == 0 {
359 log::trace!(
360 "Model returned output '{name}' of shape []. Consider removing this output"
361 );
362 Vec::new()
363 } else {
364 unsafe { from_raw_parts(shape, dim_count as usize) }.to_vec()
365 };
366 let data_type = unsafe { transmute::<u32, crate::memory::DataType>(data_type) };
367 let memory_type: MemoryType = unsafe { transmute(memory_type) };
368
369 let buffer = Buffer {
372 ptr: base as *mut _,
373 len: byte_size as usize,
374 data_type,
375 memory_type,
376 owned: false,
377 };
378 Ok(Output {
379 name,
380 shape,
381 buffer,
382 index_in_parent_response: index,
383 parent_response: self.clone(),
384 })
385 }
386
387 fn classification_label(&self, index: u32, class: u64) -> Result<String, Error> {
389 let mut label = null::<c_char>();
390 triton_call!(
391 sys::TRITONSERVER_InferenceResponseOutputClassificationLabel(
392 self.0,
393 index,
394 class as usize,
395 &mut label as *mut _,
396 ),
397 from_char_array(label)
398 )
399 }
400}
401
402impl Drop for InferenceResponseWrapper {
403 fn drop(&mut self) {
404 if !self.0.is_null() {
405 unsafe { sys::TRITONSERVER_InferenceResponseDelete(self.0) };
406 }
407 }
408}