tract/
lib.rs

1#![allow(clippy::missing_safety_doc)]
2
3use anyhow::{Context, Result};
4use std::cell::RefCell;
5use std::ffi::{c_char, c_void, CStr, CString};
6use tract_api::{
7    AsFact, DatumType, InferenceModelInterface, ModelInterface, NnefInterface, OnnxInterface,
8    RunnableInterface, StateInterface, ValueInterface,
9};
10use tract_rs::{State, Value};
11
12/// Used as a return type of functions that can encounter errors.
13/// If the function encountered an error, you can retrieve it using the `tract_get_last_error`
14/// function
15#[repr(C)]
16#[allow(non_camel_case_types)]
17#[derive(Debug, PartialEq, Eq)]
18pub enum TRACT_RESULT {
19    /// The function returned successfully
20    TRACT_RESULT_OK = 0,
21    /// The function returned an error
22    TRACT_RESULT_KO = 1,
23}
24
25thread_local! {
26    pub(crate) static LAST_ERROR: RefCell<Option<CString>> = const { RefCell::new(None) };
27}
28
29fn wrap<F: FnOnce() -> anyhow::Result<()>>(func: F) -> TRACT_RESULT {
30    match func() {
31        Ok(_) => TRACT_RESULT::TRACT_RESULT_OK,
32        Err(e) => {
33            let msg = format!("{e:?}");
34            if std::env::var("TRACT_ERROR_STDERR").is_ok() {
35                eprintln!("{msg}");
36            }
37            LAST_ERROR.with(|p| {
38                *p.borrow_mut() = Some(CString::new(msg).unwrap_or_else(|_| {
39                    CString::new("tract error message contains 0, can't convert to CString")
40                        .unwrap()
41                }))
42            });
43            TRACT_RESULT::TRACT_RESULT_KO
44        }
45    }
46}
47
48/// Retrieve the last error that happened in this thread. A function encountered an error if
49/// its return type is of type `TRACT_RESULT` and it returned `TRACT_RESULT_KO`.
50///
51/// # Return value
52///  It returns a pointer to a null-terminated UTF-8 string that will contain the error description.
53///  Rust side keeps ownership of the buffer. It will be valid as long as no other tract calls is
54///  performed by the thread.
55///  If no error occured, null is returned.
56#[unsafe(no_mangle)]
57pub extern "C" fn tract_get_last_error() -> *const std::ffi::c_char {
58    LAST_ERROR.with(|msg| msg.borrow().as_ref().map(|s| s.as_ptr()).unwrap_or(std::ptr::null()))
59}
60
61/// Returns a pointer to a static buffer containing a null-terminated version string.
62///
63/// The returned pointer must not be freed.
64#[unsafe(no_mangle)]
65pub extern "C" fn tract_version() -> *const std::ffi::c_char {
66    unsafe {
67        CStr::from_bytes_with_nul_unchecked(concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes())
68            .as_ptr()
69    }
70}
71
72/// Frees a string allocated by libtract.
73#[unsafe(no_mangle)]
74pub unsafe extern "C" fn tract_free_cstring(ptr: *mut std::ffi::c_char) {
75    unsafe {
76        if !ptr.is_null() {
77            let _ = CString::from_raw(ptr);
78        }
79    }
80}
81
82macro_rules! check_not_null {
83    ($($ptr:expr),*) => {
84        $(
85            if $ptr.is_null() {
86                anyhow::bail!(concat!("Unexpected null pointer ", stringify!($ptr)));
87            }
88         )*
89    }
90}
91
92macro_rules! release {
93    ($ptr:expr) => {
94        wrap(|| unsafe {
95            check_not_null!($ptr, *$ptr);
96            let _ = Box::from_raw(*$ptr);
97            *$ptr = std::ptr::null_mut();
98            Ok(())
99        })
100    };
101}
102
103// NNEF
104pub struct TractNnef(tract_rs::Nnef);
105
106/// Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models.
107///
108/// The returned object should be destroyed with `tract_nnef_destroy` once the model
109/// has been loaded.
110#[unsafe(no_mangle)]
111pub unsafe extern "C" fn tract_nnef_create(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
112    wrap(|| unsafe {
113        check_not_null!(nnef);
114        *nnef = Box::into_raw(Box::new(TractNnef(tract_rs::nnef()?)));
115        Ok(())
116    })
117}
118
119#[unsafe(no_mangle)]
120pub unsafe extern "C" fn tract_nnef_transform_model(
121    nnef: *const TractNnef,
122    model: *mut TractModel,
123    transform_spec: *const i8,
124) -> TRACT_RESULT {
125    wrap(|| unsafe {
126        check_not_null!(nnef, model, transform_spec);
127        let transform_spec = CStr::from_ptr(transform_spec as _).to_str()?;
128        (*nnef)
129            .0
130            .transform_model(&mut (*model).0, transform_spec)
131            .with_context(|| format!("performing transform {transform_spec:?}"))?;
132        Ok(())
133    })
134}
135
136#[unsafe(no_mangle)]
137pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> TRACT_RESULT {
138    wrap(|| unsafe {
139        check_not_null!(nnef);
140        (*nnef).0.enable_tract_core()
141    })
142}
143
144#[unsafe(no_mangle)]
145pub unsafe extern "C" fn tract_nnef_enable_tract_extra(nnef: *mut TractNnef) -> TRACT_RESULT {
146    wrap(|| unsafe {
147        check_not_null!(nnef);
148        (*nnef).0.enable_tract_extra()
149    })
150}
151
152#[unsafe(no_mangle)]
153pub unsafe extern "C" fn tract_nnef_enable_tract_transformers(
154    nnef: *mut TractNnef,
155) -> TRACT_RESULT {
156    wrap(|| unsafe {
157        check_not_null!(nnef);
158        (*nnef).0.enable_tract_transformers()
159    })
160}
161
162#[unsafe(no_mangle)]
163pub unsafe extern "C" fn tract_nnef_enable_onnx(nnef: *mut TractNnef) -> TRACT_RESULT {
164    wrap(|| unsafe {
165        check_not_null!(nnef);
166        (*nnef).0.enable_onnx()
167    })
168}
169
170#[unsafe(no_mangle)]
171pub unsafe extern "C" fn tract_nnef_enable_pulse(nnef: *mut TractNnef) -> TRACT_RESULT {
172    wrap(|| unsafe {
173        check_not_null!(nnef);
174        (*nnef).0.enable_pulse()
175    })
176}
177
178#[unsafe(no_mangle)]
179pub unsafe extern "C" fn tract_nnef_enable_extended_identifier_syntax(
180    nnef: *mut TractNnef,
181) -> TRACT_RESULT {
182    wrap(|| unsafe {
183        check_not_null!(nnef);
184        (*nnef).0.enable_extended_identifier_syntax()
185    })
186}
187
188/// Destroy the NNEF parser. It is safe to detroy the NNEF parser once the model had been loaded.
189#[unsafe(no_mangle)]
190pub unsafe extern "C" fn tract_nnef_destroy(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
191    release!(nnef)
192}
193
194/// Parse and load an NNEF model as a tract TypedModel.
195///
196/// `path` is a null-terminated utf-8 string pointer. It can be an archive (tar or tar.gz file) or a
197/// directory.
198#[unsafe(no_mangle)]
199pub unsafe extern "C" fn tract_nnef_model_for_path(
200    nnef: *const TractNnef,
201    path: *const c_char,
202    model: *mut *mut TractModel,
203) -> TRACT_RESULT {
204    wrap(|| unsafe {
205        check_not_null!(nnef, model, path);
206        *model = std::ptr::null_mut();
207        let path = CStr::from_ptr(path).to_str()?;
208        let m = Box::new(TractModel(
209            (*nnef).0.model_for_path(path).with_context(|| format!("opening file {path:?}"))?,
210        ));
211        *model = Box::into_raw(m);
212        Ok(())
213    })
214}
215
216/// Dump a TypedModel as a NNEF tar file.
217///
218/// `path` is a null-terminated utf-8 string pointer to the `.tar` file to be created.
219///
220/// This function creates a plain, non-compressed, archive.
221#[unsafe(no_mangle)]
222pub unsafe extern "C" fn tract_nnef_write_model_to_tar(
223    nnef: *const TractNnef,
224    path: *const c_char,
225    model: *const TractModel,
226) -> TRACT_RESULT {
227    wrap(|| unsafe {
228        check_not_null!(nnef, model, path);
229        let path = CStr::from_ptr(path).to_str()?;
230        (*nnef).0.write_model_to_tar(path, &(*model).0)?;
231        Ok(())
232    })
233}
234
235/// Dump a TypedModel as a NNEF .tar.gz file.
236///
237/// `path` is a null-terminated utf-8 string pointer to the `.tar.gz` file to be created.
238#[unsafe(no_mangle)]
239pub unsafe extern "C" fn tract_nnef_write_model_to_tar_gz(
240    nnef: *const TractNnef,
241    path: *const c_char,
242    model: *const TractModel,
243) -> TRACT_RESULT {
244    wrap(|| unsafe {
245        check_not_null!(nnef, model, path);
246        let path = CStr::from_ptr(path).to_str()?;
247        (*nnef).0.write_model_to_tar_gz(path, &(*model).0)?;
248        Ok(())
249    })
250}
251
252/// Dump a TypedModel as a NNEF directory.
253///
254/// `path` is a null-terminated utf-8 string pointer to the directory to be created.
255///
256/// This function creates a plain, non-compressed, archive.
257#[unsafe(no_mangle)]
258pub unsafe extern "C" fn tract_nnef_write_model_to_dir(
259    nnef: *const TractNnef,
260    path: *const c_char,
261    model: *const TractModel,
262) -> TRACT_RESULT {
263    wrap(|| unsafe {
264        check_not_null!(nnef, model, path);
265        let path = CStr::from_ptr(path).to_str()?;
266        (*nnef).0.write_model_to_dir(path, &(*model).0)?;
267        Ok(())
268    })
269}
270
271// ONNX
272pub struct TractOnnx(tract_rs::Onnx);
273
274/// Creates an instance of an ONNX framework and parser that can be used to load models.
275///
276/// The returned object should be destroyed with `tract_nnef_destroy` once the model
277/// has been loaded.
278#[unsafe(no_mangle)]
279pub unsafe extern "C" fn tract_onnx_create(onnx: *mut *mut TractOnnx) -> TRACT_RESULT {
280    wrap(|| unsafe {
281        check_not_null!(onnx);
282        *onnx = Box::into_raw(Box::new(TractOnnx(tract_rs::onnx()?)));
283        Ok(())
284    })
285}
286
287/// Destroy the NNEF parser. It is safe to detroy the NNEF parser once the model had been loaded.
288#[unsafe(no_mangle)]
289pub unsafe extern "C" fn tract_onnx_destroy(onnx: *mut *mut TractOnnx) -> TRACT_RESULT {
290    release!(onnx)
291}
292
293/// Parse and load an ONNX model as a tract InferenceModel.
294///
295/// `path` is a null-terminated utf-8 string pointer. It must point to a `.onnx` model file.
296#[unsafe(no_mangle)]
297pub unsafe extern "C" fn tract_onnx_model_for_path(
298    onnx: *const TractOnnx,
299    path: *const c_char,
300    model: *mut *mut TractInferenceModel,
301) -> TRACT_RESULT {
302    wrap(|| unsafe {
303        check_not_null!(onnx, path, model);
304        *model = std::ptr::null_mut();
305        let path = CStr::from_ptr(path).to_str()?;
306        let m = Box::new(TractInferenceModel((*onnx).0.model_for_path(path)?));
307        *model = Box::into_raw(m);
308        Ok(())
309    })
310}
311
312// INFERENCE MODEL
313pub struct TractInferenceModel(tract_rs::InferenceModel);
314
315/// Query an InferenceModel input counts.
316#[unsafe(no_mangle)]
317pub unsafe extern "C" fn tract_inference_model_input_count(
318    model: *const TractInferenceModel,
319    inputs: *mut usize,
320) -> TRACT_RESULT {
321    wrap(|| unsafe {
322        check_not_null!(model, inputs);
323        let model = &(*model).0;
324        *inputs = model.input_count()?;
325        Ok(())
326    })
327}
328
329/// Query an InferenceModel output counts.
330#[unsafe(no_mangle)]
331pub unsafe extern "C" fn tract_inference_model_output_count(
332    model: *const TractInferenceModel,
333    outputs: *mut usize,
334) -> TRACT_RESULT {
335    wrap(|| unsafe {
336        check_not_null!(model, outputs);
337        let model = &(*model).0;
338        *outputs = model.output_count()?;
339        Ok(())
340    })
341}
342
343/// Query the name of a model input.
344///
345/// The returned name must be freed by the caller using tract_free_cstring.
346#[unsafe(no_mangle)]
347pub unsafe extern "C" fn tract_inference_model_input_name(
348    model: *const TractInferenceModel,
349    input: usize,
350    name: *mut *mut c_char,
351) -> TRACT_RESULT {
352    wrap(|| unsafe {
353        check_not_null!(model, name);
354        *name = std::ptr::null_mut();
355        let m = &(*model).0;
356        *name = CString::new(&*m.input_name(input)?)?.into_raw();
357        Ok(())
358    })
359}
360
361/// Query the name of a model output.
362///
363/// The returned name must be freed by the caller using tract_free_cstring.
364#[unsafe(no_mangle)]
365pub unsafe extern "C" fn tract_inference_model_output_name(
366    model: *const TractInferenceModel,
367    output: usize,
368    name: *mut *mut i8,
369) -> TRACT_RESULT {
370    wrap(|| unsafe {
371        check_not_null!(model, name);
372        *name = std::ptr::null_mut();
373        let m = &(*model).0;
374        *name = CString::new(&*m.output_name(output)?)?.into_raw() as _;
375        Ok(())
376    })
377}
378
379#[unsafe(no_mangle)]
380pub unsafe extern "C" fn tract_inference_model_input_fact(
381    model: *const TractInferenceModel,
382    input_id: usize,
383    fact: *mut *mut TractInferenceFact,
384) -> TRACT_RESULT {
385    wrap(|| unsafe {
386        check_not_null!(model, fact);
387        *fact = std::ptr::null_mut();
388        let f = (*model).0.input_fact(input_id)?;
389        *fact = Box::into_raw(Box::new(TractInferenceFact(f)));
390        Ok(())
391    })
392}
393
394/// Set an input fact of an InferenceModel.
395///
396/// The `fact` argument is only borrowed by this function, it still must be destroyed.
397/// `fact` can be set to NULL to erase the current output fact of the model.
398#[unsafe(no_mangle)]
399pub unsafe extern "C" fn tract_inference_model_set_input_fact(
400    model: *mut TractInferenceModel,
401    input_id: usize,
402    fact: *const TractInferenceFact,
403) -> TRACT_RESULT {
404    wrap(|| unsafe {
405        check_not_null!(model);
406        let f = fact.as_ref().map(|f| &f.0).cloned().unwrap_or_default();
407        (*model).0.set_input_fact(input_id, f)?;
408        Ok(())
409    })
410}
411
412/// Change the model outputs nodes (by name).
413///
414/// `names` is an array containing `len` pointers to null terminated strings.
415#[unsafe(no_mangle)]
416pub unsafe extern "C" fn tract_inference_model_set_output_names(
417    model: *mut TractInferenceModel,
418    len: usize,
419    names: *const *const c_char,
420) -> TRACT_RESULT {
421    wrap(|| unsafe {
422        check_not_null!(model, names, *names);
423        let node_names = (0..len)
424            .map(|i| Ok(CStr::from_ptr(*names.add(i)).to_str()?.to_owned()))
425            .collect::<Result<Vec<_>>>()?;
426        (*model).0.set_output_names(&node_names)?;
427        Ok(())
428    })
429}
430
431/// Query an output fact for an InferenceModel.
432///
433/// The return model must be freed using `tract_inference_fact_destroy`.
434#[unsafe(no_mangle)]
435pub unsafe extern "C" fn tract_inference_model_output_fact(
436    model: *const TractInferenceModel,
437    output_id: usize,
438    fact: *mut *mut TractInferenceFact,
439) -> TRACT_RESULT {
440    wrap(|| unsafe {
441        check_not_null!(model, fact);
442        *fact = std::ptr::null_mut();
443        let f = (*model).0.output_fact(output_id)?;
444        *fact = Box::into_raw(Box::new(TractInferenceFact(f)));
445        Ok(())
446    })
447}
448
449/// Set an output fact of an InferenceModel.
450///
451/// The `fact` argument is only borrowed by this function, it still must be destroyed.
452/// `fact` can be set to NULL to erase the current output fact of the model.
453#[unsafe(no_mangle)]
454pub unsafe extern "C" fn tract_inference_model_set_output_fact(
455    model: *mut TractInferenceModel,
456    output_id: usize,
457    fact: *const TractInferenceFact,
458) -> TRACT_RESULT {
459    wrap(|| unsafe {
460        check_not_null!(model);
461        let f = fact.as_ref().map(|f| &f.0).cloned().unwrap_or_default();
462        (*model).0.set_output_fact(output_id, f)?;
463        Ok(())
464    })
465}
466
467/// Analyse an InferencedModel in-place.
468#[unsafe(no_mangle)]
469pub unsafe extern "C" fn tract_inference_model_analyse(
470    model: *mut TractInferenceModel,
471) -> TRACT_RESULT {
472    wrap(|| unsafe {
473        check_not_null!(model);
474        (*model).0.analyse()?;
475        Ok(())
476    })
477}
478
479/// Convenience function to obtain an optimized TypedModel from an InferenceModel.
480///
481/// This function takes ownership of the InferenceModel `model` whether it succeeds
482/// or not. `tract_inference_model_destroy` must not be used on `model`.
483///
484/// On the other hand, caller will be owning the newly created `optimized` model.
485#[unsafe(no_mangle)]
486pub unsafe extern "C" fn tract_inference_model_into_optimized(
487    model: *mut *mut TractInferenceModel,
488    optimized: *mut *mut TractModel,
489) -> TRACT_RESULT {
490    wrap(|| unsafe {
491        check_not_null!(model, *model, optimized);
492        *optimized = std::ptr::null_mut();
493        let m = Box::from_raw(*model);
494        *model = std::ptr::null_mut();
495        let result = m.0.into_optimized()?;
496        *optimized = Box::into_raw(Box::new(TractModel(result))) as _;
497        Ok(())
498    })
499}
500
501/// Transform a fully analysed InferenceModel to a TypedModel.
502///
503/// This function takes ownership of the InferenceModel `model` whether it succeeds
504/// or not. `tract_inference_model_destroy` must not be used on `model`.
505///
506/// On the other hand, caller will be owning the newly created `optimized` model.
507#[unsafe(no_mangle)]
508pub unsafe extern "C" fn tract_inference_model_into_typed(
509    model: *mut *mut TractInferenceModel,
510    typed: *mut *mut TractModel,
511) -> TRACT_RESULT {
512    wrap(|| unsafe {
513        check_not_null!(model, *model, typed);
514        *typed = std::ptr::null_mut();
515        let m = Box::from_raw(*model);
516        *model = std::ptr::null_mut();
517        let result = m.0.into_typed()?;
518        *typed = Box::into_raw(Box::new(TractModel(result))) as _;
519        Ok(())
520    })
521}
522
523/// Destroy an InferenceModel.
524#[unsafe(no_mangle)]
525pub unsafe extern "C" fn tract_inference_model_destroy(
526    model: *mut *mut TractInferenceModel,
527) -> TRACT_RESULT {
528    release!(model)
529}
530// TYPED MODEL
531
532pub struct TractModel(tract_rs::Model);
533
534/// Query an InferenceModel input counts.
535#[unsafe(no_mangle)]
536pub unsafe extern "C" fn tract_model_input_count(
537    model: *const TractModel,
538    inputs: *mut usize,
539) -> TRACT_RESULT {
540    wrap(|| unsafe {
541        check_not_null!(model, inputs);
542        let model = &(*model).0;
543        *inputs = model.input_count()?;
544        Ok(())
545    })
546}
547
548/// Query an InferenceModel output counts.
549#[unsafe(no_mangle)]
550pub unsafe extern "C" fn tract_model_output_count(
551    model: *const TractModel,
552    outputs: *mut usize,
553) -> TRACT_RESULT {
554    wrap(|| unsafe {
555        check_not_null!(model, outputs);
556        let model = &(*model).0;
557        *outputs = model.output_count()?;
558        Ok(())
559    })
560}
561
562/// Query the name of a model input.
563///
564/// The returned name must be freed by the caller using tract_free_cstring.
565#[unsafe(no_mangle)]
566pub unsafe extern "C" fn tract_model_input_name(
567    model: *const TractModel,
568    input: usize,
569    name: *mut *mut c_char,
570) -> TRACT_RESULT {
571    wrap(|| unsafe {
572        check_not_null!(model, name);
573        *name = std::ptr::null_mut();
574        let m = &(*model).0;
575        *name = CString::new(m.input_name(input)?)?.into_raw();
576        Ok(())
577    })
578}
579
580/// Query the input fact of a model.
581///
582/// Thre returned fact must be freed with tract_fact_destroy.
583#[unsafe(no_mangle)]
584pub unsafe extern "C" fn tract_model_input_fact(
585    model: *const TractModel,
586    input_id: usize,
587    fact: *mut *mut TractFact,
588) -> TRACT_RESULT {
589    wrap(|| unsafe {
590        check_not_null!(model, fact);
591        *fact = std::ptr::null_mut();
592        let f = (*model).0.input_fact(input_id)?;
593        *fact = Box::into_raw(Box::new(TractFact(f)));
594        Ok(())
595    })
596}
597
598/// Query the name of a model output.
599///
600/// The returned name must be freed by the caller using tract_free_cstring.
601#[unsafe(no_mangle)]
602pub unsafe extern "C" fn tract_model_output_name(
603    model: *const TractModel,
604    output: usize,
605    name: *mut *mut c_char,
606) -> TRACT_RESULT {
607    wrap(|| unsafe {
608        check_not_null!(model, name);
609        *name = std::ptr::null_mut();
610        let m = &(*model).0;
611        *name = CString::new(m.output_name(output)?)?.into_raw();
612        Ok(())
613    })
614}
615
616/// Query the output fact of a model.
617///
618/// Thre returned fact must be freed with tract_fact_destroy.
619#[unsafe(no_mangle)]
620pub unsafe extern "C" fn tract_model_output_fact(
621    model: *const TractModel,
622    input_id: usize,
623    fact: *mut *mut TractFact,
624) -> TRACT_RESULT {
625    wrap(|| unsafe {
626        check_not_null!(model, fact);
627        *fact = std::ptr::null_mut();
628        let f = (*model).0.output_fact(input_id)?;
629        *fact = Box::into_raw(Box::new(TractFact(f)));
630        Ok(())
631    })
632}
633
634/// Change the model outputs nodes (by name).
635///
636/// `names` is an array containing `len` pointers to null terminated strings.
637#[unsafe(no_mangle)]
638pub unsafe extern "C" fn tract_model_set_output_names(
639    model: *mut TractModel,
640    len: usize,
641    names: *const *const c_char,
642) -> TRACT_RESULT {
643    wrap(|| unsafe {
644        check_not_null!(model, names, *names);
645        let node_names = (0..len)
646            .map(|i| Ok(CStr::from_ptr(*names.add(i)).to_str()?.to_owned()))
647            .collect::<Result<Vec<_>>>()?;
648        (*model).0.set_output_names(&node_names)
649    })
650}
651
652/// Give value one or more symbols used in the model.
653///
654/// * symbols is an array of `nb_symbols` pointers to null-terminated UTF-8 string for the symbols
655///   names to substitue
656/// * values is an array of `nb_symbols` integer values
657#[unsafe(no_mangle)]
658pub unsafe extern "C" fn tract_model_concretize_symbols(
659    model: *mut TractModel,
660    nb_symbols: usize,
661    symbols: *const *const i8,
662    values: *const i64,
663) -> TRACT_RESULT {
664    wrap(|| unsafe {
665        check_not_null!(model, symbols, values);
666        let model = &mut (*model).0;
667        let mut table = vec![];
668        for i in 0..nb_symbols {
669            let name = CStr::from_ptr(*symbols.add(i) as _)
670                .to_str()
671                .with_context(|| {
672                    format!("failed to parse symbol name for {i}th symbol (not utf8)")
673                })?
674                .to_owned();
675            table.push((name, *values.add(i)));
676        }
677        model.concretize_symbols(table)
678    })
679}
680
681/// Pulsify the model
682///
683/// * stream_symbol is the name of the stream symbol
684/// * pulse expression is a dim to use as the pulse size (like "8", "P" or "3*p").
685#[unsafe(no_mangle)]
686pub unsafe extern "C" fn tract_model_pulse_simple(
687    model: *mut *mut TractModel,
688    stream_symbol: *const i8,
689    pulse_expr: *const i8,
690) -> TRACT_RESULT {
691    wrap(|| unsafe {
692        check_not_null!(model, *model, stream_symbol, pulse_expr);
693        let model = &mut (**model).0;
694        let stream_sym = CStr::from_ptr(stream_symbol as _)
695            .to_str()
696            .context("failed to parse stream symbol name (not utf8)")?;
697        let pulse_dim = CStr::from_ptr(pulse_expr as _)
698            .to_str()
699            .context("failed to parse stream symbol name (not utf8)")?;
700        model.pulse(stream_sym, pulse_dim)
701    })
702}
703
704/// Apply a transform to the model.
705#[unsafe(no_mangle)]
706pub unsafe extern "C" fn tract_model_transform(
707    model: *mut TractModel,
708    transform: *const i8,
709) -> TRACT_RESULT {
710    wrap(|| unsafe {
711        check_not_null!(model, transform);
712        let t = CStr::from_ptr(transform as _)
713            .to_str()
714            .context("failed to parse transform name (not utf8)")?;
715        (*model).0.transform(t)
716    })
717}
718
719/// Declutter a TypedModel in-place.
720#[unsafe(no_mangle)]
721pub unsafe extern "C" fn tract_model_declutter(model: *mut TractModel) -> TRACT_RESULT {
722    wrap(|| unsafe {
723        check_not_null!(model);
724        (*model).0.declutter()
725    })
726}
727
728/// Optimize a TypedModel in-place.
729#[unsafe(no_mangle)]
730pub unsafe extern "C" fn tract_model_optimize(model: *mut TractModel) -> TRACT_RESULT {
731    wrap(|| unsafe {
732        check_not_null!(model);
733        (*model).0.optimize()
734    })
735}
736
737/// Perform a profile of the model using the provided inputs.
738#[unsafe(no_mangle)]
739pub unsafe extern "C" fn tract_model_profile_json(
740    model: *mut TractModel,
741    inputs: *mut *mut TractValue,
742    states: *const *const TractValue,
743    n_states: usize,
744    json: *mut *mut i8,
745) -> TRACT_RESULT {
746    wrap(|| unsafe {
747        check_not_null!(model, json);
748
749        let input: Option<Vec<Value>> = if !inputs.is_null() {
750            let input_len = (*model).0.input_count()?;
751            Some(
752                std::slice::from_raw_parts(inputs, input_len)
753                    .iter()
754                    .map(|tv| (**tv).0.clone())
755                    .collect(),
756            )
757        } else {
758            None
759        };
760
761        let state_initializers: Option<Vec<Value>> = if !states.is_null() {
762            anyhow::ensure!(n_states != 0);
763            let hashmap = std::slice::from_raw_parts(states, n_states).iter()
764                    .map(|tv| {
765                        (**tv).0.clone()
766                    }).collect();
767            Some(hashmap)
768        } else { None };
769
770        let profile = (*model).0.profile_json(input, state_initializers)?;
771        *json = CString::new(profile)?.into_raw() as _;
772        Ok(())
773    })
774}
775
776/// Convert a TypedModel into a TypedRunnableModel.
777///
778/// This function transfers ownership of the `model` argument to the newly-created `runnable` model.
779///
780/// Runnable are reference counted. When done, it should be released with `tract_runnable_release`.
781#[unsafe(no_mangle)]
782pub unsafe extern "C" fn tract_model_into_runnable(
783    model: *mut *mut TractModel,
784    runnable: *mut *mut TractRunnable,
785) -> TRACT_RESULT {
786    wrap(|| unsafe {
787        check_not_null!(model, runnable);
788        let m = Box::from_raw(*model).0;
789        *model = std::ptr::null_mut();
790        *runnable = Box::into_raw(Box::new(TractRunnable(m.into_runnable()?))) as _;
791        Ok(())
792    })
793}
794
795/// Query the number of properties in a model.
796#[unsafe(no_mangle)]
797pub unsafe extern "C" fn tract_model_property_count(
798    model: *const TractModel,
799    count: *mut usize,
800) -> TRACT_RESULT {
801    wrap(|| unsafe {
802        check_not_null!(model, count);
803        *count = (*model).0.property_keys()?.len();
804        Ok(())
805    })
806}
807
808/// Query the properties names of a model.
809///
810/// The "names" array should be big enough to fit `tract_model_property_count` string pointers.
811///
812/// Each name will have to be freed using `tract_free_cstring`.
813#[unsafe(no_mangle)]
814pub unsafe extern "C" fn tract_model_property_names(
815    model: *const TractModel,
816    names: *mut *mut i8,
817) -> TRACT_RESULT {
818    wrap(|| unsafe {
819        check_not_null!(model, names);
820        for (ix, name) in (*model).0.property_keys()?.iter().enumerate() {
821            *names.add(ix) = CString::new(&**name)?.into_raw() as _;
822        }
823        Ok(())
824    })
825}
826
827/// Query a property value in a model.
828#[unsafe(no_mangle)]
829pub unsafe extern "C" fn tract_model_property(
830    model: *const TractModel,
831    name: *const i8,
832    value: *mut *mut TractValue,
833) -> TRACT_RESULT {
834    wrap(|| unsafe {
835        check_not_null!(model, name, value);
836        let name = CStr::from_ptr(name as _)
837            .to_str()
838            .context("failed to parse property name (not utf8)")?
839            .to_owned();
840        let v = (*model).0.property(name).context("Property not found")?;
841        *value = Box::into_raw(Box::new(TractValue(v)));
842        Ok(())
843    })
844}
845
846/// Destroy a TypedModel.
847#[unsafe(no_mangle)]
848pub unsafe extern "C" fn tract_model_destroy(model: *mut *mut TractModel) -> TRACT_RESULT {
849    release!(model)
850}
851
852// RUNNABLE MODEL
853pub struct TractRunnable(tract_rs::Runnable);
854
855/// Spawn a session state from a runnable model.
856///
857/// This function does not take ownership of the `runnable` object, it can be used again to spawn
858/// other state instances. The runnable object is internally reference counted, it will be
859/// kept alive as long as any associated `State` exists (or as long as the `runnable` is not
860/// explicitely release with `tract_runnable_release`).
861///
862/// `state` is a newly-created object. It should ultimately be detroyed with `tract_state_destroy`.
863#[unsafe(no_mangle)]
864pub unsafe extern "C" fn tract_runnable_spawn_state(
865    runnable: *mut TractRunnable,
866    state: *mut *mut TractState,
867) -> TRACT_RESULT {
868    wrap(|| unsafe {
869        check_not_null!(runnable, state);
870        *state = std::ptr::null_mut();
871        let s = (*runnable).0.spawn_state()?;
872        *state = Box::into_raw(Box::new(TractState(s)));
873        Ok(())
874    })
875}
876
877/// Convenience function to run a stateless model.
878///
879/// `inputs` is a pointer to an pre-existing array of input TractValue. Its length *must* be equal
880/// to the number of inputs of the models. The function does not take ownership of the input
881/// values.
882/// `outputs` is a pointer to a pre-existing array of TractValue pointers that will be overwritten
883/// with pointers to outputs values. These values are under the responsiblity of the caller, it
884/// will have to release them with `tract_value_destroy`.
885#[unsafe(no_mangle)]
886pub unsafe extern "C" fn tract_runnable_run(
887    runnable: *mut TractRunnable,
888    inputs: *mut *mut TractValue,
889    outputs: *mut *mut TractValue,
890) -> TRACT_RESULT {
891    wrap(|| unsafe {
892        check_not_null!(runnable);
893        let mut s = (*runnable).0.spawn_state()?;
894        state_run(&mut s, inputs, outputs)
895    })
896}
897
898/// Query a Runnable input counts.
899#[unsafe(no_mangle)]
900pub unsafe extern "C" fn tract_runnable_input_count(
901    model: *const TractRunnable,
902    inputs: *mut usize,
903) -> TRACT_RESULT {
904    wrap(|| unsafe {
905        check_not_null!(model, inputs);
906        let model = &(*model).0;
907        *inputs = model.input_count()?;
908        Ok(())
909    })
910}
911
912/// Query an Runnable output counts.
913#[unsafe(no_mangle)]
914pub unsafe extern "C" fn tract_runnable_output_count(
915    model: *const TractRunnable,
916    outputs: *mut usize,
917) -> TRACT_RESULT {
918    wrap(|| unsafe {
919        check_not_null!(model, outputs);
920        let model = &(*model).0;
921        *outputs = model.output_count()?;
922        Ok(())
923    })
924}
925
926#[unsafe(no_mangle)]
927pub unsafe extern "C" fn tract_runnable_release(runnable: *mut *mut TractRunnable) -> TRACT_RESULT {
928    release!(runnable)
929}
930
931// VALUE
932pub struct TractValue(tract_rs::Value);
933
934/// Create a TractValue (aka tensor) from caller data and metadata.
935///
936/// This call copies the data into tract space. All the pointers only need to be alive for the
937/// duration of the call.
938///
939/// rank is the number of dimensions of the tensor (i.e. the length of the shape vector).
940///
941/// The returned value must be destroyed by `tract_value_destroy`.
942#[unsafe(no_mangle)]
943pub unsafe extern "C" fn tract_value_from_bytes(
944    datum_type: DatumType,
945    rank: usize,
946    shape: *const usize,
947    data: *mut c_void,
948    value: *mut *mut TractValue,
949) -> TRACT_RESULT {
950    wrap(|| unsafe {
951        check_not_null!(value);
952        *value = std::ptr::null_mut();
953        let shape = std::slice::from_raw_parts(shape, rank);
954        let len = shape.iter().product::<usize>();
955        let data = std::slice::from_raw_parts(data as *const u8, len * datum_type.size_of());
956        let it = Value::from_bytes(datum_type, shape, data)?;
957        *value = Box::into_raw(Box::new(TractValue(it)));
958        Ok(())
959    })
960}
961
962/// Destroy a value.
963#[unsafe(no_mangle)]
964pub unsafe extern "C" fn tract_value_destroy(value: *mut *mut TractValue) -> TRACT_RESULT {
965    release!(value)
966}
967
968/// Inspect part of a value. Except `value`, all argument pointers can be null if only some specific bits
969/// are required.
970#[unsafe(no_mangle)]
971pub unsafe extern "C" fn tract_value_as_bytes(
972    value: *mut TractValue,
973    datum_type: *mut DatumType,
974    rank: *mut usize,
975    shape: *mut *const usize,
976    data: *mut *const c_void,
977) -> TRACT_RESULT {
978    wrap(|| unsafe {
979        check_not_null!(value);
980        let value = &(*value).0;
981        let bits = value.as_bytes()?;
982        if !datum_type.is_null() {
983            *datum_type = bits.0;
984        }
985        if !rank.is_null() {
986            *rank = bits.1.len();
987        }
988        if !shape.is_null() {
989            *shape = bits.1.as_ptr();
990        }
991        if !data.is_null() {
992            *data = bits.2.as_ptr() as _;
993        }
994        Ok(())
995    })
996}
997
998// STATE
999pub struct TractState(tract_rs::State);
1000
1001/// Run a turn on a model state
1002///
1003/// `inputs` is a pointer to an pre-existing array of input TractValue. Its length *must* be equal
1004/// to the number of inputs of the models. The function does not take ownership of the input
1005/// values.
1006/// `outputs` is a pointer to a pre-existing array of TractValue pointers that will be overwritten
1007/// with pointers to outputs values. These values are under the responsiblity of the caller, it
1008/// will have to release them with `tract_value_destroy`.
1009#[unsafe(no_mangle)]
1010pub unsafe extern "C" fn tract_state_run(
1011    state: *mut TractState,
1012    inputs: *mut *mut TractValue,
1013    outputs: *mut *mut TractValue,
1014) -> TRACT_RESULT {
1015    wrap(|| unsafe {
1016        check_not_null!(state, inputs, outputs);
1017        state_run(&mut (*state).0, inputs, outputs)
1018    })
1019}
1020
1021/// Query a State input counts.
1022#[unsafe(no_mangle)]
1023pub unsafe extern "C" fn tract_state_input_count(
1024    state: *const TractState,
1025    inputs: *mut usize,
1026) -> TRACT_RESULT {
1027    wrap(|| unsafe {
1028        check_not_null!(state, inputs);
1029        let state = &(*state).0;
1030        *inputs = state.input_count()?;
1031        Ok(())
1032    })
1033}
1034
1035/// Query an State output counts.
1036#[unsafe(no_mangle)]
1037pub unsafe extern "C" fn tract_state_output_count(
1038    state: *const TractState,
1039    outputs: *mut usize,
1040) -> TRACT_RESULT {
1041    wrap(|| unsafe {
1042        check_not_null!(state, outputs);
1043        let state = &(*state).0;
1044        *outputs = state.output_count()?;
1045        Ok(())
1046    })
1047}
1048
1049#[unsafe(no_mangle)]
1050pub unsafe extern "C" fn tract_state_destroy(state: *mut *mut TractState) -> TRACT_RESULT {
1051    release!(state)
1052}
1053
1054/// Get number of initializable stateful op
1055#[unsafe(no_mangle)]
1056pub unsafe extern "C" fn tract_state_initializable_states_count(
1057    state: *const TractState,
1058    n_states: *mut usize,
1059) -> TRACT_RESULT {
1060    wrap(|| unsafe {
1061        check_not_null!(state, n_states);
1062        let state = &(*state).0;
1063        *n_states = state.initializable_states_count()?;
1064        Ok(())
1065    })
1066}
1067
1068/// Get Stateful Ops's state facts
1069#[unsafe(no_mangle)]
1070pub unsafe extern "C" fn tract_state_get_states_facts(
1071    state: *const TractState,
1072    states: *mut *mut TractFact,
1073) -> TRACT_RESULT {
1074    wrap(|| unsafe {
1075        check_not_null!(state, states);
1076        let state = &(*state).0;
1077    
1078        let state_vec = state.get_states_facts()?;
1079        for (ix, f) in state_vec.into_iter().enumerate() {
1080            *states.add(ix) = Box::into_raw(Box::new(TractFact(f)));
1081        }
1082        Ok(())
1083    })
1084}
1085
1086/// Initialize Stateful Ops with specified values
1087#[unsafe(no_mangle)]
1088pub unsafe extern "C" fn tract_state_set_states(
1089    state: *mut TractState,
1090    states: *const *const TractValue,
1091) -> TRACT_RESULT {
1092    wrap(|| unsafe {
1093        check_not_null!(state, states);
1094        let state = &mut (*state).0;
1095
1096        let n_states = state.initializable_states_count()?;
1097        let state_initializers: Vec<Value> =
1098        std::slice::from_raw_parts(states, n_states).iter()
1099                    .map(|tv| {
1100                        (**tv).0.clone()
1101                    }).collect();
1102        state.set_states(state_initializers)?;
1103        Ok(())
1104    })
1105}
1106
1107/// Get Stateful Ops's current states.
1108#[unsafe(no_mangle)]
1109pub unsafe extern "C" fn tract_state_get_states(
1110    state: *const TractState,
1111    states: *mut *mut TractValue
1112) -> TRACT_RESULT {
1113    wrap(|| unsafe {
1114        let state = &(*state).0;
1115    
1116        let state_vec = state.get_states()?;
1117        for (ix, s) in state_vec.into_iter().enumerate() {
1118            *states.add(ix) = Box::into_raw(Box::new(TractValue(s)));
1119        }
1120        Ok(())
1121    })
1122}
1123
1124// FACT
1125pub struct TractFact(tract_rs::Fact);
1126
1127/// Parse a fact specification string into an Fact.
1128///
1129/// The returned fact must be free with `tract_fact_destroy`.
1130#[unsafe(no_mangle)]
1131pub unsafe extern "C" fn tract_fact_parse(
1132    model: *mut TractModel,
1133    spec: *const c_char,
1134    fact: *mut *mut TractFact,
1135) -> TRACT_RESULT {
1136    wrap(|| unsafe {
1137        check_not_null!(model, spec, fact);
1138        let spec = CStr::from_ptr(spec).to_str()?;
1139        let f: tract_rs::Fact = spec.as_fact(&mut (*model).0)?.as_ref().clone();
1140        *fact = Box::into_raw(Box::new(TractFact(f)));
1141        Ok(())
1142    })
1143}
1144
1145/// Write a fact as its specification string.
1146///
1147/// The returned string must be freed by the caller using tract_free_cstring.
1148#[unsafe(no_mangle)]
1149pub unsafe extern "C" fn tract_fact_dump(
1150    fact: *const TractFact,
1151    spec: *mut *mut c_char,
1152) -> TRACT_RESULT {
1153    wrap(|| unsafe {
1154        check_not_null!(fact, spec);
1155        *spec = CString::new(format!("{}", (*fact).0))?.into_raw();
1156        Ok(())
1157    })
1158}
1159
1160#[unsafe(no_mangle)]
1161pub unsafe extern "C" fn tract_fact_destroy(fact: *mut *mut TractFact) -> TRACT_RESULT {
1162    release!(fact)
1163}
1164
1165// INFERENCE FACT
1166pub struct TractInferenceFact(tract_rs::InferenceFact);
1167
1168/// Parse a fact specification string into an InferenceFact.
1169///
1170/// The returned fact must be free with `tract_inference_fact_destroy`.
1171#[unsafe(no_mangle)]
1172pub unsafe extern "C" fn tract_inference_fact_parse(
1173    model: *mut TractInferenceModel,
1174    spec: *const c_char,
1175    fact: *mut *mut TractInferenceFact,
1176) -> TRACT_RESULT {
1177    wrap(|| unsafe {
1178        check_not_null!(model, spec, fact);
1179        let spec = CStr::from_ptr(spec).to_str()?;
1180        let f: tract_rs::InferenceFact = spec.as_fact(&mut (*model).0)?.as_ref().clone();
1181        *fact = Box::into_raw(Box::new(TractInferenceFact(f)));
1182        Ok(())
1183    })
1184}
1185
1186/// Creates an empty inference fact.
1187///
1188/// The returned fact must be freed by the caller using tract_inference_fact_destroy
1189#[unsafe(no_mangle)]
1190pub unsafe extern "C" fn tract_inference_fact_empty(
1191    fact: *mut *mut TractInferenceFact,
1192) -> TRACT_RESULT {
1193    wrap(|| unsafe {
1194        check_not_null!(fact);
1195        *fact = Box::into_raw(Box::new(TractInferenceFact(Default::default())));
1196        Ok(())
1197    })
1198}
1199
1200/// Write an inference fact as its specification string.
1201///
1202/// The returned string must be freed by the caller using tract_free_cstring.
1203#[unsafe(no_mangle)]
1204pub unsafe extern "C" fn tract_inference_fact_dump(
1205    fact: *const TractInferenceFact,
1206    spec: *mut *mut c_char,
1207) -> TRACT_RESULT {
1208    wrap(|| unsafe {
1209        check_not_null!(fact, spec);
1210        *spec = CString::new(format!("{}", (*fact).0))?.into_raw();
1211        Ok(())
1212    })
1213}
1214
1215/// Destroy a fact.
1216#[unsafe(no_mangle)]
1217pub unsafe extern "C" fn tract_inference_fact_destroy(
1218    fact: *mut *mut TractInferenceFact,
1219) -> TRACT_RESULT {
1220    release!(fact)
1221}
1222
1223// MISC
1224
1225// HELPERS
1226
1227unsafe fn state_run(
1228    state: &mut State,
1229    inputs: *mut *mut TractValue,
1230    outputs: *mut *mut TractValue,
1231) -> Result<()> {
1232    unsafe {
1233        let values: Vec<_> = std::slice::from_raw_parts(inputs, state.input_count()?)
1234            .iter()
1235            .map(|tv| (**tv).0.clone())
1236            .collect();
1237        let values = state.run(values)?;
1238        for (i, value) in values.into_iter().enumerate() {
1239            *(outputs.add(i)) = Box::into_raw(Box::new(TractValue(value)))
1240        }
1241        Ok(())
1242    }
1243}