1use std::any::Any;
2
3use crate::internal::*;
4use tract_core::internal::*;
5
6pub fn expand<E: Expansion>(e: E) -> Box<dyn InferenceOp> {
7 Box::new(Box::new(e) as Box<dyn Expansion>)
8}
9
10pub trait Expansion:
11 tract_core::dyn_clone::DynClone
12 + std::fmt::Debug
13 + Send
14 + Sync
15 + tract_core::downcast_rs::Downcast
16 + Any
17{
18 fn name(&self) -> StaticName;
19 fn validation(&self) -> Validation {
20 Validation::Accurate
21 }
22
23 fn info(&self) -> TractResult<Vec<String>> {
24 Ok(vec![])
25 }
26
27 fn nboutputs(&self) -> TractResult<usize> {
28 Ok(1)
29 }
30
31 fn wire(
32 &self,
33 prefix: &str,
34 model: &mut TypedModel,
35 inputs: &[OutletId],
36 ) -> TractResult<TVec<OutletId>>;
37
38 #[allow(unused_variables)]
39 fn wire_with_inference_model_and_node(
40 &self,
41 prefix: &str,
42 model: &InferenceModel,
43 node: &InferenceNode,
44 typed_model: &mut TypedModel,
45 inputs: &[OutletId],
46 ) -> TractResult<TVec<OutletId>> {
47 self.wire(prefix, typed_model, inputs)
48 }
49
50 fn rules<'r, 'p: 'r, 's: 'r>(
51 &'s self,
52 s: &mut Solver<'r>,
53 inputs: &'p [TensorProxy],
54 outputs: &'p [TensorProxy],
55 ) -> InferenceResult;
56
57 fn is_stateless(&self) -> bool {
58 true
59 }
60}
61
62tract_core::dyn_clone::clone_trait_object!(Expansion);
63
64impl Op for Box<dyn Expansion> {
65 fn name(&self) -> StaticName {
66 self.as_ref().name()
67 }
68
69 fn info(&self) -> TractResult<Vec<String>> {
70 self.as_ref().info()
71 }
72
73 fn validation(&self) -> Validation {
74 self.as_ref().validation()
75 }
76
77 not_a_typed_op!();
78}
79
80impl EvalOp for Box<dyn Expansion> {
81 fn is_stateless(&self) -> bool {
82 self.as_ref().is_stateless()
83 }
84
85 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
86 let mut adhoc = TypedModel::default();
87 let wires = inputs
88 .iter()
89 .enumerate()
90 .map(|(ix, i)| {
91 adhoc.add_source(
92 format!("adhoc-source-{ix}"),
93 TypedFact::from(i.clone().into_arc_tensor()),
94 )
95 })
96 .collect::<TractResult<TVec<OutletId>>>()?;
97
98 let wires = self.wire("adhoc", &mut adhoc, &wires)?;
99 adhoc.set_output_outlets(&wires)?;
100 SimplePlan::new(adhoc)?.run(inputs)
101 }
102}
103
104impl InferenceRulesOp for Box<dyn Expansion> {
105 fn rules<'r, 'p: 'r, 's: 'r>(
106 &'s self,
107 s: &mut Solver<'r>,
108 inputs: &'p [TensorProxy],
109 outputs: &'p [TensorProxy],
110 ) -> InferenceResult {
111 self.as_ref().rules(s, inputs, outputs)
112 }
113
114 fn to_typed(
115 &self,
116 source: &InferenceModel,
117 node: &InferenceNode,
118 target: &mut TypedModel,
119 mapping: &HashMap<OutletId, OutletId>,
120 ) -> TractResult<TVec<OutletId>> {
121 let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<Vec<_>>();
122 let outputs =
123 self.wire_with_inference_model_and_node(&node.name, source, node, target, &inputs)?;
124 for (ix, o) in outputs.iter().enumerate() {
125 let expected = &node.outputs[ix].fact;
126 let got = target.outlet_fact(*o)?;
127 if expected.clone().unify_with(&InferenceFact::from(got)).is_err() {
128 bail!("Output mismatch after rewiring expansion for output #{}: expected {:?} got {:?}", ix, expected, got);
129 }
130 }
131 Ok(outputs)
132 }
133
134 fn nboutputs(&self) -> TractResult<usize> {
135 self.as_ref().nboutputs()
136 }
137
138 as_op!();
139}
140
141pub fn inference_wrap<O, R>(op: O, outputs: usize, rules: R) -> Box<dyn InferenceOp>
142where
143 O: TypedOp,
144 R: for<'r, 'p, 's> Fn(
145 &'s dyn Op,
146 &mut Solver<'r>,
147 &'p [TensorProxy],
148 &'p [TensorProxy],
149 ) -> InferenceResult
150 + Send
151 + Sync
152 + 'static,
153{
154 expand(InferenceWrapper { typed_op: Box::new(op), rules: Arc::new(rules), outputs })
155}
156
157type RuleProducer = dyn for<'r, 'p, 's> Fn(
158 &'s dyn Op,
159 &mut Solver<'r>,
160 &'p [TensorProxy],
161 &'p [TensorProxy],
162 ) -> InferenceResult
163 + Send
164 + Sync
165 + 'static;
166
167#[derive(Clone, new)]
168pub struct InferenceWrapper {
169 typed_op: Box<dyn TypedOp>,
170 rules: Arc<RuleProducer>,
171 outputs: usize,
172}
173
174impl std::fmt::Debug for InferenceWrapper {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 std::fmt::Debug::fmt(&self.typed_op, f)
177 }
178}
179
180impl Expansion for InferenceWrapper {
181 fn name(&self) -> StaticName {
182 self.typed_op.name()
183 }
184
185 fn wire(
186 &self,
187 prefix: &str,
188 model: &mut TypedModel,
189 inputs: &[OutletId],
190 ) -> TractResult<TVec<OutletId>> {
191 model.wire_node(prefix, &self.typed_op, inputs)
192 }
193
194 fn rules<'r, 'p: 'r, 's: 'r>(
195 &'s self,
196 s: &mut Solver<'r>,
197 inputs: &'p [TensorProxy],
198 outputs: &'p [TensorProxy],
199 ) -> InferenceResult {
200 (self.rules)(self.typed_op.as_op(), s, inputs, outputs)
201 }
202
203 fn nboutputs(&self) -> TractResult<usize> {
204 Ok(self.outputs)
205 }
206}