1use crate::internal::*;
2
3pub fn cast(to: DatumType) -> Cast {
4 Cast { to }
5}
6
7pub fn wire_cast(
8 prefix: impl AsRef<str>,
9 target: &mut TypedModel,
10 inputs: &[OutletId],
11 operating_datum_type: DatumType,
12) -> TractResult<TVec<OutletId>> {
13 let prefix = prefix.as_ref();
14 let mut wires = tvec!();
15 for mut wire in inputs.iter().copied() {
16 if target.outlet_fact(wire)?.datum_type != operating_datum_type {
17 wire = target.wire_node(
18 target.unique_name(format!("{prefix}.cast")),
19 crate::ops::cast::cast(operating_datum_type),
20 &[wire],
21 )?[0];
22 }
23 wires.push(wire);
24 }
25 Ok(wires)
26}
27
28#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
29pub struct Cast {
30 pub to: DatumType,
31}
32
33impl Op for Cast {
34 fn name(&self) -> StaticName {
35 "Cast".into()
36 }
37
38 op_as_typed_op!();
39 impl_op_same_as!();
40}
41
42impl EvalOp for Cast {
43 fn is_stateless(&self) -> bool {
44 true
45 }
46
47 fn eval_with_session(
48 &self,
49 _node_id: usize,
50 state: &TurnState,
51 inputs: TVec<TValue>,
52 ) -> TractResult<TVec<TValue>> {
53 let input = args_1!(inputs);
54 if input.datum_type() == self.to {
55 Ok(tvec!(input))
56 } else if input.datum_type() == TDim::datum_type() {
57 let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
58 for (dim, i) in
59 tract_itertools::izip!(input.as_slice::<TDim>()?, tmp.as_slice_mut::<i64>()?)
60 {
61 *i = dim.eval(&state.resolved_symbols).to_i64()?
62 }
63 Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
64 } else {
65 Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
66 }
67 }
68}
69
70impl TypedOp for Cast {
71 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
72 Ok(tvec!(self.to.fact(inputs[0].shape.clone())))
73 }
74
75 fn declutter(
76 &self,
77 model: &TypedModel,
78 node: &TypedNode,
79 ) -> TractResult<Option<TypedModelPatch>> {
80 if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
81 TypedModelPatch::shunt_one_op(model, node)
82 } else {
83 Ok(None)
84 }
85 }
86
87 fn axes_mapping(
88 &self,
89 inputs: &[&TypedFact],
90 outputs: &[&TypedFact],
91 ) -> TractResult<AxesMapping> {
92 AxesMapping::natural(inputs, outputs)
93 }
94
95 fn change_axes(
96 &self,
97 model: &TypedModel,
98 node: &TypedNode,
99 _io: InOut,
100 change: &AxisOp,
101 ) -> TractResult<Option<AxisChangeConsequence>> {
102 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
103 }
104
105 fn slice(
106 &self,
107 patch: &mut TypedModelPatch,
108 _model: &TypedModel,
109 node: &TypedNode,
110 _prefix: &str,
111 inputs: &[OutletId],
112 _output_axis: usize,
113 _start: &TDim,
114 _end: &TDim,
115 ) -> TractResult<Option<TVec<OutletId>>> {
116 patch.wire_node(&node.name, &node.op, inputs).map(Some)
117 }
118
119 as_op!();
120}