Skip to main content

tract_proxy/
lib.rs

1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12    ($expr:expr) => {
13        unsafe {
14            if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15                let buf = CStr::from_ptr(sys::tract_get_last_error());
16                Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17            } else {
18                Ok(())
19            }
20        }
21    };
22}
23
24macro_rules! wrapper {
25    ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26        #[derive(Debug, Clone)]
27        pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29        impl Drop for $new_type {
30            fn drop(&mut self) {
31                unsafe {
32                    sys::$dest(&mut self.0);
33                }
34            }
35        }
36    };
37}
38
39pub fn nnef() -> Result<Nnef> {
40    let mut nnef = null_mut();
41    check!(sys::tract_nnef_create(&mut nnef))?;
42    Ok(Nnef(nnef))
43}
44
45pub fn onnx() -> Result<Onnx> {
46    let mut onnx = null_mut();
47    check!(sys::tract_onnx_create(&mut onnx))?;
48    Ok(Onnx(onnx))
49}
50
51pub fn version() -> &'static str {
52    unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
53}
54
55wrapper!(Nnef, TractNnef, tract_nnef_destroy);
56impl NnefInterface for Nnef {
57    type Model = Model;
58    fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
59        let path = path.as_ref();
60        let path = CString::new(
61            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
62        )?;
63        let mut model = null_mut();
64        check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
65        Ok(Model(model))
66    }
67
68    fn load_buffer(&self, data: &[u8]) -> Result<Model> {
69        let mut model = null_mut();
70        check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
71        Ok(Model(model))
72    }
73
74    fn enable_tract_core(&mut self) -> Result<()> {
75        check!(sys::tract_nnef_enable_tract_core(self.0))
76    }
77
78    fn enable_tract_extra(&mut self) -> Result<()> {
79        check!(sys::tract_nnef_enable_tract_extra(self.0))
80    }
81
82    fn enable_tract_transformers(&mut self) -> Result<()> {
83        check!(sys::tract_nnef_enable_tract_transformers(self.0))
84    }
85
86    fn enable_onnx(&mut self) -> Result<()> {
87        check!(sys::tract_nnef_enable_onnx(self.0))
88    }
89
90    fn enable_pulse(&mut self) -> Result<()> {
91        check!(sys::tract_nnef_enable_pulse(self.0))
92    }
93
94    fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
95        check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
96    }
97
98    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
99        let path = path.as_ref();
100        let path = CString::new(
101            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
102        )?;
103        check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
104        Ok(())
105    }
106
107    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
108        let path = path.as_ref();
109        let path = CString::new(
110            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
111        )?;
112        check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
113        Ok(())
114    }
115
116    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
117        let path = path.as_ref();
118        let path = CString::new(
119            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
120        )?;
121        check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
122        Ok(())
123    }
124}
125
126// ONNX
127wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
128
129impl OnnxInterface for Onnx {
130    type InferenceModel = InferenceModel;
131    fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
132        let path = path.as_ref();
133        let path = CString::new(
134            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
135        )?;
136        let mut model = null_mut();
137        check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
138        Ok(InferenceModel(model))
139    }
140
141    fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
142        let mut model = null_mut();
143        check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
144        Ok(InferenceModel(model))
145    }
146}
147
148// INFERENCE MODEL
149wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
150impl InferenceModelInterface for InferenceModel {
151    type Model = Model;
152    type InferenceFact = InferenceFact;
153    fn set_output_names(
154        &mut self,
155        outputs: impl IntoIterator<Item = impl AsRef<str>>,
156    ) -> Result<()> {
157        let c_strings: Vec<CString> =
158            outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
159        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
160        check!(sys::tract_inference_model_set_output_names(
161            self.0,
162            c_strings.len(),
163            ptrs.as_ptr()
164        ))?;
165        Ok(())
166    }
167
168    fn input_count(&self) -> Result<usize> {
169        let mut count = 0;
170        check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
171        Ok(count)
172    }
173
174    fn output_count(&self) -> Result<usize> {
175        let mut count = 0;
176        check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
177        Ok(count)
178    }
179
180    fn input_name(&self, id: usize) -> Result<String> {
181        let mut ptr = null_mut();
182        check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
183        unsafe {
184            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
185            sys::tract_free_cstring(ptr);
186            Ok(ret)
187        }
188    }
189
190    fn output_name(&self, id: usize) -> Result<String> {
191        let mut ptr = null_mut();
192        check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
193        unsafe {
194            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
195            sys::tract_free_cstring(ptr);
196            Ok(ret)
197        }
198    }
199
200    fn input_fact(&self, id: usize) -> Result<InferenceFact> {
201        let mut ptr = null_mut();
202        check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
203        Ok(InferenceFact(ptr))
204    }
205
206    fn set_input_fact(
207        &mut self,
208        id: usize,
209        fact: impl AsFact<Self, Self::InferenceFact>,
210    ) -> Result<()> {
211        let fact = fact.as_fact(self)?;
212        check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
213        Ok(())
214    }
215
216    fn output_fact(&self, id: usize) -> Result<InferenceFact> {
217        let mut ptr = null_mut();
218        check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
219        Ok(InferenceFact(ptr))
220    }
221
222    fn set_output_fact(
223        &mut self,
224        id: usize,
225        fact: impl AsFact<InferenceModel, InferenceFact>,
226    ) -> Result<()> {
227        let fact = fact.as_fact(self)?;
228        check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
229        Ok(())
230    }
231
232    fn analyse(&mut self) -> Result<()> {
233        check!(sys::tract_inference_model_analyse(self.0))?;
234        Ok(())
235    }
236
237    fn into_tract(mut self) -> Result<Self::Model> {
238        let mut ptr = null_mut();
239        check!(sys::tract_inference_model_into_tract(&mut self.0, &mut ptr))?;
240        Ok(Model(ptr))
241    }
242}
243
244// MODEL
245wrapper!(Model, TractModel, tract_model_destroy);
246
247impl ModelInterface for Model {
248    type Fact = Fact;
249    type Value = Value;
250    type Runnable = Runnable;
251    fn input_count(&self) -> Result<usize> {
252        let mut count = 0;
253        check!(sys::tract_model_input_count(self.0, &mut count))?;
254        Ok(count)
255    }
256
257    fn output_count(&self) -> Result<usize> {
258        let mut count = 0;
259        check!(sys::tract_model_output_count(self.0, &mut count))?;
260        Ok(count)
261    }
262
263    fn input_name(&self, id: usize) -> Result<String> {
264        let mut ptr = null_mut();
265        check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
266        unsafe {
267            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
268            sys::tract_free_cstring(ptr);
269            Ok(ret)
270        }
271    }
272
273    fn output_name(&self, id: usize) -> Result<String> {
274        let mut ptr = null_mut();
275        check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
276        unsafe {
277            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
278            sys::tract_free_cstring(ptr);
279            Ok(ret)
280        }
281    }
282
283    fn set_output_names(
284        &mut self,
285        outputs: impl IntoIterator<Item = impl AsRef<str>>,
286    ) -> Result<()> {
287        let c_strings: Vec<CString> =
288            outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
289        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
290        check!(sys::tract_model_set_output_names(self.0, c_strings.len(), ptrs.as_ptr()))?;
291        Ok(())
292    }
293
294    fn input_fact(&self, id: usize) -> Result<Fact> {
295        let mut ptr = null_mut();
296        check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
297        Ok(Fact(ptr))
298    }
299
300    fn output_fact(&self, id: usize) -> Result<Fact> {
301        let mut ptr = null_mut();
302        check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
303        Ok(Fact(ptr))
304    }
305
306    fn into_runnable(self) -> Result<Runnable> {
307        let mut model = self;
308        let mut runnable = null_mut();
309        check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
310        Ok(Runnable(runnable))
311    }
312
313    fn concretize_symbols(
314        &mut self,
315        values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
316    ) -> Result<()> {
317        let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
318        let c_strings: Vec<CString> =
319            names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
320        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
321        check!(sys::tract_model_concretize_symbols(
322            self.0,
323            ptrs.len(),
324            ptrs.as_ptr(),
325            values.as_ptr()
326        ))?;
327        Ok(())
328    }
329
330    fn transform(&mut self, transform: &str) -> Result<()> {
331        let t = CString::new(transform)?;
332        check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
333        Ok(())
334    }
335
336    fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
337        let name = CString::new(name.as_ref())?;
338        let value = CString::new(value.as_ref())?;
339        check!(sys::tract_model_pulse_simple(&mut self.0, name.as_ptr(), value.as_ptr()))?;
340        Ok(())
341    }
342
343    fn property_keys(&self) -> Result<Vec<String>> {
344        let mut len = 0;
345        check!(sys::tract_model_property_count(self.0, &mut len))?;
346        let mut keys = vec![null_mut(); len];
347        check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
348        unsafe {
349            keys.into_iter()
350                .map(|pc| {
351                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
352                    sys::tract_free_cstring(pc);
353                    Ok(s)
354                })
355                .collect()
356        }
357    }
358
359    fn property(&self, name: impl AsRef<str>) -> Result<Value> {
360        let mut v = null_mut();
361        let name = CString::new(name.as_ref())?;
362        check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
363        Ok(Value(v))
364    }
365
366    fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
367        let spec = CString::new(spec)?;
368        let mut ptr = null_mut();
369        check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
370        Ok(Fact(ptr))
371    }
372}
373
374// RUNTIME
375wrapper!(Runtime, TractRuntime, tract_runtime_release);
376
377pub fn runtime_for_name(name: &str) -> Result<Runtime> {
378    let mut rt = null_mut();
379    let name = CString::new(name)?;
380    check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
381    Ok(Runtime(rt))
382}
383
384impl RuntimeInterface for Runtime {
385    type Runnable = Runnable;
386
387    type Model = Model;
388
389    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
390        let mut model = model;
391        let mut runnable = null_mut();
392        check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
393        Ok(Runnable(runnable))
394    }
395}
396
397// RUNNABLE
398wrapper!(Runnable, TractRunnable, tract_runnable_release);
399unsafe impl Send for Runnable {}
400unsafe impl Sync for Runnable {}
401
402impl RunnableInterface for Runnable {
403    type Value = Value;
404    type State = State;
405    type Fact = Fact;
406
407    fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Value>>
408    where
409        I: IntoIterator<Item = V>,
410        V: TryInto<Value, Error = E>,
411        E: Into<anyhow::Error>,
412    {
413        self.spawn_state()?.run(inputs)
414    }
415
416    fn spawn_state(&self) -> Result<State> {
417        let mut state = null_mut();
418        check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
419        Ok(State(state))
420    }
421
422    fn input_count(&self) -> Result<usize> {
423        let mut count = 0;
424        check!(sys::tract_runnable_input_count(self.0, &mut count))?;
425        Ok(count)
426    }
427
428    fn output_count(&self) -> Result<usize> {
429        let mut count = 0;
430        check!(sys::tract_runnable_output_count(self.0, &mut count))?;
431        Ok(count)
432    }
433
434    fn input_fact(&self, id: usize) -> Result<Self::Fact> {
435        let mut ptr = null_mut();
436        check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
437        Ok(Fact(ptr))
438    }
439
440    fn output_fact(&self, id: usize) -> Result<Self::Fact> {
441        let mut ptr = null_mut();
442        check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
443        Ok(Fact(ptr))
444    }
445
446    fn property_keys(&self) -> Result<Vec<String>> {
447        let mut len = 0;
448        check!(sys::tract_runnable_property_count(self.0, &mut len))?;
449        let mut keys = vec![null_mut(); len];
450        check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
451        unsafe {
452            keys.into_iter()
453                .map(|pc| {
454                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
455                    sys::tract_free_cstring(pc);
456                    Ok(s)
457                })
458                .collect()
459        }
460    }
461
462    fn property(&self, name: impl AsRef<str>) -> Result<Value> {
463        let mut v = null_mut();
464        let name = CString::new(name.as_ref())?;
465        check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
466        Ok(Value(v))
467    }
468
469    fn cost_json(&self) -> Result<String> {
470        let input: Option<Vec<Value>> = None;
471        let states: Option<Vec<Value>> = None;
472        self.profile_json(input, states)
473    }
474
475    fn profile_json<I, IV, IE, S, SV, SE>(
476        &self,
477        inputs: Option<I>,
478        state_initializers: Option<S>,
479    ) -> Result<String>
480    where
481        I: IntoIterator<Item = IV>,
482        IV: TryInto<Self::Value, Error = IE>,
483        IE: Into<anyhow::Error>,
484        S: IntoIterator<Item = SV>,
485        SV: TryInto<Self::Value, Error = SE>,
486        SE: Into<anyhow::Error>,
487    {
488        let inputs = if let Some(inputs) = inputs {
489            let inputs = inputs
490                .into_iter()
491                .map(|i| i.try_into().map_err(|e| e.into()))
492                .collect::<Result<Vec<Value>>>()?;
493            anyhow::ensure!(self.input_count()? == inputs.len());
494            Some(inputs)
495        } else {
496            None
497        };
498        let mut iptrs: Option<Vec<*mut sys::TractValue>> =
499            inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
500        let mut json: *mut i8 = null_mut();
501        let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
502
503        let (state_inits, n_states) = if let Some(state_vec) = state_initializers {
504            let mut states: Vec<*const _> = vec![];
505
506            for v in state_vec {
507                let val: Value = v.try_into().map_err(|e| e.into())?;
508                states.push(val.0);
509            }
510            let len = states.len();
511            (Some(states), len)
512        } else {
513            (None, 0)
514        };
515
516        let states = state_inits.map(|is| is.as_ptr()).unwrap_or(null());
517        check!(sys::tract_runnable_profile_json(self.0, values, states, n_states, &mut json))?;
518        anyhow::ensure!(!json.is_null());
519        unsafe {
520            let s = CStr::from_ptr(json).to_owned();
521            sys::tract_free_cstring(json);
522            Ok(s.to_str()?.to_owned())
523        }
524    }
525}
526
527// STATE
528wrapper!(State, TractState, tract_state_destroy);
529
530impl StateInterface for State {
531    type Value = Value;
532    type Fact = Fact;
533
534    fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Value>>
535    where
536        I: IntoIterator<Item = V>,
537        V: TryInto<Value, Error = E>,
538        E: Into<anyhow::Error>,
539    {
540        let inputs = inputs
541            .into_iter()
542            .map(|i| i.try_into().map_err(|e| e.into()))
543            .collect::<Result<Vec<Value>>>()?;
544        let mut outputs = vec![null_mut(); self.output_count()?];
545        let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
546        check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
547        let outputs = outputs.into_iter().map(Value).collect();
548        Ok(outputs)
549    }
550
551    fn input_count(&self) -> Result<usize> {
552        let mut count = 0;
553        check!(sys::tract_state_input_count(self.0, &mut count))?;
554        Ok(count)
555    }
556
557    fn output_count(&self) -> Result<usize> {
558        let mut count = 0;
559        check!(sys::tract_state_output_count(self.0, &mut count))?;
560        Ok(count)
561    }
562
563    #[allow(deprecated)]
564    fn initializable_states_count(&self) -> Result<usize> {
565        let mut n_states = 0;
566        check!(sys::tract_state_initializable_states_count(self.0, &mut n_states))?;
567        Ok(n_states)
568    }
569
570    #[allow(deprecated)]
571    fn get_states_facts(&self) -> Result<Vec<Fact>> {
572        let n_states = self.initializable_states_count()?;
573        let mut fptrs = vec![null_mut(); n_states];
574
575        check!(sys::tract_state_get_states_facts(self.0, fptrs.as_mut_ptr()))?;
576
577        let res = fptrs.into_iter().map(|value| Ok(Fact(value))).collect::<Result<Vec<Fact>>>();
578
579        res
580    }
581
582    #[allow(deprecated)]
583    fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
584    where
585        I: IntoIterator<Item = V>,
586        V: TryInto<Self::Value, Error = E>,
587        E: Into<anyhow::Error>,
588    {
589        let sptrs = {
590            let mut states: Vec<*const _> = vec![];
591
592            for s in state_initializers {
593                let val: Value = s.try_into().map_err(|e| e.into())?;
594                states.push(val.0);
595            }
596
597            let len = states.len();
598            anyhow::ensure!(
599                len == self.initializable_states_count()?,
600                "Expected {} states, got {len}",
601                self.initializable_states_count()?
602            );
603            Some(states)
604        };
605
606        let sptrs = sptrs.map(|it| it.as_ptr()).unwrap_or(null());
607        check!(sys::tract_state_set_states(self.0, sptrs))?;
608
609        Ok(())
610    }
611
612    #[allow(deprecated)]
613    fn get_states(&self) -> Result<Vec<Self::Value>> {
614        let n_states = self.initializable_states_count()?;
615
616        let mut sptrs = vec![null_mut(); n_states];
617        check!(sys::tract_state_get_states(self.0, sptrs.as_mut_ptr()))?;
618
619        let res = sptrs.into_iter().map(|value| Ok(Value(value))).collect::<Result<Vec<Value>>>();
620
621        res
622    }
623}
624
625// VALUE
626wrapper!(Value, TractValue, tract_value_destroy);
627unsafe impl Send for Value {}
628unsafe impl Sync for Value {}
629
630impl ValueInterface for Value {
631    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
632        anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
633        let mut value = null_mut();
634        check!(sys::tract_value_from_bytes(
635            dt as _,
636            shape.len(),
637            shape.as_ptr(),
638            data.as_ptr() as _,
639            &mut value
640        ))?;
641        Ok(Value(value))
642    }
643
644    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
645        let mut rank = 0;
646        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
647        let mut shape = null();
648        let mut data = null();
649        check!(sys::tract_value_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
650        unsafe {
651            let dt: DatumType = std::mem::transmute(dt);
652            let shape = std::slice::from_raw_parts(shape, rank);
653            let len: usize = shape.iter().product();
654            let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
655            Ok((dt, shape, data))
656        }
657    }
658
659    fn datum_type(&self) -> Result<DatumType> {
660        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
661        check!(sys::tract_value_as_bytes(
662            self.0,
663            &mut dt,
664            std::ptr::null_mut(),
665            std::ptr::null_mut(),
666            std::ptr::null_mut()
667        ))?;
668        unsafe {
669            let dt: DatumType = std::mem::transmute(dt);
670            Ok(dt)
671        }
672    }
673
674    fn convert_to(&self, to: DatumType) -> Result<Self> {
675        let mut new = null_mut();
676        check!(sys::tract_value_convert_to(self.0, to as _, &mut new))?;
677        Ok(Value(new))
678    }
679}
680
681impl PartialEq for Value {
682    fn eq(&self, other: &Self) -> bool {
683        let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
684        let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
685        me_dt == other_dt && me_shape == other_shape && me_data == other_data
686    }
687}
688
689value_from_to_ndarray!();
690
691// FACT
692wrapper!(Fact, TractFact, tract_fact_destroy);
693
694impl Fact {
695    fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
696        let cstr = CString::new(spec.to_string())?;
697        let mut fact = null_mut();
698        check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
699        Ok(Fact(fact))
700    }
701
702    fn dump(&self) -> Result<String> {
703        let mut ptr = null_mut();
704        check!(sys::tract_fact_dump(self.0, &mut ptr))?;
705        unsafe {
706            let s = CStr::from_ptr(ptr).to_owned();
707            sys::tract_free_cstring(ptr);
708            Ok(s.to_str()?.to_owned())
709        }
710    }
711}
712
713impl FactInterface for Fact {
714    type Dim = Dim;
715
716    fn datum_type(&self) -> Result<DatumType> {
717        let mut dt = 0u32;
718        check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
719        Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
720    }
721
722    fn rank(&self) -> Result<usize> {
723        let mut rank = 0;
724        check!(sys::tract_fact_rank(self.0, &mut rank))?;
725        Ok(rank)
726    }
727
728    fn dim(&self, axis: usize) -> Result<Self::Dim> {
729        let mut ptr = null_mut();
730        check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
731        Ok(Dim(ptr))
732    }
733}
734
735impl std::fmt::Display for Fact {
736    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737        match self.dump() {
738            Ok(s) => f.write_str(&s),
739            Err(_) => Err(std::fmt::Error),
740        }
741    }
742}
743
744// INFERENCE FACT
745wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
746
747impl InferenceFact {
748    fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
749        let cstr = CString::new(spec.to_string())?;
750        let mut fact = null_mut();
751        check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
752        Ok(InferenceFact(fact))
753    }
754
755    fn dump(&self) -> Result<String> {
756        let mut ptr = null_mut();
757        check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
758        unsafe {
759            let s = CStr::from_ptr(ptr).to_owned();
760            sys::tract_free_cstring(ptr);
761            Ok(s.to_str()?.to_owned())
762        }
763    }
764}
765
766impl InferenceFactInterface for InferenceFact {
767    fn empty() -> Result<InferenceFact> {
768        let mut fact = null_mut();
769        check!(sys::tract_inference_fact_empty(&mut fact))?;
770        Ok(InferenceFact(fact))
771    }
772}
773
774impl Default for InferenceFact {
775    fn default() -> Self {
776        Self::empty().unwrap()
777    }
778}
779
780impl std::fmt::Display for InferenceFact {
781    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
782        match self.dump() {
783            Ok(s) => f.write_str(&s),
784            Err(_) => Err(std::fmt::Error),
785        }
786    }
787}
788
789as_inference_fact_impl!(InferenceModel, InferenceFact);
790as_fact_impl!(Model, Fact);
791
792// Dim
793wrapper!(Dim, TractDim, tract_dim_destroy);
794
795impl Dim {
796    fn dump(&self) -> Result<String> {
797        let mut ptr = null_mut();
798        check!(sys::tract_dim_dump(self.0, &mut ptr))?;
799        unsafe {
800            let s = CStr::from_ptr(ptr).to_owned();
801            sys::tract_free_cstring(ptr);
802            Ok(s.to_str()?.to_owned())
803        }
804    }
805}
806
807impl DimInterface for Dim {
808    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
809        let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
810        let c_strings: Vec<CString> =
811            names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
812        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
813        let mut ptr = null_mut();
814        check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
815        Ok(Dim(ptr))
816    }
817
818    fn to_int64(&self) -> Result<i64> {
819        let mut i = 0;
820        check!(sys::tract_dim_to_int64(self.0, &mut i))?;
821        Ok(i)
822    }
823}
824
825impl std::fmt::Display for Dim {
826    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
827        match self.dump() {
828            Ok(s) => f.write_str(&s),
829            Err(_) => Err(std::fmt::Error),
830        }
831    }
832}