1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
//! # Tract TensorFlow module //! //! Tiny, no-nonsense, self contained, portable inference. //! //! ## Example //! //! ``` //! # extern crate tract_core; //! # extern crate tract_tensorflow; //! # fn main() { //! use tract_core::prelude::*; //! //! // build a simple model that just add 3 to each input component //! let tf = tract_tensorflow::tensorflow(); //! let mut model = tf.model_for_path("tests/models/plus3.pb").unwrap(); //! //! // set input input type and shape, then optimize the network. //! model.set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), tvec!(3))).unwrap(); //! let model = model.into_optimized().unwrap(); //! //! // we build an execution plan. default input and output are inferred from //! // the model graph //! let plan = SimplePlan::new(&model).unwrap(); //! //! // run the computation. //! let input = tensor1(&[1.0f32, 2.5, 5.0]); //! let mut outputs = plan.run(tvec![input]).unwrap(); //! //! // take the first and only output tensor //! let mut tensor = outputs.pop().unwrap(); //! //! assert_eq!(tensor, rctensor1(&[4.0f32, 5.5, 8.0])); //! # } //! ``` //! #[macro_use] extern crate derive_new; #[macro_use] extern crate error_chain; #[allow(unused_imports)] #[macro_use] extern crate log; #[cfg(any(test, featutre = "conform"))] extern crate env_logger; extern crate num_traits; extern crate prost; extern crate prost_types; #[macro_use] extern crate tract_core; #[cfg(feature = "conform")] extern crate tensorflow; #[cfg(feature = "conform")] pub mod conform; pub mod model; pub mod ops; pub mod tensor; pub mod tfpb; pub use model::Tensorflow; pub fn tensorflow() -> Tensorflow { let mut ops = crate::model::TfOpRegister::default(); ops::register_all_ops(&mut ops); Tensorflow { op_register: ops } } #[cfg(test)] #[allow(dead_code)] pub fn setup_test_logger() { env_logger::Builder::from_default_env().filter_level(log::LevelFilter::Trace).init(); }