1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8pub mod transform;
9
10pub use transform::{ConcretizeSymbols, FloatPrecision, Pulse, TransformConfig, TransformSpec};
11
12pub trait NnefInterface: Debug + Sized {
16 type Model: ModelInterface;
17 fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
21
22 fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
26
27 fn enable_tract_core(&mut self) -> Result<()>;
29
30 fn enable_tract_extra(&mut self) -> Result<()>;
32
33 fn enable_tract_transformers(&mut self) -> Result<()>;
35
36 fn enable_onnx(&mut self) -> Result<()>;
39
40 fn enable_pulse(&mut self) -> Result<()>;
42
43 fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
48
49 fn with_tract_core(mut self) -> Result<Self> {
51 self.enable_tract_core()?;
52 Ok(self)
53 }
54
55 fn with_tract_extra(mut self) -> Result<Self> {
57 self.enable_tract_extra()?;
58 Ok(self)
59 }
60
61 fn with_tract_transformers(mut self) -> Result<Self> {
63 self.enable_tract_transformers()?;
64 Ok(self)
65 }
66
67 fn with_onnx(mut self) -> Result<Self> {
69 self.enable_onnx()?;
70 Ok(self)
71 }
72
73 fn with_pulse(mut self) -> Result<Self> {
75 self.enable_pulse()?;
76 Ok(self)
77 }
78
79 fn with_extended_identifier_syntax(mut self) -> Result<Self> {
81 self.enable_extended_identifier_syntax()?;
82 Ok(self)
83 }
84
85 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
89
90 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
96 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
97}
98
99pub trait OnnxInterface: Debug {
100 type InferenceModel: InferenceModelInterface;
101 fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
102 fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
104}
105
106pub trait InferenceModelInterface: Debug + Sized {
107 type Model: ModelInterface;
108 type InferenceFact: InferenceFactInterface;
109 fn set_output_names(
110 &mut self,
111 outputs: impl IntoIterator<Item = impl AsRef<str>>,
112 ) -> Result<()>;
113 fn input_count(&self) -> Result<usize>;
114 fn output_count(&self) -> Result<usize>;
115 fn input_name(&self, id: usize) -> Result<String>;
116 fn output_name(&self, id: usize) -> Result<String>;
117
118 fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
119
120 fn set_input_fact(
121 &mut self,
122 id: usize,
123 fact: impl AsFact<Self, Self::InferenceFact>,
124 ) -> Result<()>;
125
126 fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
127
128 fn set_output_fact(
129 &mut self,
130 id: usize,
131 fact: impl AsFact<Self, Self::InferenceFact>,
132 ) -> Result<()>;
133
134 fn analyse(&mut self) -> Result<()>;
135
136 fn into_model(self) -> Result<Self::Model>;
137}
138
139pub trait ModelInterface: Debug + Sized {
140 type Fact: FactInterface;
141 type Runnable: RunnableInterface;
142 type Tensor: TensorInterface;
143 fn input_count(&self) -> Result<usize>;
144
145 fn output_count(&self) -> Result<usize>;
146
147 fn input_name(&self, id: usize) -> Result<String>;
148
149 fn output_name(&self, id: usize) -> Result<String>;
150
151 fn set_output_names(
152 &mut self,
153 outputs: impl IntoIterator<Item = impl AsRef<str>>,
154 ) -> Result<()>;
155
156 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
157
158 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
159
160 fn into_runnable(self) -> Result<Self::Runnable>;
161
162 fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()>;
163
164 fn property_keys(&self) -> Result<Vec<String>>;
165
166 fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
167
168 fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
169
170 fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
171 Ok((0..self.input_count()?)
172 .map(|ix| self.input_fact(ix))
173 .collect::<Result<Vec<_>>>()?
174 .into_iter())
175 }
176
177 fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
178 Ok((0..self.output_count()?)
179 .map(|ix| self.output_fact(ix))
180 .collect::<Result<Vec<_>>>()?
181 .into_iter())
182 }
183}
184
185pub trait RuntimeInterface: Debug {
186 type Runnable: RunnableInterface;
187 type Model: ModelInterface;
188 fn name(&self) -> Result<String>;
189 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
190}
191
192pub trait RunnableInterface: Debug + Send + Sync {
193 type Tensor: TensorInterface;
194 type Fact: FactInterface;
195 type State: StateInterface<Tensor = Self::Tensor>;
196 fn run(&self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>> {
197 self.spawn_state()?.run(inputs.into_inputs()?)
198 }
199
200 fn input_count(&self) -> Result<usize>;
201 fn output_count(&self) -> Result<usize>;
202 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
203
204 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
205
206 fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
207 Ok((0..self.input_count()?)
208 .map(|ix| self.input_fact(ix))
209 .collect::<Result<Vec<_>>>()?
210 .into_iter())
211 }
212
213 fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
214 Ok((0..self.output_count()?)
215 .map(|ix| self.output_fact(ix))
216 .collect::<Result<Vec<_>>>()?
217 .into_iter())
218 }
219
220 fn property_keys(&self) -> Result<Vec<String>>;
221 fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
222
223 fn spawn_state(&self) -> Result<Self::State>;
224
225 fn cost_json(&self) -> Result<String>;
226
227 fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
228 where
229 I: IntoIterator<Item = IV>,
230 IV: TryInto<Self::Tensor, Error = IE>,
231 IE: Into<anyhow::Error> + Debug;
232}
233
234pub trait StateInterface: Debug {
235 type Fact: FactInterface;
236 type Tensor: TensorInterface;
237
238 fn input_count(&self) -> Result<usize>;
239 fn output_count(&self) -> Result<usize>;
240
241 fn run(&mut self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>>;
242}
243
244pub trait TensorInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
245 fn datum_type(&self) -> Result<DatumType>;
246 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
247 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
248
249 fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
250 let data = unsafe {
251 std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
252 };
253 Self::from_bytes(T::datum_type(), shape, data)
254 }
255
256 fn as_slice<T: Datum>(&self) -> Result<&[T]> {
257 let (dt, _shape, data) = self.as_bytes()?;
258 ensure!(T::datum_type() == dt);
259 let data = unsafe {
260 std::slice::from_raw_parts(
261 data.as_ptr() as *const T,
262 data.len() / std::mem::size_of::<T>(),
263 )
264 };
265 Ok(data)
266 }
267
268 fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
269 let (_, shape, _) = self.as_bytes()?;
270 let data = self.as_slice()?;
271 Ok((shape, data))
272 }
273
274 fn shape(&self) -> Result<&[usize]> {
275 let (_, shape, _) = self.as_bytes()?;
276 Ok(shape)
277 }
278
279 fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
280 let (shape, data) = self.as_shape_and_slice()?;
281 Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
282 }
283
284 fn view1<T: Datum>(&self) -> Result<ndarray::ArrayView1<'_, T>> {
285 Ok(self.view::<T>()?.into_dimensionality()?)
286 }
287
288 fn view2<T: Datum>(&self) -> Result<ndarray::ArrayView2<'_, T>> {
289 Ok(self.view::<T>()?.into_dimensionality()?)
290 }
291
292 fn view3<T: Datum>(&self) -> Result<ndarray::ArrayView3<'_, T>> {
293 Ok(self.view::<T>()?.into_dimensionality()?)
294 }
295
296 fn view4<T: Datum>(&self) -> Result<ndarray::ArrayView4<'_, T>> {
297 Ok(self.view::<T>()?.into_dimensionality()?)
298 }
299
300 fn view5<T: Datum>(&self) -> Result<ndarray::ArrayView5<'_, T>> {
301 Ok(self.view::<T>()?.into_dimensionality()?)
302 }
303
304 fn view6<T: Datum>(&self) -> Result<ndarray::ArrayView6<'_, T>> {
305 Ok(self.view::<T>()?.into_dimensionality()?)
306 }
307
308 fn convert_to(&self, to: DatumType) -> Result<Self>;
309}
310
311pub trait FactInterface: Debug + Display + Clone {
312 type Dim: DimInterface;
313 fn datum_type(&self) -> Result<DatumType>;
314 fn rank(&self) -> Result<usize>;
315 fn dim(&self, axis: usize) -> Result<Self::Dim>;
316
317 fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
318 Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
319 }
320}
321
322pub trait DimInterface: Debug + Display + Clone {
323 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
324 fn to_int64(&self) -> Result<i64>;
325}
326
327pub trait InferenceFactInterface: Debug + Display + Default + Clone {
328 fn empty() -> Result<Self>;
329}
330
331pub trait AsFact<M, F>: Debug {
332 fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
333}
334
335#[repr(C)]
336#[derive(Debug, PartialEq, Eq, Copy, Clone)]
337pub enum DatumType {
338 Bool = 0x01,
339 U8 = 0x11,
340 U16 = 0x12,
341 U32 = 0x14,
342 U64 = 0x18,
343 I8 = 0x21,
344 I16 = 0x22,
345 I32 = 0x24,
346 I64 = 0x28,
347 F16 = 0x32,
348 F32 = 0x34,
349 F64 = 0x38,
350 #[cfg(feature = "complex")]
351 ComplexI16 = 0x42,
352 #[cfg(feature = "complex")]
353 ComplexI32 = 0x44,
354 #[cfg(feature = "complex")]
355 ComplexI64 = 0x48,
356 #[cfg(feature = "complex")]
357 ComplexF16 = 0x52,
358 #[cfg(feature = "complex")]
359 ComplexF32 = 0x54,
360 #[cfg(feature = "complex")]
361 ComplexF64 = 0x58,
362}
363
364impl DatumType {
365 pub fn size_of(&self) -> usize {
366 use DatumType::*;
367 match &self {
368 Bool | U8 | I8 => 1,
369 U16 | I16 | F16 => 2,
370 U32 | I32 | F32 => 4,
371 U64 | I64 | F64 => 8,
372 #[cfg(feature = "complex")]
373 ComplexI16 | ComplexF16 => 4,
374 #[cfg(feature = "complex")]
375 ComplexI32 | ComplexF32 => 8,
376 #[cfg(feature = "complex")]
377 ComplexI64 | ComplexF64 => 16,
378 }
379 }
380
381 pub fn is_bool(&self) -> bool {
382 *self == DatumType::Bool
383 }
384
385 pub fn is_number(&self) -> bool {
386 *self != DatumType::Bool
387 }
388
389 pub fn is_unsigned(&self) -> bool {
390 use DatumType::*;
391 *self == U8 || *self == U16 || *self == U32 || *self == U64
392 }
393
394 pub fn is_signed(&self) -> bool {
395 use DatumType::*;
396 *self == I8 || *self == I16 || *self == I32 || *self == I64
397 }
398
399 pub fn is_float(&self) -> bool {
400 use DatumType::*;
401 *self == F16 || *self == F32 || *self == F64
402 }
403}
404
405pub trait Datum {
406 fn datum_type() -> DatumType;
407}
408
409pub trait IntoInputs<V: TensorInterface> {
411 fn into_inputs(self) -> Result<Vec<V>>;
412}
413
414impl<V, T, E, const N: usize> IntoInputs<V> for [T; N]
416where
417 V: TensorInterface,
418 T: TryInto<V, Error = E>,
419 E: Into<anyhow::Error>,
420{
421 fn into_inputs(self) -> Result<Vec<V>> {
422 self.into_iter().map(|v| v.try_into().map_err(|e| e.into())).collect()
423 }
424}
425
426impl<V: TensorInterface> IntoInputs<V> for Vec<V> {
428 fn into_inputs(self) -> Result<Vec<V>> {
429 Ok(self)
430 }
431}
432
433macro_rules! impl_into_inputs_tuple {
435 ($($idx:tt : $T:ident),+) => {
436 impl<V, $($T),+> IntoInputs<V> for ($($T,)+)
437 where
438 V: TensorInterface,
439 $($T: TryInto<V>,
440 <$T as TryInto<V>>::Error: Into<anyhow::Error>,)+
441 {
442 fn into_inputs(self) -> Result<Vec<V>> {
443 Ok(vec![$(self.$idx.try_into().map_err(|e| e.into())?),+])
444 }
445 }
446 };
447}
448
449impl_into_inputs_tuple!(0: A);
450impl_into_inputs_tuple!(0: A, 1: B);
451impl_into_inputs_tuple!(0: A, 1: B, 2: C);
452impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D);
453impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_);
454impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F);
455impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G);
456impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G, 7: H);
457
458pub fn tensor<V, T, E>(v: T) -> Result<V>
460where
461 V: TensorInterface,
462 T: TryInto<V, Error = E>,
463 E: Into<anyhow::Error>,
464{
465 v.try_into().map_err(|e| e.into())
466}
467
468macro_rules! impl_datum_type {
469 ($ty:ty, $c_repr:expr) => {
470 impl Datum for $ty {
471 fn datum_type() -> DatumType {
472 $c_repr
473 }
474 }
475 };
476}
477
478impl_datum_type!(bool, DatumType::Bool);
479impl_datum_type!(u8, DatumType::U8);
480impl_datum_type!(u16, DatumType::U16);
481impl_datum_type!(u32, DatumType::U32);
482impl_datum_type!(u64, DatumType::U64);
483impl_datum_type!(i8, DatumType::I8);
484impl_datum_type!(i16, DatumType::I16);
485impl_datum_type!(i32, DatumType::I32);
486impl_datum_type!(i64, DatumType::I64);
487impl_datum_type!(half::f16, DatumType::F16);
488impl_datum_type!(f32, DatumType::F32);
489impl_datum_type!(f64, DatumType::F64);