tritonserver_rs/
server.rs

1use std::{
2    collections::HashMap,
3    ffi::{c_void, CStr},
4    mem::transmute,
5    path::Path,
6    ptr::null_mut,
7    sync::Arc,
8    time::Duration,
9};
10
11use serde_json::{from_slice, Value};
12
13use crate::{
14    message::{self, Index, Message, Model},
15    metrics::{self, PrometheusMetrics},
16    options::Options,
17    parameter::{Parameter, ParameterContent},
18    path_to_cstring, sys, to_cstring, Error, ErrorCode, Request,
19};
20
21/// Batch properties of the model.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[repr(u32)]
24pub enum Batch {
25    /// Triton cannot determine the batching properties of the model.
26    /// This means that the model does not support batching in any way that is useable by Triton.
27    Unknown = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_UNKNOWN,
28    /// The model supports batching along the first dimension of every input and output tensor.
29    /// Triton schedulers that perform batching can automatically batch inference requests along this dimension.
30    FirstDim = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_FIRST_DIM,
31}
32
33/// Transaction policy of the model.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35#[repr(u32)]
36pub enum Transaction {
37    /// The model generates exactly one response per request.
38    OneToOne = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_ONE_TO_ONE,
39    /// The model may generate zero to many responses per request.
40    Decoupled = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_DECOUPLED,
41}
42
43bitflags::bitflags! {
44    /// Flags that control how to collect the index.
45    pub struct State: u32 {
46        /// If set in 'flags', only the models that are loaded into the server and ready for inferencing are returned.
47        const READY = sys::tritonserver_modelindexflag_enum_TRITONSERVER_INDEX_FLAG_READY;
48    }
49}
50
51/// Kinds of instance groups recognized by TRITONSERVER.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53#[repr(u32)]
54pub enum InstanceGroup {
55    Auto = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_AUTO,
56    Cpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_CPU,
57    Gpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_GPU,
58    Model = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_MODEL,
59}
60
61impl InstanceGroup {
62    fn as_cstr(self) -> &'static CStr {
63        unsafe { CStr::from_ptr(sys::TRITONSERVER_InstanceGroupKindString(self as u32)) }
64    }
65
66    /// Get the string representation of an instance-group kind.
67    pub fn as_str(self) -> &'static str {
68        self.as_cstr()
69            .to_str()
70            .unwrap_or(crate::error::CSTR_CONVERT_ERROR_PLUG)
71    }
72}
73
74#[derive(Debug)]
75pub(crate) struct Inner(*mut sys::TRITONSERVER_Server);
76impl Inner {
77    pub(crate) fn stop(&self) -> Result<(), Error> {
78        triton_call!(sys::TRITONSERVER_ServerStop(self.0))
79    }
80
81    pub(crate) fn is_live(&self) -> Result<bool, Error> {
82        let mut result = false;
83        triton_call!(
84            sys::TRITONSERVER_ServerIsLive(self.0, &mut result as *mut _),
85            result
86        )
87    }
88
89    pub(crate) fn delete(&self) -> Result<(), Error> {
90        triton_call!(sys::TRITONSERVER_ServerDelete(self.0))
91    }
92
93    pub(crate) fn as_mut_ptr(&self) -> *mut sys::TRITONSERVER_Server {
94        self.0
95    }
96}
97
98impl Drop for Inner {
99    fn drop(&mut self) {
100        let _ = self
101            .is_live()
102            .and_then(|live| {
103                if live {
104                    self.stop().and_then(|_| loop {
105                        if !self.is_live()? {
106                            return Ok(());
107                        }
108                    })
109                } else {
110                    Ok(())
111                }
112            })
113            .and_then(|_| self.delete());
114    }
115}
116
117/// # SAFETY
118/// Inner is Send. But it's not Sync! \
119/// However, it's used only in Server and Server is never clones Inner,
120/// so there is always only 1 copy of it.
121unsafe impl Send for Inner {}
122unsafe impl Sync for Inner {}
123
124/// Inference server object.
125#[derive(Debug)]
126pub struct Server {
127    pub(crate) ptr: Arc<Inner>,
128    pub(crate) models: HashMap<String, Model>,
129    pub(crate) runtime: tokio::runtime::Handle,
130}
131
132unsafe impl Send for Server {}
133
134impl Server {
135    /// Create new server object.
136    pub async fn new(options: Options) -> Result<Self, Error> {
137        let mut server = null_mut::<sys::TRITONSERVER_Server>();
138        triton_call!(sys::TRITONSERVER_ServerNew(
139            &mut server as *mut _,
140            *options.0
141        ))?;
142
143        assert!(!server.is_null());
144
145        let mut server = Server {
146            ptr: Arc::new(Inner(server)),
147            models: HashMap::new(),
148            runtime: tokio::runtime::Handle::current(),
149        };
150        server.update_all_models()?;
151
152        Ok(server)
153    }
154
155    pub(crate) fn get_model<M: AsRef<str>>(&self, model: M) -> Result<&Model, Error> {
156        self.models.get(model.as_ref()).ok_or_else(|| {
157            Error::new(
158                ErrorCode::NotFound,
159                format!(
160                    "Model {} is not found in server model metadata storage.",
161                    model.as_ref()
162                ),
163            )
164        })
165    }
166
167    fn update_all_models(&mut self) -> Result<(), Error> {
168        for model in self.model_index(State::all())? {
169            self.update_model_info(model.name)?;
170        }
171        Ok(())
172    }
173
174    fn update_model_info<M: AsRef<str>>(&mut self, model: M) -> Result<(), Error> {
175        self.models
176            .insert(model.as_ref().to_string(), self.model_metadata(model, -1)?);
177        Ok(())
178    }
179
180    /// Stop a server object. A server can't be restarted once it has been stopped.
181    pub fn stop(&self) -> Result<(), Error> {
182        self.ptr.stop()
183    }
184
185    /// Create a request to the model `model` of version `version`. \
186    /// If version is set as `-1`, the server will choose a version based on the model's policy.
187    pub fn create_request<M: AsRef<str>>(&self, model: M, version: i64) -> Result<Request, Error> {
188        let model_name = to_cstring(model.as_ref())?;
189        let mut ptr = null_mut::<sys::TRITONSERVER_InferenceRequest>();
190
191        triton_call!(sys::TRITONSERVER_InferenceRequestNew(
192            &mut ptr as *mut _,
193            self.ptr.as_mut_ptr(),
194            model_name.as_ptr(),
195            version,
196        ))?;
197
198        assert!(!ptr.is_null());
199        Request::new(ptr, self, model)
200    }
201
202    /// Check the model repository for changes and update server state based on those changes.
203    pub fn poll_model_repository(&mut self) -> Result<(), Error> {
204        triton_call!(sys::TRITONSERVER_ServerPollModelRepository(
205            self.ptr.as_mut_ptr()
206        ))?;
207
208        self.update_all_models()
209    }
210
211    /// Set the exit timeout on the server object. This value overrides the value initially set through server options and provides a mechanism to update the exit timeout while the serving is running.
212    ///
213    /// `timeout` The exit timeout.
214    pub fn set_exit_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
215        triton_call!(
216            sys::TRITONSERVER_ServerSetExitTimeout(self.ptr.as_mut_ptr(), timeout.as_secs() as _),
217            self
218        )
219    }
220
221    /// Register a new model repository. Not available in polling mode.
222    ///
223    /// `repository` The full path to the model repository. \
224    /// `name_mapping` List of name_mapping parameters.
225    /// Each mapping has the model directory name as its key,
226    /// overridden model name as its value.
227    pub fn register_model_repo<P: AsRef<Path>, N: AsRef<str>>(
228        &mut self,
229        repository: P,
230        name_mapping: HashMap<String, String>,
231    ) -> Result<&mut Self, Error> {
232        let path = path_to_cstring(repository)?;
233
234        let mut mapping_params = name_mapping
235            .into_iter()
236            .map(|(k, v)| {
237                Parameter::new(k, ParameterContent::String(v)).map(|param| *param.ptr as *const _)
238            })
239            .collect::<Result<Vec<_>, _>>()?;
240
241        triton_call!(
242            sys::TRITONSERVER_ServerRegisterModelRepository(
243                self.ptr.as_mut_ptr(),
244                path.as_ptr(),
245                mapping_params.as_mut_ptr(),
246                mapping_params.len() as _
247            ),
248            self
249        )
250    }
251
252    /// Unregister a model repository. Not available in polling mode.
253    ///
254    /// `repository_path` The full path to the model repository.
255    pub fn unregister_model_repo<P: AsRef<Path>, N: AsRef<str>>(
256        &mut self,
257        repository: P,
258    ) -> Result<&mut Self, Error> {
259        let path = path_to_cstring(repository)?;
260
261        triton_call!(
262            sys::TRITONSERVER_ServerUnregisterModelRepository(self.ptr.as_mut_ptr(), path.as_ptr()),
263            self
264        )
265    }
266
267    /// Returns true if server is live, false otherwise.
268    pub fn is_live(&self) -> Result<bool, Error> {
269        self.ptr.is_live()
270    }
271
272    /// Returns true if server is ready, false otherwise.
273    pub fn is_ready(&self) -> Result<bool, Error> {
274        let mut result = false;
275
276        triton_call!(
277            sys::TRITONSERVER_ServerIsReady(self.ptr.as_mut_ptr(), &mut result as *mut _),
278            result
279        )
280    }
281
282    /// Returns true if the model is ready. \
283    /// `name`: The name of the model to get readiness for. \
284    /// `version`: The version of the model to get readiness for. If -1 then the server will choose a version based on the model's policy. \
285    pub fn model_is_ready<N: AsRef<str>>(&self, name: N, version: i64) -> Result<bool, Error> {
286        let name = to_cstring(name)?;
287        let mut result = false;
288
289        triton_call!(
290            sys::TRITONSERVER_ServerModelIsReady(
291                self.ptr.as_mut_ptr(),
292                name.as_ptr(),
293                version,
294                &mut result as *mut _,
295            ),
296            result
297        )
298    }
299
300    /// Get the batch properties of the model. \
301    /// `name`: The name of the model. \
302    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy. \
303    pub fn model_batch_properties<N: AsRef<str>>(
304        &self,
305        name: N,
306        version: i64,
307    ) -> Result<Batch, Error> {
308        let name = to_cstring(name)?;
309        let mut result: u32 = 0;
310        let mut ptr = null_mut::<c_void>();
311
312        triton_call!(sys::TRITONSERVER_ServerModelBatchProperties(
313            self.ptr.as_mut_ptr(),
314            name.as_ptr(),
315            version,
316            &mut result as *mut _,
317            &mut ptr as *mut _,
318        ))?;
319        unsafe { Ok(transmute::<u32, Batch>(result)) }
320    }
321
322    /// Get the transaction policy of the model. \
323    /// `name`: The name of the model. \
324    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy. \
325    pub fn model_transaction_properties<N: AsRef<str>>(
326        &self,
327        name: N,
328        version: i64,
329    ) -> Result<Transaction, Error> {
330        let name = to_cstring(name)?;
331        let mut result: u32 = 0;
332        let mut ptr = null_mut::<c_void>();
333
334        triton_call!(sys::TRITONSERVER_ServerModelTransactionProperties(
335            self.ptr.as_mut_ptr(),
336            name.as_ptr(),
337            version,
338            &mut result as *mut _,
339            &mut ptr as *mut _,
340        ))?;
341        unsafe { Ok(transmute::<u32, Transaction>(result)) }
342    }
343
344    /// Get the metadata of the server as a Message(json) object.
345    pub fn metadata(&self) -> Result<message::Server, Error> {
346        let mut result = null_mut::<sys::TRITONSERVER_Message>();
347
348        triton_call!(sys::TRITONSERVER_ServerMetadata(
349            self.ptr.as_mut_ptr(),
350            &mut result as *mut _
351        ))?;
352
353        assert!(!result.is_null());
354        Message(result).to_json().and_then(|json| {
355            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
356        })
357    }
358
359    /// Get the metadata of a model as a Message(json) object.\
360    /// `name`: The name of the model. \
361    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy.
362    pub fn model_metadata<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Model, Error> {
363        let name = to_cstring(name)?;
364        let mut result = null_mut::<sys::TRITONSERVER_Message>();
365
366        triton_call!(sys::TRITONSERVER_ServerModelMetadata(
367            self.ptr.as_mut_ptr(),
368            name.as_ptr(),
369            version,
370            &mut result as *mut _,
371        ))?;
372
373        assert!(!result.is_null());
374        Message(result).to_json().and_then(|json| {
375            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
376        })
377    }
378
379    /// Get the statistics of a model as a Message(json) object. \
380    /// `name`: The name of the model. \
381    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy.
382    pub fn model_statistics<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Value, Error> {
383        let name = to_cstring(name)?;
384        let mut result = null_mut::<sys::TRITONSERVER_Message>();
385
386        triton_call!(sys::TRITONSERVER_ServerModelStatistics(
387            self.ptr.as_mut_ptr(),
388            name.as_ptr(),
389            version,
390            &mut result as *mut _,
391        ))?;
392
393        assert!(!result.is_null());
394        Message(result).to_json().and_then(|json| {
395            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
396        })
397    }
398
399    /// Get the configuration of a model as a Message(json) object. \
400    /// `name`: The name of the model. \
401    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy. \
402    /// `config`: The model configuration will be returned in a format matching this version. \
403    /// If the configuration cannot be represented in the requested version's format then an error will be returned.
404    /// Currently only version 1 is supported.
405    pub fn model_config<N: AsRef<str>>(
406        &self,
407        name: N,
408        version: i64,
409        config: u32,
410    ) -> Result<Value, Error> {
411        let name = to_cstring(name)?;
412        let mut result = null_mut::<sys::TRITONSERVER_Message>();
413
414        triton_call!(sys::TRITONSERVER_ServerModelConfig(
415            self.ptr.as_mut_ptr(),
416            name.as_ptr(),
417            version,
418            config,
419            &mut result as *mut _,
420        ))?;
421
422        assert!(!result.is_null());
423        Message(result).to_json().and_then(|json| {
424            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
425        })
426    }
427
428    /// Get the index of all unique models in the model repositories as a Message(json) object.
429    pub fn model_index(&self, flags: State) -> Result<Vec<Index>, Error> {
430        let mut result = null_mut::<sys::TRITONSERVER_Message>();
431
432        triton_call!(sys::TRITONSERVER_ServerModelIndex(
433            self.ptr.as_mut_ptr(),
434            flags.bits(),
435            &mut result as *mut _,
436        ))?;
437
438        assert!(!result.is_null());
439        Message(result).to_json().and_then(|json| {
440            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
441        })
442    }
443
444    /// Load the requested model or reload the model if it is already loaded. \
445    /// The function does not return until the model is loaded or fails to load \.
446    /// `name`: The name of the model.
447    pub fn load_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
448        let model_name = to_cstring(&name)?;
449
450        triton_call!(sys::TRITONSERVER_ServerLoadModel(
451            self.ptr.as_mut_ptr(),
452            model_name.as_ptr()
453        ))?;
454
455        self.update_model_info(name)
456    }
457
458    /// Load the requested model or reload the model if it is already loaded, with load parameters provided. \
459    /// The function does not return until the model is loaded or fails to load. \
460    /// Currently the below parameter names are recognized:
461    ///
462    /// - "config" : string parameter that contains a JSON representation of the
463    ///   model configuration. This config will be used for loading the model instead
464    ///   of the one in the model directory.
465    ///
466    /// Can be usefull if is needed to load the model with altered config.
467    /// For example, if it's required to load only one exact version of the model (see [Parameter::from_config_with_exact_version] for more info).
468    ///
469    /// `name`: The name of the model. \
470    /// `parameters`: slice of parameters.
471    pub fn load_model_with_parametrs<N: AsRef<str>, P: AsRef<[Parameter]>>(
472        &mut self,
473        name: N,
474        parameters: P,
475    ) -> Result<(), Error> {
476        let model_name = to_cstring(&name)?;
477        let params_count = parameters.as_ref().len();
478        let mut parametrs = parameters
479            .as_ref()
480            .iter()
481            .map(|p| p.ptr.cast_const())
482            .collect::<Vec<_>>();
483
484        triton_call!(sys::TRITONSERVER_ServerLoadModelWithParameters(
485            self.ptr.as_mut_ptr(),
486            model_name.as_ptr(),
487            parametrs.as_mut_ptr(),
488            params_count as _,
489        ))?;
490
491        self.update_model_info(name)
492    }
493
494    /// Unload the requested model. \
495    /// Unloading a model that is not loaded on server has no affect and success code will be returned. \
496    /// The function does not wait for the requested model to be fully unload and success code will be returned. \
497    /// `name`: The name of the model.
498    pub fn unload_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
499        let model_name = to_cstring(&name)?;
500
501        triton_call!(sys::TRITONSERVER_ServerUnloadModel(
502            self.ptr.as_mut_ptr(),
503            model_name.as_ptr()
504        ))?;
505
506        self.update_model_info(name)
507    }
508
509    /// Unload the requested model, and also unload any dependent model that was loaded along with the requested model
510    /// (for example, the models composing an ensemble). \
511    /// Unloading a model that is not loaded on server has no affect and success code will be returned. \
512    /// The function does not wait for the requested model and all dependent models to be fully unload and success code will be returned. \
513    /// `name`: The name of the model.
514    pub fn unload_model_and_dependents<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
515        let model_name = to_cstring(&name)?;
516
517        triton_call!(sys::TRITONSERVER_ServerUnloadModelAndDependents(
518            self.ptr.as_mut_ptr(),
519            model_name.as_ptr(),
520        ))?;
521
522        self.update_model_info(name)
523    }
524
525    /// Get the current Prometheus metrics for the server.
526    pub fn metrics(&self) -> Result<metrics::PrometheusMetrics, Error> {
527        let mut metrics = null_mut::<sys::TRITONSERVER_Metrics>();
528
529        triton_call!(sys::TRITONSERVER_ServerMetrics(
530            self.ptr.as_mut_ptr(),
531            &mut metrics as *mut _
532        ))?;
533
534        assert!(!metrics.is_null());
535        Ok(PrometheusMetrics(Arc::new(metrics)))
536    }
537
538    pub fn is_log_enabled(&self, level: LogLevel) -> bool {
539        unsafe { sys::TRITONSERVER_LogIsEnabled(level as u32) }
540    }
541}
542
543#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
544#[repr(u32)]
545pub enum LogLevel {
546    Info = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_INFO,
547    Warn = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_WARN,
548    Error = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_ERROR,
549    Verbose = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_VERBOSE,
550}