1use anyhow::{Result, ensure};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8
9pub trait NnefInterface: Sized {
13 type Model: ModelInterface;
14 fn load(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
18
19 fn load_buffer(&self, data: &[u8]) -> Result<Self::Model>;
23
24 fn enable_tract_core(&mut self) -> Result<()>;
26
27 fn enable_tract_extra(&mut self) -> Result<()>;
29
30 fn enable_tract_transformers(&mut self) -> Result<()>;
32
33 fn enable_onnx(&mut self) -> Result<()>;
36
37 fn enable_pulse(&mut self) -> Result<()>;
39
40 fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
45
46 fn with_tract_core(mut self) -> Result<Self> {
48 self.enable_tract_core()?;
49 Ok(self)
50 }
51
52 fn with_tract_extra(mut self) -> Result<Self> {
54 self.enable_tract_extra()?;
55 Ok(self)
56 }
57
58 fn with_tract_transformers(mut self) -> Result<Self> {
60 self.enable_tract_transformers()?;
61 Ok(self)
62 }
63
64 fn with_onnx(mut self) -> Result<Self> {
66 self.enable_onnx()?;
67 Ok(self)
68 }
69
70 fn with_pulse(mut self) -> Result<Self> {
72 self.enable_pulse()?;
73 Ok(self)
74 }
75
76 fn with_extended_identifier_syntax(mut self) -> Result<Self> {
78 self.enable_extended_identifier_syntax()?;
79 Ok(self)
80 }
81
82 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
86
87 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
93 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
94}
95
96pub trait OnnxInterface {
97 type InferenceModel: InferenceModelInterface;
98 fn load(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
99 fn load_buffer(&self, data: &[u8]) -> Result<Self::InferenceModel>;
101}
102
103pub trait InferenceModelInterface: Sized {
104 type Model: ModelInterface;
105 type InferenceFact: InferenceFactInterface;
106 fn set_output_names(
107 &mut self,
108 outputs: impl IntoIterator<Item = impl AsRef<str>>,
109 ) -> Result<()>;
110 fn input_count(&self) -> Result<usize>;
111 fn output_count(&self) -> Result<usize>;
112 fn input_name(&self, id: usize) -> Result<String>;
113 fn output_name(&self, id: usize) -> Result<String>;
114
115 fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
116
117 fn set_input_fact(
118 &mut self,
119 id: usize,
120 fact: impl AsFact<Self, Self::InferenceFact>,
121 ) -> Result<()>;
122
123 fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
124
125 fn set_output_fact(
126 &mut self,
127 id: usize,
128 fact: impl AsFact<Self, Self::InferenceFact>,
129 ) -> Result<()>;
130
131 fn analyse(&mut self) -> Result<()>;
132
133 fn into_tract(self) -> Result<Self::Model>;
134}
135
136pub trait ModelInterface: Sized {
137 type Fact: FactInterface;
138 type Runnable: RunnableInterface;
139 type Value: ValueInterface;
140 fn input_count(&self) -> Result<usize>;
141
142 fn output_count(&self) -> Result<usize>;
143
144 fn input_name(&self, id: usize) -> Result<String>;
145
146 fn output_name(&self, id: usize) -> Result<String>;
147
148 fn set_output_names(
149 &mut self,
150 outputs: impl IntoIterator<Item = impl AsRef<str>>,
151 ) -> Result<()>;
152
153 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
154
155 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
156
157 fn into_runnable(self) -> Result<Self::Runnable>;
158
159 fn concretize_symbols(
160 &mut self,
161 values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
162 ) -> Result<()>;
163
164 fn transform(&mut self, transform: &str) -> Result<()>;
165
166 fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()>;
167
168 fn property_keys(&self) -> Result<Vec<String>>;
169
170 fn property(&self, name: impl AsRef<str>) -> Result<Self::Value>;
171
172 fn parse_fact(&self, spec: &str) -> Result<Self::Fact>;
173
174 fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
175 Ok((0..self.input_count()?)
176 .map(|ix| self.input_fact(ix))
177 .collect::<Result<Vec<_>>>()?
178 .into_iter())
179 }
180
181 fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
182 Ok((0..self.output_count()?)
183 .map(|ix| self.output_fact(ix))
184 .collect::<Result<Vec<_>>>()?
185 .into_iter())
186 }
187}
188
189pub trait RuntimeInterface {
190 type Runnable: RunnableInterface;
191 type Model: ModelInterface;
192 fn prepare(&self, model: Self::Model) -> Result<Self::Runnable>;
193}
194
195pub trait RunnableInterface: Send + Sync {
196 type Value: ValueInterface;
197 type Fact: FactInterface;
198 type State: StateInterface<Value = Self::Value>;
199 fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Self::Value>>
200 where
201 I: IntoIterator<Item = V>,
202 V: TryInto<Self::Value, Error = E>,
203 E: Into<anyhow::Error>,
204 {
205 self.spawn_state()?.run(inputs)
206 }
207
208 fn input_count(&self) -> Result<usize>;
209 fn output_count(&self) -> Result<usize>;
210 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
211
212 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
213
214 fn input_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
215 Ok((0..self.input_count()?)
216 .map(|ix| self.input_fact(ix))
217 .collect::<Result<Vec<_>>>()?
218 .into_iter())
219 }
220
221 fn output_facts(&self) -> Result<impl Iterator<Item = Self::Fact>> {
222 Ok((0..self.output_count()?)
223 .map(|ix| self.output_fact(ix))
224 .collect::<Result<Vec<_>>>()?
225 .into_iter())
226 }
227
228 fn property_keys(&self) -> Result<Vec<String>>;
229 fn property(&self, name: impl AsRef<str>) -> Result<Self::Value>;
230
231 fn spawn_state(&self) -> Result<Self::State>;
232
233 fn cost_json(&self) -> Result<String>;
234
235 fn profile_json<I, IV, IE, S, SV, SE>(
236 &self,
237 inputs: Option<I>,
238 state_initializers: Option<S>,
239 ) -> Result<String>
240 where
241 I: IntoIterator<Item = IV>,
242 IV: TryInto<Self::Value, Error = IE>,
243 IE: Into<anyhow::Error> + Debug,
244 S: IntoIterator<Item = SV>,
245 SV: TryInto<Self::Value, Error = SE>,
246 SE: Into<anyhow::Error> + Debug;
247}
248
249pub trait StateInterface {
250 type Fact: FactInterface;
251 type Value: ValueInterface;
252
253 fn input_count(&self) -> Result<usize>;
254 fn output_count(&self) -> Result<usize>;
255
256 fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Self::Value>>
257 where
258 I: IntoIterator<Item = V>,
259 V: TryInto<Self::Value, Error = E>,
260 E: Into<anyhow::Error>;
261
262 #[doc(hidden)]
263 #[deprecated]
264 fn initializable_states_count(&self) -> Result<usize>;
265
266 #[doc(hidden)]
267 #[deprecated]
268 fn get_states_facts(&self) -> Result<Vec<Self::Fact>>;
269
270 #[doc(hidden)]
271 #[deprecated]
272 fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
273 where
274 I: IntoIterator<Item = V>,
275 V: TryInto<Self::Value, Error = E>,
276 E: Into<anyhow::Error> + Debug;
277
278 #[doc(hidden)]
279 #[deprecated]
280 fn get_states(&self) -> Result<Vec<Self::Value>>;
281}
282
283pub trait ValueInterface: Debug + Sized + Clone + PartialEq + Send + Sync {
284 fn datum_type(&self) -> Result<DatumType>;
285 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
286 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
287
288 fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
289 let data = unsafe {
290 std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
291 };
292 Self::from_bytes(T::datum_type(), shape, data)
293 }
294
295 fn as_slice<T: Datum>(&self) -> Result<&[T]> {
296 let (dt, _shape, data) = self.as_bytes()?;
297 ensure!(T::datum_type() == dt);
298 let data = unsafe {
299 std::slice::from_raw_parts(
300 data.as_ptr() as *const T,
301 data.len() / std::mem::size_of::<T>(),
302 )
303 };
304 Ok(data)
305 }
306
307 fn as_shape_and_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
308 let (_, shape, _) = self.as_bytes()?;
309 let data = self.as_slice()?;
310 Ok((shape, data))
311 }
312
313 fn shape(&self) -> Result<&[usize]> {
314 let (_, shape, _) = self.as_bytes()?;
315 Ok(shape)
316 }
317
318 fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<'_, T>> {
319 let (shape, data) = self.as_shape_and_slice()?;
320 Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
321 }
322
323 fn convert_to(&self, to: DatumType) -> Result<Self>;
324}
325
326pub trait FactInterface: Debug + Display + Clone {
327 type Dim: DimInterface;
328 fn datum_type(&self) -> Result<DatumType>;
329 fn rank(&self) -> Result<usize>;
330 fn dim(&self, axis: usize) -> Result<Self::Dim>;
331
332 fn dims(&self) -> Result<impl Iterator<Item = Self::Dim>> {
333 Ok((0..self.rank()?).map(|axis| self.dim(axis)).collect::<Result<Vec<_>>>()?.into_iter())
334 }
335}
336
337pub trait DimInterface: Debug + Display + Clone {
338 fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self>;
339 fn to_int64(&self) -> Result<i64>;
340}
341
342pub trait InferenceFactInterface: Debug + Display + Default + Clone {
343 fn empty() -> Result<Self>;
344}
345
346pub trait AsFact<M, F> {
347 fn as_fact(&self, model: &M) -> Result<Bow<'_, F>>;
348}
349
350#[repr(C)]
351#[allow(non_camel_case_types)]
352#[derive(Debug, PartialEq, Eq, Copy, Clone)]
353pub enum DatumType {
354 TRACT_DATUM_TYPE_BOOL = 0x01,
355 TRACT_DATUM_TYPE_U8 = 0x11,
356 TRACT_DATUM_TYPE_U16 = 0x12,
357 TRACT_DATUM_TYPE_U32 = 0x14,
358 TRACT_DATUM_TYPE_U64 = 0x18,
359 TRACT_DATUM_TYPE_I8 = 0x21,
360 TRACT_DATUM_TYPE_I16 = 0x22,
361 TRACT_DATUM_TYPE_I32 = 0x24,
362 TRACT_DATUM_TYPE_I64 = 0x28,
363 TRACT_DATUM_TYPE_F16 = 0x32,
364 TRACT_DATUM_TYPE_F32 = 0x34,
365 TRACT_DATUM_TYPE_F64 = 0x38,
366 #[cfg(feature = "complex")]
367 TRACT_DATUM_TYPE_COMPLEX_I16 = 0x42,
368 #[cfg(feature = "complex")]
369 TRACT_DATUM_TYPE_COMPLEX_I32 = 0x44,
370 #[cfg(feature = "complex")]
371 TRACT_DATUM_TYPE_COMPLEX_I64 = 0x48,
372 #[cfg(feature = "complex")]
373 TRACT_DATUM_TYPE_COMPLEX_F16 = 0x52,
374 #[cfg(feature = "complex")]
375 TRACT_DATUM_TYPE_COMPLEX_F32 = 0x54,
376 #[cfg(feature = "complex")]
377 TRACT_DATUM_TYPE_COMPLEX_F64 = 0x58,
378}
379
380impl DatumType {
381 pub fn size_of(&self) -> usize {
382 use DatumType::*;
383 match &self {
384 TRACT_DATUM_TYPE_BOOL | TRACT_DATUM_TYPE_U8 | TRACT_DATUM_TYPE_I8 => 1,
385 TRACT_DATUM_TYPE_U16 | TRACT_DATUM_TYPE_I16 | TRACT_DATUM_TYPE_F16 => 2,
386 TRACT_DATUM_TYPE_U32 | TRACT_DATUM_TYPE_I32 | TRACT_DATUM_TYPE_F32 => 4,
387 TRACT_DATUM_TYPE_U64 | TRACT_DATUM_TYPE_I64 | TRACT_DATUM_TYPE_F64 => 8,
388 #[cfg(feature = "complex")]
389 TRACT_DATUM_TYPE_COMPLEX_I16 | TRACT_DATUM_TYPE_F16 => 4,
390 #[cfg(feature = "complex")]
391 TRACT_DATUM_TYPE_COMPLEX_I32 | TRACT_DATUM_TYPE_F32 => 8,
392 #[cfg(feature = "complex")]
393 TRACT_DATUM_TYPE_COMPLEX_I64 | TRACT_DATUM_TYPE_F64 => 16,
394 }
395 }
396
397 pub fn is_bool(&self) -> bool {
398 use DatumType::*;
399 *self == TRACT_DATUM_TYPE_BOOL
400 }
401
402 pub fn is_number(&self) -> bool {
403 use DatumType::*;
404 *self != TRACT_DATUM_TYPE_BOOL
405 }
406
407 pub fn is_unsigned(&self) -> bool {
408 use DatumType::*;
409 *self == TRACT_DATUM_TYPE_U8
410 || *self == TRACT_DATUM_TYPE_U16
411 || *self == TRACT_DATUM_TYPE_U32
412 || *self == TRACT_DATUM_TYPE_U64
413 }
414
415 pub fn is_signed(&self) -> bool {
416 use DatumType::*;
417 *self == TRACT_DATUM_TYPE_I8
418 || *self == TRACT_DATUM_TYPE_I16
419 || *self == TRACT_DATUM_TYPE_I32
420 || *self == TRACT_DATUM_TYPE_I64
421 }
422
423 pub fn is_float(&self) -> bool {
424 use DatumType::*;
425 *self == TRACT_DATUM_TYPE_F16
426 || *self == TRACT_DATUM_TYPE_F32
427 || *self == TRACT_DATUM_TYPE_F64
428 }
429}
430
431pub trait Datum {
432 fn datum_type() -> DatumType;
433}
434
435macro_rules! impl_datum_type {
436 ($ty:ty, $c_repr:expr) => {
437 impl Datum for $ty {
438 fn datum_type() -> DatumType {
439 $c_repr
440 }
441 }
442 };
443}
444
445impl_datum_type!(bool, DatumType::TRACT_DATUM_TYPE_BOOL);
446impl_datum_type!(u8, DatumType::TRACT_DATUM_TYPE_U8);
447impl_datum_type!(u16, DatumType::TRACT_DATUM_TYPE_U16);
448impl_datum_type!(u32, DatumType::TRACT_DATUM_TYPE_U32);
449impl_datum_type!(u64, DatumType::TRACT_DATUM_TYPE_U64);
450impl_datum_type!(i8, DatumType::TRACT_DATUM_TYPE_I8);
451impl_datum_type!(i16, DatumType::TRACT_DATUM_TYPE_I16);
452impl_datum_type!(i32, DatumType::TRACT_DATUM_TYPE_I32);
453impl_datum_type!(i64, DatumType::TRACT_DATUM_TYPE_I64);
454impl_datum_type!(half::f16, DatumType::TRACT_DATUM_TYPE_F16);
455impl_datum_type!(f32, DatumType::TRACT_DATUM_TYPE_F32);
456impl_datum_type!(f64, DatumType::TRACT_DATUM_TYPE_F64);