Skip to main content

tract_core/ops/
cast.rs

1use crate::internal::*;
2use crate::ops::array::MultiBroadcastTo;
3
4pub fn cast(to: DatumType) -> Cast {
5    Cast { to }
6}
7
8pub fn wire_cast(
9    prefix: impl AsRef<str>,
10    target: &mut TypedModel,
11    inputs: &[OutletId],
12    operating_datum_type: DatumType,
13) -> TractResult<TVec<OutletId>> {
14    let prefix = prefix.as_ref();
15    let mut wires = tvec!();
16    for mut wire in inputs.iter().copied() {
17        if target.outlet_fact(wire)?.datum_type != operating_datum_type {
18            wire = target.wire_node(
19                target.unique_name(format!("{prefix}.cast")),
20                crate::ops::cast::cast(operating_datum_type),
21                &[wire],
22            )?[0];
23        }
24        wires.push(wire);
25    }
26    Ok(wires)
27}
28
29#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
30pub struct Cast {
31    pub to: DatumType,
32}
33
34impl Op for Cast {
35    fn name(&self) -> StaticName {
36        "Cast".into()
37    }
38
39    op_as_typed_op!();
40}
41
42impl EvalOp for Cast {
43    fn is_stateless(&self) -> bool {
44        true
45    }
46
47    fn eval_with_session(
48        &self,
49        _node_id: usize,
50        state: &TurnState,
51        inputs: TVec<TValue>,
52    ) -> TractResult<TVec<TValue>> {
53        let input = args_1!(inputs);
54        if input.datum_type() == self.to {
55            Ok(tvec!(input))
56        } else if input.datum_type() == TDim::datum_type() {
57            let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
58            let input_plain = input.try_as_plain()?;
59            let mut tmp_plain = tmp.try_as_plain_mut()?;
60            for (dim, i) in tract_itertools::izip!(
61                input_plain.as_slice::<TDim>()?,
62                tmp_plain.as_slice_mut::<i64>()?
63            ) {
64                *i = dim.eval(&state.resolved_symbols).to_i64()?
65            }
66            Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
67        } else {
68            Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
69        }
70    }
71}
72
73impl TypedOp for Cast {
74    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
75        let mut fact = self.to.fact(inputs[0].shape.clone());
76        fact.uniform_tdim = inputs[0].uniform_tdim.clone();
77        if let Some(u) = &inputs[0].uniform {
78            if let Ok(cast_u) = u.cast_to_dt(self.to) {
79                fact.uniform = Some(std::sync::Arc::new(cast_u.into_owned()));
80            }
81        }
82        Ok(tvec!(fact))
83    }
84
85    fn input_roi(
86        &self,
87        model: &TypedModel,
88        node: &TypedNode,
89    ) -> TractResult<Option<TVec<Option<TDim>>>> {
90        crate::optim::propagate_roi::bubble_roi(model, node)
91    }
92
93    fn declutter(
94        &self,
95        model: &TypedModel,
96        node: &TypedNode,
97    ) -> TractResult<Option<TypedModelPatch>> {
98        if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
99            return TypedModelPatch::shunt_one_op(model, node);
100        }
101        // linear_prec (fan-in=1, fan-out=1) rather than single_prec: swapping
102        // through a fan-out predecessor clones it, and the clone breaks
103        // downstream pattern detectors (e.g. Square+Reduce<Sum>+Mul fusion into
104        // Reduce<MeanOfSquares>, which then feeds RmsNorm detection).
105        //
106        // AxisOp is intentionally NOT in the predicate: pulling Cast above an
107        // AxisOp (Reshape/Move/Add/Rm) prevents the CUDA conversion from
108        // fusing the post-AxisOp Cast into the downstream GEMM-class kernel,
109        // leaving ~64 standalone CudaCast ops on OpenELM-270M (TG128 -4%).
110        if let Some(prec) = model.linear_prec(node.id)?
111            && (prec.op_is::<IntoShape>() || prec.op_is::<MultiBroadcastTo>())
112        {
113            let mut patch = TypedModelPatch::default();
114            let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
115            wire = patch.wire_node(&node.name, &node.op, &wire)?;
116            wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
117            patch.shunt_outside(model, node.id.into(), wire[0])?;
118            return Ok(Some(patch));
119        }
120        Ok(None)
121    }
122
123    fn axes_mapping(
124        &self,
125        inputs: &[&TypedFact],
126        outputs: &[&TypedFact],
127    ) -> TractResult<AxesMapping> {
128        AxesMapping::natural(inputs, outputs)
129    }
130
131    fn change_axes(
132        &self,
133        model: &TypedModel,
134        node: &TypedNode,
135        _io: InOut,
136        change: &AxisOp,
137    ) -> TractResult<Option<AxisChangeConsequence>> {
138        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
139    }
140
141    fn slice(
142        &self,
143        patch: &mut TypedModelPatch,
144        _model: &TypedModel,
145        node: &TypedNode,
146        _prefix: &str,
147        inputs: &[OutletId],
148        _output_axis: usize,
149        _start: &TDim,
150        _end: &TDim,
151    ) -> TractResult<Option<TVec<OutletId>>> {
152        patch.wire_node(&node.name, &node.op, inputs).map(Some)
153    }
154
155    as_op!();
156}