1#![allow(clippy::missing_safety_doc)]
2
3use anyhow::{Context, Result};
4use std::cell::RefCell;
5use std::ffi::{c_char, c_void, CStr, CString};
6use tract_api::{
7 AsFact, DatumType, InferenceModelInterface, ModelInterface, NnefInterface, OnnxInterface,
8 RunnableInterface, StateInterface, ValueInterface,
9};
10use tract_rs::{State, Value};
11
12#[repr(C)]
16#[allow(non_camel_case_types)]
17#[derive(Debug, PartialEq, Eq)]
18pub enum TRACT_RESULT {
19 TRACT_RESULT_OK = 0,
21 TRACT_RESULT_KO = 1,
23}
24
25thread_local! {
26 pub(crate) static LAST_ERROR: RefCell<Option<CString>> = const { RefCell::new(None) };
27}
28
29fn wrap<F: FnOnce() -> anyhow::Result<()>>(func: F) -> TRACT_RESULT {
30 match func() {
31 Ok(_) => TRACT_RESULT::TRACT_RESULT_OK,
32 Err(e) => {
33 let msg = format!("{e:?}");
34 if std::env::var("TRACT_ERROR_STDERR").is_ok() {
35 eprintln!("{msg}");
36 }
37 LAST_ERROR.with(|p| {
38 *p.borrow_mut() = Some(CString::new(msg).unwrap_or_else(|_| {
39 CString::new("tract error message contains 0, can't convert to CString")
40 .unwrap()
41 }))
42 });
43 TRACT_RESULT::TRACT_RESULT_KO
44 }
45 }
46}
47
48#[unsafe(no_mangle)]
57pub extern "C" fn tract_get_last_error() -> *const std::ffi::c_char {
58 LAST_ERROR.with(|msg| msg.borrow().as_ref().map(|s| s.as_ptr()).unwrap_or(std::ptr::null()))
59}
60
61#[unsafe(no_mangle)]
65pub extern "C" fn tract_version() -> *const std::ffi::c_char {
66 unsafe {
67 CStr::from_bytes_with_nul_unchecked(concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes())
68 .as_ptr()
69 }
70}
71
72#[unsafe(no_mangle)]
74pub unsafe extern "C" fn tract_free_cstring(ptr: *mut std::ffi::c_char) {
75 unsafe {
76 if !ptr.is_null() {
77 let _ = CString::from_raw(ptr);
78 }
79 }
80}
81
82macro_rules! check_not_null {
83 ($($ptr:expr),*) => {
84 $(
85 if $ptr.is_null() {
86 anyhow::bail!(concat!("Unexpected null pointer ", stringify!($ptr)));
87 }
88 )*
89 }
90}
91
92macro_rules! release {
93 ($ptr:expr) => {
94 wrap(|| unsafe {
95 check_not_null!($ptr, *$ptr);
96 let _ = Box::from_raw(*$ptr);
97 *$ptr = std::ptr::null_mut();
98 Ok(())
99 })
100 };
101}
102
103pub struct TractNnef(tract_rs::Nnef);
105
106#[unsafe(no_mangle)]
111pub unsafe extern "C" fn tract_nnef_create(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
112 wrap(|| unsafe {
113 check_not_null!(nnef);
114 *nnef = Box::into_raw(Box::new(TractNnef(tract_rs::nnef()?)));
115 Ok(())
116 })
117}
118
119#[unsafe(no_mangle)]
120pub unsafe extern "C" fn tract_nnef_transform_model(
121 nnef: *const TractNnef,
122 model: *mut TractModel,
123 transform_spec: *const i8,
124) -> TRACT_RESULT {
125 wrap(|| unsafe {
126 check_not_null!(nnef, model, transform_spec);
127 let transform_spec = CStr::from_ptr(transform_spec as _).to_str()?;
128 (*nnef)
129 .0
130 .transform_model(&mut (*model).0, transform_spec)
131 .with_context(|| format!("performing transform {transform_spec:?}"))?;
132 Ok(())
133 })
134}
135
136#[unsafe(no_mangle)]
137pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> TRACT_RESULT {
138 wrap(|| unsafe {
139 check_not_null!(nnef);
140 (*nnef).0.enable_tract_core()
141 })
142}
143
144#[unsafe(no_mangle)]
145pub unsafe extern "C" fn tract_nnef_enable_tract_extra(nnef: *mut TractNnef) -> TRACT_RESULT {
146 wrap(|| unsafe {
147 check_not_null!(nnef);
148 (*nnef).0.enable_tract_extra()
149 })
150}
151
152#[unsafe(no_mangle)]
153pub unsafe extern "C" fn tract_nnef_enable_tract_transformers(
154 nnef: *mut TractNnef,
155) -> TRACT_RESULT {
156 wrap(|| unsafe {
157 check_not_null!(nnef);
158 (*nnef).0.enable_tract_transformers()
159 })
160}
161
162#[unsafe(no_mangle)]
163pub unsafe extern "C" fn tract_nnef_enable_onnx(nnef: *mut TractNnef) -> TRACT_RESULT {
164 wrap(|| unsafe {
165 check_not_null!(nnef);
166 (*nnef).0.enable_onnx()
167 })
168}
169
170#[unsafe(no_mangle)]
171pub unsafe extern "C" fn tract_nnef_enable_pulse(nnef: *mut TractNnef) -> TRACT_RESULT {
172 wrap(|| unsafe {
173 check_not_null!(nnef);
174 (*nnef).0.enable_pulse()
175 })
176}
177
178#[unsafe(no_mangle)]
179pub unsafe extern "C" fn tract_nnef_enable_extended_identifier_syntax(
180 nnef: *mut TractNnef,
181) -> TRACT_RESULT {
182 wrap(|| unsafe {
183 check_not_null!(nnef);
184 (*nnef).0.enable_extended_identifier_syntax()
185 })
186}
187
188#[unsafe(no_mangle)]
190pub unsafe extern "C" fn tract_nnef_destroy(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
191 release!(nnef)
192}
193
194#[unsafe(no_mangle)]
199pub unsafe extern "C" fn tract_nnef_model_for_path(
200 nnef: *const TractNnef,
201 path: *const c_char,
202 model: *mut *mut TractModel,
203) -> TRACT_RESULT {
204 wrap(|| unsafe {
205 check_not_null!(nnef, model, path);
206 *model = std::ptr::null_mut();
207 let path = CStr::from_ptr(path).to_str()?;
208 let m = Box::new(TractModel(
209 (*nnef).0.model_for_path(path).with_context(|| format!("opening file {path:?}"))?,
210 ));
211 *model = Box::into_raw(m);
212 Ok(())
213 })
214}
215
216#[unsafe(no_mangle)]
222pub unsafe extern "C" fn tract_nnef_write_model_to_tar(
223 nnef: *const TractNnef,
224 path: *const c_char,
225 model: *const TractModel,
226) -> TRACT_RESULT {
227 wrap(|| unsafe {
228 check_not_null!(nnef, model, path);
229 let path = CStr::from_ptr(path).to_str()?;
230 (*nnef).0.write_model_to_tar(path, &(*model).0)?;
231 Ok(())
232 })
233}
234
235#[unsafe(no_mangle)]
239pub unsafe extern "C" fn tract_nnef_write_model_to_tar_gz(
240 nnef: *const TractNnef,
241 path: *const c_char,
242 model: *const TractModel,
243) -> TRACT_RESULT {
244 wrap(|| unsafe {
245 check_not_null!(nnef, model, path);
246 let path = CStr::from_ptr(path).to_str()?;
247 (*nnef).0.write_model_to_tar_gz(path, &(*model).0)?;
248 Ok(())
249 })
250}
251
252#[unsafe(no_mangle)]
258pub unsafe extern "C" fn tract_nnef_write_model_to_dir(
259 nnef: *const TractNnef,
260 path: *const c_char,
261 model: *const TractModel,
262) -> TRACT_RESULT {
263 wrap(|| unsafe {
264 check_not_null!(nnef, model, path);
265 let path = CStr::from_ptr(path).to_str()?;
266 (*nnef).0.write_model_to_dir(path, &(*model).0)?;
267 Ok(())
268 })
269}
270
271pub struct TractOnnx(tract_rs::Onnx);
273
274#[unsafe(no_mangle)]
279pub unsafe extern "C" fn tract_onnx_create(onnx: *mut *mut TractOnnx) -> TRACT_RESULT {
280 wrap(|| unsafe {
281 check_not_null!(onnx);
282 *onnx = Box::into_raw(Box::new(TractOnnx(tract_rs::onnx()?)));
283 Ok(())
284 })
285}
286
287#[unsafe(no_mangle)]
289pub unsafe extern "C" fn tract_onnx_destroy(onnx: *mut *mut TractOnnx) -> TRACT_RESULT {
290 release!(onnx)
291}
292
293#[unsafe(no_mangle)]
297pub unsafe extern "C" fn tract_onnx_model_for_path(
298 onnx: *const TractOnnx,
299 path: *const c_char,
300 model: *mut *mut TractInferenceModel,
301) -> TRACT_RESULT {
302 wrap(|| unsafe {
303 check_not_null!(onnx, path, model);
304 *model = std::ptr::null_mut();
305 let path = CStr::from_ptr(path).to_str()?;
306 let m = Box::new(TractInferenceModel((*onnx).0.model_for_path(path)?));
307 *model = Box::into_raw(m);
308 Ok(())
309 })
310}
311
312pub struct TractInferenceModel(tract_rs::InferenceModel);
314
315#[unsafe(no_mangle)]
317pub unsafe extern "C" fn tract_inference_model_input_count(
318 model: *const TractInferenceModel,
319 inputs: *mut usize,
320) -> TRACT_RESULT {
321 wrap(|| unsafe {
322 check_not_null!(model, inputs);
323 let model = &(*model).0;
324 *inputs = model.input_count()?;
325 Ok(())
326 })
327}
328
329#[unsafe(no_mangle)]
331pub unsafe extern "C" fn tract_inference_model_output_count(
332 model: *const TractInferenceModel,
333 outputs: *mut usize,
334) -> TRACT_RESULT {
335 wrap(|| unsafe {
336 check_not_null!(model, outputs);
337 let model = &(*model).0;
338 *outputs = model.output_count()?;
339 Ok(())
340 })
341}
342
343#[unsafe(no_mangle)]
347pub unsafe extern "C" fn tract_inference_model_input_name(
348 model: *const TractInferenceModel,
349 input: usize,
350 name: *mut *mut c_char,
351) -> TRACT_RESULT {
352 wrap(|| unsafe {
353 check_not_null!(model, name);
354 *name = std::ptr::null_mut();
355 let m = &(*model).0;
356 *name = CString::new(&*m.input_name(input)?)?.into_raw();
357 Ok(())
358 })
359}
360
361#[unsafe(no_mangle)]
365pub unsafe extern "C" fn tract_inference_model_output_name(
366 model: *const TractInferenceModel,
367 output: usize,
368 name: *mut *mut i8,
369) -> TRACT_RESULT {
370 wrap(|| unsafe {
371 check_not_null!(model, name);
372 *name = std::ptr::null_mut();
373 let m = &(*model).0;
374 *name = CString::new(&*m.output_name(output)?)?.into_raw() as _;
375 Ok(())
376 })
377}
378
379#[unsafe(no_mangle)]
380pub unsafe extern "C" fn tract_inference_model_input_fact(
381 model: *const TractInferenceModel,
382 input_id: usize,
383 fact: *mut *mut TractInferenceFact,
384) -> TRACT_RESULT {
385 wrap(|| unsafe {
386 check_not_null!(model, fact);
387 *fact = std::ptr::null_mut();
388 let f = (*model).0.input_fact(input_id)?;
389 *fact = Box::into_raw(Box::new(TractInferenceFact(f)));
390 Ok(())
391 })
392}
393
394#[unsafe(no_mangle)]
399pub unsafe extern "C" fn tract_inference_model_set_input_fact(
400 model: *mut TractInferenceModel,
401 input_id: usize,
402 fact: *const TractInferenceFact,
403) -> TRACT_RESULT {
404 wrap(|| unsafe {
405 check_not_null!(model);
406 let f = fact.as_ref().map(|f| &f.0).cloned().unwrap_or_default();
407 (*model).0.set_input_fact(input_id, f)?;
408 Ok(())
409 })
410}
411
412#[unsafe(no_mangle)]
416pub unsafe extern "C" fn tract_inference_model_set_output_names(
417 model: *mut TractInferenceModel,
418 len: usize,
419 names: *const *const c_char,
420) -> TRACT_RESULT {
421 wrap(|| unsafe {
422 check_not_null!(model, names, *names);
423 let node_names = (0..len)
424 .map(|i| Ok(CStr::from_ptr(*names.add(i)).to_str()?.to_owned()))
425 .collect::<Result<Vec<_>>>()?;
426 (*model).0.set_output_names(&node_names)?;
427 Ok(())
428 })
429}
430
431#[unsafe(no_mangle)]
435pub unsafe extern "C" fn tract_inference_model_output_fact(
436 model: *const TractInferenceModel,
437 output_id: usize,
438 fact: *mut *mut TractInferenceFact,
439) -> TRACT_RESULT {
440 wrap(|| unsafe {
441 check_not_null!(model, fact);
442 *fact = std::ptr::null_mut();
443 let f = (*model).0.output_fact(output_id)?;
444 *fact = Box::into_raw(Box::new(TractInferenceFact(f)));
445 Ok(())
446 })
447}
448
449#[unsafe(no_mangle)]
454pub unsafe extern "C" fn tract_inference_model_set_output_fact(
455 model: *mut TractInferenceModel,
456 output_id: usize,
457 fact: *const TractInferenceFact,
458) -> TRACT_RESULT {
459 wrap(|| unsafe {
460 check_not_null!(model);
461 let f = fact.as_ref().map(|f| &f.0).cloned().unwrap_or_default();
462 (*model).0.set_output_fact(output_id, f)?;
463 Ok(())
464 })
465}
466
467#[unsafe(no_mangle)]
469pub unsafe extern "C" fn tract_inference_model_analyse(
470 model: *mut TractInferenceModel,
471) -> TRACT_RESULT {
472 wrap(|| unsafe {
473 check_not_null!(model);
474 (*model).0.analyse()?;
475 Ok(())
476 })
477}
478
479#[unsafe(no_mangle)]
486pub unsafe extern "C" fn tract_inference_model_into_optimized(
487 model: *mut *mut TractInferenceModel,
488 optimized: *mut *mut TractModel,
489) -> TRACT_RESULT {
490 wrap(|| unsafe {
491 check_not_null!(model, *model, optimized);
492 *optimized = std::ptr::null_mut();
493 let m = Box::from_raw(*model);
494 *model = std::ptr::null_mut();
495 let result = m.0.into_optimized()?;
496 *optimized = Box::into_raw(Box::new(TractModel(result))) as _;
497 Ok(())
498 })
499}
500
501#[unsafe(no_mangle)]
508pub unsafe extern "C" fn tract_inference_model_into_typed(
509 model: *mut *mut TractInferenceModel,
510 typed: *mut *mut TractModel,
511) -> TRACT_RESULT {
512 wrap(|| unsafe {
513 check_not_null!(model, *model, typed);
514 *typed = std::ptr::null_mut();
515 let m = Box::from_raw(*model);
516 *model = std::ptr::null_mut();
517 let result = m.0.into_typed()?;
518 *typed = Box::into_raw(Box::new(TractModel(result))) as _;
519 Ok(())
520 })
521}
522
523#[unsafe(no_mangle)]
525pub unsafe extern "C" fn tract_inference_model_destroy(
526 model: *mut *mut TractInferenceModel,
527) -> TRACT_RESULT {
528 release!(model)
529}
530pub struct TractModel(tract_rs::Model);
533
534#[unsafe(no_mangle)]
536pub unsafe extern "C" fn tract_model_input_count(
537 model: *const TractModel,
538 inputs: *mut usize,
539) -> TRACT_RESULT {
540 wrap(|| unsafe {
541 check_not_null!(model, inputs);
542 let model = &(*model).0;
543 *inputs = model.input_count()?;
544 Ok(())
545 })
546}
547
548#[unsafe(no_mangle)]
550pub unsafe extern "C" fn tract_model_output_count(
551 model: *const TractModel,
552 outputs: *mut usize,
553) -> TRACT_RESULT {
554 wrap(|| unsafe {
555 check_not_null!(model, outputs);
556 let model = &(*model).0;
557 *outputs = model.output_count()?;
558 Ok(())
559 })
560}
561
562#[unsafe(no_mangle)]
566pub unsafe extern "C" fn tract_model_input_name(
567 model: *const TractModel,
568 input: usize,
569 name: *mut *mut c_char,
570) -> TRACT_RESULT {
571 wrap(|| unsafe {
572 check_not_null!(model, name);
573 *name = std::ptr::null_mut();
574 let m = &(*model).0;
575 *name = CString::new(m.input_name(input)?)?.into_raw();
576 Ok(())
577 })
578}
579
580#[unsafe(no_mangle)]
584pub unsafe extern "C" fn tract_model_input_fact(
585 model: *const TractModel,
586 input_id: usize,
587 fact: *mut *mut TractFact,
588) -> TRACT_RESULT {
589 wrap(|| unsafe {
590 check_not_null!(model, fact);
591 *fact = std::ptr::null_mut();
592 let f = (*model).0.input_fact(input_id)?;
593 *fact = Box::into_raw(Box::new(TractFact(f)));
594 Ok(())
595 })
596}
597
598#[unsafe(no_mangle)]
602pub unsafe extern "C" fn tract_model_output_name(
603 model: *const TractModel,
604 output: usize,
605 name: *mut *mut c_char,
606) -> TRACT_RESULT {
607 wrap(|| unsafe {
608 check_not_null!(model, name);
609 *name = std::ptr::null_mut();
610 let m = &(*model).0;
611 *name = CString::new(m.output_name(output)?)?.into_raw();
612 Ok(())
613 })
614}
615
616#[unsafe(no_mangle)]
620pub unsafe extern "C" fn tract_model_output_fact(
621 model: *const TractModel,
622 input_id: usize,
623 fact: *mut *mut TractFact,
624) -> TRACT_RESULT {
625 wrap(|| unsafe {
626 check_not_null!(model, fact);
627 *fact = std::ptr::null_mut();
628 let f = (*model).0.output_fact(input_id)?;
629 *fact = Box::into_raw(Box::new(TractFact(f)));
630 Ok(())
631 })
632}
633
634#[unsafe(no_mangle)]
638pub unsafe extern "C" fn tract_model_set_output_names(
639 model: *mut TractModel,
640 len: usize,
641 names: *const *const c_char,
642) -> TRACT_RESULT {
643 wrap(|| unsafe {
644 check_not_null!(model, names, *names);
645 let node_names = (0..len)
646 .map(|i| Ok(CStr::from_ptr(*names.add(i)).to_str()?.to_owned()))
647 .collect::<Result<Vec<_>>>()?;
648 (*model).0.set_output_names(&node_names)
649 })
650}
651
652#[unsafe(no_mangle)]
658pub unsafe extern "C" fn tract_model_concretize_symbols(
659 model: *mut TractModel,
660 nb_symbols: usize,
661 symbols: *const *const i8,
662 values: *const i64,
663) -> TRACT_RESULT {
664 wrap(|| unsafe {
665 check_not_null!(model, symbols, values);
666 let model = &mut (*model).0;
667 let mut table = vec![];
668 for i in 0..nb_symbols {
669 let name = CStr::from_ptr(*symbols.add(i) as _)
670 .to_str()
671 .with_context(|| {
672 format!("failed to parse symbol name for {i}th symbol (not utf8)")
673 })?
674 .to_owned();
675 table.push((name, *values.add(i)));
676 }
677 model.concretize_symbols(table)
678 })
679}
680
681#[unsafe(no_mangle)]
686pub unsafe extern "C" fn tract_model_pulse_simple(
687 model: *mut *mut TractModel,
688 stream_symbol: *const i8,
689 pulse_expr: *const i8,
690) -> TRACT_RESULT {
691 wrap(|| unsafe {
692 check_not_null!(model, *model, stream_symbol, pulse_expr);
693 let model = &mut (**model).0;
694 let stream_sym = CStr::from_ptr(stream_symbol as _)
695 .to_str()
696 .context("failed to parse stream symbol name (not utf8)")?;
697 let pulse_dim = CStr::from_ptr(pulse_expr as _)
698 .to_str()
699 .context("failed to parse stream symbol name (not utf8)")?;
700 model.pulse(stream_sym, pulse_dim)
701 })
702}
703
704#[unsafe(no_mangle)]
706pub unsafe extern "C" fn tract_model_transform(
707 model: *mut TractModel,
708 transform: *const i8,
709) -> TRACT_RESULT {
710 wrap(|| unsafe {
711 check_not_null!(model, transform);
712 let t = CStr::from_ptr(transform as _)
713 .to_str()
714 .context("failed to parse transform name (not utf8)")?;
715 (*model).0.transform(t)
716 })
717}
718
719#[unsafe(no_mangle)]
721pub unsafe extern "C" fn tract_model_declutter(model: *mut TractModel) -> TRACT_RESULT {
722 wrap(|| unsafe {
723 check_not_null!(model);
724 (*model).0.declutter()
725 })
726}
727
728#[unsafe(no_mangle)]
730pub unsafe extern "C" fn tract_model_optimize(model: *mut TractModel) -> TRACT_RESULT {
731 wrap(|| unsafe {
732 check_not_null!(model);
733 (*model).0.optimize()
734 })
735}
736
737#[unsafe(no_mangle)]
739pub unsafe extern "C" fn tract_model_profile_json(
740 model: *mut TractModel,
741 inputs: *mut *mut TractValue,
742 states: *const *const TractValue,
743 n_states: usize,
744 json: *mut *mut i8,
745) -> TRACT_RESULT {
746 wrap(|| unsafe {
747 check_not_null!(model, json);
748
749 let input: Option<Vec<Value>> = if !inputs.is_null() {
750 let input_len = (*model).0.input_count()?;
751 Some(
752 std::slice::from_raw_parts(inputs, input_len)
753 .iter()
754 .map(|tv| (**tv).0.clone())
755 .collect(),
756 )
757 } else {
758 None
759 };
760
761 let state_initializers: Option<Vec<Value>> = if !states.is_null() {
762 anyhow::ensure!(n_states != 0);
763 let hashmap = std::slice::from_raw_parts(states, n_states).iter()
764 .map(|tv| {
765 (**tv).0.clone()
766 }).collect();
767 Some(hashmap)
768 } else { None };
769
770 let profile = (*model).0.profile_json(input, state_initializers)?;
771 *json = CString::new(profile)?.into_raw() as _;
772 Ok(())
773 })
774}
775
776#[unsafe(no_mangle)]
782pub unsafe extern "C" fn tract_model_into_runnable(
783 model: *mut *mut TractModel,
784 runnable: *mut *mut TractRunnable,
785) -> TRACT_RESULT {
786 wrap(|| unsafe {
787 check_not_null!(model, runnable);
788 let m = Box::from_raw(*model).0;
789 *model = std::ptr::null_mut();
790 *runnable = Box::into_raw(Box::new(TractRunnable(m.into_runnable()?))) as _;
791 Ok(())
792 })
793}
794
795#[unsafe(no_mangle)]
797pub unsafe extern "C" fn tract_model_property_count(
798 model: *const TractModel,
799 count: *mut usize,
800) -> TRACT_RESULT {
801 wrap(|| unsafe {
802 check_not_null!(model, count);
803 *count = (*model).0.property_keys()?.len();
804 Ok(())
805 })
806}
807
808#[unsafe(no_mangle)]
814pub unsafe extern "C" fn tract_model_property_names(
815 model: *const TractModel,
816 names: *mut *mut i8,
817) -> TRACT_RESULT {
818 wrap(|| unsafe {
819 check_not_null!(model, names);
820 for (ix, name) in (*model).0.property_keys()?.iter().enumerate() {
821 *names.add(ix) = CString::new(&**name)?.into_raw() as _;
822 }
823 Ok(())
824 })
825}
826
827#[unsafe(no_mangle)]
829pub unsafe extern "C" fn tract_model_property(
830 model: *const TractModel,
831 name: *const i8,
832 value: *mut *mut TractValue,
833) -> TRACT_RESULT {
834 wrap(|| unsafe {
835 check_not_null!(model, name, value);
836 let name = CStr::from_ptr(name as _)
837 .to_str()
838 .context("failed to parse property name (not utf8)")?
839 .to_owned();
840 let v = (*model).0.property(name).context("Property not found")?;
841 *value = Box::into_raw(Box::new(TractValue(v)));
842 Ok(())
843 })
844}
845
846#[unsafe(no_mangle)]
848pub unsafe extern "C" fn tract_model_destroy(model: *mut *mut TractModel) -> TRACT_RESULT {
849 release!(model)
850}
851
852pub struct TractRunnable(tract_rs::Runnable);
854
855#[unsafe(no_mangle)]
864pub unsafe extern "C" fn tract_runnable_spawn_state(
865 runnable: *mut TractRunnable,
866 state: *mut *mut TractState,
867) -> TRACT_RESULT {
868 wrap(|| unsafe {
869 check_not_null!(runnable, state);
870 *state = std::ptr::null_mut();
871 let s = (*runnable).0.spawn_state()?;
872 *state = Box::into_raw(Box::new(TractState(s)));
873 Ok(())
874 })
875}
876
877#[unsafe(no_mangle)]
886pub unsafe extern "C" fn tract_runnable_run(
887 runnable: *mut TractRunnable,
888 inputs: *mut *mut TractValue,
889 outputs: *mut *mut TractValue,
890) -> TRACT_RESULT {
891 wrap(|| unsafe {
892 check_not_null!(runnable);
893 let mut s = (*runnable).0.spawn_state()?;
894 state_run(&mut s, inputs, outputs)
895 })
896}
897
898#[unsafe(no_mangle)]
900pub unsafe extern "C" fn tract_runnable_input_count(
901 model: *const TractRunnable,
902 inputs: *mut usize,
903) -> TRACT_RESULT {
904 wrap(|| unsafe {
905 check_not_null!(model, inputs);
906 let model = &(*model).0;
907 *inputs = model.input_count()?;
908 Ok(())
909 })
910}
911
912#[unsafe(no_mangle)]
914pub unsafe extern "C" fn tract_runnable_output_count(
915 model: *const TractRunnable,
916 outputs: *mut usize,
917) -> TRACT_RESULT {
918 wrap(|| unsafe {
919 check_not_null!(model, outputs);
920 let model = &(*model).0;
921 *outputs = model.output_count()?;
922 Ok(())
923 })
924}
925
926#[unsafe(no_mangle)]
927pub unsafe extern "C" fn tract_runnable_release(runnable: *mut *mut TractRunnable) -> TRACT_RESULT {
928 release!(runnable)
929}
930
931pub struct TractValue(tract_rs::Value);
933
934#[unsafe(no_mangle)]
943pub unsafe extern "C" fn tract_value_from_bytes(
944 datum_type: DatumType,
945 rank: usize,
946 shape: *const usize,
947 data: *mut c_void,
948 value: *mut *mut TractValue,
949) -> TRACT_RESULT {
950 wrap(|| unsafe {
951 check_not_null!(value);
952 *value = std::ptr::null_mut();
953 let shape = std::slice::from_raw_parts(shape, rank);
954 let len = shape.iter().product::<usize>();
955 let data = std::slice::from_raw_parts(data as *const u8, len * datum_type.size_of());
956 let it = Value::from_bytes(datum_type, shape, data)?;
957 *value = Box::into_raw(Box::new(TractValue(it)));
958 Ok(())
959 })
960}
961
962#[unsafe(no_mangle)]
964pub unsafe extern "C" fn tract_value_destroy(value: *mut *mut TractValue) -> TRACT_RESULT {
965 release!(value)
966}
967
968#[unsafe(no_mangle)]
971pub unsafe extern "C" fn tract_value_as_bytes(
972 value: *mut TractValue,
973 datum_type: *mut DatumType,
974 rank: *mut usize,
975 shape: *mut *const usize,
976 data: *mut *const c_void,
977) -> TRACT_RESULT {
978 wrap(|| unsafe {
979 check_not_null!(value);
980 let value = &(*value).0;
981 let bits = value.as_bytes()?;
982 if !datum_type.is_null() {
983 *datum_type = bits.0;
984 }
985 if !rank.is_null() {
986 *rank = bits.1.len();
987 }
988 if !shape.is_null() {
989 *shape = bits.1.as_ptr();
990 }
991 if !data.is_null() {
992 *data = bits.2.as_ptr() as _;
993 }
994 Ok(())
995 })
996}
997
998pub struct TractState(tract_rs::State);
1000
1001#[unsafe(no_mangle)]
1010pub unsafe extern "C" fn tract_state_run(
1011 state: *mut TractState,
1012 inputs: *mut *mut TractValue,
1013 outputs: *mut *mut TractValue,
1014) -> TRACT_RESULT {
1015 wrap(|| unsafe {
1016 check_not_null!(state, inputs, outputs);
1017 state_run(&mut (*state).0, inputs, outputs)
1018 })
1019}
1020
1021#[unsafe(no_mangle)]
1023pub unsafe extern "C" fn tract_state_input_count(
1024 state: *const TractState,
1025 inputs: *mut usize,
1026) -> TRACT_RESULT {
1027 wrap(|| unsafe {
1028 check_not_null!(state, inputs);
1029 let state = &(*state).0;
1030 *inputs = state.input_count()?;
1031 Ok(())
1032 })
1033}
1034
1035#[unsafe(no_mangle)]
1037pub unsafe extern "C" fn tract_state_output_count(
1038 state: *const TractState,
1039 outputs: *mut usize,
1040) -> TRACT_RESULT {
1041 wrap(|| unsafe {
1042 check_not_null!(state, outputs);
1043 let state = &(*state).0;
1044 *outputs = state.output_count()?;
1045 Ok(())
1046 })
1047}
1048
1049#[unsafe(no_mangle)]
1050pub unsafe extern "C" fn tract_state_destroy(state: *mut *mut TractState) -> TRACT_RESULT {
1051 release!(state)
1052}
1053
1054#[unsafe(no_mangle)]
1056pub unsafe extern "C" fn tract_state_initializable_states_count(
1057 state: *const TractState,
1058 n_states: *mut usize,
1059) -> TRACT_RESULT {
1060 wrap(|| unsafe {
1061 check_not_null!(state, n_states);
1062 let state = &(*state).0;
1063 *n_states = state.initializable_states_count()?;
1064 Ok(())
1065 })
1066}
1067
1068#[unsafe(no_mangle)]
1070pub unsafe extern "C" fn tract_state_get_states_facts(
1071 state: *const TractState,
1072 states: *mut *mut TractFact,
1073) -> TRACT_RESULT {
1074 wrap(|| unsafe {
1075 check_not_null!(state, states);
1076 let state = &(*state).0;
1077
1078 let state_vec = state.get_states_facts()?;
1079 for (ix, f) in state_vec.into_iter().enumerate() {
1080 *states.add(ix) = Box::into_raw(Box::new(TractFact(f)));
1081 }
1082 Ok(())
1083 })
1084}
1085
1086#[unsafe(no_mangle)]
1088pub unsafe extern "C" fn tract_state_set_states(
1089 state: *mut TractState,
1090 states: *const *const TractValue,
1091) -> TRACT_RESULT {
1092 wrap(|| unsafe {
1093 check_not_null!(state, states);
1094 let state = &mut (*state).0;
1095
1096 let n_states = state.initializable_states_count()?;
1097 let state_initializers: Vec<Value> =
1098 std::slice::from_raw_parts(states, n_states).iter()
1099 .map(|tv| {
1100 (**tv).0.clone()
1101 }).collect();
1102 state.set_states(state_initializers)?;
1103 Ok(())
1104 })
1105}
1106
1107#[unsafe(no_mangle)]
1109pub unsafe extern "C" fn tract_state_get_states(
1110 state: *const TractState,
1111 states: *mut *mut TractValue
1112) -> TRACT_RESULT {
1113 wrap(|| unsafe {
1114 let state = &(*state).0;
1115
1116 let state_vec = state.get_states()?;
1117 for (ix, s) in state_vec.into_iter().enumerate() {
1118 *states.add(ix) = Box::into_raw(Box::new(TractValue(s)));
1119 }
1120 Ok(())
1121 })
1122}
1123
1124pub struct TractFact(tract_rs::Fact);
1126
1127#[unsafe(no_mangle)]
1131pub unsafe extern "C" fn tract_fact_parse(
1132 model: *mut TractModel,
1133 spec: *const c_char,
1134 fact: *mut *mut TractFact,
1135) -> TRACT_RESULT {
1136 wrap(|| unsafe {
1137 check_not_null!(model, spec, fact);
1138 let spec = CStr::from_ptr(spec).to_str()?;
1139 let f: tract_rs::Fact = spec.as_fact(&mut (*model).0)?.as_ref().clone();
1140 *fact = Box::into_raw(Box::new(TractFact(f)));
1141 Ok(())
1142 })
1143}
1144
1145#[unsafe(no_mangle)]
1149pub unsafe extern "C" fn tract_fact_dump(
1150 fact: *const TractFact,
1151 spec: *mut *mut c_char,
1152) -> TRACT_RESULT {
1153 wrap(|| unsafe {
1154 check_not_null!(fact, spec);
1155 *spec = CString::new(format!("{}", (*fact).0))?.into_raw();
1156 Ok(())
1157 })
1158}
1159
1160#[unsafe(no_mangle)]
1161pub unsafe extern "C" fn tract_fact_destroy(fact: *mut *mut TractFact) -> TRACT_RESULT {
1162 release!(fact)
1163}
1164
1165pub struct TractInferenceFact(tract_rs::InferenceFact);
1167
1168#[unsafe(no_mangle)]
1172pub unsafe extern "C" fn tract_inference_fact_parse(
1173 model: *mut TractInferenceModel,
1174 spec: *const c_char,
1175 fact: *mut *mut TractInferenceFact,
1176) -> TRACT_RESULT {
1177 wrap(|| unsafe {
1178 check_not_null!(model, spec, fact);
1179 let spec = CStr::from_ptr(spec).to_str()?;
1180 let f: tract_rs::InferenceFact = spec.as_fact(&mut (*model).0)?.as_ref().clone();
1181 *fact = Box::into_raw(Box::new(TractInferenceFact(f)));
1182 Ok(())
1183 })
1184}
1185
1186#[unsafe(no_mangle)]
1190pub unsafe extern "C" fn tract_inference_fact_empty(
1191 fact: *mut *mut TractInferenceFact,
1192) -> TRACT_RESULT {
1193 wrap(|| unsafe {
1194 check_not_null!(fact);
1195 *fact = Box::into_raw(Box::new(TractInferenceFact(Default::default())));
1196 Ok(())
1197 })
1198}
1199
1200#[unsafe(no_mangle)]
1204pub unsafe extern "C" fn tract_inference_fact_dump(
1205 fact: *const TractInferenceFact,
1206 spec: *mut *mut c_char,
1207) -> TRACT_RESULT {
1208 wrap(|| unsafe {
1209 check_not_null!(fact, spec);
1210 *spec = CString::new(format!("{}", (*fact).0))?.into_raw();
1211 Ok(())
1212 })
1213}
1214
1215#[unsafe(no_mangle)]
1217pub unsafe extern "C" fn tract_inference_fact_destroy(
1218 fact: *mut *mut TractInferenceFact,
1219) -> TRACT_RESULT {
1220 release!(fact)
1221}
1222
1223unsafe fn state_run(
1228 state: &mut State,
1229 inputs: *mut *mut TractValue,
1230 outputs: *mut *mut TractValue,
1231) -> Result<()> {
1232 unsafe {
1233 let values: Vec<_> = std::slice::from_raw_parts(inputs, state.input_count()?)
1234 .iter()
1235 .map(|tv| (**tv).0.clone())
1236 .collect();
1237 let values = state.run(values)?;
1238 for (i, value) in values.into_iter().enumerate() {
1239 *(outputs.add(i)) = Box::into_raw(Box::new(TractValue(value)))
1240 }
1241 Ok(())
1242 }
1243}