1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8pub mod transform;
9
10pub use transform::{FloatPrecision, Pulse, SetSymbols, TransformConfig, TransformSpec};
11
12pub trait NnefInterface: Debug + Sized {
16 type Model: ModelInterface;
17 fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
21
22 fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
26
27 fn disable_tract_core(&mut self) -> Result<()>;
30
31 fn enable_tract_extra(&mut self) -> Result<()>;
33
34 fn enable_tract_transformers(&mut self) -> Result<()>;
36
37 fn enable_onnx(&mut self) -> Result<()>;
40
41 fn enable_pulse(&mut self) -> Result<()>;
43
44 fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
49
50 fn without_tract_core(mut self) -> Result<Self> {
52 self.disable_tract_core()?;
53 Ok(self)
54 }
55
56 fn with_tract_extra(mut self) -> Result<Self> {
58 self.enable_tract_extra()?;
59 Ok(self)
60 }
61
62 fn with_tract_transformers(mut self) -> Result<Self> {
64 self.enable_tract_transformers()?;
65 Ok(self)
66 }
67
68 fn with_onnx(mut self) -> Result<Self> {
70 self.enable_onnx()?;
71 Ok(self)
72 }
73
74 fn with_pulse(mut self) -> Result<Self> {
76 self.enable_pulse()?;
77 Ok(self)
78 }
79
80 fn with_extended_identifier_syntax(mut self) -> Result<Self> {
82 self.enable_extended_identifier_syntax()?;
83 Ok(self)
84 }
85
86 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
90
91 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
97 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
98}
99
100pub trait OnnxInterface: Debug {
101 type InferenceModel: InferenceModelInterface;
102 fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
103 fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
105}
106
107pub trait InferenceModelInterface: Debug + Sized {
108 type Model: ModelInterface;
109 type InferenceFact: InferenceFactInterface;
110 fn input_count(&self) -> Result<usize>;
111 fn output_count(&self) -> Result<usize>;
112 fn input_name(&self, id: usize) -> Result<String>;
113 fn output_name(&self, id: usize) -> Result<String>;
114
115 fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
116
117 fn set_input_fact(
118 &mut self,
119 id: usize,
120 fact: impl AsFact<Self, Self::InferenceFact>,
121 ) -> Result<()>;
122
123 fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
124
125 fn set_output_fact(
126 &mut self,
127 id: usize,
128 fact: impl AsFact<Self, Self::InferenceFact>,
129 ) -> Result<()>;
130
131 fn analyse(&mut self) -> Result<()>;
132
133 fn into_model(self) -> Result<Self::Model>;
134}
135
136pub trait ModelInterface: Debug + Sized {
137 type Fact: FactInterface;
138 type Runnable: RunnableInterface;
139 type Tensor: TensorInterface;
140 fn input_count(&self) -> Result<usize>;
141
142 fn output_count(&self) -> Result<usize>;
143
144 fn input_name(&self, id: usize) -> Result<String>;
145
146 fn output_name(&self, id: usize) -> Result<String>;
147
148 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
149
150 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
151
152 fn into_runnable(self) -> Result<Self::Runnable>;
153
154 fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()>;
155
156 fn property_keys(&self) -> Result<Vec<String>>;
157
158 fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
159
160 fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
161
162 fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
163 Ok((0..self.input_count()?)
164 .map(|ix| self.input_fact(ix))
165 .collect::<Result<Vec<_>>>()?
166 .into_iter())
167 }
168
169 fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
170 Ok((0..self.output_count()?)
171 .map(|ix| self.output_fact(ix))
172 .collect::<Result<Vec<_>>>()?
173 .into_iter())
174 }
175}
176
177pub trait RuntimeInterface: Debug {
178 type Runnable: RunnableInterface;
179 type Model: ModelInterface;
180 fn name(&self) -> Result<String>;
181 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
182}
183
184pub trait RunnableInterface: Debug + Send + Sync {
185 type Tensor: TensorInterface;
186 type Fact: FactInterface;
187 type State: StateInterface<Tensor = Self::Tensor>;
188 fn run(&self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>> {
189 self.spawn_state()?.run(inputs.into_inputs()?)
190 }
191
192 fn input_count(&self) -> Result<usize>;
193 fn output_count(&self) -> Result<usize>;
194 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
195
196 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
197
198 fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
199 Ok((0..self.input_count()?)
200 .map(|ix| self.input_fact(ix))
201 .collect::<Result<Vec<_>>>()?
202 .into_iter())
203 }
204
205 fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
206 Ok((0..self.output_count()?)
207 .map(|ix| self.output_fact(ix))
208 .collect::<Result<Vec<_>>>()?
209 .into_iter())
210 }
211
212 fn property_keys(&self) -> Result<Vec<String>>;
213 fn property(&self, name: impl AsRef<str>) -> Result<Self::Tensor>;
214
215 fn spawn_state(&self) -> Result<Self::State>;
216
217 fn cost_json(&self) -> Result<String>;
218
219 fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
220 where
221 I: IntoIterator<Item = IV>,
222 IV: TryInto<Self::Tensor, Error = IE>,
223 IE: Into<anyhow::Error> + Debug;
224}
225
226pub trait StateInterface: Debug + Clone + Send {
227 type Fact: FactInterface;
228 type Tensor: TensorInterface;
229
230 fn input_count(&self) -> Result<usize>;
231 fn output_count(&self) -> Result<usize>;
232
233 fn run(&mut self, inputs: impl IntoInputs<Self::Tensor>) -> Result<Vec<Self::Tensor>>;
234}
235
236pub trait TensorInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
237 fn datum_type(&self) -> Result<DatumType>;
238 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
239 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
240
241 fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
242 let data = unsafe {
243 std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
244 };
245 Self::from_bytes(T::datum_type(), shape, data)
246 }
247
248 fn as_slice<T: Datum>(&self) -> Result<&[T]> {
249 let (dt, _shape, data) = self.as_bytes()?;
250 ensure!(T::datum_type() == dt);
251 let data = unsafe {
252 std::slice::from_raw_parts(
253 data.as_ptr() as *const T,
254 data.len() / std::mem::size_of::<T>(),
255 )
256 };
257 Ok(data)
258 }
259
260 fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
261 let (_, shape, _) = self.as_bytes()?;
262 let data = self.as_slice()?;
263 Ok((shape, data))
264 }
265
266 fn shape(&self) -> Result<&[usize]> {
267 let (_, shape, _) = self.as_bytes()?;
268 Ok(shape)
269 }
270
271 fn convert_to(&self, to: DatumType) -> Result<Self>;
272}
273
274pub trait FactInterface: Debug + Display + Clone {
275 type Dim: DimInterface;
276 fn datum_type(&self) -> Result<DatumType>;
277 fn rank(&self) -> Result<usize>;
278 fn dim(&self, axis: usize) -> Result<Self::Dim>;
279
280 fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
281 Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
282 }
283}
284
285pub trait DimInterface: Debug + Display + Clone {
286 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
287 fn to_int64(&self) -> Result<i64>;
288}
289
290pub trait InferenceFactInterface: Debug + Display + Default + Clone {
291 fn empty() -> Result<Self>;
292}
293
294pub trait AsFact<M, F>: Debug {
295 fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
296}
297
298#[repr(C)]
299#[derive(Debug, PartialEq, Eq, Copy, Clone)]
300pub enum DatumType {
301 Bool = 0x01,
302 U8 = 0x11,
303 U16 = 0x12,
304 U32 = 0x14,
305 U64 = 0x18,
306 I8 = 0x21,
307 I16 = 0x22,
308 I32 = 0x24,
309 I64 = 0x28,
310 F16 = 0x32,
311 F32 = 0x34,
312 F64 = 0x38,
313 #[cfg(feature = "complex")]
314 ComplexI16 = 0x42,
315 #[cfg(feature = "complex")]
316 ComplexI32 = 0x44,
317 #[cfg(feature = "complex")]
318 ComplexI64 = 0x48,
319 #[cfg(feature = "complex")]
320 ComplexF16 = 0x52,
321 #[cfg(feature = "complex")]
322 ComplexF32 = 0x54,
323 #[cfg(feature = "complex")]
324 ComplexF64 = 0x58,
325}
326
327impl DatumType {
328 pub fn size_of(&self) -> usize {
329 use DatumType::*;
330 match &self {
331 Bool | U8 | I8 => 1,
332 U16 | I16 | F16 => 2,
333 U32 | I32 | F32 => 4,
334 U64 | I64 | F64 => 8,
335 #[cfg(feature = "complex")]
336 ComplexI16 | ComplexF16 => 4,
337 #[cfg(feature = "complex")]
338 ComplexI32 | ComplexF32 => 8,
339 #[cfg(feature = "complex")]
340 ComplexI64 | ComplexF64 => 16,
341 }
342 }
343
344 pub fn is_bool(&self) -> bool {
345 *self == DatumType::Bool
346 }
347
348 pub fn is_number(&self) -> bool {
349 *self != DatumType::Bool
350 }
351
352 pub fn is_unsigned(&self) -> bool {
353 use DatumType::*;
354 *self == U8 || *self == U16 || *self == U32 || *self == U64
355 }
356
357 pub fn is_signed(&self) -> bool {
358 use DatumType::*;
359 *self == I8 || *self == I16 || *self == I32 || *self == I64
360 }
361
362 pub fn is_float(&self) -> bool {
363 use DatumType::*;
364 *self == F16 || *self == F32 || *self == F64
365 }
366}
367
368pub trait Datum {
369 fn datum_type() -> DatumType;
370}
371
372pub trait IntoInputs<V: TensorInterface> {
374 fn into_inputs(self) -> Result<Vec<V>>;
375}
376
377impl<V, T, E, const N: usize> IntoInputs<V> for [T; N]
379where
380 V: TensorInterface,
381 T: TryInto<V, Error = E>,
382 E: Into<anyhow::Error>,
383{
384 fn into_inputs(self) -> Result<Vec<V>> {
385 self.into_iter().map(|v| v.try_into().map_err(|e| e.into())).collect()
386 }
387}
388
389impl<V: TensorInterface> IntoInputs<V> for Vec<V> {
391 fn into_inputs(self) -> Result<Vec<V>> {
392 Ok(self)
393 }
394}
395
396macro_rules! impl_into_inputs_tuple {
398 ($($idx:tt : $T:ident),+) => {
399 impl<V, $($T),+> IntoInputs<V> for ($($T,)+)
400 where
401 V: TensorInterface,
402 $($T: TryInto<V>,
403 <$T as TryInto<V>>::Error: Into<anyhow::Error>,)+
404 {
405 fn into_inputs(self) -> Result<Vec<V>> {
406 Ok(vec![$(self.$idx.try_into().map_err(|e| e.into())?),+])
407 }
408 }
409 };
410}
411
412impl_into_inputs_tuple!(0: A);
413impl_into_inputs_tuple!(0: A, 1: B);
414impl_into_inputs_tuple!(0: A, 1: B, 2: C);
415impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D);
416impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_);
417impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F);
418impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G);
419impl_into_inputs_tuple!(0: A, 1: B, 2: C, 3: D, 4: E_, 5: F, 6: G, 7: H);
420
421pub fn tensor<V, T, E>(v: T) -> Result<V>
423where
424 V: TensorInterface,
425 T: TryInto<V, Error = E>,
426 E: Into<anyhow::Error>,
427{
428 v.try_into().map_err(|e| e.into())
429}
430
431macro_rules! impl_datum_type {
432 ($ty:ty, $c_repr:expr) => {
433 impl Datum for $ty {
434 fn datum_type() -> DatumType {
435 $c_repr
436 }
437 }
438 };
439}
440
441impl_datum_type!(bool, DatumType::Bool);
442impl_datum_type!(u8, DatumType::U8);
443impl_datum_type!(u16, DatumType::U16);
444impl_datum_type!(u32, DatumType::U32);
445impl_datum_type!(u64, DatumType::U64);
446impl_datum_type!(i8, DatumType::I8);
447impl_datum_type!(i16, DatumType::I16);
448impl_datum_type!(i32, DatumType::I32);
449impl_datum_type!(i64, DatumType::I64);
450impl_datum_type!(half::f16, DatumType::F16);
451impl_datum_type!(f32, DatumType::F32);
452impl_datum_type!(f64, DatumType::F64);