tritonserver_rs/request/
utils.rs1use std::{collections::HashMap, future::Future};
2
3use crate::{
4 error::{Error, ErrorCode},
5 memory::Buffer,
6 request::infer::*,
7 sys, Response,
8};
9
10impl Future for ResponseFuture {
12 type Output = Result<Response, InferenceError>;
13 fn poll(
14 self: std::pin::Pin<&mut Self>,
15 cx: &mut std::task::Context<'_>,
16 ) -> std::task::Poll<Self::Output> {
17 if self.input_release.is_some() {
18 log::debug!("ResponseFuture has unhandled InputRelease. \
19 Ignore this message if there is no need to handle returned input resources. They will be dropped.
20 ");
21 }
22 let request_canceller = self.request_ptr.clone();
23
24 let res = unsafe { self.map_unchecked_mut(|this| &mut this.response_receiver) }
25 .poll(cx)
26 .map(|recv_res| match recv_res {
27 Ok(res) => res,
28 Err(recv_err) => Err(Error::new(
29 ErrorCode::Internal,
30 format!("response receive error: {recv_err}"),
31 )
32 .into()),
33 });
34
35 if res.is_ready() {
36 request_canceller
37 .is_inferenced
38 .store(true, std::sync::atomic::Ordering::SeqCst);
39 }
40 res
41 }
42}
43
44impl ResponseFuture {
45 pub fn blocking_recv(self) -> Result<Response, InferenceError> {
51 let request_canceller = self.request_ptr.clone();
52 let res = match self.response_receiver.blocking_recv() {
53 Ok(res) => res,
54 Err(recv_err) => Err(Error::new(
55 ErrorCode::Internal,
56 format!("response receive error: {recv_err}"),
57 )
58 .into()),
59 };
60
61 request_canceller
62 .is_inferenced
63 .store(true, std::sync::atomic::Ordering::SeqCst);
64 res
65 }
66
67 pub fn get_input_release(&mut self) -> InputRelease {
73 self.input_release.take().unwrap_or_else(|| {
74 log::error!("ResponseFuture::get_input_release was invoked twice in a row. Empty future is returned");
75 let (_, rx) = tokio::sync::oneshot::channel();
76 InputRelease(rx)
77 })
78 }
79}
80
81impl RequestCanceller {
82 fn is_cancelled(&self) -> Result<bool, Error> {
83 let mut res = false;
84 triton_call!(
85 sys::TRITONSERVER_InferenceRequestIsCancelled(self.request_ptr, &mut res),
86 res
87 )
88 }
89}
90
91impl Drop for RequestCanceller {
92 fn drop(&mut self) {
93 if !self.is_inferenced.load(std::sync::atomic::Ordering::SeqCst)
94 && !self.is_cancelled().unwrap_or(true)
95 {
96 let _ = unsafe { sys::TRITONSERVER_InferenceRequestCancel(self.request_ptr) };
97 }
98 }
99}
100
101impl Future for InputRelease {
106 type Output = Result<HashMap<String, Buffer>, Error>;
107 fn poll(
108 self: std::pin::Pin<&mut Self>,
109 cx: &mut std::task::Context<'_>,
110 ) -> std::task::Poll<Self::Output> {
111 unsafe { self.map_unchecked_mut(|this| &mut this.0) }
112 .poll(cx)
113 .map_err(|recv_err| {
114 Error::new(
115 ErrorCode::Internal,
116 format!("Receive input buffer error: {recv_err}"),
117 )
118 })
119 }
120}
121
122impl InputRelease {
123 pub fn blocking_recv(self) -> Result<HashMap<String, Buffer>, Error> {
127 self.0.blocking_recv().map_err(|recv_error| {
128 Error::new(
129 ErrorCode::Internal,
130 format!("Receive input buffer error: {recv_error}"),
131 )
132 })
133 }
134}