tritonserver_rs/
response.rs

1#![allow(clippy::arc_with_non_send_sync)]
2
3use std::{
4    collections::HashMap,
5    ffi::{c_void, CStr},
6    hint,
7    mem::transmute,
8    os::raw::c_char,
9    ptr::{null, null_mut},
10    slice::from_raw_parts,
11    sync::Arc,
12};
13
14use log::trace;
15use tokio::runtime::Handle;
16
17use crate::{
18    allocator::Allocator,
19    error::{Error, CSTR_CONVERT_ERROR_PLUG},
20    from_char_array,
21    memory::{Buffer, DataType, MemoryType},
22    parameter::{Parameter, ParameterContent},
23    request::infer::InferenceError,
24    sys,
25};
26
27/// Output tensor of the model.
28///
29/// Must not outlive the parent Response.
30///
31/// Each output is a reference on a part of
32/// the output buffer (passed to request via Allocator) that contains the embedding.
33/// May be smaller than initial buffer, if Triton does not need whole buffer.
34#[derive(Debug)]
35pub struct Output {
36    /// Name of the output tensor.
37    pub name: String,
38    /// Shape (dims) of the output tensor.
39    pub shape: Vec<i64>,
40    buffer: Buffer,
41    parent_response: Arc<InferenceResponseWrapper>,
42    index_in_parent_response: u32,
43}
44
45// Can't copy Output and use it's ptr directly from public, so safe.
46unsafe impl Send for Output {}
47unsafe impl Sync for Output {}
48
49impl Output {
50    /// Get the Buffer containing the inference result (embedding).
51    ///
52    /// # Safety
53    /// Do not mutate data of the returned value.
54    /// If mutable (owned) Buffer is needed, use [Response::return_buffers].
55    pub fn get_buffer(&self) -> &Buffer {
56        &self.buffer
57    }
58
59    /// Get memory type of the output tensor.
60    pub fn memory_type(&self) -> MemoryType {
61        self.buffer.memory_type
62    }
63
64    /// Get data type of the output tensor.
65    pub fn data_type(&self) -> DataType {
66        self.buffer.data_type
67    }
68
69    /// Get a classification label associated with the output.
70    pub fn classification_label(&self, class: u64) -> Result<String, Error> {
71        self.parent_response
72            .classification_label(self.index_in_parent_response, class)
73    }
74}
75
76pub struct Response {
77    outputs: Vec<Output>,
78    triton_ptr_wrapper: Arc<InferenceResponseWrapper>,
79    buffers_count: u32,
80    /// Алокатор нужен тут, так как после вызова InferenceResponseWrapper::drop() тритон начинает вызывать
81    /// release(), в которых участвует алокатор. Соответсвенно, он не должен быть уничтожен до этого момента.
82    allocator: Arc<Allocator>,
83    parameters: Vec<Parameter>,
84}
85
86unsafe impl Send for Response {}
87unsafe impl Sync for Response {}
88
89impl Response {
90    /// Read the inference result, obtain output.
91    pub(crate) fn new(
92        ptr: *mut sys::TRITONSERVER_InferenceResponse,
93        buffers_count: u32,
94        allocator: Arc<Allocator>,
95        runtime: Handle,
96    ) -> Result<Self, InferenceError> {
97        trace!("Response::new() is called");
98        let wrapper = Arc::new(InferenceResponseWrapper(ptr));
99
100        // Ошибка в ходе выполнения.
101        if let Some(error) = wrapper.error() {
102            drop(wrapper);
103
104            if allocator.is_alloc_called() {
105                // Waiting for the end of the release
106
107                while allocator
108                    .0
109                    .returned_buffers
110                    .load(std::sync::atomic::Ordering::Relaxed)
111                    < buffers_count
112                {
113                    hint::spin_loop()
114                }
115            }
116
117            let bufs = std::thread::spawn(move || {
118                runtime.block_on(async move {
119                    let mut bufs = allocator.0.output_buffers.write().await;
120                    bufs.drain().collect()
121                })
122            })
123            .join()
124            .unwrap();
125
126            return Err(InferenceError {
127                error,
128                output_buffers: bufs,
129            });
130        }
131
132        let output_count = wrapper.output_count()?;
133
134        if output_count != buffers_count {
135            log::error!(
136                "output_count: {output_count} != count of assigned output buffers: {buffers_count}",
137            );
138        }
139
140        let mut outputs = Vec::new();
141        let mut output_ids = Vec::new();
142        trace!("Response::new() obtaining outputs");
143        for output_id in 0..output_count {
144            let output = wrapper.output(output_id)?;
145            output_ids.push(output.name.clone());
146            outputs.push(output);
147        }
148
149        let mut parameters = Vec::new();
150        for parameter_id in 0..wrapper.parameter_count()? {
151            parameters.push(wrapper.parameter(parameter_id)?);
152        }
153
154        Ok(Self {
155            outputs,
156            triton_ptr_wrapper: wrapper,
157            buffers_count,
158            allocator,
159            parameters,
160        })
161    }
162
163    /// The results of the inference.
164    pub fn get_outputs(&self) -> &[Output] {
165        &self.outputs
166    }
167
168    /// Get `output_name` result of the inference.
169    pub fn get_output<O: AsRef<str>>(&self, output_name: O) -> Option<&Output> {
170        self.outputs.iter().find(|o| o.name == output_name.as_ref())
171    }
172
173    /// Deconstruct the Response and get all the allocated output buffers back. \
174    /// If you want just an immutable result of the inference, use [Response::get_outputs] or [Response::get_output] method.
175    pub async fn return_buffers(self) -> Result<HashMap<String, Buffer>, Error> {
176        // Triron will call `allocator::release()`
177        // (therefore, we can get output buffer back)
178        // ONLY after we call sys::TRITONSERVER_InferenceResponseDelete(),
179        // that is the wrapper destructor.
180        // each Output has Arc on wrapper so drop outputs first.
181        drop(self.outputs);
182        drop(self.triton_ptr_wrapper);
183        trace!("return_buffer() awaiting on output receivers");
184        let buffers_count = self.buffers_count;
185
186        while self
187            .allocator
188            .0
189            .returned_buffers
190            .load(std::sync::atomic::Ordering::Relaxed)
191            < buffers_count
192        {
193            hint::spin_loop()
194        }
195
196        let res = {
197            let mut bufs = self.allocator.0.output_buffers.write().await;
198            bufs.drain().collect()
199        };
200
201        drop(self.allocator);
202        Ok(res)
203    }
204
205    /// Get model name and version used to produce thr response.
206    pub fn model(&self) -> Result<(&str, i64), Error> {
207        self.triton_ptr_wrapper.model()
208    }
209
210    /// Get the ID of the request corresponding to the response.
211    pub fn id(&self) -> Result<String, Error> {
212        self.triton_ptr_wrapper.id()
213    }
214
215    /// Get all information about the response parameters.
216    pub fn parameters(&self) -> Vec<Parameter> {
217        self.parameters.clone()
218    }
219}
220
221#[derive(Debug)]
222struct InferenceResponseWrapper(*mut sys::TRITONSERVER_InferenceResponse);
223
224// Если в какой-то момент нужно будет вернуть все эти методы в публичное пространство, необходимо
225// поставить lifetime на Output и Parameter.
226impl InferenceResponseWrapper {
227    /// Return the error status of an inference response.
228    /// Return a Some(Error) object on failure, return None on success.
229    fn error(&self) -> Option<Error> {
230        let err = unsafe { sys::TRITONSERVER_InferenceResponseError(self.0) };
231        if err.is_null() {
232            None
233        } else {
234            Some(Error {
235                ptr: err,
236                owned: false,
237            })
238        }
239    }
240
241    /// Get model name and version used to produce a response.
242    fn model(&self) -> Result<(&str, i64), Error> {
243        let mut name = null::<c_char>();
244        let mut version: i64 = 0;
245        triton_call!(sys::TRITONSERVER_InferenceResponseModel(
246            self.0,
247            &mut name as *mut _,
248            &mut version as *mut _,
249        ))?;
250
251        assert!(!name.is_null());
252        Ok((
253            unsafe { CStr::from_ptr(name) }
254                .to_str()
255                .unwrap_or(CSTR_CONVERT_ERROR_PLUG),
256            version,
257        ))
258    }
259
260    /// Get the ID of the request corresponding to a response.
261    fn id(&self) -> Result<String, Error> {
262        let mut id = null::<c_char>();
263        triton_call!(
264            sys::TRITONSERVER_InferenceResponseId(self.0, &mut id as *mut _),
265            from_char_array(id)
266        )
267    }
268
269    /// Get the number of parameters available in the response.
270    fn parameter_count(&self) -> Result<u32, Error> {
271        let mut count: u32 = 0;
272        triton_call!(
273            sys::TRITONSERVER_InferenceResponseParameterCount(self.0, &mut count as *mut _),
274            count
275        )
276    }
277
278    /// Get all information about a parameter.
279    fn parameter(&self, index: u32) -> Result<Parameter, Error> {
280        let mut name = null::<c_char>();
281        let mut kind: sys::TRITONSERVER_ParameterType = 0;
282        let mut value = null::<c_void>();
283        triton_call!(sys::TRITONSERVER_InferenceResponseParameter(
284            self.0,
285            index,
286            &mut name as *mut _,
287            &mut kind as *mut _,
288            &mut value as *mut _,
289        ))?;
290
291        assert!(!name.is_null());
292        assert!(!value.is_null());
293        let name = unsafe { CStr::from_ptr(name) }
294            .to_str()
295            .unwrap_or(CSTR_CONVERT_ERROR_PLUG);
296        let value = match kind {
297            sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_STRING => {
298                ParameterContent::String(
299                    unsafe { CStr::from_ptr(value as *const c_char) }
300                        .to_str()
301                        .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
302                        .to_string(),
303                )
304            }
305            sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_INT => {
306                ParameterContent::Int(unsafe { *(value as *mut i64) })
307            }
308            sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BOOL => {
309                ParameterContent::Bool(unsafe { *(value as *mut bool) })
310            }
311            _ => unreachable!(),
312        };
313        Parameter::new(name, value)
314    }
315
316    /// Get the number of outputs available in the response.
317    fn output_count(&self) -> Result<u32, Error> {
318        let mut count: u32 = 0;
319        triton_call!(
320            sys::TRITONSERVER_InferenceResponseOutputCount(self.0, &mut count as *mut _),
321            count
322        )
323    }
324
325    fn output(self: &Arc<Self>, index: u32) -> Result<Output, Error> {
326        let mut name = null::<c_char>();
327        let mut data_type: sys::TRITONSERVER_DataType = 0;
328        let mut shape = null::<i64>();
329        let mut dim_count: u64 = 0;
330        let mut base = null::<c_void>();
331        let mut byte_size: libc::size_t = 0;
332        let mut memory_type: sys::TRITONSERVER_MemoryType = 0;
333        let mut memory_type_id: i64 = 0;
334        let mut userp = null_mut::<c_void>();
335
336        triton_call!(sys::TRITONSERVER_InferenceResponseOutput(
337            self.0,
338            index,
339            &mut name as *mut _,
340            &mut data_type as *mut _,
341            &mut shape as *mut _,
342            &mut dim_count as *mut _,
343            &mut base as *mut _,
344            &mut byte_size as *mut _,
345            &mut memory_type as *mut _,
346            &mut memory_type_id as *mut _,
347            &mut userp as *mut _,
348        ))?;
349
350        assert!(!name.is_null());
351        assert!(!base.is_null());
352
353        let name = unsafe { CStr::from_ptr(name) }
354            .to_str()
355            .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
356            .to_string();
357
358        let shape = if dim_count == 0 {
359            log::trace!(
360                "Model returned output '{name}' of shape []. Consider removing this output"
361            );
362            Vec::new()
363        } else {
364            unsafe { from_raw_parts(shape, dim_count as usize) }.to_vec()
365        };
366        let data_type = unsafe { transmute::<u32, crate::memory::DataType>(data_type) };
367        let memory_type: MemoryType = unsafe { transmute(memory_type) };
368
369        // Not owned buffer, because we can't move or mutate it,
370        // we just borrow it from triton.
371        let buffer = Buffer {
372            ptr: base as *mut _,
373            len: byte_size as usize,
374            data_type,
375            memory_type,
376            owned: false,
377        };
378        Ok(Output {
379            name,
380            shape,
381            buffer,
382            index_in_parent_response: index,
383            parent_response: self.clone(),
384        })
385    }
386
387    /// Get a classification label associated with an output for a given index.
388    fn classification_label(&self, index: u32, class: u64) -> Result<String, Error> {
389        let mut label = null::<c_char>();
390        triton_call!(
391            sys::TRITONSERVER_InferenceResponseOutputClassificationLabel(
392                self.0,
393                index,
394                class as usize,
395                &mut label as *mut _,
396            ),
397            from_char_array(label)
398        )
399    }
400}
401
402impl Drop for InferenceResponseWrapper {
403    fn drop(&mut self) {
404        if !self.0.is_null() {
405            unsafe { sys::TRITONSERVER_InferenceResponseDelete(self.0) };
406        }
407    }
408}