Skip to main content

tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12pub mod einsum_matmul;
13pub mod kernel_selection;
14pub mod prefix_matmul;
15
16#[cfg(test)]
17mod proptest;
18
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedOpaqueFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
24    if !fact.datum_type.is_opaque() {
25        return Ok(Cow::Borrowed(&*fact.shape));
26    }
27    let Some(opaque_fact) = fact.opaque_fact() else {
28        bail!("Datum fact is opaque, but no opaque fact was found.")
29    };
30    if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() {
31        Ok(Cow::Owned(
32            fact.shape.iter().cloned().chain(bqf.shape().iter().map(|d| d.to_dim())).collect_vec(),
33        ))
34    } else if let Some(pof) = opaque_fact.downcast_ref::<PackedBlockQuantFact>() {
35        Ok(Cow::Owned(
36            fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
37        ))
38    } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
39        Ok(Cow::Owned(
40            fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
41        ))
42    } else {
43        bail!("Unsupported opaque fact {opaque_fact:?}")
44    }
45}
46
47#[derive(Clone, Hash, PartialEq)]
48pub struct EinSum {
49    pub axes: AxesMapping,
50    pub operating_dt: DatumType,
51    // if present, assume we're a binary op.
52    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
53    pub q_params: Option<DatumType>,
54}
55
56impl EinSum {
57    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
58        EinSum { axes, operating_dt, q_params: None }
59    }
60
61    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
62        EinSum { axes, operating_dt, q_params: Some(output_type) }
63    }
64
65    pub fn actual_input_shapes_from_facts<'m>(
66        &self,
67        inputs: &'m [impl Borrow<TypedFact>],
68    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
69        ensure!(inputs.len() == self.axes.input_count());
70        let shapes: TVec<Cow<[TDim]>> = inputs
71            .iter()
72            .map(|t| block_quant_aware_input_shape(t.borrow()))
73            .collect::<TractResult<_>>()?;
74        ensure!(
75            shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
76        );
77        Ok(shapes)
78    }
79
80    #[allow(unused_variables)]
81    pub(crate) fn propagate_axis(
82        &self,
83        model: &TypedModel,
84        node: &TypedNode,
85        io: InOut,
86        axis: usize,
87    ) -> TractResult<Option<TypedModelPatch>> {
88        let mut new_axis = self.axes.axis((io, axis))?.clone();
89        let repr = new_axis.repr;
90        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
91        let mut taps = tvec!();
92        for (ix, input) in node.inputs.iter().enumerate() {
93            let mut tap = patch.tap_model(model, *input)?;
94            if new_axis.inputs[ix].len() > 1 {
95                return Ok(None); // FIXME maybe
96            } else if new_axis.inputs[ix].is_empty() {
97                let insert_at = self.axes.rank(InOut::In(ix));
98                tap = patch.wire_node(
99                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
100                    AxisOp::Add(insert_at),
101                    &[tap],
102                )?[0];
103                new_axis.inputs[ix].push(insert_at);
104            }
105            taps.push(tap);
106        }
107        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
108            let insert_at = self.axes.rank(InOut::Out(0));
109            new_axis.outputs[0].push(insert_at);
110            Some(insert_at)
111        } else {
112            None
113        };
114        let new_expr = self
115            .axes
116            .iter_all_axes()
117            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
118            .collect_vec();
119        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
120        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
121        if let Some(position) = must_rm_axis {
122            wire = patch.wire_node(
123                format!("{}.prop_axis.{}.output", &node.name, repr),
124                AxisOp::Rm(position),
125                &wire,
126            )?;
127        }
128        patch.shunt_outside(model, node.id.into(), wire[0])?;
129        Ok(Some(patch))
130    }
131
132    pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
133        if self.operating_dt.is_integer() {
134            tvec!(i32::datum_type())
135        } else if self.operating_dt == f16::datum_type() {
136            tvec!(f16::datum_type(), f32::datum_type())
137        } else {
138            tvec!(self.operating_dt)
139        }
140    }
141}
142
143impl Debug for EinSum {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
146    }
147}
148
149impl Op for EinSum {
150    fn name(&self) -> StaticName {
151        "EinSum".into()
152    }
153
154    fn info(&self) -> TractResult<Vec<String>> {
155        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
156        if let Some(qp) = self.q_params {
157            info.push(format!("Quantized output: {qp:?}"));
158        }
159        Ok(info)
160    }
161
162    op_as_typed_op!();
163}
164
165impl EvalOp for EinSum {
166    fn is_stateless(&self) -> bool {
167        true
168    }
169
170    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
171        if inputs.iter().all(|i| i.datum_type().is_number()) {
172            let mut adhoc_model = TypedModel::default();
173            let mut wires = tvec!();
174            for (ix, input) in inputs.iter().enumerate() {
175                let fact = TypedFact::shape_and_dt_of(input);
176                let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
177                wires.push(wire);
178            }
179            let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
180            adhoc_model.set_output_outlets(&output)?;
181            let opti = adhoc_model.into_optimized()?;
182            if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
183                return opti.into_runnable()?.run(inputs);
184            }
185        }
186
187        let output = if let Some(qp) = self.q_params {
188            eval::eval_q(&self.axes, qp, inputs)
189        } else {
190            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
191        }?;
192        Ok(tvec!(output.into_tvalue()))
193    }
194}
195
196impl TypedOp for EinSum {
197    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
198        let shapes = self.actual_input_shapes_from_facts(inputs)?;
199        for i in 0..inputs.len() {
200            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
201        }
202        for axis in self.axes.iter_all_axes() {
203            assert!(
204                shapes
205                    .iter()
206                    .enumerate()
207                    .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
208                    .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
209                    .is_ok()
210            );
211        }
212        if let Some(qp) = self.q_params {
213            ensure!(inputs.len() == 9);
214            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
215        } else {
216            Ok(tvec!(TypedFact::dt_shape(
217                self.operating_dt,
218                eval::output_shape(&self.axes, &shapes)?
219            )))
220        }
221    }
222
223    fn axes_mapping(
224        &self,
225        inputs: &[&TypedFact],
226        _outputs: &[&TypedFact],
227    ) -> TractResult<AxesMapping> {
228        let mut axes = self.axes.clone();
229        for (slot, i) in inputs.iter().enumerate() {
230            if i.datum_type.is_opaque()
231                && (i.opaque_fact().is_some_and(|of| {
232                    of.is::<BlockQuantFact>()
233                        || of.is::<PackedOpaqueFact>()
234                        || of.is::<PackedBlockQuantFact>()
235                }))
236            {
237                axes = axes
238                    .remove_axis_occurency(InOut::In(slot), i.rank())?
239                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
240            }
241        }
242        Ok(axes)
243    }
244
245    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
246        let shapes = self.actual_input_shapes_from_facts(inputs)?;
247        let oshape = eval::output_shape(&self.axes, &shapes)?;
248        let ks = self
249            .axes
250            .iter_all_axes()
251            .filter(|axis| axis.outputs[0].len() == 0)
252            .map(|axis| {
253                axis.inputs
254                    .iter()
255                    .enumerate()
256                    .flat_map(|(ix, axes)| {
257                        axes.iter()
258                            .map(|axis| shapes[ix][*axis].clone())
259                            .collect::<TVec<_>>()
260                            .into_iter()
261                    })
262                    .find(|d| !d.is_one())
263                    .unwrap_or_else(|| 1.to_dim())
264            })
265            .product::<TDim>();
266        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
267    }
268
269    fn slice(
270        &self,
271        patch: &mut TypedModelPatch,
272        model: &TypedModel,
273        node: &TypedNode,
274        prefix: &str,
275        inputs: &[OutletId],
276        output_axis: usize,
277        _start: &TDim,
278        _end: &TDim,
279    ) -> TractResult<Option<TVec<OutletId>>> {
280        let facts = model.node_input_facts(node.id)?;
281        let axis = self.axes.axis((InOut::Out(0), output_axis))?;
282        if facts
283            .iter()
284            .enumerate()
285            .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.datum_type.is_opaque())
286        {
287            Ok(None)
288        } else {
289            patch.wire_node(prefix, self.clone(), inputs).map(Some)
290        }
291    }
292
293    #[allow(unused_variables)]
294    fn change_axes(
295        &self,
296        model: &TypedModel,
297        node: &TypedNode,
298        io: InOut,
299        change: &AxisOp,
300    ) -> TractResult<Option<AxisChangeConsequence>> {
301        let (mut inputs, mut outputs) = self.axes.to_strs();
302        let interface: &mut String = match io {
303            InOut::In(i) => &mut inputs[i],
304            InOut::Out(o) => &mut outputs[o],
305        };
306        let mut axes: Vec<char> = interface.chars().collect();
307        match change {
308            AxisOp::Rm(rm) => {
309                axes.remove(*rm);
310            }
311            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
312            AxisOp::Move(from, to) => {
313                let c = axes.remove(*from);
314                axes.insert(*to, c);
315            }
316            _ => {
317                return Ok(None);
318            }
319        };
320        *interface = axes.into_iter().collect();
321        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
322        Ok(Some(AxisChangeConsequence {
323            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
324            wire_changes: tvec!((io, change.clone())),
325        }))
326    }
327
328    fn declutter_with_session(
329        &self,
330        session: &mut crate::optim::OptimizerSession,
331        model: &TypedModel,
332        node: &TypedNode,
333    ) -> TractResult<Option<TypedModelPatch>> {
334        if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
335            return Ok(Some(patch));
336        }
337        if let Some(patch) = declutter_broadcast(self, session, model, node)? {
338            return Ok(Some(patch));
339        }
340        Ok(None)
341    }
342
343    fn codegen(
344        &self,
345        model: &TypedModel,
346        node: &TypedNode,
347    ) -> TractResult<Option<TypedModelPatch>> {
348        rule_if!(
349            (self.q_params.is_none() && node.inputs.len() == 2)
350                || (self.q_params.is_some() && node.inputs.len() == 9)
351        );
352        einsum_matmul::detect_rule(&(), model, node, &node.name, self)
353    }
354
355    as_op!();
356}
357
358fn declutter_reshape_folding_input_axis(
359    op: &EinSum,
360    _session: &mut crate::optim::OptimizerSession,
361    model: &TypedModel,
362    node: &TypedNode,
363) -> TractResult<Option<TypedModelPatch>> {
364    for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
365        let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
366        if to.len() > 1 {
367            continue;
368        }
369        let mut axes = op.axes.clone();
370        let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
371        // add a temporary input to axes to hold the extra axes
372        let extra_input = node.inputs.len();
373        axes = axes.with_extra_input(extra_input)?;
374        for label in &extra_labels {
375            axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
376        }
377        let folded_axis = op.axes.axis((InOut::In(slot), at))?;
378        if folded_axis.outputs[0].len() > 1 {
379            return Ok(None);
380        };
381        let mut patch = TypedModelPatch::default();
382        let mut taps = patch.taps(model, &node.inputs)?;
383        for (input, tap) in taps.iter_mut().enumerate() {
384            if folded_axis.inputs[input].len() == 0 {
385                continue;
386            };
387            if folded_axis.inputs[input].len() > 1 {
388                return Ok(None);
389            };
390            let pos = folded_axis.inputs[input][0];
391            for label in &extra_labels {
392                axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
393            }
394            *tap = patch.wire_node(
395                format!("{}.reshape_folded_input_{}", node.name, input),
396                AxisOp::Reshape(pos, to.clone(), from.clone()),
397                &[*tap],
398            )?[0];
399        }
400        if folded_axis.outputs[0].len() == 1 {
401            let pos = folded_axis.outputs[0][0];
402            for label in &extra_labels {
403                axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
404            }
405        }
406        axes = axes.remove_slot(InOut::In(extra_input))?;
407        let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
408        if folded_axis.outputs[0].len() == 1 {
409            let pos = folded_axis.outputs[0][0];
410            wire = patch.wire_node(
411                format!("{}.reshape_folded_output", node.name),
412                AxisOp::Reshape(pos, from.clone(), to.clone()),
413                &wire,
414            )?;
415        }
416        patch.shunt_outside(model, node.id.into(), wire[0])?;
417        return Ok(Some(patch));
418    }
419    Ok(None)
420}
421
422fn declutter_broadcast(
423    op: &EinSum,
424    _session: &mut crate::optim::OptimizerSession,
425    model: &TypedModel,
426    node: &TypedNode,
427) -> TractResult<Option<TypedModelPatch>> {
428    for (ix, outlet) in node.inputs.iter().enumerate() {
429        let prec = model.node(outlet.node);
430        if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
431            let mut patch = TypedModelPatch::default();
432            let mut wires = patch.taps(model, &node.inputs)?;
433            wires[ix] = patch.tap_model(model, prec.inputs[0])?;
434            let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
435            patch.shunt_outside(model, node.id.into(), wire)?;
436            return Ok(Some(patch));
437        }
438    }
439    Ok(None)
440}