tract_hir/ops/
expandable.rs

1use std::any::Any;
2
3use crate::internal::*;
4use tract_core::internal::*;
5
6pub fn expand<E: Expansion>(e: E) -> Box<dyn InferenceOp> {
7    Box::new(Box::new(e) as Box<dyn Expansion>)
8}
9
10pub trait Expansion:
11    tract_core::dyn_clone::DynClone
12    + std::fmt::Debug
13    + Send
14    + Sync
15    + tract_core::downcast_rs::Downcast
16    + Any
17{
18    fn name(&self) -> StaticName;
19    fn validation(&self) -> Validation {
20        Validation::Accurate
21    }
22
23    fn info(&self) -> TractResult<Vec<String>> {
24        Ok(vec![])
25    }
26
27    fn nboutputs(&self) -> TractResult<usize> {
28        Ok(1)
29    }
30
31    fn wire(
32        &self,
33        prefix: &str,
34        model: &mut TypedModel,
35        inputs: &[OutletId],
36    ) -> TractResult<TVec<OutletId>>;
37
38    #[allow(unused_variables)]
39    fn wire_with_inference_model_and_node(
40        &self,
41        prefix: &str,
42        model: &InferenceModel,
43        node: &InferenceNode,
44        typed_model: &mut TypedModel,
45        inputs: &[OutletId],
46    ) -> TractResult<TVec<OutletId>> {
47        self.wire(prefix, typed_model, inputs)
48    }
49
50    fn rules<'r, 'p: 'r, 's: 'r>(
51        &'s self,
52        s: &mut Solver<'r>,
53        inputs: &'p [TensorProxy],
54        outputs: &'p [TensorProxy],
55    ) -> InferenceResult;
56
57    fn is_stateless(&self) -> bool {
58        true
59    }
60}
61
62tract_core::dyn_clone::clone_trait_object!(Expansion);
63
64impl Op for Box<dyn Expansion> {
65    fn name(&self) -> StaticName {
66        self.as_ref().name()
67    }
68
69    fn info(&self) -> TractResult<Vec<String>> {
70        self.as_ref().info()
71    }
72
73    fn validation(&self) -> Validation {
74        self.as_ref().validation()
75    }
76
77    not_a_typed_op!();
78}
79
80impl EvalOp for Box<dyn Expansion> {
81    fn is_stateless(&self) -> bool {
82        self.as_ref().is_stateless()
83    }
84
85    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
86        let mut adhoc = TypedModel::default();
87        let wires = inputs
88            .iter()
89            .enumerate()
90            .map(|(ix, i)| {
91                adhoc.add_source(
92                    format!("adhoc-source-{ix}"),
93                    TypedFact::from(i.clone().into_arc_tensor()),
94                )
95            })
96            .collect::<TractResult<TVec<OutletId>>>()?;
97
98        let wires = self.wire("adhoc", &mut adhoc, &wires)?;
99        adhoc.set_output_outlets(&wires)?;
100        SimplePlan::new(adhoc)?.run(inputs)
101    }
102}
103
104impl InferenceRulesOp for Box<dyn Expansion> {
105    fn rules<'r, 'p: 'r, 's: 'r>(
106        &'s self,
107        s: &mut Solver<'r>,
108        inputs: &'p [TensorProxy],
109        outputs: &'p [TensorProxy],
110    ) -> InferenceResult {
111        self.as_ref().rules(s, inputs, outputs)
112    }
113
114    fn to_typed(
115        &self,
116        source: &InferenceModel,
117        node: &InferenceNode,
118        target: &mut TypedModel,
119        mapping: &HashMap<OutletId, OutletId>,
120    ) -> TractResult<TVec<OutletId>> {
121        let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<Vec<_>>();
122        let outputs =
123            self.wire_with_inference_model_and_node(&node.name, source, node, target, &inputs)?;
124        for (ix, o) in outputs.iter().enumerate() {
125            let expected = &node.outputs[ix].fact;
126            let got = target.outlet_fact(*o)?;
127            if expected.clone().unify_with(&InferenceFact::from(got)).is_err() {
128                bail!("Output mismatch after rewiring expansion for output #{}: expected {:?} got {:?}", ix, expected, got);
129            }
130        }
131        Ok(outputs)
132    }
133
134    fn nboutputs(&self) -> TractResult<usize> {
135        self.as_ref().nboutputs()
136    }
137
138    as_op!();
139}
140
141pub fn inference_wrap<O, R>(op: O, outputs: usize, rules: R) -> Box<dyn InferenceOp>
142where
143    O: TypedOp,
144    R: for<'r, 'p, 's> Fn(
145            &'s dyn Op,
146            &mut Solver<'r>,
147            &'p [TensorProxy],
148            &'p [TensorProxy],
149        ) -> InferenceResult
150        + Send
151        + Sync
152        + 'static,
153{
154    expand(InferenceWrapper { typed_op: Box::new(op), rules: Arc::new(rules), outputs })
155}
156
157type RuleProducer = dyn for<'r, 'p, 's> Fn(
158        &'s dyn Op,
159        &mut Solver<'r>,
160        &'p [TensorProxy],
161        &'p [TensorProxy],
162    ) -> InferenceResult
163    + Send
164    + Sync
165    + 'static;
166
167#[derive(Clone, new)]
168pub struct InferenceWrapper {
169    typed_op: Box<dyn TypedOp>,
170    rules: Arc<RuleProducer>,
171    outputs: usize,
172}
173
174impl std::fmt::Debug for InferenceWrapper {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        std::fmt::Debug::fmt(&self.typed_op, f)
177    }
178}
179
180impl Expansion for InferenceWrapper {
181    fn name(&self) -> StaticName {
182        self.typed_op.name()
183    }
184
185    fn wire(
186        &self,
187        prefix: &str,
188        model: &mut TypedModel,
189        inputs: &[OutletId],
190    ) -> TractResult<TVec<OutletId>> {
191        model.wire_node(prefix, &self.typed_op, inputs)
192    }
193
194    fn rules<'r, 'p: 'r, 's: 'r>(
195        &'s self,
196        s: &mut Solver<'r>,
197        inputs: &'p [TensorProxy],
198        outputs: &'p [TensorProxy],
199    ) -> InferenceResult {
200        (self.rules)(self.typed_op.as_op(), s, inputs, outputs)
201    }
202
203    fn nboutputs(&self) -> TractResult<usize> {
204        Ok(self.outputs)
205    }
206}