1use dyn_clone::clone_box;
2use tract_linalg::frame::block_quant::BlockQuantValue;
3use tract_linalg::mmm::WeightType;
4
5use crate::internal::*;
6use crate::ops::array::Gather;
7use crate::ops::einsum::EinSum;
8
9#[derive(Debug, Clone, Hash, Eq, PartialEq)]
10pub struct Const(pub Arc<Tensor>, pub Option<Box<dyn OpaqueFact>>);
11
12impl Const {
13 pub fn new(tensor: Arc<Tensor>) -> Const {
14 Const(tensor, None)
15 }
16
17 pub fn new_with_opaque_fact(tensor: Arc<Tensor>, fact: Box<dyn OpaqueFact>) -> Const {
18 Const(tensor, Some(fact))
19 }
20}
21
22impl Op for Const {
23 fn name(&self) -> Cow<str> {
24 "Const".into()
25 }
26
27 op_as_typed_op!();
28 impl_op_same_as!();
29}
30
31impl EvalOp for Const {
32 fn is_stateless(&self) -> bool {
33 true
34 }
35
36 fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
37 Ok(tvec![Arc::clone(&self.0).into_tvalue()])
38 }
39}
40
41impl TypedOp for Const {
42 as_op!();
43
44 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
45 let fact = TypedFact::from(&self.0);
46 if let Some(opaque) = &self.1 {
47 Ok(tvec!(fact.with_opaque_fact(opaque.clone())))
48 } else {
49 Ok(tvec!(fact))
50 }
51 }
52
53 fn change_axes(
54 &self,
55 _model: &TypedModel,
56 _node: &TypedNode,
57 io: InOut,
58 change: &AxisOp,
59 ) -> TractResult<Option<AxisChangeConsequence>> {
60 anyhow::ensure!(io == InOut::Out(0));
61 let mut new_tensor = self.0.clone().into_tensor();
62 if change.change_tensor(&mut new_tensor, false).is_ok() {
63 Ok(Some(AxisChangeConsequence {
64 substitute_op: Some(Box::new(Const(new_tensor.into_arc_tensor(), self.1.clone()))),
65 wire_changes: tvec!((io, change.clone())),
66 }))
67 } else {
68 Ok(None)
69 }
70 }
71
72 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
73 Ok(tvec!((Cost::Params(self.0.datum_type().unquantized()), self.0.len().into())))
74 }
75
76 fn concretize_dims(
77 &self,
78 _source: &TypedModel,
79 node: &TypedNode,
80 target: &mut TypedModel,
81 _mapping: &HashMap<OutletId, OutletId>,
82 values: &SymbolValues,
83 ) -> TractResult<TVec<OutletId>> {
84 let op = if self.0.datum_type() == TDim::datum_type() {
85 let mut tensor = self.0.clone().into_tensor();
86 for d in tensor.as_slice_mut::<TDim>()? {
87 *d = d.eval(values);
88 }
89 Const(tensor.into_arc_tensor(), self.1.clone())
90 } else {
91 self.clone()
92 };
93 target.wire_node(&node.name, op, &[])
94 }
95
96 fn codegen(
97 &self,
98 model: &TypedModel,
99 node: &TypedNode,
100 ) -> TractResult<Option<TypedModelPatch>> {
101 let looks_like_weights = (self.0.datum_type().is_number() && self.0.rank() == 2)
102 || (self.0.to_scalar::<Opaque>().is_ok_and(|opaque| opaque.is::<BlockQuantValue>()));
103 if !looks_like_weights {
104 return Ok(None);
105 }
106 let mut have_abstract_einsum = false;
107 for succ in &node.outputs[0].successors {
108 let snode = model.node(succ.node);
109 if let Some(gather) = snode.op_as::<Gather>() {
110 if succ.slot != 0 || gather.axis != 0 {
111 return Ok(None);
112 }
113 } else if let Some(einsum) = snode.op_as::<EinSum>() {
114 if succ.slot != 0 || snode.inputs.len() != 2 {
115 return Ok(None);
116 }
117 let m_axis = einsum.axes.axis((InOut::In(0), 0))?;
118 if m_axis.inputs[0].len() != 1
119 || m_axis.inputs[1].len() != 0
120 || m_axis.outputs[0].len() != 1
121 {
122 return Ok(None);
123 }
124 let k_axis = einsum.axes.axis((InOut::In(0), 1))?;
125 if k_axis.inputs[0].len() != 1
126 || k_axis.inputs[1].len() != 1
127 || k_axis.outputs[0].len() != 0
128 {
129 return Ok(None);
130 }
131 for axis in einsum.axes.iter_all_axes() {
132 if axis != k_axis
133 && axis != m_axis
134 && axis.inputs[0].len() == 0
135 && axis.inputs[1].len() == 1
136 && axis.outputs[0].len() == 1
137 && snode.outputs[0].fact.shape[axis.outputs[0][0]].as_i64().is_none()
138 {
139 have_abstract_einsum = true;
140 }
141 }
142 } else {
143 return Ok(None);
144 }
145 }
146 if node.outputs[0].successors.len() > 1 || have_abstract_einsum {
147 let weight =
148 self.0.to_scalar::<Opaque>().ok().and_then(|a| a.downcast_ref::<BlockQuantValue>());
149 let weight_type = if let Some(a_payload) = weight {
150 WeightType::BlockQuant(a_payload.fact.format.clone())
151 } else {
152 WeightType::Plain(self.0.datum_type())
153 };
154 let format = tract_linalg::ops().kit_input_format(weight_type);
155 let packed = format.prepare_tensor(&self.0, 1, 0)?;
156 let fact = clone_box(packed.opaque_fact());
157 let opaque = Opaque(Arc::new(packed));
158 let konst = Const(rctensor0(opaque), Some(fact));
159 let mut patch = TypedModelPatch::new(format!("Versatile packing {node}"));
160 let konst = patch.wire_node(&node.name, konst, &[])?;
161 for succ in &node.outputs[0].successors {
162 let succ_node = model.node(succ.node);
163 let mut taps = patch.taps(model, &succ_node.inputs)?;
164 taps[succ.slot] = konst[0];
165 let replacement = patch.wire_node(&succ_node.name, succ_node.op.clone(), &taps)?;
166 patch.shunt_outside(model, succ.node.into(), replacement[0])?;
167 }
168 return Ok(Some(patch));
169 }
170 Ok(None)
171 }
172}