1use std::{
2 collections::HashMap,
3 ffi::{c_void, CStr},
4 mem::transmute,
5 path::Path,
6 ptr::null_mut,
7 sync::Arc,
8 time::Duration,
9};
10
11use serde_json::{from_slice, Value};
12
13use crate::{
14 message::{self, Index, Message, Model},
15 metrics::{self, PrometheusMetrics},
16 options::Options,
17 parameter::{Parameter, ParameterContent},
18 path_to_cstring, sys, to_cstring, Error, ErrorCode, Request,
19};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23#[repr(u32)]
24pub enum Batch {
25 Unknown = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_UNKNOWN,
28 FirstDim = sys::tritonserver_batchflag_enum_TRITONSERVER_BATCH_FIRST_DIM,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35#[repr(u32)]
36pub enum Transaction {
37 OneToOne = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_ONE_TO_ONE,
39 Decoupled = sys::tritonserver_txn_property_flag_enum_TRITONSERVER_TXN_DECOUPLED,
41}
42
43bitflags::bitflags! {
44 pub struct State: u32 {
46 const READY = sys::tritonserver_modelindexflag_enum_TRITONSERVER_INDEX_FLAG_READY;
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53#[repr(u32)]
54pub enum InstanceGroup {
55 Auto = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_AUTO,
56 Cpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_CPU,
57 Gpu = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_GPU,
58 Model = sys::TRITONSERVER_instancegroupkind_enum_TRITONSERVER_INSTANCEGROUPKIND_MODEL,
59}
60
61impl InstanceGroup {
62 fn as_cstr(self) -> &'static CStr {
63 unsafe { CStr::from_ptr(sys::TRITONSERVER_InstanceGroupKindString(self as u32)) }
64 }
65
66 pub fn as_str(self) -> &'static str {
68 self.as_cstr()
69 .to_str()
70 .unwrap_or(crate::error::CSTR_CONVERT_ERROR_PLUG)
71 }
72}
73
74#[derive(Debug)]
75pub(crate) struct Inner(*mut sys::TRITONSERVER_Server);
76impl Inner {
77 pub(crate) fn stop(&self) -> Result<(), Error> {
78 triton_call!(sys::TRITONSERVER_ServerStop(self.0))
79 }
80
81 pub(crate) fn is_live(&self) -> Result<bool, Error> {
82 let mut result = false;
83 triton_call!(
84 sys::TRITONSERVER_ServerIsLive(self.0, &mut result as *mut _),
85 result
86 )
87 }
88
89 pub(crate) fn delete(&self) -> Result<(), Error> {
90 triton_call!(sys::TRITONSERVER_ServerDelete(self.0))
91 }
92
93 pub(crate) fn as_mut_ptr(&self) -> *mut sys::TRITONSERVER_Server {
94 self.0
95 }
96}
97
98impl Drop for Inner {
99 fn drop(&mut self) {
100 let _ = self
101 .is_live()
102 .and_then(|live| {
103 if live {
104 self.stop().and_then(|_| loop {
105 if !self.is_live()? {
106 return Ok(());
107 }
108 })
109 } else {
110 Ok(())
111 }
112 })
113 .and_then(|_| self.delete());
114 }
115}
116
117unsafe impl Send for Inner {}
122unsafe impl Sync for Inner {}
123
124#[derive(Debug)]
126pub struct Server {
127 pub(crate) ptr: Arc<Inner>,
128 pub(crate) models: HashMap<String, Model>,
129 pub(crate) runtime: tokio::runtime::Handle,
130}
131
132unsafe impl Send for Server {}
133
134impl Server {
135 pub async fn new(options: Options) -> Result<Self, Error> {
137 let mut server = null_mut::<sys::TRITONSERVER_Server>();
138 triton_call!(sys::TRITONSERVER_ServerNew(
139 &mut server as *mut _,
140 *options.0
141 ))?;
142
143 assert!(!server.is_null());
144
145 let mut server = Server {
146 ptr: Arc::new(Inner(server)),
147 models: HashMap::new(),
148 runtime: tokio::runtime::Handle::current(),
149 };
150 server.update_all_models()?;
151
152 Ok(server)
153 }
154
155 pub(crate) fn get_model<M: AsRef<str>>(&self, model: M) -> Result<&Model, Error> {
156 self.models.get(model.as_ref()).ok_or_else(|| {
157 Error::new(
158 ErrorCode::NotFound,
159 format!(
160 "Model {} is not found in server model metadata storage.",
161 model.as_ref()
162 ),
163 )
164 })
165 }
166
167 fn update_all_models(&mut self) -> Result<(), Error> {
168 for model in self.model_index(State::all())? {
169 self.update_model_info(model.name)?;
170 }
171 Ok(())
172 }
173
174 fn update_model_info<M: AsRef<str>>(&mut self, model: M) -> Result<(), Error> {
175 self.models
176 .insert(model.as_ref().to_string(), self.model_metadata(model, -1)?);
177 Ok(())
178 }
179
180 pub fn stop(&self) -> Result<(), Error> {
182 self.ptr.stop()
183 }
184
185 pub fn create_request<M: AsRef<str>>(&self, model: M, version: i64) -> Result<Request, Error> {
188 let model_name = to_cstring(model.as_ref())?;
189 let mut ptr = null_mut::<sys::TRITONSERVER_InferenceRequest>();
190
191 triton_call!(sys::TRITONSERVER_InferenceRequestNew(
192 &mut ptr as *mut _,
193 self.ptr.as_mut_ptr(),
194 model_name.as_ptr(),
195 version,
196 ))?;
197
198 assert!(!ptr.is_null());
199 Request::new(ptr, self, model)
200 }
201
202 pub fn poll_model_repository(&mut self) -> Result<(), Error> {
204 triton_call!(sys::TRITONSERVER_ServerPollModelRepository(
205 self.ptr.as_mut_ptr()
206 ))?;
207
208 self.update_all_models()
209 }
210
211 pub fn set_exit_timeout(&mut self, timeout: Duration) -> Result<&mut Self, Error> {
215 triton_call!(
216 sys::TRITONSERVER_ServerSetExitTimeout(self.ptr.as_mut_ptr(), timeout.as_secs() as _),
217 self
218 )
219 }
220
221 pub fn register_model_repo<P: AsRef<Path>, N: AsRef<str>>(
228 &mut self,
229 repository: P,
230 name_mapping: HashMap<String, String>,
231 ) -> Result<&mut Self, Error> {
232 let path = path_to_cstring(repository)?;
233
234 let mut mapping_params = name_mapping
235 .into_iter()
236 .map(|(k, v)| {
237 Parameter::new(k, ParameterContent::String(v)).map(|param| *param.ptr as *const _)
238 })
239 .collect::<Result<Vec<_>, _>>()?;
240
241 triton_call!(
242 sys::TRITONSERVER_ServerRegisterModelRepository(
243 self.ptr.as_mut_ptr(),
244 path.as_ptr(),
245 mapping_params.as_mut_ptr(),
246 mapping_params.len() as _
247 ),
248 self
249 )
250 }
251
252 pub fn unregister_model_repo<P: AsRef<Path>, N: AsRef<str>>(
256 &mut self,
257 repository: P,
258 ) -> Result<&mut Self, Error> {
259 let path = path_to_cstring(repository)?;
260
261 triton_call!(
262 sys::TRITONSERVER_ServerUnregisterModelRepository(self.ptr.as_mut_ptr(), path.as_ptr()),
263 self
264 )
265 }
266
267 pub fn is_live(&self) -> Result<bool, Error> {
269 self.ptr.is_live()
270 }
271
272 pub fn is_ready(&self) -> Result<bool, Error> {
274 let mut result = false;
275
276 triton_call!(
277 sys::TRITONSERVER_ServerIsReady(self.ptr.as_mut_ptr(), &mut result as *mut _),
278 result
279 )
280 }
281
282 pub fn model_is_ready<N: AsRef<str>>(&self, name: N, version: i64) -> Result<bool, Error> {
286 let name = to_cstring(name)?;
287 let mut result = false;
288
289 triton_call!(
290 sys::TRITONSERVER_ServerModelIsReady(
291 self.ptr.as_mut_ptr(),
292 name.as_ptr(),
293 version,
294 &mut result as *mut _,
295 ),
296 result
297 )
298 }
299
300 pub fn model_batch_properties<N: AsRef<str>>(
304 &self,
305 name: N,
306 version: i64,
307 ) -> Result<Batch, Error> {
308 let name = to_cstring(name)?;
309 let mut result: u32 = 0;
310 let mut ptr = null_mut::<c_void>();
311
312 triton_call!(sys::TRITONSERVER_ServerModelBatchProperties(
313 self.ptr.as_mut_ptr(),
314 name.as_ptr(),
315 version,
316 &mut result as *mut _,
317 &mut ptr as *mut _,
318 ))?;
319 unsafe { Ok(transmute::<u32, Batch>(result)) }
320 }
321
322 pub fn model_transaction_properties<N: AsRef<str>>(
326 &self,
327 name: N,
328 version: i64,
329 ) -> Result<Transaction, Error> {
330 let name = to_cstring(name)?;
331 let mut result: u32 = 0;
332 let mut ptr = null_mut::<c_void>();
333
334 triton_call!(sys::TRITONSERVER_ServerModelTransactionProperties(
335 self.ptr.as_mut_ptr(),
336 name.as_ptr(),
337 version,
338 &mut result as *mut _,
339 &mut ptr as *mut _,
340 ))?;
341 unsafe { Ok(transmute::<u32, Transaction>(result)) }
342 }
343
344 pub fn metadata(&self) -> Result<message::Server, Error> {
346 let mut result = null_mut::<sys::TRITONSERVER_Message>();
347
348 triton_call!(sys::TRITONSERVER_ServerMetadata(
349 self.ptr.as_mut_ptr(),
350 &mut result as *mut _
351 ))?;
352
353 assert!(!result.is_null());
354 Message(result).to_json().and_then(|json| {
355 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
356 })
357 }
358
359 pub fn model_metadata<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Model, Error> {
363 let name = to_cstring(name)?;
364 let mut result = null_mut::<sys::TRITONSERVER_Message>();
365
366 triton_call!(sys::TRITONSERVER_ServerModelMetadata(
367 self.ptr.as_mut_ptr(),
368 name.as_ptr(),
369 version,
370 &mut result as *mut _,
371 ))?;
372
373 assert!(!result.is_null());
374 Message(result).to_json().and_then(|json| {
375 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
376 })
377 }
378
379 pub fn model_statistics<N: AsRef<str>>(&self, name: N, version: i64) -> Result<Value, Error> {
383 let name = to_cstring(name)?;
384 let mut result = null_mut::<sys::TRITONSERVER_Message>();
385
386 triton_call!(sys::TRITONSERVER_ServerModelStatistics(
387 self.ptr.as_mut_ptr(),
388 name.as_ptr(),
389 version,
390 &mut result as *mut _,
391 ))?;
392
393 assert!(!result.is_null());
394 Message(result).to_json().and_then(|json| {
395 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
396 })
397 }
398
399 pub fn model_config<N: AsRef<str>>(
406 &self,
407 name: N,
408 version: i64,
409 config: u32,
410 ) -> Result<Value, Error> {
411 let name = to_cstring(name)?;
412 let mut result = null_mut::<sys::TRITONSERVER_Message>();
413
414 triton_call!(sys::TRITONSERVER_ServerModelConfig(
415 self.ptr.as_mut_ptr(),
416 name.as_ptr(),
417 version,
418 config,
419 &mut result as *mut _,
420 ))?;
421
422 assert!(!result.is_null());
423 Message(result).to_json().and_then(|json| {
424 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
425 })
426 }
427
428 pub fn model_index(&self, flags: State) -> Result<Vec<Index>, Error> {
430 let mut result = null_mut::<sys::TRITONSERVER_Message>();
431
432 triton_call!(sys::TRITONSERVER_ServerModelIndex(
433 self.ptr.as_mut_ptr(),
434 flags.bits(),
435 &mut result as *mut _,
436 ))?;
437
438 assert!(!result.is_null());
439 Message(result).to_json().and_then(|json| {
440 from_slice(json).map_err(|err| Error::new(ErrorCode::Internal, err.to_string()))
441 })
442 }
443
444 pub fn load_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
448 let model_name = to_cstring(&name)?;
449
450 triton_call!(sys::TRITONSERVER_ServerLoadModel(
451 self.ptr.as_mut_ptr(),
452 model_name.as_ptr()
453 ))?;
454
455 self.update_model_info(name)
456 }
457
458 pub fn load_model_with_parametrs<N: AsRef<str>, P: AsRef<[Parameter]>>(
472 &mut self,
473 name: N,
474 parameters: P,
475 ) -> Result<(), Error> {
476 let model_name = to_cstring(&name)?;
477 let params_count = parameters.as_ref().len();
478 let mut parametrs = parameters
479 .as_ref()
480 .iter()
481 .map(|p| p.ptr.cast_const())
482 .collect::<Vec<_>>();
483
484 triton_call!(sys::TRITONSERVER_ServerLoadModelWithParameters(
485 self.ptr.as_mut_ptr(),
486 model_name.as_ptr(),
487 parametrs.as_mut_ptr(),
488 params_count as _,
489 ))?;
490
491 self.update_model_info(name)
492 }
493
494 pub fn unload_model<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
499 let model_name = to_cstring(&name)?;
500
501 triton_call!(sys::TRITONSERVER_ServerUnloadModel(
502 self.ptr.as_mut_ptr(),
503 model_name.as_ptr()
504 ))?;
505
506 self.update_model_info(name)
507 }
508
509 pub fn unload_model_and_dependents<N: AsRef<str>>(&mut self, name: N) -> Result<(), Error> {
515 let model_name = to_cstring(&name)?;
516
517 triton_call!(sys::TRITONSERVER_ServerUnloadModelAndDependents(
518 self.ptr.as_mut_ptr(),
519 model_name.as_ptr(),
520 ))?;
521
522 self.update_model_info(name)
523 }
524
525 pub fn metrics(&self) -> Result<metrics::PrometheusMetrics, Error> {
527 let mut metrics = null_mut::<sys::TRITONSERVER_Metrics>();
528
529 triton_call!(sys::TRITONSERVER_ServerMetrics(
530 self.ptr.as_mut_ptr(),
531 &mut metrics as *mut _
532 ))?;
533
534 assert!(!metrics.is_null());
535 Ok(PrometheusMetrics(Arc::new(metrics)))
536 }
537
538 pub fn is_log_enabled(&self, level: LogLevel) -> bool {
539 unsafe { sys::TRITONSERVER_LogIsEnabled(level as u32) }
540 }
541}
542
543#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
544#[repr(u32)]
545pub enum LogLevel {
546 Info = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_INFO,
547 Warn = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_WARN,
548 Error = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_ERROR,
549 Verbose = sys::TRITONSERVER_loglevel_enum_TRITONSERVER_LOG_VERBOSE,
550}