Skip to main content

tract_proxy/
lib.rs

1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9
10mod ndarray_interop;
11pub use ndarray_interop::__ndarray_interop;
12
13macro_rules! check {
14    ($expr:expr) => {
15        unsafe {
16            if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
17                let buf = CStr::from_ptr(sys::tract_get_last_error());
18                Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
19            } else {
20                Ok(())
21            }
22        }
23    };
24}
25
26macro_rules! wrapper {
27    ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
28        #[derive(Debug)]
29        pub struct $new_type(*mut sys::$c_type $(, $typ)*);
30
31        impl Drop for $new_type {
32            fn drop(&mut self) {
33                unsafe {
34                    sys::$dest(&mut self.0);
35                }
36            }
37        }
38    };
39}
40
41macro_rules! wrapper_clone {
42    ($new_type:ident, $clone_fn:ident) => {
43        impl Clone for $new_type {
44            fn clone(&self) -> Self {
45                let mut clone = null_mut();
46                unsafe {
47                    sys::$clone_fn(self.0, &mut clone);
48                }
49                $new_type(clone)
50            }
51        }
52    };
53}
54
55pub fn nnef() -> Result<Nnef> {
56    let mut nnef = null_mut();
57    check!(sys::tract_nnef_create(&mut nnef))?;
58    Ok(Nnef(nnef))
59}
60
61pub fn onnx() -> Result<Onnx> {
62    let mut onnx = null_mut();
63    check!(sys::tract_onnx_create(&mut onnx))?;
64    Ok(Onnx(onnx))
65}
66
67pub fn version() -> &'static str {
68    unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
69}
70
71wrapper!(Nnef, TractNnef, tract_nnef_destroy);
72impl NnefInterface for Nnef {
73    type Model = Model;
74    fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
75        let path = path.as_ref();
76        let path = CString::new(
77            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
78        )?;
79        let mut model = null_mut();
80        check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
81        Ok(Model(model))
82    }
83
84    fn load_buffer(&self, data: &[u8]) -> Result<Model> {
85        let mut model = null_mut();
86        check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
87        Ok(Model(model))
88    }
89
90    fn enable_tract_core(&mut self) -> Result<()> {
91        check!(sys::tract_nnef_enable_tract_core(self.0))
92    }
93
94    fn enable_tract_extra(&mut self) -> Result<()> {
95        check!(sys::tract_nnef_enable_tract_extra(self.0))
96    }
97
98    fn enable_tract_transformers(&mut self) -> Result<()> {
99        check!(sys::tract_nnef_enable_tract_transformers(self.0))
100    }
101
102    fn enable_onnx(&mut self) -> Result<()> {
103        check!(sys::tract_nnef_enable_onnx(self.0))
104    }
105
106    fn enable_pulse(&mut self) -> Result<()> {
107        check!(sys::tract_nnef_enable_pulse(self.0))
108    }
109
110    fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
111        check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
112    }
113
114    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
115        let path = path.as_ref();
116        let path = CString::new(
117            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
118        )?;
119        check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
120        Ok(())
121    }
122
123    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
124        let path = path.as_ref();
125        let path = CString::new(
126            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
127        )?;
128        check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
129        Ok(())
130    }
131
132    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
133        let path = path.as_ref();
134        let path = CString::new(
135            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
136        )?;
137        check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
138        Ok(())
139    }
140}
141
142// ONNX
143wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
144
145impl OnnxInterface for Onnx {
146    type InferenceModel = InferenceModel;
147    fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
148        let path = path.as_ref();
149        let path = CString::new(
150            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
151        )?;
152        let mut model = null_mut();
153        check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
154        Ok(InferenceModel(model))
155    }
156
157    fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
158        let mut model = null_mut();
159        check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
160        Ok(InferenceModel(model))
161    }
162}
163
164// INFERENCE MODEL
165wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
166impl InferenceModelInterface for InferenceModel {
167    type Model = Model;
168    type InferenceFact = InferenceFact;
169    fn input_count(&self) -> Result<usize> {
170        let mut count = 0;
171        check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
172        Ok(count)
173    }
174
175    fn output_count(&self) -> Result<usize> {
176        let mut count = 0;
177        check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
178        Ok(count)
179    }
180
181    fn input_name(&self, id: usize) -> Result<String> {
182        let mut ptr = null_mut();
183        check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
184        unsafe {
185            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
186            sys::tract_free_cstring(ptr);
187            Ok(ret)
188        }
189    }
190
191    fn output_name(&self, id: usize) -> Result<String> {
192        let mut ptr = null_mut();
193        check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
194        unsafe {
195            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
196            sys::tract_free_cstring(ptr);
197            Ok(ret)
198        }
199    }
200
201    fn input_fact(&self, id: usize) -> Result<InferenceFact> {
202        let mut ptr = null_mut();
203        check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
204        Ok(InferenceFact(ptr))
205    }
206
207    fn set_input_fact(
208        &mut self,
209        id: usize,
210        fact: impl AsFact<Self, Self::InferenceFact>,
211    ) -> Result<()> {
212        let fact = fact.as_fact(self)?;
213        check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
214        Ok(())
215    }
216
217    fn output_fact(&self, id: usize) -> Result<InferenceFact> {
218        let mut ptr = null_mut();
219        check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
220        Ok(InferenceFact(ptr))
221    }
222
223    fn set_output_fact(
224        &mut self,
225        id: usize,
226        fact: impl AsFact<InferenceModel, InferenceFact>,
227    ) -> Result<()> {
228        let fact = fact.as_fact(self)?;
229        check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
230        Ok(())
231    }
232
233    fn analyse(&mut self) -> Result<()> {
234        check!(sys::tract_inference_model_analyse(self.0))?;
235        Ok(())
236    }
237
238    fn into_model(mut self) -> Result<Self::Model> {
239        let mut ptr = null_mut();
240        check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
241        Ok(Model(ptr))
242    }
243}
244
245// MODEL
246wrapper!(Model, TractModel, tract_model_destroy);
247
248impl ModelInterface for Model {
249    type Fact = Fact;
250    type Tensor = Tensor;
251    type Runnable = Runnable;
252    fn input_count(&self) -> Result<usize> {
253        let mut count = 0;
254        check!(sys::tract_model_input_count(self.0, &mut count))?;
255        Ok(count)
256    }
257
258    fn output_count(&self) -> Result<usize> {
259        let mut count = 0;
260        check!(sys::tract_model_output_count(self.0, &mut count))?;
261        Ok(count)
262    }
263
264    fn input_name(&self, id: usize) -> Result<String> {
265        let mut ptr = null_mut();
266        check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
267        unsafe {
268            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
269            sys::tract_free_cstring(ptr);
270            Ok(ret)
271        }
272    }
273
274    fn output_name(&self, id: usize) -> Result<String> {
275        let mut ptr = null_mut();
276        check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
277        unsafe {
278            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
279            sys::tract_free_cstring(ptr);
280            Ok(ret)
281        }
282    }
283
284    fn input_fact(&self, id: usize) -> Result<Fact> {
285        let mut ptr = null_mut();
286        check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
287        Ok(Fact(ptr))
288    }
289
290    fn output_fact(&self, id: usize) -> Result<Fact> {
291        let mut ptr = null_mut();
292        check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
293        Ok(Fact(ptr))
294    }
295
296    fn into_runnable(self) -> Result<Runnable> {
297        let mut model = self;
298        let mut runnable = null_mut();
299        check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
300        Ok(Runnable(runnable))
301    }
302
303    fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
304        let transform = spec.into().to_transform_string();
305        let t = CString::new(transform)?;
306        check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
307        Ok(())
308    }
309
310    fn property_keys(&self) -> Result<Vec<String>> {
311        let mut len = 0;
312        check!(sys::tract_model_property_count(self.0, &mut len))?;
313        let mut keys = vec![null_mut(); len];
314        check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
315        unsafe {
316            keys.into_iter()
317                .map(|pc| {
318                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
319                    sys::tract_free_cstring(pc);
320                    Ok(s)
321                })
322                .collect()
323        }
324    }
325
326    fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
327        let mut v = null_mut();
328        let name = CString::new(name.as_ref())?;
329        check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
330        Ok(Tensor(v))
331    }
332
333    fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
334        let spec = CString::new(spec)?;
335        let mut ptr = null_mut();
336        check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
337        Ok(Fact(ptr))
338    }
339}
340
341// RUNTIME
342wrapper!(Runtime, TractRuntime, tract_runtime_release);
343
344pub fn runtime_for_name(name: &str) -> Result<Runtime> {
345    let mut rt = null_mut();
346    let name = CString::new(name)?;
347    check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
348    Ok(Runtime(rt))
349}
350
351impl RuntimeInterface for Runtime {
352    type Runnable = Runnable;
353
354    type Model = Model;
355
356    fn name(&self) -> Result<String> {
357        let mut ptr = null_mut();
358        check!(sys::tract_runtime_name(self.0, &mut ptr))?;
359        unsafe {
360            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
361            sys::tract_free_cstring(ptr);
362            Ok(ret)
363        }
364    }
365
366    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
367        let mut model = model;
368        let mut runnable = null_mut();
369        check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
370        Ok(Runnable(runnable))
371    }
372}
373
374// RUNNABLE
375wrapper!(Runnable, TractRunnable, tract_runnable_release);
376unsafe impl Send for Runnable {}
377unsafe impl Sync for Runnable {}
378
379impl RunnableInterface for Runnable {
380    type Tensor = Tensor;
381    type State = State;
382    type Fact = Fact;
383
384    fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
385        StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
386    }
387
388    fn spawn_state(&self) -> Result<State> {
389        let mut state = null_mut();
390        check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
391        Ok(State(state))
392    }
393
394    fn input_count(&self) -> Result<usize> {
395        let mut count = 0;
396        check!(sys::tract_runnable_input_count(self.0, &mut count))?;
397        Ok(count)
398    }
399
400    fn output_count(&self) -> Result<usize> {
401        let mut count = 0;
402        check!(sys::tract_runnable_output_count(self.0, &mut count))?;
403        Ok(count)
404    }
405
406    fn input_fact(&self, id: usize) -> Result<Self::Fact> {
407        let mut ptr = null_mut();
408        check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
409        Ok(Fact(ptr))
410    }
411
412    fn output_fact(&self, id: usize) -> Result<Self::Fact> {
413        let mut ptr = null_mut();
414        check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
415        Ok(Fact(ptr))
416    }
417
418    fn property_keys(&self) -> Result<Vec<String>> {
419        let mut len = 0;
420        check!(sys::tract_runnable_property_count(self.0, &mut len))?;
421        let mut keys = vec![null_mut(); len];
422        check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
423        unsafe {
424            keys.into_iter()
425                .map(|pc| {
426                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
427                    sys::tract_free_cstring(pc);
428                    Ok(s)
429                })
430                .collect()
431        }
432    }
433
434    fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
435        let mut v = null_mut();
436        let name = CString::new(name.as_ref())?;
437        check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
438        Ok(Tensor(v))
439    }
440
441    fn cost_json(&self) -> Result<String> {
442        let input: Option<Vec<Tensor>> = None;
443        self.profile_json(input)
444    }
445
446    fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
447    where
448        I: IntoIterator<Item = IV>,
449        IV: TryInto<Self::Tensor, Error = IE>,
450        IE: Into<anyhow::Error>,
451    {
452        let inputs = if let Some(inputs) = inputs {
453            let inputs = inputs
454                .into_iter()
455                .map(|i| i.try_into().map_err(|e| e.into()))
456                .collect::<Result<Vec<Tensor>>>()?;
457            anyhow::ensure!(self.input_count()? == inputs.len());
458            Some(inputs)
459        } else {
460            None
461        };
462        let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
463            inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
464        let mut json: *mut i8 = null_mut();
465        let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
466
467        check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
468        anyhow::ensure!(!json.is_null());
469        unsafe {
470            let s = CStr::from_ptr(json).to_owned();
471            sys::tract_free_cstring(json);
472            Ok(s.to_str()?.to_owned())
473        }
474    }
475}
476
477// STATE
478pub struct State(*mut sys::TractState);
479
480impl Drop for State {
481    fn drop(&mut self) {
482        unsafe {
483            sys::tract_state_destroy(&mut self.0);
484        }
485    }
486}
487
488impl std::fmt::Debug for State {
489    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490        write!(f, "State({:?})", self.0)
491    }
492}
493
494impl Clone for State {
495    fn clone(&self) -> Self {
496        let mut clone = null_mut();
497        unsafe {
498            sys::tract_state_clone(self.0, &mut clone);
499        }
500        State(clone)
501    }
502}
503
504// Safety: the underlying FrozenState is Send
505unsafe impl Send for State {}
506
507impl StateInterface for State {
508    type Tensor = Tensor;
509    type Fact = Fact;
510
511    fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
512        let inputs = inputs.into_inputs()?;
513        let mut outputs = vec![null_mut(); self.output_count()?];
514        let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
515        check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
516        let outputs = outputs.into_iter().map(Tensor).collect();
517        Ok(outputs)
518    }
519
520    fn input_count(&self) -> Result<usize> {
521        let mut count = 0;
522        check!(sys::tract_state_input_count(self.0, &mut count))?;
523        Ok(count)
524    }
525
526    fn output_count(&self) -> Result<usize> {
527        let mut count = 0;
528        check!(sys::tract_state_output_count(self.0, &mut count))?;
529        Ok(count)
530    }
531}
532
533// TENSOR
534wrapper!(Tensor, TractTensor, tract_tensor_destroy);
535wrapper_clone!(Tensor, tract_tensor_clone);
536unsafe impl Send for Tensor {}
537unsafe impl Sync for Tensor {}
538
539impl TensorInterface for Tensor {
540    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
541        anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
542        let mut value = null_mut();
543        check!(sys::tract_tensor_from_bytes(
544            dt as _,
545            shape.len(),
546            shape.as_ptr(),
547            data.as_ptr() as _,
548            &mut value
549        ))?;
550        Ok(Tensor(value))
551    }
552
553    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
554        let mut rank = 0;
555        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
556        let mut shape = null();
557        let mut data = null();
558        check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
559        unsafe {
560            let dt: DatumType = std::mem::transmute(dt);
561            let shape = std::slice::from_raw_parts(shape, rank);
562            let len: usize = shape.iter().product();
563            let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
564            Ok((dt, shape, data))
565        }
566    }
567
568    fn datum_type(&self) -> Result<DatumType> {
569        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
570        check!(sys::tract_tensor_as_bytes(
571            self.0,
572            &mut dt,
573            std::ptr::null_mut(),
574            std::ptr::null_mut(),
575            std::ptr::null_mut()
576        ))?;
577        unsafe {
578            let dt: DatumType = std::mem::transmute(dt);
579            Ok(dt)
580        }
581    }
582
583    fn convert_to(&self, to: DatumType) -> Result<Self> {
584        let mut new = null_mut();
585        check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
586        Ok(Tensor(new))
587    }
588}
589
590impl PartialEq for Tensor {
591    fn eq(&self, other: &Self) -> bool {
592        let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
593        let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
594        me_dt == other_dt && me_shape == other_shape && me_data == other_data
595    }
596}
597
598// FACT
599wrapper!(Fact, TractFact, tract_fact_destroy);
600wrapper_clone!(Fact, tract_fact_clone);
601
602impl Fact {
603    fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
604        let cstr = CString::new(spec.to_string())?;
605        let mut fact = null_mut();
606        check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
607        Ok(Fact(fact))
608    }
609
610    fn dump(&self) -> Result<String> {
611        let mut ptr = null_mut();
612        check!(sys::tract_fact_dump(self.0, &mut ptr))?;
613        unsafe {
614            let s = CStr::from_ptr(ptr).to_owned();
615            sys::tract_free_cstring(ptr);
616            Ok(s.to_str()?.to_owned())
617        }
618    }
619}
620
621impl FactInterface for Fact {
622    type Dim = Dim;
623
624    fn datum_type(&self) -> Result<DatumType> {
625        let mut dt = 0u32;
626        check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
627        Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
628    }
629
630    fn rank(&self) -> Result<usize> {
631        let mut rank = 0;
632        check!(sys::tract_fact_rank(self.0, &mut rank))?;
633        Ok(rank)
634    }
635
636    fn dim(&self, axis: usize) -> Result<Self::Dim> {
637        let mut ptr = null_mut();
638        check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
639        Ok(Dim(ptr))
640    }
641}
642
643impl std::fmt::Display for Fact {
644    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
645        match self.dump() {
646            Ok(s) => f.write_str(&s),
647            Err(_) => Err(std::fmt::Error),
648        }
649    }
650}
651
652// INFERENCE FACT
653wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
654wrapper_clone!(InferenceFact, tract_inference_fact_clone);
655
656impl InferenceFact {
657    fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
658        let cstr = CString::new(spec.to_string())?;
659        let mut fact = null_mut();
660        check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
661        Ok(InferenceFact(fact))
662    }
663
664    fn dump(&self) -> Result<String> {
665        let mut ptr = null_mut();
666        check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
667        unsafe {
668            let s = CStr::from_ptr(ptr).to_owned();
669            sys::tract_free_cstring(ptr);
670            Ok(s.to_str()?.to_owned())
671        }
672    }
673}
674
675impl InferenceFactInterface for InferenceFact {
676    fn empty() -> Result<InferenceFact> {
677        let mut fact = null_mut();
678        check!(sys::tract_inference_fact_empty(&mut fact))?;
679        Ok(InferenceFact(fact))
680    }
681}
682
683impl Default for InferenceFact {
684    fn default() -> Self {
685        Self::empty().unwrap()
686    }
687}
688
689impl std::fmt::Display for InferenceFact {
690    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
691        match self.dump() {
692            Ok(s) => f.write_str(&s),
693            Err(_) => Err(std::fmt::Error),
694        }
695    }
696}
697
698as_inference_fact_impl!(InferenceModel, InferenceFact);
699as_fact_impl!(Model, Fact);
700
701// Dim
702wrapper!(Dim, TractDim, tract_dim_destroy);
703wrapper_clone!(Dim, tract_dim_clone);
704
705impl Dim {
706    fn dump(&self) -> Result<String> {
707        let mut ptr = null_mut();
708        check!(sys::tract_dim_dump(self.0, &mut ptr))?;
709        unsafe {
710            let s = CStr::from_ptr(ptr).to_owned();
711            sys::tract_free_cstring(ptr);
712            Ok(s.to_str()?.to_owned())
713        }
714    }
715}
716
717impl DimInterface for Dim {
718    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
719        let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
720        let c_strings: Vec<CString> =
721            names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
722        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
723        let mut ptr = null_mut();
724        check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
725        Ok(Dim(ptr))
726    }
727
728    fn to_int64(&self) -> Result<i64> {
729        let mut i = 0;
730        check!(sys::tract_dim_to_int64(self.0, &mut i))?;
731        Ok(i)
732    }
733}
734
735impl std::fmt::Display for Dim {
736    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737        match self.dump() {
738            Ok(s) => f.write_str(&s),
739            Err(_) => Err(std::fmt::Error),
740        }
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747
748    #[test]
749    fn clone_tensor_no_double_free() {
750        let t = Tensor::from_slice::<f32>(&[2, 2], &[1.0, 2.0, 3.0, 4.0]).unwrap();
751        let clone = t.clone();
752        assert_eq!(t, clone);
753    }
754}