1use crate::axes::Axis;
2use crate::internal::*;
3use ndarray::*;
4use tract_linalg::block_quant::{
5 BlockQuantStorage, PackedBlockQuantFact, PackedBlockQuantFormat, block_quant_slice,
6};
7use tract_linalg::mmm::{MMMInputValue, PackedMatrixStorage};
8use tract_linalg::pack::PackedFormat;
9
10use super::ModePicker;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub struct OptMatMulPack {
14 pub(crate) packers: Vec<PackedFormat>,
15 pub(crate) mode_picker: ModePicker,
16 pub(crate) k_axis: usize,
17 pub(crate) mn_axis: usize,
18}
19
20impl Op for OptMatMulPack {
21 fn name(&self) -> StaticName {
22 "OptMatMulPack".into()
23 }
24
25 fn info(&self) -> TractResult<Vec<String>> {
26 Ok(vec![format!("{:?}. k axis: {}, mn axis: {}", self.packers, self.k_axis, self.mn_axis)])
27 }
28
29 op_as_typed_op!();
30}
31
32impl EvalOp for OptMatMulPack {
33 fn is_stateless(&self) -> bool {
34 true
35 }
36
37 fn eval_with_session(
38 &self,
39 _node_id: usize,
40 session: &TurnState,
41 mut inputs: TVec<TValue>,
42 ) -> TractResult<TVec<TValue>> {
43 self.do_eval(session, inputs.remove(0))
44 }
45}
46
47impl TypedOp for OptMatMulPack {
48 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
49 match self.mode_picker {
50 ModePicker::Single => ensure!(self.packers.len() == 1),
51 ModePicker::VecVsMat => ensure!(self.packers.len() == 2),
52 }
53 let k = inputs[0].shape[self.k_axis].clone();
54 let mn = inputs[0].shape[self.mn_axis].clone();
55 let exotic_fact = DynPackedExoticFact { k, mn, packers: self.packers.clone() };
56 Ok(tvec!(
57 inputs[0]
58 .datum_type
59 .fact(self.output_shape(&inputs[0].shape))
60 .with_exotic_fact(exotic_fact)
61 ))
62 }
63
64 fn axes_mapping(
65 &self,
66 inputs: &[&TypedFact],
67 outputs: &[&TypedFact],
68 ) -> TractResult<AxesMapping> {
69 let mut axes: Vec<Axis> = (0..inputs[0].rank())
70 .filter(|&ix| ix != self.k_axis && ix != self.mn_axis)
71 .enumerate()
72 .zip('a'..)
73 .map(|((o, i), repr)| Axis::new(repr, 1, 1).input(0, i).output(0, o))
74 .collect();
75 axes.push(Axis::new('K', 1, 1).input(0, self.k_axis));
76 axes.push(Axis::new('M', 1, 1).input(0, self.mn_axis));
77 axes.push(Axis::new('P', 1, 1).output(0, outputs[0].rank()));
78 AxesMapping::new(1, 1, axes)
79 }
80
81 as_op!();
82}
83
84impl OptMatMulPack {
85 fn do_eval(&self, _session: &TurnState, input: TValue) -> TractResult<TVec<TValue>> {
86 unsafe {
87 let mode = self.mode_picker.pick(input.shape()[self.mn_axis])?;
88 let packer = &self.packers[mode];
89 let output_shape: TVec<usize> = self.output_shape(input.shape());
90 let stores = if output_shape.iter().all(|d| *d == 1) {
91 let packed = packer.pack_tensor_view(&input.view(), self.k_axis, self.mn_axis)?;
92 PackedMatrixStorage::new_batched(&output_shape, vec![packed])
93 .into_tensor(input.datum_type())
94 } else {
95 let mut bc_shape: TVec<usize> = input.shape().into();
96 bc_shape[self.k_axis] = 1;
97 bc_shape[self.mn_axis] = 1;
98
99 let mut values: Vec<Box<dyn MMMInputValue>> =
100 Vec::with_capacity(output_shape.iter().product());
101 for coord in indices(&*bc_shape) {
102 let offset = coord
103 .as_array_view()
104 .iter()
105 .zip(input.strides())
106 .map(|(x, s)| *x as isize * s)
107 .sum::<isize>()
108 * input.datum_type().size_of() as isize;
109 values.push(packer.pack_tensor_view(
110 &TensorView::from_bytes(&input, offset, input.shape(), input.strides()),
111 self.k_axis,
112 self.mn_axis,
113 )?);
114 }
115 PackedMatrixStorage::new_batched(&output_shape, values)
116 .into_tensor(input.datum_type())
117 };
118 Ok(tvec!(stores.into_tvalue()))
119 }
120 }
121
122 pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
123 let mut packed_shape: TVec<D> = input.into();
124 packed_shape.remove(self.mn_axis.max(self.k_axis));
125 packed_shape.remove(self.mn_axis.min(self.k_axis));
126 packed_shape
127 }
128}
129
130#[derive(Hash, Clone, Debug, PartialEq, Eq)]
131pub struct DynPackedExoticFact {
132 pub k: TDim,
133 pub mn: TDim,
134 pub packers: Vec<PackedFormat>,
135}
136
137impl ExoticFact for DynPackedExoticFact {
138 fn buffer_sizes(&self) -> TVec<TDim> {
139 tvec!(self.k.clone() * &self.mn * self.packers[0].dt.size_of())
140 }
141}
142
143#[derive(Debug, Clone, Hash, Eq, PartialEq)]
144pub struct OptSimpleMatMulPack {
145 pub(crate) packed_format: PackedBlockQuantFormat,
146 pub(crate) k: usize,
147 pub(crate) m: usize,
148}
149
150impl Op for OptSimpleMatMulPack {
151 fn name(&self) -> StaticName {
152 "OptSimpleMatMulPack".into()
153 }
154 op_as_typed_op!();
155}
156
157impl EvalOp for OptSimpleMatMulPack {
158 fn is_stateless(&self) -> bool {
159 true
160 }
161
162 fn state(
163 &self,
164 _session: &TurnState,
165 _node_id: usize,
166 ) -> TractResult<Option<Box<dyn OpState>>> {
167 Ok(None)
168 }
169
170 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
171 let input = args_1!(inputs);
172 let bqs = input.try_storage_as::<BlockQuantStorage>()?;
173 let num_groups: usize = input.shape()[..input.rank().saturating_sub(2)].iter().product();
175 let m_per_group = input.shape()[input.rank() - 2];
176 let k = *input.shape().last().unwrap();
177 let values = (0..num_groups)
178 .map(|g| {
179 let slice = block_quant_slice(bqs.value(), bqs.format(), m_per_group, k, g);
180 let iv: Box<dyn MMMInputValue> = Box::new(self.packed_format.pack(slice, k)?);
181 Ok(iv)
182 })
183 .collect::<TractResult<Vec<_>>>()?;
184 let leading_shape = &input.shape()[..input.rank().saturating_sub(2)];
185 let output =
186 PackedMatrixStorage::new_batched(leading_shape, values).into_tensor(input.datum_type());
187 Ok(tvec!(output.into_tvalue()))
188 }
189}
190
191impl TypedOp for OptSimpleMatMulPack {
192 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
193 let input = inputs[0];
194 let output_shape: TVec<TDim> = if input.rank() > 2 {
196 input.shape[..input.rank() - 2].to_vec().into()
197 } else {
198 tvec!()
199 };
200 let fact =
201 inputs[0].datum_type.fact(&*output_shape).with_exotic_fact(PackedBlockQuantFact {
202 format: self.packed_format.clone(),
203 shape: tvec!(self.m, self.k),
204 });
205 Ok(tvec!(fact))
206 }
207
208 as_op!();
209}