Skip to main content

tract_core/ops/nn/
reduce.rs

1use crate::internal::Axis;
2use crate::internal::*;
3use crate::ops::binary::TypedBinOp;
4use crate::ops::cast::cast;
5use crate::ops::change_axes::wire_with_rank_broadcast;
6use crate::ops::element_wise::ElementWiseOp;
7use crate::ops::math::{Mul, Square, div, square};
8use std::convert::TryFrom;
9use std::iter::Sum;
10use std::mem::transmute;
11use tract_data::internal::ClampCast;
12use tract_data::itertools::Itertools;
13use tract_ndarray::prelude::*;
14use tract_num_traits::{AsPrimitive, Bounded};
15
16macro_rules! r {
17    ($($path:ident)::* ($dt:expr) ($($args:expr),*)) => {
18        match $dt {
19            DatumType::U8   => $($path)::*::<u8,_,_,_>($($args),*),
20            DatumType::I8   => $($path)::*::<i8,_,_,_>($($args),*),
21            DatumType::U16  => $($path)::*::<u16,_,_,_>($($args),*),
22            DatumType::I16  => $($path)::*::<i16,_,_,_>($($args),*),
23            DatumType::I32  => $($path)::*::<i32,_,_,_>($($args),*),
24            DatumType::I64  => $($path)::*::<i64,_,_,_>($($args),*),
25            DatumType::F16  => $($path)::*::<f16,_,_,_>($($args),*),
26            DatumType::F32  => $($path)::*::<f32,_,_,_>($($args),*),
27            DatumType::F64  => $($path)::*::<f64,_,_,_>($($args),*),
28            DatumType::QI8(_)  => $($path)::*::<i8,_,_,_>($($args),*),
29            DatumType::QU8(_)  => $($path)::*::<u8,_,_,_>($($args),*),
30            _ => bail!("{:?} is not a number", $dt)
31        }
32    };
33    ($($path:ident)::* ($dt:expr) ($($args:expr),*); $($q_path:ident)::* ($($q_args:expr),*)) => {
34        match $dt {
35            DatumType::U8   => $($path)::*::<u8,_,_,_>($($args),*),
36            DatumType::I8   => $($path)::*::<i8,_,_,_>($($args),*),
37            DatumType::U16  => $($path)::*::<u16,_,_,_>($($args),*),
38            DatumType::I16  => $($path)::*::<i16,_,_,_>($($args),*),
39            DatumType::I32  => $($path)::*::<i32,_,_,_>($($args),*),
40            DatumType::I64  => $($path)::*::<i64,_,_,_>($($args),*),
41            DatumType::F16  => $($path)::*::<f16,_,_,_>($($args),*),
42            DatumType::F32  => $($path)::*::<f32,_,_,_>($($args),*),
43            DatumType::F64  => $($path)::*::<f64,_,_,_>($($args),*),
44            DatumType::QI8(_)  => $($q_path)::*::<i8,_,_,_>($($q_args),*),
45            DatumType::QU8(_)  => $($q_path)::*::<u8,_,_,_>($($q_args),*),
46            _ => bail!("{:?} is not a number", $dt)
47        }
48    }
49}
50
51#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
52pub enum Reducer {
53    ArgMax(bool), // take last
54    ArgMin(bool),
55    Max,
56    Min,
57    Prod,
58    Sum,
59    MeanOfSquares,
60    All,
61    Any,
62}
63
64impl Reducer {
65    pub fn reduce(&self, axes: &[usize], input: &Tensor) -> TractResult<Tensor> {
66        use Reducer::*;
67        let dt = input.datum_type();
68        let output_shape: Vec<usize> = input
69            .shape()
70            .iter()
71            .enumerate()
72            .map(|(ax, &d)| if axes.contains(&ax) { 1 } else { d })
73            .collect();
74        let (zp, scale) = input.datum_type().zp_scale();
75        unsafe {
76            let mut t = match self {
77                ArgMax(last) => {
78                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmax_t, *last))
79                }
80                ArgMin(last) => {
81                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmin_t, *last))
82                }
83                Min => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, min_t, ())),
84                Max => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, max_t, ())),
85                Prod => {
86                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, prod_t, ()); Self::reduce_t(self, axes, &output_shape, input, q_prod_t, (zp, scale)))
87                }
88                Sum => {
89                    if dt.is_float() {
90                        dispatch_floatlike!(Self::sum(dt)(self, axes, input))
91                    } else {
92                        r!(Self::reduce_t(dt)(
93                            self,
94                            axes,
95                            &output_shape,
96                            input,
97                            q_sum_t,
98                            (zp, scale)
99                        ))
100                    }
101                }
102                MeanOfSquares => self.mean_of_squares(axes, input)?,
103                All => Self::reduce_t(self, axes, &output_shape, input, all_bool, ()),
104                Any => Self::reduce_t(self, axes, &output_shape, input, any_bool, ()),
105            };
106            if input.datum_type().is_quantized()
107                && input.datum_type().unquantized() == t.datum_type().unquantized()
108            {
109                t.set_datum_type(input.datum_type());
110            }
111            Ok(t)
112        }
113    }
114
115    unsafe fn reduce_t<T, TO, F, A>(
116        &self,
117        axes: &[usize],
118        output_shape: &[usize],
119        input_tensor: &Tensor,
120        f: F,
121        args: A,
122    ) -> Tensor
123    where
124        F: for<'a> Fn(ArrayViewD<'a, T>, A) -> TO,
125        T: Copy + Datum,
126        TO: Copy + Datum,
127        A: Copy,
128    {
129        use ndarray::*;
130        let input = unsafe { input_tensor.to_array_view_unchecked::<T>() };
131        let result = Array::from_shape_fn(output_shape, |coords| {
132            let slice_spec: Vec<SliceInfoElem> = coords
133                .slice()
134                .iter()
135                .enumerate()
136                .map(|(ax, &d)| if axes.contains(&ax) { (..).into() } else { d.into() })
137                .collect();
138            let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
139            let slice = input.slice(&slice_info);
140            f(slice, args)
141        });
142        result.into_tensor()
143    }
144
145    // sum is a special citizen: enough activity that it gets "special"
146    // treatment. we could use the same "algo" for min, max and prod, to the
147    // price of more code in the library. argmax and argmin are more
148    // tricky (not associative)
149    unsafe fn sum<T>(&self, axes: &[usize], input: &Tensor) -> Tensor
150    where
151        T: Copy + Datum + num_traits::Zero + Sum,
152        f16: AsPrimitive<T>,
153        f32: AsPrimitive<T>,
154    {
155        if axes.len() == 0 {
156            return input.to_owned();
157        }
158
159        // use tract-optimized path only when single reuction axis and is at end
160        if axes.len() > 1 || axes[0] != input.rank() - 1 {
161            let mut operative_axes = vec![];
162            let mut operative_shape: Vec<usize> = vec![];
163            for (ix, dim) in input.shape().iter().enumerate() {
164                // axis is reduced, but is not the first of a series of reduced axes
165                if ix > 0 && axes.contains(&ix) && axes.contains(&(ix - 1)) {
166                    *operative_shape.last_mut().unwrap() *= *dim;
167                } else if axes.contains(&ix) {
168                    operative_axes.push(operative_shape.len());
169                    operative_shape.push(*dim);
170                } else {
171                    operative_shape.push(*dim);
172                }
173            }
174            let mut output = unsafe {
175                input
176                    .to_array_view_unchecked::<T>()
177                    .into_shape_with_order(operative_shape)
178                    .unwrap()
179                    .sum_axis(Axis(*operative_axes.iter().max().unwrap()))
180            };
181
182            for axis in operative_axes.iter().rev().skip(1) {
183                output = output.sum_axis(Axis(*axis));
184            }
185
186            let mut output = output.into_tensor();
187
188            for &axis in axes {
189                output.insert_axis(axis).unwrap();
190            }
191
192            output
193        } else {
194            let mut output: Option<ArrayD<T>> = None;
195            for axis in axes.iter().copied() {
196                let input_view = output
197                    .as_ref()
198                    .map(|o| o.view())
199                    .unwrap_or_else(|| unsafe { input.to_array_view_unchecked::<T>() });
200
201                // Create array that will contain intermidiate result
202                let reduced_dim = input_view.shape()[axis];
203                let input_stride = input_view.strides()[axis] as usize;
204                let output_shape = input_view
205                    .shape()
206                    .iter()
207                    .enumerate()
208                    .map(|(idx, dim)| if idx != axis { *dim } else { 1 })
209                    .collect_vec();
210
211                output = Some(ArrayD::from_shape_fn(output_shape.clone(), |coords| {
212                    let mut view = input_view.view();
213                    for ix in 0..output_shape.len() {
214                        if ix != axis {
215                            view.collapse_axis(Axis(ix), coords[ix]);
216                        }
217                    }
218
219                    if let Some(slice) = view.as_slice() {
220                        if T::datum_type() == f16::datum_type() {
221                            let slice: &[f16] = unsafe { std::mem::transmute(slice) };
222                            (tract_linalg::ops().sum_f16)()
223                                .run_with_params(slice, ())
224                                .unwrap()
225                                .as_()
226                        } else if T::datum_type() == f32::datum_type() {
227                            let slice: &[f32] = unsafe { std::mem::transmute(slice) };
228                            (tract_linalg::ops().sum_f32)()
229                                .run_with_params(slice, ())
230                                .unwrap()
231                                .as_()
232                        } else {
233                            slice.iter().cloned().sum::<T>()
234                        }
235                    } else {
236                        let first: *const T = &input_view[coords];
237                        let mut sum = T::zero();
238                        for i in 0..reduced_dim {
239                            sum = sum + unsafe { *(first.add(i * input_stride)) };
240                        }
241                        sum
242                    }
243                }));
244            }
245            output.unwrap().into_tensor()
246        }
247    }
248
249    fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult<Tensor> {
250        let dt = input.datum_type();
251        let mut input = input.cast_to::<f32>()?.into_owned();
252        input.try_as_plain_mut()?.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
253        let mut output = unsafe { self.sum::<f32>(axis, &input) };
254        let norm = output.len() as f32 / input.len() as f32;
255        output.try_as_plain_mut()?.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
256        Ok(output.cast_to_dt(dt)?.into_owned())
257    }
258}
259
260fn argmax_t<T>(v: ArrayViewD<T>, last: bool) -> i64
261where
262    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
263{
264    v.iter()
265        .copied()
266        .enumerate()
267        .fold(
268            (0usize, T::min_value()),
269            |acc, v| {
270                if v.1 > acc.1 || (last && acc.1 == v.1) { v } else { acc }
271            },
272        )
273        .0 as i64
274}
275
276fn argmin_t<T>(v: ArrayViewD<T>, last: bool) -> i64
277where
278    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
279{
280    v.iter()
281        .copied()
282        .enumerate()
283        .fold(
284            (0usize, T::max_value()),
285            |acc, v| {
286                if v.1 < acc.1 || (last && acc.1 == v.1) { v } else { acc }
287            },
288        )
289        .0 as i64
290}
291
292fn max_t<T>(v: ArrayViewD<T>, _: ()) -> T
293where
294    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
295{
296    if T::datum_type() == f32::datum_type()
297        && let Some(slice) = v.as_slice()
298    {
299        let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
300        (tract_linalg::ops().max_f32)().run(slice).unwrap();
301    }
302    v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
303}
304
305fn min_t<T>(v: ArrayViewD<T>, _: ()) -> T
306where
307    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
308{
309    v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v })
310}
311
312fn prod_t<T>(v: ArrayViewD<T>, _: ()) -> T
313where
314    T: Copy + Datum + num_traits::One,
315{
316    v.fold(T::one(), |acc, &v| acc * v)
317}
318
319fn q_prod_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
320where
321    T: Copy + num_traits::AsPrimitive<f32> + Bounded + Datum,
322    f32: num_traits::AsPrimitive<T>,
323{
324    let (zp, scale) = zp_scale;
325    (v.fold(1f32, |acc, &v| acc * (v.as_() - zp as f32)) * scale.powi(v.len() as i32 - 1)
326        + zp as f32)
327        .clamp_cast()
328}
329
330fn q_sum_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
331where
332    T: Copy + Bounded + num_traits::AsPrimitive<i32> + Datum,
333    i32: num_traits::AsPrimitive<T>,
334{
335    let (zp, _) = zp_scale;
336    (v.fold(0i32, |acc, &v| acc + v.as_()) - zp * (v.len() as i32 - 1)).clamp_cast()
337}
338
339fn all_bool(v: ArrayViewD<bool>, _: ()) -> bool {
340    v.iter().all(|v| *v)
341}
342
343fn any_bool(v: ArrayViewD<bool>, _: ()) -> bool {
344    v.iter().any(|v| *v)
345}
346
347#[derive(Clone, Debug, new, Hash, PartialEq, Eq)]
348pub struct Reduce {
349    pub axes: TVec<usize>,
350    pub reducer: Reducer,
351}
352
353impl Op for Reduce {
354    fn name(&self) -> StaticName {
355        format!("Reduce<{:?}>", self.reducer).into()
356    }
357    fn info(&self) -> TractResult<Vec<String>> {
358        Ok(vec![format!("axes: {:?}", self.axes)])
359    }
360    op_as_typed_op!();
361}
362
363impl EvalOp for Reduce {
364    fn is_stateless(&self) -> bool {
365        true
366    }
367
368    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
369        Ok(tvec!(self.reducer.reduce(&self.axes, &inputs[0])?.into()))
370    }
371}
372
373impl TypedOp for Reduce {
374    fn input_roi(
375        &self,
376        model: &TypedModel,
377        node: &TypedNode,
378    ) -> TractResult<Option<TVec<Option<TDim>>>> {
379        crate::optim::propagate_roi::bubble_roi(model, node)
380    }
381
382    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
383        ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
384        if inputs[0].datum_type == TDim::datum_type() {
385            bail!("Reduce input must be cast from TDim to i64 beforehand")
386        }
387        let mut shape: TVec<_> = inputs[0].shape.to_tvec();
388        for &ax in &self.axes {
389            shape[ax] = 1.to_dim();
390        }
391        let dt = if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
392            DatumType::I64
393        } else {
394            inputs[0].datum_type
395        };
396        Ok(tvec!(dt.fact(shape)))
397    }
398
399    fn declutter(
400        &self,
401        model: &TypedModel,
402        node: &TypedNode,
403    ) -> TractResult<Option<TypedModelPatch>> {
404        if let Some(patch) = self.declutter_mean_of_square(model, node)? {
405            return Ok(Some(patch));
406        }
407        if let Some(patch) = self.declutter_scalar_mul_then_sum(model, node)? {
408            return Ok(Some(patch));
409        }
410        if let Some(patch) = self.declutter_reduce_reduce(model, node)? {
411            return Ok(Some(patch));
412        }
413        if let Some(patch) = super::rms_norm::detect_rms_norm(self, model, node)? {
414            return Ok(Some(patch));
415        }
416        Ok(None)
417    }
418
419    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
420        let dt = inputs[0].datum_type;
421        let count: TDim = inputs[0].shape.iter().product();
422        match self.reducer {
423            Reducer::Sum
424            | Reducer::Prod
425            | Reducer::Min
426            | Reducer::Max
427            | Reducer::All
428            | Reducer::Any => Ok(tvec!((Cost::FMA(dt), count))),
429            Reducer::MeanOfSquares => Ok(tvec!((Cost::FMA(dt), count * 2))),
430            Reducer::ArgMax(_) | Reducer::ArgMin(_) => Ok(tvec!((Cost::FMA(dt), count))),
431        }
432    }
433
434    fn axes_mapping(
435        &self,
436        inputs: &[&TypedFact],
437        outputs: &[&TypedFact],
438    ) -> TractResult<AxesMapping> {
439        let mut letters = 'a'..;
440        let axes = (0..inputs[0].rank())
441            .flat_map(|ix| {
442                if self.axes.contains(&ix) {
443                    tvec!(
444                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
445                            .input(0, ix),
446                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
447                            .output(0, ix),
448                    )
449                } else {
450                    tvec!(
451                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
452                            .input(0, ix)
453                            .output(0, ix)
454                    )
455                }
456                .into_iter()
457            })
458            .collect_vec();
459        AxesMapping::new(1, 1, axes)
460    }
461
462    fn change_axes(
463        &self,
464        model: &TypedModel,
465        node: &TypedNode,
466        _io: InOut,
467        change: &AxisOp,
468    ) -> TractResult<Option<AxisChangeConsequence>> {
469        let mut axes = tvec!();
470        for reduced in &self.axes {
471            rule_if_some!(axis = change.transform_axis(*reduced));
472            axes.push(axis);
473        }
474        axes.sort();
475        let op = Some(Box::new(Self { axes, ..self.clone() }) as _);
476        Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
477    }
478
479    fn slice(
480        &self,
481        patch: &mut TypedModelPatch,
482        _model: &TypedModel,
483        node: &TypedNode,
484        _prefix: &str,
485        inputs: &[OutletId],
486        output_axis: usize,
487        _start: &TDim,
488        _end: &TDim,
489    ) -> TractResult<Option<TVec<OutletId>>> {
490        rule_if!(!self.axes.contains(&output_axis));
491        patch.wire_node(&node.name, &node.op, inputs).map(Some)
492    }
493
494    as_op!();
495}
496
497impl Reduce {
498    fn declutter_reduce_reduce(
499        &self,
500        model: &TypedModel,
501        node: &TypedNode,
502    ) -> TractResult<Option<TypedModelPatch>> {
503        use Reducer::*;
504        rule_if_some!(prec = model.linear_prec(node.id)?);
505        rule_if_some!(prec_reduce = prec.op_as::<Self>());
506        rule_if!(prec_reduce.reducer == self.reducer);
507        rule_if!([Sum, Prod, Min, Max].contains(&self.reducer));
508        let mut patch = TypedModelPatch::default();
509        let wire = patch.tap_model(model, prec.inputs[0])?;
510        let wire = patch.wire_node(
511            &node.name,
512            Self {
513                reducer: self.reducer,
514                axes: prec_reduce
515                    .axes
516                    .iter()
517                    .chain(self.axes.iter())
518                    .copied()
519                    .sorted()
520                    .dedup()
521                    .collect(),
522            },
523            &[wire],
524        )?;
525        patch.shunt_outside(model, node.id.into(), wire[0])?;
526        Ok(Some(patch))
527    }
528
529    fn declutter_scalar_mul_then_sum(
530        &self,
531        model: &TypedModel,
532        node: &TypedNode,
533    ) -> TractResult<Option<TypedModelPatch>> {
534        if self.reducer == Reducer::Sum {
535            rule_if_some!(prec = model.linear_prec(node.id)?);
536            rule_if_some!(prec_bin = prec.op_as::<TypedBinOp>());
537            rule_if!(prec_bin.0.is::<Mul>());
538            let mul_input_fact = model.node_input_facts(prec.id)?;
539            rule_if_some!(
540                scalar_slot = mul_input_fact
541                    .iter()
542                    .position(|f| f.konst.as_ref().is_some_and(|k| k.volume() == 1))
543            );
544            let mut patch = TypedModelPatch::default();
545            let scalar = patch.tap_model(model, prec.inputs[scalar_slot])?;
546            let wire = patch.tap_model(model, prec.inputs[1 - scalar_slot])?;
547            let wire = patch.wire_node(&node.name, self.clone(), &[wire])?[0];
548            let wire = patch.wire_node(&prec.name, prec_bin.clone(), &[wire, scalar])?[0];
549            patch.shunt_outside(model, node.id.into(), wire)?;
550            return Ok(Some(patch));
551        }
552        Ok(None)
553    }
554
555    fn declutter_mean_of_square(
556        &self,
557        model: &TypedModel,
558        node: &TypedNode,
559    ) -> TractResult<Option<TypedModelPatch>> {
560        if self.reducer == Reducer::Sum {
561            rule_if_some!(prec = model.linear_prec(node.id)?);
562            rule_if_some!(prec_ew = prec.op_as::<ElementWiseOp>());
563            rule_if!(prec_ew.0.is::<Square>());
564            rule_if!(node.outputs.len() == 1);
565            rule_if!(node.outputs[0].successors.len() == 1);
566            let our_inlet = node.outputs[0].successors[0];
567            let succ = model.node(our_inlet.node);
568            rule_if_some!(succ_bin = succ.op_as::<TypedBinOp>());
569            rule_if!(succ_bin.0.is::<Mul>());
570            let other = succ.inputs[1 - our_inlet.slot];
571            rule_if_some!(other_konst = model.outlet_fact(other)?.uniform.as_ref());
572            let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product();
573            rule_if_some!(norm = norm.as_i64());
574            rule_if!(norm > 0);
575            let norm = tensor0((norm as f32).recip());
576            if other_konst.close_enough(&norm, Approximation::Close).is_ok() {
577                let mut patch = TypedModelPatch::default();
578                let wire = patch.tap_model(model, prec.inputs[0])?;
579                let wire = patch.wire_node(
580                    &node.name,
581                    Reduce::new(self.axes.clone(), Reducer::MeanOfSquares),
582                    &[wire],
583                )?[0];
584                patch.shunt_outside(model, succ.id.into(), wire)?;
585                return Ok(Some(patch));
586            }
587        }
588        Ok(None)
589    }
590}
591
592pub fn expand_mean_of_squares(
593    _ctx: &(),
594    model: &TypedModel,
595    node: &TypedNode,
596    name: &str,
597    op: &Reduce,
598) -> TractResult<Option<TypedModelPatch>> {
599    rule_if!(op.reducer == Reducer::MeanOfSquares);
600    let mut patch = TypedModelPatch::default();
601    let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
602    let input_fact = model.outlet_fact(node.inputs[0])?;
603    let dt = input_fact.datum_type;
604    if dt != f32::datum_type() {
605        wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
606    }
607    wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
608    wire = patch.wire_node(
609        format!("{name}.sum"),
610        Reduce::new(op.axes.clone(), Reducer::Sum),
611        &wire,
612    )?;
613    let card = input_fact
614        .shape
615        .iter()
616        .enumerate()
617        .filter(|(ix, _dim)| op.axes.contains(ix))
618        .map(|(_ix, dim)| dim)
619        .product::<TDim>();
620    let card = patch.add_const(format!("{name}.card"), tensor0(card))?;
621    let card = patch.wire_node(format!("{name}.card_to_f32"), cast(f32::datum_type()), &[card])?;
622
623    wire =
624        wire_with_rank_broadcast(format!("{name}.norm"), &mut patch, div(), &[wire[0], card[0]])?;
625    if dt != f32::datum_type() {
626        wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
627    }
628    patch.shunt_outside(model, node.id.into(), wire[0])?;
629    Ok(Some(patch))
630}