1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::{null, null_mut};
4
5use tract_api::*;
6use tract_proxy_sys as sys;
7
8use anyhow::{Context, Result};
9use ndarray::*;
10
11macro_rules! check {
12 ($expr:expr) => {
13 unsafe {
14 if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
15 let buf = CStr::from_ptr(sys::tract_get_last_error());
16 Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
17 } else {
18 Ok(())
19 }
20 }
21 };
22}
23
24macro_rules! wrapper {
25 ($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
26 #[derive(Debug, Clone)]
27 pub struct $new_type(*mut sys::$c_type $(, $typ)*);
28
29 impl Drop for $new_type {
30 fn drop(&mut self) {
31 unsafe {
32 sys::$dest(&mut self.0);
33 }
34 }
35 }
36 };
37}
38
39pub fn nnef() -> Result<Nnef> {
40 let mut nnef = null_mut();
41 check!(sys::tract_nnef_create(&mut nnef))?;
42 Ok(Nnef(nnef))
43}
44
45pub fn onnx() -> Result<Onnx> {
46 let mut onnx = null_mut();
47 check!(sys::tract_onnx_create(&mut onnx))?;
48 Ok(Onnx(onnx))
49}
50
51pub fn version() -> &'static str {
52 unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
53}
54
55wrapper!(Nnef, TractNnef, tract_nnef_destroy);
56impl NnefInterface for Nnef {
57 type Model = Model;
58 fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
59 let path = path.as_ref();
60 let path = CString::new(
61 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
62 )?;
63 let mut model = null_mut();
64 check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
65 Ok(Model(model))
66 }
67
68 fn load_buffer(&self, data: &[u8]) -> Result<Model> {
69 let mut model = null_mut();
70 check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
71 Ok(Model(model))
72 }
73
74 fn enable_tract_core(&mut self) -> Result<()> {
75 check!(sys::tract_nnef_enable_tract_core(self.0))
76 }
77
78 fn enable_tract_extra(&mut self) -> Result<()> {
79 check!(sys::tract_nnef_enable_tract_extra(self.0))
80 }
81
82 fn enable_tract_transformers(&mut self) -> Result<()> {
83 check!(sys::tract_nnef_enable_tract_transformers(self.0))
84 }
85
86 fn enable_onnx(&mut self) -> Result<()> {
87 check!(sys::tract_nnef_enable_onnx(self.0))
88 }
89
90 fn enable_pulse(&mut self) -> Result<()> {
91 check!(sys::tract_nnef_enable_pulse(self.0))
92 }
93
94 fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
95 check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
96 }
97
98 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
99 let path = path.as_ref();
100 let path = CString::new(
101 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
102 )?;
103 check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
104 Ok(())
105 }
106
107 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
108 let path = path.as_ref();
109 let path = CString::new(
110 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
111 )?;
112 check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
113 Ok(())
114 }
115
116 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
117 let path = path.as_ref();
118 let path = CString::new(
119 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
120 )?;
121 check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
122 Ok(())
123 }
124}
125
126wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
128
129impl OnnxInterface for Onnx {
130 type InferenceModel = InferenceModel;
131 fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
132 let path = path.as_ref();
133 let path = CString::new(
134 path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
135 )?;
136 let mut model = null_mut();
137 check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
138 Ok(InferenceModel(model))
139 }
140
141 fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
142 let mut model = null_mut();
143 check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
144 Ok(InferenceModel(model))
145 }
146}
147
148wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
150impl InferenceModelInterface for InferenceModel {
151 type Model = Model;
152 type InferenceFact = InferenceFact;
153 fn set_output_names(
154 &mut self,
155 outputs: impl IntoIterator<Item = impl AsRef<str>>,
156 ) -> Result<()> {
157 let c_strings: Vec<CString> =
158 outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
159 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
160 check!(sys::tract_inference_model_set_output_names(
161 self.0,
162 c_strings.len(),
163 ptrs.as_ptr()
164 ))?;
165 Ok(())
166 }
167
168 fn input_count(&self) -> Result<usize> {
169 let mut count = 0;
170 check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
171 Ok(count)
172 }
173
174 fn output_count(&self) -> Result<usize> {
175 let mut count = 0;
176 check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
177 Ok(count)
178 }
179
180 fn input_name(&self, id: usize) -> Result<String> {
181 let mut ptr = null_mut();
182 check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
183 unsafe {
184 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
185 sys::tract_free_cstring(ptr);
186 Ok(ret)
187 }
188 }
189
190 fn output_name(&self, id: usize) -> Result<String> {
191 let mut ptr = null_mut();
192 check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
193 unsafe {
194 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
195 sys::tract_free_cstring(ptr);
196 Ok(ret)
197 }
198 }
199
200 fn input_fact(&self, id: usize) -> Result<InferenceFact> {
201 let mut ptr = null_mut();
202 check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
203 Ok(InferenceFact(ptr))
204 }
205
206 fn set_input_fact(
207 &mut self,
208 id: usize,
209 fact: impl AsFact<Self, Self::InferenceFact>,
210 ) -> Result<()> {
211 let fact = fact.as_fact(self)?;
212 check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
213 Ok(())
214 }
215
216 fn output_fact(&self, id: usize) -> Result<InferenceFact> {
217 let mut ptr = null_mut();
218 check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
219 Ok(InferenceFact(ptr))
220 }
221
222 fn set_output_fact(
223 &mut self,
224 id: usize,
225 fact: impl AsFact<InferenceModel, InferenceFact>,
226 ) -> Result<()> {
227 let fact = fact.as_fact(self)?;
228 check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
229 Ok(())
230 }
231
232 fn analyse(&mut self) -> Result<()> {
233 check!(sys::tract_inference_model_analyse(self.0))?;
234 Ok(())
235 }
236
237 fn into_tract(mut self) -> Result<Self::Model> {
238 let mut ptr = null_mut();
239 check!(sys::tract_inference_model_into_tract(&mut self.0, &mut ptr))?;
240 Ok(Model(ptr))
241 }
242}
243
244wrapper!(Model, TractModel, tract_model_destroy);
246
247impl ModelInterface for Model {
248 type Fact = Fact;
249 type Value = Value;
250 type Runnable = Runnable;
251 fn input_count(&self) -> Result<usize> {
252 let mut count = 0;
253 check!(sys::tract_model_input_count(self.0, &mut count))?;
254 Ok(count)
255 }
256
257 fn output_count(&self) -> Result<usize> {
258 let mut count = 0;
259 check!(sys::tract_model_output_count(self.0, &mut count))?;
260 Ok(count)
261 }
262
263 fn input_name(&self, id: usize) -> Result<String> {
264 let mut ptr = null_mut();
265 check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
266 unsafe {
267 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
268 sys::tract_free_cstring(ptr);
269 Ok(ret)
270 }
271 }
272
273 fn output_name(&self, id: usize) -> Result<String> {
274 let mut ptr = null_mut();
275 check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
276 unsafe {
277 let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
278 sys::tract_free_cstring(ptr);
279 Ok(ret)
280 }
281 }
282
283 fn set_output_names(
284 &mut self,
285 outputs: impl IntoIterator<Item = impl AsRef<str>>,
286 ) -> Result<()> {
287 let c_strings: Vec<CString> =
288 outputs.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
289 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
290 check!(sys::tract_model_set_output_names(self.0, c_strings.len(), ptrs.as_ptr()))?;
291 Ok(())
292 }
293
294 fn input_fact(&self, id: usize) -> Result<Fact> {
295 let mut ptr = null_mut();
296 check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
297 Ok(Fact(ptr))
298 }
299
300 fn output_fact(&self, id: usize) -> Result<Fact> {
301 let mut ptr = null_mut();
302 check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
303 Ok(Fact(ptr))
304 }
305
306 fn into_runnable(self) -> Result<Runnable> {
307 let mut model = self;
308 let mut runnable = null_mut();
309 check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
310 Ok(Runnable(runnable))
311 }
312
313 fn concretize_symbols(
314 &mut self,
315 values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
316 ) -> Result<()> {
317 let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
318 let c_strings: Vec<CString> =
319 names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
320 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
321 check!(sys::tract_model_concretize_symbols(
322 self.0,
323 ptrs.len(),
324 ptrs.as_ptr(),
325 values.as_ptr()
326 ))?;
327 Ok(())
328 }
329
330 fn transform(&mut self, transform: &str) -> Result<()> {
331 let t = CString::new(transform)?;
332 check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
333 Ok(())
334 }
335
336 fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
337 let name = CString::new(name.as_ref())?;
338 let value = CString::new(value.as_ref())?;
339 check!(sys::tract_model_pulse_simple(&mut self.0, name.as_ptr(), value.as_ptr()))?;
340 Ok(())
341 }
342
343 fn property_keys(&self) -> Result<Vec<String>> {
344 let mut len = 0;
345 check!(sys::tract_model_property_count(self.0, &mut len))?;
346 let mut keys = vec![null_mut(); len];
347 check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
348 unsafe {
349 keys.into_iter()
350 .map(|pc| {
351 let s = CStr::from_ptr(pc).to_str()?.to_owned();
352 sys::tract_free_cstring(pc);
353 Ok(s)
354 })
355 .collect()
356 }
357 }
358
359 fn property(&self, name: impl AsRef<str>) -> Result<Value> {
360 let mut v = null_mut();
361 let name = CString::new(name.as_ref())?;
362 check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
363 Ok(Value(v))
364 }
365
366 fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
367 let spec = CString::new(spec)?;
368 let mut ptr = null_mut();
369 check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
370 Ok(Fact(ptr))
371 }
372}
373
374wrapper!(Runtime, TractRuntime, tract_runtime_release);
376
377pub fn runtime_for_name(name: &str) -> Result<Runtime> {
378 let mut rt = null_mut();
379 let name = CString::new(name)?;
380 check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
381 Ok(Runtime(rt))
382}
383
384impl RuntimeInterface for Runtime {
385 type Runnable = Runnable;
386
387 type Model = Model;
388
389 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
390 let mut model = model;
391 let mut runnable = null_mut();
392 check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
393 Ok(Runnable(runnable))
394 }
395}
396
397wrapper!(Runnable, TractRunnable, tract_runnable_release);
399unsafe impl Send for Runnable {}
400unsafe impl Sync for Runnable {}
401
402impl RunnableInterface for Runnable {
403 type Value = Value;
404 type State = State;
405 type Fact = Fact;
406
407 fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Value>>
408 where
409 I: IntoIterator<Item = V>,
410 V: TryInto<Value, Error = E>,
411 E: Into<anyhow::Error>,
412 {
413 self.spawn_state()?.run(inputs)
414 }
415
416 fn spawn_state(&self) -> Result<State> {
417 let mut state = null_mut();
418 check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
419 Ok(State(state))
420 }
421
422 fn input_count(&self) -> Result<usize> {
423 let mut count = 0;
424 check!(sys::tract_runnable_input_count(self.0, &mut count))?;
425 Ok(count)
426 }
427
428 fn output_count(&self) -> Result<usize> {
429 let mut count = 0;
430 check!(sys::tract_runnable_output_count(self.0, &mut count))?;
431 Ok(count)
432 }
433
434 fn input_fact(&self, id: usize) -> Result<Self::Fact> {
435 let mut ptr = null_mut();
436 check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
437 Ok(Fact(ptr))
438 }
439
440 fn output_fact(&self, id: usize) -> Result<Self::Fact> {
441 let mut ptr = null_mut();
442 check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
443 Ok(Fact(ptr))
444 }
445
446 fn property_keys(&self) -> Result<Vec<String>> {
447 let mut len = 0;
448 check!(sys::tract_runnable_property_count(self.0, &mut len))?;
449 let mut keys = vec![null_mut(); len];
450 check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
451 unsafe {
452 keys.into_iter()
453 .map(|pc| {
454 let s = CStr::from_ptr(pc).to_str()?.to_owned();
455 sys::tract_free_cstring(pc);
456 Ok(s)
457 })
458 .collect()
459 }
460 }
461
462 fn property(&self, name: impl AsRef<str>) -> Result<Value> {
463 let mut v = null_mut();
464 let name = CString::new(name.as_ref())?;
465 check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
466 Ok(Value(v))
467 }
468
469 fn cost_json(&self) -> Result<String> {
470 let input: Option<Vec<Value>> = None;
471 let states: Option<Vec<Value>> = None;
472 self.profile_json(input, states)
473 }
474
475 fn profile_json<I, IV, IE, S, SV, SE>(
476 &self,
477 inputs: Option<I>,
478 state_initializers: Option<S>,
479 ) -> Result<String>
480 where
481 I: IntoIterator<Item = IV>,
482 IV: TryInto<Self::Value, Error = IE>,
483 IE: Into<anyhow::Error>,
484 S: IntoIterator<Item = SV>,
485 SV: TryInto<Self::Value, Error = SE>,
486 SE: Into<anyhow::Error>,
487 {
488 let inputs = if let Some(inputs) = inputs {
489 let inputs = inputs
490 .into_iter()
491 .map(|i| i.try_into().map_err(|e| e.into()))
492 .collect::<Result<Vec<Value>>>()?;
493 anyhow::ensure!(self.input_count()? == inputs.len());
494 Some(inputs)
495 } else {
496 None
497 };
498 let mut iptrs: Option<Vec<*mut sys::TractValue>> =
499 inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
500 let mut json: *mut i8 = null_mut();
501 let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
502
503 let (state_inits, n_states) = if let Some(state_vec) = state_initializers {
504 let mut states: Vec<*const _> = vec![];
505
506 for v in state_vec {
507 let val: Value = v.try_into().map_err(|e| e.into())?;
508 states.push(val.0);
509 }
510 let len = states.len();
511 (Some(states), len)
512 } else {
513 (None, 0)
514 };
515
516 let states = state_inits.map(|is| is.as_ptr()).unwrap_or(null());
517 check!(sys::tract_runnable_profile_json(self.0, values, states, n_states, &mut json))?;
518 anyhow::ensure!(!json.is_null());
519 unsafe {
520 let s = CStr::from_ptr(json).to_owned();
521 sys::tract_free_cstring(json);
522 Ok(s.to_str()?.to_owned())
523 }
524 }
525}
526
527wrapper!(State, TractState, tract_state_destroy);
529
530impl StateInterface for State {
531 type Value = Value;
532 type Fact = Fact;
533
534 fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Value>>
535 where
536 I: IntoIterator<Item = V>,
537 V: TryInto<Value, Error = E>,
538 E: Into<anyhow::Error>,
539 {
540 let inputs = inputs
541 .into_iter()
542 .map(|i| i.try_into().map_err(|e| e.into()))
543 .collect::<Result<Vec<Value>>>()?;
544 let mut outputs = vec![null_mut(); self.output_count()?];
545 let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
546 check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
547 let outputs = outputs.into_iter().map(Value).collect();
548 Ok(outputs)
549 }
550
551 fn input_count(&self) -> Result<usize> {
552 let mut count = 0;
553 check!(sys::tract_state_input_count(self.0, &mut count))?;
554 Ok(count)
555 }
556
557 fn output_count(&self) -> Result<usize> {
558 let mut count = 0;
559 check!(sys::tract_state_output_count(self.0, &mut count))?;
560 Ok(count)
561 }
562
563 #[allow(deprecated)]
564 fn initializable_states_count(&self) -> Result<usize> {
565 let mut n_states = 0;
566 check!(sys::tract_state_initializable_states_count(self.0, &mut n_states))?;
567 Ok(n_states)
568 }
569
570 #[allow(deprecated)]
571 fn get_states_facts(&self) -> Result<Vec<Fact>> {
572 let n_states = self.initializable_states_count()?;
573 let mut fptrs = vec![null_mut(); n_states];
574
575 check!(sys::tract_state_get_states_facts(self.0, fptrs.as_mut_ptr()))?;
576
577 let res = fptrs.into_iter().map(|value| Ok(Fact(value))).collect::<Result<Vec<Fact>>>();
578
579 res
580 }
581
582 #[allow(deprecated)]
583 fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
584 where
585 I: IntoIterator<Item = V>,
586 V: TryInto<Self::Value, Error = E>,
587 E: Into<anyhow::Error>,
588 {
589 let sptrs = {
590 let mut states: Vec<*const _> = vec![];
591
592 for s in state_initializers {
593 let val: Value = s.try_into().map_err(|e| e.into())?;
594 states.push(val.0);
595 }
596
597 let len = states.len();
598 anyhow::ensure!(
599 len == self.initializable_states_count()?,
600 "Expected {} states, got {len}",
601 self.initializable_states_count()?
602 );
603 Some(states)
604 };
605
606 let sptrs = sptrs.map(|it| it.as_ptr()).unwrap_or(null());
607 check!(sys::tract_state_set_states(self.0, sptrs))?;
608
609 Ok(())
610 }
611
612 #[allow(deprecated)]
613 fn get_states(&self) -> Result<Vec<Self::Value>> {
614 let n_states = self.initializable_states_count()?;
615
616 let mut sptrs = vec![null_mut(); n_states];
617 check!(sys::tract_state_get_states(self.0, sptrs.as_mut_ptr()))?;
618
619 let res = sptrs.into_iter().map(|value| Ok(Value(value))).collect::<Result<Vec<Value>>>();
620
621 res
622 }
623}
624
625wrapper!(Value, TractValue, tract_value_destroy);
627unsafe impl Send for Value {}
628unsafe impl Sync for Value {}
629
630impl ValueInterface for Value {
631 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
632 anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
633 let mut value = null_mut();
634 check!(sys::tract_value_from_bytes(
635 dt as _,
636 shape.len(),
637 shape.as_ptr(),
638 data.as_ptr() as _,
639 &mut value
640 ))?;
641 Ok(Value(value))
642 }
643
644 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
645 let mut rank = 0;
646 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
647 let mut shape = null();
648 let mut data = null();
649 check!(sys::tract_value_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
650 unsafe {
651 let dt: DatumType = std::mem::transmute(dt);
652 let shape = std::slice::from_raw_parts(shape, rank);
653 let len: usize = shape.iter().product();
654 let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
655 Ok((dt, shape, data))
656 }
657 }
658
659 fn datum_type(&self) -> Result<DatumType> {
660 let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
661 check!(sys::tract_value_as_bytes(
662 self.0,
663 &mut dt,
664 std::ptr::null_mut(),
665 std::ptr::null_mut(),
666 std::ptr::null_mut()
667 ))?;
668 unsafe {
669 let dt: DatumType = std::mem::transmute(dt);
670 Ok(dt)
671 }
672 }
673
674 fn convert_to(&self, to: DatumType) -> Result<Self> {
675 let mut new = null_mut();
676 check!(sys::tract_value_convert_to(self.0, to as _, &mut new))?;
677 Ok(Value(new))
678 }
679}
680
681impl PartialEq for Value {
682 fn eq(&self, other: &Self) -> bool {
683 let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
684 let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
685 me_dt == other_dt && me_shape == other_shape && me_data == other_data
686 }
687}
688
689value_from_to_ndarray!();
690
691wrapper!(Fact, TractFact, tract_fact_destroy);
693
694impl Fact {
695 fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
696 let cstr = CString::new(spec.to_string())?;
697 let mut fact = null_mut();
698 check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
699 Ok(Fact(fact))
700 }
701
702 fn dump(&self) -> Result<String> {
703 let mut ptr = null_mut();
704 check!(sys::tract_fact_dump(self.0, &mut ptr))?;
705 unsafe {
706 let s = CStr::from_ptr(ptr).to_owned();
707 sys::tract_free_cstring(ptr);
708 Ok(s.to_str()?.to_owned())
709 }
710 }
711}
712
713impl FactInterface for Fact {
714 type Dim = Dim;
715
716 fn datum_type(&self) -> Result<DatumType> {
717 let mut dt = 0u32;
718 check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
719 Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
720 }
721
722 fn rank(&self) -> Result<usize> {
723 let mut rank = 0;
724 check!(sys::tract_fact_rank(self.0, &mut rank))?;
725 Ok(rank)
726 }
727
728 fn dim(&self, axis: usize) -> Result<Self::Dim> {
729 let mut ptr = null_mut();
730 check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
731 Ok(Dim(ptr))
732 }
733}
734
735impl std::fmt::Display for Fact {
736 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
737 match self.dump() {
738 Ok(s) => f.write_str(&s),
739 Err(_) => Err(std::fmt::Error),
740 }
741 }
742}
743
744wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
746
747impl InferenceFact {
748 fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
749 let cstr = CString::new(spec.to_string())?;
750 let mut fact = null_mut();
751 check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
752 Ok(InferenceFact(fact))
753 }
754
755 fn dump(&self) -> Result<String> {
756 let mut ptr = null_mut();
757 check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
758 unsafe {
759 let s = CStr::from_ptr(ptr).to_owned();
760 sys::tract_free_cstring(ptr);
761 Ok(s.to_str()?.to_owned())
762 }
763 }
764}
765
766impl InferenceFactInterface for InferenceFact {
767 fn empty() -> Result<InferenceFact> {
768 let mut fact = null_mut();
769 check!(sys::tract_inference_fact_empty(&mut fact))?;
770 Ok(InferenceFact(fact))
771 }
772}
773
774impl Default for InferenceFact {
775 fn default() -> Self {
776 Self::empty().unwrap()
777 }
778}
779
780impl std::fmt::Display for InferenceFact {
781 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
782 match self.dump() {
783 Ok(s) => f.write_str(&s),
784 Err(_) => Err(std::fmt::Error),
785 }
786 }
787}
788
789as_inference_fact_impl!(InferenceModel, InferenceFact);
790as_fact_impl!(Model, Fact);
791
792wrapper!(Dim, TractDim, tract_dim_destroy);
794
795impl Dim {
796 fn dump(&self) -> Result<String> {
797 let mut ptr = null_mut();
798 check!(sys::tract_dim_dump(self.0, &mut ptr))?;
799 unsafe {
800 let s = CStr::from_ptr(ptr).to_owned();
801 sys::tract_free_cstring(ptr);
802 Ok(s.to_str()?.to_owned())
803 }
804 }
805}
806
807impl DimInterface for Dim {
808 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
809 let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
810 let c_strings: Vec<CString> =
811 names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
812 let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
813 let mut ptr = null_mut();
814 check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
815 Ok(Dim(ptr))
816 }
817
818 fn to_int64(&self) -> Result<i64> {
819 let mut i = 0;
820 check!(sys::tract_dim_to_int64(self.0, &mut i))?;
821 Ok(i)
822 }
823}
824
825impl std::fmt::Display for Dim {
826 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
827 match self.dump() {
828 Ok(s) => f.write_str(&s),
829 Err(_) => Err(std::fmt::Error),
830 }
831 }
832}