tract_nnef/
registry.rs

1use std::ops::ControlFlow;
2
3use crate::ast::Identifier;
4use crate::internal::*;
5
6use crate::ast;
7use crate::deser::Value;
8
9use tract_core::dyn_clone::clone_box;
10use tract_core::ops::binary::*;
11use tract_core::transform::ModelTransform;
12
13pub type ToTract = fn(&mut ModelBuilder, &ResolvedInvocation) -> TractResult<Value>;
14pub type FromTract =
15    Box<dyn Fn(&mut IntoAst, &TypedNode) -> TractResult<Option<Arc<RValue>>> + Send + Sync>;
16pub type FromTractWithOp<O> =
17    fn(&mut IntoAst, node: &TypedNode, op: &O) -> TractResult<Option<Arc<RValue>>>;
18pub type BinOp = (Identifier, Box<dyn BinMiniOp>);
19pub type GetTransform =
20    Box<dyn Fn(&str) -> TractResult<Option<Box<dyn ModelTransform>>> + Send + Sync>;
21pub type Extension = Box<
22    dyn Fn(&mut crate::deser::ModelBuilder, &Identifier, &str) -> TractResult<ControlFlow<(), ()>>
23        + Send
24        + Sync,
25>;
26
27#[derive(Clone)]
28pub struct PrimitiveDecl {
29    pub decl: FragmentDecl,
30    pub docstrings: Option<Vec<String>>,
31    pub to_tract: ToTract,
32}
33
34impl PrimitiveDecl {
35    pub fn validate(&self) -> TractResult<()> {
36        self.decl.validate().with_context(|| format!("Invalid primitive `{}'", self.decl.id.0))
37    }
38
39    pub fn with_doc(&mut self, docstring: impl Into<String>) -> &mut Self {
40        self.docstrings.get_or_insert_with(Vec::new).push(docstring.into());
41        self
42    }
43}
44
45pub struct Registry {
46    pub id: Identifier,
47    pub docstrings: Option<Vec<String>>,
48    pub aliases: Vec<Identifier>,
49    pub fragments: HashMap<Identifier, FragmentDef>,
50    pub primitives: HashMap<Identifier, PrimitiveDecl>,
51    pub transforms: GetTransform,
52    pub unit_element_wise_ops: Vec<(Identifier, Box<dyn ElementWiseMiniOp>)>,
53    pub element_wise_ops: Vec<(Identifier, TypeId, FromTract, Vec<ast::Parameter>, ToTract)>,
54    pub binary_ops: Vec<BinOp>,
55    pub from_tract: HashMap<TypeId, FromTract>,
56    pub extensions: Vec<Extension>,
57}
58
59impl Registry {
60    pub fn new(id: impl AsRef<str>) -> Registry {
61        Registry {
62            id: id.as_ref().into(),
63            docstrings: None,
64            aliases: Default::default(),
65            primitives: Default::default(),
66            fragments: Default::default(),
67            from_tract: Default::default(),
68            unit_element_wise_ops: Default::default(),
69            element_wise_ops: Default::default(),
70            binary_ops: Default::default(),
71            transforms: Box::new(|_| Ok(None)),
72            extensions: Default::default(),
73        }
74    }
75
76    pub fn with_doc(mut self, docstring: impl Into<String>) -> Registry {
77        self.docstrings.get_or_insert_with(Vec::new).push(docstring.into());
78        self
79    }
80
81    pub fn register_dumper<O: TypedOp>(&mut self, dumper: FromTractWithOp<O>) {
82        self.from_tract.insert(
83            std::any::TypeId::of::<O>(),
84            Box::new(move |ast: &mut IntoAst, node: &TypedNode| {
85                let op = node.op_as::<O>().unwrap();
86                dumper(ast, node, op)
87            }),
88        );
89    }
90
91    pub fn register_primitive(
92        &mut self,
93        id: impl AsRef<str>,
94        params: &[ast::Parameter],
95        results: &[impl Into<ast::Result_> + Clone],
96        func: ToTract,
97    ) -> &mut PrimitiveDecl {
98        let id: Identifier = id.as_ref().into();
99        let decl = FragmentDecl {
100            id: id.clone(),
101            generic_decl: None,
102            parameters: params.to_vec(),
103            results: results.iter().cloned().map(|it| it.into()).collect(),
104        };
105        let primitive_decl = PrimitiveDecl { decl, docstrings: None, to_tract: func };
106        self.primitives.insert(id.clone(), primitive_decl);
107        self.primitives.get_mut(&id).expect("Unexpected empty entry in primitives hashmap")
108    }
109
110    pub fn register_fragment(&mut self, def: FragmentDef) {
111        self.fragments.insert(def.decl.id.clone(), def);
112    }
113
114    pub fn register_unit_element_wise(&mut self, id: impl AsRef<str>, ew: &dyn ElementWiseMiniOp) {
115        assert!(std::mem::size_of_val(ew) == 0);
116        self.unit_element_wise_ops.push((id.as_ref().into(), clone_box(ew)));
117    }
118
119    pub fn register_element_wise(
120        &mut self,
121        id: impl AsRef<str>,
122        type_id: TypeId,
123        dumper: FromTract,
124        parameters: Vec<ast::Parameter>,
125        loader: ToTract,
126    ) {
127        self.element_wise_ops.push((id.as_ref().into(), type_id, dumper, parameters, loader));
128    }
129
130    pub fn register_binary(&mut self, id: impl AsRef<str>, op: &dyn BinMiniOp) {
131        self.binary_ops.push((id.as_ref().into(), clone_box(op)));
132    }
133
134    pub fn serialize(
135        &self,
136        ast: &mut IntoAst,
137        node: &TypedNode,
138    ) -> TractResult<Option<Arc<RValue>>> {
139        use tract_core::ops;
140        if node.op_is::<ops::identity::Identity>() {
141            return Ok(Some(ast.mapping[&node.inputs[0]].clone()));
142        } else if let Some(op) = node.op().downcast_ref::<ops::element_wise::ElementWiseOp>() {
143            if std::mem::size_of_val(op.0.as_ref()) == 0 {
144                if let Some(op) = self
145                    .unit_element_wise_ops
146                    .iter()
147                    .find(|ew| ew.1.as_ref().type_id() == op.0.type_id())
148                {
149                    let a = ast.mapping[&node.inputs[0]].clone();
150                    return Ok(Some(invocation(&op.0, &[a], &[])));
151                }
152            } else if let Some(op) = self.element_wise_ops.iter().find(|ew| ew.1 == op.0.type_id())
153            {
154                if let Some(result) = (op.2)(ast, node)? {
155                    return Ok(Some(result));
156                }
157            }
158        } else if let Some(op) = node.op().downcast_ref::<ops::binary::TypedBinOp>() {
159            if let Some(op) =
160                self.binary_ops.iter().find(|ew| ew.1.as_ref().type_id() == op.0.type_id())
161            {
162                let a = ast.mapping[&node.inputs[0]].clone();
163                let b = ast.mapping[&node.inputs[1]].clone();
164                return Ok(Some(invocation(&op.0, &[a, b], &[])));
165            }
166        } else if let Some(op) = self.from_tract.get(&node.op().type_id()) {
167            if let Some(result) = op(ast, node)? {
168                return Ok(Some(result));
169            }
170        }
171        Ok(None)
172    }
173
174    pub fn deserialize(
175        &self,
176        builder: &mut ModelBuilder,
177        invocation: &ast::Invocation,
178        dt: &[Option<DatumType>],
179    ) -> TractResult<Option<Value>> {
180        if let Some(op) = self.primitives.get(&invocation.id) {
181            let resolved = ResolvedInvocation {
182                invocation,
183                default_params: &op.decl.parameters,
184                dt_from_quant_file: dt,
185            };
186            let out_value = (op.to_tract)(builder, &resolved)
187                .with_context(|| format!("Deserializing op `{}'", invocation.id.0))?;
188            return Ok(Some(out_value));
189        }
190        let c_dt: Option<DatumType> = dt.first().cloned().and_then(|dt| dt);
191        if let Some(ew) = self.unit_element_wise_ops.iter().find(|ew| ew.0 == invocation.id) {
192            let input =
193                invocation.arguments[0].rvalue.resolve(builder, &[])?.to::<OutletId>(builder)?;
194            let outlet = builder.wire_as_outlets(
195                tract_core::ops::element_wise::ElementWiseOp(ew.1.clone(), c_dt),
196                &[input],
197            )?;
198            if let Some(assumed_out_dt) = c_dt {
199                let out_dt = builder.model.outlet_fact(outlet[0])?.datum_type;
200                if out_dt != assumed_out_dt {
201                    return Ok(Some(
202                        builder.wire(tract_core::ops::cast::cast(assumed_out_dt), &outlet)?,
203                    ));
204                }
205            }
206            return Ok(Some(Value::Wire(outlet[0])));
207        }
208        if let Some(ew) = self.element_wise_ops.iter().find(|ew| ew.0 == invocation.id) {
209            let resolved =
210                ResolvedInvocation { invocation, default_params: &ew.3, dt_from_quant_file: dt };
211            return Ok(Some(
212                (ew.4)(builder, &resolved)
213                    .with_context(|| format!("Deserializing op `{}'", invocation.id.0))?,
214            ));
215        }
216        if let Some(bin) = self.binary_ops.iter().find(|bin| bin.0 == invocation.id) {
217            let mut a =
218                invocation.arguments[0].rvalue.resolve(builder, &[])?.to::<OutletId>(builder)?;
219            let mut b =
220                invocation.arguments[1].rvalue.resolve(builder, &[])?.to::<OutletId>(builder)?;
221            let a_fact = builder.model.outlet_fact(a)?;
222            let b_fact = builder.model.outlet_fact(b)?;
223            let a_dt = a_fact.datum_type;
224            let b_dt = b_fact.datum_type;
225
226            // mitigation of nnef "scalar" type mismatch with tract-core more
227            // strict types
228            let operating_dt = if a_dt == b_dt
229                && bin.1.operating_datum_type(a_dt, b_dt).map(|it| it == a_dt).unwrap_or(false)
230            {
231                a_dt
232            } else if a_dt == String::datum_type() || b_dt == String::datum_type() {
233                String::datum_type()
234            } else if a_dt == TDim::datum_type() || b_dt == TDim::datum_type() {
235                bin.1.operating_datum_type(a_dt, b_dt)?
236            // assume scalar are inline and we should not trust their DT
237            } else if a_fact.konst.is_some() && a_fact.shape.volume().is_one() {
238                b_dt
239            } else if b_fact.konst.is_some() && b_fact.shape.volume().is_one() {
240                a_dt
241            } else if builder.model.node(a.node).op_is::<tract_core::ops::konst::Const>() {
242                b_dt
243            } else if builder.model.node(b.node).op_is::<tract_core::ops::konst::Const>() {
244                a_dt
245            } else {
246                bin.1.operating_datum_type(a_dt, b_dt)?
247            };
248
249            if !a_dt.is_quantized() || !b_dt.is_quantized() {
250                a = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[a])?[0];
251                b = builder.wire_as_outlets(tract_core::ops::cast::cast(operating_dt), &[b])?[0];
252            }
253            let inputs = multi_rank_broadcast(builder, &[a, b])?;
254
255            let c_dt: Option<DatumType> = dt.first().cloned().and_then(|dt| dt);
256            let required_operating_dt = a_dt.is_quantized() || b_dt.is_quantized();
257            let mut wire = builder.wire_as_outlets(
258                tract_core::ops::binary::TypedBinOp(
259                    bin.1.clone(),
260                    c_dt.filter(|_| required_operating_dt),
261                ),
262                &inputs,
263            )?[0];
264            if let Some(c_dt) = c_dt {
265                wire = builder.wire_as_outlets(tract_core::ops::cast::cast(c_dt), &[wire])?[0];
266            }
267            return Ok(Some(Value::Wire(wire)));
268        }
269        if let Some(frag) = self.fragments.get(&invocation.id) {
270            let resolved = ResolvedInvocation {
271                invocation,
272                default_params: &frag.decl.parameters,
273                dt_from_quant_file: dt,
274            };
275            return Ok(Some(builder.wire_fragment_invocation(
276                &resolved,
277                &frag.decl,
278                frag.body.as_deref().unwrap(),
279            )?));
280        }
281        Ok(None)
282    }
283}
284
285pub fn multi_rank_broadcast(
286    builder: &mut ModelBuilder,
287    inputs: &[OutletId],
288) -> TractResult<TVec<OutletId>> {
289    let ranks = inputs
290        .iter()
291        .map(|&i| Ok(builder.model.outlet_fact(i)?.rank()))
292        .collect::<TractResult<Vec<usize>>>()?;
293    let max_rank = ranks.iter().copied().max().unwrap();
294    (inputs.iter())
295        .zip(ranks.iter())
296        .map(|(&i, &r)| {
297            (r..max_rank).try_fold(i, |w, n| Ok(builder.wire_as_outlets(AxisOp::Add(n), &[w])?[0]))
298        })
299        .collect()
300}