1use crate::internal::*;
2use crate::ops::array::MultiBroadcastTo;
3
4pub fn cast(to: DatumType) -> Cast {
5 Cast { to }
6}
7
8pub fn wire_cast(
9 prefix: impl AsRef<str>,
10 target: &mut TypedModel,
11 inputs: &[OutletId],
12 operating_datum_type: DatumType,
13) -> TractResult<TVec<OutletId>> {
14 let prefix = prefix.as_ref();
15 let mut wires = tvec!();
16 for mut wire in inputs.iter().copied() {
17 if target.outlet_fact(wire)?.datum_type != operating_datum_type {
18 wire = target.wire_node(
19 target.unique_name(format!("{prefix}.cast")),
20 crate::ops::cast::cast(operating_datum_type),
21 &[wire],
22 )?[0];
23 }
24 wires.push(wire);
25 }
26 Ok(wires)
27}
28
29#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
30pub struct Cast {
31 pub to: DatumType,
32}
33
34impl Op for Cast {
35 fn name(&self) -> StaticName {
36 "Cast".into()
37 }
38
39 op_as_typed_op!();
40}
41
42impl EvalOp for Cast {
43 fn is_stateless(&self) -> bool {
44 true
45 }
46
47 fn eval_with_session(
48 &self,
49 _node_id: usize,
50 state: &TurnState,
51 inputs: TVec<TValue>,
52 ) -> TractResult<TVec<TValue>> {
53 let input = args_1!(inputs);
54 if input.datum_type() == self.to {
55 Ok(tvec!(input))
56 } else if input.datum_type() == TDim::datum_type() {
57 let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
58 let input_plain = input.try_as_plain()?;
59 let mut tmp_plain = tmp.try_as_plain_mut()?;
60 for (dim, i) in tract_itertools::izip!(
61 input_plain.as_slice::<TDim>()?,
62 tmp_plain.as_slice_mut::<i64>()?
63 ) {
64 *i = dim.eval(&state.resolved_symbols).to_i64()?
65 }
66 Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
67 } else {
68 Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
69 }
70 }
71}
72
73impl TypedOp for Cast {
74 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
75 let mut fact = self.to.fact(inputs[0].shape.clone());
76 fact.uniform_tdim = inputs[0].uniform_tdim.clone();
77 if let Some(u) = &inputs[0].uniform {
78 if let Ok(cast_u) = u.cast_to_dt(self.to) {
79 fact.uniform = Some(std::sync::Arc::new(cast_u.into_owned()));
80 }
81 }
82 Ok(tvec!(fact))
83 }
84
85 fn input_roi(
86 &self,
87 model: &TypedModel,
88 node: &TypedNode,
89 ) -> TractResult<Option<TVec<Option<TDim>>>> {
90 crate::optim::propagate_roi::bubble_roi(model, node)
91 }
92
93 fn declutter(
94 &self,
95 model: &TypedModel,
96 node: &TypedNode,
97 ) -> TractResult<Option<TypedModelPatch>> {
98 if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
99 return TypedModelPatch::shunt_one_op(model, node);
100 }
101 if let Some(prec) = model.linear_prec(node.id)?
111 && (prec.op_is::<IntoShape>() || prec.op_is::<MultiBroadcastTo>())
112 {
113 let mut patch = TypedModelPatch::default();
114 let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
115 wire = patch.wire_node(&node.name, &node.op, &wire)?;
116 wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
117 patch.shunt_outside(model, node.id.into(), wire[0])?;
118 return Ok(Some(patch));
119 }
120 Ok(None)
121 }
122
123 fn axes_mapping(
124 &self,
125 inputs: &[&TypedFact],
126 outputs: &[&TypedFact],
127 ) -> TractResult<AxesMapping> {
128 AxesMapping::natural(inputs, outputs)
129 }
130
131 fn change_axes(
132 &self,
133 model: &TypedModel,
134 node: &TypedNode,
135 _io: InOut,
136 change: &AxisOp,
137 ) -> TractResult<Option<AxisChangeConsequence>> {
138 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
139 }
140
141 fn slice(
142 &self,
143 patch: &mut TypedModelPatch,
144 _model: &TypedModel,
145 node: &TypedNode,
146 _prefix: &str,
147 inputs: &[OutletId],
148 _output_axis: usize,
149 _start: &TDim,
150 _end: &TDim,
151 ) -> TractResult<Option<TVec<OutletId>>> {
152 patch.wire_node(&node.name, &node.op, inputs).map(Some)
153 }
154
155 as_op!();
156}