1use crate::internal::*;
2
3#[derive(Debug, Clone, Hash, Eq, PartialEq)]
4pub struct Const(Arc<Tensor>, Option<Box<dyn ExoticFact>>);
5
6impl Const {
7 pub fn new(tensor: Arc<Tensor>) -> TractResult<Const> {
8 Self::new_with_opt_exotic_fact(tensor, None)
9 }
10
11 pub fn new_with_exotic_fact(
12 tensor: Arc<Tensor>,
13 fact: Box<dyn ExoticFact>,
14 ) -> TractResult<Const> {
15 Self::new_with_opt_exotic_fact(tensor, Some(fact))
16 }
17
18 pub fn new_with_opt_exotic_fact(
19 tensor: Arc<Tensor>,
20 fact: Option<Box<dyn ExoticFact>>,
21 ) -> TractResult<Const> {
22 ensure!(fact.is_some() || tensor.is_plain(), "Exotic tensor requires an exotic_fact");
23 Ok(Const(tensor, fact))
24 }
25
26 pub fn val(&self) -> &Arc<Tensor> {
27 &self.0
28 }
29
30 pub fn exotic_fact(&self) -> Option<&dyn ExoticFact> {
31 self.1.as_deref()
32 }
33}
34
35impl Op for Const {
36 fn name(&self) -> StaticName {
37 "Const".into()
38 }
39
40 op_as_typed_op!();
41}
42
43impl EvalOp for Const {
44 fn is_stateless(&self) -> bool {
45 true
46 }
47
48 fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
49 Ok(tvec![Arc::clone(&self.0).into_tvalue()])
50 }
51}
52
53impl TypedOp for Const {
54 as_op!();
55
56 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
57 let fact = if self.1.is_some() {
58 let mut f = TypedFact::dt_shape(
62 self.0.datum_type(),
63 ShapeFact::from_dims(self.0.shape().iter().map(TDim::from)),
64 );
65 f.konst = Some(Arc::clone(&self.0));
66 f.exotic_fact.clone_from(&self.1);
67 f
68 } else {
69 TypedFact::try_from(&self.0)?
71 };
72 Ok(tvec!(fact))
73 }
74
75 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
76 Ok(tvec!((Cost::Params(self.0.datum_type().unquantized()), self.0.len().into())))
77 }
78
79 fn concretize_dims(
80 &self,
81 _source: &TypedModel,
82 node: &TypedNode,
83 target: &mut TypedModel,
84 _mapping: &HashMap<OutletId, OutletId>,
85 values: &SymbolValues,
86 ) -> TractResult<TVec<OutletId>> {
87 let op = if self.0.datum_type() == TDim::datum_type() {
88 let mut tensor = self.0.clone().into_tensor();
89 for d in tensor.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
90 *d = d.eval(values);
91 }
92 Const(tensor.into_arc_tensor(), self.1.clone())
93 } else {
94 self.clone()
95 };
96 target.wire_node(&node.name, op, &[])
97 }
98
99 fn change_axes(
100 &self,
101 _model: &TypedModel,
102 _node: &TypedNode,
103 io: InOut,
104 change: &AxisOp,
105 ) -> TractResult<Option<AxisChangeConsequence>> {
106 anyhow::ensure!(io == InOut::Out(0));
107 let mut new_tensor = self.0.clone().into_tensor();
108 if change.change_tensor(&mut new_tensor, false).is_ok() {
109 let mut sub = Const(new_tensor.into_arc_tensor(), None);
110 if self.1.is_some() {
111 let my_fact = self.output_facts(&[])?;
112 let changed_fact = change.output_facts(&[&my_fact[0]])?;
113 sub.1 = changed_fact[0].exotic_fact.clone();
114 }
115 Ok(Some(AxisChangeConsequence {
116 substitute_op: Some(Box::new(sub)),
117 wire_changes: tvec!((io, change.clone())),
118 }))
119 } else {
120 Ok(None)
121 }
122 }
123}