tract_tensorflow/ops/
mod.rs

1use tract_hir::internal::*;
2
3use crate::model::ParsingContext;
4use crate::model::TfOpRegister;
5use crate::tfpb::tensorflow::NodeDef;
6
7pub mod array;
8pub mod control_flow;
9pub mod logic;
10pub mod math;
11pub mod nn;
12pub mod quant;
13pub mod random;
14pub mod rec;
15pub mod vars;
16
17pub fn register_all_ops(reg: &mut TfOpRegister) {
18    array::register_all_ops(reg);
19    control_flow::register_all_ops(reg);
20    logic::register_all_ops(reg);
21    math::register_all_ops(reg);
22    nn::register_all_ops(reg);
23    quant::register_all_ops(reg);
24    random::register_all_ops(reg);
25    rec::register_all_ops(reg);
26    vars::register_all_ops(reg);
27    reg.insert("Cast", cast);
28    reg.insert("Const", konst);
29    reg.insert("Identity", |_, _| Ok(Box::new(tract_hir::ops::identity::Identity)));
30    reg.insert("NoOp", |_, _| Ok(Box::new(Noop)));
31    reg.insert("Placeholder", |_, _| Ok(Box::new(tract_hir::ops::source::Source::new())));
32}
33
34fn cast(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
35    let dtype = node.get_attr_datum_type("DstT")?;
36    Ok(Box::new(tract_hir::ops::cast::cast(dtype)))
37}
38
39fn konst(_ctx: &ParsingContext, node: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
40    let dtype = node.get_attr_datum_type("dtype")?;
41    let mat = node.get_attr_tensor("value")?;
42
43    if mat.datum_type() != dtype {
44        bail!("Const node {:?} doesn't have the expected {:?} type.", mat, dtype);
45    }
46
47    Ok(Box::new(tract_hir::ops::konst::Const::new(mat.into())?))
48}
49
50#[derive(Clone, Debug, new, Hash)]
51pub struct Noop;
52
53impl Op for Noop {
54    fn name(&self) -> StaticName {
55        "Noop".into()
56    }
57
58    op_as_typed_op!();
59}
60
61impl EvalOp for Noop {
62    fn is_stateless(&self) -> bool {
63        true
64    }
65
66    fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
67        Ok(tvec!(Tensor::from(false).into()))
68    }
69}
70
71impl InferenceRulesOp for Noop {
72    fn rules<'r, 'p: 'r, 's: 'r>(
73        &'s self,
74        s: &mut Solver<'r>,
75        _inputs: &'p [TensorProxy],
76        outputs: &'p [TensorProxy],
77    ) -> InferenceResult {
78        s.equals(&outputs[0].datum_type, bool::datum_type())?;
79        s.equals(&outputs[0].rank, 0)?;
80        Ok(())
81    }
82
83    as_op!();
84    to_typed!();
85}
86
87impl TypedOp for Noop {
88    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
89        Ok(tvec!(bool::scalar_fact()))
90    }
91
92    as_op!();
93}