tritonserver_rs/request/
infer.rs1use std::{
2 collections::HashMap,
3 ffi::c_void,
4 ptr::null_mut,
5 sync::{atomic::AtomicBool, Arc},
6};
7
8use log::trace;
9use tokio::{
10 runtime::Handle,
11 sync::oneshot::{self, Receiver},
12};
13
14use crate::{
15 allocator::Allocator,
16 error::{Error, ErrorCode},
17 memory::Buffer,
18 sys, Request, Response,
19};
20
21#[derive(Debug)]
23pub struct InferenceError {
24 pub error: Error,
25 pub output_buffers: HashMap<String, Buffer>,
26}
27
28impl From<Error> for InferenceError {
29 fn from(error: Error) -> Self {
30 Self {
31 error,
32 output_buffers: HashMap::new(),
33 }
34 }
35}
36
37impl std::fmt::Display for InferenceError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 self.error.fmt(f)
40 }
41}
42
43impl std::error::Error for InferenceError {}
44
45pub struct ResponseFuture {
50 pub(super) response_receiver: Receiver<Result<Response, InferenceError>>,
51 pub(super) input_release: Option<InputRelease>,
52 pub(super) request_ptr: Arc<RequestCanceller>,
53}
54
55pub(super) struct RequestCanceller {
56 pub(crate) is_inferenced: AtomicBool,
57 pub(crate) request_ptr: *mut sys::TRITONSERVER_InferenceRequest,
58}
59unsafe impl Send for RequestCanceller {}
60unsafe impl Sync for RequestCanceller {}
61
62pub struct InputRelease(pub(super) oneshot::Receiver<HashMap<String, Buffer>>);
68
69impl Request<'_> {
71 pub fn infer_async(mut self) -> Result<ResponseFuture, Error> {
77 if self.input.is_empty() {
79 return Err(Error::new(
80 ErrorCode::NotFound,
81 "Request's output buffer is not set",
82 ));
83 }
84 if self.custom_allocator.is_none() {
85 return Err(Error::new(
86 ErrorCode::NotFound,
87 "Request's output buffers allocator is not set",
88 ));
89 }
90 let custom_allocator = self.custom_allocator.take().unwrap();
91 let trace = self.custom_trace.take();
92
93 let datatype_hints = self.add_outputs()?;
95 let outputs_count = self.server.get_model(&self.model_name)?.outputs.len();
96
97 let runtime = self.server.runtime.clone();
98 let request_ptr = self.ptr;
99 let server_ptr = self.server.ptr.as_mut_ptr();
100
101 let (input_tx, input_rx) = oneshot::channel();
103 let boxed_request_input_recover = Box::into_raw(Box::new((self, input_tx)));
106 let drop_boxed_request = |boxed_request: *mut (Request, _)| {
107 let (_restored_request, _) = unsafe { *Box::from_raw(boxed_request) };
108 };
109
110 let err = unsafe {
113 sys::TRITONSERVER_InferenceRequestSetReleaseCallback(
114 request_ptr,
115 Some(release_callback),
116 boxed_request_input_recover as *mut _,
117 )
118 };
119
120 if !err.is_null() {
121 drop_boxed_request(boxed_request_input_recover);
122
123 let err = Error {
124 ptr: err,
125 owned: true,
126 };
127 return Err(err);
128 }
129
130 let allocator = Arc::new(Allocator::new(
135 custom_allocator,
136 datatype_hints,
137 runtime.clone(),
138 )?);
139
140 let allocator_ptr = Arc::as_ptr(&allocator);
141 let (response_tx, response_rx) = oneshot::channel();
145
146 triton_call!(sys::TRITONSERVER_InferenceRequestSetResponseCallback(
147 request_ptr,
148 allocator.get_allocator(),
149 allocator_ptr as *mut c_void,
150 Some(responce_wrapper),
151 Box::into_raw(Box::new(ResponseCallbackItems {
152 response_tx,
153 allocator,
154 outputs_count,
155 runtime,
156 })) as *mut _,
157 ))?;
158
159 let trace_ptr = trace
160 .as_ref()
161 .map(|trace| trace.ptr.0)
162 .unwrap_or_else(null_mut);
163
164 triton_call!(sys::TRITONSERVER_ServerInferAsync(
165 server_ptr,
166 request_ptr,
167 trace_ptr
168 ))?;
169
170 if let Some(trace) = trace {
171 std::mem::forget(trace.ptr);
172 }
173
174 Ok(ResponseFuture {
175 response_receiver: response_rx,
176 input_release: Some(InputRelease(input_rx)),
177 request_ptr: Arc::new(RequestCanceller {
178 request_ptr,
179 is_inferenced: AtomicBool::new(false),
180 }),
181 })
182 }
183}
184
185struct ResponseCallbackItems {
186 response_tx: oneshot::Sender<Result<Response, InferenceError>>,
187 allocator: Arc<Allocator>,
188 outputs_count: usize,
189 runtime: Handle,
190}
191
192unsafe extern "C" fn release_callback(
194 ptr: *mut sys::TRITONSERVER_InferenceRequest,
195 _flags: u32,
196 user_data: *mut c_void,
197) {
198 trace!("release_callback is called");
199 assert!(!ptr.is_null());
200 assert!(!user_data.is_null());
201
202 let (mut request, input_tx) = *Box::from_raw(user_data as *mut (Request, oneshot::Sender<_>));
203 let mut buffers = HashMap::new();
205 std::mem::swap(&mut buffers, &mut request.input);
206
207 if input_tx.send(buffers).is_err() {
208 log::debug!("InputRelease was dropped before the input buffers returned from triton. Input buffers will be dropped");
209 }
210
211 assert_eq!(request.ptr, ptr);
212 trace!("release_callback is ended");
213}
214
215unsafe extern "C" fn responce_wrapper(
217 response: *mut sys::TRITONSERVER_InferenceResponse,
218 _flags: u32,
219 user_data: *mut c_void,
220) {
221 trace!("response wrapper is called");
222 assert!(!response.is_null());
223 assert!(!user_data.is_null());
224
225 let ResponseCallbackItems {
227 response_tx,
228 allocator,
229 outputs_count,
230 runtime,
231 } = *Box::from_raw(user_data as *mut ResponseCallbackItems);
232
233 let send_res = response_tx.send(Response::new(
234 response,
235 outputs_count as u32,
236 allocator,
237 runtime,
238 ));
239 if send_res.is_err() {
240 log::error!("error sending the result of the inference. It will be lost (including the output buffer)")
241 } else {
242 trace!("response wrapper: result is sent to oneshot");
243 }
244}