1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12pub mod einsum_matmul;
13pub mod kernel_selection;
14pub mod prefix_matmul;
15
16#[cfg(test)]
17mod proptest;
18
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedOpaqueFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
24 if !fact.datum_type.is_opaque() {
25 return Ok(Cow::Borrowed(&*fact.shape));
26 }
27 let Some(opaque_fact) = fact.opaque_fact() else {
28 bail!("Datum fact is opaque, but no opaque fact was found.")
29 };
30 if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() {
31 Ok(Cow::Owned(
32 fact.shape.iter().cloned().chain(bqf.shape().iter().map(|d| d.to_dim())).collect_vec(),
33 ))
34 } else if let Some(pof) = opaque_fact.downcast_ref::<PackedBlockQuantFact>() {
35 Ok(Cow::Owned(
36 fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
37 ))
38 } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
39 Ok(Cow::Owned(
40 fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
41 ))
42 } else {
43 bail!("Unsupported opaque fact {opaque_fact:?}")
44 }
45}
46
47#[derive(Clone, Hash, PartialEq)]
48pub struct EinSum {
49 pub axes: AxesMapping,
50 pub operating_dt: DatumType,
51 pub q_params: Option<DatumType>,
54}
55
56impl EinSum {
57 pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
58 EinSum { axes, operating_dt, q_params: None }
59 }
60
61 pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
62 EinSum { axes, operating_dt, q_params: Some(output_type) }
63 }
64
65 pub fn actual_input_shapes_from_facts<'m>(
66 &self,
67 inputs: &'m [impl Borrow<TypedFact>],
68 ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
69 ensure!(inputs.len() == self.axes.input_count());
70 let shapes: TVec<Cow<[TDim]>> = inputs
71 .iter()
72 .map(|t| block_quant_aware_input_shape(t.borrow()))
73 .collect::<TractResult<_>>()?;
74 ensure!(
75 shapes.iter().enumerate().all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix)))
76 );
77 Ok(shapes)
78 }
79
80 #[allow(unused_variables)]
81 pub(crate) fn propagate_axis(
82 &self,
83 model: &TypedModel,
84 node: &TypedNode,
85 io: InOut,
86 axis: usize,
87 ) -> TractResult<Option<TypedModelPatch>> {
88 let mut new_axis = self.axes.axis((io, axis))?.clone();
89 let repr = new_axis.repr;
90 let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
91 let mut taps = tvec!();
92 for (ix, input) in node.inputs.iter().enumerate() {
93 let mut tap = patch.tap_model(model, *input)?;
94 if new_axis.inputs[ix].len() > 1 {
95 return Ok(None); } else if new_axis.inputs[ix].is_empty() {
97 let insert_at = self.axes.rank(InOut::In(ix));
98 tap = patch.wire_node(
99 format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
100 AxisOp::Add(insert_at),
101 &[tap],
102 )?[0];
103 new_axis.inputs[ix].push(insert_at);
104 }
105 taps.push(tap);
106 }
107 let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
108 let insert_at = self.axes.rank(InOut::Out(0));
109 new_axis.outputs[0].push(insert_at);
110 Some(insert_at)
111 } else {
112 None
113 };
114 let new_expr = self
115 .axes
116 .iter_all_axes()
117 .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
118 .collect_vec();
119 let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
120 let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
121 if let Some(position) = must_rm_axis {
122 wire = patch.wire_node(
123 format!("{}.prop_axis.{}.output", &node.name, repr),
124 AxisOp::Rm(position),
125 &wire,
126 )?;
127 }
128 patch.shunt_outside(model, node.id.into(), wire[0])?;
129 Ok(Some(patch))
130 }
131
132 pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
133 if self.operating_dt.is_integer() {
134 tvec!(i32::datum_type())
135 } else if self.operating_dt == f16::datum_type() {
136 tvec!(f16::datum_type(), f32::datum_type())
137 } else {
138 tvec!(self.operating_dt)
139 }
140 }
141}
142
143impl Debug for EinSum {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
146 }
147}
148
149impl Op for EinSum {
150 fn name(&self) -> StaticName {
151 "EinSum".into()
152 }
153
154 fn info(&self) -> TractResult<Vec<String>> {
155 let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
156 if let Some(qp) = self.q_params {
157 info.push(format!("Quantized output: {qp:?}"));
158 }
159 Ok(info)
160 }
161
162 op_as_typed_op!();
163}
164
165impl EvalOp for EinSum {
166 fn is_stateless(&self) -> bool {
167 true
168 }
169
170 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
171 if inputs.iter().all(|i| i.datum_type().is_number()) {
172 let mut adhoc_model = TypedModel::default();
173 let mut wires = tvec!();
174 for (ix, input) in inputs.iter().enumerate() {
175 let fact = TypedFact::shape_and_dt_of(input);
176 let wire = adhoc_model.add_source(format!("input.{ix}"), fact)?;
177 wires.push(wire);
178 }
179 let output = adhoc_model.wire_node("einsum", self.clone(), &wires)?;
180 adhoc_model.set_output_outlets(&output)?;
181 let opti = adhoc_model.into_optimized()?;
182 if opti.nodes.iter().all(|node| !node.op_is::<Self>()) {
183 return opti.into_runnable()?.run(inputs);
184 }
185 }
186
187 let output = if let Some(qp) = self.q_params {
188 eval::eval_q(&self.axes, qp, inputs)
189 } else {
190 dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
191 }?;
192 Ok(tvec!(output.into_tvalue()))
193 }
194}
195
196impl TypedOp for EinSum {
197 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
198 let shapes = self.actual_input_shapes_from_facts(inputs)?;
199 for i in 0..inputs.len() {
200 ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
201 }
202 for axis in self.axes.iter_all_axes() {
203 assert!(
204 shapes
205 .iter()
206 .enumerate()
207 .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
208 .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
209 .is_ok()
210 );
211 }
212 if let Some(qp) = self.q_params {
213 ensure!(inputs.len() == 9);
214 Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
215 } else {
216 Ok(tvec!(TypedFact::dt_shape(
217 self.operating_dt,
218 eval::output_shape(&self.axes, &shapes)?
219 )))
220 }
221 }
222
223 fn axes_mapping(
224 &self,
225 inputs: &[&TypedFact],
226 _outputs: &[&TypedFact],
227 ) -> TractResult<AxesMapping> {
228 let mut axes = self.axes.clone();
229 for (slot, i) in inputs.iter().enumerate() {
230 if i.datum_type.is_opaque()
231 && (i.opaque_fact().is_some_and(|of| {
232 of.is::<BlockQuantFact>()
233 || of.is::<PackedOpaqueFact>()
234 || of.is::<PackedBlockQuantFact>()
235 }))
236 {
237 axes = axes
238 .remove_axis_occurency(InOut::In(slot), i.rank())?
239 .remove_axis_occurency(InOut::In(slot), i.rank())?;
240 }
241 }
242 Ok(axes)
243 }
244
245 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
246 let shapes = self.actual_input_shapes_from_facts(inputs)?;
247 let oshape = eval::output_shape(&self.axes, &shapes)?;
248 let ks = self
249 .axes
250 .iter_all_axes()
251 .filter(|axis| axis.outputs[0].len() == 0)
252 .map(|axis| {
253 axis.inputs
254 .iter()
255 .enumerate()
256 .flat_map(|(ix, axes)| {
257 axes.iter()
258 .map(|axis| shapes[ix][*axis].clone())
259 .collect::<TVec<_>>()
260 .into_iter()
261 })
262 .find(|d| !d.is_one())
263 .unwrap_or_else(|| 1.to_dim())
264 })
265 .product::<TDim>();
266 Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
267 }
268
269 fn slice(
270 &self,
271 patch: &mut TypedModelPatch,
272 model: &TypedModel,
273 node: &TypedNode,
274 prefix: &str,
275 inputs: &[OutletId],
276 output_axis: usize,
277 _start: &TDim,
278 _end: &TDim,
279 ) -> TractResult<Option<TVec<OutletId>>> {
280 let facts = model.node_input_facts(node.id)?;
281 let axis = self.axes.axis((InOut::Out(0), output_axis))?;
282 if facts
283 .iter()
284 .enumerate()
285 .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.datum_type.is_opaque())
286 {
287 Ok(None)
288 } else {
289 patch.wire_node(prefix, self.clone(), inputs).map(Some)
290 }
291 }
292
293 #[allow(unused_variables)]
294 fn change_axes(
295 &self,
296 model: &TypedModel,
297 node: &TypedNode,
298 io: InOut,
299 change: &AxisOp,
300 ) -> TractResult<Option<AxisChangeConsequence>> {
301 let (mut inputs, mut outputs) = self.axes.to_strs();
302 let interface: &mut String = match io {
303 InOut::In(i) => &mut inputs[i],
304 InOut::Out(o) => &mut outputs[o],
305 };
306 let mut axes: Vec<char> = interface.chars().collect();
307 match change {
308 AxisOp::Rm(rm) => {
309 axes.remove(*rm);
310 }
311 AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
312 AxisOp::Move(from, to) => {
313 let c = axes.remove(*from);
314 axes.insert(*to, c);
315 }
316 _ => {
317 return Ok(None);
318 }
319 };
320 *interface = axes.into_iter().collect();
321 let axes = AxesMapping::from_strs(&inputs, &outputs)?;
322 Ok(Some(AxisChangeConsequence {
323 substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
324 wire_changes: tvec!((io, change.clone())),
325 }))
326 }
327
328 fn declutter_with_session(
329 &self,
330 session: &mut crate::optim::OptimizerSession,
331 model: &TypedModel,
332 node: &TypedNode,
333 ) -> TractResult<Option<TypedModelPatch>> {
334 if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
335 return Ok(Some(patch));
336 }
337 if let Some(patch) = declutter_broadcast(self, session, model, node)? {
338 return Ok(Some(patch));
339 }
340 Ok(None)
341 }
342
343 fn codegen(
344 &self,
345 model: &TypedModel,
346 node: &TypedNode,
347 ) -> TractResult<Option<TypedModelPatch>> {
348 rule_if!(
349 (self.q_params.is_none() && node.inputs.len() == 2)
350 || (self.q_params.is_some() && node.inputs.len() == 9)
351 );
352 einsum_matmul::detect_rule(&(), model, node, &node.name, self)
353 }
354
355 as_op!();
356}
357
358fn declutter_reshape_folding_input_axis(
359 op: &EinSum,
360 _session: &mut crate::optim::OptimizerSession,
361 model: &TypedModel,
362 node: &TypedNode,
363) -> TractResult<Option<TypedModelPatch>> {
364 for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
365 let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
366 if to.len() > 1 {
367 continue;
368 }
369 let mut axes = op.axes.clone();
370 let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
371 let extra_input = node.inputs.len();
373 axes = axes.with_extra_input(extra_input)?;
374 for label in &extra_labels {
375 axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
376 }
377 let folded_axis = op.axes.axis((InOut::In(slot), at))?;
378 if folded_axis.outputs[0].len() > 1 {
379 return Ok(None);
380 };
381 let mut patch = TypedModelPatch::default();
382 let mut taps = patch.taps(model, &node.inputs)?;
383 for (input, tap) in taps.iter_mut().enumerate() {
384 if folded_axis.inputs[input].len() == 0 {
385 continue;
386 };
387 if folded_axis.inputs[input].len() > 1 {
388 return Ok(None);
389 };
390 let pos = folded_axis.inputs[input][0];
391 for label in &extra_labels {
392 axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
393 }
394 *tap = patch.wire_node(
395 format!("{}.reshape_folded_input_{}", node.name, input),
396 AxisOp::Reshape(pos, to.clone(), from.clone()),
397 &[*tap],
398 )?[0];
399 }
400 if folded_axis.outputs[0].len() == 1 {
401 let pos = folded_axis.outputs[0][0];
402 for label in &extra_labels {
403 axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
404 }
405 }
406 axes = axes.remove_slot(InOut::In(extra_input))?;
407 let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
408 if folded_axis.outputs[0].len() == 1 {
409 let pos = folded_axis.outputs[0][0];
410 wire = patch.wire_node(
411 format!("{}.reshape_folded_output", node.name),
412 AxisOp::Reshape(pos, from.clone(), to.clone()),
413 &wire,
414 )?;
415 }
416 patch.shunt_outside(model, node.id.into(), wire[0])?;
417 return Ok(Some(patch));
418 }
419 Ok(None)
420}
421
422fn declutter_broadcast(
423 op: &EinSum,
424 _session: &mut crate::optim::OptimizerSession,
425 model: &TypedModel,
426 node: &TypedNode,
427) -> TractResult<Option<TypedModelPatch>> {
428 for (ix, outlet) in node.inputs.iter().enumerate() {
429 let prec = model.node(outlet.node);
430 if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
431 let mut patch = TypedModelPatch::default();
432 let mut wires = patch.taps(model, &node.inputs)?;
433 wires[ix] = patch.tap_model(model, prec.inputs[0])?;
434 let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
435 patch.shunt_outside(model, node.id.into(), wire)?;
436 return Ok(Some(patch));
437 }
438 }
439 Ok(None)
440}