Skip to main content

tract_core/ops/matmul/
pack.rs

1use crate::axes::Axis;
2use crate::internal::*;
3use ndarray::*;
4use tract_linalg::block_quant::{BlockQuantFact, PackedBlockQuantFact, PackedBlockQuantFormat};
5use tract_linalg::mmm::MMMInputValue;
6use tract_linalg::pack::PackedFormat;
7
8use super::ModePicker;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct OptMatMulPack {
12    pub(crate) packers: Vec<PackedFormat>,
13    pub(crate) mode_picker: ModePicker,
14    pub(crate) k_axis: usize,
15    pub(crate) mn_axis: usize,
16}
17
18impl Op for OptMatMulPack {
19    fn name(&self) -> StaticName {
20        "OptMatMulPack".into()
21    }
22
23    fn info(&self) -> TractResult<Vec<String>> {
24        Ok(vec![format!("{:?}. k axis: {}, mn axis: {}", self.packers, self.k_axis, self.mn_axis)])
25    }
26
27    op_as_typed_op!();
28    impl_op_same_as!();
29}
30
31impl EvalOp for OptMatMulPack {
32    fn is_stateless(&self) -> bool {
33        true
34    }
35
36    fn eval_with_session(
37        &self,
38        _node_id: usize,
39        session: &TurnState,
40        mut inputs: TVec<TValue>,
41    ) -> TractResult<TVec<TValue>> {
42        self.do_eval(session, inputs.remove(0))
43    }
44}
45
46impl TypedOp for OptMatMulPack {
47    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
48        match self.mode_picker {
49            ModePicker::Single => ensure!(self.packers.len() == 1),
50            ModePicker::VecVsMat => ensure!(self.packers.len() == 2),
51        }
52        let k = inputs[0].shape[self.k_axis].clone();
53        let mn = inputs[0].shape[self.mn_axis].clone();
54        let opaque_fact = DynPackedOpaqueFact { k, mn, packers: self.packers.clone() };
55        Ok(tvec!(
56            Opaque::datum_type()
57                .fact(self.output_shape(&inputs[0].shape))
58                .with_opaque_fact(opaque_fact)
59        ))
60    }
61
62    fn axes_mapping(
63        &self,
64        inputs: &[&TypedFact],
65        outputs: &[&TypedFact],
66    ) -> TractResult<AxesMapping> {
67        let mut axes: Vec<Axis> = (0..inputs[0].rank())
68            .filter(|&ix| ix != self.k_axis && ix != self.mn_axis)
69            .enumerate()
70            .zip('a'..)
71            .map(|((o, i), repr)| Axis::new(repr, 1, 1).input(0, i).output(0, o))
72            .collect();
73        axes.push(Axis::new('K', 1, 1).input(0, self.k_axis));
74        axes.push(Axis::new('M', 1, 1).input(0, self.mn_axis));
75        axes.push(Axis::new('P', 1, 1).output(0, outputs[0].rank()));
76        AxesMapping::new(1, 1, axes)
77    }
78
79    as_op!();
80}
81
82impl OptMatMulPack {
83    fn do_eval(&self, _session: &TurnState, input: TValue) -> TractResult<TVec<TValue>> {
84        unsafe {
85            let mode = self.mode_picker.pick(input.shape()[self.mn_axis])?;
86            let packer = &self.packers[mode];
87            let output_shape: TVec<usize> = self.output_shape(input.shape());
88            let stores = if output_shape.iter().all(|d| *d == 1) {
89                tensor0::<Opaque>(
90                    packer.pack_tensor_view(&input.view(), self.k_axis, self.mn_axis)?.into(),
91                )
92                .into_shape(&output_shape)?
93            } else {
94                let mut stores = Tensor::uninitialized_dt(Opaque::datum_type(), &output_shape)?;
95                let mut stores_dense = stores.try_as_dense_mut()?;
96                let mut stores_view = stores_dense.to_array_view_mut::<Opaque>()?;
97                let mut bc_shape: TVec<usize> = input.shape().into();
98                bc_shape[self.k_axis] = 1;
99                bc_shape[self.mn_axis] = 1;
100
101                for coord in indices(&*bc_shape) {
102                    let offset = coord
103                        .as_array_view()
104                        .iter()
105                        .zip(input.strides())
106                        .map(|(x, s)| *x as isize * s)
107                        .sum::<isize>()
108                        * input.datum_type().size_of() as isize;
109                    let mut pack_coords: TVec<usize> = coord.slice().into();
110                    pack_coords.remove(self.k_axis.max(self.mn_axis));
111                    pack_coords.remove(self.k_axis.min(self.mn_axis));
112                    stores_view[&*pack_coords] = packer
113                        .pack_tensor_view(
114                            &TensorView::from_bytes(&input, offset, input.shape(), input.strides()),
115                            self.k_axis,
116                            self.mn_axis,
117                        )?
118                        .into();
119                }
120                stores
121            };
122            Ok(tvec!(stores.into_tvalue()))
123        }
124    }
125
126    pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
127        let mut packed_shape: TVec<D> = input.into();
128        packed_shape.remove(self.mn_axis.max(self.k_axis));
129        packed_shape.remove(self.mn_axis.min(self.k_axis));
130        packed_shape
131    }
132}
133
134#[derive(Hash, Clone, Debug, PartialEq, Eq)]
135pub struct DynPackedOpaqueFact {
136    pub k: TDim,
137    pub mn: TDim,
138    pub packers: Vec<PackedFormat>,
139}
140
141impl OpaqueFact for DynPackedOpaqueFact {
142    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
143        other.downcast_ref::<Self>().is_some_and(|o| o == self)
144    }
145
146    fn buffer_sizes(&self) -> TVec<TDim> {
147        tvec!(self.k.clone() * &self.mn * self.packers[0].dt.size_of())
148    }
149}
150
151#[derive(Debug, Clone, Hash, Eq, PartialEq)]
152pub struct OptSimpleMatMulPack {
153    pub(crate) packed_format: PackedBlockQuantFormat,
154    pub(crate) k: usize,
155    pub(crate) m: usize,
156}
157
158impl Op for OptSimpleMatMulPack {
159    fn name(&self) -> StaticName {
160        "OptSimpleMatMulPack".into()
161    }
162    op_as_typed_op!();
163    impl_op_same_as!();
164}
165
166impl EvalOp for OptSimpleMatMulPack {
167    fn is_stateless(&self) -> bool {
168        true
169    }
170
171    fn state(
172        &self,
173        _session: &TurnState,
174        _node_id: usize,
175    ) -> TractResult<Option<Box<dyn OpState>>> {
176        Ok(None)
177    }
178
179    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
180        let input = args_1!(inputs);
181        let mut output = tensor1(
182            &input
183                .try_as_dense()?
184                .as_slice::<Opaque>()?
185                .iter()
186                .map(|i| {
187                    let i = i.downcast_ref::<BlobWithFact>().context("Expected BlockWithFact")?;
188                    let i_bqf = i
189                        .fact
190                        .downcast_ref::<BlockQuantFact>()
191                        .context("Expected BlockQuantFact")?;
192                    let iv: Box<dyn MMMInputValue> =
193                        Box::new(self.packed_format.pack(&i.value, i_bqf.k())?);
194                    Ok(Opaque(Arc::new(iv)))
195                })
196                .collect::<TractResult<Vec<_>>>()?,
197        );
198        output.set_shape(input.shape())?;
199        Ok(tvec!(output.into_tvalue()))
200    }
201}
202
203impl TypedOp for OptSimpleMatMulPack {
204    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
205        let fact = Opaque::fact(inputs[0].shape.clone()).with_opaque_fact(PackedBlockQuantFact {
206            format: self.packed_format.clone(),
207            shape: tvec!(self.m, self.k),
208        });
209        Ok(tvec!(fact))
210    }
211
212    as_op!();
213}