1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10pub mod einsum_matmul;
11pub mod kernel_selection;
12pub mod prefix_matmul;
13
14#[cfg(test)]
15mod proptest;
16
17use num_traits::One;
18use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
19use tract_linalg::mmm::PackedExoticFact;
20
21pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
22 if fact.is_plain() {
23 return Ok(Cow::Borrowed(&*fact.shape));
24 }
25 let Some(exotic_fact) = fact.exotic_fact() else {
26 bail!("Datum fact is exotic, but no exotic fact was found.")
27 };
28 if let Some(_bqf) = exotic_fact.downcast_ref::<BlockQuantFact>() {
29 Ok(Cow::Borrowed(&*fact.shape))
30 } else if let Some(pof) = exotic_fact.downcast_ref::<PackedBlockQuantFact>() {
31 Ok(Cow::Owned(
32 fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
33 ))
34 } else if let Some(pof) = exotic_fact.downcast_ref::<PackedExoticFact>() {
35 Ok(Cow::Owned(
36 fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
37 ))
38 } else {
39 bail!("Unsupported exotic fact {exotic_fact:?}")
40 }
41}
42
43#[derive(Clone, Hash, PartialEq, Eq)]
44pub struct EinSum {
45 pub axes: AxesMapping,
46 pub operating_dt: DatumType,
47 pub q_params: Option<DatumType>,
50}
51
52impl EinSum {
53 pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
54 EinSum { axes, operating_dt, q_params: None }
55 }
56
57 pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
58 EinSum { axes, operating_dt, q_params: Some(output_type) }
59 }
60
61 pub fn actual_input_shapes_from_facts<'m>(
62 &self,
63 inputs: &'m [impl Borrow<TypedFact>],
64 ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
65 ensure!(inputs.len() == self.axes.input_count());
66 let shapes: TVec<Cow<[TDim]>> = inputs
67 .iter()
68 .map(|t| block_quant_aware_input_shape(t.borrow()))
69 .collect::<TractResult<_>>()?;
70 ensure!(
71 shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
72 );
73 Ok(shapes)
74 }
75
76 #[allow(unused_variables)]
77 pub(crate) fn propagate_axis(
78 &self,
79 model: &TypedModel,
80 node: &TypedNode,
81 io: InOut,
82 axis: usize,
83 ) -> TractResult<Option<TypedModelPatch>> {
84 let mut new_axis = self.axes.axis((io, axis))?.clone();
85 let repr = new_axis.repr;
86 let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
87 let mut taps = tvec!();
88 for (ix, input) in node.inputs.iter().enumerate() {
89 let mut tap = patch.tap_model(model, *input)?;
90 rule_if!(new_axis.inputs[ix].len() <= 1); if new_axis.inputs[ix].is_empty() {
92 let insert_at = self.axes.rank(InOut::In(ix));
93 tap = patch.wire_node(
94 format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
95 AxisOp::Add(insert_at),
96 &[tap],
97 )?[0];
98 new_axis.inputs[ix].push(insert_at);
99 }
100 taps.push(tap);
101 }
102 let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
103 let insert_at = self.axes.rank(InOut::Out(0));
104 new_axis.outputs[0].push(insert_at);
105 Some(insert_at)
106 } else {
107 None
108 };
109 let new_expr = self
110 .axes
111 .iter_all_axes()
112 .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
113 .collect_vec();
114 let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
115 let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
116 if let Some(position) = must_rm_axis {
117 wire = patch.wire_node(
118 format!("{}.prop_axis.{}.output", &node.name, repr),
119 AxisOp::Rm(position),
120 &wire,
121 )?;
122 }
123 patch.shunt_outside(model, node.id.into(), wire[0])?;
124 Ok(Some(patch))
125 }
126
127 pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
128 if self.operating_dt.is_integer() {
129 tvec!(i32::datum_type())
130 } else if self.operating_dt == f16::datum_type() {
131 tvec!(f16::datum_type(), f32::datum_type())
132 } else {
133 tvec!(self.operating_dt)
134 }
135 }
136}
137
138impl Debug for EinSum {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
141 }
142}
143
144impl Op for EinSum {
145 fn name(&self) -> StaticName {
146 "EinSum".into()
147 }
148
149 fn info(&self) -> TractResult<Vec<String>> {
150 let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
151 if let Some(qp) = self.q_params {
152 info.push(format!("Quantized output: {qp:?}"));
153 }
154 Ok(info)
155 }
156
157 op_as_typed_op!();
158}
159
160impl EvalOp for EinSum {
161 fn is_stateless(&self) -> bool {
162 true
163 }
164
165 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
166 if inputs.iter().all(|i| i.datum_type().is_number() && i.is_plain()) {
167 let mut adhoc_model = TypedModel::default();
168 let mut wires = tvec!();
169 for (ix, input) in inputs.iter().enumerate() {
170 let fact = TypedFact::shape_and_dt_of(input);
171 let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
172 wires.push(wire);
173 }
174 let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
175 adhoc_model.select_output_outlets(&output)?;
176 let opti = adhoc_model.into_optimized()?;
177 if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
178 return opti.into_runnable()?.run(inputs);
179 }
180 }
181
182 let output = if let Some(qp) = self.q_params {
183 eval::eval_q(&self.axes, qp, inputs)
184 } else {
185 dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
186 }?;
187 Ok(tvec!(output.into_tvalue()))
188 }
189}
190
191impl TypedOp for EinSum {
192 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
193 let shapes = self.actual_input_shapes_from_facts(inputs)?;
194 for i in 0..inputs.len() {
195 ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
196 }
197 for axis in self.axes.iter_all_axes() {
198 assert!(
199 shapes
200 .iter()
201 .enumerate()
202 .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
203 .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
204 .is_ok()
205 );
206 }
207 if let Some(qp) = self.q_params {
208 ensure!(inputs.len() == 9);
209 Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
210 } else {
211 Ok(tvec!(TypedFact::dt_shape(
212 self.operating_dt,
213 eval::output_shape(&self.axes, &shapes)?
214 )))
215 }
216 }
217
218 fn input_roi(
219 &self,
220 model: &TypedModel,
221 node: &TypedNode,
222 ) -> TractResult<Option<TVec<Option<TDim>>>> {
223 let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
230 let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
231 let input_facts: TVec<&TypedFact> =
232 node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
233 let output_facts = tvec![output_fact];
234 let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
235 let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
236 let mapping = self.axes_mapping(&inputs_ref, &outputs_ref)?;
237 let roi_coord_axes: Vec<(usize, Symbol)> = roi
238 .symbols()
239 .into_iter()
240 .filter_map(|s| crate::ops::logic::sym_to_coord_axis(&s).map(|k| (k, s)))
241 .collect();
242
243 let project_for_input = |input_ix: usize| -> Option<TDim> {
244 let mut projected: Vec<Symbol> = vec![];
247 let mut preserved: Vec<(Symbol, usize)> = vec![];
248 for (out_pos, sym) in &roi_coord_axes {
249 let logical = mapping
250 .iter_all_axes()
251 .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
252 match logical.inputs[input_ix].first() {
253 None => projected.push(sym.clone()),
254 Some(&in_pos) => {
255 if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
256 return None;
257 }
258 preserved.push((sym.clone(), in_pos));
259 }
260 }
261 }
262 if projected.is_empty() {
263 let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
265 for (sym, in_pos) in &preserved {
266 if crate::ops::logic::sym_to_coord_axis(sym) != Some(*in_pos) {
267 let scope = sym.scope()?;
268 sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(*in_pos)));
269 }
270 }
271 return if sub_map.is_empty() {
272 Some(roi.clone())
273 } else {
274 roi.substitute_all(&sub_map).ok()
275 };
276 }
277 for p_sym in &projected {
280 for (k_sym, k_in_pos) in &preserved {
281 if let Some(band) = crate::optim::propagate_roi::recognise_chunked_band_project(
282 roi, p_sym, k_sym,
283 ) {
284 if crate::ops::logic::sym_to_coord_axis(k_sym) != Some(*k_in_pos) {
287 let scope = k_sym.scope()?;
288 let mut m: HashMap<Symbol, TDim> = HashMap::new();
289 m.insert(k_sym.clone(), TDim::Sym(scope.coord_sym(*k_in_pos)));
290 return band.substitute_all(&m).ok();
291 }
292 return Some(band);
293 }
294 }
295 }
296 None
297 };
298 let result: TVec<Option<TDim>> =
299 (0..node.inputs.len()).map(|ix| project_for_input(ix)).collect();
300 Ok(Some(result))
301 }
302
303 fn axes_mapping(
304 &self,
305 inputs: &[&TypedFact],
306 _outputs: &[&TypedFact],
307 ) -> TractResult<AxesMapping> {
308 let mut axes = self.axes.clone();
309 for (slot, i) in inputs.iter().enumerate() {
310 if i.is_exotic()
311 && (i.exotic_fact().is_some_and(|of| {
312 of.is::<PackedExoticFact>() || of.is::<PackedBlockQuantFact>()
313 }))
314 {
315 axes = axes
316 .remove_axis_occurency(InOut::In(slot), i.rank())?
317 .remove_axis_occurency(InOut::In(slot), i.rank())?;
318 }
319 }
320 Ok(axes)
321 }
322
323 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
324 let shapes = self.actual_input_shapes_from_facts(inputs)?;
325 let oshape = eval::output_shape(&self.axes, &shapes)?;
326 let ks = self
327 .axes
328 .iter_all_axes()
329 .filter(|axis| axis.outputs[0].len() == 0)
330 .map(|axis| {
331 axis.inputs
332 .iter()
333 .enumerate()
334 .flat_map(|(ix, axes)| {
335 axes.iter()
336 .map(|axis| shapes[ix][*axis].clone())
337 .collect::<TVec<_>>()
338 .into_iter()
339 })
340 .find(|d| !d.is_one())
341 .unwrap_or_else(|| 1.to_dim())
342 })
343 .product::<TDim>();
344 Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
345 }
346
347 fn slice(
348 &self,
349 patch: &mut TypedModelPatch,
350 model: &TypedModel,
351 node: &TypedNode,
352 prefix: &str,
353 inputs: &[OutletId],
354 output_axis: usize,
355 _start: &TDim,
356 _end: &TDim,
357 ) -> TractResult<Option<TVec<OutletId>>> {
358 let facts = model.node_input_facts(node.id)?;
359 let axis = self.axes.axis((InOut::Out(0), output_axis))?;
360 if facts
361 .iter()
362 .enumerate()
363 .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.is_exotic())
364 {
365 Ok(None)
366 } else {
367 patch.wire_node(prefix, self.clone(), inputs).map(Some)
368 }
369 }
370
371 #[allow(unused_variables)]
372 fn change_axes(
373 &self,
374 model: &TypedModel,
375 node: &TypedNode,
376 io: InOut,
377 change: &AxisOp,
378 ) -> TractResult<Option<AxisChangeConsequence>> {
379 let (mut inputs, mut outputs) = self.axes.to_strs();
380 let interface: &mut String = match io {
381 InOut::In(i) => &mut inputs[i],
382 InOut::Out(o) => &mut outputs[o],
383 };
384 let mut axes: Vec<char> = interface.chars().collect();
385 match change {
386 AxisOp::Rm(rm) => {
387 axes.remove(*rm);
388 }
389 AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
390 AxisOp::Move(from, to) => {
391 let c = axes.remove(*from);
392 axes.insert(*to, c);
393 }
394 _ => {
395 return Ok(None);
396 }
397 };
398 *interface = axes.into_iter().collect();
399 let axes = AxesMapping::from_strs(&inputs, &outputs)?;
400 Ok(Some(AxisChangeConsequence {
401 substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
402 wire_changes: tvec!((io, change.clone())),
403 }))
404 }
405
406 fn declutter_with_session(
407 &self,
408 session: &mut crate::optim::OptimizerSession,
409 model: &TypedModel,
410 node: &TypedNode,
411 ) -> TractResult<Option<TypedModelPatch>> {
412 if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
413 return Ok(Some(patch));
414 }
415 if let Some(patch) = declutter_broadcast(self, session, model, node)? {
416 return Ok(Some(patch));
417 }
418 if let Some(patch) = unit_k_to_broadcast_mul(self, model, node)? {
419 return Ok(Some(patch));
420 }
421 Ok(None)
422 }
423
424 fn codegen(
425 &self,
426 model: &TypedModel,
427 node: &TypedNode,
428 ) -> TractResult<Option<TypedModelPatch>> {
429 rule_if!(
430 (self.q_params.is_none() && node.inputs.len() == 2)
431 || (self.q_params.is_some() && node.inputs.len() == 9)
432 );
433 if let Some(patch) = unit_k_to_broadcast_mul(self, model, node)? {
439 return Ok(Some(patch));
440 }
441 einsum_matmul::detect_rule(&(), model, node, &node.name, self)
442 }
443
444 as_op!();
445}
446
447fn declutter_reshape_folding_input_axis(
448 op: &EinSum,
449 _session: &mut crate::optim::OptimizerSession,
450 model: &TypedModel,
451 node: &TypedNode,
452) -> TractResult<Option<TypedModelPatch>> {
453 for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
454 let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
455 if to.len() > 1 {
456 continue;
457 }
458 let mut axes = op.axes.clone();
459 let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
460 let extra_input = node.inputs.len();
462 axes = axes.with_extra_input(extra_input)?;
463 for label in &extra_labels {
464 axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
465 }
466 let folded_axis = op.axes.axis((InOut::In(slot), at))?;
467 rule_if!(folded_axis.outputs[0].len() <= 1);
468 let mut patch = TypedModelPatch::default();
469 let mut taps = patch.taps(model, &node.inputs)?;
470 for (input, tap) in taps.iter_mut().enumerate() {
471 if folded_axis.inputs[input].len() == 0 {
472 continue;
473 };
474 rule_if!(folded_axis.inputs[input].len() <= 1);
475 let pos = folded_axis.inputs[input][0];
476 for label in &extra_labels {
477 axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
478 }
479 *tap = patch.wire_node(
480 format!("{}.reshape_folded_input_{}", node.name, input),
481 AxisOp::Reshape(pos, to.clone(), from.clone()),
482 &[*tap],
483 )?[0];
484 }
485 if folded_axis.outputs[0].len() == 1 {
486 let pos = folded_axis.outputs[0][0];
487 for label in &extra_labels {
488 axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
489 }
490 }
491 axes = axes.remove_slot(InOut::In(extra_input))?;
492 let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
493 if folded_axis.outputs[0].len() == 1 {
494 let pos = folded_axis.outputs[0][0];
495 wire = patch.wire_node(
496 format!("{}.reshape_folded_output", node.name),
497 AxisOp::Reshape(pos, from.clone(), to.clone()),
498 &wire,
499 )?;
500 }
501 patch.shunt_outside(model, node.id.into(), wire[0])?;
502 return Ok(Some(patch));
503 }
504 Ok(None)
505}
506
507fn declutter_broadcast(
508 op: &EinSum,
509 _session: &mut crate::optim::OptimizerSession,
510 model: &TypedModel,
511 node: &TypedNode,
512) -> TractResult<Option<TypedModelPatch>> {
513 for (ix, outlet) in node.inputs.iter().enumerate() {
514 let prec = model.node(outlet.node);
515 if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
516 let mut patch = TypedModelPatch::default();
517 let mut wires = patch.taps(model, &node.inputs)?;
518 wires[ix] = patch.tap_model(model, prec.inputs[0])?;
519 let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
520 patch.shunt_outside(model, node.id.into(), wire)?;
521 return Ok(Some(patch));
522 }
523 }
524 Ok(None)
525}
526
527fn unit_k_to_broadcast_mul(
542 op: &EinSum,
543 model: &TypedModel,
544 node: &TypedNode,
545) -> TractResult<Option<TypedModelPatch>> {
546 if op.q_params.is_some() || node.inputs.len() != 2 {
547 return Ok(None);
548 }
549 let input_facts = model.node_input_facts(node.id)?;
550 let input_shapes = op.actual_input_shapes_from_facts(&input_facts)?;
551 let k_axes: TVec<&Axis> = op
552 .axes
553 .iter_all_axes()
554 .filter(|a| a.inputs[0].len() == 1 && a.inputs[1].len() == 1 && a.outputs[0].is_empty())
555 .collect();
556 let any_nontrivial_k = k_axes.iter().any(|a| {
558 !input_shapes[0][a.inputs[0][0]].is_one() || !input_shapes[1][a.inputs[1][0]].is_one()
559 });
560 if any_nontrivial_k {
561 return Ok(None);
562 }
563 let has_deconv_sum_consumer = node.outputs.first().map_or(false, |o| {
571 o.successors.iter().any(|inlet| model.node(inlet.node).op.name() == "DeconvSum")
572 });
573 if !has_deconv_sum_consumer {
574 return Ok(None);
575 }
576
577 let one = TDim::one();
578 for axis in op.axes.iter_all_axes() {
580 let in_left =
581 axis.inputs[0].first().map(|pos| &input_shapes[0][*pos]).unwrap_or(&one) != &one;
582 let in_right =
583 axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one;
584 let in_out = !axis.outputs[0].is_empty();
585 if (in_left ^ in_right) && !in_out {
586 return Ok(None);
587 }
588 }
589
590 let c_axes: Vec<char> = op.axes.axes(InOut::Out(0)).map(|a| a.repr).collect();
591 if c_axes.is_empty() {
592 return Ok(None);
593 }
594
595 let k_reprs: TVec<char> = k_axes.iter().map(|a| a.repr).collect();
596 let mut patch = TypedModelPatch::new("EinSum unit-K → broadcast Mul");
597 let mut wires: TVec<OutletId> = patch.taps(model, &node.inputs)?;
598 let name = &node.name;
599
600 for (slot, wire) in wires.iter_mut().enumerate() {
601 let cur_dt = patch.outlet_fact(*wire)?.datum_type;
604 if cur_dt != op.operating_dt {
605 *wire = patch.wire_node(
606 format!("{name}.cast_in{slot}"),
607 crate::ops::cast::cast(op.operating_dt),
608 &[*wire],
609 )?[0];
610 }
611
612 let mut k_positions: Vec<usize> = k_axes.iter().map(|a| a.inputs[slot][0]).collect();
614 k_positions.sort_by(|a, b| b.cmp(a));
615 for (i, pos) in k_positions.into_iter().enumerate() {
616 *wire =
617 patch.wire_node(format!("{name}.rm_k_in{slot}.{i}"), AxisOp::Rm(pos), &[*wire])?[0];
618 }
619
620 let mut current: Vec<char> = op
621 .axes
622 .axes(InOut::In(slot))
623 .map(|a| a.repr)
624 .filter(|c| !k_reprs.contains(c))
625 .collect();
626
627 let mut to_drop: Vec<(usize, char)> = current
629 .iter()
630 .enumerate()
631 .filter(|(_, c)| !c_axes.contains(c))
632 .map(|(i, c)| (i, *c))
633 .collect();
634 to_drop.sort_by(|a, b| b.0.cmp(&a.0));
635 for (pos, c) in to_drop {
636 *wire = patch.wire_node(
637 format!("{name}.rm_extra_in{slot}_{c}"),
638 AxisOp::Rm(pos),
639 &[*wire],
640 )?[0];
641 current.remove(pos);
642 }
643
644 for (target_pos, &t) in c_axes.iter().enumerate() {
646 if !current.contains(&t) {
647 *wire = patch.wire_node(
648 format!("{name}.add_in{slot}_{t}"),
649 AxisOp::Add(target_pos),
650 &[*wire],
651 )?[0];
652 current.insert(target_pos, t);
653 }
654 }
655
656 for (target_pos, &t) in c_axes.iter().enumerate() {
658 let cur_pos = current.iter().position(|&c| c == t).unwrap();
659 if cur_pos != target_pos {
660 *wire = patch.wire_node(
661 format!("{name}.move_in{slot}_{t}"),
662 AxisOp::Move(cur_pos, target_pos),
663 &[*wire],
664 )?[0];
665 let removed = current.remove(cur_pos);
666 current.insert(target_pos, removed);
667 }
668 }
669 }
670
671 let result = patch.wire_node(name, crate::ops::math::mul(), &wires)?;
672 patch.shunt_outside(model, node.id.into(), result[0])?;
673 Ok(Some(patch))
674}