tract_core/ops/
konst.rs

1use dyn_clone::clone_box;
2use tract_linalg::frame::block_quant::BlockQuantValue;
3use tract_linalg::mmm::WeightType;
4
5use crate::internal::*;
6use crate::ops::array::Gather;
7use crate::ops::einsum::EinSum;
8
9#[derive(Debug, Clone, Hash, Eq, PartialEq)]
10pub struct Const(pub Arc<Tensor>, pub Option<Box<dyn OpaqueFact>>);
11
12impl Const {
13    pub fn new(tensor: Arc<Tensor>) -> Const {
14        Const(tensor, None)
15    }
16
17    pub fn new_with_opaque_fact(tensor: Arc<Tensor>, fact: Box<dyn OpaqueFact>) -> Const {
18        Const(tensor, Some(fact))
19    }
20}
21
22impl Op for Const {
23    fn name(&self) -> Cow<str> {
24        "Const".into()
25    }
26
27    op_as_typed_op!();
28    impl_op_same_as!();
29}
30
31impl EvalOp for Const {
32    fn is_stateless(&self) -> bool {
33        true
34    }
35
36    fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
37        Ok(tvec![Arc::clone(&self.0).into_tvalue()])
38    }
39}
40
41impl TypedOp for Const {
42    as_op!();
43
44    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
45        let fact = TypedFact::from(&self.0);
46        if let Some(opaque) = &self.1 {
47            Ok(tvec!(fact.with_opaque_fact(opaque.clone())))
48        } else {
49            Ok(tvec!(fact))
50        }
51    }
52
53    fn change_axes(
54        &self,
55        _model: &TypedModel,
56        _node: &TypedNode,
57        io: InOut,
58        change: &AxisOp,
59    ) -> TractResult<Option<AxisChangeConsequence>> {
60        anyhow::ensure!(io == InOut::Out(0));
61        let mut new_tensor = self.0.clone().into_tensor();
62        if change.change_tensor(&mut new_tensor, false).is_ok() {
63            Ok(Some(AxisChangeConsequence {
64                substitute_op: Some(Box::new(Const(new_tensor.into_arc_tensor(), self.1.clone()))),
65                wire_changes: tvec!((io, change.clone())),
66            }))
67        } else {
68            Ok(None)
69        }
70    }
71
72    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
73        Ok(tvec!((Cost::Params(self.0.datum_type().unquantized()), self.0.len().into())))
74    }
75
76    fn concretize_dims(
77        &self,
78        _source: &TypedModel,
79        node: &TypedNode,
80        target: &mut TypedModel,
81        _mapping: &HashMap<OutletId, OutletId>,
82        values: &SymbolValues,
83    ) -> TractResult<TVec<OutletId>> {
84        let op = if self.0.datum_type() == TDim::datum_type() {
85            let mut tensor = self.0.clone().into_tensor();
86            for d in tensor.as_slice_mut::<TDim>()? {
87                *d = d.eval(values);
88            }
89            Const(tensor.into_arc_tensor(), self.1.clone())
90        } else {
91            self.clone()
92        };
93        target.wire_node(&node.name, op, &[])
94    }
95
96    fn codegen(
97        &self,
98        model: &TypedModel,
99        node: &TypedNode,
100    ) -> TractResult<Option<TypedModelPatch>> {
101        let looks_like_weights = (self.0.datum_type().is_number() && self.0.rank() == 2)
102            || (self.0.to_scalar::<Opaque>().is_ok_and(|opaque| opaque.is::<BlockQuantValue>()));
103        if !looks_like_weights {
104            return Ok(None);
105        }
106        let mut have_abstract_einsum = false;
107        for succ in &node.outputs[0].successors {
108            let snode = model.node(succ.node);
109            if let Some(gather) = snode.op_as::<Gather>() {
110                if succ.slot != 0 || gather.axis != 0 {
111                    return Ok(None);
112                }
113            } else if let Some(einsum) = snode.op_as::<EinSum>() {
114                if succ.slot != 0 || snode.inputs.len() != 2 {
115                    return Ok(None);
116                }
117                let m_axis = einsum.axes.axis((InOut::In(0), 0))?;
118                if m_axis.inputs[0].len() != 1
119                    || m_axis.inputs[1].len() != 0
120                    || m_axis.outputs[0].len() != 1
121                {
122                    return Ok(None);
123                }
124                let k_axis = einsum.axes.axis((InOut::In(0), 1))?;
125                if k_axis.inputs[0].len() != 1
126                    || k_axis.inputs[1].len() != 1
127                    || k_axis.outputs[0].len() != 0
128                {
129                    return Ok(None);
130                }
131                for axis in einsum.axes.iter_all_axes() {
132                    if axis != k_axis
133                        && axis != m_axis
134                        && axis.inputs[0].len() == 0
135                        && axis.inputs[1].len() == 1
136                        && axis.outputs[0].len() == 1
137                        && snode.outputs[0].fact.shape[axis.outputs[0][0]].as_i64().is_none()
138                    {
139                        have_abstract_einsum = true;
140                    }
141                }
142            } else {
143                return Ok(None);
144            }
145        }
146        if node.outputs[0].successors.len() > 1 || have_abstract_einsum {
147            let weight =
148                self.0.to_scalar::<Opaque>().ok().and_then(|a| a.downcast_ref::<BlockQuantValue>());
149            let weight_type = if let Some(a_payload) = weight {
150                WeightType::BlockQuant(a_payload.fact.format.clone())
151            } else {
152                WeightType::Plain(self.0.datum_type())
153            };
154            let format = tract_linalg::ops().kit_input_format(weight_type);
155            let packed = format.prepare_tensor(&self.0, 1, 0)?;
156            let fact = clone_box(packed.opaque_fact());
157            let opaque = Opaque(Arc::new(packed));
158            let konst = Const(rctensor0(opaque), Some(fact));
159            let mut patch = TypedModelPatch::new(format!("Versatile packing {node}"));
160            let konst = patch.wire_node(&node.name, konst, &[])?;
161            for succ in &node.outputs[0].successors {
162                let succ_node = model.node(succ.node);
163                let mut taps = patch.taps(model, &succ_node.inputs)?;
164                taps[succ.slot] = konst[0];
165                let replacement = patch.wire_node(&succ_node.name, succ_node.op.clone(), &taps)?;
166                patch.shunt_outside(model, succ.node.into(), replacement[0])?;
167            }
168            return Ok(Some(patch));
169        }
170        Ok(None)
171    }
172}