1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12 ($expr:expr) => {
13 unsafe {
14 if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15 let buf = CStr::from_ptr(sys::tract_get_last_error());
16 Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17 } else {
18 Ok(())
19 }
20 }
21 };
22}
23
24macro_rules! wrapper {
25 ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26 #[derive(Debug, Clone)]
27 pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29 impl Drop for $new_type {
30 fn drop(&mut self) {
31 unsafe {
32 sys::$dest(&mut self.0);
33 }
34 }
35 }
36 };
37}
38
39pub fn nnef() -> Result<Nnef> {
40 let mut nnef = null_mut();
41 check!(sys::tract_nnef_create(&mut nnef))?;
42 Ok(Nnef(nnef))
43}
44
45pub fn onnx() -> Result<Onnx> {
46 let mut onnx = null_mut();
47 check!(sys::tract_onnx_create(&mut onnx))?;
48 Ok(Onnx(onnx))
49}
50
51pub fn version() -> &'static str {
52 unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
53}
54
55wrapper!(Nnef, TractNnef, tract_nnef_destroy);
56impl NnefInterface for Nnef {
57 type Model = Model;
58 fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Model> {
59 let path = path.as_ref();
60 let path = CString::new(
61 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
62 )?;
63 let mut model = null_mut();
64 check!(sys::tract_nnef_model_for_path(self.0, path.as_ptr(), &mut model))?;
65 Ok(Model(model))
66 }
67
68 fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()> {
69 let t = CString::new(transform_spec)?;
70 check!(sys::tract_nnef_transform_model(self.0, model.0, t.as_ptr()))
71 }
72
73 fn enable_tract_core(&mut self) -> Result<()> {
74 check!(sys::tract_nnef_enable_tract_core(self.0))
75 }
76
77 fn enable_tract_extra(&mut self) -> Result<()> {
78 check!(sys::tract_nnef_enable_tract_extra(self.0))
79 }
80
81 fn enable_tract_transformers(&mut self) -> Result<()> {
82 check!(sys::tract_nnef_enable_tract_transformers(self.0))
83 }
84
85 fn enable_onnx(&mut self) -> Result<()> {
86 check!(sys::tract_nnef_enable_onnx(self.0))
87 }
88
89 fn enable_pulse(&mut self) -> Result<()> {
90 check!(sys::tract_nnef_enable_pulse(self.0))
91 }
92
93 fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
94 check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
95 }
96
97 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
98 let path = path.as_ref();
99 let path = CString::new(
100 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
101 )?;
102 check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
103 Ok(())
104 }
105
106 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
107 let path = path.as_ref();
108 let path = CString::new(
109 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
110 )?;
111 check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
112 Ok(())
113 }
114
115 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
116 let path = path.as_ref();
117 let path = CString::new(
118 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
119 )?;
120 check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
121 Ok(())
122 }
123}
124
125wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
127
128impl OnnxInterface for Onnx {
129 type InferenceModel = InferenceModel;
130 fn model_for_path(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
131 let path = path.as_ref();
132 let path = CString::new(
133 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
134 )?;
135 let mut model = null_mut();
136 check!(sys::tract_onnx_model_for_path(self.0, path.as_ptr(), &mut model))?;
137 Ok(InferenceModel(model))
138 }
139}
140
141wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
143impl InferenceModelInterface for InferenceModel {
144 type Model = Model;
145 type InferenceFact = InferenceFact;
146 fn set_output_names(
147 &mut self,
148 outputs: impl IntoIterator<Item = impl AsRef<str>>,
149 ) -> Result<()> {
150 let c_strings: Vec<CString> =
151 outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
152 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
153 check!(sys::tract_inference_model_set_output_names(
154 self.0,
155 c_strings.len(),
156 ptrs.as_ptr()
157 ))?;
158 Ok(())
159 }
160
161 fn input_count(&self) -> Result<usize> {
162 let mut count = 0;
163 check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
164 Ok(count)
165 }
166
167 fn output_count(&self) -> Result<usize> {
168 let mut count = 0;
169 check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
170 Ok(count)
171 }
172
173 fn input_name(&self, id: usize) -> Result<String> {
174 let mut ptr = null_mut();
175 check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
176 unsafe {
177 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
178 sys::tract_free_cstring(ptr);
179 Ok(ret)
180 }
181 }
182
183 fn output_name(&self, id: usize) -> Result<String> {
184 let mut ptr = null_mut();
185 check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
186 unsafe {
187 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
188 sys::tract_free_cstring(ptr);
189 Ok(ret)
190 }
191 }
192
193 fn input_fact(&self, id: usize) -> Result<InferenceFact> {
194 let mut ptr = null_mut();
195 check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
196 Ok(InferenceFact(ptr))
197 }
198
199 fn set_input_fact(
200 &mut self,
201 id: usize,
202 fact: impl AsFact<Self, Self::InferenceFact>,
203 ) -> Result<()> {
204 let fact = fact.as_fact(self)?;
205 check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
206 Ok(())
207 }
208
209 fn output_fact(&self, id: usize) -> Result<InferenceFact> {
210 let mut ptr = null_mut();
211 check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
212 Ok(InferenceFact(ptr))
213 }
214
215 fn set_output_fact(
216 &mut self,
217 id: usize,
218 fact: impl AsFact<InferenceModel, InferenceFact>,
219 ) -> Result<()> {
220 let fact = fact.as_fact(self)?;
221 check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
222 Ok(())
223 }
224
225 fn analyse(&mut self) -> Result<()> {
226 check!(sys::tract_inference_model_analyse(self.0))?;
227 Ok(())
228 }
229
230 fn into_typed(mut self) -> Result<Self::Model> {
231 let mut ptr = null_mut();
232 check!(sys::tract_inference_model_into_typed(&mut self.0, &mut ptr))?;
233 Ok(Model(ptr))
234 }
235
236 fn into_optimized(mut self) -> Result<Self::Model> {
237 let mut ptr = null_mut();
238 check!(sys::tract_inference_model_into_optimized(&mut self.0, &mut ptr))?;
239 Ok(Model(ptr))
240 }
241}
242
243wrapper!(Model, TractModel, tract_model_destroy);
245
246impl ModelInterface for Model {
247 type Fact = Fact;
248 type Value = Value;
249 type Runnable = Runnable;
250 fn input_count(&self) -> Result<usize> {
251 let mut count = 0;
252 check!(sys::tract_model_input_count(self.0, &mut count))?;
253 Ok(count)
254 }
255
256 fn output_count(&self) -> Result<usize> {
257 let mut count = 0;
258 check!(sys::tract_model_output_count(self.0, &mut count))?;
259 Ok(count)
260 }
261
262 fn input_name(&self, id: usize) -> Result<String> {
263 let mut ptr = null_mut();
264 check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
265 unsafe {
266 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
267 sys::tract_free_cstring(ptr);
268 Ok(ret)
269 }
270 }
271
272 fn output_name(&self, id: usize) -> Result<String> {
273 let mut ptr = null_mut();
274 check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
275 unsafe {
276 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
277 sys::tract_free_cstring(ptr);
278 Ok(ret)
279 }
280 }
281
282 fn set_output_names(
283 &mut self,
284 outputs: impl IntoIterator<Item = impl AsRef<str>>,
285 ) -> Result<()> {
286 let c_strings: Vec<CString> =
287 outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
288 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
289 check!(sys::tract_model_set_output_names(self.0, c_strings.len(), ptrs.as_ptr()))?;
290 Ok(())
291 }
292
293 fn input_fact(&self, id: usize) -> Result<Fact> {
294 let mut ptr = null_mut();
295 check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
296 Ok(Fact(ptr))
297 }
298
299 fn output_fact(&self, id: usize) -> Result<Fact> {
300 let mut ptr = null_mut();
301 check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
302 Ok(Fact(ptr))
303 }
304
305 fn declutter(&mut self) -> Result<()> {
306 check!(sys::tract_model_declutter(self.0))?;
307 Ok(())
308 }
309
310 fn optimize(&mut self) -> Result<()> {
311 check!(sys::tract_model_optimize(self.0))?;
312 Ok(())
313 }
314
315 fn into_decluttered(self) -> Result<Model> {
316 check!(sys::tract_model_declutter(self.0))?;
317 Ok(self)
318 }
319
320 fn into_optimized(self) -> Result<Model> {
321 check!(sys::tract_model_optimize(self.0))?;
322 Ok(self)
323 }
324
325 fn into_runnable(self) -> Result<Runnable> {
326 let mut model = self;
327 let mut runnable = null_mut();
328 check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
329 Ok(Runnable(runnable))
330 }
331
332 fn concretize_symbols(
333 &mut self,
334 values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
335 ) -> Result<()> {
336 let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
337 let c_strings: Vec<CString> =
338 names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
339 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
340 check!(sys::tract_model_concretize_symbols(
341 self.0,
342 ptrs.len(),
343 ptrs.as_ptr(),
344 values.as_ptr()
345 ))?;
346 Ok(())
347 }
348
349 fn transform(&mut self, transform: &str) -> Result<()> {
350 let t = CString::new(transform)?;
351 check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
352 Ok(())
353 }
354
355 fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
356 let name = CString::new(name.as_ref())?;
357 let value = CString::new(value.as_ref())?;
358 check!(sys::tract_model_pulse_simple(&mut self.0, name.as_ptr(), value.as_ptr()))?;
359 Ok(())
360 }
361
362 fn cost_json(&self) -> Result<String> {
363 let input: Option<Vec<Value>> = None;
364 self.profile_json(input)
365 }
366
367 fn profile_json<I, V, E>(&self, inputs: Option<I>) -> Result<String>
368 where
369 I: IntoIterator<Item = V>,
370 V: TryInto<Value, Error = E>,
371 E: Into<anyhow::Error>,
372 {
373 let inputs = if let Some(inputs) = inputs {
374 let inputs = inputs
375 .into_iter()
376 .map(|i| i.try_into().map_err(|e| e.into()))
377 .collect::<Result<Vec<Value>>>()?;
378 anyhow::ensure!(self.input_count()? == inputs.len());
379 Some(inputs)
380 } else {
381 None
382 };
383 let mut iptrs: Option<Vec<*mut sys::TractValue>> =
384 inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
385 let mut json: *mut i8 = null_mut();
386 let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
387 check!(sys::tract_model_profile_json(self.0, values, &mut json))?;
388 anyhow::ensure!(!json.is_null());
389 unsafe {
390 let s = CStr::from_ptr(json).to_owned();
391 sys::tract_free_cstring(json);
392 Ok(s.to_str()?.to_owned())
393 }
394 }
395
396 fn property_keys(&self) -> Result<Vec<String>> {
397 let mut len = 0;
398 check!(sys::tract_model_property_count(self.0, &mut len))?;
399 let mut keys = vec![null_mut(); len];
400 check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
401 unsafe {
402 keys.into_iter()
403 .map(|pc| {
404 let s = CStr::from_ptr(pc).to_str()?.to_owned();
405 sys::tract_free_cstring(pc);
406 Ok(s)
407 })
408 .collect()
409 }
410 }
411
412 fn property(&self, name: impl AsRef<str>) -> Result<Value> {
413 let mut v = null_mut();
414 let name = CString::new(name.as_ref())?;
415 check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
416 Ok(Value(v))
417 }
418}
419
420wrapper!(Runnable, TractRunnable, tract_runnable_release);
422
423impl RunnableInterface for Runnable {
424 type Value = Value;
425 type State = State;
426
427 fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Value>>
428 where
429 I: IntoIterator<Item = V>,
430 V: TryInto<Value, Error = E>,
431 E: Into<anyhow::Error>,
432 {
433 self.spawn_state()?.run(inputs)
434 }
435
436 fn spawn_state(&self) -> Result<State> {
437 let mut state = null_mut();
438 check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
439 Ok(State(state))
440 }
441
442 fn input_count(&self) -> Result<usize> {
443 let mut count = 0;
444 check!(sys::tract_runnable_input_count(self.0, &mut count))?;
445 Ok(count)
446 }
447
448 fn output_count(&self) -> Result<usize> {
449 let mut count = 0;
450 check!(sys::tract_runnable_output_count(self.0, &mut count))?;
451 Ok(count)
452 }
453}
454
455wrapper!(State, TractState, tract_state_destroy);
457
458impl StateInterface for State {
459 type Value = Value;
460 fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Value>>
461 where
462 I: IntoIterator<Item = V>,
463 V: TryInto<Value, Error = E>,
464 E: Into<anyhow::Error>,
465 {
466 let inputs = inputs
467 .into_iter()
468 .map(|i| i.try_into().map_err(|e| e.into()))
469 .collect::<Result<Vec<Value>>>()?;
470 let mut outputs = vec![null_mut(); self.output_count()?];
471 let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
472 check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
473 let outputs = outputs.into_iter().map(Value).collect();
474 Ok(outputs)
475 }
476
477 fn input_count(&self) -> Result<usize> {
478 let mut count = 0;
479 check!(sys::tract_state_input_count(self.0, &mut count))?;
480 Ok(count)
481 }
482
483 fn output_count(&self) -> Result<usize> {
484 let mut count = 0;
485 check!(sys::tract_state_output_count(self.0, &mut count))?;
486 Ok(count)
487 }
488}
489
490wrapper!(Value, TractValue, tract_value_destroy);
492
493impl ValueInterface for Value {
494 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
495 anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
496 let mut value = null_mut();
497 check!(sys::tract_value_from_bytes(
498 dt as _,
499 shape.len(),
500 shape.as_ptr(),
501 data.as_ptr() as _,
502 &mut value
503 ))?;
504 Ok(Value(value))
505 }
506
507 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
508 let mut rank = 0;
509 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
510 let mut shape = null();
511 let mut data = null();
512 check!(sys::tract_value_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
513 unsafe {
514 let dt: DatumType = std::mem::transmute(dt);
515 let shape = std::slice::from_raw_parts(shape, rank);
516 let len: usize = shape.iter().product();
517 let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
518 Ok((dt, shape, data))
519 }
520 }
521}
522
523value_from_to_ndarray!();
524
525wrapper!(Fact, TractFact, tract_fact_destroy);
527
528impl Fact {
529 fn new(model: &mut Model, spec: impl ToString) -> Result<Fact> {
530 let cstr = CString::new(spec.to_string())?;
531 let mut fact = null_mut();
532 check!(sys::tract_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
533 Ok(Fact(fact))
534 }
535
536 fn dump(&self) -> Result<String> {
537 let mut ptr = null_mut();
538 check!(sys::tract_fact_dump(self.0, &mut ptr))?;
539 unsafe {
540 let s = CStr::from_ptr(ptr).to_owned();
541 sys::tract_free_cstring(ptr);
542 Ok(s.to_str()?.to_owned())
543 }
544 }
545}
546
547impl FactInterface for Fact {}
548
549impl std::fmt::Display for Fact {
550 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551 match self.dump() {
552 Ok(s) => f.write_str(&s),
553 Err(_) => Err(std::fmt::Error),
554 }
555 }
556}
557
558wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
560
561impl InferenceFact {
562 fn new(model: &mut InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
563 let cstr = CString::new(spec.to_string())?;
564 let mut fact = null_mut();
565 check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
566 Ok(InferenceFact(fact))
567 }
568
569 fn dump(&self) -> Result<String> {
570 let mut ptr = null_mut();
571 check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
572 unsafe {
573 let s = CStr::from_ptr(ptr).to_owned();
574 sys::tract_free_cstring(ptr);
575 Ok(s.to_str()?.to_owned())
576 }
577 }
578}
579
580impl InferenceFactInterface for InferenceFact {
581 fn empty() -> Result<InferenceFact> {
582 let mut fact = null_mut();
583 check!(sys::tract_inference_fact_empty(&mut fact))?;
584 Ok(InferenceFact(fact))
585 }
586}
587
588impl Default for InferenceFact {
589 fn default() -> Self {
590 Self::empty().unwrap()
591 }
592}
593
594impl std::fmt::Display for InferenceFact {
595 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
596 match self.dump() {
597 Ok(s) => f.write_str(&s),
598 Err(_) => Err(std::fmt::Error),
599 }
600 }
601}
602
603as_inference_fact_impl!(InferenceModel, InferenceFact);
604as_fact_impl!(Model, Fact);