1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12pub mod einsum_matmul;
13pub mod kernel_selection;
14pub mod prefix_matmul;
15
16#[cfg(test)]
17mod proptest;
18
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedExoticFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
24 if fact.is_plain() {
25 return Ok(Cow::Borrowed(&*fact.shape));
26 }
27 let Some(exotic_fact) = fact.exotic_fact() else {
28 bail!("Datum fact is exotic, but no exotic fact was found.")
29 };
30 if let Some(_bqf) = exotic_fact.downcast_ref::<BlockQuantFact>() {
31 Ok(Cow::Borrowed(&*fact.shape))
32 } else if let Some(pof) = exotic_fact.downcast_ref::<PackedBlockQuantFact>() {
33 Ok(Cow::Owned(
34 fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
35 ))
36 } else if let Some(pof) = exotic_fact.downcast_ref::<PackedExoticFact>() {
37 Ok(Cow::Owned(
38 fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
39 ))
40 } else {
41 bail!("Unsupported exotic fact {exotic_fact:?}")
42 }
43}
44
45#[derive(Clone, Hash, PartialEq, Eq)]
46pub struct EinSum {
47 pub axes: AxesMapping,
48 pub operating_dt: DatumType,
49 pub q_params: Option<DatumType>,
52}
53
54impl EinSum {
55 pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
56 EinSum { axes, operating_dt, q_params: None }
57 }
58
59 pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
60 EinSum { axes, operating_dt, q_params: Some(output_type) }
61 }
62
63 pub fn actual_input_shapes_from_facts<'m>(
64 &self,
65 inputs: &'m [impl Borrow<TypedFact>],
66 ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
67 ensure!(inputs.len() == self.axes.input_count());
68 let shapes: TVec<Cow<[TDim]>> = inputs
69 .iter()
70 .map(|t| block_quant_aware_input_shape(t.borrow()))
71 .collect::<TractResult<_>>()?;
72 ensure!(
73 shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
74 );
75 Ok(shapes)
76 }
77
78 #[allow(unused_variables)]
79 pub(crate) fn propagate_axis(
80 &self,
81 model: &TypedModel,
82 node: &TypedNode,
83 io: InOut,
84 axis: usize,
85 ) -> TractResult<Option<TypedModelPatch>> {
86 let mut new_axis = self.axes.axis((io, axis))?.clone();
87 let repr = new_axis.repr;
88 let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
89 let mut taps = tvec!();
90 for (ix, input) in node.inputs.iter().enumerate() {
91 let mut tap = patch.tap_model(model, *input)?;
92 if new_axis.inputs[ix].len() > 1 {
93 return Ok(None); } else if new_axis.inputs[ix].is_empty() {
95 let insert_at = self.axes.rank(InOut::In(ix));
96 tap = patch.wire_node(
97 format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
98 AxisOp::Add(insert_at),
99 &[tap],
100 )?[0];
101 new_axis.inputs[ix].push(insert_at);
102 }
103 taps.push(tap);
104 }
105 let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
106 let insert_at = self.axes.rank(InOut::Out(0));
107 new_axis.outputs[0].push(insert_at);
108 Some(insert_at)
109 } else {
110 None
111 };
112 let new_expr = self
113 .axes
114 .iter_all_axes()
115 .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
116 .collect_vec();
117 let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
118 let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
119 if let Some(position) = must_rm_axis {
120 wire = patch.wire_node(
121 format!("{}.prop_axis.{}.output", &node.name, repr),
122 AxisOp::Rm(position),
123 &wire,
124 )?;
125 }
126 patch.shunt_outside(model, node.id.into(), wire[0])?;
127 Ok(Some(patch))
128 }
129
130 pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
131 if self.operating_dt.is_integer() {
132 tvec!(i32::datum_type())
133 } else if self.operating_dt == f16::datum_type() {
134 tvec!(f16::datum_type(), f32::datum_type())
135 } else {
136 tvec!(self.operating_dt)
137 }
138 }
139}
140
141impl Debug for EinSum {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
144 }
145}
146
147impl Op for EinSum {
148 fn name(&self) -> StaticName {
149 "EinSum".into()
150 }
151
152 fn info(&self) -> TractResult<Vec<String>> {
153 let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
154 if let Some(qp) = self.q_params {
155 info.push(format!("Quantized output: {qp:?}"));
156 }
157 Ok(info)
158 }
159
160 op_as_typed_op!();
161}
162
163impl EvalOp for EinSum {
164 fn is_stateless(&self) -> bool {
165 true
166 }
167
168 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
169 if inputs.iter().all(|i| i.datum_type().is_number() && i.is_plain()) {
170 let mut adhoc_model = TypedModel::default();
171 let mut wires = tvec!();
172 for (ix, input) in inputs.iter().enumerate() {
173 let fact = TypedFact::shape_and_dt_of(input);
174 let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
175 wires.push(wire);
176 }
177 let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
178 adhoc_model.select_output_outlets(&output)?;
179 let opti = adhoc_model.into_optimized()?;
180 if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
181 return opti.into_runnable()?.run(inputs);
182 }
183 }
184
185 let output = if let Some(qp) = self.q_params {
186 eval::eval_q(&self.axes, qp, inputs)
187 } else {
188 dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
189 }?;
190 Ok(tvec!(output.into_tvalue()))
191 }
192}
193
194impl TypedOp for EinSum {
195 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
196 let shapes = self.actual_input_shapes_from_facts(inputs)?;
197 for i in 0..inputs.len() {
198 ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
199 }
200 for axis in self.axes.iter_all_axes() {
201 assert!(
202 shapes
203 .iter()
204 .enumerate()
205 .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
206 .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
207 .is_ok()
208 );
209 }
210 if let Some(qp) = self.q_params {
211 ensure!(inputs.len() == 9);
212 Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
213 } else {
214 Ok(tvec!(TypedFact::dt_shape(
215 self.operating_dt,
216 eval::output_shape(&self.axes, &shapes)?
217 )))
218 }
219 }
220
221 fn input_roi(
222 &self,
223 model: &TypedModel,
224 node: &TypedNode,
225 ) -> TractResult<Option<TVec<Option<TDim>>>> {
226 crate::optim::propagate_roi::bubble_roi(model, node)
227 }
228
229 fn axes_mapping(
230 &self,
231 inputs: &[&TypedFact],
232 _outputs: &[&TypedFact],
233 ) -> TractResult<AxesMapping> {
234 let mut axes = self.axes.clone();
235 for (slot, i) in inputs.iter().enumerate() {
236 if i.is_exotic()
237 && (i.exotic_fact().is_some_and(|of| {
238 of.is::<PackedExoticFact>() || of.is::<PackedBlockQuantFact>()
239 }))
240 {
241 axes = axes
242 .remove_axis_occurency(InOut::In(slot), i.rank())?
243 .remove_axis_occurency(InOut::In(slot), i.rank())?;
244 }
245 }
246 Ok(axes)
247 }
248
249 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
250 let shapes = self.actual_input_shapes_from_facts(inputs)?;
251 let oshape = eval::output_shape(&self.axes, &shapes)?;
252 let ks = self
253 .axes
254 .iter_all_axes()
255 .filter(|axis| axis.outputs[0].len() == 0)
256 .map(|axis| {
257 axis.inputs
258 .iter()
259 .enumerate()
260 .flat_map(|(ix, axes)| {
261 axes.iter()
262 .map(|axis| shapes[ix][*axis].clone())
263 .collect::<TVec<_>>()
264 .into_iter()
265 })
266 .find(|d| !d.is_one())
267 .unwrap_or_else(|| 1.to_dim())
268 })
269 .product::<TDim>();
270 Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
271 }
272
273 fn slice(
274 &self,
275 patch: &mut TypedModelPatch,
276 model: &TypedModel,
277 node: &TypedNode,
278 prefix: &str,
279 inputs: &[OutletId],
280 output_axis: usize,
281 _start: &TDim,
282 _end: &TDim,
283 ) -> TractResult<Option<TVec<OutletId>>> {
284 let facts = model.node_input_facts(node.id)?;
285 let axis = self.axes.axis((InOut::Out(0), output_axis))?;
286 if facts
287 .iter()
288 .enumerate()
289 .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.is_exotic())
290 {
291 Ok(None)
292 } else {
293 patch.wire_node(prefix, self.clone(), inputs).map(Some)
294 }
295 }
296
297 #[allow(unused_variables)]
298 fn change_axes(
299 &self,
300 model: &TypedModel,
301 node: &TypedNode,
302 io: InOut,
303 change: &AxisOp,
304 ) -> TractResult<Option<AxisChangeConsequence>> {
305 let (mut inputs, mut outputs) = self.axes.to_strs();
306 let interface: &mut String = match io {
307 InOut::In(i) => &mut inputs[i],
308 InOut::Out(o) => &mut outputs[o],
309 };
310 let mut axes: Vec<char> = interface.chars().collect();
311 match change {
312 AxisOp::Rm(rm) => {
313 axes.remove(*rm);
314 }
315 AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
316 AxisOp::Move(from, to) => {
317 let c = axes.remove(*from);
318 axes.insert(*to, c);
319 }
320 _ => {
321 return Ok(None);
322 }
323 };
324 *interface = axes.into_iter().collect();
325 let axes = AxesMapping::from_strs(&inputs, &outputs)?;
326 Ok(Some(AxisChangeConsequence {
327 substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
328 wire_changes: tvec!((io, change.clone())),
329 }))
330 }
331
332 fn declutter_with_session(
333 &self,
334 session: &mut crate::optim::OptimizerSession,
335 model: &TypedModel,
336 node: &TypedNode,
337 ) -> TractResult<Option<TypedModelPatch>> {
338 if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
339 return Ok(Some(patch));
340 }
341 if let Some(patch) = declutter_broadcast(self, session, model, node)? {
342 return Ok(Some(patch));
343 }
344 Ok(None)
345 }
346
347 fn codegen(
348 &self,
349 model: &TypedModel,
350 node: &TypedNode,
351 ) -> TractResult<Option<TypedModelPatch>> {
352 rule_if!(
353 (self.q_params.is_none() && node.inputs.len() == 2)
354 || (self.q_params.is_some() && node.inputs.len() == 9)
355 );
356 einsum_matmul::detect_rule(&(), model, node, &node.name, self)
357 }
358
359 as_op!();
360}
361
362fn declutter_reshape_folding_input_axis(
363 op: &EinSum,
364 _session: &mut crate::optim::OptimizerSession,
365 model: &TypedModel,
366 node: &TypedNode,
367) -> TractResult<Option<TypedModelPatch>> {
368 for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
369 let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
370 if to.len() > 1 {
371 continue;
372 }
373 let mut axes = op.axes.clone();
374 let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
375 let extra_input = node.inputs.len();
377 axes = axes.with_extra_input(extra_input)?;
378 for label in &extra_labels {
379 axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
380 }
381 let folded_axis = op.axes.axis((InOut::In(slot), at))?;
382 if folded_axis.outputs[0].len() > 1 {
383 return Ok(None);
384 };
385 let mut patch = TypedModelPatch::default();
386 let mut taps = patch.taps(model, &node.inputs)?;
387 for (input, tap) in taps.iter_mut().enumerate() {
388 if folded_axis.inputs[input].len() == 0 {
389 continue;
390 };
391 if folded_axis.inputs[input].len() > 1 {
392 return Ok(None);
393 };
394 let pos = folded_axis.inputs[input][0];
395 for label in &extra_labels {
396 axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
397 }
398 *tap = patch.wire_node(
399 format!("{}.reshape_folded_input_{}", node.name, input),
400 AxisOp::Reshape(pos, to.clone(), from.clone()),
401 &[*tap],
402 )?[0];
403 }
404 if folded_axis.outputs[0].len() == 1 {
405 let pos = folded_axis.outputs[0][0];
406 for label in &extra_labels {
407 axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
408 }
409 }
410 axes = axes.remove_slot(InOut::In(extra_input))?;
411 let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
412 if folded_axis.outputs[0].len() == 1 {
413 let pos = folded_axis.outputs[0][0];
414 wire = patch.wire_node(
415 format!("{}.reshape_folded_output", node.name),
416 AxisOp::Reshape(pos, from.clone(), to.clone()),
417 &wire,
418 )?;
419 }
420 patch.shunt_outside(model, node.id.into(), wire[0])?;
421 return Ok(Some(patch));
422 }
423 Ok(None)
424}
425
426fn declutter_broadcast(
427 op: &EinSum,
428 _session: &mut crate::optim::OptimizerSession,
429 model: &TypedModel,
430 node: &TypedNode,
431) -> TractResult<Option<TypedModelPatch>> {
432 for (ix, outlet) in node.inputs.iter().enumerate() {
433 let prec = model.node(outlet.node);
434 if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
435 let mut patch = TypedModelPatch::default();
436 let mut wires = patch.taps(model, &node.inputs)?;
437 wires[ix] = patch.tap_model(model, prec.inputs[0])?;
438 let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
439 patch.shunt_outside(model, node.id.into(), wire)?;
440 return Ok(Some(patch));
441 }
442 }
443 Ok(None)
444}