1use anyhow::{ensure, Result};
2use boow::Bow;
3use std::fmt::{Debug, Display};
4use std::path::Path;
5
6#[macro_use]
7pub mod macros;
8
9pub trait NnefInterface: Sized {
13 type Model: ModelInterface;
14 fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Self::Model>;
18
19 fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()>;
21
22 fn enable_tract_core(&mut self) -> Result<()>;
24
25 fn enable_tract_extra(&mut self) -> Result<()>;
27
28 fn enable_tract_transformers(&mut self) -> Result<()>;
30
31 fn enable_onnx(&mut self) -> Result<()>;
34
35 fn enable_pulse(&mut self) -> Result<()>;
37
38 fn enable_extended_identifier_syntax(&mut self) -> Result<()>;
43
44 fn with_tract_core(mut self) -> Result<Self> {
46 self.enable_tract_core()?;
47 Ok(self)
48 }
49
50 fn with_tract_extra(mut self) -> Result<Self> {
52 self.enable_tract_extra()?;
53 Ok(self)
54 }
55
56 fn with_tract_transformers(mut self) -> Result<Self> {
58 self.enable_tract_transformers()?;
59 Ok(self)
60 }
61
62 fn with_onnx(mut self) -> Result<Self> {
64 self.enable_onnx()?;
65 Ok(self)
66 }
67
68 fn with_pulse(mut self) -> Result<Self> {
70 self.enable_pulse()?;
71 Ok(self)
72 }
73
74 fn with_extended_identifier_syntax(mut self) -> Result<Self> {
76 self.enable_extended_identifier_syntax()?;
77 Ok(self)
78 }
79
80 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
84
85 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
91 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Self::Model) -> Result<()>;
92}
93
94pub trait OnnxInterface {
95 type InferenceModel: InferenceModelInterface;
96 fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel>;
97}
98
99pub trait InferenceModelInterface: Sized {
100 type Model: ModelInterface;
101 type InferenceFact: InferenceFactInterface;
102 fn set_output_names(
103 &mut self,
104 outputs: impl IntoIterator<Item = impl AsRef<str>>,
105 ) -> Result<()>;
106 fn input_count(&self) -> Result<usize>;
107 fn output_count(&self) -> Result<usize>;
108 fn input_name(&self, id: usize) -> Result<String>;
109 fn output_name(&self, id: usize) -> Result<String>;
110
111 fn input_fact(&self, id: usize) -> Result<Self::InferenceFact>;
112
113 fn set_input_fact(
114 &mut self,
115 id: usize,
116 fact: impl AsFact<Self, Self::InferenceFact>,
117 ) -> Result<()>;
118
119 fn output_fact(&self, id: usize) -> Result<Self::InferenceFact>;
120
121 fn set_output_fact(
122 &mut self,
123 id: usize,
124 fact: impl AsFact<Self, Self::InferenceFact>,
125 ) -> Result<()>;
126
127 fn analyse(&mut self) -> Result<()>;
128
129 fn into_typed(self) -> Result<Self::Model>;
130
131 fn into_optimized(self) -> Result<Self::Model>;
132}
133
134pub trait ModelInterface: Sized {
135 type Fact: FactInterface;
136 type Runnable: RunnableInterface;
137 type Value: ValueInterface;
138 fn input_count(&self) -> Result<usize>;
139
140 fn output_count(&self) -> Result<usize>;
141
142 fn input_name(&self, id: usize) -> Result<String>;
143
144 fn output_name(&self, id: usize) -> Result<String>;
145
146 fn set_output_names(
147 &mut self,
148 outputs: impl IntoIterator<Item = impl AsRef<str>>,
149 ) -> Result<()>;
150
151 fn input_fact(&self, id: usize) -> Result<Self::Fact>;
152
153 fn output_fact(&self, id: usize) -> Result<Self::Fact>;
154
155 fn declutter(&mut self) -> Result<()>;
156
157 fn optimize(&mut self) -> Result<()>;
158
159 fn into_decluttered(self) -> Result<Self>;
160
161 fn into_optimized(self) -> Result<Self>;
162
163 fn into_runnable(self) -> Result<Self::Runnable>;
164
165 fn concretize_symbols(
166 &mut self,
167 values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
168 ) -> Result<()>;
169
170 fn transform(&mut self, transform: &str) -> Result<()>;
171
172 fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()>;
173
174 fn cost_json(&self) -> Result<String>;
175
176 fn profile_json<I, V, E>(&self, inputs: Option<I>) -> Result<String>
177 where
178 I: IntoIterator<Item = V>,
179 V: TryInto<Self::Value, Error = E>,
180 E: Into<anyhow::Error> + Debug;
181
182 fn property_keys(&self) -> Result<Vec<String>>;
183
184 fn property(&self, name: impl AsRef<str>) -> Result<Self::Value>;
185}
186
187pub trait RunnableInterface {
188 type Value: ValueInterface;
189 type State: StateInterface<Value = Self::Value>;
190 fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Self::Value>>
191 where
192 I: IntoIterator<Item = V>,
193 V: TryInto<Self::Value, Error = E>,
194 E: Into<anyhow::Error>,
195 {
196 self.spawn_state()?.run(inputs)
197 }
198
199 fn input_count(&self) -> Result<usize>;
200 fn output_count(&self) -> Result<usize>;
201
202 fn spawn_state(&self) -> Result<Self::State>;
203}
204
205pub trait StateInterface {
206 type Value: ValueInterface;
207
208 fn input_count(&self) -> Result<usize>;
209 fn output_count(&self) -> Result<usize>;
210
211 fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Self::Value>>
212 where
213 I: IntoIterator<Item = V>,
214 V: TryInto<Self::Value, Error = E>,
215 E: Into<anyhow::Error>;
216}
217
218pub trait ValueInterface: Sized + Clone {
219 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self>;
220 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])>;
221
222 fn from_slice<T: Datum>(shape: &[usize], data: &[T]) -> Result<Self> {
223 let data = unsafe {
224 std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
225 };
226 Self::from_bytes(T::datum_type(), shape, data)
227 }
228
229 fn as_slice<T: Datum>(&self) -> Result<(&[usize], &[T])> {
230 let (dt, shape, data) = self.as_bytes()?;
231 ensure!(T::datum_type() == dt);
232 let data = unsafe {
233 std::slice::from_raw_parts(
234 data.as_ptr() as *const T,
235 data.len() / std::mem::size_of::<T>(),
236 )
237 };
238 Ok((shape, data))
239 }
240
241 fn view<T: Datum>(&self) -> Result<ndarray::ArrayViewD<T>> {
242 let (shape, data) = self.as_slice()?;
243 Ok(unsafe { ndarray::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
244 }
245}
246
247pub trait FactInterface: Debug + Display + Clone {}
248pub trait InferenceFactInterface: Debug + Display + Default + Clone {
249 fn empty() -> Result<Self>;
250}
251
252pub trait AsFact<M, F> {
253 fn as_fact(&self, model: &mut M) -> Result<Bow<F>>;
254}
255
256#[repr(C)]
257#[allow(non_camel_case_types)]
258#[derive(Debug, PartialEq, Eq, Copy, Clone)]
259pub enum DatumType {
260 TRACT_DATUM_TYPE_BOOL = 0x01,
261 TRACT_DATUM_TYPE_U8 = 0x11,
262 TRACT_DATUM_TYPE_U16 = 0x12,
263 TRACT_DATUM_TYPE_U32 = 0x14,
264 TRACT_DATUM_TYPE_U64 = 0x18,
265 TRACT_DATUM_TYPE_I8 = 0x21,
266 TRACT_DATUM_TYPE_I16 = 0x22,
267 TRACT_DATUM_TYPE_I32 = 0x24,
268 TRACT_DATUM_TYPE_I64 = 0x28,
269 TRACT_DATUM_TYPE_F16 = 0x32,
270 TRACT_DATUM_TYPE_F32 = 0x34,
271 TRACT_DATUM_TYPE_F64 = 0x38,
272 #[cfg(feature = "complex")]
273 TRACT_DATUM_TYPE_COMPLEX_I16 = 0x42,
274 #[cfg(feature = "complex")]
275 TRACT_DATUM_TYPE_COMPLEX_I32 = 0x44,
276 #[cfg(feature = "complex")]
277 TRACT_DATUM_TYPE_COMPLEX_I64 = 0x48,
278 #[cfg(feature = "complex")]
279 TRACT_DATUM_TYPE_COMPLEX_F16 = 0x52,
280 #[cfg(feature = "complex")]
281 TRACT_DATUM_TYPE_COMPLEX_F32 = 0x54,
282 #[cfg(feature = "complex")]
283 TRACT_DATUM_TYPE_COMPLEX_F64 = 0x58,
284}
285
286impl DatumType {
287 pub fn size_of(&self) -> usize {
288 use DatumType::*;
289 match &self {
290 TRACT_DATUM_TYPE_BOOL | TRACT_DATUM_TYPE_U8 | TRACT_DATUM_TYPE_I8 => 1,
291 TRACT_DATUM_TYPE_U16 | TRACT_DATUM_TYPE_I16 | TRACT_DATUM_TYPE_F16 => 2,
292 TRACT_DATUM_TYPE_U32 | TRACT_DATUM_TYPE_I32 | TRACT_DATUM_TYPE_F32 => 4,
293 TRACT_DATUM_TYPE_U64 | TRACT_DATUM_TYPE_I64 | TRACT_DATUM_TYPE_F64 => 8,
294 #[cfg(feature = "complex")]
295 TRACT_DATUM_TYPE_COMPLEX_I16 | TRACT_DATUM_TYPE_F16 => 4,
296 #[cfg(feature = "complex")]
297 TRACT_DATUM_TYPE_COMPLEX_I32 | TRACT_DATUM_TYPE_F32 => 8,
298 #[cfg(feature = "complex")]
299 TRACT_DATUM_TYPE_COMPLEX_I64 | TRACT_DATUM_TYPE_F64 => 16,
300 }
301 }
302}
303
304pub trait Datum {
305 fn datum_type() -> DatumType;
306}
307
308macro_rules! impl_datum_type {
309 ($ty:ty, $c_repr:expr) => {
310 impl Datum for $ty {
311 fn datum_type() -> DatumType {
312 $c_repr
313 }
314 }
315 };
316}
317
318impl_datum_type!(bool, DatumType::TRACT_DATUM_TYPE_BOOL);
319impl_datum_type!(u8, DatumType::TRACT_DATUM_TYPE_U8);
320impl_datum_type!(u16, DatumType::TRACT_DATUM_TYPE_U16);
321impl_datum_type!(u32, DatumType::TRACT_DATUM_TYPE_U32);
322impl_datum_type!(u64, DatumType::TRACT_DATUM_TYPE_U64);
323impl_datum_type!(i8, DatumType::TRACT_DATUM_TYPE_I8);
324impl_datum_type!(i16, DatumType::TRACT_DATUM_TYPE_I16);
325impl_datum_type!(i32, DatumType::TRACT_DATUM_TYPE_I32);
326impl_datum_type!(i64, DatumType::TRACT_DATUM_TYPE_I64);
327impl_datum_type!(half::f16, DatumType::TRACT_DATUM_TYPE_F16);
328impl_datum_type!(f32, DatumType::TRACT_DATUM_TYPE_F32);
329impl_datum_type!(f64, DatumType::TRACT_DATUM_TYPE_F64);