1use std::borrow::Borrow;
2use std::fmt::Debug;
3
4use crate::internal::*;
5use crate::ops::array::Slice;
6use crate::tract_data::itertools::Itertools;
7
8mod eval;
9
10#[cfg(feature = "blas")]
11pub mod as_blas;
12use super::array::TypedConcat;
13use super::math::add;
14mod as_matmul;
15pub mod kernel_selection;
16pub mod optimize;
17
18#[cfg(test)]
19mod proptest;
20
21pub use as_matmul::{rewrite_einsums_as_matmul, BasicMatMul};
22use tract_linalg::frame::block_quant::BlockQuantFact;
23use tract_linalg::mmm::PackedOpaqueFact;
24
25pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<[TDim]>> {
26 if !fact.datum_type.is_opaque() {
27 return Ok(Cow::Borrowed(&*fact.shape));
28 }
29 let Some(opaque_fact) = fact.opaque_fact.as_ref() else {
30 bail!("Datum fact is opaque, but no opaque fact was found.")
31 };
32 let inner_shape: Cow<[usize]> = if let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>()
33 {
34 Cow::Borrowed(&*bqf.shape)
35 } else if let Some(pof) = opaque_fact.downcast_ref::<PackedOpaqueFact>() {
38 Cow::Owned(vec![pof.mn, pof.k])
39 } else {
40 bail!("Unsupported opaque fact {opaque_fact:?}")
41 };
42 let shape: Vec<TDim> =
43 fact.shape.iter().cloned().chain(inner_shape.iter().map(|d| d.to_dim())).collect();
44 Ok(Cow::Owned(shape))
45}
46
47#[derive(Clone, Hash)]
48pub struct EinSum {
49 pub axes: AxesMapping,
50 pub operating_dt: DatumType,
51 pub q_params: Option<DatumType>,
54}
55
56impl EinSum {
57 pub fn new(axes: AxesMapping, operating_dt: DatumType) -> EinSum {
58 EinSum { axes, operating_dt, q_params: None }
59 }
60
61 pub fn newq(axes: AxesMapping, operating_dt: DatumType, output_type: DatumType) -> EinSum {
62 EinSum { axes, operating_dt, q_params: Some(output_type) }
63 }
64
65 pub fn actual_input_shapes_from_facts<'m>(
66 &self,
67 inputs: &'m [impl Borrow<TypedFact>],
68 ) -> TractResult<TVec<Cow<'m, [TDim]>>> {
69 ensure!(inputs.len() == self.axes.input_count());
70 let shapes: TVec<Cow<[TDim]>> = inputs
71 .iter()
72 .map(|t| block_quant_aware_input_shape(t.borrow()))
73 .collect::<TractResult<_>>()?;
74 ensure!(shapes
75 .iter()
76 .enumerate()
77 .all(|(ix, fact)| fact.len() == self.axes.rank(InOut::In(ix))));
78 Ok(shapes)
79 }
80
81 #[allow(unused_variables)]
82 pub(crate) fn propagate_axis(
83 &self,
84 model: &TypedModel,
85 node: &TypedNode,
86 io: InOut,
87 axis: usize,
88 ) -> TractResult<Option<TypedModelPatch>> {
89 let mut new_axis = self.axes.axis((io, axis))?.clone();
90 let repr = new_axis.repr;
91 let mut patch = TypedModelPatch::new(format!("Propagate axis {}", new_axis.repr));
92 let mut taps = tvec!();
93 for (ix, input) in node.inputs.iter().enumerate() {
94 let mut tap = patch.tap_model(model, *input)?;
95 if new_axis.inputs[ix].len() > 1 {
96 return Ok(None); } else if new_axis.inputs[ix].is_empty() {
98 let insert_at = self.axes.rank(InOut::In(ix));
99 tap = patch.wire_node(
100 format!("{}.prop_axis.{}.input_{}", &node.name, new_axis.repr, ix),
101 AxisOp::Add(insert_at),
102 &[tap],
103 )?[0];
104 new_axis.inputs[ix].push(insert_at);
105 }
106 taps.push(tap);
107 }
108 let must_rm_axis: Option<usize> = if new_axis.outputs[0].len() == 0 {
109 let insert_at = self.axes.rank(InOut::Out(0));
110 new_axis.outputs[0].push(insert_at);
111 Some(insert_at)
112 } else {
113 None
114 };
115 let new_expr = self
116 .axes
117 .iter_all_axes()
118 .map(|it| if it.repr == new_axis.repr { new_axis.clone() } else { it.clone() })
119 .collect_vec();
120 let axes = AxesMapping::new(node.inputs.len(), 1, new_expr)?;
121 let mut wire = patch.wire_node(&node.name, Self { axes, ..self.clone() }, &taps)?;
122 if let Some(position) = must_rm_axis {
123 wire = patch.wire_node(
124 format!("{}.prop_axis.{}.output", &node.name, repr),
125 AxisOp::Rm(position),
126 &wire,
127 )?;
128 }
129 patch.shunt_outside(model, node.id.into(), wire[0])?;
130 Ok(Some(patch))
131 }
132
133 #[allow(clippy::comparison_chain)]
134 fn declutter_after_concat(
135 &self,
136 model: &TypedModel,
137 node: &TypedNode,
138 ) -> TractResult<Option<TypedModelPatch>> {
139 if self.q_params.is_some() {
140 return Ok(None);
142 }
143 'outer: for (slot, input) in node.inputs.iter().enumerate() {
144 let precursor = model.node(input.node);
145 if let Some(concat) = precursor.op_as::<TypedConcat>() {
146 let offsets = concat.offsets(&model.node_input_facts(precursor.id)?)?;
147 let axis_info = self.axes.axis((InOut::In(slot), concat.axis))?;
148 if axis_info.outputs[0].len() > 0 {
150 continue;
151 }
152 let mut patch = TypedModelPatch::new(format!(
153 "Split Einsum for concat on axis {}",
154 axis_info.repr
155 ));
156 let mut inputs: TVec<TVec<OutletId>> = tvec!();
158 for (slot, input) in node.inputs.iter().enumerate() {
159 let tap = patch.tap_model(model, *input)?;
160 if axis_info.inputs[slot].len() > 1 {
161 continue 'outer;
162 } else if axis_info.inputs[slot].len() == 1 {
163 let mut slices = tvec!();
164 for (start, end) in offsets.iter().cloned().tuple_windows() {
165 let wire = patch.wire_node(
166 format!(
167 "{}.concat-einsum-slice-{}.{}.{}..{}",
168 node.name, axis_info.repr, slot, start, end
169 ),
170 Slice { axis: axis_info.inputs[slot][0], start, end },
171 &[tap],
172 )?;
173 slices.push(wire[0]);
174 }
175 inputs.push(slices);
176 } else {
177 inputs.push(tvec!(tap)); };
179 }
180 let mut einsums = tvec!();
181 for (ix, (start, end)) in offsets.iter().tuple_windows().enumerate() {
182 let mut einsum_inputs = tvec!();
183 for input_ix in 0..node.inputs.len() {
184 einsum_inputs
185 .push(inputs[input_ix].get(ix).cloned().unwrap_or(inputs[input_ix][0]));
186 }
187 let einsum = patch.wire_node(
188 format!(
189 "{}.concat-einsum-{}.{}..{}",
190 node.name, axis_info.repr, start, end
191 ),
192 self.clone(),
193 &einsum_inputs,
194 )?[0];
195 einsums.push(einsum);
196 }
197 let wire = if let Some(axis) = axis_info.outputs[0].first().cloned() {
198 patch.wire_node(
199 format!("{}.concat-einsum-{}.concat", node.name, axis_info.repr),
200 TypedConcat { axis },
201 &einsums,
202 )?[0]
203 } else {
204 let mut wire = einsums[0];
205 for ix in 1..einsums.len() {
206 wire = patch.wire_node(
207 format!("{}.concat-einsum-{}.add-{}", node.name, axis_info.repr, ix),
208 add(),
209 &[wire, einsums[ix]],
210 )?[0]
211 }
212 wire
213 };
214 patch.shunt_outside(model, node.id.into(), wire)?;
215 return Ok(Some(patch));
216 }
217 }
218 Ok(None)
219 }
220}
221
222impl Debug for EinSum {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 write!(f, "EinSum {} ({:?})", self.axes, self.operating_dt)
225 }
226}
227
228impl Op for EinSum {
229 fn name(&self) -> Cow<str> {
230 "EinSum".into()
231 }
232
233 fn info(&self) -> TractResult<Vec<String>> {
234 let mut info = vec![format!("{} ({:?})", self.axes, self.operating_dt)];
235 if let Some(qp) = self.q_params {
236 info.push(format!("Quantized output: {qp:?}"));
237 }
238 Ok(info)
239 }
240
241 op_as_typed_op!();
242}
243
244impl EvalOp for EinSum {
245 fn is_stateless(&self) -> bool {
246 true
247 }
248
249 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
250 let output = if let Some(qp) = self.q_params {
251 eval::eval_q(&self.axes, qp, inputs)
252 } else {
253 dispatch_numbers!(eval::eval_t(self.operating_dt)(&self.axes, inputs))
254 }?;
255 Ok(tvec!(output.into_tvalue()))
256 }
257}
258
259impl TypedOp for EinSum {
260 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
261 let shapes = self.actual_input_shapes_from_facts(inputs)?;
262 for i in 0..inputs.len() {
263 ensure!(shapes[i].len() == self.axes.rank(InOut::In(i)));
264 }
265 for axis in self.axes.iter_all_axes() {
266 assert!(shapes
267 .iter()
268 .enumerate()
269 .flat_map(|(slot, shape)| axis.inputs[slot].iter().map(|a| &shape[*a]))
270 .try_fold(TDim::one(), |a, b| TDim::broadcast(a, b.clone()))
271 .is_ok());
272 }
273 if let Some(qp) = self.q_params {
274 ensure!(inputs.len() == 9);
275 Ok(tvec!(qp.fact(eval::output_shape(&self.axes, &shapes[0..2])?)))
276 } else {
277 Ok(tvec!(TypedFact::dt_shape(
278 self.operating_dt,
279 eval::output_shape(&self.axes, &shapes)?
280 )))
281 }
282 }
283
284 fn axes_mapping(
285 &self,
286 inputs: &[&TypedFact],
287 _outputs: &[&TypedFact],
288 ) -> TractResult<AxesMapping> {
289 let mut axes = self.axes.clone();
290 for (slot, i) in inputs.iter().enumerate() {
291 if i.datum_type.is_opaque()
292 && i.opaque_fact.as_ref().is_some_and(|of| of.is::<BlockQuantFact>())
293 {
294 axes = axes
295 .remove_axis_occurency(InOut::In(slot), i.rank())?
296 .remove_axis_occurency(InOut::In(slot), i.rank())?;
297 }
298 }
299 Ok(axes)
300 }
301
302 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
303 let shapes = self.actual_input_shapes_from_facts(inputs)?;
304 let oshape = eval::output_shape(&self.axes, &shapes)?;
305 let ks = self
306 .axes
307 .iter_all_axes()
308 .filter(|axis| axis.outputs[0].len() == 0)
309 .map(|axis| {
310 axis.inputs
311 .iter()
312 .enumerate()
313 .flat_map(|(ix, axes)| {
314 axes.iter()
315 .map(|axis| shapes[ix][*axis].clone())
316 .collect::<TVec<_>>()
317 .into_iter()
318 })
319 .find(|d| !d.is_one())
320 .unwrap_or_else(|| 1.to_dim())
321 })
322 .product::<TDim>();
323 Ok(tvec!((Cost::FMA(self.operating_dt), oshape.iter().product::<TDim>() * ks)))
324 }
325
326 fn slice(
327 &self,
328 patch: &mut TypedModelPatch,
329 model: &TypedModel,
330 node: &TypedNode,
331 prefix: &str,
332 inputs: &[OutletId],
333 _output_axis: usize,
334 _start: usize,
335 _end: usize,
336 ) -> TractResult<Option<TVec<OutletId>>> {
337 if model.node_input_facts(node.id)?.iter().any(|f| f.datum_type.is_opaque()) {
338 Ok(None)
339 } else {
340 patch.wire_node(prefix, self.clone(), inputs).map(Some)
341 }
342 }
343
344 #[allow(unused_variables)]
345 fn change_axes(
346 &self,
347 model: &TypedModel,
348 node: &TypedNode,
349 io: InOut,
350 change: &AxisOp,
351 ) -> TractResult<Option<AxisChangeConsequence>> {
352 let (mut inputs, mut outputs) = self.axes.to_strs();
353 let interface: &mut String = match io {
354 InOut::In(i) => &mut inputs[i],
355 InOut::Out(o) => &mut outputs[o],
356 };
357 let mut axes: Vec<char> = interface.chars().collect();
358 match change {
359 AxisOp::Rm(rm) => {
360 axes.remove(*rm);
361 }
362 AxisOp::Add(add) => axes.insert(*add, self.axes.available_label()),
363 AxisOp::Move(from, to) => {
364 let c = axes.remove(*from);
365 axes.insert(*to, c);
366 }
367 _ => return Ok(None),
368 };
369 *interface = axes.into_iter().collect();
370 let axes = AxesMapping::from_strs(&inputs, &outputs)?;
371 Ok(Some(AxisChangeConsequence {
372 substitute_op: Some(Box::new(EinSum { axes, ..self.clone() })),
373 wire_changes: tvec!((io, change.clone())),
374 }))
375 }
376
377 fn declutter(
378 &self,
379 model: &TypedModel,
380 node: &TypedNode,
381 ) -> TractResult<Option<TypedModelPatch>> {
382 self.declutter_after_concat(model, node)
383 }
384
385 fn codegen(
386 &self,
387 model: &TypedModel,
388 node: &TypedNode,
389 ) -> TractResult<Option<TypedModelPatch>> {
390 optimize::optimize(self, model, node)
391 }
392
393 as_op!();
394}