tritonserver_rs/
parameter.rs1use std::{fs::File, path::Path, sync::Arc};
2
3use crate::{error::Error, sys, to_cstring};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
7#[repr(u32)]
8pub enum TritonParameterType {
9 String = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_STRING,
10 Int = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_INT,
11 Bool = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BOOL,
12 Double = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_DOUBLE,
13 Bytes = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BYTES,
14}
15
16#[derive(Debug, Clone)]
18pub enum ParameterContent {
19 String(String),
20 Int(i64),
21 Bool(bool),
22 Double(f64),
23 Bytes(Vec<u8>),
24}
25
26#[derive(Debug, Clone)]
28pub struct Parameter {
29 pub(crate) ptr: Arc<*mut sys::TRITONSERVER_Parameter>,
30 pub name: String,
31 pub content: ParameterContent,
32}
33
34unsafe impl Send for Parameter {}
35unsafe impl Sync for Parameter {}
36
37impl Parameter {
38 pub fn new<N: AsRef<str>>(name: N, value: ParameterContent) -> Result<Self, Error> {
40 let c_name = to_cstring(&name)?;
41 let ptr = match &value {
42 ParameterContent::Bool(v) => unsafe {
43 sys::TRITONSERVER_ParameterNew(
44 c_name.as_ptr(),
45 TritonParameterType::Bool as _,
46 v as *const bool as *const _,
47 )
48 },
49 ParameterContent::Int(v) => unsafe {
50 sys::TRITONSERVER_ParameterNew(
51 c_name.as_ptr(),
52 TritonParameterType::Int as _,
53 v as *const i64 as *const _,
54 )
55 },
56 ParameterContent::String(v) => {
57 let v = to_cstring(v)?;
58 unsafe {
59 sys::TRITONSERVER_ParameterNew(
60 c_name.as_ptr(),
61 TritonParameterType::String as _,
62 v.as_ptr() as *const _,
63 )
64 }
65 }
66 ParameterContent::Double(v) => unsafe {
67 sys::TRITONSERVER_ParameterNew(
68 c_name.as_ptr(),
69 TritonParameterType::Double as _,
70 v as *const f64 as *const _,
71 )
72 },
73 ParameterContent::Bytes(v) => unsafe {
74 sys::TRITONSERVER_ParameterBytesNew(
75 c_name.as_ptr(),
76 v.as_ptr() as *const _,
77 v.len() as _,
78 )
79 },
80 };
81
82 Ok(Self {
83 ptr: Arc::new(ptr),
84 name: name.as_ref().to_string(),
85 content: value,
86 })
87 }
88
89 pub fn from_config_with_exact_version(
95 mut config: serde_json::Value,
96 version: i64,
97 ) -> Result<Self, Error> {
98 config["version_policy"] = serde_json::json!({"specific": { "versions": [version]}});
99 Parameter::new("config", ParameterContent::String(config.to_string()))
100 }
101}
102
103impl Drop for Parameter {
104 fn drop(&mut self) {
105 if !self.ptr.is_null() && Arc::strong_count(&self.ptr) == 1 {
106 unsafe { sys::TRITONSERVER_ParameterDelete(*self.ptr) }
107 }
108 }
109}
110
111fn hjson_to_json(value: serde_hjson::Value) -> serde_json::Value {
112 match value {
113 serde_hjson::Value::Null => serde_json::Value::Null,
114 serde_hjson::Value::U64(v) => serde_json::Value::from(v),
115 serde_hjson::Value::I64(v) => serde_json::Value::from(v),
116 serde_hjson::Value::F64(v) => serde_json::Value::from(v),
117 serde_hjson::Value::Bool(v) => serde_json::Value::from(v),
118 serde_hjson::Value::String(v) => serde_json::Value::from(v),
119
120 serde_hjson::Value::Array(v) => {
121 serde_json::Value::from_iter(v.into_iter().map(hjson_to_json))
122 }
123 serde_hjson::Value::Object(v) => serde_json::Value::from_iter(
124 v.into_iter()
125 .map(|(key, value)| (key, hjson_to_json(value))),
126 ),
127 }
128}
129
130pub fn load_config_as_json<P: AsRef<Path>>(config_path: P) -> Result<serde_json::Value, Error> {
137 let content = File::open(config_path).map_err(|err| {
138 Error::new(
139 crate::error::ErrorCode::InvalidArg,
140 format!("Error opening the config file: {err}"),
141 )
142 })?;
143 let value = serde_hjson::from_reader::<_, serde_hjson::Value>(&content).map_err(|err| {
144 Error::new(
145 crate::error::ErrorCode::InvalidArg,
146 format!("Error parsing the config file as hjson: {err}"),
147 )
148 })?;
149 Ok(hjson_to_json(value))
150}
151
152#[test]
153fn test_config_to_json() {
154 let json_cfg = serde_json::json!({
155 "name": "voicenet",
156 "platform": "onnxruntime_onnx",
157 "input": [
158 {
159 "data_type": "TYPE_FP32",
160 "name": "input",
161 "dims": [512, 160000]
162 }
163 ],
164 "output": [
165 {
166 "data_type": "TYPE_FP32",
167 "name": "output",
168 "dims": [512, 512]
169 }
170 ],
171 "instance_group": [
172 {
173 "count": 2,
174 "kind": "KIND_CPU"
175 }
176 ],
177 "optimization": { "execution_accelerators": {
178 "cpu_execution_accelerator" : [ {
179 "name" : "openvino"
180 } ]
181 }}
182 });
183
184 assert_eq!(
185 load_config_as_json("model_repo/voicenet_onnx/voicenet/config.pbtxt").unwrap(),
186 json_cfg
187 );
188}