tritonserver_rs/request/
utils.rs

1use std::{collections::HashMap, future::Future};
2
3use crate::{
4    error::{Error, ErrorCode},
5    memory::Buffer,
6    request::infer::*,
7    sys, Response,
8};
9
10/// Awaiting on this structure will returt result of the inference: Ok([Response]) or Err([InferenceError]).
11impl Future for ResponseFuture {
12    type Output = Result<Response, InferenceError>;
13    fn poll(
14        self: std::pin::Pin<&mut Self>,
15        cx: &mut std::task::Context<'_>,
16    ) -> std::task::Poll<Self::Output> {
17        if self.input_release.is_some() {
18            log::debug!("ResponseFuture has unhandled InputRelease. \
19                Ignore this message if there is no need to handle returned input resources. They will be dropped.
20            ");
21        }
22        let request_canceller = self.request_ptr.clone();
23
24        let res = unsafe { self.map_unchecked_mut(|this| &mut this.response_receiver) }
25            .poll(cx)
26            .map(|recv_res| match recv_res {
27                Ok(res) => res,
28                Err(recv_err) => Err(Error::new(
29                    ErrorCode::Internal,
30                    format!("response receive error: {recv_err}"),
31                )
32                .into()),
33            });
34
35        if res.is_ready() {
36            request_canceller
37                .is_inferenced
38                .store(true, std::sync::atomic::Ordering::SeqCst);
39        }
40        res
41    }
42}
43
44impl ResponseFuture {
45    /// Blocking await to call outside of asynchronous contexts.
46    ///
47    /// # Panics
48    ///
49    /// This function panics if called within an asynchronous execution context.
50    pub fn blocking_recv(self) -> Result<Response, InferenceError> {
51        let request_canceller = self.request_ptr.clone();
52        let res = match self.response_receiver.blocking_recv() {
53            Ok(res) => res,
54            Err(recv_err) => Err(Error::new(
55                ErrorCode::Internal,
56                format!("response receive error: {recv_err}"),
57            )
58            .into()),
59        };
60
61        request_canceller
62            .is_inferenced
63            .store(true, std::sync::atomic::Ordering::SeqCst);
64        res
65    }
66
67    /// Get the future to return the input buffers assigned to the Request.
68    ///
69    /// **NOTE**: this function should be called at most once. Otherwise it will return garbage. \
70    /// **Note** that input buffer can be released in any time from the start of the inference
71    /// to the end of it.
72    pub fn get_input_release(&mut self) -> InputRelease {
73        self.input_release.take().unwrap_or_else(|| {
74            log::error!("ResponseFuture::get_input_release was invoked twice in a row. Empty future is returned");
75            let (_, rx) = tokio::sync::oneshot::channel();
76            InputRelease(rx)
77        })
78    }
79}
80
81impl RequestCanceller {
82    fn is_cancelled(&self) -> Result<bool, Error> {
83        let mut res = false;
84        triton_call!(
85            sys::TRITONSERVER_InferenceRequestIsCancelled(self.request_ptr, &mut res),
86            res
87        )
88    }
89}
90
91impl Drop for RequestCanceller {
92    fn drop(&mut self) {
93        if !self.is_inferenced.load(std::sync::atomic::Ordering::SeqCst)
94            && !self.is_cancelled().unwrap_or(true)
95        {
96            let _ = unsafe { sys::TRITONSERVER_InferenceRequestCancel(self.request_ptr) };
97        }
98    }
99}
100
101/// Awaiting on input buffers returnal from the inference.
102///
103/// Note that input buffer can be released in any time from the start of the inference
104/// to the end of it.
105impl Future for InputRelease {
106    type Output = Result<HashMap<String, Buffer>, Error>;
107    fn poll(
108        self: std::pin::Pin<&mut Self>,
109        cx: &mut std::task::Context<'_>,
110    ) -> std::task::Poll<Self::Output> {
111        unsafe { self.map_unchecked_mut(|this| &mut this.0) }
112            .poll(cx)
113            .map_err(|recv_err| {
114                Error::new(
115                    ErrorCode::Internal,
116                    format!("Receive input buffer error: {recv_err}"),
117                )
118            })
119    }
120}
121
122impl InputRelease {
123    /// Blocking receive to call outside of asynchronous contexts.\
124    /// # Panics
125    /// This function panics if called within an asynchronous execution context.
126    pub fn blocking_recv(self) -> Result<HashMap<String, Buffer>, Error> {
127        self.0.blocking_recv().map_err(|recv_error| {
128            Error::new(
129                ErrorCode::Internal,
130                format!("Receive input buffer error: {recv_error}"),
131            )
132        })
133    }
134}