Skip to main content

tract_core/ops/matmul/
pack.rs

1use crate::axes::Axis;
2use crate::internal::*;
3use ndarray::*;
4use tract_linalg::WeightType;
5use tract_linalg::block_quant::{
6    BlockQuantStorage, PackedBlockQuantFact, PackedBlockQuantFormat, block_quant_slice,
7};
8use tract_linalg::mmm::{MMMInputFormat, MMMInputValue, PackedMatrixStorage};
9use tract_linalg::pack::{PackedFormat, PackedI8K4};
10
11use super::ModePicker;
12
13// Pack one (possibly strided) view with a dynamic packing format. Keeps the
14// PackedFormat fast path byte-identical; routes the K=4-inner SMOPA packer
15// (PackedI8K4) through its view packer. Other formats are unsupported here.
16fn pack_view_with(
17    packer: &dyn MMMInputFormat,
18    t: &TensorView,
19    k_axis: usize,
20    mn_axis: usize,
21) -> TractResult<Box<dyn MMMInputValue>> {
22    if let Some(pf) = packer.downcast_ref::<PackedFormat>() {
23        pf.pack_tensor_view(t, k_axis, mn_axis)
24    } else if let Some(p4) = packer.downcast_ref::<PackedI8K4>() {
25        p4.pack_view(t, k_axis, mn_axis)
26    } else {
27        bail!("OptMatMulPack does not support packing format {packer:?}")
28    }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub struct OptMatMulPack {
33    pub(crate) packers: Vec<Box<dyn MMMInputFormat>>,
34    pub(crate) mode_picker: ModePicker,
35    pub(crate) k_axis: usize,
36    pub(crate) mn_axis: usize,
37}
38
39impl Op for OptMatMulPack {
40    fn name(&self) -> StaticName {
41        "OptMatMulPack".into()
42    }
43
44    fn info(&self) -> TractResult<Vec<String>> {
45        Ok(vec![format!("{:?}. k axis: {}, mn axis: {}", self.packers, self.k_axis, self.mn_axis)])
46    }
47
48    op_as_typed_op!();
49}
50
51impl EvalOp for OptMatMulPack {
52    fn is_stateless(&self) -> bool {
53        true
54    }
55
56    fn eval_with_session(
57        &self,
58        _node_id: usize,
59        session: &TurnState,
60        mut inputs: TVec<TValue>,
61    ) -> TractResult<TVec<TValue>> {
62        self.do_eval(session, inputs.remove(0))
63    }
64}
65
66impl TypedOp for OptMatMulPack {
67    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
68        match self.mode_picker {
69            ModePicker::Single => ensure!(self.packers.len() == 1),
70            ModePicker::VecVsMat => ensure!(self.packers.len() == 2),
71        }
72        let k = inputs[0].shape[self.k_axis].clone();
73        let mn = inputs[0].shape[self.mn_axis].clone();
74        let exotic_fact = DynPackedExoticFact { k, mn, packers: self.packers.clone() };
75        Ok(tvec!(
76            inputs[0]
77                .datum_type
78                .fact(self.output_shape(&inputs[0].shape))
79                .with_exotic_fact(exotic_fact)
80        ))
81    }
82
83    fn axes_mapping(
84        &self,
85        inputs: &[&TypedFact],
86        outputs: &[&TypedFact],
87    ) -> TractResult<AxesMapping> {
88        let mut axes: Vec<Axis> = (0..inputs[0].rank())
89            .filter(|&ix| ix != self.k_axis && ix != self.mn_axis)
90            .enumerate()
91            .zip('a'..)
92            .map(|((o, i), repr)| Axis::new(repr, 1, 1).input(0, i).output(0, o))
93            .collect();
94        axes.push(Axis::new('K', 1, 1).input(0, self.k_axis));
95        axes.push(Axis::new('M', 1, 1).input(0, self.mn_axis));
96        axes.push(Axis::new('P', 1, 1).output(0, outputs[0].rank()));
97        AxesMapping::new(1, 1, axes)
98    }
99
100    as_op!();
101}
102
103impl OptMatMulPack {
104    fn do_eval(&self, _session: &TurnState, input: TValue) -> TractResult<TVec<TValue>> {
105        unsafe {
106            let mode = self.mode_picker.pick(input.shape()[self.mn_axis])?;
107            let packer = &self.packers[mode];
108            let output_shape: TVec<usize> = self.output_shape(input.shape());
109            let stores = if output_shape.iter().all(|d| *d == 1) {
110                let packed = pack_view_with(&**packer, &input.view(), self.k_axis, self.mn_axis)?;
111                PackedMatrixStorage::new_batched(&output_shape, vec![packed])
112                    .into_tensor(input.datum_type())
113            } else {
114                let mut bc_shape: TVec<usize> = input.shape().into();
115                bc_shape[self.k_axis] = 1;
116                bc_shape[self.mn_axis] = 1;
117
118                let mut values: Vec<Box<dyn MMMInputValue>> =
119                    Vec::with_capacity(output_shape.iter().product());
120                for coord in indices(&*bc_shape) {
121                    let offset = coord
122                        .as_array_view()
123                        .iter()
124                        .zip(input.strides())
125                        .map(|(x, s)| *x as isize * s)
126                        .sum::<isize>()
127                        * input.datum_type().size_of() as isize;
128                    values.push(pack_view_with(
129                        &**packer,
130                        &TensorView::from_bytes(&input, offset, input.shape(), input.strides()),
131                        self.k_axis,
132                        self.mn_axis,
133                    )?);
134                }
135                PackedMatrixStorage::new_batched(&output_shape, values)
136                    .into_tensor(input.datum_type())
137            };
138            Ok(tvec!(stores.into_tvalue()))
139        }
140    }
141
142    pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
143        let mut packed_shape: TVec<D> = input.into();
144        packed_shape.remove(self.mn_axis.max(self.k_axis));
145        packed_shape.remove(self.mn_axis.min(self.k_axis));
146        packed_shape
147    }
148}
149
150#[derive(Hash, Clone, Debug, PartialEq, Eq)]
151pub struct DynPackedExoticFact {
152    pub k: TDim,
153    pub mn: TDim,
154    pub packers: Vec<Box<dyn MMMInputFormat>>,
155}
156
157impl ExoticFact for DynPackedExoticFact {
158    fn buffer_sizes(&self) -> TVec<TDim> {
159        let elem_bytes = match self.packers[0].precursor() {
160            WeightType::Plain(dt) => dt.size_of(),
161            // OptMatMulPack only ever carries plain (PackedFormat / PackedI8K4) packers.
162            WeightType::BlockQuant(_) => 1,
163        };
164        tvec!(self.k.clone() * &self.mn * elem_bytes)
165    }
166}
167
168#[derive(Debug, Clone, Hash, Eq, PartialEq)]
169pub struct OptSimpleMatMulPack {
170    pub(crate) packed_format: PackedBlockQuantFormat,
171    pub(crate) k: usize,
172    pub(crate) m: usize,
173}
174
175impl Op for OptSimpleMatMulPack {
176    fn name(&self) -> StaticName {
177        "OptSimpleMatMulPack".into()
178    }
179    op_as_typed_op!();
180}
181
182impl EvalOp for OptSimpleMatMulPack {
183    fn is_stateless(&self) -> bool {
184        true
185    }
186
187    fn state(
188        &self,
189        _session: &TurnState,
190        _node_id: usize,
191    ) -> TractResult<Option<Box<dyn OpState>>> {
192        Ok(None)
193    }
194
195    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
196        let input = args_1!(inputs);
197        let bqs = input.try_storage_as::<BlockQuantStorage>()?;
198        // Leading dims before the last 2 (M, K) are batch/group dims
199        let num_groups: usize = input.shape()[..input.rank().saturating_sub(2)].iter().product();
200        let m_per_group = input.shape()[input.rank() - 2];
201        let k = *input.shape().last().unwrap();
202        let values = (0..num_groups)
203            .map(|g| {
204                let slice = block_quant_slice(bqs.value(), bqs.format(), m_per_group, k, g);
205                let iv: Box<dyn MMMInputValue> = Box::new(self.packed_format.pack(slice, k)?);
206                Ok(iv)
207            })
208            .collect::<TractResult<Vec<_>>>()?;
209        let leading_shape = &input.shape()[..input.rank().saturating_sub(2)];
210        let output =
211            PackedMatrixStorage::new_batched(leading_shape, values).into_tensor(input.datum_type());
212        Ok(tvec!(output.into_tvalue()))
213    }
214}
215
216impl TypedOp for OptSimpleMatMulPack {
217    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
218        let input = inputs[0];
219        // Input shape is [G, M, K] — output removes M and K, keeping leading dims
220        let output_shape: TVec<TDim> = if input.rank() > 2 {
221            input.shape[..input.rank() - 2].to_vec().into()
222        } else {
223            tvec!()
224        };
225        let fact =
226            inputs[0].datum_type.fact(&*output_shape).with_exotic_fact(PackedBlockQuantFact {
227                format: self.packed_format.clone(),
228                shape: tvec!(self.m, self.k),
229            });
230        Ok(tvec!(fact))
231    }
232
233    as_op!();
234}