tract_tensorflow/ops/
mod.rs1use tract_hir::internal::*;
2
3use crate::model::ParsingContext;
4use crate::model::TfOpRegister;
5use crate::tfpb::tensorflow::NodeDef;
6
7pub mod array;
8pub mod control_flow;
9pub mod logic;
10pub mod math;
11pub mod nn;
12pub mod quant;
13pub mod random;
14pub mod rec;
15pub mod vars;
16
17pub fn register_all_ops(reg: &mut TfOpRegister) {
18 array::register_all_ops(reg);
19 control_flow::register_all_ops(reg);
20 logic::register_all_ops(reg);
21 math::register_all_ops(reg);
22 nn::register_all_ops(reg);
23 quant::register_all_ops(reg);
24 random::register_all_ops(reg);
25 rec::register_all_ops(reg);
26 vars::register_all_ops(reg);
27 reg.insert("Cast", cast);
28 reg.insert("Const", konst);
29 reg.insert("Identity", |_, _| Ok(Box::new(tract_hir::ops::identity::Identity)));
30 reg.insert("NoOp", |_, _| Ok(Box::new(Noop)));
31 reg.insert("Placeholder", |_, _| Ok(Box::new(tract_hir::ops::source::Source::new())));
32}
33
34fn cast(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
35 let dtype = node.get_attr_datum_type("DstT")?;
36 Ok(Box::new(tract_hir::ops::cast::cast(dtype)))
37}
38
39fn konst(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
40 let dtype = node.get_attr_datum_type("dtype")?;
41 let mat = node.get_attr_tensor("value")?;
42
43 if mat.datum_type() != dtype {
44 bail!("Const node {:?} doesn't have the expected {:?} type.", mat, dtype);
45 }
46
47 Ok(Box::new(tract_hir::ops::konst::Const::new(mat.into())?))
48}
49
50#[derive(Clone, Debug, new, Hash)]
51pub struct Noop;
52
53impl Op for Noop {
54 fn name(&self) -> StaticName {
55 "Noop".into()
56 }
57
58 op_as_typed_op!();
59}
60
61impl EvalOp for Noop {
62 fn is_stateless(&self) -> bool {
63 true
64 }
65
66 fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
67 Ok(tvec!(Tensor::from(false).into()))
68 }
69}
70
71impl InferenceRulesOp for Noop {
72 fn rules<'r, 'p: 'r, 's: 'r>(
73 &'s self,
74 s: &mut Solver<'r>,
75 _inputs: &'p [TensorProxy],
76 outputs: &'p [TensorProxy],
77 ) -> InferenceResult {
78 s.equals(&outputs[0].datum_type, bool::datum_type())?;
79 s.equals(&outputs[0].rank, 0)?;
80 Ok(())
81 }
82
83 as_op!();
84 to_typed!();
85}
86
87impl TypedOp for Noop {
88 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
89 Ok(tvec!(bool::scalar_fact()))
90 }
91
92 as_op!();
93}