1use crate::internal::*;
2
3pub fn cast(to: DatumType) -> Cast {
4 Cast { to }
5}
6
7pub fn wire_cast(
8 prefix: impl AsRef<str>,
9 target: &mut TypedModel,
10 inputs: &[OutletId],
11 operating_datum_type: DatumType,
12) -> TractResult<TVec<OutletId>> {
13 let prefix = prefix.as_ref();
14 let mut wires = tvec!();
15 for mut wire in inputs.iter().copied() {
16 if target.outlet_fact(wire)?.datum_type != operating_datum_type {
17 wire = target.wire_node(
18 target.unique_name(format!("{prefix}.cast")),
19 crate::ops::cast::cast(operating_datum_type),
20 &[wire],
21 )?[0];
22 }
23 wires.push(wire);
24 }
25 Ok(wires)
26}
27
28#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
29pub struct Cast {
30 pub to: DatumType,
31}
32
33impl Op for Cast {
34 fn name(&self) -> StaticName {
35 "Cast".into()
36 }
37
38 op_as_typed_op!();
39}
40
41impl EvalOp for Cast {
42 fn is_stateless(&self) -> bool {
43 true
44 }
45
46 fn eval_with_session(
47 &self,
48 _node_id: usize,
49 state: &TurnState,
50 inputs: TVec<TValue>,
51 ) -> TractResult<TVec<TValue>> {
52 let input = args_1!(inputs);
53 if input.datum_type() == self.to {
54 Ok(tvec!(input))
55 } else if input.datum_type() == TDim::datum_type() {
56 let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
57 let input_plain = input.try_as_plain()?;
58 let mut tmp_plain = tmp.try_as_plain_mut()?;
59 for (dim, i) in tract_itertools::izip!(
60 input_plain.as_slice::<TDim>()?,
61 tmp_plain.as_slice_mut::<i64>()?
62 ) {
63 *i = dim.eval(&state.resolved_symbols).to_i64()?
64 }
65 Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
66 } else {
67 Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
68 }
69 }
70}
71
72impl TypedOp for Cast {
73 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
74 let mut fact = self.to.fact(inputs[0].shape.clone());
75 fact.uniform_tdim = inputs[0].uniform_tdim.clone();
76 if let Some(u) = &inputs[0].uniform {
77 if let Ok(cast_u) = u.cast_to_dt(self.to) {
78 fact.uniform = Some(std::sync::Arc::new(cast_u.into_owned()));
79 }
80 }
81 Ok(tvec!(fact))
82 }
83
84 fn input_roi(
85 &self,
86 model: &TypedModel,
87 node: &TypedNode,
88 ) -> TractResult<Option<TVec<Option<TDim>>>> {
89 crate::optim::propagate_roi::bubble_roi(model, node)
90 }
91
92 fn declutter(
93 &self,
94 model: &TypedModel,
95 node: &TypedNode,
96 ) -> TractResult<Option<TypedModelPatch>> {
97 if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
98 TypedModelPatch::shunt_one_op(model, node)
99 } else {
100 Ok(None)
101 }
102 }
103
104 fn axes_mapping(
105 &self,
106 inputs: &[&TypedFact],
107 outputs: &[&TypedFact],
108 ) -> TractResult<AxesMapping> {
109 AxesMapping::natural(inputs, outputs)
110 }
111
112 fn change_axes(
113 &self,
114 model: &TypedModel,
115 node: &TypedNode,
116 _io: InOut,
117 change: &AxisOp,
118 ) -> TractResult<Option<AxisChangeConsequence>> {
119 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
120 }
121
122 fn slice(
123 &self,
124 patch: &mut TypedModelPatch,
125 _model: &TypedModel,
126 node: &TypedNode,
127 _prefix: &str,
128 inputs: &[OutletId],
129 _output_axis: usize,
130 _start: &TDim,
131 _end: &TDim,
132 ) -> TractResult<Option<TVec<OutletId>>> {
133 patch.wire_node(&node.name, &node.op, inputs).map(Some)
134 }
135
136 as_op!();
137}