Skip to main content

tract_proxy/
lib.rs

1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12    ($expr:expr) => {
13        unsafe {
14            if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15                let buf = CStr::from_ptr(sys::tract_get_last_error());
16                Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17            } else {
18                Ok(())
19            }
20        }
21    };
22}
23
24macro_rules! wrapper {
25    ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26        #[derive(Debug)]
27        pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29        impl Drop for $new_type {
30            fn drop(&mut self) {
31                unsafe {
32                    sys::$dest(&mut self.0);
33                }
34            }
35        }
36    };
37}
38
39macro_rules! wrapper_clone {
40    ($new_type:ident, $clone_fn:ident) => {
41        impl Clone for $new_type {
42            fn clone(&self) -> Self {
43                let mut clone = null_mut();
44                unsafe {
45                    sys::$clone_fn(self.0, &mut clone);
46                }
47                $new_type(clone)
48            }
49        }
50    };
51}
52
53pub fn nnef() -> Result<Nnef> {
54    let mut nnef = null_mut();
55    check!(sys::tract_nnef_create(&mut nnef))?;
56    Ok(Nnef(nnef))
57}
58
59pub fn onnx() -> Result<Onnx> {
60    let mut onnx = null_mut();
61    check!(sys::tract_onnx_create(&mut onnx))?;
62    Ok(Onnx(onnx))
63}
64
65pub fn version() -> &'static str {
66    unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
67}
68
69wrapper!(Nnef, TractNnef, tract_nnef_destroy);
70impl NnefInterface for Nnef {
71    type Model = Model;
72    fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
73        let path = path.as_ref();
74        let path = CString::new(
75            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
76        )?;
77        let mut model = null_mut();
78        check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
79        Ok(Model(model))
80    }
81
82    fn load_buffer(&self, data: &[u8]) -> Result<Model> {
83        let mut model = null_mut();
84        check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
85        Ok(Model(model))
86    }
87
88    fn enable_tract_core(&mut self) -> Result<()> {
89        check!(sys::tract_nnef_enable_tract_core(self.0))
90    }
91
92    fn enable_tract_extra(&mut self) -> Result<()> {
93        check!(sys::tract_nnef_enable_tract_extra(self.0))
94    }
95
96    fn enable_tract_transformers(&mut self) -> Result<()> {
97        check!(sys::tract_nnef_enable_tract_transformers(self.0))
98    }
99
100    fn enable_onnx(&mut self) -> Result<()> {
101        check!(sys::tract_nnef_enable_onnx(self.0))
102    }
103
104    fn enable_pulse(&mut self) -> Result<()> {
105        check!(sys::tract_nnef_enable_pulse(self.0))
106    }
107
108    fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
109        check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
110    }
111
112    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
113        let path = path.as_ref();
114        let path = CString::new(
115            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
116        )?;
117        check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
118        Ok(())
119    }
120
121    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
122        let path = path.as_ref();
123        let path = CString::new(
124            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
125        )?;
126        check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
127        Ok(())
128    }
129
130    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
131        let path = path.as_ref();
132        let path = CString::new(
133            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
134        )?;
135        check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
136        Ok(())
137    }
138}
139
140// ONNX
141wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
142
143impl OnnxInterface for Onnx {
144    type InferenceModel = InferenceModel;
145    fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
146        let path = path.as_ref();
147        let path = CString::new(
148            path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
149        )?;
150        let mut model = null_mut();
151        check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
152        Ok(InferenceModel(model))
153    }
154
155    fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
156        let mut model = null_mut();
157        check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
158        Ok(InferenceModel(model))
159    }
160}
161
162// INFERENCE MODEL
163wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
164impl InferenceModelInterface for InferenceModel {
165    type Model = Model;
166    type InferenceFact = InferenceFact;
167    fn input_count(&self) -> Result<usize> {
168        let mut count = 0;
169        check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
170        Ok(count)
171    }
172
173    fn output_count(&self) -> Result<usize> {
174        let mut count = 0;
175        check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
176        Ok(count)
177    }
178
179    fn input_name(&self, id: usize) -> Result<String> {
180        let mut ptr = null_mut();
181        check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
182        unsafe {
183            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
184            sys::tract_free_cstring(ptr);
185            Ok(ret)
186        }
187    }
188
189    fn output_name(&self, id: usize) -> Result<String> {
190        let mut ptr = null_mut();
191        check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
192        unsafe {
193            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
194            sys::tract_free_cstring(ptr);
195            Ok(ret)
196        }
197    }
198
199    fn input_fact(&self, id: usize) -> Result<InferenceFact> {
200        let mut ptr = null_mut();
201        check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
202        Ok(InferenceFact(ptr))
203    }
204
205    fn set_input_fact(
206        &mut self,
207        id: usize,
208        fact: impl AsFact<Self, Self::InferenceFact>,
209    ) -> Result<()> {
210        let fact = fact.as_fact(self)?;
211        check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
212        Ok(())
213    }
214
215    fn output_fact(&self, id: usize) -> Result<InferenceFact> {
216        let mut ptr = null_mut();
217        check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
218        Ok(InferenceFact(ptr))
219    }
220
221    fn set_output_fact(
222        &mut self,
223        id: usize,
224        fact: impl AsFact<InferenceModel, InferenceFact>,
225    ) -> Result<()> {
226        let fact = fact.as_fact(self)?;
227        check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
228        Ok(())
229    }
230
231    fn analyse(&mut self) -> Result<()> {
232        check!(sys::tract_inference_model_analyse(self.0))?;
233        Ok(())
234    }
235
236    fn into_model(mut self) -> Result<Self::Model> {
237        let mut ptr = null_mut();
238        check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
239        Ok(Model(ptr))
240    }
241}
242
243// MODEL
244wrapper!(Model, TractModel, tract_model_destroy);
245
246impl ModelInterface for Model {
247    type Fact = Fact;
248    type Tensor = Tensor;
249    type Runnable = Runnable;
250    fn input_count(&self) -> Result<usize> {
251        let mut count = 0;
252        check!(sys::tract_model_input_count(self.0, &mut count))?;
253        Ok(count)
254    }
255
256    fn output_count(&self) -> Result<usize> {
257        let mut count = 0;
258        check!(sys::tract_model_output_count(self.0, &mut count))?;
259        Ok(count)
260    }
261
262    fn input_name(&self, id: usize) -> Result<String> {
263        let mut ptr = null_mut();
264        check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
265        unsafe {
266            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
267            sys::tract_free_cstring(ptr);
268            Ok(ret)
269        }
270    }
271
272    fn output_name(&self, id: usize) -> Result<String> {
273        let mut ptr = null_mut();
274        check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
275        unsafe {
276            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
277            sys::tract_free_cstring(ptr);
278            Ok(ret)
279        }
280    }
281
282    fn input_fact(&self, id: usize) -> Result<Fact> {
283        let mut ptr = null_mut();
284        check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
285        Ok(Fact(ptr))
286    }
287
288    fn output_fact(&self, id: usize) -> Result<Fact> {
289        let mut ptr = null_mut();
290        check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
291        Ok(Fact(ptr))
292    }
293
294    fn into_runnable(self) -> Result<Runnable> {
295        let mut model = self;
296        let mut runnable = null_mut();
297        check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
298        Ok(Runnable(runnable))
299    }
300
301    fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
302        let transform = spec.into().to_transform_string();
303        let t = CString::new(transform)?;
304        check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
305        Ok(())
306    }
307
308    fn property_keys(&self) -> Result<Vec<String>> {
309        let mut len = 0;
310        check!(sys::tract_model_property_count(self.0, &mut len))?;
311        let mut keys = vec![null_mut(); len];
312        check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
313        unsafe {
314            keys.into_iter()
315                .map(|pc| {
316                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
317                    sys::tract_free_cstring(pc);
318                    Ok(s)
319                })
320                .collect()
321        }
322    }
323
324    fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
325        let mut v = null_mut();
326        let name = CString::new(name.as_ref())?;
327        check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
328        Ok(Tensor(v))
329    }
330
331    fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
332        let spec = CString::new(spec)?;
333        let mut ptr = null_mut();
334        check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
335        Ok(Fact(ptr))
336    }
337}
338
339// RUNTIME
340wrapper!(Runtime, TractRuntime, tract_runtime_release);
341
342pub fn runtime_for_name(name: &str) -> Result<Runtime> {
343    let mut rt = null_mut();
344    let name = CString::new(name)?;
345    check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
346    Ok(Runtime(rt))
347}
348
349impl RuntimeInterface for Runtime {
350    type Runnable = Runnable;
351
352    type Model = Model;
353
354    fn name(&self) -> Result<String> {
355        let mut ptr = null_mut();
356        check!(sys::tract_runtime_name(self.0, &mut ptr))?;
357        unsafe {
358            let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
359            sys::tract_free_cstring(ptr);
360            Ok(ret)
361        }
362    }
363
364    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
365        let mut model = model;
366        let mut runnable = null_mut();
367        check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
368        Ok(Runnable(runnable))
369    }
370}
371
372// RUNNABLE
373wrapper!(Runnable, TractRunnable, tract_runnable_release);
374unsafe impl Send for Runnable {}
375unsafe impl Sync for Runnable {}
376
377impl RunnableInterface for Runnable {
378    type Tensor = Tensor;
379    type State = State;
380    type Fact = Fact;
381
382    fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
383        StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
384    }
385
386    fn spawn_state(&self) -> Result<State> {
387        let mut state = null_mut();
388        check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
389        Ok(State(state))
390    }
391
392    fn input_count(&self) -> Result<usize> {
393        let mut count = 0;
394        check!(sys::tract_runnable_input_count(self.0, &mut count))?;
395        Ok(count)
396    }
397
398    fn output_count(&self) -> Result<usize> {
399        let mut count = 0;
400        check!(sys::tract_runnable_output_count(self.0, &mut count))?;
401        Ok(count)
402    }
403
404    fn input_fact(&self, id: usize) -> Result<Self::Fact> {
405        let mut ptr = null_mut();
406        check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
407        Ok(Fact(ptr))
408    }
409
410    fn output_fact(&self, id: usize) -> Result<Self::Fact> {
411        let mut ptr = null_mut();
412        check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
413        Ok(Fact(ptr))
414    }
415
416    fn property_keys(&self) -> Result<Vec<String>> {
417        let mut len = 0;
418        check!(sys::tract_runnable_property_count(self.0, &mut len))?;
419        let mut keys = vec![null_mut(); len];
420        check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
421        unsafe {
422            keys.into_iter()
423                .map(|pc| {
424                    let s = CStr::from_ptr(pc).to_str()?.to_owned();
425                    sys::tract_free_cstring(pc);
426                    Ok(s)
427                })
428                .collect()
429        }
430    }
431
432    fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
433        let mut v = null_mut();
434        let name = CString::new(name.as_ref())?;
435        check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
436        Ok(Tensor(v))
437    }
438
439    fn cost_json(&self) -> Result<String> {
440        let input: Option<Vec<Tensor>> = None;
441        self.profile_json(input)
442    }
443
444    fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
445    where
446        I: IntoIterator<Item = IV>,
447        IV: TryInto<Self::Tensor, Error = IE>,
448        IE: Into<anyhow::Error>,
449    {
450        let inputs = if let Some(inputs) = inputs {
451            let inputs = inputs
452                .into_iter()
453                .map(|i| i.try_into().map_err(|e| e.into()))
454                .collect::<Result<Vec<Tensor>>>()?;
455            anyhow::ensure!(self.input_count()? == inputs.len());
456            Some(inputs)
457        } else {
458            None
459        };
460        let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
461            inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
462        let mut json: *mut i8 = null_mut();
463        let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
464
465        check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
466        anyhow::ensure!(!json.is_null());
467        unsafe {
468            let s = CStr::from_ptr(json).to_owned();
469            sys::tract_free_cstring(json);
470            Ok(s.to_str()?.to_owned())
471        }
472    }
473}
474
475// STATE
476pub struct State(*mut sys::TractState);
477
478impl Drop for State {
479    fn drop(&mut self) {
480        unsafe {
481            sys::tract_state_destroy(&mut self.0);
482        }
483    }
484}
485
486impl std::fmt::Debug for State {
487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488        write!(f, "State({:?})", self.0)
489    }
490}
491
492impl Clone for State {
493    fn clone(&self) -> Self {
494        let mut clone = null_mut();
495        unsafe {
496            sys::tract_state_clone(self.0, &mut clone);
497        }
498        State(clone)
499    }
500}
501
502// Safety: the underlying FrozenState is Send
503unsafe impl Send for State {}
504
505impl StateInterface for State {
506    type Tensor = Tensor;
507    type Fact = Fact;
508
509    fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
510        let inputs = inputs.into_inputs()?;
511        let mut outputs = vec![null_mut(); self.output_count()?];
512        let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
513        check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
514        let outputs = outputs.into_iter().map(Tensor).collect();
515        Ok(outputs)
516    }
517
518    fn input_count(&self) -> Result<usize> {
519        let mut count = 0;
520        check!(sys::tract_state_input_count(self.0, &mut count))?;
521        Ok(count)
522    }
523
524    fn output_count(&self) -> Result<usize> {
525        let mut count = 0;
526        check!(sys::tract_state_output_count(self.0, &mut count))?;
527        Ok(count)
528    }
529}
530
531// TENSOR
532wrapper!(Tensor, TractTensor, tract_tensor_destroy);
533wrapper_clone!(Tensor, tract_tensor_clone);
534unsafe impl Send for Tensor {}
535unsafe impl Sync for Tensor {}
536
537impl TensorInterface for Tensor {
538    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
539        anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
540        let mut value = null_mut();
541        check!(sys::tract_tensor_from_bytes(
542            dt as _,
543            shape.len(),
544            shape.as_ptr(),
545            data.as_ptr() as _,
546            &mut value
547        ))?;
548        Ok(Tensor(value))
549    }
550
551    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
552        let mut rank = 0;
553        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
554        let mut shape = null();
555        let mut data = null();
556        check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
557        unsafe {
558            let dt: DatumType = std::mem::transmute(dt);
559            let shape = std::slice::from_raw_parts(shape, rank);
560            let len: usize = shape.iter().product();
561            let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
562            Ok((dt, shape, data))
563        }
564    }
565
566    fn datum_type(&self) -> Result<DatumType> {
567        let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
568        check!(sys::tract_tensor_as_bytes(
569            self.0,
570            &mut dt,
571            std::ptr::null_mut(),
572            std::ptr::null_mut(),
573            std::ptr::null_mut()
574        ))?;
575        unsafe {
576            let dt: DatumType = std::mem::transmute(dt);
577            Ok(dt)
578        }
579    }
580
581    fn convert_to(&self, to: DatumType) -> Result<Self> {
582        let mut new = null_mut();
583        check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
584        Ok(Tensor(new))
585    }
586}
587
588impl PartialEq for Tensor {
589    fn eq(&self, other: &Self) -> bool {
590        let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
591        let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
592        me_dt == other_dt && me_shape == other_shape && me_data == other_data
593    }
594}
595
596tensor_from_to_ndarray!();
597
598// FACT
599wrapper!(Fact, TractFact, tract_fact_destroy);
600wrapper_clone!(Fact, tract_fact_clone);
601
602impl Fact {
603    fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
604        let cstr = CString::new(spec.to_string())?;
605        let mut fact = null_mut();
606        check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
607        Ok(Fact(fact))
608    }
609
610    fn dump(&self) -> Result<String> {
611        let mut ptr = null_mut();
612        check!(sys::tract_fact_dump(self.0, &mut ptr))?;
613        unsafe {
614            let s = CStr::from_ptr(ptr).to_owned();
615            sys::tract_free_cstring(ptr);
616            Ok(s.to_str()?.to_owned())
617        }
618    }
619}
620
621impl FactInterface for Fact {
622    type Dim = Dim;
623
624    fn datum_type(&self) -> Result<DatumType> {
625        let mut dt = 0u32;
626        check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
627        Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
628    }
629
630    fn rank(&self) -> Result<usize> {
631        let mut rank = 0;
632        check!(sys::tract_fact_rank(self.0, &mut rank))?;
633        Ok(rank)
634    }
635
636    fn dim(&self, axis: usize) -> Result<Self::Dim> {
637        let mut ptr = null_mut();
638        check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
639        Ok(Dim(ptr))
640    }
641}
642
643impl std::fmt::Display for Fact {
644    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
645        match self.dump() {
646            Ok(s) => f.write_str(&s),
647            Err(_) => Err(std::fmt::Error),
648        }
649    }
650}
651
652// INFERENCE FACT
653wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
654wrapper_clone!(InferenceFact, tract_inference_fact_clone);
655
656impl InferenceFact {
657    fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
658        let cstr = CString::new(spec.to_string())?;
659        let mut fact = null_mut();
660        check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
661        Ok(InferenceFact(fact))
662    }
663
664    fn dump(&self) -> Result<String> {
665        let mut ptr = null_mut();
666        check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
667        unsafe {
668            let s = CStr::from_ptr(ptr).to_owned();
669            sys::tract_free_cstring(ptr);
670            Ok(s.to_str()?.to_owned())
671        }
672    }
673}
674
675impl InferenceFactInterface for InferenceFact {
676    fn empty() -> Result<InferenceFact> {
677        let mut fact = null_mut();
678        check!(sys::tract_inference_fact_empty(&mut fact))?;
679        Ok(InferenceFact(fact))
680    }
681}
682
683impl Default for InferenceFact {
684    fn default() -> Self {
685        Self::empty().unwrap()
686    }
687}
688
689impl std::fmt::Display for InferenceFact {
690    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
691        match self.dump() {
692            Ok(s) => f.write_str(&s),
693            Err(_) => Err(std::fmt::Error),
694        }
695    }
696}
697
698as_inference_fact_impl!(InferenceModel, InferenceFact);
699as_fact_impl!(Model, Fact);
700
701// Dim
702wrapper!(Dim, TractDim, tract_dim_destroy);
703wrapper_clone!(Dim, tract_dim_clone);
704
705impl Dim {
706    fn dump(&self) -> Result<String> {
707        let mut ptr = null_mut();
708        check!(sys::tract_dim_dump(self.0, &mut ptr))?;
709        unsafe {
710            let s = CStr::from_ptr(ptr).to_owned();
711            sys::tract_free_cstring(ptr);
712            Ok(s.to_str()?.to_owned())
713        }
714    }
715}
716
717impl DimInterface for Dim {
718    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
719        let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
720        let c_strings: Vec<CString> =
721            names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
722        let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
723        let mut ptr = null_mut();
724        check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
725        Ok(Dim(ptr))
726    }
727
728    fn to_int64(&self) -> Result<i64> {
729        let mut i = 0;
730        check!(sys::tract_dim_to_int64(self.0, &mut i))?;
731        Ok(i)
732    }
733}
734
735impl std::fmt::Display for Dim {
736    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737        match self.dump() {
738            Ok(s) => f.write_str(&s),
739            Err(_) => Err(std::fmt::Error),
740        }
741    }
742}
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747
748    #[test]
749    fn clone_tensor_no_double_free() {
750        let t = Tensor::from_slice::<f32>(&[2, 2], &[1.0, 2.0, 3.0, 4.0]).unwrap();
751        let clone = t.clone();
752        assert_eq!(t, clone);
753    }
754}