Skip to main content

tract_api/
lib.rs

1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8pub mod transform;
9
10pub use transform::{ConcretizeSymbols, FloatPrecision, Pulse, TransformConfig, TransformSpec};
11
12/// an implementation of tract's NNEF framework object
13///
14/// Entry point for NNEF model manipulation: loading from file, dumping to file.
15pub trait NnefInterface: Debug + Sized {
16    type Model: ModelInterface;
17    /// Load a NNEF model from the path into a tract-core model.
18    ///
19    /// * `path` can point to a directory, a `tar` file or a `tar.gz` file.
20    fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
21
22    /// Load a NNEF model from a buffer into a tract-core model.
23    ///
24    /// data is the content of a NNEF model, as a `tar` file or a `tar.gz` file.
25    fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
26
27    /// Allow the framework to use tract_core extensions instead of a stricter NNEF definition.
28    fn enable_tract_core(&mut self) -> Result<()>;
29
30    /// Allow the framework to use tract_extra extensions.
31    fn enable_tract_extra(&mut self) -> Result<()>;
32
33    /// Allow the framework to use tract_transformers extensions to support common transformer operators.
34    fn enable_tract_transformers(&mut self) -> Result<()>;
35
36    /// Allow the framework to use tract_onnx extensions to support operators in ONNX that are
37    /// absent from NNEF.
38    fn enable_onnx(&mut self) -> Result<()>;
39
40    /// Allow the framework to use tract_pulse extensions to support stateful streaming operation.
41    fn enable_pulse(&mut self) -> Result<()>;
42
43    /// Allow the framework to use a tract-proprietary extension that can support special characters
44    /// in node names. If disable, tract will replace everything by underscore '_' to keep
45    /// compatibility with NNEF. If enabled, the extended syntax will be used, allowing to maintain
46    /// the node names in serialized form.
47    fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
48
49    /// Convenience function, similar with enable_tract_core but allowing method chaining.
50    fn with_tract_core(mut self) -> Result<Self> {
51        self.enable_tract_core()?;
52        Ok(self)
53    }
54
55    /// Convenience function, similar with enable_tract_core but allowing method chaining.
56    fn with_tract_extra(mut self) -> Result<Self> {
57        self.enable_tract_extra()?;
58        Ok(self)
59    }
60
61    /// Convenience function, similar with enable_tract_transformers but allowing method chaining.
62    fn with_tract_transformers(mut self) -> Result<Self> {
63        self.enable_tract_transformers()?;
64        Ok(self)
65    }
66
67    /// Convenience function, similar with enable_onnx but allowing method chaining.
68    fn with_onnx(mut self) -> Result<Self> {
69        self.enable_onnx()?;
70        Ok(self)
71    }
72
73    /// Convenience function, similar with enable_pulse but allowing method chaining.
74    fn with_pulse(mut self) -> Result<Self> {
75        self.enable_pulse()?;
76        Ok(self)
77    }
78
79    /// Convenience function, similar with enable_extended_identifier_syntax but allowing method chaining.
80    fn with_extended_identifier_syntax(mut self) -> Result<Self> {
81        self.enable_extended_identifier_syntax()?;
82        Ok(self)
83    }
84
85    /// Dump a TypedModel as a NNEF directory.
86    ///
87    /// `path` is the directory name to dump to
88    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
89
90    /// Dump a TypedModel as a NNEF tar file.
91    ///
92    /// This function creates a plain, non-compressed, archive.
93    ///
94    /// `path` is the archive name
95    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
96    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
97}
98
99pub trait OnnxInterface: Debug {
100    type InferenceModel: InferenceModelInterface;
101    fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
102    /// Load a ONNX model from a buffer into an InferenceModel.
103    fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
104}
105
106pub trait InferenceModelInterface: Debug + Sized {
107    type Model: ModelInterface;
108    type InferenceFact: InferenceFactInterface;
109    fn input_count(&self) -> Result<usize>;
110    fn output_count(&self) -> Result<usize>;
111    fn input_name(&self, id: usize) -> Result<String>;
112    fn output_name(&self, id: usize) -> Result<String>;
113
114    fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
115
116    fn set_input_fact(
117        &mut self,
118        id: usize,
119        fact: impl AsFact<Self, Self::InferenceFact>,
120    ) -> Result<()>;
121
122    fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
123
124    fn set_output_fact(
125        &mut self,
126        id: usize,
127        fact: impl AsFact<Self, Self::InferenceFact>,
128    ) -> Result<()>;
129
130    fn analyse(&mut self) -> Result<()>;
131
132    fn into_model(self) -> Result<Self::Model>;
133}
134
135pub trait ModelInterface: Debug + Sized {
136    type Fact: FactInterface;
137    type Runnable: RunnableInterface;
138    type Tensor: TensorInterface;
139    fn input_count(&self) -> Result<usize>;
140
141    fn output_count(&self) -> Result<usize>;
142
143    fn input_name(&self, id: usize) -> Result<String>;
144
145    fn output_name(&self, id: usize) -> Result<String>;
146
147    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
148
149    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
150
151    fn into_runnable(self) -> Result<Self::Runnable>;
152
153    fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()>;
154
155    fn property_keys(&self) -> Result<Vec<String>>;
156
157    fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
158
159    fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
160
161    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
162        Ok((0..self.input_count()?)
163            .map(|ix| self.input_fact(ix))
164            .collect::<Result<Vec<_>>>()?
165            .into_iter())
166    }
167
168    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
169        Ok((0..self.output_count()?)
170            .map(|ix| self.output_fact(ix))
171            .collect::<Result<Vec<_>>>()?
172            .into_iter())
173    }
174}
175
176pub trait RuntimeInterface: Debug {
177    type Runnable: RunnableInterface;
178    type Model: ModelInterface;
179    fn name(&self) -> Result<String>;
180    fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
181}
182
183pub trait RunnableInterface: Debug + Send + Sync {
184    type Tensor: TensorInterface;
185    type Fact: FactInterface;
186    type State: StateInterface<Tensor = Self::Tensor>;
187    fn run(&self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>> {
188        self.spawn_state()?.run(inputs.into_inputs()?)
189    }
190
191    fn input_count(&self) -> Result<usize>;
192    fn output_count(&self) -> Result<usize>;
193    fn input_fact(&self, id: usize) -> Result<Self::Fact>;
194
195    fn output_fact(&self, id: usize) -> Result<Self::Fact>;
196
197    fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
198        Ok((0..self.input_count()?)
199            .map(|ix| self.input_fact(ix))
200            .collect::<Result<Vec<_>>>()?
201            .into_iter())
202    }
203
204    fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
205        Ok((0..self.output_count()?)
206            .map(|ix| self.output_fact(ix))
207            .collect::<Result<Vec<_>>>()?
208            .into_iter())
209    }
210
211    fn property_keys(&self) -> Result<Vec<String>>;
212    fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
213
214    fn spawn_state(&self) -> Result<Self::State>;
215
216    fn cost_json(&self) -> Result<String>;
217
218    fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
219    where
220        I: IntoIterator<Item = IV>,
221        IV: TryInto<Self::Tensor, Error = IE>,
222        IE: Into<anyhow::Error> + Debug;
223}
224
225pub trait StateInterface: Debug + Clone + Send {
226    type Fact: FactInterface;
227    type Tensor: TensorInterface;
228
229    fn input_count(&self) -> Result<usize>;
230    fn output_count(&self) -> Result<usize>;
231
232    fn run(&mut self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>>;
233}
234
235pub trait TensorInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
236    fn datum_type(&self) -> Result<DatumType>;
237    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
238    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
239
240    fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
241        let data = unsafe {
242            std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
243        };
244        Self::from_bytes(T::datum_type(), shape, data)
245    }
246
247    fn as_slice<T: Datum>(&self) -> Result<&[T]> {
248        let (dt, _shape, data) = self.as_bytes()?;
249        ensure!(T::datum_type() == dt);
250        let data = unsafe {
251            std::slice::from_raw_parts(
252                data.as_ptr() as *const T,
253                data.len() / std::mem::size_of::<T>(),
254            )
255        };
256        Ok(data)
257    }
258
259    fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
260        let (_, shape, _) = self.as_bytes()?;
261        let data = self.as_slice()?;
262        Ok((shape, data))
263    }
264
265    fn shape(&self) -> Result<&[usize]> {
266        let (_, shape, _) = self.as_bytes()?;
267        Ok(shape)
268    }
269
270    fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
271        let (shape, data) = self.as_shape_and_slice()?;
272        Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
273    }
274
275    fn view1<T: Datum>(&self) -> Result<ndarray::ArrayView1<'_, T>> {
276        Ok(self.view::<T>()?.into_dimensionality()?)
277    }
278
279    fn view2<T: Datum>(&self) -> Result<ndarray::ArrayView2<'_, T>> {
280        Ok(self.view::<T>()?.into_dimensionality()?)
281    }
282
283    fn view3<T: Datum>(&self) -> Result<ndarray::ArrayView3<'_, T>> {
284        Ok(self.view::<T>()?.into_dimensionality()?)
285    }
286
287    fn view4<T: Datum>(&self) -> Result<ndarray::ArrayView4<'_, T>> {
288        Ok(self.view::<T>()?.into_dimensionality()?)
289    }
290
291    fn view5<T: Datum>(&self) -> Result<ndarray::ArrayView5<'_, T>> {
292        Ok(self.view::<T>()?.into_dimensionality()?)
293    }
294
295    fn view6<T: Datum>(&self) -> Result<ndarray::ArrayView6<'_, T>> {
296        Ok(self.view::<T>()?.into_dimensionality()?)
297    }
298
299    fn convert_to(&self, to: DatumType) -> Result<Self>;
300}
301
302pub trait FactInterface: Debug + Display + Clone {
303    type Dim: DimInterface;
304    fn datum_type(&self) -> Result<DatumType>;
305    fn rank(&self) -> Result<usize>;
306    fn dim(&self, axis: usize) -> Result<Self::Dim>;
307
308    fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
309        Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
310    }
311}
312
313pub trait DimInterface: Debug + Display + Clone {
314    fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
315    fn to_int64(&self) -> Result<i64>;
316}
317
318pub trait InferenceFactInterface: Debug + Display + Default + Clone {
319    fn empty() -> Result<Self>;
320}
321
322pub trait AsFact<M, F>: Debug {
323    fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
324}
325
326#[repr(C)]
327#[derive(Debug, PartialEq, Eq, Copy, Clone)]
328pub enum DatumType {
329    Bool = 0x01,
330    U8 = 0x11,
331    U16 = 0x12,
332    U32 = 0x14,
333    U64 = 0x18,
334    I8 = 0x21,
335    I16 = 0x22,
336    I32 = 0x24,
337    I64 = 0x28,
338    F16 = 0x32,
339    F32 = 0x34,
340    F64 = 0x38,
341    #[cfg(feature = "complex")]
342    ComplexI16 = 0x42,
343    #[cfg(feature = "complex")]
344    ComplexI32 = 0x44,
345    #[cfg(feature = "complex")]
346    ComplexI64 = 0x48,
347    #[cfg(feature = "complex")]
348    ComplexF16 = 0x52,
349    #[cfg(feature = "complex")]
350    ComplexF32 = 0x54,
351    #[cfg(feature = "complex")]
352    ComplexF64 = 0x58,
353}
354
355impl DatumType {
356    pub fn size_of(&self) -> usize {
357        use DatumType::*;
358        match &self {
359            Bool | U8 | I8 => 1,
360            U16 | I16 | F16 => 2,
361            U32 | I32 | F32 => 4,
362            U64 | I64 | F64 => 8,
363            #[cfg(feature = "complex")]
364            ComplexI16 | ComplexF16 => 4,
365            #[cfg(feature = "complex")]
366            ComplexI32 | ComplexF32 => 8,
367            #[cfg(feature = "complex")]
368            ComplexI64 | ComplexF64 => 16,
369        }
370    }
371
372    pub fn is_bool(&self) -> bool {
373        *self == DatumType::Bool
374    }
375
376    pub fn is_number(&self) -> bool {
377        *self != DatumType::Bool
378    }
379
380    pub fn is_unsigned(&self) -> bool {
381        use DatumType::*;
382        *self == U8 || *self == U16 || *self == U32 || *self == U64
383    }
384
385    pub fn is_signed(&self) -> bool {
386        use DatumType::*;
387        *self == I8 || *self == I16 || *self == I32 || *self == I64
388    }
389
390    pub fn is_float(&self) -> bool {
391        use DatumType::*;
392        *self == F16 || *self == F32 || *self == F64
393    }
394}
395
396pub trait Datum {
397    fn datum_type() -> DatumType;
398}
399
400// IntoInputs trait — ergonomic input conversion for run()
401pub trait IntoInputs<V: TensorInterface> {
402    fn into_inputs(self) -> Result<Vec<V>>;
403}
404
405// Arrays of anything convertible to Tensor
406impl<V, T, E, const N: usize> IntoInputs<V> for [T; N]
407where
408    V: TensorInterface,
409    T: TryInto<V, Error = E>,
410    E: Into<anyhow::Error>,
411{
412    fn into_inputs(self) -> Result<Vec<V>> {
413        self.into_iter().map(|v| v.try_into().map_err(|e| e.into())).collect()
414    }
415}
416
417// Vec<V> passthrough
418impl<V: TensorInterface> IntoInputs<V> for Vec<V> {
419    fn into_inputs(self) -> Result<Vec<V>> {
420        Ok(self)
421    }
422}
423
424// Tuples — each element converts independently
425macro_rules! impl_into_inputs_tuple {
426    ($($idx:tt : $T:ident),+) => {
427        impl<V, $($T),+> IntoInputs<V> for ($($T,)+)
428        where
429            V: TensorInterface,
430            $($T: TryInto<V>,
431              <$T as TryInto<V>>::Error: Into<anyhow::Error>,)+
432        {
433            fn into_inputs(self) -> Result<Vec<V>> {
434                Ok(vec![$(self.$idx.try_into().map_err(|e| e.into())?),+])
435            }
436        }
437    };
438}
439
440impl_into_inputs_tuple!(0: A);
441impl_into_inputs_tuple!(0: A, 1: B);
442impl_into_inputs_tuple!(0: A, 1: B, 2: C);
443impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D);
444impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_);
445impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F);
446impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G);
447impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G, 7: H);
448
449/// Convert any compatible input into a `V: TensorInterface`.
450pub fn tensor<V, T, E>(v: T) -> Result<V>
451where
452    V: TensorInterface,
453    T: TryInto<V, Error = E>,
454    E: Into<anyhow::Error>,
455{
456    v.try_into().map_err(|e| e.into())
457}
458
459macro_rules! impl_datum_type {
460    ($ty:ty, $c_repr:expr) => {
461        impl Datum for $ty {
462            fn datum_type() -> DatumType {
463                $c_repr
464            }
465        }
466    };
467}
468
469impl_datum_type!(bool, DatumType::Bool);
470impl_datum_type!(u8, DatumType::U8);
471impl_datum_type!(u16, DatumType::U16);
472impl_datum_type!(u32, DatumType::U32);
473impl_datum_type!(u64, DatumType::U64);
474impl_datum_type!(i8, DatumType::I8);
475impl_datum_type!(i16, DatumType::I16);
476impl_datum_type!(i32, DatumType::I32);
477impl_datum_type!(i64, DatumType::I64);
478impl_datum_type!(half::f16, DatumType::F16);
479impl_datum_type!(f32, DatumType::F32);
480impl_datum_type!(f64, DatumType::F64);