1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::tract_data::itertools::Itertools;
6
7mod eval;
8
9#[cfg(feature = "blas")]
10pub mod as_blas;
11mod as_matmul;
12pub mod kernel_selection;
13pub mod optimize;
14
15#[cfg(test)]
16mod proptest;
17
18pub use as_matmul::{rewrite_einsums_as_matmul, BasicMatMul};
19use num_traits::One;
20use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact};
21use tract_linalg::mmm::PackedOpaqueFact;
22
23pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<[TDim]>> {
24 if !fact.datum_type.is_opaque() {
25 return Ok(Cow::Borrowed(&*fact.shape));
26 }
27 let Some(opaque_fact) = fact.opaque_fact() else {
28 bail!("Datum fact is opaque, but no opaque fact was found.")
29 };
30 if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() {
31 Ok(Cow::Owned(
32 fact.shape.iter().cloned().chain(bqf.shape().iter().map(|d| d.to_dim())).collect_vec(),
33 ))
34 } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
37 Ok(Cow::Owned(
38 fact.shape.iter().cloned().chain([pof.mn.clone(), pof.k.to_dim()]).collect_vec(),
39 ))
40 } else {
41 bail!("Unsupported opaque fact {opaque_fact:?}")
42 }
43}
44
45#[derive(Clone, Hash)]
46pub struct EinSum {
47 pub axes: AxesMapping,
48 pub operating_dt: DatumType,
49 pub q_params: Option<DatumType>,
52}
53
54impl EinSum {
55 pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
56 EinSum { axes, operating_dt, q_params: None }
57 }
58
59 pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
60 EinSum { axes, operating_dt, q_params: Some(output_type) }
61 }
62
63 pub fn actual_input_shapes_from_facts<'m>(
64 &self,
65 inputs: &'m [impl Borrow<TypedFact>],
66 ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
67 ensure!(inputs.len() == self.axes.input_count());
68 let shapes: TVec<Cow<[TDim]>> = inputs
69 .iter()
70 .map(|t| block_quant_aware_input_shape(t.borrow()))
71 .collect::<TractResult<_>>()?;
72 ensure!(shapes
73 .iter()
74 .enumerate()
75 .all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix))));
76 Ok(shapes)
77 }
78
79 #[allow(unused_variables)]
80 pub(crate) fn propagate_axis(
81 &self,
82 model: &TypedModel,
83 node: &TypedNode,
84 io: InOut,
85 axis: usize,
86 ) -> TractResult<Option<TypedModelPatch>> {
87 let mut new_axis = self.axes.axis((io, axis))?.clone();
88 let repr = new_axis.repr;
89 let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
90 let mut taps = tvec!();
91 for (ix, input) in node.inputs.iter().enumerate() {
92 let mut tap = patch.tap_model(model, *input)?;
93 if new_axis.inputs[ix].len() > 1 {
94 return Ok(None); } else if new_axis.inputs[ix].is_empty() {
96 let insert_at = self.axes.rank(InOut::In(ix));
97 tap = patch.wire_node(
98 format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
99 AxisOp::Add(insert_at),
100 &[tap],
101 )?[0];
102 new_axis.inputs[ix].push(insert_at);
103 }
104 taps.push(tap);
105 }
106 let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
107 let insert_at = self.axes.rank(InOut::Out(0));
108 new_axis.outputs[0].push(insert_at);
109 Some(insert_at)
110 } else {
111 None
112 };
113 let new_expr = self
114 .axes
115 .iter_all_axes()
116 .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
117 .collect_vec();
118 let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
119 let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
120 if let Some(position) = must_rm_axis {
121 wire = patch.wire_node(
122 format!("{}.prop_axis.{}.output", &node.name, repr),
123 AxisOp::Rm(position),
124 &wire,
125 )?;
126 }
127 patch.shunt_outside(model, node.id.into(), wire[0])?;
128 Ok(Some(patch))
129 }
130
131 pub fn acceptable_accumulators(&self) -> TVec<DatumType> {
132 if self.operating_dt.is_integer() {
133 tvec!(i32::datum_type())
134 } else if self.operating_dt == f16::datum_type() {
135 tvec!(f16::datum_type(), f32::datum_type())
136 } else {
137 tvec!(self.operating_dt)
138 }
139 }
140}
141
142impl Debug for EinSum {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
145 }
146}
147
148impl Op for EinSum {
149 fn name(&self) -> Cow<str> {
150 "EinSum".into()
151 }
152
153 fn info(&self) -> TractResult<Vec<String>> {
154 let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
155 if let Some(qp) = self.q_params {
156 info.push(format!("Quantized output: {qp:?}"));
157 }
158 Ok(info)
159 }
160
161 op_as_typed_op!();
162}
163
164impl EvalOp for EinSum {
165 fn is_stateless(&self) -> bool {
166 true
167 }
168
169 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
170 let output = if let Some(qp) = self.q_params {
171 eval::eval_q(&self.axes, qp, inputs)
172 } else {
173 dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
174 }?;
175 Ok(tvec!(output.into_tvalue()))
176 }
177}
178
179impl TypedOp for EinSum {
180 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
181 let shapes = self.actual_input_shapes_from_facts(inputs)?;
182 for i in 0..inputs.len() {
183 ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
184 }
185 for axis in self.axes.iter_all_axes() {
186 assert!(shapes
187 .iter()
188 .enumerate()
189 .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
190 .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
191 .is_ok());
192 }
193 if let Some(qp) = self.q_params {
194 ensure!(inputs.len() == 9);
195 Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
196 } else {
197 Ok(tvec!(TypedFact::dt_shape(
198 self.operating_dt,
199 eval::output_shape(&self.axes, &shapes)?
200 )))
201 }
202 }
203
204 fn axes_mapping(
205 &self,
206 inputs: &[&TypedFact],
207 _outputs: &[&TypedFact],
208 ) -> TractResult<AxesMapping> {
209 let mut axes = self.axes.clone();
210 for (slot, i) in inputs.iter().enumerate() {
211 if i.datum_type.is_opaque()
212 && (i.opaque_fact().is_some_and(|of| {
213 of.is::<BlockQuantFact>()
214 || of.is::<PackedOpaqueFact>()
215 || of.is::<PackedBlockQuantFact>()
216 }))
217 {
218 axes = axes
219 .remove_axis_occurency(InOut::In(slot), i.rank())?
220 .remove_axis_occurency(InOut::In(slot), i.rank())?;
221 }
222 }
223 Ok(axes)
224 }
225
226 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
227 let shapes = self.actual_input_shapes_from_facts(inputs)?;
228 let oshape = eval::output_shape(&self.axes, &shapes)?;
229 let ks = self
230 .axes
231 .iter_all_axes()
232 .filter(|axis| axis.outputs[0].len() == 0)
233 .map(|axis| {
234 axis.inputs
235 .iter()
236 .enumerate()
237 .flat_map(|(ix, axes)| {
238 axes.iter()
239 .map(|axis| shapes[ix][*axis].clone())
240 .collect::<TVec<_>>()
241 .into_iter()
242 })
243 .find(|d| !d.is_one())
244 .unwrap_or_else(|| 1.to_dim())
245 })
246 .product::<TDim>();
247 Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
248 }
249
250 fn slice(
251 &self,
252 patch: &mut TypedModelPatch,
253 model: &TypedModel,
254 node: &TypedNode,
255 prefix: &str,
256 inputs: &[OutletId],
257 output_axis: usize,
258 _start: &TDim,
259 _end: &TDim,
260 ) -> TractResult<Option<TVec<OutletId>>> {
261 let facts = model.node_input_facts(node.id)?;
262 let axis = self.axes.axis((InOut::Out(0), output_axis))?;
263 if facts
264 .iter()
265 .enumerate()
266 .any(|(slot, fact)| axis.inputs[slot].len() > 0 && fact.datum_type.is_opaque())
267 {
268 Ok(None)
269 } else {
270 patch.wire_node(prefix, self.clone(), inputs).map(Some)
271 }
272 }
273
274 #[allow(unused_variables)]
275 fn change_axes(
276 &self,
277 model: &TypedModel,
278 node: &TypedNode,
279 io: InOut,
280 change: &AxisOp,
281 ) -> TractResult<Option<AxisChangeConsequence>> {
282 let (mut inputs, mut outputs) = self.axes.to_strs();
283 let interface: &mut String = match io {
284 InOut::In(i) => &mut inputs[i],
285 InOut::Out(o) => &mut outputs[o],
286 };
287 let mut axes: Vec<char> = interface.chars().collect();
288 match change {
289 AxisOp::Rm(rm) => {
290 axes.remove(*rm);
291 }
292 AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
293 AxisOp::Move(from, to) => {
294 let c = axes.remove(*from);
295 axes.insert(*to, c);
296 }
297 _ => {
298 return Ok(None);
299 }
300 };
301 *interface = axes.into_iter().collect();
302 let axes = AxesMapping::from_strs(&inputs, &outputs)?;
303 Ok(Some(AxisChangeConsequence {
304 substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
305 wire_changes: tvec!((io, change.clone())),
306 }))
307 }
308
309 fn codegen(
322 &self,
323 model: &TypedModel,
324 node: &TypedNode,
325 ) -> TractResult<Option<TypedModelPatch>> {
326 optimize::optimize(self, model, node).with_context(|| {
327 format!(
328 "axes: {} — inputs: {}",
329 self.axes,
330 model
331 .node_input_facts(node.id)
332 .unwrap()
333 .iter()
334 .map(|f| format!("{f:?}"))
335 .join(" • ")
336 )
337 })
338 }
339
340 as_op!();
341}
342
343