tritonserver_rs/request/
infer.rs

1use std::{
2    collections::HashMap,
3    ffi::c_void,
4    ptr::null_mut,
5    sync::{atomic::AtomicBool, Arc},
6};
7
8use log::trace;
9use tokio::{
10    runtime::Handle,
11    sync::oneshot::{self, Receiver},
12};
13
14use crate::{
15    allocator::Allocator,
16    error::{Error, ErrorCode},
17    memory::Buffer,
18    sys, Request, Response,
19};
20
21/// Inference result error. Contains output buffers that was allocated by user provided Allocator during the inference.
22#[derive(Debug)]
23pub struct InferenceError {
24    pub error: Error,
25    pub output_buffers: HashMap<String, Buffer>,
26}
27
28impl From<Error> for InferenceError {
29    fn from(error: Error) -> Self {
30        Self {
31            error,
32            output_buffers: HashMap::new(),
33        }
34    }
35}
36
37impl std::fmt::Display for InferenceError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        self.error.fmt(f)
40    }
41}
42
43impl std::error::Error for InferenceError {}
44
45/// Future that returns the inference response. \
46/// The request can be cancelled by dropping this structure.
47///
48/// Also the input buffers assigned to the request can be returned via [get_input_release](ResponseFuture::get_input_release).
49pub struct ResponseFuture {
50    pub(super) response_receiver: Receiver<Result<Response, InferenceError>>,
51    pub(super) input_release: Option<InputRelease>,
52    pub(super) request_ptr: Arc<RequestCanceller>,
53}
54
55pub(super) struct RequestCanceller {
56    pub(crate) is_inferenced: AtomicBool,
57    pub(crate) request_ptr: *mut sys::TRITONSERVER_InferenceRequest,
58}
59unsafe impl Send for RequestCanceller {}
60unsafe impl Sync for RequestCanceller {}
61
62/// Struct that returns input buffers assigned to the request. \
63/// Note: input buffer can be released in any time from the start of the inference
64/// to the end of it.
65///
66/// Input buffers will be dropped if no one will await on this struct.
67pub struct InputRelease(pub(super) oneshot::Receiver<HashMap<String, Buffer>>);
68
69/// Start inference.
70impl Request<'_> {
71    /// Perform inference using the metadata and inputs supplied by the Request(self). \
72    /// If the function returns success,
73    /// the returned struct can be used to get results (.await) of the inference and
74    /// to return input buffers after the inference start [ResponseFuture::get_input_release]. \
75    /// Note: output buffer will be returned with [Response] or [InferenceError]. \
76    pub fn infer_async(mut self) -> Result<ResponseFuture, Error> {
77        // Check on all buffers are set.
78        if self.input.is_empty() {
79            return Err(Error::new(
80                ErrorCode::NotFound,
81                "Request's output buffer is not set",
82            ));
83        }
84        if self.custom_allocator.is_none() {
85            return Err(Error::new(
86                ErrorCode::NotFound,
87                "Request's output buffers allocator is not set",
88            ));
89        }
90        let custom_allocator = self.custom_allocator.take().unwrap();
91        let trace = self.custom_trace.take();
92
93        // Add outputs.
94        let datatype_hints = self.add_outputs()?;
95        let outputs_count = self.server.get_model(&self.model_name)?.outputs.len();
96
97        let runtime = self.server.runtime.clone();
98        let request_ptr = self.ptr;
99        let server_ptr = self.server.ptr.as_mut_ptr();
100
101        // Канал, по которому мы вернем input buffer пользователю.
102        let (input_tx, input_rx) = oneshot::channel();
103        // На всякий случай сохраним указатель, в случае ошибки sys::TRITONSERVER_InferenceRequestSetReleaseCallback
104        // разыменуем его и правильно дропнем Request.
105        let boxed_request_input_recover = Box::into_raw(Box::new((self, input_tx)));
106        let drop_boxed_request = |boxed_request: *mut (Request, _)| {
107            let (_restored_request, _) = unsafe { *Box::from_raw(boxed_request) };
108        };
109
110        // Здесь мы отдаем Request, он нам вернется в методе release_callback.
111        // Там же будет возвращен input_buffer.
112        let err = unsafe {
113            sys::TRITONSERVER_InferenceRequestSetReleaseCallback(
114                request_ptr,
115                Some(release_callback),
116                boxed_request_input_recover as *mut _,
117            )
118        };
119
120        if !err.is_null() {
121            drop_boxed_request(boxed_request_input_recover);
122
123            let err = Error {
124                ptr: err,
125                owned: true,
126            };
127            return Err(err);
128        }
129
130        // Allocator отправляется в alloc -> release, там он выдает запрашиваемые тритоном буферы в alloc и шлет их обратно в release.
131        // Так как Allocator используется тритоном в методе release, который вызывается после удаления Response,
132        // необходимо отправить алокатор в response_wrapper -> Response, чтобы Arc не дропнулся раньше времени.
133        // Имена буферов отправляется в response_wrapper, на нем будем ждать возвращенные буферы для Response.
134        let allocator = Arc::new(Allocator::new(
135            custom_allocator,
136            datatype_hints,
137            runtime.clone(),
138        )?);
139
140        let allocator_ptr = Arc::as_ptr(&allocator);
141        // response_tx отправляется в response_wrapper,
142        // когда там сконструируется Response, он будет положен в tx.
143        // response_rx отправляется юзеру внутри ResponseFuture, он на нем await-ится.
144        let (response_tx, response_rx) = oneshot::channel();
145
146        triton_call!(sys::TRITONSERVER_InferenceRequestSetResponseCallback(
147            request_ptr,
148            allocator.get_allocator(),
149            allocator_ptr as *mut c_void,
150            Some(responce_wrapper),
151            Box::into_raw(Box::new(ResponseCallbackItems {
152                response_tx,
153                allocator,
154                outputs_count,
155                runtime,
156            })) as *mut _,
157        ))?;
158
159        let trace_ptr = trace
160            .as_ref()
161            .map(|trace| trace.ptr.0)
162            .unwrap_or_else(null_mut);
163
164        triton_call!(sys::TRITONSERVER_ServerInferAsync(
165            server_ptr,
166            request_ptr,
167            trace_ptr
168        ))?;
169
170        if let Some(trace) = trace {
171            std::mem::forget(trace.ptr);
172        }
173
174        Ok(ResponseFuture {
175            response_receiver: response_rx,
176            input_release: Some(InputRelease(input_rx)),
177            request_ptr: Arc::new(RequestCanceller {
178                request_ptr,
179                is_inferenced: AtomicBool::new(false),
180            }),
181        })
182    }
183}
184
185struct ResponseCallbackItems {
186    response_tx: oneshot::Sender<Result<Response, InferenceError>>,
187    allocator: Arc<Allocator>,
188    outputs_count: usize,
189    runtime: Handle,
190}
191
192/// C-code returns the ownership on Request using this method.
193unsafe extern "C" fn release_callback(
194    ptr: *mut sys::TRITONSERVER_InferenceRequest,
195    _flags: u32,
196    user_data: *mut c_void,
197) {
198    trace!("release_callback is called");
199    assert!(!ptr.is_null());
200    assert!(!user_data.is_null());
201
202    let (mut request, input_tx) = *Box::from_raw(user_data as *mut (Request, oneshot::Sender<_>));
203    // Drain the input buffers
204    let mut buffers = HashMap::new();
205    std::mem::swap(&mut buffers, &mut request.input);
206
207    if input_tx.send(buffers).is_err() {
208        log::debug!("InputRelease was dropped before the input buffers returned from triton. Input buffers will be dropped");
209    }
210
211    assert_eq!(request.ptr, ptr);
212    trace!("release_callback is ended");
213}
214
215/// C-code calls this method when Response is ready.
216unsafe extern "C" fn responce_wrapper(
217    response: *mut sys::TRITONSERVER_InferenceResponse,
218    _flags: u32,
219    user_data: *mut c_void,
220) {
221    trace!("response wrapper is called");
222    assert!(!response.is_null());
223    assert!(!user_data.is_null());
224
225    // Allocator присылали сюда только для того, чтобы он не дропнулся во время реквеста.
226    let ResponseCallbackItems {
227        response_tx,
228        allocator,
229        outputs_count,
230        runtime,
231    } = *Box::from_raw(user_data as *mut ResponseCallbackItems);
232
233    let send_res = response_tx.send(Response::new(
234        response,
235        outputs_count as u32,
236        allocator,
237        runtime,
238    ));
239    if send_res.is_err() {
240        log::error!("error sending the result of the inference. It will be lost (including the output buffer)")
241    } else {
242        trace!("response wrapper: result is sent to oneshot");
243    }
244}