tract_core/ops/
konst.rs

1use dyn_clone::clone_box;
2use tract_itertools::Itertools;
3use tract_linalg::block_quant::BlockQuantValue;
4use tract_linalg::mmm::MMMInputValue;
5
6use crate::internal::*;
7use crate::ops::array::Gather;
8use crate::ops::einsum::EinSum;
9
10use super::einsum::optimize::EinSumAnnotatedAsLinear;
11
12#[derive(Debug, Clone, Hash, Eq, PartialEq)]
13pub struct Const(Arc<Tensor>, Option<Box<dyn OpaqueFact>>);
14
15impl Const {
16    pub fn new(tensor: Arc<Tensor>) -> TractResult<Const> {
17        Self::new_with_opt_opaque_fact(tensor, None)
18    }
19
20    pub fn new_with_opaque_fact(
21        tensor: Arc<Tensor>,
22        fact: Box<dyn OpaqueFact>,
23    ) -> TractResult<Const> {
24        Self::new_with_opt_opaque_fact(tensor, Some(fact))
25    }
26
27    pub fn new_with_opt_opaque_fact(
28        tensor: Arc<Tensor>,
29        fact: Option<Box<dyn OpaqueFact>>,
30    ) -> TractResult<Const> {
31        ensure!(fact.is_some() == tensor.datum_type().is_opaque());
32        Ok(Const(tensor, fact))
33    }
34
35    pub fn val(&self) -> &Arc<Tensor> {
36        &self.0
37    }
38
39    pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
40        self.1.as_deref()
41    }
42}
43
44impl Op for Const {
45    fn name(&self) -> Cow<str> {
46        "Const".into()
47    }
48
49    op_as_typed_op!();
50    impl_op_same_as!();
51}
52
53impl EvalOp for Const {
54    fn is_stateless(&self) -> bool {
55        true
56    }
57
58    fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
59        Ok(tvec![Arc::clone(&self.0).into_tvalue()])
60    }
61}
62
63impl TypedOp for Const {
64    as_op!();
65
66    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67        let fact = TypedFact::from(&self.0);
68        if let Some(opaque) = &self.1 {
69            Ok(tvec!(fact.with_opaque_fact(opaque.clone())))
70        } else {
71            Ok(tvec!(fact))
72        }
73    }
74
75    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
76        Ok(tvec!((Cost::Params(self.0.datum_type().unquantized()), self.0.len().into())))
77    }
78
79    fn concretize_dims(
80        &self,
81        _source: &TypedModel,
82        node: &TypedNode,
83        target: &mut TypedModel,
84        _mapping: &HashMap<OutletId, OutletId>,
85        values: &SymbolValues,
86    ) -> TractResult<TVec<OutletId>> {
87        let op = if self.0.datum_type() == TDim::datum_type() {
88            let mut tensor = self.0.clone().into_tensor();
89            for d in tensor.as_slice_mut::<TDim>()? {
90                *d = d.eval(values);
91            }
92            Const(tensor.into_arc_tensor(), self.1.clone())
93        } else {
94            self.clone()
95        };
96        target.wire_node(&node.name, op, &[])
97    }
98
99    fn change_axes(
100        &self,
101        _model: &TypedModel,
102        _node: &TypedNode,
103        io: InOut,
104        change: &AxisOp,
105    ) -> TractResult<Option<AxisChangeConsequence>> {
106        anyhow::ensure!(io == InOut::Out(0));
107        let mut new_tensor = self.0.clone().into_tensor();
108        if change.change_tensor(&mut new_tensor, false).is_ok() {
109            let mut sub = Const(new_tensor.into_arc_tensor(), None);
110            if self.1.is_some() {
111                let my_fact = self.output_facts(&[])?;
112                let changed_fact = change.output_facts(&[&my_fact[0]])?;
113                sub.1 = changed_fact[0].opaque_fact.clone();
114            }
115            Ok(Some(AxisChangeConsequence {
116                substitute_op: Some(Box::new(sub)),
117                wire_changes: tvec!((io, change.clone())),
118            }))
119        } else {
120            Ok(None)
121        }
122    }
123
124    fn codegen(
125        &self,
126        model: &TypedModel,
127        node: &TypedNode,
128    ) -> TractResult<Option<TypedModelPatch>> {
129        let looks_like_weights = (self.0.datum_type().is_number() && self.0.rank() == 2)
130            || (self.0.to_scalar::<Opaque>().is_ok_and(|opaque| opaque.is::<BlockQuantValue>()));
131        if !looks_like_weights {
132            return Ok(None);
133        }
134        let mut matmuls = vec![];
135        for succ in &node.outputs[0].successors {
136            let snode = model.node(succ.node);
137            if let Some(gather) = snode.op_as::<Gather>() {
138                if succ.slot != 0 || gather.axis != 0 {
139                    return Ok(None);
140                }
141            } else if let Some(einsum) = snode.op_as::<EinSum>() {
142                if let Some(linear) = EinSumAnnotatedAsLinear::from(model, snode, einsum)? {
143                    matmuls.push(linear);
144                } else {
145                    return Ok(None);
146                }
147            } else {
148                return Ok(None);
149            }
150        }
151        if matmuls.len() == 0 {
152            return Ok(None);
153        }
154
155        ensure!(matmuls.iter().map(|linear| linear.m_axis.inputs[0][0]).all_equal());
156        ensure!(matmuls.iter().map(|linear| linear.k_axis.inputs[0][0]).all_equal());
157
158        let m_axis = matmuls[0].m_axis.inputs[0][0];
159        let k_axis = matmuls[0].k_axis.inputs[0][0];
160        let must_swap = m_axis > k_axis;
161
162        let ops = tract_linalg::ops();
163        let (choice,) = matmuls
164            .iter()
165            .map(|mm| mm.preferred_packing())
166            .dedup_by(|a, b| a.same_as(&**b))
167            .collect_tuple::<(_,)>()
168            .unwrap_or_else(|| {
169                let it = ops
170                    .all_possible_packing(matmuls[0].weight_type.clone())
171                    .min_by_key(|format| {
172                        matmuls
173                            .iter()
174                            .map(|linear| linear.cost_for_weights(&**format))
175                            .max()
176                            .unwrap()
177                    })
178                    .unwrap();
179                (clone_box(it),)
180            });
181
182        let packed = choice.prepare_tensor(&self.0, k_axis, m_axis).context("in prepare_tensor")?;
183        let fact = clone_box(
184            packed.as_slice::<Opaque>()?[0]
185                .downcast_ref::<Box<dyn MMMInputValue>>()
186                .unwrap()
187                .opaque_fact(),
188        );
189        let konst = Const(packed.into_arc_tensor(), Some(fact));
190        let mut patch = TypedModelPatch::new(format!("Packing {node} as {choice:?}"));
191        let konst = patch.wire_node(&node.name, konst, &[])?;
192        for succ in &node.outputs[0].successors {
193            let succ_node = model.node(succ.node);
194            let mut taps = patch.taps(model, &succ_node.inputs)?;
195            taps[succ.slot] = konst[0];
196            let new_op: Box<dyn TypedOp> = if let Some(gather) = succ_node.op_as::<Gather>() {
197                let output_type = succ_node.outputs[0].fact.datum_type;
198                Box::new(Gather { axis: gather.axis, output_type: Some(output_type) })
199            } else if let Some(linear) = succ_node.op_as::<EinSum>() {
200                let mut op = linear.clone();
201                if must_swap {
202                    op.axes.iter_all_axes_mut().for_each(|axes| {
203                        axes.inputs[0].iter_mut().for_each(|pos| {
204                            *pos = if *pos == k_axis {
205                                m_axis
206                            } else if *pos == m_axis {
207                                k_axis
208                            } else {
209                                *pos
210                            }
211                        })
212                    });
213                }
214                Box::new(op)
215            } else {
216                bail!("Unexpected op")
217            };
218            let replacement = patch.wire_node(&succ_node.name, new_op, &taps)?;
219            patch.shunt_outside(model, succ.node.into(), replacement[0])?;
220        }
221        Ok(Some(patch))
222    }
223}