1use dyn_clone::clone_box;
2use tract_itertools::Itertools;
3use tract_linalg::block_quant::BlockQuantValue;
4use tract_linalg::mmm::MMMInputValue;
5
6use crate::internal::*;
7use crate::ops::array::Gather;
8use crate::ops::einsum::EinSum;
9
10use super::einsum::optimize::EinSumAnnotatedAsLinear;
11
12#[derive(Debug, Clone, Hash, Eq, PartialEq)]
13pub struct Const(Arc<Tensor>, Option<Box<dyn OpaqueFact>>);
14
15impl Const {
16 pub fn new(tensor: Arc<Tensor>) -> TractResult<Const> {
17 Self::new_with_opt_opaque_fact(tensor, None)
18 }
19
20 pub fn new_with_opaque_fact(
21 tensor: Arc<Tensor>,
22 fact: Box<dyn OpaqueFact>,
23 ) -> TractResult<Const> {
24 Self::new_with_opt_opaque_fact(tensor, Some(fact))
25 }
26
27 pub fn new_with_opt_opaque_fact(
28 tensor: Arc<Tensor>,
29 fact: Option<Box<dyn OpaqueFact>>,
30 ) -> TractResult<Const> {
31 ensure!(fact.is_some() == tensor.datum_type().is_opaque());
32 Ok(Const(tensor, fact))
33 }
34
35 pub fn val(&self) -> &Arc<Tensor> {
36 &self.0
37 }
38
39 pub fn opaque_fact(&self) -> Option<&dyn OpaqueFact> {
40 self.1.as_deref()
41 }
42}
43
44impl Op for Const {
45 fn name(&self) -> Cow<str> {
46 "Const".into()
47 }
48
49 op_as_typed_op!();
50 impl_op_same_as!();
51}
52
53impl EvalOp for Const {
54 fn is_stateless(&self) -> bool {
55 true
56 }
57
58 fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
59 Ok(tvec![Arc::clone(&self.0).into_tvalue()])
60 }
61}
62
63impl TypedOp for Const {
64 as_op!();
65
66 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67 let fact = TypedFact::from(&self.0);
68 if let Some(opaque) = &self.1 {
69 Ok(tvec!(fact.with_opaque_fact(opaque.clone())))
70 } else {
71 Ok(tvec!(fact))
72 }
73 }
74
75 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
76 Ok(tvec!((Cost::Params(self.0.datum_type().unquantized()), self.0.len().into())))
77 }
78
79 fn concretize_dims(
80 &self,
81 _source: &TypedModel,
82 node: &TypedNode,
83 target: &mut TypedModel,
84 _mapping: &HashMap<OutletId, OutletId>,
85 values: &SymbolValues,
86 ) -> TractResult<TVec<OutletId>> {
87 let op = if self.0.datum_type() == TDim::datum_type() {
88 let mut tensor = self.0.clone().into_tensor();
89 for d in tensor.as_slice_mut::<TDim>()? {
90 *d = d.eval(values);
91 }
92 Const(tensor.into_arc_tensor(), self.1.clone())
93 } else {
94 self.clone()
95 };
96 target.wire_node(&node.name, op, &[])
97 }
98
99 fn change_axes(
100 &self,
101 _model: &TypedModel,
102 _node: &TypedNode,
103 io: InOut,
104 change: &AxisOp,
105 ) -> TractResult<Option<AxisChangeConsequence>> {
106 anyhow::ensure!(io == InOut::Out(0));
107 let mut new_tensor = self.0.clone().into_tensor();
108 if change.change_tensor(&mut new_tensor, false).is_ok() {
109 let mut sub = Const(new_tensor.into_arc_tensor(), None);
110 if self.1.is_some() {
111 let my_fact = self.output_facts(&[])?;
112 let changed_fact = change.output_facts(&[&my_fact[0]])?;
113 sub.1 = changed_fact[0].opaque_fact.clone();
114 }
115 Ok(Some(AxisChangeConsequence {
116 substitute_op: Some(Box::new(sub)),
117 wire_changes: tvec!((io, change.clone())),
118 }))
119 } else {
120 Ok(None)
121 }
122 }
123
124 fn codegen(
125 &self,
126 model: &TypedModel,
127 node: &TypedNode,
128 ) -> TractResult<Option<TypedModelPatch>> {
129 let looks_like_weights = (self.0.datum_type().is_number() && self.0.rank() == 2)
130 || (self.0.to_scalar::<Opaque>().is_ok_and(|opaque| opaque.is::<BlockQuantValue>()));
131 if !looks_like_weights {
132 return Ok(None);
133 }
134 let mut matmuls = vec![];
135 for succ in &node.outputs[0].successors {
136 let snode = model.node(succ.node);
137 if let Some(gather) = snode.op_as::<Gather>() {
138 if succ.slot != 0 || gather.axis != 0 {
139 return Ok(None);
140 }
141 } else if let Some(einsum) = snode.op_as::<EinSum>() {
142 if let Some(linear) = EinSumAnnotatedAsLinear::from(model, snode, einsum)? {
143 matmuls.push(linear);
144 } else {
145 return Ok(None);
146 }
147 } else {
148 return Ok(None);
149 }
150 }
151 if matmuls.len() == 0 {
152 return Ok(None);
153 }
154
155 ensure!(matmuls.iter().map(|linear| linear.m_axis.inputs[0][0]).all_equal());
156 ensure!(matmuls.iter().map(|linear| linear.k_axis.inputs[0][0]).all_equal());
157
158 let m_axis = matmuls[0].m_axis.inputs[0][0];
159 let k_axis = matmuls[0].k_axis.inputs[0][0];
160 let must_swap = m_axis > k_axis;
161
162 let ops = tract_linalg::ops();
163 let (choice,) = matmuls
164 .iter()
165 .map(|mm| mm.preferred_packing())
166 .dedup_by(|a, b| a.same_as(&**b))
167 .collect_tuple::<(_,)>()
168 .unwrap_or_else(|| {
169 let it = ops
170 .all_possible_packing(matmuls[0].weight_type.clone())
171 .min_by_key(|format| {
172 matmuls
173 .iter()
174 .map(|linear| linear.cost_for_weights(&**format))
175 .max()
176 .unwrap()
177 })
178 .unwrap();
179 (clone_box(it),)
180 });
181
182 let packed = choice.prepare_tensor(&self.0, k_axis, m_axis).context("in prepare_tensor")?;
183 let fact = clone_box(
184 packed.as_slice::<Opaque>()?[0]
185 .downcast_ref::<Box<dyn MMMInputValue>>()
186 .unwrap()
187 .opaque_fact(),
188 );
189 let konst = Const(packed.into_arc_tensor(), Some(fact));
190 let mut patch = TypedModelPatch::new(format!("Packing {node} as {choice:?}"));
191 let konst = patch.wire_node(&node.name, konst, &[])?;
192 for succ in &node.outputs[0].successors {
193 let succ_node = model.node(succ.node);
194 let mut taps = patch.taps(model, &succ_node.inputs)?;
195 taps[succ.slot] = konst[0];
196 let new_op: Box<dyn TypedOp> = if let Some(gather) = succ_node.op_as::<Gather>() {
197 let output_type = succ_node.outputs[0].fact.datum_type;
198 Box::new(Gather { axis: gather.axis, output_type: Some(output_type) })
199 } else if let Some(linear) = succ_node.op_as::<EinSum>() {
200 let mut op = linear.clone();
201 if must_swap {
202 op.axes.iter_all_axes_mut().for_each(|axes| {
203 axes.inputs[0].iter_mut().for_each(|pos| {
204 *pos = if *pos == k_axis {
205 m_axis
206 } else if *pos == m_axis {
207 k_axis
208 } else {
209 *pos
210 }
211 })
212 });
213 }
214 Box::new(op)
215 } else {
216 bail!("Unexpected op")
217 };
218 let replacement = patch.wire_node(&succ_node.name, new_op, &taps)?;
219 patch.shunt_outside(model, succ.node.into(), replacement[0])?;
220 }
221 Ok(Some(patch))
222 }
223}