tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::tract_data::itertools::Itertools;
6
7mod eval;
8
9#[cfg(feature = "blas")]
10pub mod as_blas;
11mod as_matmul;
12pub mod kernel_selection;
13pub mod optimize;
14
15#[cfg(test)]
16mod proptest;
17
18pub use as_matmul::{rewrite_einsums_as_matmul, BasicMatMul};
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedOpaqueFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<[TDim]>> {
24    if !fact.datum_type.is_opaque() {
25        return Ok(Cow::Borrowed(&*fact.shape));
26    }
27    let Some(opaque_fact) = fact.opaque_fact() else {
28        bail!("Datum fact is opaque, but no opaque fact was found.")
29    };
30    if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() {
31        Ok(Cow::Owned(
32            fact.shape.iter().cloned().chain(bqf.shape().iter().map(|d| d.to_dim())).collect_vec(),
33        ))
34    // } else if let Some(pbqf) = opaque_fact.downcast_ref::<PackedBlockQuantFact>() {
35    //     &pbqf.shape
36    } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
37        Ok(Cow::Owned(
38            fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
39        ))
40    } else {
41        bail!("Unsupported opaque fact {opaque_fact:?}")
42    }
43}
44
45#[derive(Clone, Hash)]
46pub struct EinSum {
47    pub axes: AxesMapping,
48    pub operating_dt: DatumType,
49    // if present, assume we're a binary op.
50    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
51    pub q_params: Option<DatumType>,
52}
53
54impl EinSum {
55    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
56        EinSum { axes, operating_dt, q_params: None }
57    }
58
59    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
60        EinSum { axes, operating_dt, q_params: Some(output_type) }
61    }
62
63    pub fn actual_input_shapes_from_facts<'m>(
64        &self,
65        inputs: &'m [impl Borrow<TypedFact>],
66    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
67        ensure!(inputs.len() == self.axes.input_count());
68        let shapes: TVec<Cow<[TDim]>> = inputs
69            .iter()
70            .map(|t| block_quant_aware_input_shape(t.borrow()))
71            .collect::<TractResult<_>>()?;
72        ensure!(shapes
73            .iter()
74            .enumerate()
75            .all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix))));
76        Ok(shapes)
77    }
78
79    #[allow(unused_variables)]
80    pub(crate) fn propagate_axis(
81        &self,
82        model: &TypedModel,
83        node: &TypedNode,
84        io: InOut,
85        axis: usize,
86    ) -> TractResult<Option<TypedModelPatch>> {
87        let mut new_axis = self.axes.axis((io, axis))?.clone();
88        let repr = new_axis.repr;
89        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
90        let mut taps = tvec!();
91        for (ix, input) in node.inputs.iter().enumerate() {
92            let mut tap = patch.tap_model(model, *input)?;
93            if new_axis.inputs[ix].len() > 1 {
94                return Ok(None); // FIXME maybe
95            } else if new_axis.inputs[ix].is_empty() {
96                let insert_at = self.axes.rank(InOut::In(ix));
97                tap = patch.wire_node(
98                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
99                    AxisOp::Add(insert_at),
100                    &[tap],
101                )?[0];
102                new_axis.inputs[ix].push(insert_at);
103            }
104            taps.push(tap);
105        }
106        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
107            let insert_at = self.axes.rank(InOut::Out(0));
108            new_axis.outputs[0].push(insert_at);
109            Some(insert_at)
110        } else {
111            None
112        };
113        let new_expr = self
114            .axes
115            .iter_all_axes()
116            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
117            .collect_vec();
118        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
119        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
120        if let Some(position) = must_rm_axis {
121            wire = patch.wire_node(
122                format!("{}.prop_axis.{}.output", &node.name, repr),
123                AxisOp::Rm(position),
124                &wire,
125            )?;
126        }
127        patch.shunt_outside(model, node.id.into(), wire[0])?;
128        Ok(Some(patch))
129    }
130
131    pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
132        if self.operating_dt.is_integer() {
133            tvec!(i32::datum_type())
134        } else if self.operating_dt == f16::datum_type() {
135            tvec!(f16::datum_type(), f32::datum_type())
136        } else {
137            tvec!(self.operating_dt)
138        }
139    }
140}
141
142impl Debug for EinSum {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
145    }
146}
147
148impl Op for EinSum {
149    fn name(&self) -> Cow<str> {
150        "EinSum".into()
151    }
152
153    fn info(&self) -> TractResult<Vec<String>> {
154        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
155        if let Some(qp) = self.q_params {
156            info.push(format!("Quantized output: {qp:?}"));
157        }
158        Ok(info)
159    }
160
161    op_as_typed_op!();
162}
163
164impl EvalOp for EinSum {
165    fn is_stateless(&self) -> bool {
166        true
167    }
168
169    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
170        let output = if let Some(qp) = self.q_params {
171            eval::eval_q(&self.axes, qp, inputs)
172        } else {
173            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
174        }?;
175        Ok(tvec!(output.into_tvalue()))
176    }
177}
178
179impl TypedOp for EinSum {
180    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
181        let shapes = self.actual_input_shapes_from_facts(inputs)?;
182        for i in 0..inputs.len() {
183            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
184        }
185        for axis in self.axes.iter_all_axes() {
186            assert!(shapes
187                .iter()
188                .enumerate()
189                .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
190                .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
191                .is_ok());
192        }
193        if let Some(qp) = self.q_params {
194            ensure!(inputs.len() == 9);
195            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
196        } else {
197            Ok(tvec!(TypedFact::dt_shape(
198                self.operating_dt,
199                eval::output_shape(&self.axes, &shapes)?
200            )))
201        }
202    }
203
204    fn axes_mapping(
205        &self,
206        inputs: &[&TypedFact],
207        _outputs: &[&TypedFact],
208    ) -> TractResult<AxesMapping> {
209        let mut axes = self.axes.clone();
210        for (slot, i) in inputs.iter().enumerate() {
211            if i.datum_type.is_opaque()
212                && (i.opaque_fact().is_some_and(|of| {
213                    of.is::<BlockQuantFact>()
214                        || of.is::<PackedOpaqueFact>()
215                        || of.is::<PackedBlockQuantFact>()
216                }))
217            {
218                axes = axes
219                    .remove_axis_occurency(InOut::In(slot), i.rank())?
220                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
221            }
222        }
223        Ok(axes)
224    }
225
226    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
227        let shapes = self.actual_input_shapes_from_facts(inputs)?;
228        let oshape = eval::output_shape(&self.axes, &shapes)?;
229        let ks = self
230            .axes
231            .iter_all_axes()
232            .filter(|axis| axis.outputs[0].len() == 0)
233            .map(|axis| {
234                axis.inputs
235                    .iter()
236                    .enumerate()
237                    .flat_map(|(ix, axes)| {
238                        axes.iter()
239                            .map(|axis| shapes[ix][*axis].clone())
240                            .collect::<TVec<_>>()
241                            .into_iter()
242                    })
243                    .find(|d| !d.is_one())
244                    .unwrap_or_else(|| 1.to_dim())
245            })
246            .product::<TDim>();
247        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
248    }
249
250    fn slice(
251        &self,
252        patch: &mut TypedModelPatch,
253        model: &TypedModel,
254        node: &TypedNode,
255        prefix: &str,
256        inputs: &[OutletId],
257        output_axis: usize,
258        _start: &TDim,
259        _end: &TDim,
260    ) -> TractResult<Option<TVec<OutletId>>> {
261        let facts = model.node_input_facts(node.id)?;
262        let axis = self.axes.axis((InOut::Out(0), output_axis))?;
263        if facts
264            .iter()
265            .enumerate()
266            .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.datum_type.is_opaque())
267        {
268            Ok(None)
269        } else {
270            patch.wire_node(prefix, self.clone(), inputs).map(Some)
271        }
272    }
273
274    #[allow(unused_variables)]
275    fn change_axes(
276        &self,
277        model: &TypedModel,
278        node: &TypedNode,
279        io: InOut,
280        change: &AxisOp,
281    ) -> TractResult<Option<AxisChangeConsequence>> {
282        let (mut inputs, mut outputs) = self.axes.to_strs();
283        let interface: &mut String = match io {
284            InOut::In(i) => &mut inputs[i],
285            InOut::Out(o) => &mut outputs[o],
286        };
287        let mut axes: Vec<char> = interface.chars().collect();
288        match change {
289            AxisOp::Rm(rm) => {
290                axes.remove(*rm);
291            }
292            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
293            AxisOp::Move(from, to) => {
294                let c = axes.remove(*from);
295                axes.insert(*to, c);
296            }
297            _ => {
298                return Ok(None);
299            }
300        };
301        *interface = axes.into_iter().collect();
302        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
303        Ok(Some(AxisChangeConsequence {
304            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
305            wire_changes: tvec!((io, change.clone())),
306        }))
307    }
308
309    //     fn declutter_with_session(
310    //         &self,
311    //         session: &mut crate::optim::OptimizerSession,
312    //         model: &TypedModel,
313    //         node: &TypedNode,
314    //     ) -> TractResult<Option<TypedModelPatch>> {
315    //         // if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
316    //         //     return Ok(Some(patch));
317    //         // }
318    //         return Ok(None);
319    //     }
320
321    fn codegen(
322        &self,
323        model: &TypedModel,
324        node: &TypedNode,
325    ) -> TractResult<Option<TypedModelPatch>> {
326        optimize::optimize(self, model, node).with_context(|| {
327            format!(
328                "axes: {} — inputs: {}",
329                self.axes,
330                model
331                    .node_input_facts(node.id)
332                    .unwrap()
333                    .iter()
334                    .map(|f| format!("{f:?}"))
335                    .join(" • ")
336            )
337        })
338    }
339
340    as_op!();
341}
342
343// fn declutter_reshape_folding_input_axis(
344//     op: &EinSum,
345//     _session: &mut crate::optim::OptimizerSession,
346//     model: &TypedModel,
347//     node: &TypedNode,
348// ) -> TractResult<Option<TypedModelPatch>> {
349//     for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
350//         let Some(AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
351//         if to.len() > 1 {
352//             continue;
353//         }
354//         let mut axes = op.axes.clone();
355//         let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
356//         // add a temporary input to axes to hold the extra axes
357//         let extra_input = node.inputs.len();
358//         axes = axes.with_extra_input(extra_input)?;
359//         for label in &extra_labels {
360//             axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
361//         }
362//         let folded_axis = op.axes.axis((InOut::In(slot), *at))?;
363//         if folded_axis.outputs[0].len() > 1 {
364//             return Ok(None);
365//         };
366//         let mut patch = TypedModelPatch::default();
367//         let mut taps = patch.taps(model, &node.inputs)?;
368//         for (input, tap) in taps.iter_mut().enumerate() {
369//             if folded_axis.inputs[input].len() == 0 {
370//                 continue;
371//             };
372//             if folded_axis.inputs[input].len() > 1 {
373//                 return Ok(None);
374//             };
375//             let pos = folded_axis.inputs[input][0];
376//             for label in &extra_labels {
377//                 axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
378//             }
379//             *tap = patch.wire_node(
380//                 format!("{}.reshape_folded_input_{}", node.name, input),
381//                 AxisOp::Reshape(pos, to.clone(), from.clone()),
382//                 &[*tap],
383//             )?[0];
384//         }
385//         if folded_axis.outputs[0].len() == 1 {
386//             let pos = folded_axis.outputs[0][0];
387//             for label in &extra_labels {
388//                 axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
389//             }
390//         }
391//         axes = axes.remove_slot(InOut::In(extra_input))?;
392//         let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
393//         if folded_axis.outputs[0].len() == 1 {
394//             let pos = folded_axis.outputs[0][0];
395//             wire = patch.wire_node(
396//                 format!("{}.reshape_folded_output", node.name),
397//                 AxisOp::Reshape(pos, from.clone(), to.clone()),
398//                 &wire,
399//             )?;
400//         }
401//         patch.shunt_outside(model, node.id.into(), wire[0])?;
402//         return Ok(Some(patch));
403//     }
404//     return Ok(None);
405// }