1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::MultiBroadcastTo;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12pub mod einsum_matmul;
13pub mod kernel_selection;
14pub mod prefix_matmul;
15
16#[cfg(test)]
17mod proptest;
18
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedOpaqueFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<'_, [TDim]>> {
24 if !fact.datum_type.is_opaque() {
25 return Ok(Cow::Borrowed(&*fact.shape));
26 }
27 let Some(opaque_fact) = fact.opaque_fact() else {
28 bail!("Datum fact is opaque, but no opaque fact was found.")
29 };
30 if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() {
31 Ok(Cow::Owned(
32 fact.shape.iter().cloned().chain(bqf.shape().iter().map(|d| d.to_dim())).collect_vec(),
33 ))
34 } else if let Some(pof) = opaque_fact.downcast_ref::<PackedBlockQuantFact>() {
35 Ok(Cow::Owned(
36 fact.shape.iter().cloned().chain(pof.shape.iter().map(|i| i.to_dim())).collect_vec(),
37 ))
38 } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
39 Ok(Cow::Owned(
40 fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
41 ))
42 } else {
43 bail!("Unsupported opaque fact {opaque_fact:?}")
44 }
45}
46
47#[derive(Clone, Hash, PartialEq)]
48pub struct EinSum {
49 pub axes: AxesMapping,
50 pub operating_dt: DatumType,
51 pub q_params: Option<DatumType>,
54}
55
56impl EinSum {
57 pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
58 EinSum { axes, operating_dt, q_params: None }
59 }
60
61 pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
62 EinSum { axes, operating_dt, q_params: Some(output_type) }
63 }
64
65 pub fn actual_input_shapes_from_facts<'m>(
66 &self,
67 inputs: &'m [impl Borrow<TypedFact>],
68 ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
69 ensure!(inputs.len() == self.axes.input_count());
70 let shapes: TVec<Cow<[TDim]>> = inputs
71 .iter()
72 .map(|t| block_quant_aware_input_shape(t.borrow()))
73 .collect::<TractResult<_>>()?;
74 ensure!(shapes
75 .iter()
76 .enumerate()
77 .all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix))));
78 Ok(shapes)
79 }
80
81 #[allow(unused_variables)]
82 pub(crate) fn propagate_axis(
83 &self,
84 model: &TypedModel,
85 node: &TypedNode,
86 io: InOut,
87 axis: usize,
88 ) -> TractResult<Option<TypedModelPatch>> {
89 let mut new_axis = self.axes.axis((io, axis))?.clone();
90 let repr = new_axis.repr;
91 let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
92 let mut taps = tvec!();
93 for (ix, input) in node.inputs.iter().enumerate() {
94 let mut tap = patch.tap_model(model, *input)?;
95 if new_axis.inputs[ix].len() > 1 {
96 return Ok(None); } else if new_axis.inputs[ix].is_empty() {
98 let insert_at = self.axes.rank(InOut::In(ix));
99 tap = patch.wire_node(
100 format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
101 AxisOp::Add(insert_at),
102 &[tap],
103 )?[0];
104 new_axis.inputs[ix].push(insert_at);
105 }
106 taps.push(tap);
107 }
108 let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
109 let insert_at = self.axes.rank(InOut::Out(0));
110 new_axis.outputs[0].push(insert_at);
111 Some(insert_at)
112 } else {
113 None
114 };
115 let new_expr = self
116 .axes
117 .iter_all_axes()
118 .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
119 .collect_vec();
120 let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
121 let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
122 if let Some(position) = must_rm_axis {
123 wire = patch.wire_node(
124 format!("{}.prop_axis.{}.output", &node.name, repr),
125 AxisOp::Rm(position),
126 &wire,
127 )?;
128 }
129 patch.shunt_outside(model, node.id.into(), wire[0])?;
130 Ok(Some(patch))
131 }
132
133 pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
134 if self.operating_dt.is_integer() {
135 tvec!(i32::datum_type())
136 } else if self.operating_dt == f16::datum_type() {
137 tvec!(f16::datum_type(), f32::datum_type())
138 } else {
139 tvec!(self.operating_dt)
140 }
141 }
142}
143
144impl Debug for EinSum {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
147 }
148}
149
150impl Op for EinSum {
151 fn name(&self) -> StaticName {
152 "EinSum".into()
153 }
154
155 fn info(&self) -> TractResult<Vec<String>> {
156 let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
157 if let Some(qp) = self.q_params {
158 info.push(format!("Quantized output: {qp:?}"));
159 }
160 Ok(info)
161 }
162
163 op_as_typed_op!();
164}
165
166impl EvalOp for EinSum {
167 fn is_stateless(&self) -> bool {
168 true
169 }
170
171 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
172 let output = if let Some(qp) = self.q_params {
173 eval::eval_q(&self.axes, qp, inputs)
174 } else {
175 dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
176 }?;
177 Ok(tvec!(output.into_tvalue()))
178 }
179}
180
181impl TypedOp for EinSum {
182 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
183 let shapes = self.actual_input_shapes_from_facts(inputs)?;
184 for i in 0..inputs.len() {
185 ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
186 }
187 for axis in self.axes.iter_all_axes() {
188 assert!(shapes
189 .iter()
190 .enumerate()
191 .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
192 .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
193 .is_ok());
194 }
195 if let Some(qp) = self.q_params {
196 ensure!(inputs.len() == 9);
197 Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
198 } else {
199 Ok(tvec!(TypedFact::dt_shape(
200 self.operating_dt,
201 eval::output_shape(&self.axes, &shapes)?
202 )))
203 }
204 }
205
206 fn axes_mapping(
207 &self,
208 inputs: &[&TypedFact],
209 _outputs: &[&TypedFact],
210 ) -> TractResult<AxesMapping> {
211 let mut axes = self.axes.clone();
212 for (slot, i) in inputs.iter().enumerate() {
213 if i.datum_type.is_opaque()
214 && (i.opaque_fact().is_some_and(|of| {
215 of.is::<BlockQuantFact>()
216 || of.is::<PackedOpaqueFact>()
217 || of.is::<PackedBlockQuantFact>()
218 }))
219 {
220 axes = axes
221 .remove_axis_occurency(InOut::In(slot), i.rank())?
222 .remove_axis_occurency(InOut::In(slot), i.rank())?;
223 }
224 }
225 Ok(axes)
226 }
227
228 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
229 let shapes = self.actual_input_shapes_from_facts(inputs)?;
230 let oshape = eval::output_shape(&self.axes, &shapes)?;
231 let ks = self
232 .axes
233 .iter_all_axes()
234 .filter(|axis| axis.outputs[0].len() == 0)
235 .map(|axis| {
236 axis.inputs
237 .iter()
238 .enumerate()
239 .flat_map(|(ix, axes)| {
240 axes.iter()
241 .map(|axis| shapes[ix][*axis].clone())
242 .collect::<TVec<_>>()
243 .into_iter()
244 })
245 .find(|d| !d.is_one())
246 .unwrap_or_else(|| 1.to_dim())
247 })
248 .product::<TDim>();
249 Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
250 }
251
252 fn slice(
253 &self,
254 patch: &mut TypedModelPatch,
255 model: &TypedModel,
256 node: &TypedNode,
257 prefix: &str,
258 inputs: &[OutletId],
259 output_axis: usize,
260 _start: &TDim,
261 _end: &TDim,
262 ) -> TractResult<Option<TVec<OutletId>>> {
263 let facts = model.node_input_facts(node.id)?;
264 let axis = self.axes.axis((InOut::Out(0), output_axis))?;
265 if facts
266 .iter()
267 .enumerate()
268 .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.datum_type.is_opaque())
269 {
270 Ok(None)
271 } else {
272 patch.wire_node(prefix, self.clone(), inputs).map(Some)
273 }
274 }
275
276 #[allow(unused_variables)]
277 fn change_axes(
278 &self,
279 model: &TypedModel,
280 node: &TypedNode,
281 io: InOut,
282 change: &AxisOp,
283 ) -> TractResult<Option<AxisChangeConsequence>> {
284 let (mut inputs, mut outputs) = self.axes.to_strs();
285 let interface: &mut String = match io {
286 InOut::In(i) => &mut inputs[i],
287 InOut::Out(o) => &mut outputs[o],
288 };
289 let mut axes: Vec<char> = interface.chars().collect();
290 match change {
291 AxisOp::Rm(rm) => {
292 axes.remove(*rm);
293 }
294 AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
295 AxisOp::Move(from, to) => {
296 let c = axes.remove(*from);
297 axes.insert(*to, c);
298 }
299 _ => {
300 return Ok(None);
301 }
302 };
303 *interface = axes.into_iter().collect();
304 let axes = AxesMapping::from_strs(&inputs, &outputs)?;
305 Ok(Some(AxisChangeConsequence {
306 substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
307 wire_changes: tvec!((io, change.clone())),
308 }))
309 }
310
311 fn declutter_with_session(
312 &self,
313 session: &mut crate::optim::OptimizerSession,
314 model: &TypedModel,
315 node: &TypedNode,
316 ) -> TractResult<Option<TypedModelPatch>> {
317 if let Some(patch) = declutter_reshape_folding_input_axis(self, session, model, node)? {
318 return Ok(Some(patch));
319 }
320 if let Some(patch) = declutter_broadcast(self, session, model, node)? {
321 return Ok(Some(patch));
322 }
323 Ok(None)
324 }
325
326 fn codegen(
327 &self,
328 model: &TypedModel,
329 node: &TypedNode,
330 ) -> TractResult<Option<TypedModelPatch>> {
331 if (self.q_params.is_none() && node.inputs.len() != 2)
332 || (self.q_params.is_some() && node.inputs.len() != 9)
333 {
334 return Ok(None);
335 }
336 einsum_matmul::detect_rule(&(), model, node, &node.name, self)
337 }
338
339 as_op!();
340}
341
342fn declutter_reshape_folding_input_axis(
343 op: &EinSum,
344 _session: &mut crate::optim::OptimizerSession,
345 model: &TypedModel,
346 node: &TypedNode,
347) -> TractResult<Option<TypedModelPatch>> {
348 for (slot, prec) in node.inputs.iter().map(|n| model.node(n.node)).enumerate() {
349 let Some(&AxisOp::Reshape(at, ref from, ref to)) = prec.op_as() else { continue };
350 if to.len() > 1 {
351 continue;
352 }
353 let mut axes = op.axes.clone();
354 let extra_labels = axes.available_labels().take(from.len() - 1).collect_vec();
355 let extra_input = node.inputs.len();
357 axes = axes.with_extra_input(extra_input)?;
358 for label in &extra_labels {
359 axes = axes.with_extra_axis(*label, InOut::In(extra_input), 0)?;
360 }
361 let folded_axis = op.axes.axis((InOut::In(slot), at))?;
362 if folded_axis.outputs[0].len() > 1 {
363 return Ok(None);
364 };
365 let mut patch = TypedModelPatch::default();
366 let mut taps = patch.taps(model, &node.inputs)?;
367 for (input, tap) in taps.iter_mut().enumerate() {
368 if folded_axis.inputs[input].len() == 0 {
369 continue;
370 };
371 if folded_axis.inputs[input].len() > 1 {
372 return Ok(None);
373 };
374 let pos = folded_axis.inputs[input][0];
375 for label in &extra_labels {
376 axes = axes.with_extra_axis_occurency(*label, InOut::In(input), pos)?;
377 }
378 *tap = patch.wire_node(
379 format!("{}.reshape_folded_input_{}", node.name, input),
380 AxisOp::Reshape(pos, to.clone(), from.clone()),
381 &[*tap],
382 )?[0];
383 }
384 if folded_axis.outputs[0].len() == 1 {
385 let pos = folded_axis.outputs[0][0];
386 for label in &extra_labels {
387 axes = axes.with_extra_axis_occurency(*label, InOut::Out(0), pos)?;
388 }
389 }
390 axes = axes.remove_slot(InOut::In(extra_input))?;
391 let mut wire = patch.wire_node(&node.name, EinSum { axes, ..op.clone() }, &taps)?;
392 if folded_axis.outputs[0].len() == 1 {
393 let pos = folded_axis.outputs[0][0];
394 wire = patch.wire_node(
395 format!("{}.reshape_folded_output", node.name),
396 AxisOp::Reshape(pos, from.clone(), to.clone()),
397 &wire,
398 )?;
399 }
400 patch.shunt_outside(model, node.id.into(), wire[0])?;
401 return Ok(Some(patch));
402 }
403 Ok(None)
404}
405
406fn declutter_broadcast(
407 op: &EinSum,
408 _session: &mut crate::optim::OptimizerSession,
409 model: &TypedModel,
410 node: &TypedNode,
411) -> TractResult<Option<TypedModelPatch>> {
412 for (ix, outlet) in node.inputs.iter().enumerate() {
413 let prec = model.node(outlet.node);
414 if prec.op_is::<MultiBroadcastTo>() && prec.outputs[0].successors.len() == 1 {
415 let mut patch = TypedModelPatch::default();
416 let mut wires = patch.taps(model, &node.inputs)?;
417 wires[ix] = patch.tap_model(model, prec.inputs[0])?;
418 let wire = patch.wire_node(&node.name, op.clone(), &wires)?[0];
419 patch.shunt_outside(model, node.id.into(), wire)?;
420 return Ok(Some(patch));
421 }
422 }
423 Ok(None)
424}