Skip to main content

tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12pub mod einsum_matmul;
13pub mod kernel_selection;
14pub mod prefix_matmul;
15
16#[cfg(test)]
17mod proptest;
18
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedOpaqueFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
24    if !fact.datum_type.is_opaque() {
25        return Ok(Cow::Borrowed(&*fact.shape));
26    }
27    let Some(opaque_fact) = fact.opaque_fact() else {
28        bail!("Datum fact is opaque, but no opaque fact was found.")
29    };
30    if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() {
31        Ok(Cow::Owned(
32            fact.shape.iter().cloned().chain(bqf.shape().iter().map(|d| d.to_dim())).collect_vec(),
33        ))
34    } else if let Some(pof) = opaque_fact.downcast_ref::<PackedBlockQuantFact>() {
35        Ok(Cow::Owned(
36            fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
37        ))
38    } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
39        Ok(Cow::Owned(
40            fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
41        ))
42    } else {
43        bail!("Unsupported opaque fact {opaque_fact:?}")
44    }
45}
46
47#[derive(Clone, Hash, PartialEq)]
48pub struct EinSum {
49    pub axes: AxesMapping,
50    pub operating_dt: DatumType,
51    // if present, assume we're a binary op.
52    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
53    pub q_params: Option<DatumType>,
54}
55
56impl EinSum {
57    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
58        EinSum { axes, operating_dt, q_params: None }
59    }
60
61    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
62        EinSum { axes, operating_dt, q_params: Some(output_type) }
63    }
64
65    pub fn actual_input_shapes_from_facts<'m>(
66        &self,
67        inputs: &'m [impl Borrow<TypedFact>],
68    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
69        ensure!(inputs.len() == self.axes.input_count());
70        let shapes: TVec<Cow<[TDim]>> = inputs
71            .iter()
72            .map(|t| block_quant_aware_input_shape(t.borrow()))
73            .collect::<TractResult<_>>()?;
74        ensure!(shapes
75            .iter()
76            .enumerate()
77            .all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix))));
78        Ok(shapes)
79    }
80
81    #[allow(unused_variables)]
82    pub(crate) fn propagate_axis(
83        &self,
84        model: &TypedModel,
85        node: &TypedNode,
86        io: InOut,
87        axis: usize,
88    ) -> TractResult<Option<TypedModelPatch>> {
89        let mut new_axis = self.axes.axis((io, axis))?.clone();
90        let repr = new_axis.repr;
91        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
92        let mut taps = tvec!();
93        for (ix, input) in node.inputs.iter().enumerate() {
94            let mut tap = patch.tap_model(model, *input)?;
95            if new_axis.inputs[ix].len() > 1 {
96                return Ok(None); // FIXME maybe
97            } else if new_axis.inputs[ix].is_empty() {
98                let insert_at = self.axes.rank(InOut::In(ix));
99                tap = patch.wire_node(
100                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
101                    AxisOp::Add(insert_at),
102                    &[tap],
103                )?[0];
104                new_axis.inputs[ix].push(insert_at);
105            }
106            taps.push(tap);
107        }
108        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
109            let insert_at = self.axes.rank(InOut::Out(0));
110            new_axis.outputs[0].push(insert_at);
111            Some(insert_at)
112        } else {
113            None
114        };
115        let new_expr = self
116            .axes
117            .iter_all_axes()
118            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
119            .collect_vec();
120        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
121        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
122        if let Some(position) = must_rm_axis {
123            wire = patch.wire_node(
124                format!("{}.prop_axis.{}.output", &node.name, repr),
125                AxisOp::Rm(position),
126                &wire,
127            )?;
128        }
129        patch.shunt_outside(model, node.id.into(), wire[0])?;
130        Ok(Some(patch))
131    }
132
133    pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
134        if self.operating_dt.is_integer() {
135            tvec!(i32::datum_type())
136        } else if self.operating_dt == f16::datum_type() {
137            tvec!(f16::datum_type(), f32::datum_type())
138        } else {
139            tvec!(self.operating_dt)
140        }
141    }
142}
143
144impl Debug for EinSum {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
147    }
148}
149
150impl Op for EinSum {
151    fn name(&self) -> StaticName {
152        "EinSum".into()
153    }
154
155    fn info(&self) -> TractResult<Vec<String>> {
156        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
157        if let Some(qp) = self.q_params {
158            info.push(format!("Quantized output: {qp:?}"));
159        }
160        Ok(info)
161    }
162
163    op_as_typed_op!();
164}
165
166impl EvalOp for EinSum {
167    fn is_stateless(&self) -> bool {
168        true
169    }
170
171    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
172        let output = if let Some(qp) = self.q_params {
173            eval::eval_q(&self.axes, qp, inputs)
174        } else {
175            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
176        }?;
177        Ok(tvec!(output.into_tvalue()))
178    }
179}
180
181impl TypedOp for EinSum {
182    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
183        let shapes = self.actual_input_shapes_from_facts(inputs)?;
184        for i in 0..inputs.len() {
185            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
186        }
187        for axis in self.axes.iter_all_axes() {
188            assert!(shapes
189                .iter()
190                .enumerate()
191                .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
192                .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
193                .is_ok());
194        }
195        if let Some(qp) = self.q_params {
196            ensure!(inputs.len() == 9);
197            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
198        } else {
199            Ok(tvec!(TypedFact::dt_shape(
200                self.operating_dt,
201                eval::output_shape(&self.axes, &shapes)?
202            )))
203        }
204    }
205
206    fn axes_mapping(
207        &self,
208        inputs: &[&TypedFact],
209        _outputs: &[&TypedFact],
210    ) -> TractResult<AxesMapping> {
211        let mut axes = self.axes.clone();
212        for (slot, i) in inputs.iter().enumerate() {
213            if i.datum_type.is_opaque()
214                && (i.opaque_fact().is_some_and(|of| {
215                    of.is::<BlockQuantFact>()
216                        || of.is::<PackedOpaqueFact>()
217                        || of.is::<PackedBlockQuantFact>()
218                }))
219            {
220                axes = axes
221                    .remove_axis_occurency(InOut::In(slot), i.rank())?
222                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
223            }
224        }
225        Ok(axes)
226    }
227
228    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
229        let shapes = self.actual_input_shapes_from_facts(inputs)?;
230        let oshape = eval::output_shape(&self.axes, &shapes)?;
231        let ks = self
232            .axes
233            .iter_all_axes()
234            .filter(|axis| axis.outputs[0].len() == 0)
235            .map(|axis| {
236                axis.inputs
237                    .iter()
238                    .enumerate()
239                    .flat_map(|(ix, axes)| {
240                        axes.iter()
241                            .map(|axis| shapes[ix][*axis].clone())
242                            .collect::<TVec<_>>()
243                            .into_iter()
244                    })
245                    .find(|d| !d.is_one())
246                    .unwrap_or_else(|| 1.to_dim())
247            })
248            .product::<TDim>();
249        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
250    }
251
252    fn slice(
253        &self,
254        patch: &mut TypedModelPatch,
255        model: &TypedModel,
256        node: &TypedNode,
257        prefix: &str,
258        inputs: &[OutletId],
259        output_axis: usize,
260        _start: &TDim,
261        _end: &TDim,
262    ) -> TractResult<Option<TVec<OutletId>>> {
263        let facts = model.node_input_facts(node.id)?;
264        let axis = self.axes.axis((InOut::Out(0), output_axis))?;
265        if facts
266            .iter()
267            .enumerate()
268            .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.datum_type.is_opaque())
269        {
270            Ok(None)
271        } else {
272            patch.wire_node(prefix, self.clone(), inputs).map(Some)
273        }
274    }
275
276    #[allow(unused_variables)]
277    fn change_axes(
278        &self,
279        model: &TypedModel,
280        node: &TypedNode,
281        io: InOut,
282        change: &AxisOp,
283    ) -> TractResult<Option<AxisChangeConsequence>> {
284        let (mut inputs, mut outputs) = self.axes.to_strs();
285        let interface: &mut String = match io {
286            InOut::In(i) => &mut inputs[i],
287            InOut::Out(o) => &mut outputs[o],
288        };
289        let mut axes: Vec<char> = interface.chars().collect();
290        match change {
291            AxisOp::Rm(rm) => {
292                axes.remove(*rm);
293            }
294            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
295            AxisOp::Move(from, to) => {
296                let c = axes.remove(*from);
297                axes.insert(*to, c);
298            }
299            _ => {
300                return Ok(None);
301            }
302        };
303        *interface = axes.into_iter().collect();
304        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
305        Ok(Some(AxisChangeConsequence {
306            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
307            wire_changes: tvec!((io, change.clone())),
308        }))
309    }
310
311    fn declutter_with_session(
312        &self,
313        session: &mut crate::optim::OptimizerSession,
314        model: &TypedModel,
315        node: &TypedNode,
316    ) -> TractResult<Option<TypedModelPatch>> {
317        if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
318            return Ok(Some(patch));
319        }
320        if let Some(patch) = declutter_broadcast(self, session, model, node)? {
321            return Ok(Some(patch));
322        }
323        Ok(None)
324    }
325
326    fn codegen(
327        &self,
328        model: &TypedModel,
329        node: &TypedNode,
330    ) -> TractResult<Option<TypedModelPatch>> {
331        if (self.q_params.is_none() && node.inputs.len() != 2)
332            || (self.q_params.is_some() && node.inputs.len() != 9)
333        {
334            return Ok(None);
335        }
336        einsum_matmul::detect_rule(&(), model, node, &node.name, self)
337    }
338
339    as_op!();
340}
341
342fn declutter_reshape_folding_input_axis(
343    op: &EinSum,
344    _session: &mut crate::optim::OptimizerSession,
345    model: &TypedModel,
346    node: &TypedNode,
347) -> TractResult<Option<TypedModelPatch>> {
348    for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
349        let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
350        if to.len() > 1 {
351            continue;
352        }
353        let mut axes = op.axes.clone();
354        let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
355        // add a temporary input to axes to hold the extra axes
356        let extra_input = node.inputs.len();
357        axes = axes.with_extra_input(extra_input)?;
358        for label in &extra_labels {
359            axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
360        }
361        let folded_axis = op.axes.axis((InOut::In(slot), at))?;
362        if folded_axis.outputs[0].len() > 1 {
363            return Ok(None);
364        };
365        let mut patch = TypedModelPatch::default();
366        let mut taps = patch.taps(model, &node.inputs)?;
367        for (input, tap) in taps.iter_mut().enumerate() {
368            if folded_axis.inputs[input].len() == 0 {
369                continue;
370            };
371            if folded_axis.inputs[input].len() > 1 {
372                return Ok(None);
373            };
374            let pos = folded_axis.inputs[input][0];
375            for label in &extra_labels {
376                axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
377            }
378            *tap = patch.wire_node(
379                format!("{}.reshape_folded_input_{}", node.name, input),
380                AxisOp::Reshape(pos, to.clone(), from.clone()),
381                &[*tap],
382            )?[0];
383        }
384        if folded_axis.outputs[0].len() == 1 {
385            let pos = folded_axis.outputs[0][0];
386            for label in &extra_labels {
387                axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
388            }
389        }
390        axes = axes.remove_slot(InOut::In(extra_input))?;
391        let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
392        if folded_axis.outputs[0].len() == 1 {
393            let pos = folded_axis.outputs[0][0];
394            wire = patch.wire_node(
395                format!("{}.reshape_folded_output", node.name),
396                AxisOp::Reshape(pos, from.clone(), to.clone()),
397                &wire,
398            )?;
399        }
400        patch.shunt_outside(model, node.id.into(), wire[0])?;
401        return Ok(Some(patch));
402    }
403    Ok(None)
404}
405
406fn declutter_broadcast(
407    op: &EinSum,
408    _session: &mut crate::optim::OptimizerSession,
409    model: &TypedModel,
410    node: &TypedNode,
411) -> TractResult<Option<TypedModelPatch>> {
412    for (ix, outlet) in node.inputs.iter().enumerate() {
413        let prec = model.node(outlet.node);
414        if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
415            let mut patch = TypedModelPatch::default();
416            let mut wires = patch.taps(model, &node.inputs)?;
417            wires[ix] = patch.tap_model(model, prec.inputs[0])?;
418            let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
419            patch.shunt_outside(model, node.id.into(), wire)?;
420            return Ok(Some(patch));
421        }
422    }
423    Ok(None)
424}