1use std::fmt::{Debug, Display};
2use std::path::Path;
3use std::sync::Arc;
4
5use anyhow::{Context, Result};
6use ndarray::{Data, Dimension, RawData};
7use tract_extra::WithTractExtra;
8use tract_libcli::annotations::Annotations;
9use tract_libcli::profile::BenchLimits;
10use tract_libcli::tensor::RunTensors;
11use tract_nnef::internal::parse_tdim;
12use tract_nnef::prelude::{
13 Framework, IntoTValue, SymbolValues, TValue, TVec, Tensor, TractResult, TypedFact, TypedModel,
14 TypedRunnableModel, TypedSimplePlan, TypedSimpleState,
15};
16use tract_onnx::prelude::InferenceModelExt;
17use tract_onnx_opl::WithOnnx;
18use tract_pulse::internal::PlanOptions;
19use tract_pulse::model::{PulsedModel, PulsedModelExt};
20use tract_pulse::WithPulse;
21use tract_transformers::WithTractTransformers;
22
23use tract_api::*;
24
25pub fn nnef() -> Result<Nnef> {
27 Ok(Nnef(tract_nnef::nnef()))
28}
29
30pub fn onnx() -> Result<Onnx> {
31 Ok(Onnx(tract_onnx::onnx()))
32}
33
34pub fn version() -> &'static str {
36 env!("CARGO_PKG_VERSION")
37}
38
39pub struct Nnef(tract_nnef::internal::Nnef);
40
41impl NnefInterface for Nnef {
42 type Model = Model;
43 fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Model> {
44 self.0.model_for_path(path).map(Model)
45 }
46
47 fn transform_model(&self, model: &mut Self::Model, transform_spec: &str) -> Result<()> {
48 if let Some(transform) = self.0.get_transform(transform_spec)? {
49 transform.transform(&mut model.0)?;
50 }
51 Ok(())
52 }
53
54 fn enable_tract_core(&mut self) -> Result<()> {
55 self.0.enable_tract_core();
56 Ok(())
57 }
58
59 fn enable_tract_extra(&mut self) -> Result<()> {
60 self.0.enable_tract_extra();
61 Ok(())
62 }
63
64 fn enable_tract_transformers(&mut self) -> Result<()> {
65 self.0.enable_tract_transformers();
66 Ok(())
67 }
68
69 fn enable_onnx(&mut self) -> Result<()> {
70 self.0.enable_onnx();
71 Ok(())
72 }
73
74 fn enable_pulse(&mut self) -> Result<()> {
75 self.0.enable_pulse();
76 Ok(())
77 }
78
79 fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
80 self.0.allow_extended_identifier_syntax(true);
81 Ok(())
82 }
83
84 fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
85 self.0.write_to_dir(&model.0, path)
86 }
87
88 fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
89 let file = std::fs::File::create(path)?;
90 self.0.write_to_tar(&model.0, file)?;
91 Ok(())
92 }
93
94 fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
95 let file = std::fs::File::create(path)?;
96 let gz = flate2::write::GzEncoder::new(file, flate2::Compression::default());
97 self.0.write_to_tar(&model.0, gz)?;
98 Ok(())
99 }
100}
101
102pub struct Onnx(tract_onnx::Onnx);
103impl OnnxInterface for Onnx {
104 type InferenceModel = InferenceModel;
105 fn model_for_path(&self, path: impl AsRef<Path>) -> Result<Self::InferenceModel> {
106 Ok(InferenceModel(self.0.model_for_path(path)?))
107 }
108}
109
110pub struct InferenceModel(tract_onnx::prelude::InferenceModel);
111impl InferenceModelInterface for InferenceModel {
112 type Model = Model;
113 type InferenceFact = InferenceFact;
114
115 fn input_count(&self) -> Result<usize> {
116 Ok(self.0.inputs.len())
117 }
118
119 fn output_count(&self) -> Result<usize> {
120 Ok(self.0.outputs.len())
121 }
122
123 fn input_name(&self, id: usize) -> Result<String> {
124 let node = self.0.inputs[id].node;
125 Ok(self.0.node(node).name.to_string())
126 }
127
128 fn output_name(&self, id: usize) -> Result<String> {
129 let node = self.0.outputs[id].node;
130 Ok(self.0.node(node).name.to_string())
131 }
132
133 fn set_output_names(
134 &mut self,
135 outputs: impl IntoIterator<Item = impl AsRef<str>>,
136 ) -> Result<()> {
137 self.0.set_output_names(outputs)
138 }
139
140 fn input_fact(&self, id: usize) -> Result<InferenceFact> {
141 Ok(InferenceFact(self.0.input_fact(id)?.clone()))
142 }
143
144 fn set_input_fact(
145 &mut self,
146 id: usize,
147 fact: impl AsFact<Self, Self::InferenceFact>,
148 ) -> Result<()> {
149 let fact = fact.as_fact(self)?.0.clone();
150 self.0.set_input_fact(id, fact)
151 }
152
153 fn output_fact(&self, id: usize) -> Result<InferenceFact> {
154 Ok(InferenceFact(self.0.output_fact(id)?.clone()))
155 }
156
157 fn set_output_fact(
158 &mut self,
159 id: usize,
160 fact: impl AsFact<Self, Self::InferenceFact>,
161 ) -> Result<()> {
162 let fact = fact.as_fact(self)?.0.clone();
163 self.0.set_output_fact(id, fact)
164 }
165
166 fn analyse(&mut self) -> Result<()> {
167 self.0.analyse(false)?;
168 Ok(())
169 }
170
171 fn into_typed(self) -> Result<Self::Model> {
172 let typed = self.0.into_typed()?;
173 Ok(Model(typed))
174 }
175
176 fn into_optimized(self) -> Result<Self::Model> {
177 let typed = self.0.into_optimized()?;
178 Ok(Model(typed))
179 }
180}
181
182pub struct Model(TypedModel);
184
185impl ModelInterface for Model {
186 type Fact = Fact;
187 type Runnable = Runnable;
188 type Value = Value;
189
190 fn input_count(&self) -> Result<usize> {
191 Ok(self.0.inputs.len())
192 }
193
194 fn output_count(&self) -> Result<usize> {
195 Ok(self.0.outputs.len())
196 }
197
198 fn input_name(&self, id: usize) -> Result<String> {
199 let node = self.0.inputs[id].node;
200 Ok(self.0.node(node).name.to_string())
201 }
202
203 fn output_name(&self, id: usize) -> Result<String> {
204 let node = self.0.outputs[id].node;
205 Ok(self.0.node(node).name.to_string())
206 }
207
208 fn set_output_names(
209 &mut self,
210 outputs: impl IntoIterator<Item = impl AsRef<str>>,
211 ) -> Result<()> {
212 self.0.set_output_names(outputs)
213 }
214
215 fn input_fact(&self, id: usize) -> Result<Fact> {
216 Ok(Fact(self.0.input_fact(id)?.clone()))
217 }
218
219 fn output_fact(&self, id: usize) -> Result<Fact> {
220 Ok(Fact(self.0.output_fact(id)?.clone()))
221 }
222
223 fn declutter(&mut self) -> Result<()> {
224 self.0.declutter()
225 }
226
227 fn optimize(&mut self) -> Result<()> {
228 self.0.optimize()
229 }
230
231 fn into_decluttered(mut self) -> Result<Model> {
232 self.0.declutter()?;
233 Ok(self)
234 }
235
236 fn into_optimized(self) -> Result<Model> {
237 Ok(Model(self.0.into_optimized()?))
238 }
239
240 fn into_runnable(self) -> Result<Runnable> {
241 Ok(Runnable(Arc::new(self.0.into_runnable()?)))
242 }
243
244 fn concretize_symbols(
245 &mut self,
246 values: impl IntoIterator<Item = (impl AsRef<str>, i64)>,
247 ) -> Result<()> {
248 let mut table = SymbolValues::default();
249 for (k, v) in values {
250 table = table.with(&self.0.symbols.sym(k.as_ref()), v);
251 }
252 self.0 = self.0.concretize_dims(&table)?;
253 Ok(())
254 }
255
256 fn transform(&mut self, transform: &str) -> Result<()> {
257 let transform = tract_onnx::tract_core::transform::get_transform(transform)
258 .with_context(|| format!("transform `{transform}' could not be found"))?;
259 transform.transform(&mut self.0)
260 }
261
262 fn pulse(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
263 let stream_sym = self.0.symbols.sym(name.as_ref());
264 let pulse_dim = parse_tdim(&self.0.symbols, value.as_ref())?;
265 self.0 = PulsedModel::new(&self.0, stream_sym, &pulse_dim)?.into_typed()?;
266 Ok(())
267 }
268
269 fn cost_json(&self) -> Result<String> {
270 let input: Option<Vec<Value>> = None;
271 let states: Option<Vec<Value>> = None;
272 self.profile_json(input, states)
273 }
274
275 fn profile_json<I, IV, IE, S, SV, SE>(
276 &self,
277 inputs: Option<I>,
278 state_initializers: Option<S>,
279 ) -> Result<String>
280 where
281 I: IntoIterator<Item = IV>,
282 IV: TryInto<Self::Value, Error = IE>,
283 IE: Into<anyhow::Error> + Debug,
284 S: IntoIterator<Item = SV>,
285 SV: TryInto<Self::Value, Error = SE>,
286 SE: Into<anyhow::Error> + Debug,
287 {
288 let mut annotations = Annotations::from_model(&self.0)?;
289 tract_libcli::profile::extract_costs(&mut annotations, &self.0, &SymbolValues::default())?;
290 if let Some(inputs) = inputs {
291 let inputs = inputs
292 .into_iter()
293 .map(|v| Ok(v.try_into().unwrap().0))
294 .collect::<TractResult<TVec<_>>>()?;
295
296 let mut state_inits: Vec<TValue> = vec![];
297
298 if let Some(states) = state_initializers {
299 states.into_iter().for_each(|s| state_inits.push(s.try_into().unwrap().0));
300 }
301 tract_libcli::profile::profile(
302 &self.0,
303 &BenchLimits::default(),
304 &mut annotations,
305 &PlanOptions::default(),
306 &RunTensors { sources: vec![inputs], state_initializers: state_inits },
307 None,
308 true,
309 )?;
310 };
311 let export = tract_libcli::export::GraphPerfInfo::from(&self.0, &annotations);
312 Ok(serde_json::to_string(&export)?)
313 }
314
315 fn property_keys(&self) -> Result<Vec<String>> {
316 Ok(self.0.properties.keys().cloned().collect())
317 }
318
319 fn property(&self, name: impl AsRef<str>) -> Result<Value> {
320 let name = name.as_ref();
321 self.0
322 .properties
323 .get(name)
324 .with_context(|| format!("no property for name {name}"))
325 .map(|t| Value(t.clone().into_tvalue()))
326 }
327}
328
329pub struct Runnable(Arc<TypedRunnableModel<TypedModel>>);
331
332impl RunnableInterface for Runnable {
333 type Value = Value;
334 type State = State;
335
336 fn run<I, V, E>(&self, inputs: I) -> Result<Vec<Value>>
337 where
338 I: IntoIterator<Item = V>,
339 V: TryInto<Self::Value, Error = E>,
340 E: Into<anyhow::Error>,
341 {
342 self.spawn_state()?.run(inputs)
343 }
344
345 fn input_count(&self) -> Result<usize> {
346 Ok(self.0.model().inputs.len())
347 }
348
349 fn output_count(&self) -> Result<usize> {
350 Ok(self.0.model().outputs.len())
351 }
352
353 fn spawn_state(&self) -> Result<State> {
354 let state = TypedSimpleState::new(self.0.clone())?;
355 Ok(State(state))
356 }
357}
358
359pub struct State(TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>);
361
362impl StateInterface for State {
363 type Fact = Fact;
364 type Value = Value;
365
366 fn input_count(&self) -> Result<usize> {
367 Ok(self.0.model().inputs.len())
368 }
369
370 fn output_count(&self) -> Result<usize> {
371 Ok(self.0.model().outputs.len())
372 }
373
374 fn run<I, V, E>(&mut self, inputs: I) -> Result<Vec<Value>>
375 where
376 I: IntoIterator<Item = V>,
377 V: TryInto<Value, Error = E>,
378 E: Into<anyhow::Error>,
379 {
380 let inputs: TVec<TValue> = inputs
381 .into_iter()
382 .map(|i| i.try_into().map_err(|e| e.into()).map(|v| v.0))
383 .collect::<Result<_>>()?;
384 let outputs = self.0.run(inputs)?;
385 Ok(outputs.into_iter().map(Value).collect())
386 }
387
388 fn initializable_states_count(&self) -> Result<usize> {
389 Ok(self
390 .0
391 .states
392 .iter()
393 .filter_map(Option::as_ref)
394 .filter(|s| s.init_tensor_fact().is_some())
395 .count())
396 }
397
398 fn get_states_facts(&self) -> Result<Vec<Fact>> {
399 Ok(self
400 .0
401 .states
402 .iter()
403 .filter_map(Option::as_ref)
404 .filter_map(|s| s.init_tensor_fact().map(Fact))
405 .collect::<Vec<Fact>>())
406 }
407
408 fn set_states<I, V, E>(&mut self, state_initializers: I) -> Result<()>
409 where
410 I: IntoIterator<Item = V>,
411 V: TryInto<Self::Value, Error = E>,
412 E: Into<anyhow::Error> + Debug,
413 {
414 let mut states = vec![];
415 state_initializers.into_iter().for_each(|s| {
416 states.push(s.try_into().unwrap().0);
417 });
418
419 self.0.init_states(&mut states)?;
420 Ok(())
421 }
422
423 fn get_states(&self) -> Result<Vec<Self::Value>> {
424 let mut states = vec![];
425 for state in self
426 .0
427 .states
428 .iter()
429 .filter_map(Option::as_ref)
430 .filter(|s| s.init_tensor_fact().is_some())
431 {
432 state.save_to(&mut states)?;
433 }
434
435 let mut res = vec![];
436 for state in states {
437 res.push(Value(state));
438 }
439 Ok(res)
440 }
441}
442
443#[derive(Clone)]
445pub struct Value(TValue);
446
447impl ValueInterface for Value {
448 fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
449 let dt = to_internal_dt(dt);
450 let len = shape.iter().product::<usize>() * dt.size_of();
451 anyhow::ensure!(len == data.len());
452 let tensor = unsafe { Tensor::from_raw_dt(dt, shape, data)? };
453 Ok(Value(tensor.into_tvalue()))
454 }
455
456 fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
457 let dt = from_internal_dt(self.0.datum_type())?;
458 Ok((dt, self.0.shape(), unsafe { self.0.as_slice_unchecked::<u8>() }))
459 }
460
461 }
472
473#[derive(Clone, Debug)]
474pub struct Fact(TypedFact);
475
476impl FactInterface for Fact {}
477
478impl Fact {
479 fn new(model: &mut Model, spec: impl ToString) -> Result<Fact> {
480 let fact = tract_libcli::tensor::parse_spec(&model.0.symbols, &spec.to_string())?;
481 let fact = tract_onnx::prelude::Fact::to_typed_fact(&fact)?.into_owned();
482 Ok(Fact(fact))
483 }
484
485 fn dump(&self) -> Result<String> {
486 Ok(format!("{:?}", self.0))
487 }
488}
489
490impl Display for Fact {
491 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492 write!(f, "{}", self.dump().unwrap())
493 }
494}
495
496#[derive(Default, Clone, Debug)]
497pub struct InferenceFact(tract_onnx::prelude::InferenceFact);
498
499impl InferenceFactInterface for InferenceFact {
500 fn empty() -> Result<InferenceFact> {
501 Ok(InferenceFact(Default::default()))
502 }
503}
504
505impl InferenceFact {
506 fn new(model: &mut InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
507 let fact = tract_libcli::tensor::parse_spec(&model.0.symbols, &spec.to_string())?;
508 Ok(InferenceFact(fact))
509 }
510
511 fn dump(&self) -> Result<String> {
512 Ok(format!("{:?}", self.0))
513 }
514}
515
516impl Display for InferenceFact {
517 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518 write!(f, "{}", self.dump().unwrap())
519 }
520}
521
522value_from_to_ndarray!();
523as_inference_fact_impl!(InferenceModel, InferenceFact);
524as_fact_impl!(Model, Fact);
525
526fn to_internal_dt(it: DatumType) -> tract_nnef::prelude::DatumType {
536 use tract_nnef::prelude::DatumType::*;
537 use DatumType::*;
538 match it {
539 TRACT_DATUM_TYPE_BOOL => Bool,
540 TRACT_DATUM_TYPE_U8 => U8,
541 TRACT_DATUM_TYPE_U16 => U16,
542 TRACT_DATUM_TYPE_U32 => U32,
543 TRACT_DATUM_TYPE_U64 => U64,
544 TRACT_DATUM_TYPE_I8 => I8,
545 TRACT_DATUM_TYPE_I16 => I16,
546 TRACT_DATUM_TYPE_I32 => I32,
547 TRACT_DATUM_TYPE_I64 => I64,
548 TRACT_DATUM_TYPE_F16 => F16,
549 TRACT_DATUM_TYPE_F32 => F32,
550 TRACT_DATUM_TYPE_F64 => F64,
551 #[cfg(feature = "complex")]
552 TRACT_DATUM_TYPE_COMPLEX_I16 => ComplexI16,
553 #[cfg(feature = "complex")]
554 TRACT_DATUM_TYPE_COMPLEX_I32 => ComplexI32,
555 #[cfg(feature = "complex")]
556 TRACT_DATUM_TYPE_COMPLEX_I64 => ComplexI64,
557 #[cfg(feature = "complex")]
558 TRACT_DATUM_TYPE_COMPLEX_F16 => ComplexF16,
559 #[cfg(feature = "complex")]
560 TRACT_DATUM_TYPE_COMPLEX_F32 => ComplexF32,
561 #[cfg(feature = "complex")]
562 TRACT_DATUM_TYPE_COMPLEX_F64 => ComplexF64,
563 }
564}
565
566fn from_internal_dt(it: tract_nnef::prelude::DatumType) -> Result<DatumType> {
567 use tract_nnef::prelude::DatumType::*;
568 use DatumType::*;
569 Ok(match it {
570 Bool => TRACT_DATUM_TYPE_BOOL,
571 U8 => TRACT_DATUM_TYPE_U8,
572 U16 => TRACT_DATUM_TYPE_U16,
573 U32 => TRACT_DATUM_TYPE_U32,
574 U64 => TRACT_DATUM_TYPE_U64,
575 I8 => TRACT_DATUM_TYPE_I8,
576 I16 => TRACT_DATUM_TYPE_I16,
577 I32 => TRACT_DATUM_TYPE_I32,
578 I64 => TRACT_DATUM_TYPE_I64,
579 F16 => TRACT_DATUM_TYPE_F16,
580 F32 => TRACT_DATUM_TYPE_F32,
581 F64 => TRACT_DATUM_TYPE_F64,
582 #[cfg(feature = "complex")]
583 TRACT_DATUM_TYPE_COMPLEX_I16 => ComplexI16,
584 #[cfg(feature = "complex")]
585 TRACT_DATUM_TYPE_COMPLEX_I32 => ComplexI32,
586 #[cfg(feature = "complex")]
587 TRACT_DATUM_TYPE_COMPLEX_I64 => ComplexI64,
588 #[cfg(feature = "complex")]
589 TRACT_DATUM_TYPE_COMPLEX_F16 => ComplexF16,
590 #[cfg(feature = "complex")]
591 TRACT_DATUM_TYPE_COMPLEX_F32 => ComplexF32,
592 #[cfg(feature = "complex")]
593 TRACT_DATUM_TYPE_COMPLEX_F64 => ComplexF64,
594 _ => {
595 anyhow::bail!("Unsupported DatumType in the public API {:?}", it)
596 }
597 })
598}