tract_core/ops/matmul/
pack.rs1use crate::axes::Axis;
2use crate::internal::*;
3use ndarray::*;
4use tract_linalg::block_quant::{BlockQuantValue, PackedBlockQuantFact, PackedBlockQuantFormat};
5use tract_linalg::mmm::MMMInputValue;
6use tract_linalg::pack::PackedFormat;
7
8use super::ModePicker;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub struct OptMatMulPack {
12 pub(crate) packers: Vec<PackedFormat>,
13 pub(crate) mode_picker: ModePicker,
14 pub(crate) k_axis: usize,
15 pub(crate) mn_axis: usize,
16}
17
18impl Op for OptMatMulPack {
19 fn name(&self) -> Cow<str> {
20 "OptMatMulPack".into()
21 }
22
23 fn info(&self) -> TractResult<Vec<String>> {
24 Ok(vec![format!("{:?}. k axis: {}, mn axis: {}", self.packers, self.k_axis, self.mn_axis)])
25 }
26
27 op_as_typed_op!();
28 impl_op_same_as!();
29}
30
31impl EvalOp for OptMatMulPack {
32 fn is_stateless(&self) -> bool {
33 true
34 }
35
36 fn eval_with_session(
37 &self,
38 session: &SessionState,
39 mut inputs: TVec<TValue>,
40 ) -> TractResult<TVec<TValue>> {
41 self.do_eval(session, inputs.remove(0))
42 }
43}
44
45impl TypedOp for OptMatMulPack {
46 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
47 let k = inputs[0].shape[self.k_axis].clone();
48 let mn = inputs[0].shape[self.mn_axis].clone();
49 let opaque_fact = DynPackedOpaqueFact { k, mn, packers: self.packers.clone() };
50 Ok(tvec!(Opaque::datum_type()
51 .fact(self.output_shape(&inputs[0].shape))
52 .with_opaque_fact(opaque_fact)))
53 }
54
55 fn axes_mapping(
56 &self,
57 inputs: &[&TypedFact],
58 outputs: &[&TypedFact],
59 ) -> TractResult<AxesMapping> {
60 let mut axes: Vec<Axis> = (0..inputs[0].rank())
61 .filter(|&ix| ix != self.k_axis && ix != self.mn_axis)
62 .enumerate()
63 .zip('a'..)
64 .map(|((o, i), repr)| Axis::new(repr, 1, 1).input(0, i).output(0, o))
65 .collect();
66 axes.push(Axis::new('K', 1, 1).input(0, self.k_axis));
67 axes.push(Axis::new('M', 1, 1).input(0, self.mn_axis));
68 axes.push(Axis::new('P', 1, 1).output(0, outputs[0].rank()));
69 AxesMapping::new(1, 1, axes)
70 }
71
72 as_op!();
73}
74
75impl OptMatMulPack {
76 fn do_eval(&self, _session: &SessionState, input: TValue) -> TractResult<TVec<TValue>> {
77 unsafe {
78 let mode = self.mode_picker.pick(input.shape()[self.mn_axis])?;
79 let packer = &self.packers[mode];
80 let output_shape: TVec<usize> = self.output_shape(input.shape());
81 let stores = if output_shape.iter().all(|d| *d == 1) {
82 tensor0::<Opaque>(
83 packer.pack_tensor_view(&input.view(), self.k_axis, self.mn_axis)?.into(),
84 )
85 .into_shape(&output_shape)?
86 } else {
87 let mut stores = Tensor::uninitialized_dt(Opaque::datum_type(), &output_shape)?;
88 let mut stores_view = stores.to_array_view_mut::<Opaque>()?;
89 let mut bc_shape: TVec<usize> = input.shape().into();
90 bc_shape[self.k_axis] = 1;
91 bc_shape[self.mn_axis] = 1;
92
93 for coord in indices(&*bc_shape) {
94 let offset = coord
95 .as_array_view()
96 .iter()
97 .zip(input.strides())
98 .map(|(x, s)| *x as isize * s)
99 .sum::<isize>()
100 * input.datum_type().size_of() as isize;
101 let mut pack_coords: TVec<usize> = coord.slice().into();
102 pack_coords.remove(self.k_axis.max(self.mn_axis));
103 pack_coords.remove(self.k_axis.min(self.mn_axis));
104 stores_view[&*pack_coords] = packer
105 .pack_tensor_view(
106 &TensorView::from_bytes(&input, offset, input.shape(), input.strides()),
107 self.k_axis,
108 self.mn_axis,
109 )?
110 .into();
111 }
112 stores
113 };
114 Ok(tvec!(stores.into_tvalue()))
115 }
116 }
117
118 pub fn output_shape<D: DimLike>(&self, input: &[D]) -> TVec<D> {
119 let mut packed_shape: TVec<D> = input.into();
120 packed_shape.remove(self.mn_axis.max(self.k_axis));
121 packed_shape.remove(self.mn_axis.min(self.k_axis));
122 packed_shape
123 }
124}
125
126#[derive(Hash, Clone, Debug, PartialEq, Eq)]
127pub struct DynPackedOpaqueFact {
128 pub k: TDim,
129 pub mn: TDim,
130 pub packers: Vec<PackedFormat>,
131}
132
133impl OpaqueFact for DynPackedOpaqueFact {
134 fn mem_size(&self) -> TDim {
135 self.k.clone() * &self.mn * self.packers[0].dt.size_of()
136 }
137
138 fn same_as(&self, other: &dyn OpaqueFact) -> bool {
139 other.downcast_ref::<Self>().is_some_and(|o| o == self)
140 }
141}
142
143#[derive(Debug, Clone, Hash, Eq, PartialEq)]
144pub struct OptSimpleMatMulPack {
145 pub(crate) packed_format: PackedBlockQuantFormat,
146 pub(crate) k: usize,
147 pub(crate) m: usize,
148}
149
150impl Op for OptSimpleMatMulPack {
151 fn name(&self) -> Cow<str> {
152 "OptSimpleMatMulPack".into()
153 }
154 op_as_typed_op!();
155}
156
157impl EvalOp for OptSimpleMatMulPack {
158 fn is_stateless(&self) -> bool {
159 true
160 }
161
162 fn state(
163 &self,
164 _session: &mut SessionState,
165 _node_id: usize,
166 ) -> TractResult<Option<Box<dyn OpState>>> {
167 Ok(None)
168 }
169
170 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
171 let input = args_1!(inputs);
172 let mut output = tensor1(
173 &input
174 .as_slice::<Opaque>()?
175 .iter()
176 .map(|i| {
177 let i = i.downcast_ref::<BlockQuantValue>().unwrap();
178 let iv: Box<dyn MMMInputValue> =
179 Box::new(self.packed_format.pack(&i.value, i.fact.k())?);
180 Ok(Opaque(Arc::new(iv)))
181 })
182 .collect::<TractResult<Vec<_>>>()?,
183 );
184 output.set_shape(input.shape())?;
185 Ok(tvec!(output.into_tvalue()))
186 }
187}
188
189impl TypedOp for OptSimpleMatMulPack {
190 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
191 let fact = Opaque::fact(inputs[0].shape.clone()).with_opaque_fact(PackedBlockQuantFact {
192 format: self.packed_format.clone(),
193 shape: tvec!(self.m, self.k),
194 });
195 Ok(tvec!(fact))
196 }
197
198 as_op!();
199}