Skip to main content

tract_api/
lib.rs

1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8pub mod transform;
9
10pub use transform::{ConcretizeSymbols, FloatPrecision, Pulse, TransformConfig, TransformSpec};
11
12/// an implementation of tract's NNEF framework object
13///
14/// Entry point for NNEF model manipulation: loading from file, dumping to file.
15pub trait NnefInterface: Debug + Sized {
16    type Model: ModelInterface;
17    /// Load a NNEF model from the path into a tract-core model.
18    ///
19    /// * `path` can point to a directory, a `tar` file or a `tar.gz` file.
20    fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
21
22    /// Load a NNEF model from a buffer into a tract-core model.
23    ///
24    /// data is the content of a NNEF model, as a `tar` file or a `tar.gz` file.
25    fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
26
27    /// Allow the framework to use tract_core extensions instead of a stricter NNEF definition.
28    fn enable_tract_core(&mut self) -> Result<()>;
29
30    /// Allow the framework to use tract_extra extensions.
31    fn enable_tract_extra(&mut self) -> Result<()>;
32
33    /// Allow the framework to use tract_transformers extensions to support common transformer operators.
34    fn enable_tract_transformers(&mut self) -> Result<()>;
35
36    /// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
37    /// absent from NNEF.
38    fn enable_onnx(&mut self) -> Result<()>;
39
40    /// Allow the framework to use tract_pulse extensions to support stateful streaming operation.
41    fn enable_pulse(&mut self) -> Result<()>;
42
43    /// Allow the framework to use a tract-proprietary extension that can support special characters
44    /// in node names. If disable, tract will replace everything by underscore '_' to keep
45    /// compatibility with NNEF. If enabled, the extended syntax will be used, allowing to maintain
46    /// the node names in serialized form.
47    fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
48
49    /// Convenience function, similar with enable_tract_core but allowing method chaining.
50    fn with_tract_core(mut self) -> Result<Self> {
51        self.enable_tract_core()?;
52        Ok(self)
53    }
54
55    /// Convenience function, similar with enable_tract_core but allowing method chaining.
56    fn with_tract_extra(mut self) -> Result<Self> {
57        self.enable_tract_extra()?;
58        Ok(self)
59    }
60
61    /// Convenience function, similar with enable_tract_transformers but allowing method chaining.
62    fn with_tract_transformers(mut self) -> Result<Self> {
63        self.enable_tract_transformers()?;
64        Ok(self)
65    }
66
67    /// Convenience function, similar with enable_onnx but allowing method chaining.
68    fn with_onnx(mut self) -> Result<Self> {
69        self.enable_onnx()?;
70        Ok(self)
71    }
72
73    /// Convenience function, similar with enable_pulse but allowing method chaining.
74    fn with_pulse(mut self) -> Result<Self> {
75        self.enable_pulse()?;
76        Ok(self)
77    }
78
79    /// Convenience function, similar with enable_extended_identifier_syntax but allowing method chaining.
80    fn with_extended_identifier_syntax(mut self) -> Result<Self> {
81        self.enable_extended_identifier_syntax()?;
82        Ok(self)
83    }
84
85    /// Dump a TypedModel as a NNEF directory.
86    ///
87    /// `path` is the directory name to dump to
88    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
89
90    /// Dump a TypedModel as a NNEF tar file.
91    ///
92    /// This function creates a plain, non-compressed, archive.
93    ///
94    /// `path` is the archive name
95    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
96    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
97}
98
99pub trait OnnxInterface: Debug {
100    type InferenceModel: InferenceModelInterface;
101    fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
102    /// Load a ONNX model from a buffer into an InferenceModel.
103    fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
104}
105
106pub trait InferenceModelInterface: Debug + Sized {
107    type Model: ModelInterface;
108    type InferenceFact: InferenceFactInterface;
109    fn set_output_names(
110        &mut self,
111        outputs: impl IntoIterator<Item = impl AsRef<str>>,
112    ) -> Result<()>;
113    fn input_count(&self) -> Result<usize>;
114    fn output_count(&self) -> Result<usize>;
115    fn input_name(&self, id: usize) -> Result<String>;
116    fn output_name(&self, id: usize) -> Result<String>;
117
118    fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
119
120    fn set_input_fact(
121        &mut self,
122        id: usize,
123        fact: impl AsFact<Self, Self::InferenceFact>,
124    ) -> Result<()>;
125
126    fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
127
128    fn set_output_fact(
129        &mut self,
130        id: usize,
131        fact: impl AsFact<Self, Self::InferenceFact>,
132    ) -> Result<()>;
133
134    fn analyse(&mut self) -> Result<()>;
135
136    fn into_model(self) -> Result<Self::Model>;
137}
138
139pub trait ModelInterface: Debug + Sized {
140    type Fact: FactInterface;
141    type Runnable: RunnableInterface;
142    type Tensor: TensorInterface;
143    fn input_count(&self) -> Result<usize>;
144
145    fn output_count(&self) -> Result<usize>;
146
147    fn input_name(&self, id: usize) -> Result<String>;
148
149    fn output_name(&self, id: usize) -> Result<String>;
150
151    fn set_output_names(
152        &mut self,
153        outputs: impl IntoIterator<Item = impl AsRef<str>>,
154    ) -> Result<()>;
155
156    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
157
158    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
159
160    fn into_runnable(self) -> Result<Self::Runnable>;
161
162    fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()>;
163
164    fn property_keys(&self) -> Result<Vec<String>>;
165
166    fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
167
168    fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
169
170    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
171        Ok((0..self.input_count()?)
172            .map(|ix| self.input_fact(ix))
173            .collect::<Result<Vec<_>>>()?
174            .into_iter())
175    }
176
177    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
178        Ok((0..self.output_count()?)
179            .map(|ix| self.output_fact(ix))
180            .collect::<Result<Vec<_>>>()?
181            .into_iter())
182    }
183}
184
185pub trait RuntimeInterface: Debug {
186    type Runnable: RunnableInterface;
187    type Model: ModelInterface;
188    fn name(&self) -> Result<String>;
189    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
190}
191
192pub trait RunnableInterface: Debug + Send + Sync {
193    type Tensor: TensorInterface;
194    type Fact: FactInterface;
195    type State: StateInterface<Tensor = Self::Tensor>;
196    fn run(&self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>> {
197        self.spawn_state()?.run(inputs.into_inputs()?)
198    }
199
200    fn input_count(&self) -> Result<usize>;
201    fn output_count(&self) -> Result<usize>;
202    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
203
204    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
205
206    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
207        Ok((0..self.input_count()?)
208            .map(|ix| self.input_fact(ix))
209            .collect::<Result<Vec<_>>>()?
210            .into_iter())
211    }
212
213    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
214        Ok((0..self.output_count()?)
215            .map(|ix| self.output_fact(ix))
216            .collect::<Result<Vec<_>>>()?
217            .into_iter())
218    }
219
220    fn property_keys(&self) -> Result<Vec<String>>;
221    fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
222
223    fn spawn_state(&self) -> Result<Self::State>;
224
225    fn cost_json(&self) -> Result<String>;
226
227    fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
228    where
229        I: IntoIterator<Item = IV>,
230        IV: TryInto<Self::Tensor, Error = IE>,
231        IE: Into<anyhow::Error> + Debug;
232}
233
234pub trait StateInterface: Debug {
235    type Fact: FactInterface;
236    type Tensor: TensorInterface;
237
238    fn input_count(&self) -> Result<usize>;
239    fn output_count(&self) -> Result<usize>;
240
241    fn run(&mut self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>>;
242}
243
244pub trait TensorInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
245    fn datum_type(&self) -> Result<DatumType>;
246    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
247    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
248
249    fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
250        let data = unsafe {
251            std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
252        };
253        Self::from_bytes(T::datum_type(), shape, data)
254    }
255
256    fn as_slice<T: Datum>(&self) -> Result<&[T]> {
257        let (dt, _shape, data) = self.as_bytes()?;
258        ensure!(T::datum_type() == dt);
259        let data = unsafe {
260            std::slice::from_raw_parts(
261                data.as_ptr() as *const T,
262                data.len() / std::mem::size_of::<T>(),
263            )
264        };
265        Ok(data)
266    }
267
268    fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
269        let (_, shape, _) = self.as_bytes()?;
270        let data = self.as_slice()?;
271        Ok((shape, data))
272    }
273
274    fn shape(&self) -> Result<&[usize]> {
275        let (_, shape, _) = self.as_bytes()?;
276        Ok(shape)
277    }
278
279    fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
280        let (shape, data) = self.as_shape_and_slice()?;
281        Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
282    }
283
284    fn view1<T: Datum>(&self) -> Result<ndarray::ArrayView1<'_, T>> {
285        Ok(self.view::<T>()?.into_dimensionality()?)
286    }
287
288    fn view2<T: Datum>(&self) -> Result<ndarray::ArrayView2<'_, T>> {
289        Ok(self.view::<T>()?.into_dimensionality()?)
290    }
291
292    fn view3<T: Datum>(&self) -> Result<ndarray::ArrayView3<'_, T>> {
293        Ok(self.view::<T>()?.into_dimensionality()?)
294    }
295
296    fn view4<T: Datum>(&self) -> Result<ndarray::ArrayView4<'_, T>> {
297        Ok(self.view::<T>()?.into_dimensionality()?)
298    }
299
300    fn view5<T: Datum>(&self) -> Result<ndarray::ArrayView5<'_, T>> {
301        Ok(self.view::<T>()?.into_dimensionality()?)
302    }
303
304    fn view6<T: Datum>(&self) -> Result<ndarray::ArrayView6<'_, T>> {
305        Ok(self.view::<T>()?.into_dimensionality()?)
306    }
307
308    fn convert_to(&self, to: DatumType) -> Result<Self>;
309}
310
311pub trait FactInterface: Debug + Display + Clone {
312    type Dim: DimInterface;
313    fn datum_type(&self) -> Result<DatumType>;
314    fn rank(&self) -> Result<usize>;
315    fn dim(&self, axis: usize) -> Result<Self::Dim>;
316
317    fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
318        Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
319    }
320}
321
322pub trait DimInterface: Debug + Display + Clone {
323    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
324    fn to_int64(&self) -> Result<i64>;
325}
326
327pub trait InferenceFactInterface: Debug + Display + Default + Clone {
328    fn empty() -> Result<Self>;
329}
330
331pub trait AsFact<M, F>: Debug {
332    fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
333}
334
335#[repr(C)]
336#[derive(Debug, PartialEq, Eq, Copy, Clone)]
337pub enum DatumType {
338    Bool = 0x01,
339    U8 = 0x11,
340    U16 = 0x12,
341    U32 = 0x14,
342    U64 = 0x18,
343    I8 = 0x21,
344    I16 = 0x22,
345    I32 = 0x24,
346    I64 = 0x28,
347    F16 = 0x32,
348    F32 = 0x34,
349    F64 = 0x38,
350    #[cfg(feature = "complex")]
351    ComplexI16 = 0x42,
352    #[cfg(feature = "complex")]
353    ComplexI32 = 0x44,
354    #[cfg(feature = "complex")]
355    ComplexI64 = 0x48,
356    #[cfg(feature = "complex")]
357    ComplexF16 = 0x52,
358    #[cfg(feature = "complex")]
359    ComplexF32 = 0x54,
360    #[cfg(feature = "complex")]
361    ComplexF64 = 0x58,
362}
363
364impl DatumType {
365    pub fn size_of(&self) -> usize {
366        use DatumType::*;
367        match &self {
368            Bool | U8 | I8 => 1,
369            U16 | I16 | F16 => 2,
370            U32 | I32 | F32 => 4,
371            U64 | I64 | F64 => 8,
372            #[cfg(feature = "complex")]
373            ComplexI16 | ComplexF16 => 4,
374            #[cfg(feature = "complex")]
375            ComplexI32 | ComplexF32 => 8,
376            #[cfg(feature = "complex")]
377            ComplexI64 | ComplexF64 => 16,
378        }
379    }
380
381    pub fn is_bool(&self) -> bool {
382        *self == DatumType::Bool
383    }
384
385    pub fn is_number(&self) -> bool {
386        *self != DatumType::Bool
387    }
388
389    pub fn is_unsigned(&self) -> bool {
390        use DatumType::*;
391        *self == U8 || *self == U16 || *self == U32 || *self == U64
392    }
393
394    pub fn is_signed(&self) -> bool {
395        use DatumType::*;
396        *self == I8 || *self == I16 || *self == I32 || *self == I64
397    }
398
399    pub fn is_float(&self) -> bool {
400        use DatumType::*;
401        *self == F16 || *self == F32 || *self == F64
402    }
403}
404
405pub trait Datum {
406    fn datum_type() -> DatumType;
407}
408
409// IntoInputs trait — ergonomic input conversion for run()
410pub trait IntoInputs<V: TensorInterface> {
411    fn into_inputs(self) -> Result<Vec<V>>;
412}
413
414// Arrays of anything convertible to Tensor
415impl<V, T, E, const N: usize> IntoInputs<V> for [T; N]
416where
417    V: TensorInterface,
418    T: TryInto<V, Error = E>,
419    E: Into<anyhow::Error>,
420{
421    fn into_inputs(self) -> Result<Vec<V>> {
422        self.into_iter().map(|v| v.try_into().map_err(|e| e.into())).collect()
423    }
424}
425
426// Vec<V> passthrough
427impl<V: TensorInterface> IntoInputs<V> for Vec<V> {
428    fn into_inputs(self) -> Result<Vec<V>> {
429        Ok(self)
430    }
431}
432
433// Tuples — each element converts independently
434macro_rules! impl_into_inputs_tuple {
435    ($($idx:tt : $T:ident),+) => {
436        impl<V, $($T),+> IntoInputs<V> for ($($T,)+)
437        where
438            V: TensorInterface,
439            $($T: TryInto<V>,
440              <$T as TryInto<V>>::Error: Into<anyhow::Error>,)+
441        {
442            fn into_inputs(self) -> Result<Vec<V>> {
443                Ok(vec![$(self.$idx.try_into().map_err(|e| e.into())?),+])
444            }
445        }
446    };
447}
448
449impl_into_inputs_tuple!(0: A);
450impl_into_inputs_tuple!(0: A, 1: B);
451impl_into_inputs_tuple!(0: A, 1: B, 2: C);
452impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D);
453impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_);
454impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F);
455impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G);
456impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G, 7: H);
457
458/// Convert any compatible input into a `V: TensorInterface`.
459pub fn tensor<V, T, E>(v: T) -> Result<V>
460where
461    V: TensorInterface,
462    T: TryInto<V, Error = E>,
463    E: Into<anyhow::Error>,
464{
465    v.try_into().map_err(|e| e.into())
466}
467
468macro_rules! impl_datum_type {
469    ($ty:ty, $c_repr:expr) => {
470        impl Datum for $ty {
471            fn datum_type() -> DatumType {
472                $c_repr
473            }
474        }
475    };
476}
477
478impl_datum_type!(bool, DatumType::Bool);
479impl_datum_type!(u8, DatumType::U8);
480impl_datum_type!(u16, DatumType::U16);
481impl_datum_type!(u32, DatumType::U32);
482impl_datum_type!(u64, DatumType::U64);
483impl_datum_type!(i8, DatumType::I8);
484impl_datum_type!(i16, DatumType::I16);
485impl_datum_type!(i32, DatumType::I32);
486impl_datum_type!(i64, DatumType::I64);
487impl_datum_type!(half::f16, DatumType::F16);
488impl_datum_type!(f32, DatumType::F32);
489impl_datum_type!(f64, DatumType::F64);