tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::Slice;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12use super::array::TypedConcat;
13use super::math::add;
14mod as_matmul;
15pub mod kernel_selection;
16pub mod optimize;
17
18#[cfg(test)]
19mod proptest;
20
21pub use as_matmul::{rewrite_einsums_as_matmul, BasicMatMul};
22use tract_linalg::frame::block_quant::BlockQuantFact;
23use tract_linalg::mmm::PackedOpaqueFact;
24
25pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<[TDim]>> {
26    if !fact.datum_type.is_opaque() {
27        return Ok(Cow::Borrowed(&*fact.shape));
28    }
29    let Some(opaque_fact) = fact.opaque_fact.as_ref() else {
30        bail!("Datum fact is opaque, but no opaque fact was found.")
31    };
32    let inner_shape: Cow<[usize]> = if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>()
33    {
34        Cow::Borrowed(&*bqf.shape)
35    // } else if let Some(pbqf) = opaque_fact.downcast_ref::<PackedBlockQuantFact>() {
36    //     &pbqf.shape
37    } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
38        Cow::Owned(vec![pof.mn, pof.k])
39    } else {
40        bail!("Unsupported opaque fact {opaque_fact:?}")
41    };
42    let shape: Vec<TDim> =
43        fact.shape.iter().cloned().chain(inner_shape.iter().map(|d| d.to_dim())).collect();
44    Ok(Cow::Owned(shape))
45}
46
47#[derive(Clone, Hash)]
48pub struct EinSum {
49    pub axes: AxesMapping,
50    pub operating_dt: DatumType,
51    // if present, assume we're a binary op.
52    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
53    pub q_params: Option<DatumType>,
54}
55
56impl EinSum {
57    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
58        EinSum { axes, operating_dt, q_params: None }
59    }
60
61    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
62        EinSum { axes, operating_dt, q_params: Some(output_type) }
63    }
64
65    pub fn actual_input_shapes_from_facts<'m>(
66        &self,
67        inputs: &'m [impl Borrow<TypedFact>],
68    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
69        ensure!(inputs.len() == self.axes.input_count());
70        let shapes: TVec<Cow<[TDim]>> = inputs
71            .iter()
72            .map(|t| block_quant_aware_input_shape(t.borrow()))
73            .collect::<TractResult<_>>()?;
74        ensure!(shapes
75            .iter()
76            .enumerate()
77            .all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix))));
78        Ok(shapes)
79    }
80
81    #[allow(unused_variables)]
82    pub(crate) fn propagate_axis(
83        &self,
84        model: &TypedModel,
85        node: &TypedNode,
86        io: InOut,
87        axis: usize,
88    ) -> TractResult<Option<TypedModelPatch>> {
89        let mut new_axis = self.axes.axis((io, axis))?.clone();
90        let repr = new_axis.repr;
91        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
92        let mut taps = tvec!();
93        for (ix, input) in node.inputs.iter().enumerate() {
94            let mut tap = patch.tap_model(model, *input)?;
95            if new_axis.inputs[ix].len() > 1 {
96                return Ok(None); // FIXME maybe
97            } else if new_axis.inputs[ix].is_empty() {
98                let insert_at = self.axes.rank(InOut::In(ix));
99                tap = patch.wire_node(
100                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
101                    AxisOp::Add(insert_at),
102                    &[tap],
103                )?[0];
104                new_axis.inputs[ix].push(insert_at);
105            }
106            taps.push(tap);
107        }
108        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
109            let insert_at = self.axes.rank(InOut::Out(0));
110            new_axis.outputs[0].push(insert_at);
111            Some(insert_at)
112        } else {
113            None
114        };
115        let new_expr = self
116            .axes
117            .iter_all_axes()
118            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
119            .collect_vec();
120        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
121        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
122        if let Some(position) = must_rm_axis {
123            wire = patch.wire_node(
124                format!("{}.prop_axis.{}.output", &node.name, repr),
125                AxisOp::Rm(position),
126                &wire,
127            )?;
128        }
129        patch.shunt_outside(model, node.id.into(), wire[0])?;
130        Ok(Some(patch))
131    }
132
133    #[allow(clippy::comparison_chain)]
134    fn declutter_after_concat(
135        &self,
136        model: &TypedModel,
137        node: &TypedNode,
138    ) -> TractResult<Option<TypedModelPatch>> {
139        if self.q_params.is_some() {
140            // FIXME
141            return Ok(None);
142        }
143        'outer: for (slot, input) in node.inputs.iter().enumerate() {
144            let precursor = model.node(input.node);
145            if let Some(concat) = precursor.op_as::<TypedConcat>() {
146                let offsets = concat.offsets(&model.node_input_facts(precursor.id)?)?;
147                let axis_info = self.axes.axis((InOut::In(slot), concat.axis))?;
148                // only split if axis is a summing axis
149                if axis_info.outputs[0].len() > 0 {
150                    continue;
151                }
152                let mut patch = TypedModelPatch::new(format!(
153                    "Split Einsum for concat on axis {}",
154                    axis_info.repr
155                ));
156                // inputs[einsum_input_slot][concated_slice]. concated_slice = 0 for broadcast
157                let mut inputs: TVec<TVec<OutletId>> = tvec!();
158                for (slot, input) in node.inputs.iter().enumerate() {
159                    let tap = patch.tap_model(model, *input)?;
160                    if axis_info.inputs[slot].len() > 1 {
161                        continue 'outer;
162                    } else if axis_info.inputs[slot].len() == 1 {
163                        let mut slices = tvec!();
164                        for (start, end) in offsets.iter().cloned().tuple_windows() {
165                            let wire = patch.wire_node(
166                                format!(
167                                    "{}.concat-einsum-slice-{}.{}.{}..{}",
168                                    node.name, axis_info.repr, slot, start, end
169                                ),
170                                Slice { axis: axis_info.inputs[slot][0], start, end },
171                                &[tap],
172                            )?;
173                            slices.push(wire[0]);
174                        }
175                        inputs.push(slices);
176                    } else {
177                        inputs.push(tvec!(tap)); // broadcast
178                    };
179                }
180                let mut einsums = tvec!();
181                for (ix, (start, end)) in offsets.iter().tuple_windows().enumerate() {
182                    let mut einsum_inputs = tvec!();
183                    for input_ix in 0..node.inputs.len() {
184                        einsum_inputs
185                            .push(inputs[input_ix].get(ix).cloned().unwrap_or(inputs[input_ix][0]));
186                    }
187                    let einsum = patch.wire_node(
188                        format!(
189                            "{}.concat-einsum-{}.{}..{}",
190                            node.name, axis_info.repr, start, end
191                        ),
192                        self.clone(),
193                        &einsum_inputs,
194                    )?[0];
195                    einsums.push(einsum);
196                }
197                let wire = if let Some(axis) = axis_info.outputs[0].first().cloned() {
198                    patch.wire_node(
199                        format!("{}.concat-einsum-{}.concat", node.name, axis_info.repr),
200                        TypedConcat { axis },
201                        &einsums,
202                    )?[0]
203                } else {
204                    let mut wire = einsums[0];
205                    for ix in 1..einsums.len() {
206                        wire = patch.wire_node(
207                            format!("{}.concat-einsum-{}.add-{}", node.name, axis_info.repr, ix),
208                            add(),
209                            &[wire, einsums[ix]],
210                        )?[0]
211                    }
212                    wire
213                };
214                patch.shunt_outside(model, node.id.into(), wire)?;
215                return Ok(Some(patch));
216            }
217        }
218        Ok(None)
219    }
220}
221
222impl Debug for EinSum {
223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
225    }
226}
227
228impl Op for EinSum {
229    fn name(&self) -> Cow<str> {
230        "EinSum".into()
231    }
232
233    fn info(&self) -> TractResult<Vec<String>> {
234        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
235        if let Some(qp) = self.q_params {
236            info.push(format!("Quantized output: {qp:?}"));
237        }
238        Ok(info)
239    }
240
241    op_as_typed_op!();
242}
243
244impl EvalOp for EinSum {
245    fn is_stateless(&self) -> bool {
246        true
247    }
248
249    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
250        let output = if let Some(qp) = self.q_params {
251            eval::eval_q(&self.axes, qp, inputs)
252        } else {
253            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
254        }?;
255        Ok(tvec!(output.into_tvalue()))
256    }
257}
258
259impl TypedOp for EinSum {
260    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
261        let shapes = self.actual_input_shapes_from_facts(inputs)?;
262        for i in 0..inputs.len() {
263            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
264        }
265        for axis in self.axes.iter_all_axes() {
266            assert!(shapes
267                .iter()
268                .enumerate()
269                .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
270                .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
271                .is_ok());
272        }
273        if let Some(qp) = self.q_params {
274            ensure!(inputs.len() == 9);
275            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
276        } else {
277            Ok(tvec!(TypedFact::dt_shape(
278                self.operating_dt,
279                eval::output_shape(&self.axes, &shapes)?
280            )))
281        }
282    }
283
284    fn axes_mapping(
285        &self,
286        inputs: &[&TypedFact],
287        _outputs: &[&TypedFact],
288    ) -> TractResult<AxesMapping> {
289        let mut axes = self.axes.clone();
290        for (slot, i) in inputs.iter().enumerate() {
291            if i.datum_type.is_opaque()
292                && i.opaque_fact.as_ref().is_some_and(|of| of.is::<BlockQuantFact>())
293            {
294                axes = axes
295                    .remove_axis_occurency(InOut::In(slot), i.rank())?
296                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
297            }
298        }
299        Ok(axes)
300    }
301
302    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
303        let shapes = self.actual_input_shapes_from_facts(inputs)?;
304        let oshape = eval::output_shape(&self.axes, &shapes)?;
305        let ks = self
306            .axes
307            .iter_all_axes()
308            .filter(|axis| axis.outputs[0].len() == 0)
309            .map(|axis| {
310                axis.inputs
311                    .iter()
312                    .enumerate()
313                    .flat_map(|(ix, axes)| {
314                        axes.iter()
315                            .map(|axis| shapes[ix][*axis].clone())
316                            .collect::<TVec<_>>()
317                            .into_iter()
318                    })
319                    .find(|d| !d.is_one())
320                    .unwrap_or_else(|| 1.to_dim())
321            })
322            .product::<TDim>();
323        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
324    }
325
326    fn slice(
327        &self,
328        patch: &mut TypedModelPatch,
329        model: &TypedModel,
330        node: &TypedNode,
331        prefix: &str,
332        inputs: &[OutletId],
333        _output_axis: usize,
334        _start: usize,
335        _end: usize,
336    ) -> TractResult<Option<TVec<OutletId>>> {
337        if model.node_input_facts(node.id)?.iter().any(|f| f.datum_type.is_opaque()) {
338            Ok(None)
339        } else {
340            patch.wire_node(prefix, self.clone(), inputs).map(Some)
341        }
342    }
343
344    #[allow(unused_variables)]
345    fn change_axes(
346        &self,
347        model: &TypedModel,
348        node: &TypedNode,
349        io: InOut,
350        change: &AxisOp,
351    ) -> TractResult<Option<AxisChangeConsequence>> {
352        let (mut inputs, mut outputs) = self.axes.to_strs();
353        let interface: &mut String = match io {
354            InOut::In(i) => &mut inputs[i],
355            InOut::Out(o) => &mut outputs[o],
356        };
357        let mut axes: Vec<char> = interface.chars().collect();
358        match change {
359            AxisOp::Rm(rm) => {
360                axes.remove(*rm);
361            }
362            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
363            AxisOp::Move(from, to) => {
364                let c = axes.remove(*from);
365                axes.insert(*to, c);
366            }
367            _ => return Ok(None),
368        };
369        *interface = axes.into_iter().collect();
370        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
371        Ok(Some(AxisChangeConsequence {
372            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
373            wire_changes: tvec!((io, change.clone())),
374        }))
375    }
376
377    fn declutter(
378        &self,
379        model: &TypedModel,
380        node: &TypedNode,
381    ) -> TractResult<Option<TypedModelPatch>> {
382        self.declutter_after_concat(model, node)
383    }
384
385    fn codegen(
386        &self,
387        model: &TypedModel,
388        node: &TypedNode,
389    ) -> TractResult<Option<TypedModelPatch>> {
390        optimize::optimize(self, model, node)
391    }
392
393    as_op!();
394}