Skip to main content

tract_proxy/
lib.rs

1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12    ($expr:expr) => {
13        unsafe {
14            if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15                let buf = CStr::from_ptr(sys::tract_get_last_error());
16                Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17            } else {
18                Ok(())
19            }
20        }
21    };
22}
23
24macro_rules! wrapper {
25    ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26        #[derive(Debug, Clone)]
27        pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29        impl Drop for $new_type {
30            fn drop(&mut self) {
31                unsafe {
32                    sys::$dest(&mut self.0);
33                }
34            }
35        }
36    };
37}
38
39pub fn nnef() -> Result<Nnef> {
40    let mut nnef = null_mut();
41    check!(sys::tract_nnef_create(&mut nnef))?;
42    Ok(Nnef(nnef))
43}
44
45pub fn onnx() -> Result<Onnx> {
46    let mut onnx = null_mut();
47    check!(sys::tract_onnx_create(&mut onnx))?;
48    Ok(Onnx(onnx))
49}
50
51pub fn version() -> &'static str {
52    unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
53}
54
55wrapper!(Nnef, TractNnef, tract_nnef_destroy);
56impl NnefInterface for Nnef {
57    type Model = Model;
58    fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
59        let path = path.as_ref();
60        let path = CString::new(
61            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
62        )?;
63        let mut model = null_mut();
64        check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
65        Ok(Model(model))
66    }
67
68    fn load_buffer(&self, data: &[u8]) -> Result<Model> {
69        let mut model = null_mut();
70        check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
71        Ok(Model(model))
72    }
73
74    fn enable_tract_core(&mut self) -> Result<()> {
75        check!(sys::tract_nnef_enable_tract_core(self.0))
76    }
77
78    fn enable_tract_extra(&mut self) -> Result<()> {
79        check!(sys::tract_nnef_enable_tract_extra(self.0))
80    }
81
82    fn enable_tract_transformers(&mut self) -> Result<()> {
83        check!(sys::tract_nnef_enable_tract_transformers(self.0))
84    }
85
86    fn enable_onnx(&mut self) -> Result<()> {
87        check!(sys::tract_nnef_enable_onnx(self.0))
88    }
89
90    fn enable_pulse(&mut self) -> Result<()> {
91        check!(sys::tract_nnef_enable_pulse(self.0))
92    }
93
94    fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
95        check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
96    }
97
98    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
99        let path = path.as_ref();
100        let path = CString::new(
101            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
102        )?;
103        check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
104        Ok(())
105    }
106
107    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
108        let path = path.as_ref();
109        let path = CString::new(
110            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
111        )?;
112        check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
113        Ok(())
114    }
115
116    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
117        let path = path.as_ref();
118        let path = CString::new(
119            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
120        )?;
121        check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
122        Ok(())
123    }
124}
125
126// ONNX
127wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
128
129impl OnnxInterface for Onnx {
130    type InferenceModel = InferenceModel;
131    fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
132        let path = path.as_ref();
133        let path = CString::new(
134            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
135        )?;
136        let mut model = null_mut();
137        check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
138        Ok(InferenceModel(model))
139    }
140
141    fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
142        let mut model = null_mut();
143        check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
144        Ok(InferenceModel(model))
145    }
146}
147
148// INFERENCE MODEL
149wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
150impl InferenceModelInterface for InferenceModel {
151    type Model = Model;
152    type InferenceFact = InferenceFact;
153    fn set_output_names(
154        &mut self,
155        outputs: impl IntoIterator<Item = impl AsRef<str>>,
156    ) -> Result<()> {
157        let c_strings: Vec<CString> =
158            outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
159        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
160        check!(sys::tract_inference_model_set_output_names(
161            self.0,
162            c_strings.len(),
163            ptrs.as_ptr()
164        ))?;
165        Ok(())
166    }
167
168    fn input_count(&self) -> Result<usize> {
169        let mut count = 0;
170        check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
171        Ok(count)
172    }
173
174    fn output_count(&self) -> Result<usize> {
175        let mut count = 0;
176        check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
177        Ok(count)
178    }
179
180    fn input_name(&self, id: usize) -> Result<String> {
181        let mut ptr = null_mut();
182        check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
183        unsafe {
184            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
185            sys::tract_free_cstring(ptr);
186            Ok(ret)
187        }
188    }
189
190    fn output_name(&self, id: usize) -> Result<String> {
191        let mut ptr = null_mut();
192        check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
193        unsafe {
194            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
195            sys::tract_free_cstring(ptr);
196            Ok(ret)
197        }
198    }
199
200    fn input_fact(&self, id: usize) -> Result<InferenceFact> {
201        let mut ptr = null_mut();
202        check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
203        Ok(InferenceFact(ptr))
204    }
205
206    fn set_input_fact(
207        &mut self,
208        id: usize,
209        fact: impl AsFact<Self, Self::InferenceFact>,
210    ) -> Result<()> {
211        let fact = fact.as_fact(self)?;
212        check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
213        Ok(())
214    }
215
216    fn output_fact(&self, id: usize) -> Result<InferenceFact> {
217        let mut ptr = null_mut();
218        check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
219        Ok(InferenceFact(ptr))
220    }
221
222    fn set_output_fact(
223        &mut self,
224        id: usize,
225        fact: impl AsFact<InferenceModel, InferenceFact>,
226    ) -> Result<()> {
227        let fact = fact.as_fact(self)?;
228        check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
229        Ok(())
230    }
231
232    fn analyse(&mut self) -> Result<()> {
233        check!(sys::tract_inference_model_analyse(self.0))?;
234        Ok(())
235    }
236
237    fn into_model(mut self) -> Result<Self::Model> {
238        let mut ptr = null_mut();
239        check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
240        Ok(Model(ptr))
241    }
242}
243
244// MODEL
245wrapper!(Model, TractModel, tract_model_destroy);
246
247impl ModelInterface for Model {
248    type Fact = Fact;
249    type Tensor = Tensor;
250    type Runnable = Runnable;
251    fn input_count(&self) -> Result<usize> {
252        let mut count = 0;
253        check!(sys::tract_model_input_count(self.0, &mut count))?;
254        Ok(count)
255    }
256
257    fn output_count(&self) -> Result<usize> {
258        let mut count = 0;
259        check!(sys::tract_model_output_count(self.0, &mut count))?;
260        Ok(count)
261    }
262
263    fn input_name(&self, id: usize) -> Result<String> {
264        let mut ptr = null_mut();
265        check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
266        unsafe {
267            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
268            sys::tract_free_cstring(ptr);
269            Ok(ret)
270        }
271    }
272
273    fn output_name(&self, id: usize) -> Result<String> {
274        let mut ptr = null_mut();
275        check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
276        unsafe {
277            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
278            sys::tract_free_cstring(ptr);
279            Ok(ret)
280        }
281    }
282
283    fn set_output_names(
284        &mut self,
285        outputs: impl IntoIterator<Item = impl AsRef<str>>,
286    ) -> Result<()> {
287        let c_strings: Vec<CString> =
288            outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
289        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
290        check!(sys::tract_model_set_output_names(self.0, c_strings.len(), ptrs.as_ptr()))?;
291        Ok(())
292    }
293
294    fn input_fact(&self, id: usize) -> Result<Fact> {
295        let mut ptr = null_mut();
296        check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
297        Ok(Fact(ptr))
298    }
299
300    fn output_fact(&self, id: usize) -> Result<Fact> {
301        let mut ptr = null_mut();
302        check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
303        Ok(Fact(ptr))
304    }
305
306    fn into_runnable(self) -> Result<Runnable> {
307        let mut model = self;
308        let mut runnable = null_mut();
309        check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
310        Ok(Runnable(runnable))
311    }
312
313    fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
314        let transform = spec.into().to_transform_string();
315        let t = CString::new(transform)?;
316        check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
317        Ok(())
318    }
319
320    fn property_keys(&self) -> Result<Vec<String>> {
321        let mut len = 0;
322        check!(sys::tract_model_property_count(self.0, &mut len))?;
323        let mut keys = vec![null_mut(); len];
324        check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
325        unsafe {
326            keys.into_iter()
327                .map(|pc| {
328                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
329                    sys::tract_free_cstring(pc);
330                    Ok(s)
331                })
332                .collect()
333        }
334    }
335
336    fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
337        let mut v = null_mut();
338        let name = CString::new(name.as_ref())?;
339        check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
340        Ok(Tensor(v))
341    }
342
343    fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
344        let spec = CString::new(spec)?;
345        let mut ptr = null_mut();
346        check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
347        Ok(Fact(ptr))
348    }
349}
350
351// RUNTIME
352wrapper!(Runtime, TractRuntime, tract_runtime_release);
353
354pub fn runtime_for_name(name: &str) -> Result<Runtime> {
355    let mut rt = null_mut();
356    let name = CString::new(name)?;
357    check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
358    Ok(Runtime(rt))
359}
360
361impl RuntimeInterface for Runtime {
362    type Runnable = Runnable;
363
364    type Model = Model;
365
366    fn name(&self) -> Result<String> {
367        let mut ptr = null_mut();
368        check!(sys::tract_runtime_name(self.0, &mut ptr))?;
369        unsafe {
370            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
371            sys::tract_free_cstring(ptr);
372            Ok(ret)
373        }
374    }
375
376    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
377        let mut model = model;
378        let mut runnable = null_mut();
379        check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
380        Ok(Runnable(runnable))
381    }
382}
383
384// RUNNABLE
385wrapper!(Runnable, TractRunnable, tract_runnable_release);
386unsafe impl Send for Runnable {}
387unsafe impl Sync for Runnable {}
388
389impl RunnableInterface for Runnable {
390    type Tensor = Tensor;
391    type State = State;
392    type Fact = Fact;
393
394    fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
395        StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
396    }
397
398    fn spawn_state(&self) -> Result<State> {
399        let mut state = null_mut();
400        check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
401        Ok(State(state))
402    }
403
404    fn input_count(&self) -> Result<usize> {
405        let mut count = 0;
406        check!(sys::tract_runnable_input_count(self.0, &mut count))?;
407        Ok(count)
408    }
409
410    fn output_count(&self) -> Result<usize> {
411        let mut count = 0;
412        check!(sys::tract_runnable_output_count(self.0, &mut count))?;
413        Ok(count)
414    }
415
416    fn input_fact(&self, id: usize) -> Result<Self::Fact> {
417        let mut ptr = null_mut();
418        check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
419        Ok(Fact(ptr))
420    }
421
422    fn output_fact(&self, id: usize) -> Result<Self::Fact> {
423        let mut ptr = null_mut();
424        check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
425        Ok(Fact(ptr))
426    }
427
428    fn property_keys(&self) -> Result<Vec<String>> {
429        let mut len = 0;
430        check!(sys::tract_runnable_property_count(self.0, &mut len))?;
431        let mut keys = vec![null_mut(); len];
432        check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
433        unsafe {
434            keys.into_iter()
435                .map(|pc| {
436                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
437                    sys::tract_free_cstring(pc);
438                    Ok(s)
439                })
440                .collect()
441        }
442    }
443
444    fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
445        let mut v = null_mut();
446        let name = CString::new(name.as_ref())?;
447        check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
448        Ok(Tensor(v))
449    }
450
451    fn cost_json(&self) -> Result<String> {
452        let input: Option<Vec<Tensor>> = None;
453        self.profile_json(input)
454    }
455
456    fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
457    where
458        I: IntoIterator<Item = IV>,
459        IV: TryInto<Self::Tensor, Error = IE>,
460        IE: Into<anyhow::Error>,
461    {
462        let inputs = if let Some(inputs) = inputs {
463            let inputs = inputs
464                .into_iter()
465                .map(|i| i.try_into().map_err(|e| e.into()))
466                .collect::<Result<Vec<Tensor>>>()?;
467            anyhow::ensure!(self.input_count()? == inputs.len());
468            Some(inputs)
469        } else {
470            None
471        };
472        let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
473            inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
474        let mut json: *mut i8 = null_mut();
475        let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
476
477        check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
478        anyhow::ensure!(!json.is_null());
479        unsafe {
480            let s = CStr::from_ptr(json).to_owned();
481            sys::tract_free_cstring(json);
482            Ok(s.to_str()?.to_owned())
483        }
484    }
485}
486
487// STATE
488wrapper!(State, TractState, tract_state_destroy);
489
490impl StateInterface for State {
491    type Tensor = Tensor;
492    type Fact = Fact;
493
494    fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
495        let inputs = inputs.into_inputs()?;
496        let mut outputs = vec![null_mut(); self.output_count()?];
497        let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
498        check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
499        let outputs = outputs.into_iter().map(Tensor).collect();
500        Ok(outputs)
501    }
502
503    fn input_count(&self) -> Result<usize> {
504        let mut count = 0;
505        check!(sys::tract_state_input_count(self.0, &mut count))?;
506        Ok(count)
507    }
508
509    fn output_count(&self) -> Result<usize> {
510        let mut count = 0;
511        check!(sys::tract_state_output_count(self.0, &mut count))?;
512        Ok(count)
513    }
514}
515
516// TENSOR
517wrapper!(Tensor, TractTensor, tract_tensor_destroy);
518unsafe impl Send for Tensor {}
519unsafe impl Sync for Tensor {}
520
521impl TensorInterface for Tensor {
522    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
523        anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
524        let mut value = null_mut();
525        check!(sys::tract_tensor_from_bytes(
526            dt as _,
527            shape.len(),
528            shape.as_ptr(),
529            data.as_ptr() as _,
530            &mut value
531        ))?;
532        Ok(Tensor(value))
533    }
534
535    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
536        let mut rank = 0;
537        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
538        let mut shape = null();
539        let mut data = null();
540        check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
541        unsafe {
542            let dt: DatumType = std::mem::transmute(dt);
543            let shape = std::slice::from_raw_parts(shape, rank);
544            let len: usize = shape.iter().product();
545            let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
546            Ok((dt, shape, data))
547        }
548    }
549
550    fn datum_type(&self) -> Result<DatumType> {
551        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
552        check!(sys::tract_tensor_as_bytes(
553            self.0,
554            &mut dt,
555            std::ptr::null_mut(),
556            std::ptr::null_mut(),
557            std::ptr::null_mut()
558        ))?;
559        unsafe {
560            let dt: DatumType = std::mem::transmute(dt);
561            Ok(dt)
562        }
563    }
564
565    fn convert_to(&self, to: DatumType) -> Result<Self> {
566        let mut new = null_mut();
567        check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
568        Ok(Tensor(new))
569    }
570}
571
572impl PartialEq for Tensor {
573    fn eq(&self, other: &Self) -> bool {
574        let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
575        let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
576        me_dt == other_dt && me_shape == other_shape && me_data == other_data
577    }
578}
579
580tensor_from_to_ndarray!();
581
582// FACT
583wrapper!(Fact, TractFact, tract_fact_destroy);
584
585impl Fact {
586    fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
587        let cstr = CString::new(spec.to_string())?;
588        let mut fact = null_mut();
589        check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
590        Ok(Fact(fact))
591    }
592
593    fn dump(&self) -> Result<String> {
594        let mut ptr = null_mut();
595        check!(sys::tract_fact_dump(self.0, &mut ptr))?;
596        unsafe {
597            let s = CStr::from_ptr(ptr).to_owned();
598            sys::tract_free_cstring(ptr);
599            Ok(s.to_str()?.to_owned())
600        }
601    }
602}
603
604impl FactInterface for Fact {
605    type Dim = Dim;
606
607    fn datum_type(&self) -> Result<DatumType> {
608        let mut dt = 0u32;
609        check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
610        Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
611    }
612
613    fn rank(&self) -> Result<usize> {
614        let mut rank = 0;
615        check!(sys::tract_fact_rank(self.0, &mut rank))?;
616        Ok(rank)
617    }
618
619    fn dim(&self, axis: usize) -> Result<Self::Dim> {
620        let mut ptr = null_mut();
621        check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
622        Ok(Dim(ptr))
623    }
624}
625
626impl std::fmt::Display for Fact {
627    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
628        match self.dump() {
629            Ok(s) => f.write_str(&s),
630            Err(_) => Err(std::fmt::Error),
631        }
632    }
633}
634
635// INFERENCE FACT
636wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
637
638impl InferenceFact {
639    fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
640        let cstr = CString::new(spec.to_string())?;
641        let mut fact = null_mut();
642        check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
643        Ok(InferenceFact(fact))
644    }
645
646    fn dump(&self) -> Result<String> {
647        let mut ptr = null_mut();
648        check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
649        unsafe {
650            let s = CStr::from_ptr(ptr).to_owned();
651            sys::tract_free_cstring(ptr);
652            Ok(s.to_str()?.to_owned())
653        }
654    }
655}
656
657impl InferenceFactInterface for InferenceFact {
658    fn empty() -> Result<InferenceFact> {
659        let mut fact = null_mut();
660        check!(sys::tract_inference_fact_empty(&mut fact))?;
661        Ok(InferenceFact(fact))
662    }
663}
664
665impl Default for InferenceFact {
666    fn default() -> Self {
667        Self::empty().unwrap()
668    }
669}
670
671impl std::fmt::Display for InferenceFact {
672    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
673        match self.dump() {
674            Ok(s) => f.write_str(&s),
675            Err(_) => Err(std::fmt::Error),
676        }
677    }
678}
679
680as_inference_fact_impl!(InferenceModel, InferenceFact);
681as_fact_impl!(Model, Fact);
682
683// Dim
684wrapper!(Dim, TractDim, tract_dim_destroy);
685
686impl Dim {
687    fn dump(&self) -> Result<String> {
688        let mut ptr = null_mut();
689        check!(sys::tract_dim_dump(self.0, &mut ptr))?;
690        unsafe {
691            let s = CStr::from_ptr(ptr).to_owned();
692            sys::tract_free_cstring(ptr);
693            Ok(s.to_str()?.to_owned())
694        }
695    }
696}
697
698impl DimInterface for Dim {
699    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
700        let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
701        let c_strings: Vec<CString> =
702            names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
703        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
704        let mut ptr = null_mut();
705        check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
706        Ok(Dim(ptr))
707    }
708
709    fn to_int64(&self) -> Result<i64> {
710        let mut i = 0;
711        check!(sys::tract_dim_to_int64(self.0, &mut i))?;
712        Ok(i)
713    }
714}
715
716impl std::fmt::Display for Dim {
717    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718        match self.dump() {
719            Ok(s) => f.write_str(&s),
720            Err(_) => Err(std::fmt::Error),
721        }
722    }
723}