tract_rs/
lib.rs

1use std::fmt::{Debug, Display};
2use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use ndarray::{Data, Dimension, RawData};
7use tract_extra::WithTractExtra;
8use tract_libcli::annotations::Annotations;
9use tract_libcli::profile::BenchLimits;
10use tract_libcli::tensor::RunTensors;
11use tract_nnef::internal::parse_tdim;
12use tract_nnef::prelude::{
13    Framework, IntoTValue, SymbolValues, TValue, TVec, Tensor, TractResult, TypedFact, TypedModel,
14    TypedRunnableModel, TypedSimplePlan, TypedSimpleState,
15};
16use tract_onnx::prelude::InferenceModelExt;
17use tract_onnx_opl::WithOnnx;
18use tract_pulse::internal::PlanOptions;
19use tract_pulse::model::{PulsedModel, PulsedModelExt};
20use tract_pulse::WithPulse;
21use tract_transformers::WithTractTransformers;
22
23use tract_api::*;
24
25/// Creates an instance of an NNEF framework and parser that can be used to load and dump NNEF models.
26pub fn nnef() -> Result<Nnef> {
27    Ok(Nnef(tract_nnef::nnef()))
28}
29
30pub fn onnx() -> Result<Onnx> {
31    Ok(Onnx(tract_onnx::onnx()))
32}
33
34/// tract version tag
35pub fn version() -> &'static str {
36    env!("CARGO_PKG_VERSION")
37}
38
39pub struct Nnef(tract_nnef::internal::Nnef);
40
41impl NnefInterface for Nnef {
42    type Model = Model;
43    fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Model> {
44        self.0.model_for_path(path).map(Model)
45    }
46
47    fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()> {
48        if let Some(transform) = self.0.get_transform(transform_spec)? {
49            transform.transform(&mut model.0)?;
50        }
51        Ok(())
52    }
53
54    fn enable_tract_core(&mut self) -> Result<()> {
55        self.0.enable_tract_core();
56        Ok(())
57    }
58
59    fn enable_tract_extra(&mut self) -> Result<()> {
60        self.0.enable_tract_extra();
61        Ok(())
62    }
63
64    fn enable_tract_transformers(&mut self) -> Result<()> {
65        self.0.enable_tract_transformers();
66        Ok(())
67    }
68
69    fn enable_onnx(&mut self) -> Result<()> {
70        self.0.enable_onnx();
71        Ok(())
72    }
73
74    fn enable_pulse(&mut self) -> Result<()> {
75        self.0.enable_pulse();
76        Ok(())
77    }
78
79    fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
80        self.0.allow_extended_identifier_syntax(true);
81        Ok(())
82    }
83
84    fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
85        self.0.write_to_dir(&model.0, path)
86    }
87
88    fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
89        let file = std::fs::File::create(path)?;
90        self.0.write_to_tar(&model.0, file)?;
91        Ok(())
92    }
93
94    fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
95        let file = std::fs::File::create(path)?;
96        let gz = flate2::write::GzEncoder::new(file, flate2::Compression::default());
97        self.0.write_to_tar(&model.0, gz)?;
98        Ok(())
99    }
100}
101
102pub struct Onnx(tract_onnx::Onnx);
103impl OnnxInterface for Onnx {
104    type InferenceModel = InferenceModel;
105    fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel> {
106        Ok(InferenceModel(self.0.model_for_path(path)?))
107    }
108}
109
110pub struct InferenceModel(tract_onnx::prelude::InferenceModel);
111impl InferenceModelInterface for InferenceModel {
112    type Model = Model;
113    type InferenceFact = InferenceFact;
114
115    fn input_count(&self) -> Result<usize> {
116        Ok(self.0.inputs.len())
117    }
118
119    fn output_count(&self) -> Result<usize> {
120        Ok(self.0.outputs.len())
121    }
122
123    fn input_name(&self, id: usize) -> Result<String> {
124        let node = self.0.inputs[id].node;
125        Ok(self.0.node(node).name.to_string())
126    }
127
128    fn output_name(&self, id: usize) -> Result<String> {
129        let node = self.0.outputs[id].node;
130        Ok(self.0.node(node).name.to_string())
131    }
132
133    fn set_output_names(
134        &mut self,
135        outputs: impl IntoIterator<Item = impl AsRef<str>>,
136    ) -> Result<()> {
137        self.0.set_output_names(outputs)
138    }
139
140    fn input_fact(&self, id: usize) -> Result<InferenceFact> {
141        Ok(InferenceFact(self.0.input_fact(id)?.clone()))
142    }
143
144    fn set_input_fact(
145        &mut self,
146        id: usize,
147        fact: impl AsFact<Self, Self::InferenceFact>,
148    ) -> Result<()> {
149        let fact = fact.as_fact(self)?.0.clone();
150        self.0.set_input_fact(id, fact)
151    }
152
153    fn output_fact(&self, id: usize) -> Result<InferenceFact> {
154        Ok(InferenceFact(self.0.output_fact(id)?.clone()))
155    }
156
157    fn set_output_fact(
158        &mut self,
159        id: usize,
160        fact: impl AsFact<Self, Self::InferenceFact>,
161    ) -> Result<()> {
162        let fact = fact.as_fact(self)?.0.clone();
163        self.0.set_output_fact(id, fact)
164    }
165
166    fn analyse(&mut self) -> Result<()> {
167        self.0.analyse(false)?;
168        Ok(())
169    }
170
171    fn into_typed(self) -> Result<Self::Model> {
172        let typed = self.0.into_typed()?;
173        Ok(Model(typed))
174    }
175
176    fn into_optimized(self) -> Result<Self::Model> {
177        let typed = self.0.into_optimized()?;
178        Ok(Model(typed))
179    }
180}
181
182// MODEL
183pub struct Model(TypedModel);
184
185impl ModelInterface for Model {
186    type Fact = Fact;
187    type Runnable = Runnable;
188    type Value = Value;
189
190    fn input_count(&self) -> Result<usize> {
191        Ok(self.0.inputs.len())
192    }
193
194    fn output_count(&self) -> Result<usize> {
195        Ok(self.0.outputs.len())
196    }
197
198    fn input_name(&self, id: usize) -> Result<String> {
199        let node = self.0.inputs[id].node;
200        Ok(self.0.node(node).name.to_string())
201    }
202
203    fn output_name(&self, id: usize) -> Result<String> {
204        let node = self.0.outputs[id].node;
205        Ok(self.0.node(node).name.to_string())
206    }
207
208    fn set_output_names(
209        &mut self,
210        outputs: impl IntoIterator<Item = impl AsRef<str>>,
211    ) -> Result<()> {
212        self.0.set_output_names(outputs)
213    }
214
215    fn input_fact(&self, id: usize) -> Result<Fact> {
216        Ok(Fact(self.0.input_fact(id)?.clone()))
217    }
218
219    fn output_fact(&self, id: usize) -> Result<Fact> {
220        Ok(Fact(self.0.output_fact(id)?.clone()))
221    }
222
223    fn declutter(&mut self) -> Result<()> {
224        self.0.declutter()
225    }
226
227    fn optimize(&mut self) -> Result<()> {
228        self.0.optimize()
229    }
230
231    fn into_decluttered(mut self) -> Result<Model> {
232        self.0.declutter()?;
233        Ok(self)
234    }
235
236    fn into_optimized(self) -> Result<Model> {
237        Ok(Model(self.0.into_optimized()?))
238    }
239
240    fn into_runnable(self) -> Result<Runnable> {
241        Ok(Runnable(Arc::new(self.0.into_runnable()?)))
242    }
243
244    fn concretize_symbols(
245        &mut self,
246        values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
247    ) -> Result<()> {
248        let mut table = SymbolValues::default();
249        for (k, v) in values {
250            table = table.with(&self.0.symbols.sym(k.as_ref()), v);
251        }
252        self.0 = self.0.concretize_dims(&table)?;
253        Ok(())
254    }
255
256    fn transform(&mut self, transform: &str) -> Result<()> {
257        let transform = tract_onnx::tract_core::transform::get_transform(transform)
258            .with_context(|| format!("transform `{transform}' could not be found"))?;
259        transform.transform(&mut self.0)
260    }
261
262    fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
263        let stream_sym = self.0.symbols.sym(name.as_ref());
264        let pulse_dim = parse_tdim(&self.0.symbols, value.as_ref())?;
265        self.0 = PulsedModel::new(&self.0, stream_sym, &pulse_dim)?.into_typed()?;
266        Ok(())
267    }
268
269    fn cost_json(&self) -> Result<String> {
270        let input: Option<Vec<Value>> = None;
271        let states: Option<Vec<Value>> = None;
272        self.profile_json(input, states)
273    }
274
275    fn profile_json<I, IV, IE, S, SV, SE>(
276        &self,
277        inputs: Option<I>,
278        state_initializers: Option<S>,
279    ) -> Result<String>
280    where
281        I: IntoIterator<Item = IV>,
282        IV: TryInto<Self::Value, Error = IE>,
283        IE: Into<anyhow::Error> + Debug,
284        S: IntoIterator<Item = SV>,
285        SV: TryInto<Self::Value, Error = SE>,
286        SE: Into<anyhow::Error> + Debug,
287    {
288        let mut annotations = Annotations::from_model(&self.0)?;
289        tract_libcli::profile::extract_costs(&mut annotations, &self.0, &SymbolValues::default())?;
290        if let Some(inputs) = inputs {
291            let inputs = inputs
292                .into_iter()
293                .map(|v| Ok(v.try_into().unwrap().0))
294                .collect::<TractResult<TVec<_>>>()?;
295
296            let mut state_inits: Vec<TValue> = vec![];
297
298            if let Some(states) = state_initializers {
299                states.into_iter().for_each(|s| state_inits.push(s.try_into().unwrap().0));
300            }
301            tract_libcli::profile::profile(
302                &self.0,
303                &BenchLimits::default(),
304                &mut annotations,
305                &PlanOptions::default(),
306                &RunTensors { sources: vec![inputs], state_initializers: state_inits },
307                None,
308                true,
309            )?;
310        };
311        let export = tract_libcli::export::GraphPerfInfo::from(&self.0, &annotations);
312        Ok(serde_json::to_string(&export)?)
313    }
314
315    fn property_keys(&self) -> Result<Vec<String>> {
316        Ok(self.0.properties.keys().cloned().collect())
317    }
318
319    fn property(&self, name: impl AsRef<str>) -> Result<Value> {
320        let name = name.as_ref();
321        self.0
322            .properties
323            .get(name)
324            .with_context(|| format!("no property for name {name}"))
325            .map(|t| Value(t.clone().into_tvalue()))
326    }
327}
328
329// RUNNABLE
330pub struct Runnable(Arc<TypedRunnableModel<TypedModel>>);
331
332impl RunnableInterface for Runnable {
333    type Value = Value;
334    type State = State;
335
336    fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Value>>
337    where
338        I: IntoIterator<Item = V>,
339        V: TryInto<Self::Value, Error = E>,
340        E: Into<anyhow::Error>,
341    {
342        self.spawn_state()?.run(inputs)
343    }
344
345    fn input_count(&self) -> Result<usize> {
346        Ok(self.0.model().inputs.len())
347    }
348
349    fn output_count(&self) -> Result<usize> {
350        Ok(self.0.model().outputs.len())
351    }
352
353    fn spawn_state(&self) -> Result<State> {
354        let state = TypedSimpleState::new(self.0.clone())?;
355        Ok(State(state))
356    }
357}
358
359// STATE
360pub struct State(TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>);
361
362impl StateInterface for State {
363    type Fact = Fact;
364    type Value = Value;
365
366    fn input_count(&self) -> Result<usize> {
367        Ok(self.0.model().inputs.len())
368    }
369
370    fn output_count(&self) -> Result<usize> {
371        Ok(self.0.model().outputs.len())
372    }
373
374    fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Value>>
375    where
376        I: IntoIterator<Item = V>,
377        V: TryInto<Value, Error = E>,
378        E: Into<anyhow::Error>,
379    {
380        let inputs: TVec<TValue> = inputs
381            .into_iter()
382            .map(|i| i.try_into().map_err(|e| e.into()).map(|v| v.0))
383            .collect::<Result<_>>()?;
384        let outputs = self.0.run(inputs)?;
385        Ok(outputs.into_iter().map(Value).collect())
386    }
387
388    fn initializable_states_count(&self) -> Result<usize> {
389        Ok(self
390            .0
391            .states
392            .iter()
393            .filter_map(Option::as_ref)
394            .filter(|s| s.init_tensor_fact().is_some())
395            .count())
396    }
397
398    fn get_states_facts(&self) -> Result<Vec<Fact>> {
399        Ok(self
400            .0
401            .states
402            .iter()
403            .filter_map(Option::as_ref)
404            .filter_map(|s| s.init_tensor_fact().map(Fact))
405            .collect::<Vec<Fact>>())
406    }
407
408    fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
409    where
410        I: IntoIterator<Item = V>,
411        V: TryInto<Self::Value, Error = E>,
412        E: Into<anyhow::Error> + Debug,
413    {
414        let mut states = vec![];
415        state_initializers.into_iter().for_each(|s| {
416            states.push(s.try_into().unwrap().0);
417        });
418
419        self.0.init_states(&mut states)?;
420        Ok(())
421    }
422
423    fn get_states(&self) -> Result<Vec<Self::Value>> {
424        let mut states = vec![];
425        for state in self
426            .0
427            .states
428            .iter()
429            .filter_map(Option::as_ref)
430            .filter(|s| s.init_tensor_fact().is_some())
431        {
432            state.save_to(&mut states)?;
433        }
434
435        let mut res = vec![];
436        for state in states {
437            res.push(Value(state));
438        }
439        Ok(res)
440    }
441}
442
443// VALUE
444#[derive(Clone)]
445pub struct Value(TValue);
446
447impl ValueInterface for Value {
448    fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
449        let dt = to_internal_dt(dt);
450        let len = shape.iter().product::<usize>() * dt.size_of();
451        anyhow::ensure!(len == data.len());
452        let tensor = unsafe { Tensor::from_raw_dt(dt, shape, data)? };
453        Ok(Value(tensor.into_tvalue()))
454    }
455
456    fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
457        let dt = from_internal_dt(self.0.datum_type())?;
458        Ok((dt, self.0.shape(), unsafe { self.0.as_slice_unchecked::<u8>() }))
459    }
460
461    /*
462    fn as_parts<T: 'static>(&self) -> Result<(&[usize], &[T])> {
463        let _dt = to_datum_type::<T>()?;
464        let shape = self.0.shape();
465        let data = unsafe {
466            std::slice::from_raw_parts(self.0.as_ptr_unchecked::<u8>() as *const T, self.0.len())
467        };
468        Ok((shape, data))
469    }
470    */
471}
472
473#[derive(Clone, Debug)]
474pub struct Fact(TypedFact);
475
476impl FactInterface for Fact {}
477
478impl Fact {
479    fn new(model: &mut Model, spec: impl ToString) -> Result<Fact> {
480        let fact = tract_libcli::tensor::parse_spec(&model.0.symbols, &spec.to_string())?;
481        let fact = tract_onnx::prelude::Fact::to_typed_fact(&fact)?.into_owned();
482        Ok(Fact(fact))
483    }
484
485    fn dump(&self) -> Result<String> {
486        Ok(format!("{:?}", self.0))
487    }
488}
489
490impl Display for Fact {
491    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492        write!(f, "{}", self.dump().unwrap())
493    }
494}
495
496#[derive(Default, Clone, Debug)]
497pub struct InferenceFact(tract_onnx::prelude::InferenceFact);
498
499impl InferenceFactInterface for InferenceFact {
500    fn empty() -> Result<InferenceFact> {
501        Ok(InferenceFact(Default::default()))
502    }
503}
504
505impl InferenceFact {
506    fn new(model: &mut InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
507        let fact = tract_libcli::tensor::parse_spec(&model.0.symbols, &spec.to_string())?;
508        Ok(InferenceFact(fact))
509    }
510
511    fn dump(&self) -> Result<String> {
512        Ok(format!("{:?}", self.0))
513    }
514}
515
516impl Display for InferenceFact {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        write!(f, "{}", self.dump().unwrap())
519    }
520}
521
522value_from_to_ndarray!();
523as_inference_fact_impl!(InferenceModel, InferenceFact);
524as_fact_impl!(Model, Fact);
525
526/*
527#[inline(always)]
528fn to_datum_type<T: TractProxyDatumType>() -> Result<tract_nnef::prelude::DatumType> {
529macro_rules! dt { ($($t:ty),*) => { $(if TypeId::of::<T>() == TypeId::of::<$t>() { return Ok(<$t>::datum_type()); })* }}
530dt!(f32, f16, f64, i64, i32, i16, i8, bool, u64, u32, u16, u8);
531anyhow::bail!("Unsupported type {}", std::any::type_name::<T>())
532}
533*/
534
535fn to_internal_dt(it: DatumType) -> tract_nnef::prelude::DatumType {
536    use tract_nnef::prelude::DatumType::*;
537    use DatumType::*;
538    match it {
539        TRACT_DATUM_TYPE_BOOL => Bool,
540        TRACT_DATUM_TYPE_U8 => U8,
541        TRACT_DATUM_TYPE_U16 => U16,
542        TRACT_DATUM_TYPE_U32 => U32,
543        TRACT_DATUM_TYPE_U64 => U64,
544        TRACT_DATUM_TYPE_I8 => I8,
545        TRACT_DATUM_TYPE_I16 => I16,
546        TRACT_DATUM_TYPE_I32 => I32,
547        TRACT_DATUM_TYPE_I64 => I64,
548        TRACT_DATUM_TYPE_F16 => F16,
549        TRACT_DATUM_TYPE_F32 => F32,
550        TRACT_DATUM_TYPE_F64 => F64,
551        #[cfg(feature = "complex")]
552        TRACT_DATUM_TYPE_COMPLEX_I16 => ComplexI16,
553        #[cfg(feature = "complex")]
554        TRACT_DATUM_TYPE_COMPLEX_I32 => ComplexI32,
555        #[cfg(feature = "complex")]
556        TRACT_DATUM_TYPE_COMPLEX_I64 => ComplexI64,
557        #[cfg(feature = "complex")]
558        TRACT_DATUM_TYPE_COMPLEX_F16 => ComplexF16,
559        #[cfg(feature = "complex")]
560        TRACT_DATUM_TYPE_COMPLEX_F32 => ComplexF32,
561        #[cfg(feature = "complex")]
562        TRACT_DATUM_TYPE_COMPLEX_F64 => ComplexF64,
563    }
564}
565
566fn from_internal_dt(it: tract_nnef::prelude::DatumType) -> Result<DatumType> {
567    use tract_nnef::prelude::DatumType::*;
568    use DatumType::*;
569    Ok(match it {
570        Bool => TRACT_DATUM_TYPE_BOOL,
571        U8 => TRACT_DATUM_TYPE_U8,
572        U16 => TRACT_DATUM_TYPE_U16,
573        U32 => TRACT_DATUM_TYPE_U32,
574        U64 => TRACT_DATUM_TYPE_U64,
575        I8 => TRACT_DATUM_TYPE_I8,
576        I16 => TRACT_DATUM_TYPE_I16,
577        I32 => TRACT_DATUM_TYPE_I32,
578        I64 => TRACT_DATUM_TYPE_I64,
579        F16 => TRACT_DATUM_TYPE_F16,
580        F32 => TRACT_DATUM_TYPE_F32,
581        F64 => TRACT_DATUM_TYPE_F64,
582        #[cfg(feature = "complex")]
583        TRACT_DATUM_TYPE_COMPLEX_I16 => ComplexI16,
584        #[cfg(feature = "complex")]
585        TRACT_DATUM_TYPE_COMPLEX_I32 => ComplexI32,
586        #[cfg(feature = "complex")]
587        TRACT_DATUM_TYPE_COMPLEX_I64 => ComplexI64,
588        #[cfg(feature = "complex")]
589        TRACT_DATUM_TYPE_COMPLEX_F16 => ComplexF16,
590        #[cfg(feature = "complex")]
591        TRACT_DATUM_TYPE_COMPLEX_F32 => ComplexF32,
592        #[cfg(feature = "complex")]
593        TRACT_DATUM_TYPE_COMPLEX_F64 => ComplexF64,
594        _ => {
595            anyhow::bail!("Unsupported DatumType in the public API {:?}", it)
596        }
597    })
598}