1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12 ($expr:expr) => {
13 unsafe {
14 if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15 let buf = CStr::from_ptr(sys::tract_get_last_error());
16 Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17 } else {
18 Ok(())
19 }
20 }
21 };
22}
23
24macro_rules! wrapper {
25 ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26 #[derive(Debug)]
27 pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29 impl Drop for $new_type {
30 fn drop(&mut self) {
31 unsafe {
32 sys::$dest(&mut self.0);
33 }
34 }
35 }
36 };
37}
38
39macro_rules! wrapper_clone {
40 ($new_type:ident, $clone_fn:ident) => {
41 impl Clone for $new_type {
42 fn clone(&self) -> Self {
43 let mut clone = null_mut();
44 unsafe {
45 sys::$clone_fn(self.0, &mut clone);
46 }
47 $new_type(clone)
48 }
49 }
50 };
51}
52
53pub fn nnef() -> Result<Nnef> {
54 let mut nnef = null_mut();
55 check!(sys::tract_nnef_create(&mut nnef))?;
56 Ok(Nnef(nnef))
57}
58
59pub fn onnx() -> Result<Onnx> {
60 let mut onnx = null_mut();
61 check!(sys::tract_onnx_create(&mut onnx))?;
62 Ok(Onnx(onnx))
63}
64
65pub fn version() -> &'static str {
66 unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
67}
68
69wrapper!(Nnef, TractNnef, tract_nnef_destroy);
70impl NnefInterface for Nnef {
71 type Model = Model;
72 fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
73 let path = path.as_ref();
74 let path = CString::new(
75 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
76 )?;
77 let mut model = null_mut();
78 check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
79 Ok(Model(model))
80 }
81
82 fn load_buffer(&self, data: &[u8]) -> Result<Model> {
83 let mut model = null_mut();
84 check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
85 Ok(Model(model))
86 }
87
88 fn enable_tract_core(&mut self) -> Result<()> {
89 check!(sys::tract_nnef_enable_tract_core(self.0))
90 }
91
92 fn enable_tract_extra(&mut self) -> Result<()> {
93 check!(sys::tract_nnef_enable_tract_extra(self.0))
94 }
95
96 fn enable_tract_transformers(&mut self) -> Result<()> {
97 check!(sys::tract_nnef_enable_tract_transformers(self.0))
98 }
99
100 fn enable_onnx(&mut self) -> Result<()> {
101 check!(sys::tract_nnef_enable_onnx(self.0))
102 }
103
104 fn enable_pulse(&mut self) -> Result<()> {
105 check!(sys::tract_nnef_enable_pulse(self.0))
106 }
107
108 fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
109 check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
110 }
111
112 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
113 let path = path.as_ref();
114 let path = CString::new(
115 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
116 )?;
117 check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
118 Ok(())
119 }
120
121 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
122 let path = path.as_ref();
123 let path = CString::new(
124 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
125 )?;
126 check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
127 Ok(())
128 }
129
130 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
131 let path = path.as_ref();
132 let path = CString::new(
133 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
134 )?;
135 check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
136 Ok(())
137 }
138}
139
140wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
142
143impl OnnxInterface for Onnx {
144 type InferenceModel = InferenceModel;
145 fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
146 let path = path.as_ref();
147 let path = CString::new(
148 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
149 )?;
150 let mut model = null_mut();
151 check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
152 Ok(InferenceModel(model))
153 }
154
155 fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
156 let mut model = null_mut();
157 check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
158 Ok(InferenceModel(model))
159 }
160}
161
162wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
164impl InferenceModelInterface for InferenceModel {
165 type Model = Model;
166 type InferenceFact = InferenceFact;
167 fn input_count(&self) -> Result<usize> {
168 let mut count = 0;
169 check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
170 Ok(count)
171 }
172
173 fn output_count(&self) -> Result<usize> {
174 let mut count = 0;
175 check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
176 Ok(count)
177 }
178
179 fn input_name(&self, id: usize) -> Result<String> {
180 let mut ptr = null_mut();
181 check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
182 unsafe {
183 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
184 sys::tract_free_cstring(ptr);
185 Ok(ret)
186 }
187 }
188
189 fn output_name(&self, id: usize) -> Result<String> {
190 let mut ptr = null_mut();
191 check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
192 unsafe {
193 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
194 sys::tract_free_cstring(ptr);
195 Ok(ret)
196 }
197 }
198
199 fn input_fact(&self, id: usize) -> Result<InferenceFact> {
200 let mut ptr = null_mut();
201 check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
202 Ok(InferenceFact(ptr))
203 }
204
205 fn set_input_fact(
206 &mut self,
207 id: usize,
208 fact: impl AsFact<Self, Self::InferenceFact>,
209 ) -> Result<()> {
210 let fact = fact.as_fact(self)?;
211 check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
212 Ok(())
213 }
214
215 fn output_fact(&self, id: usize) -> Result<InferenceFact> {
216 let mut ptr = null_mut();
217 check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
218 Ok(InferenceFact(ptr))
219 }
220
221 fn set_output_fact(
222 &mut self,
223 id: usize,
224 fact: impl AsFact<InferenceModel, InferenceFact>,
225 ) -> Result<()> {
226 let fact = fact.as_fact(self)?;
227 check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
228 Ok(())
229 }
230
231 fn analyse(&mut self) -> Result<()> {
232 check!(sys::tract_inference_model_analyse(self.0))?;
233 Ok(())
234 }
235
236 fn into_model(mut self) -> Result<Self::Model> {
237 let mut ptr = null_mut();
238 check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
239 Ok(Model(ptr))
240 }
241}
242
243wrapper!(Model, TractModel, tract_model_destroy);
245
246impl ModelInterface for Model {
247 type Fact = Fact;
248 type Tensor = Tensor;
249 type Runnable = Runnable;
250 fn input_count(&self) -> Result<usize> {
251 let mut count = 0;
252 check!(sys::tract_model_input_count(self.0, &mut count))?;
253 Ok(count)
254 }
255
256 fn output_count(&self) -> Result<usize> {
257 let mut count = 0;
258 check!(sys::tract_model_output_count(self.0, &mut count))?;
259 Ok(count)
260 }
261
262 fn input_name(&self, id: usize) -> Result<String> {
263 let mut ptr = null_mut();
264 check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
265 unsafe {
266 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
267 sys::tract_free_cstring(ptr);
268 Ok(ret)
269 }
270 }
271
272 fn output_name(&self, id: usize) -> Result<String> {
273 let mut ptr = null_mut();
274 check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
275 unsafe {
276 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
277 sys::tract_free_cstring(ptr);
278 Ok(ret)
279 }
280 }
281
282 fn input_fact(&self, id: usize) -> Result<Fact> {
283 let mut ptr = null_mut();
284 check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
285 Ok(Fact(ptr))
286 }
287
288 fn output_fact(&self, id: usize) -> Result<Fact> {
289 let mut ptr = null_mut();
290 check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
291 Ok(Fact(ptr))
292 }
293
294 fn into_runnable(self) -> Result<Runnable> {
295 let mut model = self;
296 let mut runnable = null_mut();
297 check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
298 Ok(Runnable(runnable))
299 }
300
301 fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
302 let transform = spec.into().to_transform_string();
303 let t = CString::new(transform)?;
304 check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
305 Ok(())
306 }
307
308 fn property_keys(&self) -> Result<Vec<String>> {
309 let mut len = 0;
310 check!(sys::tract_model_property_count(self.0, &mut len))?;
311 let mut keys = vec![null_mut(); len];
312 check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
313 unsafe {
314 keys.into_iter()
315 .map(|pc| {
316 let s = CStr::from_ptr(pc).to_str()?.to_owned();
317 sys::tract_free_cstring(pc);
318 Ok(s)
319 })
320 .collect()
321 }
322 }
323
324 fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
325 let mut v = null_mut();
326 let name = CString::new(name.as_ref())?;
327 check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
328 Ok(Tensor(v))
329 }
330
331 fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
332 let spec = CString::new(spec)?;
333 let mut ptr = null_mut();
334 check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
335 Ok(Fact(ptr))
336 }
337}
338
339wrapper!(Runtime, TractRuntime, tract_runtime_release);
341
342pub fn runtime_for_name(name: &str) -> Result<Runtime> {
343 let mut rt = null_mut();
344 let name = CString::new(name)?;
345 check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
346 Ok(Runtime(rt))
347}
348
349impl RuntimeInterface for Runtime {
350 type Runnable = Runnable;
351
352 type Model = Model;
353
354 fn name(&self) -> Result<String> {
355 let mut ptr = null_mut();
356 check!(sys::tract_runtime_name(self.0, &mut ptr))?;
357 unsafe {
358 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
359 sys::tract_free_cstring(ptr);
360 Ok(ret)
361 }
362 }
363
364 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
365 let mut model = model;
366 let mut runnable = null_mut();
367 check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
368 Ok(Runnable(runnable))
369 }
370}
371
372wrapper!(Runnable, TractRunnable, tract_runnable_release);
374unsafe impl Send for Runnable {}
375unsafe impl Sync for Runnable {}
376
377impl RunnableInterface for Runnable {
378 type Tensor = Tensor;
379 type State = State;
380 type Fact = Fact;
381
382 fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
383 StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
384 }
385
386 fn spawn_state(&self) -> Result<State> {
387 let mut state = null_mut();
388 check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
389 Ok(State(state))
390 }
391
392 fn input_count(&self) -> Result<usize> {
393 let mut count = 0;
394 check!(sys::tract_runnable_input_count(self.0, &mut count))?;
395 Ok(count)
396 }
397
398 fn output_count(&self) -> Result<usize> {
399 let mut count = 0;
400 check!(sys::tract_runnable_output_count(self.0, &mut count))?;
401 Ok(count)
402 }
403
404 fn input_fact(&self, id: usize) -> Result<Self::Fact> {
405 let mut ptr = null_mut();
406 check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
407 Ok(Fact(ptr))
408 }
409
410 fn output_fact(&self, id: usize) -> Result<Self::Fact> {
411 let mut ptr = null_mut();
412 check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
413 Ok(Fact(ptr))
414 }
415
416 fn property_keys(&self) -> Result<Vec<String>> {
417 let mut len = 0;
418 check!(sys::tract_runnable_property_count(self.0, &mut len))?;
419 let mut keys = vec![null_mut(); len];
420 check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
421 unsafe {
422 keys.into_iter()
423 .map(|pc| {
424 let s = CStr::from_ptr(pc).to_str()?.to_owned();
425 sys::tract_free_cstring(pc);
426 Ok(s)
427 })
428 .collect()
429 }
430 }
431
432 fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
433 let mut v = null_mut();
434 let name = CString::new(name.as_ref())?;
435 check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
436 Ok(Tensor(v))
437 }
438
439 fn cost_json(&self) -> Result<String> {
440 let input: Option<Vec<Tensor>> = None;
441 self.profile_json(input)
442 }
443
444 fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
445 where
446 I: IntoIterator<Item = IV>,
447 IV: TryInto<Self::Tensor, Error = IE>,
448 IE: Into<anyhow::Error>,
449 {
450 let inputs = if let Some(inputs) = inputs {
451 let inputs = inputs
452 .into_iter()
453 .map(|i| i.try_into().map_err(|e| e.into()))
454 .collect::<Result<Vec<Tensor>>>()?;
455 anyhow::ensure!(self.input_count()? == inputs.len());
456 Some(inputs)
457 } else {
458 None
459 };
460 let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
461 inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
462 let mut json: *mut i8 = null_mut();
463 let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
464
465 check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
466 anyhow::ensure!(!json.is_null());
467 unsafe {
468 let s = CStr::from_ptr(json).to_owned();
469 sys::tract_free_cstring(json);
470 Ok(s.to_str()?.to_owned())
471 }
472 }
473}
474
475pub struct State(*mut sys::TractState);
477
478impl Drop for State {
479 fn drop(&mut self) {
480 unsafe {
481 sys::tract_state_destroy(&mut self.0);
482 }
483 }
484}
485
486impl std::fmt::Debug for State {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 write!(f, "State({:?})", self.0)
489 }
490}
491
492impl Clone for State {
493 fn clone(&self) -> Self {
494 let mut clone = null_mut();
495 unsafe {
496 sys::tract_state_clone(self.0, &mut clone);
497 }
498 State(clone)
499 }
500}
501
502unsafe impl Send for State {}
504
505impl StateInterface for State {
506 type Tensor = Tensor;
507 type Fact = Fact;
508
509 fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
510 let inputs = inputs.into_inputs()?;
511 let mut outputs = vec![null_mut(); self.output_count()?];
512 let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
513 check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
514 let outputs = outputs.into_iter().map(Tensor).collect();
515 Ok(outputs)
516 }
517
518 fn input_count(&self) -> Result<usize> {
519 let mut count = 0;
520 check!(sys::tract_state_input_count(self.0, &mut count))?;
521 Ok(count)
522 }
523
524 fn output_count(&self) -> Result<usize> {
525 let mut count = 0;
526 check!(sys::tract_state_output_count(self.0, &mut count))?;
527 Ok(count)
528 }
529}
530
531wrapper!(Tensor, TractTensor, tract_tensor_destroy);
533wrapper_clone!(Tensor, tract_tensor_clone);
534unsafe impl Send for Tensor {}
535unsafe impl Sync for Tensor {}
536
537impl TensorInterface for Tensor {
538 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
539 anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
540 let mut value = null_mut();
541 check!(sys::tract_tensor_from_bytes(
542 dt as _,
543 shape.len(),
544 shape.as_ptr(),
545 data.as_ptr() as _,
546 &mut value
547 ))?;
548 Ok(Tensor(value))
549 }
550
551 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
552 let mut rank = 0;
553 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
554 let mut shape = null();
555 let mut data = null();
556 check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
557 unsafe {
558 let dt: DatumType = std::mem::transmute(dt);
559 let shape = std::slice::from_raw_parts(shape, rank);
560 let len: usize = shape.iter().product();
561 let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
562 Ok((dt, shape, data))
563 }
564 }
565
566 fn datum_type(&self) -> Result<DatumType> {
567 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
568 check!(sys::tract_tensor_as_bytes(
569 self.0,
570 &mut dt,
571 std::ptr::null_mut(),
572 std::ptr::null_mut(),
573 std::ptr::null_mut()
574 ))?;
575 unsafe {
576 let dt: DatumType = std::mem::transmute(dt);
577 Ok(dt)
578 }
579 }
580
581 fn convert_to(&self, to: DatumType) -> Result<Self> {
582 let mut new = null_mut();
583 check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
584 Ok(Tensor(new))
585 }
586}
587
588impl PartialEq for Tensor {
589 fn eq(&self, other: &Self) -> bool {
590 let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
591 let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
592 me_dt == other_dt && me_shape == other_shape && me_data == other_data
593 }
594}
595
596tensor_from_to_ndarray!();
597
598wrapper!(Fact, TractFact, tract_fact_destroy);
600wrapper_clone!(Fact, tract_fact_clone);
601
602impl Fact {
603 fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
604 let cstr = CString::new(spec.to_string())?;
605 let mut fact = null_mut();
606 check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
607 Ok(Fact(fact))
608 }
609
610 fn dump(&self) -> Result<String> {
611 let mut ptr = null_mut();
612 check!(sys::tract_fact_dump(self.0, &mut ptr))?;
613 unsafe {
614 let s = CStr::from_ptr(ptr).to_owned();
615 sys::tract_free_cstring(ptr);
616 Ok(s.to_str()?.to_owned())
617 }
618 }
619}
620
621impl FactInterface for Fact {
622 type Dim = Dim;
623
624 fn datum_type(&self) -> Result<DatumType> {
625 let mut dt = 0u32;
626 check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
627 Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
628 }
629
630 fn rank(&self) -> Result<usize> {
631 let mut rank = 0;
632 check!(sys::tract_fact_rank(self.0, &mut rank))?;
633 Ok(rank)
634 }
635
636 fn dim(&self, axis: usize) -> Result<Self::Dim> {
637 let mut ptr = null_mut();
638 check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
639 Ok(Dim(ptr))
640 }
641}
642
643impl std::fmt::Display for Fact {
644 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
645 match self.dump() {
646 Ok(s) => f.write_str(&s),
647 Err(_) => Err(std::fmt::Error),
648 }
649 }
650}
651
652wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
654wrapper_clone!(InferenceFact, tract_inference_fact_clone);
655
656impl InferenceFact {
657 fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
658 let cstr = CString::new(spec.to_string())?;
659 let mut fact = null_mut();
660 check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
661 Ok(InferenceFact(fact))
662 }
663
664 fn dump(&self) -> Result<String> {
665 let mut ptr = null_mut();
666 check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
667 unsafe {
668 let s = CStr::from_ptr(ptr).to_owned();
669 sys::tract_free_cstring(ptr);
670 Ok(s.to_str()?.to_owned())
671 }
672 }
673}
674
675impl InferenceFactInterface for InferenceFact {
676 fn empty() -> Result<InferenceFact> {
677 let mut fact = null_mut();
678 check!(sys::tract_inference_fact_empty(&mut fact))?;
679 Ok(InferenceFact(fact))
680 }
681}
682
683impl Default for InferenceFact {
684 fn default() -> Self {
685 Self::empty().unwrap()
686 }
687}
688
689impl std::fmt::Display for InferenceFact {
690 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
691 match self.dump() {
692 Ok(s) => f.write_str(&s),
693 Err(_) => Err(std::fmt::Error),
694 }
695 }
696}
697
698as_inference_fact_impl!(InferenceModel, InferenceFact);
699as_fact_impl!(Model, Fact);
700
701wrapper!(Dim, TractDim, tract_dim_destroy);
703wrapper_clone!(Dim, tract_dim_clone);
704
705impl Dim {
706 fn dump(&self) -> Result<String> {
707 let mut ptr = null_mut();
708 check!(sys::tract_dim_dump(self.0, &mut ptr))?;
709 unsafe {
710 let s = CStr::from_ptr(ptr).to_owned();
711 sys::tract_free_cstring(ptr);
712 Ok(s.to_str()?.to_owned())
713 }
714 }
715}
716
717impl DimInterface for Dim {
718 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
719 let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
720 let c_strings: Vec<CString> =
721 names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
722 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
723 let mut ptr = null_mut();
724 check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
725 Ok(Dim(ptr))
726 }
727
728 fn to_int64(&self) -> Result<i64> {
729 let mut i = 0;
730 check!(sys::tract_dim_to_int64(self.0, &mut i))?;
731 Ok(i)
732 }
733}
734
735impl std::fmt::Display for Dim {
736 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737 match self.dump() {
738 Ok(s) => f.write_str(&s),
739 Err(_) => Err(std::fmt::Error),
740 }
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747
748 #[test]
749 fn clone_tensor_no_double_free() {
750 let t = Tensor::from_slice::<f32>(&[2, 2], &[1.0, 2.0, 3.0, 4.0]).unwrap();
751 let clone = t.clone();
752 assert_eq!(t, clone);
753 }
754}