tritonserver_rs/
server.rs

1use std::{
2    collections::HashMap,
3    ffi::{c_void, CStr},
4    mem::transmute,
5    path::Path,
6    ptr::null_mut,
7    sync::Arc,
8    time::Duration,
9};
10
11use serde_json::{from_slice, Value};
12
13use crate::{
14    message::{self, Index, Message, Model},
15    metrics::{self, Metrics},
16    options::Options,
17    parameter::{Parameter, ParameterContent},
18    path_to_cstring, sys, to_cstring, Error, ErrorCode, Request,
19};
20
21/// Batch properties of the model.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[repr(u32)]
24pub enum Batch {
25    /// Triton cannot determine the batching properties of the model.
26    /// This means that the model does not support batching in any way that is useable by Triton.
27    Unknown = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_UNKNOWN,
28    /// The model supports batching along the first dimension of every input and output tensor.
29    /// Triton schedulers that perform batching can automatically batch inference requests along this dimension.
30    FirstDim = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_FIRST_DIM,
31}
32
33/// Transaction policy of the model.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35#[repr(u32)]
36pub enum Transaction {
37    /// The model generates exactly one response per request.
38    OneToOne = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_ONE_TO_ONE,
39    /// The model may generate zero to many responses per request.
40    Decoupled = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_DECOUPLED,
41}
42
43bitflags::bitflags! {
44    /// Flags that control how to collect the index.
45    pub struct State: u32 {
46        /// If set in 'flags', only the models that are loaded into the server and ready for inferencing are returned.
47        const READY = sys::tritonserver_modelindexflag_enum_TRITONSERVER_INDEX_FLAG_READY;
48    }
49}
50
51/// Kinds of instance groups recognized by TRITONSERVER.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53#[repr(u32)]
54pub enum InstanceGroup {
55    Auto = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_AUTO,
56    Cpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_CPU,
57    Gpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_GPU,
58    Model = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_MODEL,
59}
60
61impl InstanceGroup {
62    fn as_cstr(self) -> &'static CStr {
63        unsafe { CStr::from_ptr(sys::TRITONSERVER_InstanceGroupKindString(self as u32)) }
64    }
65
66    /// Get the string representation of an instance-group kind.
67    pub fn as_str(self) -> &'static str {
68        self.as_cstr()
69            .to_str()
70            .unwrap_or(crate::error::CSTR_CONVERT_ERROR_PLUG)
71    }
72}
73
74#[derive(Debug)]
75pub(crate) struct Inner(*mut sys::TRITONSERVER_Server);
76impl Inner {
77    pub(crate) fn stop(&self) -> Result<(), Error> {
78        triton_call!(sys::TRITONSERVER_ServerStop(self.0))
79    }
80
81    pub(crate) fn is_live(&self) -> Result<bool, Error> {
82        let mut result = false;
83        triton_call!(
84            sys::TRITONSERVER_ServerIsLive(self.0, &mut result as *mut _),
85            result
86        )
87    }
88
89    pub(crate) fn delete(&self) -> Result<(), Error> {
90        triton_call!(sys::TRITONSERVER_ServerDelete(self.0))
91    }
92
93    pub(crate) fn as_mut_ptr(&self) -> *mut sys::TRITONSERVER_Server {
94        self.0
95    }
96}
97
98impl Drop for Inner {
99    fn drop(&mut self) {
100        let _ = self
101            .is_live()
102            .and_then(|live| {
103                if live {
104                    self.stop().and_then(|_| loop {
105                        if !self.is_live()? {
106                            return Ok(());
107                        }
108                    })
109                } else {
110                    Ok(())
111                }
112            })
113            .and_then(|_| self.delete());
114    }
115}
116
117/// # SAFETY
118/// Inner is Send. But it's not Sync! \
119/// However, it's used only in Server and Server is never clones Inner,
120/// so there is always only 1 copy of it.
121unsafe impl Send for Inner {}
122unsafe impl Sync for Inner {}
123
124/// Inference server object.
125#[derive(Debug)]
126pub struct Server {
127    pub(crate) ptr: Arc<Inner>,
128    pub(crate) models: HashMap<String, Model>,
129    pub(crate) runtime: tokio::runtime::Handle,
130}
131
132unsafe impl Send for Server {}
133impl Server {
134    /// Create new server object.
135    pub async fn new(options: Options) -> Result<Self, Error> {
136        let mut server = null_mut::<sys::TRITONSERVER_Server>();
137        triton_call!(sys::TRITONSERVER_ServerNew(
138            &mut server as *mut _,
139            options.0
140        ))?;
141
142        assert!(!server.is_null());
143
144        let mut server = Server {
145            ptr: Arc::new(Inner(server)),
146            models: HashMap::new(),
147            runtime: tokio::runtime::Handle::current(),
148        };
149        server.update_all_models()?;
150
151        Ok(server)
152    }
153
154    pub(crate) fn get_model<M: AsRef<str>>(&self, model: M) -> Result<&Model, Error> {
155        self.models.get(model.as_ref()).ok_or_else(|| {
156            Error::new(
157                ErrorCode::NotFound,
158                format!(
159                    "Model {} is not found in server model metadata storage.",
160                    model.as_ref()
161                ),
162            )
163        })
164    }
165
166    fn update_all_models(&mut self) -> Result<(), Error> {
167        for model in self.model_index(State::all())? {
168            self.update_model_info(model.name)?;
169        }
170        Ok(())
171    }
172
173    fn update_model_info<M: AsRef<str>>(&mut self, model: M) -> Result<(), Error> {
174        self.models
175            .insert(model.as_ref().to_string(), self.model_metadata(model, -1)?);
176        Ok(())
177    }
178
179    /// Stop a server object. A server can't be restarted once it has been stopped.
180    pub fn stop(&self) -> Result<(), Error> {
181        self.ptr.stop()
182    }
183
184    /// Create a request to the model `model` of version `version`. \
185    /// If version is set as `-1`, the server will choose a version based on the model's policy.
186    pub fn create_request<M: AsRef<str>>(&self, model: M, version: i64) -> Result<Request, Error> {
187        let model_name = to_cstring(model.as_ref())?;
188        let mut ptr = null_mut::<sys::TRITONSERVER_InferenceRequest>();
189
190        triton_call!(sys::TRITONSERVER_InferenceRequestNew(
191            &mut ptr as *mut _,
192            self.ptr.as_mut_ptr(),
193            model_name.as_ptr(),
194            version,
195        ))?;
196
197        assert!(!ptr.is_null());
198        Request::new(ptr, self, model)
199    }
200
201    /// Check the model repository for changes and update server state based on those changes.
202    pub fn poll_model_repository(&mut self) -> Result<(), Error> {
203        triton_call!(sys::TRITONSERVER_ServerPollModelRepository(
204            self.ptr.as_mut_ptr()
205        ))?;
206
207        self.update_all_models()
208    }
209
210    /// Set the exit timeout on the server object. This value overrides the value initially set through server options and provides a mechanism to update the exit timeout while the serving is running.
211    ///
212    /// `timeout` The exit timeout.
213    pub fn set_exit_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
214        triton_call!(
215            sys::TRITONSERVER_ServerSetExitTimeout(self.ptr.as_mut_ptr(), timeout.as_secs() as _),
216            self
217        )
218    }
219
220    /// Register a new model repository. Not available in polling mode.
221    ///
222    /// `repository` The full path to the model repository. \
223    /// `name_mapping` List of name_mapping parameters.
224    /// Each mapping has the model directory name as its key,
225    /// overridden model name as its value.
226    pub fn register_model_repo<P: AsRef<Path>, N: AsRef<str>>(
227        &mut self,
228        repository: P,
229        name_mapping: HashMap<String, String>,
230    ) -> Result<&mut Self, Error> {
231        let path = path_to_cstring(repository)?;
232
233        let mut mapping_params = name_mapping
234            .into_iter()
235            .map(|(k, v)| {
236                Parameter::new(k, ParameterContent::String(v)).map(|param| param.ptr as *const _)
237            })
238            .collect::<Result<Vec<_>, _>>()?;
239
240        triton_call!(
241            sys::TRITONSERVER_ServerRegisterModelRepository(
242                self.ptr.as_mut_ptr(),
243                path.as_ptr(),
244                mapping_params.as_mut_ptr(),
245                mapping_params.len() as _
246            ),
247            self
248        )
249    }
250
251    /// Unregister a model repository. Not available in polling mode.
252    ///
253    /// `repository_path` The full path to the model repository.
254    pub fn unregister_model_repo<P: AsRef<Path>, N: AsRef<str>>(
255        &mut self,
256        repository: P,
257    ) -> Result<&mut Self, Error> {
258        let path = path_to_cstring(repository)?;
259
260        triton_call!(
261            sys::TRITONSERVER_ServerUnregisterModelRepository(self.ptr.as_mut_ptr(), path.as_ptr()),
262            self
263        )
264    }
265
266    /// Returns true if server is live, false otherwise.
267    pub fn is_live(&self) -> Result<bool, Error> {
268        self.ptr.is_live()
269    }
270
271    /// Returns true if server is ready, false otherwise.
272    pub fn is_ready(&self) -> Result<bool, Error> {
273        let mut result = false;
274
275        triton_call!(
276            sys::TRITONSERVER_ServerIsReady(self.ptr.as_mut_ptr(), &mut result as *mut _),
277            result
278        )
279    }
280
281    /// Returns true if the model is ready. \
282    /// `name`: The name of the model to get readiness for. \
283    /// `version`: The version of the model to get readiness for. If -1 then the server will choose a version based on the model's policy. \
284    pub fn model_is_ready<N: AsRef<str>>(&self, name: N, version: i64) -> Result<bool, Error> {
285        let name = to_cstring(name)?;
286        let mut result = false;
287
288        triton_call!(
289            sys::TRITONSERVER_ServerModelIsReady(
290                self.ptr.as_mut_ptr(),
291                name.as_ptr(),
292                version,
293                &mut result as *mut _,
294            ),
295            result
296        )
297    }
298
299    /// Get the batch properties of the model. \
300    /// `name`: The name of the model. \
301    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy. \
302    pub fn model_batch_properties<N: AsRef<str>>(
303        &self,
304        name: N,
305        version: i64,
306    ) -> Result<Batch, Error> {
307        let name = to_cstring(name)?;
308        let mut result: u32 = 0;
309        let mut ptr = null_mut::<c_void>();
310
311        triton_call!(sys::TRITONSERVER_ServerModelBatchProperties(
312            self.ptr.as_mut_ptr(),
313            name.as_ptr(),
314            version,
315            &mut result as *mut _,
316            &mut ptr as *mut _,
317        ))?;
318        unsafe { Ok(transmute::<u32, Batch>(result)) }
319    }
320
321    /// Get the transaction policy of the model. \
322    /// `name`: The name of the model. \
323    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy. \
324    pub fn model_transaction_properties<N: AsRef<str>>(
325        &self,
326        name: N,
327        version: i64,
328    ) -> Result<Transaction, Error> {
329        let name = to_cstring(name)?;
330        let mut result: u32 = 0;
331        let mut ptr = null_mut::<c_void>();
332
333        triton_call!(sys::TRITONSERVER_ServerModelTransactionProperties(
334            self.ptr.as_mut_ptr(),
335            name.as_ptr(),
336            version,
337            &mut result as *mut _,
338            &mut ptr as *mut _,
339        ))?;
340        unsafe { Ok(transmute::<u32, Transaction>(result)) }
341    }
342
343    /// Get the metadata of the server as a Message(json) object.
344    pub fn metadata(&self) -> Result<message::Server, Error> {
345        let mut result = null_mut::<sys::TRITONSERVER_Message>();
346
347        triton_call!(sys::TRITONSERVER_ServerMetadata(
348            self.ptr.as_mut_ptr(),
349            &mut result as *mut _
350        ))?;
351
352        assert!(!result.is_null());
353        Message(result).to_json().and_then(|json| {
354            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
355        })
356    }
357
358    /// Get the metadata of a model as a Message(json) object.\
359    /// `name`: The name of the model. \
360    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy.
361    pub fn model_metadata<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Model, Error> {
362        let name = to_cstring(name)?;
363        let mut result = null_mut::<sys::TRITONSERVER_Message>();
364
365        triton_call!(sys::TRITONSERVER_ServerModelMetadata(
366            self.ptr.as_mut_ptr(),
367            name.as_ptr(),
368            version,
369            &mut result as *mut _,
370        ))?;
371
372        assert!(!result.is_null());
373        Message(result).to_json().and_then(|json| {
374            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
375        })
376    }
377
378    /// Get the statistics of a model as a Message(json) object. \
379    /// `name`: The name of the model. \
380    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy.
381    pub fn model_statistics<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Value, Error> {
382        let name = to_cstring(name)?;
383        let mut result = null_mut::<sys::TRITONSERVER_Message>();
384
385        triton_call!(sys::TRITONSERVER_ServerModelStatistics(
386            self.ptr.as_mut_ptr(),
387            name.as_ptr(),
388            version,
389            &mut result as *mut _,
390        ))?;
391
392        assert!(!result.is_null());
393        Message(result).to_json().and_then(|json| {
394            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
395        })
396    }
397
398    /// Get the configuration of a model as a Message(json) object. \
399    /// `name`: The name of the model. \
400    /// `version`: The version of the model. If -1 then the server will choose a version based on the model's policy. \
401    /// `config`: The model configuration will be returned in a format matching this version. \
402    /// If the configuration cannot be represented in the requested version's format then an error will be returned.
403    /// Currently only version 1 is supported.
404    pub fn model_config<N: AsRef<str>>(
405        &self,
406        name: N,
407        version: i64,
408        config: u32,
409    ) -> Result<Value, Error> {
410        let name = to_cstring(name)?;
411        let mut result = null_mut::<sys::TRITONSERVER_Message>();
412
413        triton_call!(sys::TRITONSERVER_ServerModelConfig(
414            self.ptr.as_mut_ptr(),
415            name.as_ptr(),
416            version,
417            config,
418            &mut result as *mut _,
419        ))?;
420
421        assert!(!result.is_null());
422        Message(result).to_json().and_then(|json| {
423            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
424        })
425    }
426
427    /// Get the index of all unique models in the model repositories as a Message(json) object.
428    pub fn model_index(&self, flags: State) -> Result<Vec<Index>, Error> {
429        let mut result = null_mut::<sys::TRITONSERVER_Message>();
430
431        triton_call!(sys::TRITONSERVER_ServerModelIndex(
432            self.ptr.as_mut_ptr(),
433            flags.bits(),
434            &mut result as *mut _,
435        ))?;
436
437        assert!(!result.is_null());
438        Message(result).to_json().and_then(|json| {
439            from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
440        })
441    }
442
443    /// Load the requested model or reload the model if it is already loaded. \
444    /// The function does not return until the model is loaded or fails to load \.
445    /// `name`: The name of the model.
446    pub fn load_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
447        let model_name = to_cstring(&name)?;
448
449        triton_call!(sys::TRITONSERVER_ServerLoadModel(
450            self.ptr.as_mut_ptr(),
451            model_name.as_ptr()
452        ))?;
453
454        self.update_model_info(name)
455    }
456
457    /// Load the requested model or reload the model if it is already loaded, with load parameters provided. \
458    /// The function does not return until the model is loaded or fails to load. \
459    /// Currently the below parameter names are recognized:
460    ///
461    /// - "config" : string parameter that contains a JSON representation of the
462    ///   model configuration. This config will be used for loading the model instead
463    ///   of the one in the model directory.
464    ///
465    /// Can be usefull if is needed to load the model with altered config.
466    /// For example, if it's required to load only one exact version of the model (see [Parameter::from_config_with_exact_version] for more info).
467    ///
468    /// `name`: The name of the model. \
469    /// `parameters`: slice of parameters.
470    pub fn load_model_with_parametrs<N: AsRef<str>, P: AsRef<[Parameter]>>(
471        &mut self,
472        name: N,
473        parameters: P,
474    ) -> Result<(), Error> {
475        let model_name = to_cstring(&name)?;
476        let params_count = parameters.as_ref().len();
477        let mut parametrs = parameters
478            .as_ref()
479            .iter()
480            .map(|p| p.ptr.cast_const())
481            .collect::<Vec<_>>();
482
483        triton_call!(sys::TRITONSERVER_ServerLoadModelWithParameters(
484            self.ptr.as_mut_ptr(),
485            model_name.as_ptr(),
486            parametrs.as_mut_ptr(),
487            params_count as _,
488        ))?;
489
490        self.update_model_info(name)
491    }
492
493    /// Unload the requested model. \
494    /// Unloading a model that is not loaded on server has no affect and success code will be returned. \
495    /// The function does not wait for the requested model to be fully unload and success code will be returned. \
496    /// `name`: The name of the model.
497    pub fn unload_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
498        let model_name = to_cstring(&name)?;
499
500        triton_call!(sys::TRITONSERVER_ServerUnloadModel(
501            self.ptr.as_mut_ptr(),
502            model_name.as_ptr()
503        ))?;
504
505        self.update_model_info(name)
506    }
507
508    /// Unload the requested model, and also unload any dependent model that was loaded along with the requested model
509    /// (for example, the models composing an ensemble). \
510    /// Unloading a model that is not loaded on server has no affect and success code will be returned. \
511    /// The function does not wait for the requested model and all dependent models to be fully unload and success code will be returned. \
512    /// `name`: The name of the model.
513    pub fn unload_model_and_dependents<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
514        let model_name = to_cstring(&name)?;
515
516        triton_call!(sys::TRITONSERVER_ServerUnloadModelAndDependents(
517            self.ptr.as_mut_ptr(),
518            model_name.as_ptr(),
519        ))?;
520
521        self.update_model_info(name)
522    }
523
524    /// Get the current metrics for the server.
525    pub fn metrics(&self) -> Result<metrics::Metrics, Error> {
526        let mut metrics = null_mut::<sys::TRITONSERVER_Metrics>();
527
528        triton_call!(sys::TRITONSERVER_ServerMetrics(
529            self.ptr.as_mut_ptr(),
530            &mut metrics as *mut _
531        ))?;
532
533        assert!(!metrics.is_null());
534        Ok(Metrics(metrics))
535    }
536
537    pub fn is_log_enabled(&self, level: LogLevel) -> bool {
538        unsafe { sys::TRITONSERVER_LogIsEnabled(level as u32) }
539    }
540}
541
542#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
543#[repr(u32)]
544pub enum LogLevel {
545    Info = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_INFO,
546    Warn = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_WARN,
547    Error = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_ERROR,
548    Verbose = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_VERBOSE,
549}