1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12 ($expr:expr) => {
13 unsafe {
14 if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15 let buf = CStr::from_ptr(sys::tract_get_last_error());
16 Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17 } else {
18 Ok(())
19 }
20 }
21 };
22}
23
24macro_rules! wrapper {
25 ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26 #[derive(Debug, Clone)]
27 pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29 impl Drop for $new_type {
30 fn drop(&mut self) {
31 unsafe {
32 sys::$dest(&mut self.0);
33 }
34 }
35 }
36 };
37}
38
39pub fn nnef() -> Result<Nnef> {
40 let mut nnef = null_mut();
41 check!(sys::tract_nnef_create(&mut nnef))?;
42 Ok(Nnef(nnef))
43}
44
45pub fn onnx() -> Result<Onnx> {
46 let mut onnx = null_mut();
47 check!(sys::tract_onnx_create(&mut onnx))?;
48 Ok(Onnx(onnx))
49}
50
51pub fn version() -> &'static str {
52 unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
53}
54
55wrapper!(Nnef, TractNnef, tract_nnef_destroy);
56impl NnefInterface for Nnef {
57 type Model = Model;
58 fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
59 let path = path.as_ref();
60 let path = CString::new(
61 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
62 )?;
63 let mut model = null_mut();
64 check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
65 Ok(Model(model))
66 }
67
68 fn load_buffer(&self, data: &[u8]) -> Result<Model> {
69 let mut model = null_mut();
70 check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
71 Ok(Model(model))
72 }
73
74 fn enable_tract_core(&mut self) -> Result<()> {
75 check!(sys::tract_nnef_enable_tract_core(self.0))
76 }
77
78 fn enable_tract_extra(&mut self) -> Result<()> {
79 check!(sys::tract_nnef_enable_tract_extra(self.0))
80 }
81
82 fn enable_tract_transformers(&mut self) -> Result<()> {
83 check!(sys::tract_nnef_enable_tract_transformers(self.0))
84 }
85
86 fn enable_onnx(&mut self) -> Result<()> {
87 check!(sys::tract_nnef_enable_onnx(self.0))
88 }
89
90 fn enable_pulse(&mut self) -> Result<()> {
91 check!(sys::tract_nnef_enable_pulse(self.0))
92 }
93
94 fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
95 check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
96 }
97
98 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
99 let path = path.as_ref();
100 let path = CString::new(
101 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
102 )?;
103 check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
104 Ok(())
105 }
106
107 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
108 let path = path.as_ref();
109 let path = CString::new(
110 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
111 )?;
112 check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
113 Ok(())
114 }
115
116 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
117 let path = path.as_ref();
118 let path = CString::new(
119 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
120 )?;
121 check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
122 Ok(())
123 }
124}
125
126wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
128
129impl OnnxInterface for Onnx {
130 type InferenceModel = InferenceModel;
131 fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
132 let path = path.as_ref();
133 let path = CString::new(
134 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
135 )?;
136 let mut model = null_mut();
137 check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
138 Ok(InferenceModel(model))
139 }
140
141 fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
142 let mut model = null_mut();
143 check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
144 Ok(InferenceModel(model))
145 }
146}
147
148wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
150impl InferenceModelInterface for InferenceModel {
151 type Model = Model;
152 type InferenceFact = InferenceFact;
153 fn set_output_names(
154 &mut self,
155 outputs: impl IntoIterator<Item = impl AsRef<str>>,
156 ) -> Result<()> {
157 let c_strings: Vec<CString> =
158 outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
159 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
160 check!(sys::tract_inference_model_set_output_names(
161 self.0,
162 c_strings.len(),
163 ptrs.as_ptr()
164 ))?;
165 Ok(())
166 }
167
168 fn input_count(&self) -> Result<usize> {
169 let mut count = 0;
170 check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
171 Ok(count)
172 }
173
174 fn output_count(&self) -> Result<usize> {
175 let mut count = 0;
176 check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
177 Ok(count)
178 }
179
180 fn input_name(&self, id: usize) -> Result<String> {
181 let mut ptr = null_mut();
182 check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
183 unsafe {
184 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
185 sys::tract_free_cstring(ptr);
186 Ok(ret)
187 }
188 }
189
190 fn output_name(&self, id: usize) -> Result<String> {
191 let mut ptr = null_mut();
192 check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
193 unsafe {
194 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
195 sys::tract_free_cstring(ptr);
196 Ok(ret)
197 }
198 }
199
200 fn input_fact(&self, id: usize) -> Result<InferenceFact> {
201 let mut ptr = null_mut();
202 check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
203 Ok(InferenceFact(ptr))
204 }
205
206 fn set_input_fact(
207 &mut self,
208 id: usize,
209 fact: impl AsFact<Self, Self::InferenceFact>,
210 ) -> Result<()> {
211 let fact = fact.as_fact(self)?;
212 check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
213 Ok(())
214 }
215
216 fn output_fact(&self, id: usize) -> Result<InferenceFact> {
217 let mut ptr = null_mut();
218 check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
219 Ok(InferenceFact(ptr))
220 }
221
222 fn set_output_fact(
223 &mut self,
224 id: usize,
225 fact: impl AsFact<InferenceModel, InferenceFact>,
226 ) -> Result<()> {
227 let fact = fact.as_fact(self)?;
228 check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
229 Ok(())
230 }
231
232 fn analyse(&mut self) -> Result<()> {
233 check!(sys::tract_inference_model_analyse(self.0))?;
234 Ok(())
235 }
236
237 fn into_model(mut self) -> Result<Self::Model> {
238 let mut ptr = null_mut();
239 check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
240 Ok(Model(ptr))
241 }
242}
243
244wrapper!(Model, TractModel, tract_model_destroy);
246
247impl ModelInterface for Model {
248 type Fact = Fact;
249 type Tensor = Tensor;
250 type Runnable = Runnable;
251 fn input_count(&self) -> Result<usize> {
252 let mut count = 0;
253 check!(sys::tract_model_input_count(self.0, &mut count))?;
254 Ok(count)
255 }
256
257 fn output_count(&self) -> Result<usize> {
258 let mut count = 0;
259 check!(sys::tract_model_output_count(self.0, &mut count))?;
260 Ok(count)
261 }
262
263 fn input_name(&self, id: usize) -> Result<String> {
264 let mut ptr = null_mut();
265 check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
266 unsafe {
267 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
268 sys::tract_free_cstring(ptr);
269 Ok(ret)
270 }
271 }
272
273 fn output_name(&self, id: usize) -> Result<String> {
274 let mut ptr = null_mut();
275 check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
276 unsafe {
277 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
278 sys::tract_free_cstring(ptr);
279 Ok(ret)
280 }
281 }
282
283 fn set_output_names(
284 &mut self,
285 outputs: impl IntoIterator<Item = impl AsRef<str>>,
286 ) -> Result<()> {
287 let c_strings: Vec<CString> =
288 outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
289 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
290 check!(sys::tract_model_set_output_names(self.0, c_strings.len(), ptrs.as_ptr()))?;
291 Ok(())
292 }
293
294 fn input_fact(&self, id: usize) -> Result<Fact> {
295 let mut ptr = null_mut();
296 check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
297 Ok(Fact(ptr))
298 }
299
300 fn output_fact(&self, id: usize) -> Result<Fact> {
301 let mut ptr = null_mut();
302 check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
303 Ok(Fact(ptr))
304 }
305
306 fn into_runnable(self) -> Result<Runnable> {
307 let mut model = self;
308 let mut runnable = null_mut();
309 check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
310 Ok(Runnable(runnable))
311 }
312
313 fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
314 let transform = spec.into().to_transform_string();
315 let t = CString::new(transform)?;
316 check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
317 Ok(())
318 }
319
320 fn property_keys(&self) -> Result<Vec<String>> {
321 let mut len = 0;
322 check!(sys::tract_model_property_count(self.0, &mut len))?;
323 let mut keys = vec![null_mut(); len];
324 check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
325 unsafe {
326 keys.into_iter()
327 .map(|pc| {
328 let s = CStr::from_ptr(pc).to_str()?.to_owned();
329 sys::tract_free_cstring(pc);
330 Ok(s)
331 })
332 .collect()
333 }
334 }
335
336 fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
337 let mut v = null_mut();
338 let name = CString::new(name.as_ref())?;
339 check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
340 Ok(Tensor(v))
341 }
342
343 fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
344 let spec = CString::new(spec)?;
345 let mut ptr = null_mut();
346 check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
347 Ok(Fact(ptr))
348 }
349}
350
351wrapper!(Runtime, TractRuntime, tract_runtime_release);
353
354pub fn runtime_for_name(name: &str) -> Result<Runtime> {
355 let mut rt = null_mut();
356 let name = CString::new(name)?;
357 check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
358 Ok(Runtime(rt))
359}
360
361impl RuntimeInterface for Runtime {
362 type Runnable = Runnable;
363
364 type Model = Model;
365
366 fn name(&self) -> Result<String> {
367 let mut ptr = null_mut();
368 check!(sys::tract_runtime_name(self.0, &mut ptr))?;
369 unsafe {
370 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
371 sys::tract_free_cstring(ptr);
372 Ok(ret)
373 }
374 }
375
376 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
377 let mut model = model;
378 let mut runnable = null_mut();
379 check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
380 Ok(Runnable(runnable))
381 }
382}
383
384wrapper!(Runnable, TractRunnable, tract_runnable_release);
386unsafe impl Send for Runnable {}
387unsafe impl Sync for Runnable {}
388
389impl RunnableInterface for Runnable {
390 type Tensor = Tensor;
391 type State = State;
392 type Fact = Fact;
393
394 fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
395 StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
396 }
397
398 fn spawn_state(&self) -> Result<State> {
399 let mut state = null_mut();
400 check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
401 Ok(State(state))
402 }
403
404 fn input_count(&self) -> Result<usize> {
405 let mut count = 0;
406 check!(sys::tract_runnable_input_count(self.0, &mut count))?;
407 Ok(count)
408 }
409
410 fn output_count(&self) -> Result<usize> {
411 let mut count = 0;
412 check!(sys::tract_runnable_output_count(self.0, &mut count))?;
413 Ok(count)
414 }
415
416 fn input_fact(&self, id: usize) -> Result<Self::Fact> {
417 let mut ptr = null_mut();
418 check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
419 Ok(Fact(ptr))
420 }
421
422 fn output_fact(&self, id: usize) -> Result<Self::Fact> {
423 let mut ptr = null_mut();
424 check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
425 Ok(Fact(ptr))
426 }
427
428 fn property_keys(&self) -> Result<Vec<String>> {
429 let mut len = 0;
430 check!(sys::tract_runnable_property_count(self.0, &mut len))?;
431 let mut keys = vec![null_mut(); len];
432 check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
433 unsafe {
434 keys.into_iter()
435 .map(|pc| {
436 let s = CStr::from_ptr(pc).to_str()?.to_owned();
437 sys::tract_free_cstring(pc);
438 Ok(s)
439 })
440 .collect()
441 }
442 }
443
444 fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
445 let mut v = null_mut();
446 let name = CString::new(name.as_ref())?;
447 check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
448 Ok(Tensor(v))
449 }
450
451 fn cost_json(&self) -> Result<String> {
452 let input: Option<Vec<Tensor>> = None;
453 self.profile_json(input)
454 }
455
456 fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
457 where
458 I: IntoIterator<Item = IV>,
459 IV: TryInto<Self::Tensor, Error = IE>,
460 IE: Into<anyhow::Error>,
461 {
462 let inputs = if let Some(inputs) = inputs {
463 let inputs = inputs
464 .into_iter()
465 .map(|i| i.try_into().map_err(|e| e.into()))
466 .collect::<Result<Vec<Tensor>>>()?;
467 anyhow::ensure!(self.input_count()? == inputs.len());
468 Some(inputs)
469 } else {
470 None
471 };
472 let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
473 inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
474 let mut json: *mut i8 = null_mut();
475 let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
476
477 check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
478 anyhow::ensure!(!json.is_null());
479 unsafe {
480 let s = CStr::from_ptr(json).to_owned();
481 sys::tract_free_cstring(json);
482 Ok(s.to_str()?.to_owned())
483 }
484 }
485}
486
487wrapper!(State, TractState, tract_state_destroy);
489
490impl StateInterface for State {
491 type Tensor = Tensor;
492 type Fact = Fact;
493
494 fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
495 let inputs = inputs.into_inputs()?;
496 let mut outputs = vec![null_mut(); self.output_count()?];
497 let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
498 check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
499 let outputs = outputs.into_iter().map(Tensor).collect();
500 Ok(outputs)
501 }
502
503 fn input_count(&self) -> Result<usize> {
504 let mut count = 0;
505 check!(sys::tract_state_input_count(self.0, &mut count))?;
506 Ok(count)
507 }
508
509 fn output_count(&self) -> Result<usize> {
510 let mut count = 0;
511 check!(sys::tract_state_output_count(self.0, &mut count))?;
512 Ok(count)
513 }
514}
515
516wrapper!(Tensor, TractTensor, tract_tensor_destroy);
518unsafe impl Send for Tensor {}
519unsafe impl Sync for Tensor {}
520
521impl TensorInterface for Tensor {
522 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
523 anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
524 let mut value = null_mut();
525 check!(sys::tract_tensor_from_bytes(
526 dt as _,
527 shape.len(),
528 shape.as_ptr(),
529 data.as_ptr() as _,
530 &mut value
531 ))?;
532 Ok(Tensor(value))
533 }
534
535 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
536 let mut rank = 0;
537 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
538 let mut shape = null();
539 let mut data = null();
540 check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
541 unsafe {
542 let dt: DatumType = std::mem::transmute(dt);
543 let shape = std::slice::from_raw_parts(shape, rank);
544 let len: usize = shape.iter().product();
545 let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
546 Ok((dt, shape, data))
547 }
548 }
549
550 fn datum_type(&self) -> Result<DatumType> {
551 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
552 check!(sys::tract_tensor_as_bytes(
553 self.0,
554 &mut dt,
555 std::ptr::null_mut(),
556 std::ptr::null_mut(),
557 std::ptr::null_mut()
558 ))?;
559 unsafe {
560 let dt: DatumType = std::mem::transmute(dt);
561 Ok(dt)
562 }
563 }
564
565 fn convert_to(&self, to: DatumType) -> Result<Self> {
566 let mut new = null_mut();
567 check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
568 Ok(Tensor(new))
569 }
570}
571
572impl PartialEq for Tensor {
573 fn eq(&self, other: &Self) -> bool {
574 let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
575 let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
576 me_dt == other_dt && me_shape == other_shape && me_data == other_data
577 }
578}
579
580tensor_from_to_ndarray!();
581
582wrapper!(Fact, TractFact, tract_fact_destroy);
584
585impl Fact {
586 fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
587 let cstr = CString::new(spec.to_string())?;
588 let mut fact = null_mut();
589 check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
590 Ok(Fact(fact))
591 }
592
593 fn dump(&self) -> Result<String> {
594 let mut ptr = null_mut();
595 check!(sys::tract_fact_dump(self.0, &mut ptr))?;
596 unsafe {
597 let s = CStr::from_ptr(ptr).to_owned();
598 sys::tract_free_cstring(ptr);
599 Ok(s.to_str()?.to_owned())
600 }
601 }
602}
603
604impl FactInterface for Fact {
605 type Dim = Dim;
606
607 fn datum_type(&self) -> Result<DatumType> {
608 let mut dt = 0u32;
609 check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
610 Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
611 }
612
613 fn rank(&self) -> Result<usize> {
614 let mut rank = 0;
615 check!(sys::tract_fact_rank(self.0, &mut rank))?;
616 Ok(rank)
617 }
618
619 fn dim(&self, axis: usize) -> Result<Self::Dim> {
620 let mut ptr = null_mut();
621 check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
622 Ok(Dim(ptr))
623 }
624}
625
626impl std::fmt::Display for Fact {
627 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
628 match self.dump() {
629 Ok(s) => f.write_str(&s),
630 Err(_) => Err(std::fmt::Error),
631 }
632 }
633}
634
635wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
637
638impl InferenceFact {
639 fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
640 let cstr = CString::new(spec.to_string())?;
641 let mut fact = null_mut();
642 check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
643 Ok(InferenceFact(fact))
644 }
645
646 fn dump(&self) -> Result<String> {
647 let mut ptr = null_mut();
648 check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
649 unsafe {
650 let s = CStr::from_ptr(ptr).to_owned();
651 sys::tract_free_cstring(ptr);
652 Ok(s.to_str()?.to_owned())
653 }
654 }
655}
656
657impl InferenceFactInterface for InferenceFact {
658 fn empty() -> Result<InferenceFact> {
659 let mut fact = null_mut();
660 check!(sys::tract_inference_fact_empty(&mut fact))?;
661 Ok(InferenceFact(fact))
662 }
663}
664
665impl Default for InferenceFact {
666 fn default() -> Self {
667 Self::empty().unwrap()
668 }
669}
670
671impl std::fmt::Display for InferenceFact {
672 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
673 match self.dump() {
674 Ok(s) => f.write_str(&s),
675 Err(_) => Err(std::fmt::Error),
676 }
677 }
678}
679
680as_inference_fact_impl!(InferenceModel, InferenceFact);
681as_fact_impl!(Model, Fact);
682
683wrapper!(Dim, TractDim, tract_dim_destroy);
685
686impl Dim {
687 fn dump(&self) -> Result<String> {
688 let mut ptr = null_mut();
689 check!(sys::tract_dim_dump(self.0, &mut ptr))?;
690 unsafe {
691 let s = CStr::from_ptr(ptr).to_owned();
692 sys::tract_free_cstring(ptr);
693 Ok(s.to_str()?.to_owned())
694 }
695 }
696}
697
698impl DimInterface for Dim {
699 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
700 let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
701 let c_strings: Vec<CString> =
702 names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
703 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
704 let mut ptr = null_mut();
705 check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
706 Ok(Dim(ptr))
707 }
708
709 fn to_int64(&self) -> Result<i64> {
710 let mut i = 0;
711 check!(sys::tract_dim_to_int64(self.0, &mut i))?;
712 Ok(i)
713 }
714}
715
716impl std::fmt::Display for Dim {
717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718 match self.dump() {
719 Ok(s) => f.write_str(&s),
720 Err(_) => Err(std::fmt::Error),
721 }
722 }
723}