1use std::{
2 collections::HashMap,
3 ffi::{c_void, CStr},
4 mem::transmute,
5 path::Path,
6 ptr::null_mut,
7 sync::Arc,
8 time::Duration,
9};
10
11use serde_json::{from_slice, Value};
12
13use crate::{
14 message::{self, Index, Message, Model},
15 metrics::{self, Metrics},
16 options::Options,
17 parameter::{Parameter, ParameterContent},
18 path_to_cstring, sys, to_cstring, Error, ErrorCode, Request,
19};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[repr(u32)]
24pub enum Batch {
25 Unknown = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_UNKNOWN,
28 FirstDim = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_FIRST_DIM,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35#[repr(u32)]
36pub enum Transaction {
37 OneToOne = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_ONE_TO_ONE,
39 Decoupled = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_DECOUPLED,
41}
42
43bitflags::bitflags! {
44 pub struct State: u32 {
46 const READY = sys::tritonserver_modelindexflag_enum_TRITONSERVER_INDEX_FLAG_READY;
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53#[repr(u32)]
54pub enum InstanceGroup {
55 Auto = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_AUTO,
56 Cpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_CPU,
57 Gpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_GPU,
58 Model = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_MODEL,
59}
60
61impl InstanceGroup {
62 fn as_cstr(self) -> &'static CStr {
63 unsafe { CStr::from_ptr(sys::TRITONSERVER_InstanceGroupKindString(self as u32)) }
64 }
65
66 pub fn as_str(self) -> &'static str {
68 self.as_cstr()
69 .to_str()
70 .unwrap_or(crate::error::CSTR_CONVERT_ERROR_PLUG)
71 }
72}
73
74#[derive(Debug)]
75pub(crate) struct Inner(*mut sys::TRITONSERVER_Server);
76impl Inner {
77 pub(crate) fn stop(&self) -> Result<(), Error> {
78 triton_call!(sys::TRITONSERVER_ServerStop(self.0))
79 }
80
81 pub(crate) fn is_live(&self) -> Result<bool, Error> {
82 let mut result = false;
83 triton_call!(
84 sys::TRITONSERVER_ServerIsLive(self.0, &mut result as *mut _),
85 result
86 )
87 }
88
89 pub(crate) fn delete(&self) -> Result<(), Error> {
90 triton_call!(sys::TRITONSERVER_ServerDelete(self.0))
91 }
92
93 pub(crate) fn as_mut_ptr(&self) -> *mut sys::TRITONSERVER_Server {
94 self.0
95 }
96}
97
98impl Drop for Inner {
99 fn drop(&mut self) {
100 let _ = self
101 .is_live()
102 .and_then(|live| {
103 if live {
104 self.stop().and_then(|_| loop {
105 if !self.is_live()? {
106 return Ok(());
107 }
108 })
109 } else {
110 Ok(())
111 }
112 })
113 .and_then(|_| self.delete());
114 }
115}
116
117unsafe impl Send for Inner {}
122unsafe impl Sync for Inner {}
123
124#[derive(Debug)]
126pub struct Server {
127 pub(crate) ptr: Arc<Inner>,
128 pub(crate) models: HashMap<String, Model>,
129 pub(crate) runtime: tokio::runtime::Handle,
130}
131
132unsafe impl Send for Server {}
133impl Server {
134 pub async fn new(options: Options) -> Result<Self, Error> {
136 let mut server = null_mut::<sys::TRITONSERVER_Server>();
137 triton_call!(sys::TRITONSERVER_ServerNew(
138 &mut server as *mut _,
139 options.0
140 ))?;
141
142 assert!(!server.is_null());
143
144 let mut server = Server {
145 ptr: Arc::new(Inner(server)),
146 models: HashMap::new(),
147 runtime: tokio::runtime::Handle::current(),
148 };
149 server.update_all_models()?;
150
151 Ok(server)
152 }
153
154 pub(crate) fn get_model<M: AsRef<str>>(&self, model: M) -> Result<&Model, Error> {
155 self.models.get(model.as_ref()).ok_or_else(|| {
156 Error::new(
157 ErrorCode::NotFound,
158 format!(
159 "Model {} is not found in server model metadata storage.",
160 model.as_ref()
161 ),
162 )
163 })
164 }
165
166 fn update_all_models(&mut self) -> Result<(), Error> {
167 for model in self.model_index(State::all())? {
168 self.update_model_info(model.name)?;
169 }
170 Ok(())
171 }
172
173 fn update_model_info<M: AsRef<str>>(&mut self, model: M) -> Result<(), Error> {
174 self.models
175 .insert(model.as_ref().to_string(), self.model_metadata(model, -1)?);
176 Ok(())
177 }
178
179 pub fn stop(&self) -> Result<(), Error> {
181 self.ptr.stop()
182 }
183
184 pub fn create_request<M: AsRef<str>>(&self, model: M, version: i64) -> Result<Request, Error> {
187 let model_name = to_cstring(model.as_ref())?;
188 let mut ptr = null_mut::<sys::TRITONSERVER_InferenceRequest>();
189
190 triton_call!(sys::TRITONSERVER_InferenceRequestNew(
191 &mut ptr as *mut _,
192 self.ptr.as_mut_ptr(),
193 model_name.as_ptr(),
194 version,
195 ))?;
196
197 assert!(!ptr.is_null());
198 Request::new(ptr, self, model)
199 }
200
201 pub fn poll_model_repository(&mut self) -> Result<(), Error> {
203 triton_call!(sys::TRITONSERVER_ServerPollModelRepository(
204 self.ptr.as_mut_ptr()
205 ))?;
206
207 self.update_all_models()
208 }
209
210 pub fn set_exit_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
214 triton_call!(
215 sys::TRITONSERVER_ServerSetExitTimeout(self.ptr.as_mut_ptr(), timeout.as_secs() as _),
216 self
217 )
218 }
219
220 pub fn register_model_repo<P: AsRef<Path>, N: AsRef<str>>(
227 &mut self,
228 repository: P,
229 name_mapping: HashMap<String, String>,
230 ) -> Result<&mut Self, Error> {
231 let path = path_to_cstring(repository)?;
232
233 let mut mapping_params = name_mapping
234 .into_iter()
235 .map(|(k, v)| {
236 Parameter::new(k, ParameterContent::String(v)).map(|param| param.ptr as *const _)
237 })
238 .collect::<Result<Vec<_>, _>>()?;
239
240 triton_call!(
241 sys::TRITONSERVER_ServerRegisterModelRepository(
242 self.ptr.as_mut_ptr(),
243 path.as_ptr(),
244 mapping_params.as_mut_ptr(),
245 mapping_params.len() as _
246 ),
247 self
248 )
249 }
250
251 pub fn unregister_model_repo<P: AsRef<Path>, N: AsRef<str>>(
255 &mut self,
256 repository: P,
257 ) -> Result<&mut Self, Error> {
258 let path = path_to_cstring(repository)?;
259
260 triton_call!(
261 sys::TRITONSERVER_ServerUnregisterModelRepository(self.ptr.as_mut_ptr(), path.as_ptr()),
262 self
263 )
264 }
265
266 pub fn is_live(&self) -> Result<bool, Error> {
268 self.ptr.is_live()
269 }
270
271 pub fn is_ready(&self) -> Result<bool, Error> {
273 let mut result = false;
274
275 triton_call!(
276 sys::TRITONSERVER_ServerIsReady(self.ptr.as_mut_ptr(), &mut result as *mut _),
277 result
278 )
279 }
280
281 pub fn model_is_ready<N: AsRef<str>>(&self, name: N, version: i64) -> Result<bool, Error> {
285 let name = to_cstring(name)?;
286 let mut result = false;
287
288 triton_call!(
289 sys::TRITONSERVER_ServerModelIsReady(
290 self.ptr.as_mut_ptr(),
291 name.as_ptr(),
292 version,
293 &mut result as *mut _,
294 ),
295 result
296 )
297 }
298
299 pub fn model_batch_properties<N: AsRef<str>>(
303 &self,
304 name: N,
305 version: i64,
306 ) -> Result<Batch, Error> {
307 let name = to_cstring(name)?;
308 let mut result: u32 = 0;
309 let mut ptr = null_mut::<c_void>();
310
311 triton_call!(sys::TRITONSERVER_ServerModelBatchProperties(
312 self.ptr.as_mut_ptr(),
313 name.as_ptr(),
314 version,
315 &mut result as *mut _,
316 &mut ptr as *mut _,
317 ))?;
318 unsafe { Ok(transmute::<u32, Batch>(result)) }
319 }
320
321 pub fn model_transaction_properties<N: AsRef<str>>(
325 &self,
326 name: N,
327 version: i64,
328 ) -> Result<Transaction, Error> {
329 let name = to_cstring(name)?;
330 let mut result: u32 = 0;
331 let mut ptr = null_mut::<c_void>();
332
333 triton_call!(sys::TRITONSERVER_ServerModelTransactionProperties(
334 self.ptr.as_mut_ptr(),
335 name.as_ptr(),
336 version,
337 &mut result as *mut _,
338 &mut ptr as *mut _,
339 ))?;
340 unsafe { Ok(transmute::<u32, Transaction>(result)) }
341 }
342
343 pub fn metadata(&self) -> Result<message::Server, Error> {
345 let mut result = null_mut::<sys::TRITONSERVER_Message>();
346
347 triton_call!(sys::TRITONSERVER_ServerMetadata(
348 self.ptr.as_mut_ptr(),
349 &mut result as *mut _
350 ))?;
351
352 assert!(!result.is_null());
353 Message(result).to_json().and_then(|json| {
354 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
355 })
356 }
357
358 pub fn model_metadata<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Model, Error> {
362 let name = to_cstring(name)?;
363 let mut result = null_mut::<sys::TRITONSERVER_Message>();
364
365 triton_call!(sys::TRITONSERVER_ServerModelMetadata(
366 self.ptr.as_mut_ptr(),
367 name.as_ptr(),
368 version,
369 &mut result as *mut _,
370 ))?;
371
372 assert!(!result.is_null());
373 Message(result).to_json().and_then(|json| {
374 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
375 })
376 }
377
378 pub fn model_statistics<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Value, Error> {
382 let name = to_cstring(name)?;
383 let mut result = null_mut::<sys::TRITONSERVER_Message>();
384
385 triton_call!(sys::TRITONSERVER_ServerModelStatistics(
386 self.ptr.as_mut_ptr(),
387 name.as_ptr(),
388 version,
389 &mut result as *mut _,
390 ))?;
391
392 assert!(!result.is_null());
393 Message(result).to_json().and_then(|json| {
394 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
395 })
396 }
397
398 pub fn model_config<N: AsRef<str>>(
405 &self,
406 name: N,
407 version: i64,
408 config: u32,
409 ) -> Result<Value, Error> {
410 let name = to_cstring(name)?;
411 let mut result = null_mut::<sys::TRITONSERVER_Message>();
412
413 triton_call!(sys::TRITONSERVER_ServerModelConfig(
414 self.ptr.as_mut_ptr(),
415 name.as_ptr(),
416 version,
417 config,
418 &mut result as *mut _,
419 ))?;
420
421 assert!(!result.is_null());
422 Message(result).to_json().and_then(|json| {
423 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
424 })
425 }
426
427 pub fn model_index(&self, flags: State) -> Result<Vec<Index>, Error> {
429 let mut result = null_mut::<sys::TRITONSERVER_Message>();
430
431 triton_call!(sys::TRITONSERVER_ServerModelIndex(
432 self.ptr.as_mut_ptr(),
433 flags.bits(),
434 &mut result as *mut _,
435 ))?;
436
437 assert!(!result.is_null());
438 Message(result).to_json().and_then(|json| {
439 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
440 })
441 }
442
443 pub fn load_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
447 let model_name = to_cstring(&name)?;
448
449 triton_call!(sys::TRITONSERVER_ServerLoadModel(
450 self.ptr.as_mut_ptr(),
451 model_name.as_ptr()
452 ))?;
453
454 self.update_model_info(name)
455 }
456
457 pub fn load_model_with_parametrs<N: AsRef<str>, P: AsRef<[Parameter]>>(
471 &mut self,
472 name: N,
473 parameters: P,
474 ) -> Result<(), Error> {
475 let model_name = to_cstring(&name)?;
476 let params_count = parameters.as_ref().len();
477 let mut parametrs = parameters
478 .as_ref()
479 .iter()
480 .map(|p| p.ptr.cast_const())
481 .collect::<Vec<_>>();
482
483 triton_call!(sys::TRITONSERVER_ServerLoadModelWithParameters(
484 self.ptr.as_mut_ptr(),
485 model_name.as_ptr(),
486 parametrs.as_mut_ptr(),
487 params_count as _,
488 ))?;
489
490 self.update_model_info(name)
491 }
492
493 pub fn unload_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
498 let model_name = to_cstring(&name)?;
499
500 triton_call!(sys::TRITONSERVER_ServerUnloadModel(
501 self.ptr.as_mut_ptr(),
502 model_name.as_ptr()
503 ))?;
504
505 self.update_model_info(name)
506 }
507
508 pub fn unload_model_and_dependents<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
514 let model_name = to_cstring(&name)?;
515
516 triton_call!(sys::TRITONSERVER_ServerUnloadModelAndDependents(
517 self.ptr.as_mut_ptr(),
518 model_name.as_ptr(),
519 ))?;
520
521 self.update_model_info(name)
522 }
523
524 pub fn metrics(&self) -> Result<metrics::Metrics, Error> {
526 let mut metrics = null_mut::<sys::TRITONSERVER_Metrics>();
527
528 triton_call!(sys::TRITONSERVER_ServerMetrics(
529 self.ptr.as_mut_ptr(),
530 &mut metrics as *mut _
531 ))?;
532
533 assert!(!metrics.is_null());
534 Ok(Metrics(metrics))
535 }
536
537 pub fn is_log_enabled(&self, level: LogLevel) -> bool {
538 unsafe { sys::TRITONSERVER_LogIsEnabled(level as u32) }
539 }
540}
541
542#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
543#[repr(u32)]
544pub enum LogLevel {
545 Info = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_INFO,
546 Warn = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_WARN,
547 Error = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_ERROR,
548 Verbose = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_VERBOSE,
549}