Skip to main content

tract_core/ops/
cast.rs

1use crate::internal::*;
2
3pub fn cast(to: DatumType) -> Cast {
4    Cast { to }
5}
6
7pub fn wire_cast(
8    prefix: impl AsRef<str>,
9    target: &mut TypedModel,
10    inputs: &[OutletId],
11    operating_datum_type: DatumType,
12) -> TractResult<TVec<OutletId>> {
13    let prefix = prefix.as_ref();
14    let mut wires = tvec!();
15    for mut wire in inputs.iter().copied() {
16        if target.outlet_fact(wire)?.datum_type != operating_datum_type {
17            wire = target.wire_node(
18                target.unique_name(format!("{prefix}.cast")),
19                crate::ops::cast::cast(operating_datum_type),
20                &[wire],
21            )?[0];
22        }
23        wires.push(wire);
24    }
25    Ok(wires)
26}
27
28#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
29pub struct Cast {
30    pub to: DatumType,
31}
32
33impl Op for Cast {
34    fn name(&self) -> StaticName {
35        "Cast".into()
36    }
37
38    op_as_typed_op!();
39}
40
41impl EvalOp for Cast {
42    fn is_stateless(&self) -> bool {
43        true
44    }
45
46    fn eval_with_session(
47        &self,
48        _node_id: usize,
49        state: &TurnState,
50        inputs: TVec<TValue>,
51    ) -> TractResult<TVec<TValue>> {
52        let input = args_1!(inputs);
53        if input.datum_type() == self.to {
54            Ok(tvec!(input))
55        } else if input.datum_type() == TDim::datum_type() {
56            let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
57            let input_plain = input.try_as_plain()?;
58            let mut tmp_plain = tmp.try_as_plain_mut()?;
59            for (dim, i) in tract_itertools::izip!(
60                input_plain.as_slice::<TDim>()?,
61                tmp_plain.as_slice_mut::<i64>()?
62            ) {
63                *i = dim.eval(&state.resolved_symbols).to_i64()?
64            }
65            Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
66        } else {
67            Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
68        }
69    }
70}
71
72impl TypedOp for Cast {
73    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
74        let mut fact = self.to.fact(inputs[0].shape.clone());
75        fact.uniform_tdim = inputs[0].uniform_tdim.clone();
76        if let Some(u) = &inputs[0].uniform {
77            if let Ok(cast_u) = u.cast_to_dt(self.to) {
78                fact.uniform = Some(std::sync::Arc::new(cast_u.into_owned()));
79            }
80        }
81        Ok(tvec!(fact))
82    }
83
84    fn input_roi(
85        &self,
86        model: &TypedModel,
87        node: &TypedNode,
88    ) -> TractResult<Option<TVec<Option<TDim>>>> {
89        crate::optim::propagate_roi::bubble_roi(model, node)
90    }
91
92    fn declutter(
93        &self,
94        model: &TypedModel,
95        node: &TypedNode,
96    ) -> TractResult<Option<TypedModelPatch>> {
97        if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
98            TypedModelPatch::shunt_one_op(model, node)
99        } else {
100            Ok(None)
101        }
102    }
103
104    fn axes_mapping(
105        &self,
106        inputs: &[&TypedFact],
107        outputs: &[&TypedFact],
108    ) -> TractResult<AxesMapping> {
109        AxesMapping::natural(inputs, outputs)
110    }
111
112    fn change_axes(
113        &self,
114        model: &TypedModel,
115        node: &TypedNode,
116        _io: InOut,
117        change: &AxisOp,
118    ) -> TractResult<Option<AxisChangeConsequence>> {
119        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
120    }
121
122    fn slice(
123        &self,
124        patch: &mut TypedModelPatch,
125        _model: &TypedModel,
126        node: &TypedNode,
127        _prefix: &str,
128        inputs: &[OutletId],
129        _output_axis: usize,
130        _start: &TDim,
131        _end: &TDim,
132    ) -> TractResult<Option<TVec<OutletId>>> {
133        patch.wire_node(&node.name, &node.op, inputs).map(Some)
134    }
135
136    as_op!();
137}