tritonserver_rs/
parameter.rs

1use std::{fs::File, path::Path, sync::Arc};
2
3use crate::{error::Error, sys, to_cstring};
4
5/// Types of parameters recognized by TRITONSERVER.
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
7#[repr(u32)]
8pub enum TritonParameterType {
9    String = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_STRING,
10    Int = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_INT,
11    Bool = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BOOL,
12    Double = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_DOUBLE,
13    Bytes = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BYTES,
14}
15
16/// Enum representation of Parameter content.
17#[derive(Debug, Clone)]
18pub enum ParameterContent {
19    String(String),
20    Int(i64),
21    Bool(bool),
22    Double(f64),
23    Bytes(Vec<u8>),
24}
25
26/// Parameter of the [Server](crate::Server) or [Response](crate::Response).
27#[derive(Debug, Clone)]
28pub struct Parameter {
29    pub(crate) ptr: Arc<*mut sys::TRITONSERVER_Parameter>,
30    pub name: String,
31    pub content: ParameterContent,
32}
33
34unsafe impl Send for Parameter {}
35unsafe impl Sync for Parameter {}
36
37impl Parameter {
38    /// Create new Parameter.
39    pub fn new<N: AsRef<str>>(name: N, value: ParameterContent) -> Result<Self, Error> {
40        let c_name = to_cstring(&name)?;
41        let ptr = match &value {
42            ParameterContent::Bool(v) => unsafe {
43                sys::TRITONSERVER_ParameterNew(
44                    c_name.as_ptr(),
45                    TritonParameterType::Bool as _,
46                    v as *const bool as *const _,
47                )
48            },
49            ParameterContent::Int(v) => unsafe {
50                sys::TRITONSERVER_ParameterNew(
51                    c_name.as_ptr(),
52                    TritonParameterType::Int as _,
53                    v as *const i64 as *const _,
54                )
55            },
56            ParameterContent::String(v) => {
57                let v = to_cstring(v)?;
58                unsafe {
59                    sys::TRITONSERVER_ParameterNew(
60                        c_name.as_ptr(),
61                        TritonParameterType::String as _,
62                        v.as_ptr() as *const _,
63                    )
64                }
65            }
66            ParameterContent::Double(v) => unsafe {
67                sys::TRITONSERVER_ParameterNew(
68                    c_name.as_ptr(),
69                    TritonParameterType::Double as _,
70                    v as *const f64 as *const _,
71                )
72            },
73            ParameterContent::Bytes(v) => unsafe {
74                sys::TRITONSERVER_ParameterBytesNew(
75                    c_name.as_ptr(),
76                    v.as_ptr() as *const _,
77                    v.len() as _,
78                )
79            },
80        };
81
82        Ok(Self {
83            ptr: Arc::new(ptr),
84            name: name.as_ref().to_string(),
85            content: value,
86        })
87    }
88
89    /// Create String Parameter of model config with exact version of the model. \
90    /// `config`: model config.pbtxt as json value.
91    /// Check [load_config_as_json] to permutate .pbtxt config to json value. \
92    /// If [Options::model_control_mode](crate::options::Options::model_control_mode) set as EXPLICIT and the result of this method is passed to [crate::Server::load_model_with_parametrs],
93    /// the server will load only that exact model and only that exact version of it.
94    pub fn from_config_with_exact_version(
95        mut config: serde_json::Value,
96        version: i64,
97    ) -> Result<Self, Error> {
98        config["version_policy"] = serde_json::json!({"specific": { "versions": [version]}});
99        Parameter::new("config", ParameterContent::String(config.to_string()))
100    }
101}
102
103impl Drop for Parameter {
104    fn drop(&mut self) {
105        if !self.ptr.is_null() && Arc::strong_count(&self.ptr) == 1 {
106            unsafe { sys::TRITONSERVER_ParameterDelete(*self.ptr) }
107        }
108    }
109}
110
111fn hjson_to_json(value: serde_hjson::Value) -> serde_json::Value {
112    match value {
113        serde_hjson::Value::Null => serde_json::Value::Null,
114        serde_hjson::Value::U64(v) => serde_json::Value::from(v),
115        serde_hjson::Value::I64(v) => serde_json::Value::from(v),
116        serde_hjson::Value::F64(v) => serde_json::Value::from(v),
117        serde_hjson::Value::Bool(v) => serde_json::Value::from(v),
118        serde_hjson::Value::String(v) => serde_json::Value::from(v),
119
120        serde_hjson::Value::Array(v) => {
121            serde_json::Value::from_iter(v.into_iter().map(hjson_to_json))
122        }
123        serde_hjson::Value::Object(v) => serde_json::Value::from_iter(
124            v.into_iter()
125                .map(|(key, value)| (key, hjson_to_json(value))),
126        ),
127    }
128}
129
130/// Load config.pbtxt from the `config_path` and parse it to json value. \
131/// Might be useful if it is required to run model with altered config.
132/// In this case String [Parameter] with name 'config' and the result of this method as data should be created
133/// and passed to [Server::load_model_with_parametrs](crate::Server::load_model_with_parametrs) ([Options::model_control_mode](crate::options::Options::model_control_mode) set as EXPLICIT required).
134/// Check realization of [Parameter::from_config_with_exact_version] as an example. \
135/// **Note (Subject to change)**: congig must be in [hjson format](https://hjson.github.io/).
136pub fn load_config_as_json<P: AsRef<Path>>(config_path: P) -> Result<serde_json::Value, Error> {
137    let content = File::open(config_path).map_err(|err| {
138        Error::new(
139            crate::error::ErrorCode::InvalidArg,
140            format!("Error opening the config file: {err}"),
141        )
142    })?;
143    let value = serde_hjson::from_reader::<_, serde_hjson::Value>(&content).map_err(|err| {
144        Error::new(
145            crate::error::ErrorCode::InvalidArg,
146            format!("Error parsing the config file as hjson: {err}"),
147        )
148    })?;
149    Ok(hjson_to_json(value))
150}
151
152#[test]
153fn test_config_to_json() {
154    let json_cfg = serde_json::json!({
155        "name": "voicenet",
156        "platform": "onnxruntime_onnx",
157        "input": [
158            {
159                "data_type": "TYPE_FP32",
160                "name": "input",
161                "dims": [512, 160000]
162            }
163        ],
164        "output": [
165            {
166                "data_type": "TYPE_FP32",
167                "name": "output",
168                "dims": [512, 512]
169            }
170        ],
171        "instance_group": [
172            {
173                "count": 2,
174                "kind": "KIND_CPU"
175            }
176        ],
177        "optimization": { "execution_accelerators": {
178            "cpu_execution_accelerator" : [ {
179                "name" : "openvino"
180            } ]
181        }}
182    });
183
184    assert_eq!(
185        load_config_as_json("model_repo/voicenet_onnx/voicenet/config.pbtxt").unwrap(),
186        json_cfg
187    );
188}