tract_tensorflow/
lib.rs

1#![allow(clippy::len_zero)]
2//! # Tract TensorFlow module
3//!
4//! Tiny, no-nonsense, self contained, portable inference.
5//!
6//! ## Example
7//!
8//! ```
9//! # extern crate tract_tensorflow;
10//! # fn main() {
11//! use tract_tensorflow::prelude::*;
12//!
13//! // build a simple model that just add 3 to each input component
14//! let tf = tensorflow();
15//! let mut model = tf.model_for_path("tests/models/plus3.pb").unwrap();
16//!
17//! // set input input type and shape, then optimize the network.
18//! model.set_input_fact(0, f32::fact(&[3]).into()).unwrap();
19//! let model = model.into_optimized().unwrap();
20//!
21//! // we build an execution plan. default input and output are inferred from
22//! // the model graph
23//! let plan = SimplePlan::new(&model).unwrap();
24//!
25//! // run the computation.
26//! let input = tensor1(&[1.0f32, 2.5, 5.0]);
27//! let mut outputs = plan.run(tvec![input]).unwrap();
28//!
29//! // take the first and only output tensor
30//! let mut tensor = outputs.pop().unwrap();
31//!
32//! assert_eq!(tensor, rctensor1(&[4.0f32, 5.5, 8.0]));
33//! # }
34//! ```
35//!
36
37#[macro_use]
38extern crate derive_new;
39#[allow(unused_imports)]
40#[macro_use]
41extern crate log;
42#[cfg(test)]
43extern crate env_logger;
44extern crate prost;
45extern crate prost_types;
46#[cfg(feature = "conform")]
47extern crate tensorflow;
48pub extern crate tract_hir;
49
50#[cfg(feature = "conform")]
51pub mod conform;
52
53pub mod model;
54pub mod ops;
55pub mod tensor;
56pub mod tfpb;
57
58pub use model::Tensorflow;
59
60pub fn tensorflow() -> Tensorflow {
61    let mut ops = crate::model::TfOpRegister::default();
62    ops::register_all_ops(&mut ops);
63    Tensorflow { op_register: ops }
64}
65
66pub use tract_hir::tract_core;
67pub mod prelude {
68    pub use crate::tensorflow;
69    pub use tract_hir::prelude::*;
70    pub use tract_hir::tract_core;
71}
72
73#[cfg(test)]
74#[allow(dead_code)]
75pub fn setup_test_logger() {
76    env_logger::Builder::from_default_env().filter_level(log::LevelFilter::Trace).init();
77}