triton_rs/
request.rs

1use crate::{check_err, decode_string, Error};
2use libc::c_void;
3use std::ffi::CStr;
4use std::ffi::CString;
5use std::ptr;
6use std::slice;
7
8pub struct Request {
9    ptr: *mut triton_sys::TRITONBACKEND_Request,
10}
11
12impl Request {
13    pub fn from_ptr(ptr: *mut triton_sys::TRITONBACKEND_Request) -> Self {
14        Self { ptr }
15    }
16
17    pub fn as_ptr(&self) -> *mut triton_sys::TRITONBACKEND_Request {
18        self.ptr
19    }
20
21    pub fn get_input(&self, name: &str) -> Result<Input, Error> {
22        let name = CString::new(name).expect("CString::new failed");
23
24        let mut input: *mut triton_sys::TRITONBACKEND_Input = ptr::null_mut();
25        check_err(unsafe {
26            triton_sys::TRITONBACKEND_RequestInput(self.ptr, name.as_ptr(), &mut input)
27        })?;
28
29        Ok(Input::from_ptr(input))
30    }
31}
32
33pub struct Input {
34    ptr: *mut triton_sys::TRITONBACKEND_Input,
35}
36impl Input {
37    pub fn from_ptr(ptr: *mut triton_sys::TRITONBACKEND_Input) -> Self {
38        Self { ptr }
39    }
40
41    fn buffer(&self) -> Result<Vec<u8>, Error> {
42        let mut buffer: *const c_void = ptr::null_mut();
43        let index = 0;
44        let mut memory_type: triton_sys::TRITONSERVER_MemoryType = 0;
45        let mut memory_type_id = 0;
46        let mut buffer_byte_size = 0;
47        check_err(unsafe {
48            triton_sys::TRITONBACKEND_InputBuffer(
49                self.ptr,
50                index,
51                &mut buffer,
52                &mut buffer_byte_size,
53                &mut memory_type,
54                &mut memory_type_id,
55            )
56        })?;
57
58        let mem: &[u8] =
59            unsafe { slice::from_raw_parts(buffer as *mut u8, buffer_byte_size as usize) };
60        Ok(mem.to_vec())
61    }
62
63    pub fn as_string(&self) -> Result<String, Error> {
64        let properties = self.properties()?;
65        let buffer = self.buffer()?;
66
67        let strings = decode_string(&buffer)?;
68        Ok(strings.first().unwrap().clone())
69    }
70
71    pub fn as_u64(&self) -> Result<u64, Error> {
72        let properties = self.properties()?;
73        let buffer = self.buffer()?;
74
75        let mut bytes = [0u8; 8];
76        bytes.copy_from_slice(&buffer);
77
78        Ok(u64::from_le_bytes(bytes))
79    }
80
81    fn properties(&self) -> Result<InputProperties, Error> {
82        let mut name = ptr::null();
83        let mut datatype = 0u32;
84        let shape = ptr::null_mut();
85        let mut dims_count = 0u32;
86        let mut byte_size = 0u64;
87        let mut buffer_count = 0u32;
88
89        check_err(unsafe {
90            triton_sys::TRITONBACKEND_InputProperties(
91                self.ptr,
92                &mut name,
93                &mut datatype,
94                shape,
95                &mut dims_count,
96                &mut byte_size,
97                &mut buffer_count,
98            )
99        })?;
100
101        let name: &CStr = unsafe { CStr::from_ptr(name) };
102        let name = name.to_string_lossy().to_string();
103
104        Ok(InputProperties {
105            name,
106            datatype,
107            // shape,
108            dims_count,
109            byte_size,
110            buffer_count,
111        })
112    }
113}
114
115#[derive(Debug)]
116pub struct InputProperties {
117    name: String,
118    datatype: u32,
119    // shape: Vec<i64>,
120    dims_count: u32,
121    byte_size: u64,
122    buffer_count: u32,
123}