Skip to main content

tract_core/ops/einsum/
mod.rs

1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12pub mod einsum_matmul;
13pub mod kernel_selection;
14pub mod prefix_matmul;
15
16#[cfg(test)]
17mod proptest;
18
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedExoticFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
24    if fact.is_plain() {
25        return Ok(Cow::Borrowed(&*fact.shape));
26    }
27    let Some(exotic_fact) = fact.exotic_fact() else {
28        bail!("Datum fact is exotic, but no exotic fact was found.")
29    };
30    if let Some(_bqf) = exotic_fact.downcast_ref::<BlockQuantFact>() {
31        Ok(Cow::Borrowed(&*fact.shape))
32    } else if let Some(pof) = exotic_fact.downcast_ref::<PackedBlockQuantFact>() {
33        Ok(Cow::Owned(
34            fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
35        ))
36    } else if let Some(pof) = exotic_fact.downcast_ref::<PackedExoticFact>() {
37        Ok(Cow::Owned(
38            fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
39        ))
40    } else {
41        bail!("Unsupported exotic fact {exotic_fact:?}")
42    }
43}
44
45#[derive(Clone, Hash, PartialEq, Eq)]
46pub struct EinSum {
47    pub axes: AxesMapping,
48    pub operating_dt: DatumType,
49    // if present, assume we're a binary op.
50    // 9 inputs are: A,B,bias, A0,Ascale, B0,BScale, C0,Cscale
51    pub q_params: Option<DatumType>,
52}
53
54impl EinSum {
55    pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
56        EinSum { axes, operating_dt, q_params: None }
57    }
58
59    pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
60        EinSum { axes, operating_dt, q_params: Some(output_type) }
61    }
62
63    pub fn actual_input_shapes_from_facts<'m>(
64        &self,
65        inputs: &'m [impl Borrow<TypedFact>],
66    ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
67        ensure!(inputs.len() == self.axes.input_count());
68        let shapes: TVec<Cow<[TDim]>> = inputs
69            .iter()
70            .map(|t| block_quant_aware_input_shape(t.borrow()))
71            .collect::<TractResult<_>>()?;
72        ensure!(
73            shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
74        );
75        Ok(shapes)
76    }
77
78    #[allow(unused_variables)]
79    pub(crate) fn propagate_axis(
80        &self,
81        model: &TypedModel,
82        node: &TypedNode,
83        io: InOut,
84        axis: usize,
85    ) -> TractResult<Option<TypedModelPatch>> {
86        let mut new_axis = self.axes.axis((io, axis))?.clone();
87        let repr = new_axis.repr;
88        let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
89        let mut taps = tvec!();
90        for (ix, input) in node.inputs.iter().enumerate() {
91            let mut tap = patch.tap_model(model, *input)?;
92            if new_axis.inputs[ix].len() > 1 {
93                return Ok(None); // FIXME maybe
94            } else if new_axis.inputs[ix].is_empty() {
95                let insert_at = self.axes.rank(InOut::In(ix));
96                tap = patch.wire_node(
97                    format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
98                    AxisOp::Add(insert_at),
99                    &[tap],
100                )?[0];
101                new_axis.inputs[ix].push(insert_at);
102            }
103            taps.push(tap);
104        }
105        let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
106            let insert_at = self.axes.rank(InOut::Out(0));
107            new_axis.outputs[0].push(insert_at);
108            Some(insert_at)
109        } else {
110            None
111        };
112        let new_expr = self
113            .axes
114            .iter_all_axes()
115            .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
116            .collect_vec();
117        let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
118        let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
119        if let Some(position) = must_rm_axis {
120            wire = patch.wire_node(
121                format!("{}.prop_axis.{}.output", &node.name, repr),
122                AxisOp::Rm(position),
123                &wire,
124            )?;
125        }
126        patch.shunt_outside(model, node.id.into(), wire[0])?;
127        Ok(Some(patch))
128    }
129
130    pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
131        if self.operating_dt.is_integer() {
132            tvec!(i32::datum_type())
133        } else if self.operating_dt == f16::datum_type() {
134            tvec!(f16::datum_type(), f32::datum_type())
135        } else {
136            tvec!(self.operating_dt)
137        }
138    }
139}
140
141impl Debug for EinSum {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
144    }
145}
146
147impl Op for EinSum {
148    fn name(&self) -> StaticName {
149        "EinSum".into()
150    }
151
152    fn info(&self) -> TractResult<Vec<String>> {
153        let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
154        if let Some(qp) = self.q_params {
155            info.push(format!("Quantized output: {qp:?}"));
156        }
157        Ok(info)
158    }
159
160    op_as_typed_op!();
161}
162
163impl EvalOp for EinSum {
164    fn is_stateless(&self) -> bool {
165        true
166    }
167
168    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
169        if inputs.iter().all(|i| i.datum_type().is_number() && i.is_plain()) {
170            let mut adhoc_model = TypedModel::default();
171            let mut wires = tvec!();
172            for (ix, input) in inputs.iter().enumerate() {
173                let fact = TypedFact::shape_and_dt_of(input);
174                let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
175                wires.push(wire);
176            }
177            let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
178            adhoc_model.select_output_outlets(&output)?;
179            let opti = adhoc_model.into_optimized()?;
180            if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
181                return opti.into_runnable()?.run(inputs);
182            }
183        }
184
185        let output = if let Some(qp) = self.q_params {
186            eval::eval_q(&self.axes, qp, inputs)
187        } else {
188            dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
189        }?;
190        Ok(tvec!(output.into_tvalue()))
191    }
192}
193
194impl TypedOp for EinSum {
195    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
196        let shapes = self.actual_input_shapes_from_facts(inputs)?;
197        for i in 0..inputs.len() {
198            ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
199        }
200        for axis in self.axes.iter_all_axes() {
201            assert!(
202                shapes
203                    .iter()
204                    .enumerate()
205                    .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
206                    .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
207                    .is_ok()
208            );
209        }
210        if let Some(qp) = self.q_params {
211            ensure!(inputs.len() == 9);
212            Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
213        } else {
214            Ok(tvec!(TypedFact::dt_shape(
215                self.operating_dt,
216                eval::output_shape(&self.axes, &shapes)?
217            )))
218        }
219    }
220
221    fn input_roi(
222        &self,
223        model: &TypedModel,
224        node: &TypedNode,
225    ) -> TractResult<Option<TVec<Option<TDim>>>> {
226        crate::optim::propagate_roi::bubble_roi(model, node)
227    }
228
229    fn axes_mapping(
230        &self,
231        inputs: &[&TypedFact],
232        _outputs: &[&TypedFact],
233    ) -> TractResult<AxesMapping> {
234        let mut axes = self.axes.clone();
235        for (slot, i) in inputs.iter().enumerate() {
236            if i.is_exotic()
237                && (i.exotic_fact().is_some_and(|of| {
238                    of.is::<PackedExoticFact>() || of.is::<PackedBlockQuantFact>()
239                }))
240            {
241                axes = axes
242                    .remove_axis_occurency(InOut::In(slot), i.rank())?
243                    .remove_axis_occurency(InOut::In(slot), i.rank())?;
244            }
245        }
246        Ok(axes)
247    }
248
249    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
250        let shapes = self.actual_input_shapes_from_facts(inputs)?;
251        let oshape = eval::output_shape(&self.axes, &shapes)?;
252        let ks = self
253            .axes
254            .iter_all_axes()
255            .filter(|axis| axis.outputs[0].len() == 0)
256            .map(|axis| {
257                axis.inputs
258                    .iter()
259                    .enumerate()
260                    .flat_map(|(ix, axes)| {
261                        axes.iter()
262                            .map(|axis| shapes[ix][*axis].clone())
263                            .collect::<TVec<_>>()
264                            .into_iter()
265                    })
266                    .find(|d| !d.is_one())
267                    .unwrap_or_else(|| 1.to_dim())
268            })
269            .product::<TDim>();
270        Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
271    }
272
273    fn slice(
274        &self,
275        patch: &mut TypedModelPatch,
276        model: &TypedModel,
277        node: &TypedNode,
278        prefix: &str,
279        inputs: &[OutletId],
280        output_axis: usize,
281        _start: &TDim,
282        _end: &TDim,
283    ) -> TractResult<Option<TVec<OutletId>>> {
284        let facts = model.node_input_facts(node.id)?;
285        let axis = self.axes.axis((InOut::Out(0), output_axis))?;
286        if facts
287            .iter()
288            .enumerate()
289            .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.is_exotic())
290        {
291            Ok(None)
292        } else {
293            patch.wire_node(prefix, self.clone(), inputs).map(Some)
294        }
295    }
296
297    #[allow(unused_variables)]
298    fn change_axes(
299        &self,
300        model: &TypedModel,
301        node: &TypedNode,
302        io: InOut,
303        change: &AxisOp,
304    ) -> TractResult<Option<AxisChangeConsequence>> {
305        let (mut inputs, mut outputs) = self.axes.to_strs();
306        let interface: &mut String = match io {
307            InOut::In(i) => &mut inputs[i],
308            InOut::Out(o) => &mut outputs[o],
309        };
310        let mut axes: Vec<char> = interface.chars().collect();
311        match change {
312            AxisOp::Rm(rm) => {
313                axes.remove(*rm);
314            }
315            AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
316            AxisOp::Move(from, to) => {
317                let c = axes.remove(*from);
318                axes.insert(*to, c);
319            }
320            _ => {
321                return Ok(None);
322            }
323        };
324        *interface = axes.into_iter().collect();
325        let axes = AxesMapping::from_strs(&inputs, &outputs)?;
326        Ok(Some(AxisChangeConsequence {
327            substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
328            wire_changes: tvec!((io, change.clone())),
329        }))
330    }
331
332    fn declutter_with_session(
333        &self,
334        session: &mut crate::optim::OptimizerSession,
335        model: &TypedModel,
336        node: &TypedNode,
337    ) -> TractResult<Option<TypedModelPatch>> {
338        if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
339            return Ok(Some(patch));
340        }
341        if let Some(patch) = declutter_broadcast(self, session, model, node)? {
342            return Ok(Some(patch));
343        }
344        Ok(None)
345    }
346
347    fn codegen(
348        &self,
349        model: &TypedModel,
350        node: &TypedNode,
351    ) -> TractResult<Option<TypedModelPatch>> {
352        rule_if!(
353            (self.q_params.is_none() && node.inputs.len() == 2)
354                || (self.q_params.is_some() && node.inputs.len() == 9)
355        );
356        einsum_matmul::detect_rule(&(), model, node, &node.name, self)
357    }
358
359    as_op!();
360}
361
362fn declutter_reshape_folding_input_axis(
363    op: &EinSum,
364    _session: &mut crate::optim::OptimizerSession,
365    model: &TypedModel,
366    node: &TypedNode,
367) -> TractResult<Option<TypedModelPatch>> {
368    for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
369        let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
370        if to.len() > 1 {
371            continue;
372        }
373        let mut axes = op.axes.clone();
374        let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
375        // add a temporary input to axes to hold the extra axes
376        let extra_input = node.inputs.len();
377        axes = axes.with_extra_input(extra_input)?;
378        for label in &extra_labels {
379            axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
380        }
381        let folded_axis = op.axes.axis((InOut::In(slot), at))?;
382        if folded_axis.outputs[0].len() > 1 {
383            return Ok(None);
384        };
385        let mut patch = TypedModelPatch::default();
386        let mut taps = patch.taps(model, &node.inputs)?;
387        for (input, tap) in taps.iter_mut().enumerate() {
388            if folded_axis.inputs[input].len() == 0 {
389                continue;
390            };
391            if folded_axis.inputs[input].len() > 1 {
392                return Ok(None);
393            };
394            let pos = folded_axis.inputs[input][0];
395            for label in &extra_labels {
396                axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
397            }
398            *tap = patch.wire_node(
399                format!("{}.reshape_folded_input_{}", node.name, input),
400                AxisOp::Reshape(pos, to.clone(), from.clone()),
401                &[*tap],
402            )?[0];
403        }
404        if folded_axis.outputs[0].len() == 1 {
405            let pos = folded_axis.outputs[0][0];
406            for label in &extra_labels {
407                axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
408            }
409        }
410        axes = axes.remove_slot(InOut::In(extra_input))?;
411        let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
412        if folded_axis.outputs[0].len() == 1 {
413            let pos = folded_axis.outputs[0][0];
414            wire = patch.wire_node(
415                format!("{}.reshape_folded_output", node.name),
416                AxisOp::Reshape(pos, from.clone(), to.clone()),
417                &wire,
418            )?;
419        }
420        patch.shunt_outside(model, node.id.into(), wire[0])?;
421        return Ok(Some(patch));
422    }
423    Ok(None)
424}
425
426fn declutter_broadcast(
427    op: &EinSum,
428    _session: &mut crate::optim::OptimizerSession,
429    model: &TypedModel,
430    node: &TypedNode,
431) -> TractResult<Option<TypedModelPatch>> {
432    for (ix, outlet) in node.inputs.iter().enumerate() {
433        let prec = model.node(outlet.node);
434        if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
435            let mut patch = TypedModelPatch::default();
436            let mut wires = patch.taps(model, &node.inputs)?;
437            wires[ix] = patch.tap_model(model, prec.inputs[0])?;
438            let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
439            patch.shunt_outside(model, node.id.into(), wire)?;
440            return Ok(Some(patch));
441        }
442    }
443    Ok(None)
444}