use std::{
collections::HashMap,
ffi::c_void,
ptr::null_mut,
sync::{atomic::AtomicBool, Arc},
};
use log::trace;
use tokio::{
runtime::Handle,
sync::oneshot::{self, Receiver},
};
use crate::{
allocator::Allocator,
error::{Error, ErrorCode},
memory::Buffer,
sys, Request, Response,
};
#[derive(Debug)]
pub struct InferenceError {
pub error: Error,
pub output_buffers: HashMap<String, Buffer>,
}
impl From<Error> for InferenceError {
fn from(error: Error) -> Self {
Self {
error,
output_buffers: HashMap::new(),
}
}
}
impl std::fmt::Display for InferenceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.error.fmt(f)
}
}
impl std::error::Error for InferenceError {}
pub struct ResponseFuture {
pub(super) response_receiver: Receiver<Result<Response, InferenceError>>,
pub(super) input_release: Option<InputRelease>,
pub(super) request_ptr: Arc<RequestCanceller>,
}
pub(super) struct RequestCanceller {
pub(crate) is_inferenced: AtomicBool,
pub(crate) request_ptr: *mut sys::TRITONSERVER_InferenceRequest,
}
unsafe impl Send for RequestCanceller {}
unsafe impl Sync for RequestCanceller {}
pub struct InputRelease(pub(super) oneshot::Receiver<HashMap<String, Buffer>>);
impl Request<'_> {
pub fn infer_async(mut self) -> Result<ResponseFuture, Error> {
if self.input.is_empty() {
return Err(Error::new(
ErrorCode::NotFound,
"Request's output buffer is not set",
));
}
if self.custom_allocator.is_none() {
return Err(Error::new(
ErrorCode::NotFound,
"Request's output buffers allocator is not set",
));
}
let custom_allocator = self.custom_allocator.take().unwrap();
let trace = self.custom_trace.take();
let datatype_hints = self.add_outputs()?;
let outputs_count = self.server.get_model(&self.model_name)?.outputs.len();
let runtime = self.server.runtime.clone();
let request_ptr = self.ptr;
let server_ptr = self.server.ptr.as_mut_ptr();
let (input_tx, input_rx) = oneshot::channel();
let boxed_request_input_recover = Box::into_raw(Box::new((self, input_tx)));
let drop_boxed_request = |boxed_request: *mut (Request, _)| {
let (_restored_request, _) = unsafe { *Box::from_raw(boxed_request) };
};
let err = unsafe {
sys::TRITONSERVER_InferenceRequestSetReleaseCallback(
request_ptr,
Some(release_callback),
boxed_request_input_recover as *mut _,
)
};
if !err.is_null() {
drop_boxed_request(boxed_request_input_recover);
let err = Error {
ptr: err,
owned: true,
};
return Err(err);
}
let allocator = Arc::new(Allocator::new(
custom_allocator,
datatype_hints,
runtime.clone(),
)?);
let allocator_ptr = Arc::as_ptr(&allocator);
let (response_tx, response_rx) = oneshot::channel();
triton_call!(sys::TRITONSERVER_InferenceRequestSetResponseCallback(
request_ptr,
allocator.get_allocator(),
allocator_ptr as *mut c_void,
Some(responce_wrapper),
Box::into_raw(Box::new(ResponseCallbackItems {
response_tx,
allocator,
outputs_count,
runtime,
})) as *mut _,
))?;
let trace_ptr = trace
.as_ref()
.map(|trace| trace.ptr.0)
.unwrap_or_else(null_mut);
triton_call!(sys::TRITONSERVER_ServerInferAsync(
server_ptr,
request_ptr,
trace_ptr
))?;
if let Some(trace) = trace {
std::mem::forget(trace.ptr);
}
Ok(ResponseFuture {
response_receiver: response_rx,
input_release: Some(InputRelease(input_rx)),
request_ptr: Arc::new(RequestCanceller {
request_ptr,
is_inferenced: AtomicBool::new(false),
}),
})
}
}
struct ResponseCallbackItems {
response_tx: oneshot::Sender<Result<Response, InferenceError>>,
allocator: Arc<Allocator>,
outputs_count: usize,
runtime: Handle,
}
unsafe extern "C" fn release_callback(
ptr: *mut sys::TRITONSERVER_InferenceRequest,
_flags: u32,
user_data: *mut c_void,
) {
trace!("release_callback is called");
assert!(!ptr.is_null());
assert!(!user_data.is_null());
let (mut request, input_tx) = *Box::from_raw(user_data as *mut (Request, oneshot::Sender<_>));
let mut buffers = HashMap::new();
std::mem::swap(&mut buffers, &mut request.input);
if input_tx.send(buffers).is_err() {
log::debug!("InputRelease was dropped before the input buffers returned from triton. Input buffers will be dropped");
}
assert_eq!(request.ptr, ptr);
trace!("release_callback is ended");
}
unsafe extern "C" fn responce_wrapper(
response: *mut sys::TRITONSERVER_InferenceResponse,
_flags: u32,
user_data: *mut c_void,
) {
trace!("response wrapper is called");
assert!(!response.is_null());
assert!(!user_data.is_null());
let ResponseCallbackItems {
response_tx,
allocator,
outputs_count,
runtime,
} = *Box::from_raw(user_data as *mut ResponseCallbackItems);
let send_res = response_tx.send(Response::new(
response,
outputs_count as u32,
allocator,
runtime,
));
if send_res.is_err() {
log::error!("error sending the result of the inference. It will be lost (including the output buffer)")
} else {
trace!("response wrapper: result is sent to oneshot");
}
}